Skip to content

Commit 3d17483

Browse files
committed
Fixes
1 parent 0aabc2b commit 3d17483

3 files changed

Lines changed: 19 additions & 21 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,11 @@ end
137137

138138
function DI.gradient(
139139
f::F,
140-
prep::EnzymeForwardGradientPrep{B},
140+
prep::EnzymeForwardGradientPrep{SIG,B},
141141
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
142142
x,
143143
contexts::Vararg{DI.Constant,C},
144-
) where {F,B,C}
144+
) where {F,SIG,B,C}
145145
DI.check_prep(f, prep, backend, x, contexts...)
146146
mode = forward_noprimal(backend)
147147
f_and_df = get_f_and_df(f, backend, mode)
@@ -154,11 +154,11 @@ end
154154

155155
function DI.value_and_gradient(
156156
f::F,
157-
prep::EnzymeForwardGradientPrep{B},
157+
prep::EnzymeForwardGradientPrep{SIG,B},
158158
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
159159
x,
160160
contexts::Vararg{DI.Constant,C},
161-
) where {F,B,C}
161+
) where {F,SIG,B,C}
162162
DI.check_prep(f, prep, backend, x, contexts...)
163163
mode = forward_withprimal(backend)
164164
f_and_df = get_f_and_df(f, backend, mode)
@@ -172,23 +172,23 @@ end
172172
function DI.gradient!(
173173
f::F,
174174
grad,
175-
prep::EnzymeForwardGradientPrep{B},
175+
prep::EnzymeForwardGradientPrep{SIG,B},
176176
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
177177
x,
178178
contexts::Vararg{DI.Constant,C},
179-
) where {F,B,C}
179+
) where {F,SIG,B,C}
180180
DI.check_prep(f, prep, backend, x, contexts...)
181181
return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...))
182182
end
183183

184184
function DI.value_and_gradient!(
185185
f::F,
186186
grad,
187-
prep::EnzymeForwardGradientPrep{B},
187+
prep::EnzymeForwardGradientPrep{SIG,B},
188188
backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}},
189189
x,
190190
contexts::Vararg{DI.Constant,C},
191-
) where {F,B,C}
191+
) where {F,SIG,B,C}
192192
DI.check_prep(f, prep, backend, x, contexts...)
193193
y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...)
194194
return y, copyto!(grad, new_grad)
@@ -219,11 +219,11 @@ end
219219

220220
function DI.jacobian(
221221
f::F,
222-
prep::EnzymeForwardOneArgJacobianPrep{B},
222+
prep::EnzymeForwardOneArgJacobianPrep{SIG,B},
223223
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},
224224
x,
225225
contexts::Vararg{DI.Constant,C},
226-
) where {F,B,C}
226+
) where {F,SIG,B,C}
227227
DI.check_prep(f, prep, backend, x, contexts...)
228228
mode = forward_noprimal(backend)
229229
f_and_df = get_f_and_df(f, backend, mode)
@@ -237,11 +237,11 @@ end
237237

238238
function DI.value_and_jacobian(
239239
f::F,
240-
prep::EnzymeForwardOneArgJacobianPrep{B},
240+
prep::EnzymeForwardOneArgJacobianPrep{SIG,B},
241241
backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}},
242242
x,
243243
contexts::Vararg{DI.Constant,C},
244-
) where {F,B,C}
244+
) where {F,SIG,B,C}
245245
DI.check_prep(f, prep, backend, x, contexts...)
246246
mode = forward_withprimal(backend)
247247
f_and_df = get_f_and_df(f, backend, mode)

DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -87,9 +87,8 @@ function DI.value_and_pushforward(
8787
contexts::Vararg{DI.Constant,C},
8888
) where {C}
8989
DI.check_prep(f, prep, backend, x, tx, contexts...)
90-
fc = DI.with_contexts(f, contexts...)
91-
ty = DI.pushforward(fc, prep, backend, x, tx)
92-
y = fc(x) # TO-DO: optimize
90+
ty = DI.pushforward(f, prep, backend, x, tx, contexts...)
91+
y = f(x, map(DI.unwrap, contexts)...) # TODO: optimize
9392
return y, ty
9493
end
9594

@@ -103,9 +102,8 @@ function DI.value_and_pushforward!(
103102
contexts::Vararg{DI.Constant,C},
104103
) where {C}
105104
DI.check_prep(f, prep, backend, x, tx, contexts...)
106-
fc = DI.with_contexts(f, contexts...)
107-
DI.pushforward!(fc, ty, prep, backend, x, tx)
108-
y = fc(x) # TO-DO: optimize
105+
DI.pushforward!(f, ty, prep, backend, x, tx, contexts...)
106+
y = f(x, map(DI.unwrap, contexts)...) # TODO: optimize
109107
return y, ty
110108
end
111109

DifferentiationInterface/ext/DifferentiationInterfaceSymbolicsExt/onearg.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ function DI.hvp!(
460460
tx::NTuple,
461461
contexts::Vararg{DI.Context,C},
462462
) where {C}
463-
DI.check_prep(f, prep, backend, x, contexts...)
463+
DI.check_prep(f, prep, backend, x, tx, contexts...)
464464
for b in eachindex(tx, tg)
465465
dx, dg = tx[b], tg[b]
466466
prep.hvp_exe!(vec(dg), vec(x), vec(dx), map(DI.unwrap, contexts)...)
@@ -476,7 +476,7 @@ function DI.gradient_and_hvp(
476476
tx::NTuple,
477477
contexts::Vararg{DI.Context,C},
478478
) where {C}
479-
DI.check_prep(f, prep, backend, x, contexts...)
479+
DI.check_prep(f, prep, backend, x, tx, contexts...)
480480
tg = DI.hvp(f, prep, backend, x, tx, contexts...)
481481
grad = DI.gradient(f, prep.gradient_prep, backend, x, contexts...)
482482
return grad, tg
@@ -492,7 +492,7 @@ function DI.gradient_and_hvp!(
492492
tx::NTuple,
493493
contexts::Vararg{DI.Context,C},
494494
) where {C}
495-
DI.check_prep(f, prep, backend, x, contexts...)
495+
DI.check_prep(f, prep, backend, x, tx, contexts...)
496496
DI.hvp!(f, tg, prep, backend, x, tx, contexts...)
497497
DI.gradient!(f, grad, prep.gradient_prep, backend, x, contexts...)
498498
return grad, tg

0 commit comments

Comments
 (0)