Skip to content

Commit 35d9a60

Browse files
committed
fix: fix type stability
1 parent 65d2efd commit 35d9a60

3 files changed

Lines changed: 35 additions & 26 deletions

File tree

.github/workflows/Test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
actions: write
2626
contents: read
2727
strategy:
28-
fail-fast: true # TODO: toggle
28+
fail-fast: false # TODO: toggle
2929
matrix:
3030
version:
3131
- "1.10"

DifferentiationInterface/src/first_order/jacobian.jl

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,14 @@ struct PushforwardJacobianPrep{
138138
BS<:BatchSizeSettings,
139139
S<:AbstractVector{<:NTuple},
140140
R<:AbstractVector{<:NTuple},
141+
SE<:NTuple,
141142
E<:PushforwardPrep,
142143
} <: StandardJacobianPrep{SIG}
143144
_sig::Val{SIG}
144145
batch_size_settings::BS
145146
batched_seeds::S
146147
batched_results::R
148+
seed_example::SE
147149
pushforward_prep::E
148150
end
149151

@@ -152,15 +154,15 @@ struct PullbackJacobianPrep{
152154
BS<:BatchSizeSettings,
153155
S<:AbstractVector{<:NTuple},
154156
R<:AbstractVector{<:NTuple},
157+
SE<:NTuple,
155158
E<:PullbackPrep,
156-
Y,
157159
} <: StandardJacobianPrep{SIG}
158160
_sig::Val{SIG}
159161
batch_size_settings::BS
160162
batched_seeds::S
161163
batched_results::R
164+
seed_example::SE
162165
pullback_prep::E
163-
y_example::Y
164166
end
165167

166168
function prepare_jacobian_nokwarg(
@@ -213,11 +215,17 @@ function _prepare_jacobian_aux(
213215
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A
214216
]
215217
batched_results = [ntuple(b -> similar(y), Val(B)) for _ in batched_seeds]
218+
seed_example = ntuple(b -> zero(x), Val(B))
216219
pushforward_prep = prepare_pushforward_nokwarg(
217-
strict, f_or_f!y..., backend, x, ntuple(b -> zero(x), Val(B)), contexts...
220+
strict, f_or_f!y..., backend, x, seed_example, contexts...
218221
)
219222
return PushforwardJacobianPrep(
220-
_sig, batch_size_settings, batched_seeds, batched_results, pushforward_prep
223+
_sig,
224+
batch_size_settings,
225+
batched_seeds,
226+
batched_results,
227+
seed_example,
228+
pushforward_prep,
221229
)
222230
end
223231

@@ -238,11 +246,17 @@ function _prepare_jacobian_aux(
238246
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A
239247
]
240248
batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds]
249+
seed_example = ntuple(b -> zero(y), Val(B))
241250
pullback_prep = prepare_pullback_nokwarg(
242-
strict, f_or_f!y..., backend, x, ntuple(b -> zero(y), Val(B)), contexts...
251+
strict, f_or_f!y..., backend, x, seed_example, contexts...
243252
)
244253
return PullbackJacobianPrep(
245-
_sig, batch_size_settings, batched_seeds, batched_results, pullback_prep, y
254+
_sig,
255+
batch_size_settings,
256+
batched_seeds,
257+
batched_results,
258+
seed_example,
259+
pullback_prep,
246260
)
247261
end
248262

@@ -365,11 +379,11 @@ function _jacobian_aux(
365379
x,
366380
contexts::Vararg{Context,C},
367381
) where {FY,SIG,B,aligned,C}
368-
(; batch_size_settings, batched_seeds, pushforward_prep) = prep
382+
(; batch_size_settings, batched_seeds, seed_example, pushforward_prep) = prep
369383
(; A, B_last) = batch_size_settings
370384

371385
pushforward_prep_same = prepare_pushforward_same_point(
372-
f_or_f!y..., pushforward_prep, backend, x, ntuple(b -> zero(x), Val(B)), contexts...
386+
f_or_f!y..., pushforward_prep, backend, x, seed_example, contexts...
373387
)
374388

375389
jac = mapreduce(hcat, eachindex(batched_seeds)) do a
@@ -421,16 +435,11 @@ function _jacobian_aux(
421435
x,
422436
contexts::Vararg{Context,C},
423437
) where {FY,SIG,B,aligned,C}
424-
(; batch_size_settings, batched_seeds, pullback_prep, y_example) = prep
438+
(; batch_size_settings, batched_seeds, seed_example, pullback_prep) = prep
425439
(; A, B_last) = batch_size_settings
426440

427441
pullback_prep_same = prepare_pullback_same_point(
428-
f_or_f!y...,
429-
pullback_prep,
430-
backend,
431-
x,
432-
ntuple(b -> zero(y_example), Val(B)),
433-
contexts...,
442+
f_or_f!y..., pullback_prep, backend, x, seed_example, contexts...
434443
)
435444

436445
jac = mapreduce(vcat, eachindex(batched_seeds)) do a
@@ -458,11 +467,13 @@ function _jacobian_aux!(
458467
x,
459468
contexts::Vararg{Context,C},
460469
) where {FY,SIG,B,C}
461-
(; batch_size_settings, batched_seeds, batched_results, pushforward_prep) = prep
470+
(;
471+
batch_size_settings, batched_seeds, batched_results, seed_example, pushforward_prep
472+
) = prep
462473
(; N) = batch_size_settings
463474

464475
pushforward_prep_same = prepare_pushforward_same_point(
465-
f_or_f!y..., pushforward_prep, backend, x, batched_seeds[1], contexts...
476+
f_or_f!y..., pushforward_prep, backend, x, seed_example, contexts...
466477
)
467478

468479
for a in eachindex(batched_seeds, batched_results)
@@ -494,16 +505,12 @@ function _jacobian_aux!(
494505
x,
495506
contexts::Vararg{Context,C},
496507
) where {FY,SIG,B,C}
497-
(; batch_size_settings, batched_seeds, batched_results, pullback_prep, y_example) = prep
508+
(; batch_size_settings, batched_seeds, batched_results, seed_example, pullback_prep) =
509+
prep
498510
(; N) = batch_size_settings
499511

500512
pullback_prep_same = prepare_pullback_same_point(
501-
f_or_f!y...,
502-
pullback_prep,
503-
backend,
504-
x,
505-
ntuple(b -> zero(y_example), Val(B)),
506-
contexts...,
513+
f_or_f!y..., pullback_prep, backend, x, seed_example, contexts...
507514
)
508515

509516
for a in eachindex(batched_seeds, batched_results)

DifferentiationInterfaceTest/test/zero_backends.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ end
8686
logging=LOGGING,
8787
),
8888
)
89-
@test all(iszero, data_allocfree[!, :allocs])
89+
@testset "$(collect(row[1:4]))" for row in collect(eachrow(data_allocfree))
90+
@test row[:allocs] == 0
91+
end
9092
end
9193

9294
test_differentiation(

0 commit comments

Comments
 (0)