Skip to content

Commit 960ea1f

Browse files
committed
Fixes
1 parent de6d614 commit 960ea1f

3 files changed

Lines changed: 56 additions & 39 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,11 @@ end
8888

8989
function compute_ydual_onearg(
9090
f::F,
91-
prep::ForwardDiffOneArgPushforwardPrep{T},
91+
prep::ForwardDiffOneArgPushforwardPrep{SIG,T},
9292
x::Number,
9393
tx::NTuple{B},
9494
contexts::Vararg{DI.Context,C},
95-
) where {F,T,B,C}
95+
) where {F,SIG,T,B,C}
9696
xdual = make_dual(T, x, tx)
9797
contexts_dual = translate_prepared(contexts, prep.contexts_dual)
9898
ydual = f(xdual, contexts_dual...)
@@ -101,11 +101,11 @@ end
101101

102102
function compute_ydual_onearg(
103103
f::F,
104-
prep::ForwardDiffOneArgPushforwardPrep{T},
104+
prep::ForwardDiffOneArgPushforwardPrep{SIG,T},
105105
x,
106106
tx::NTuple{B},
107107
contexts::Vararg{DI.Context,C},
108-
) where {F,T,B,C}
108+
) where {F,SIG,T,B,C}
109109
if DI.ismutable_array(x)
110110
make_dual!(T, prep.xdual_tmp, x, tx)
111111
xdual_tmp = prep.xdual_tmp
@@ -119,12 +119,12 @@ end
119119

120120
function DI.value_and_pushforward(
121121
f::F,
122-
prep::ForwardDiffOneArgPushforwardPrep{T},
122+
prep::ForwardDiffOneArgPushforwardPrep{SIG,T},
123123
backend::AutoForwardDiff,
124124
x,
125125
tx::NTuple{B},
126126
contexts::Vararg{DI.Context,C},
127-
) where {F,T,B,C}
127+
) where {F,SIG,T,B,C}
128128
DI.check_prep(f, prep, backend, x, tx, contexts...)
129129
ydual = compute_ydual_onearg(f, prep, x, tx, contexts...)
130130
y = myvalue(T, ydual)
@@ -135,12 +135,12 @@ end
135135
function DI.value_and_pushforward!(
136136
f::F,
137137
ty::NTuple,
138-
prep::ForwardDiffOneArgPushforwardPrep{T},
138+
prep::ForwardDiffOneArgPushforwardPrep{SIG,T},
139139
backend::AutoForwardDiff,
140140
x,
141141
tx::NTuple,
142142
contexts::Vararg{DI.Context,C},
143-
) where {F,T,C}
143+
) where {F,SIG,T,C}
144144
DI.check_prep(f, prep, backend, x, tx, contexts...)
145145
ydual = compute_ydual_onearg(f, prep, x, tx, contexts...)
146146
y = myvalue(T, ydual)
@@ -150,12 +150,12 @@ end
150150

151151
function DI.pushforward(
152152
f::F,
153-
prep::ForwardDiffOneArgPushforwardPrep{T},
153+
prep::ForwardDiffOneArgPushforwardPrep{SIG,T},
154154
backend::AutoForwardDiff,
155155
x,
156156
tx::NTuple{B},
157157
contexts::Vararg{DI.Context,C},
158-
) where {F,T,B,C}
158+
) where {F,SIG,T,B,C}
159159
DI.check_prep(f, prep, backend, x, tx, contexts...)
160160
ydual = compute_ydual_onearg(f, prep, x, tx, contexts...)
161161
ty = mypartials(T, Val(B), ydual)
@@ -165,12 +165,12 @@ end
165165
function DI.pushforward!(
166166
f::F,
167167
ty::NTuple,
168-
prep::ForwardDiffOneArgPushforwardPrep{T},
168+
prep::ForwardDiffOneArgPushforwardPrep{SIG,T},
169169
backend::AutoForwardDiff,
170170
x,
171171
tx::NTuple,
172172
contexts::Vararg{DI.Context,C},
173-
) where {F,T,C}
173+
) where {F,SIG,T,C}
174174
DI.check_prep(f, prep, backend, x, tx, contexts...)
175175
ydual = compute_ydual_onearg(f, prep, x, tx, contexts...)
176176
mypartials!(T, ty, ydual)

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ end
3030
function compute_ydual_twoarg(
3131
f!::F,
3232
y,
33-
prep::ForwardDiffTwoArgPushforwardPrep{T},
33+
prep::ForwardDiffTwoArgPushforwardPrep{SIG,T},
3434
x::Number,
3535
tx::NTuple{B},
3636
contexts::Vararg{DI.Context,C},
37-
) where {F,T,B,C}
37+
) where {F,SIG,T,B,C}
3838
(; ydual_tmp) = prep
3939
xdual_tmp = make_dual(T, x, tx)
4040
contexts_dual = translate_prepared(contexts, prep.contexts_dual)
@@ -45,11 +45,11 @@ end
4545
function compute_ydual_twoarg(
4646
f!::F,
4747
y,
48-
prep::ForwardDiffTwoArgPushforwardPrep{T},
48+
prep::ForwardDiffTwoArgPushforwardPrep{SIG,T},
4949
x,
5050
tx::NTuple{B},
5151
contexts::Vararg{DI.Context,C},
52-
) where {F,T,B,C}
52+
) where {F,SIG,T,B,C}
5353
(; xdual_tmp, ydual_tmp) = prep
5454
make_dual!(T, xdual_tmp, x, tx)
5555
contexts_dual = translate_prepared(contexts, prep.contexts_dual)
@@ -60,12 +60,12 @@ end
6060
function DI.value_and_pushforward(
6161
f!::F,
6262
y,
63-
prep::ForwardDiffTwoArgPushforwardPrep{T},
63+
prep::ForwardDiffTwoArgPushforwardPrep{SIG,T},
6464
backend::AutoForwardDiff,
6565
x,
6666
tx::NTuple{B},
6767
contexts::Vararg{DI.Context,C},
68-
) where {F,T,B,C}
68+
) where {F,SIG,T,B,C}
6969
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
7070
ydual_tmp = compute_ydual_twoarg(f!, y, prep, x, tx, contexts...)
7171
myvalue!(T, y, ydual_tmp)
@@ -77,12 +77,12 @@ function DI.value_and_pushforward!(
7777
f!::F,
7878
y,
7979
ty::NTuple,
80-
prep::ForwardDiffTwoArgPushforwardPrep{T},
80+
prep::ForwardDiffTwoArgPushforwardPrep{SIG,T},
8181
backend::AutoForwardDiff,
8282
x,
8383
tx::NTuple,
8484
contexts::Vararg{DI.Context,C},
85-
) where {F,T,C}
85+
) where {F,SIG,T,C}
8686
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
8787
ydual_tmp = compute_ydual_twoarg(f!, y, prep, x, tx, contexts...)
8888
myvalue!(T, y, ydual_tmp)
@@ -93,12 +93,12 @@ end
9393
function DI.pushforward(
9494
f!::F,
9595
y,
96-
prep::ForwardDiffTwoArgPushforwardPrep{T},
96+
prep::ForwardDiffTwoArgPushforwardPrep{SIG,T},
9797
backend::AutoForwardDiff,
9898
x,
9999
tx::NTuple{B},
100100
contexts::Vararg{DI.Context,C},
101-
) where {F,T,B,C}
101+
) where {F,SIG,T,B,C}
102102
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
103103
ydual_tmp = compute_ydual_twoarg(f!, y, prep, x, tx, contexts...)
104104
ty = mypartials(T, Val(B), ydual_tmp)
@@ -109,12 +109,12 @@ function DI.pushforward!(
109109
f!::F,
110110
y,
111111
ty::NTuple,
112-
prep::ForwardDiffTwoArgPushforwardPrep{T},
112+
prep::ForwardDiffTwoArgPushforwardPrep{SIG,T},
113113
backend::AutoForwardDiff,
114114
x,
115115
tx::NTuple,
116116
contexts::Vararg{DI.Context,C},
117-
) where {F,T,C}
117+
) where {F,SIG,T,C}
118118
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
119119
ydual_tmp = compute_ydual_twoarg(f!, y, prep, x, tx, contexts...)
120120
mypartials!(T, ty, ydual_tmp)
@@ -425,7 +425,7 @@ function DI.value_and_jacobian(
425425
x,
426426
contexts::Vararg{DI.Context,C},
427427
) where {F,C}
428-
DI.check_prep(f!, y, old_prep, backend, x, contexts...)
428+
DI.check_prep(f!, y, prep, backend, x, contexts...)
429429
contexts_dual = translate_prepared(contexts, prep.contexts_dual)
430430
fc! = DI.FixTail(f!, contexts_dual...)
431431
jac = similar(y, length(y), length(x))
@@ -447,7 +447,7 @@ function DI.value_and_jacobian!(
447447
x,
448448
contexts::Vararg{DI.Context,C},
449449
) where {F,C}
450-
DI.check_prep(f!, y, old_prep, backend, x, contexts...)
450+
DI.check_prep(f!, y, prep, backend, x, contexts...)
451451
contexts_dual = translate_prepared(contexts, prep.contexts_dual)
452452
fc! = DI.FixTail(f!, contexts_dual...)
453453
result = MutableDiffResult(y, (jac,))
@@ -467,7 +467,7 @@ function DI.jacobian(
467467
x,
468468
contexts::Vararg{DI.Context,C},
469469
) where {F,C}
470-
DI.check_prep(f!, y, old_prep, backend, x, contexts...)
470+
DI.check_prep(f!, y, prep, backend, x, contexts...)
471471
contexts_dual = translate_prepared(contexts, prep.contexts_dual)
472472
fc! = DI.FixTail(f!, contexts_dual...)
473473
CHK = tag_type(backend) === Nothing
@@ -486,7 +486,7 @@ function DI.jacobian!(
486486
x,
487487
contexts::Vararg{DI.Context,C},
488488
) where {F,C}
489-
DI.check_prep(f!, y, old_prep, backend, x, contexts...)
489+
DI.check_prep(f!, y, prep, backend, x, contexts...)
490490
contexts_dual = translate_prepared(contexts, prep.contexts_dual)
491491
fc! = DI.FixTail(f!, contexts_dual...)
492492
CHK = tag_type(backend) === Nothing

DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ function DI.prepare_pullback_same_point(
4646
x,
4747
ty::NTuple,
4848
contexts::Vararg{DI.Context,C};
49+
strict::Bool=false,
4950
) where {C}
5051
DI.check_prep(f, prep, backend, x, ty, contexts...)
5152
SIG = DI.signature(f, backend, x, ty, contexts...; strict)
@@ -187,6 +188,11 @@ end
187188

188189
# Beware, this uses ForwardDiff for the inner differentiation
189190

191+
struct ZygoteHVPPrep{SIG,P} <: DI.HVPPrep{SIG}
192+
_sig::Type{SIG}
193+
fd_prep::P
194+
end
195+
190196
function DI.prepare_hvp(
191197
f,
192198
backend::AutoZygote,
@@ -195,65 +201,76 @@ function DI.prepare_hvp(
195201
contexts::Vararg{DI.Context,C};
196202
strict::Bool=false,
197203
) where {C}
198-
return DI.prepare_hvp(
204+
SIG = DI.signature(f, backend, x, tx, contexts...; strict)
205+
fd_prep = DI.prepare_hvp(
199206
f, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...; strict
200207
)
208+
return ZygoteHVPPrep(SIG, fd_prep)
201209
end
202210

203211
function DI.hvp(
204212
f,
205-
prep::DI.ForwardOverAnythingHVPPrep,
213+
prep::ZygoteHVPPrep,
206214
backend::AutoZygote,
207215
x,
208216
tx::NTuple,
209217
contexts::Vararg{DI.Context,C},
210218
) where {C}
211219
DI.check_prep(f, prep, backend, x, tx, contexts...)
212-
return DI.hvp(f, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...)
220+
return DI.hvp(
221+
f, prep.fd_prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...
222+
)
213223
end
214224

215225
function DI.hvp!(
216226
f,
217227
tg::NTuple,
218-
prep::DI.ForwardOverAnythingHVPPrep,
228+
prep::ZygoteHVPPrep,
219229
backend::AutoZygote,
220230
x,
221231
tx::NTuple,
222232
contexts::Vararg{DI.Context,C},
223233
) where {C}
224234
DI.check_prep(f, prep, backend, x, tx, contexts...)
225235
return DI.hvp!(
226-
f, tg, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...
236+
f, tg, prep.fd_prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...
227237
)
228238
end
229239

230240
function DI.gradient_and_hvp(
231241
f,
232-
prep::DI.ForwardOverAnythingHVPPrep,
242+
prep::ZygoteHVPPrep,
233243
backend::AutoZygote,
234244
x,
235245
tx::NTuple,
236246
contexts::Vararg{DI.Context,C},
237247
) where {C}
238248
DI.check_prep(f, prep, backend, x, tx, contexts...)
239249
return DI.gradient_and_hvp(
240-
f, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...
250+
f, prep.fd_prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...
241251
)
242252
end
243253

244254
function DI.gradient_and_hvp!(
245255
f,
246256
grad,
247257
tg::NTuple,
248-
prep::DI.ForwardOverAnythingHVPPrep,
258+
prep::ZygoteHVPPrep,
249259
backend::AutoZygote,
250260
x,
251261
tx::NTuple,
252262
contexts::Vararg{DI.Context,C},
253263
) where {C}
254264
DI.check_prep(f, prep, backend, x, tx, contexts...)
255265
return DI.gradient_and_hvp!(
256-
f, grad, tg, prep, DI.SecondOrder(AutoForwardDiff(), backend), x, tx, contexts...
266+
f,
267+
grad,
268+
tg,
269+
prep.fd_prep,
270+
DI.SecondOrder(AutoForwardDiff(), backend),
271+
x,
272+
tx,
273+
contexts...,
257274
)
258275
end
259276

@@ -303,7 +320,7 @@ function DI.value_gradient_and_hessian(
303320
contexts::Vararg{DI.GeneralizedConstant,C},
304321
) where {C}
305322
DI.check_prep(f, prep, backend, x, contexts...)
306-
y, grad = DI.value_and_gradient(f, DI.NoGradientPrep(), backend, x, contexts...)
323+
y, grad = DI.value_and_gradient(f, backend, x, contexts...)
307324
hess = DI.hessian(f, prep, backend, x, contexts...)
308325
return y, grad, hess
309326
end
@@ -318,7 +335,7 @@ function DI.value_gradient_and_hessian!(
318335
contexts::Vararg{DI.GeneralizedConstant,C},
319336
) where {C}
320337
DI.check_prep(f, prep, backend, x, contexts...)
321-
y, _ = DI.value_and_gradient!(f, grad, DI.NoGradientPrep(), backend, x, contexts...)
338+
y, _ = DI.value_and_gradient!(f, grad, backend, x, contexts...)
322339
DI.hessian!(f, hess, prep, backend, x, contexts...)
323340
return y, grad, hess
324341
end

0 commit comments

Comments
 (0)