Skip to content

Commit 50d3770

Browse files
committed
Fixes
1 parent 1251518 commit 50d3770

5 files changed

Lines changed: 14 additions & 12 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ end
8787
function DI.prepare_gradient(
8888
f, backend::AutoReverseDiff{compile}, x; strict::Val=Val(false)
8989
) where {compile}
90-
_sig = DI.signature(f, backend, x)
90+
_sig = DI.signature(f, backend, x; strict)
9191
if compile
9292
tape = ReverseDiff.compile(GradientTape(f, x))
9393
return ReverseDiffGradientPrep(_sig, nothing, tape)

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ function _prepare_sparse_hessian_aux(
6262
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for a in 1:A
6363
]
6464
batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds]
65-
hvp_prep = DI.prepare_hvp(f, dense_backend, x, batched_seeds[1], contexts...)
66-
gradient_prep = DI.prepare_gradient(f, DI.inner(dense_backend), x, contexts...)
65+
hvp_prep = DI.prepare_hvp(f, dense_backend, x, batched_seeds[1], contexts...; strict)
66+
gradient_prep = DI.prepare_gradient(f, DI.inner(dense_backend), x, contexts...; strict)
6767
return SparseHessianPrep(
68-
SIG,
68+
_sig,
6969
batch_size_settings,
7070
coloring_result,
7171
compressed_matrix,

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,10 @@ function _prepare_sparse_jacobian_aux_aux(
109109
]
110110
batched_results = [ntuple(b -> similar(y), Val(B)) for _ in batched_seeds]
111111
pushforward_prep = DI.prepare_pushforward(
112-
f_or_f!y..., dense_backend, x, batched_seeds[1], contexts...
112+
f_or_f!y..., dense_backend, x, batched_seeds[1], contexts...; strict
113113
)
114114
return PushforwardSparseJacobianPrep(
115-
SIG,
115+
_sig,
116116
batch_size_settings,
117117
coloring_result,
118118
compressed_matrix,
@@ -143,10 +143,10 @@ function _prepare_sparse_jacobian_aux_aux(
143143
]
144144
batched_results = [ntuple(b -> similar(x), Val(B)) for _ in batched_seeds]
145145
pullback_prep = DI.prepare_pullback(
146-
f_or_f!y..., dense_backend, x, batched_seeds[1], contexts...
146+
f_or_f!y..., dense_backend, x, batched_seeds[1], contexts...; strict
147147
)
148148
return PullbackSparseJacobianPrep(
149-
SIG,
149+
_sig,
150150
batch_size_settings,
151151
coloring_result,
152152
compressed_matrix,

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian_mixed.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,18 +132,20 @@ function _prepare_mixed_sparse_jacobian_aux_aux(
132132
DI.forward_backend(dense_backend),
133133
x,
134134
batched_seeds_forward[1],
135-
contexts...,
135+
contexts...;
136+
strict,
136137
)
137138
pullback_prep = DI.prepare_pullback(
138139
f_or_f!y...,
139140
DI.reverse_backend(dense_backend),
140141
x,
141142
batched_seeds_reverse[1],
142-
contexts...,
143+
contexts...;
144+
strict,
143145
)
144146

145147
return MixedModeSparseJacobianPrep(
146-
SIG,
148+
_sig,
147149
batch_size_settings_forward,
148150
batch_size_settings_reverse,
149151
coloring_result,

DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ function DI.prepare_pullback_same_point(
3535
ty::NTuple,
3636
contexts::Vararg{DI.GeneralizedConstant,C},
3737
) where {C}
38-
_sig = DI.signature(f, prep, backend, x, ty, contexts...)
38+
_sig = DI.signature(f, prep, backend, x, ty, contexts...; strict)
3939
DI.check_prep(f, prep, backend, x, ty, contexts...)
4040
y, pb = forward(f, x, map(DI.unwrap, contexts)...)
4141
return TrackerPullbackPrepSamePoint(_sig, y, pb)

0 commit comments

Comments
 (0)