Skip to content

Commit fb52441

Browse files
committed
Inner HVP preparation
1 parent 2cc261f commit fb52441

15 files changed

Lines changed: 278 additions & 89 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,7 @@ struct ChainRulesPullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep
66
end
77

88
function DI.prepare_pullback(
9-
f,
10-
::AutoReverseChainRules,
11-
x,
12-
ty::NTuple,
13-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
9+
f, ::AutoReverseChainRules, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}
1410
) where {C}
1511
return DI.NoPullbackPrep()
1612
end
@@ -21,7 +17,7 @@ function DI.prepare_pullback_same_point(
2117
backend::AutoReverseChainRules,
2218
x,
2319
ty::NTuple,
24-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
20+
contexts::Vararg{DI.GeneralizedConstant,C},
2521
) where {C}
2622
rc = ruleconfig(backend)
2723
y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...)
@@ -34,7 +30,7 @@ function DI.value_and_pullback(
3430
backend::AutoReverseChainRules,
3531
x,
3632
ty::NTuple,
37-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
33+
contexts::Vararg{DI.GeneralizedConstant,C},
3834
) where {C}
3935
rc = ruleconfig(backend)
4036
y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...)
@@ -50,7 +46,7 @@ function DI.value_and_pullback(
5046
::AutoReverseChainRules,
5147
x,
5248
ty::NTuple,
53-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
49+
contexts::Vararg{DI.GeneralizedConstant,C},
5450
) where {C}
5551
(; y, pb) = prep
5652
tx = map(ty) do dy
@@ -65,7 +61,7 @@ function DI.pullback(
6561
::AutoReverseChainRules,
6662
x,
6763
ty::NTuple,
68-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
64+
contexts::Vararg{DI.GeneralizedConstant,C},
6965
) where {C}
7066
(; pb) = prep
7167
tx = map(ty) do dy

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ using ForwardDiff:
2828
value
2929

3030
DI.check_available(::AutoForwardDiff) = true
31+
DI.inner_preparation_behavior(::AutoForwardDiff) = DI.PrepareInnerOverload()
3132

3233
include("utils.jl")
3334
include("onearg.jl")

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,33 @@
22
DI.overloaded_input_type(prep::ForwardDiffOneArgPushforwardPrep) = typeof(prep.xdual_tmp)
33
DI.overloaded_input_type(prep::ForwardDiffTwoArgPushforwardPrep) = typeof(prep.xdual_tmp)
44

5+
function DI.overloaded_input(
6+
::typeof(DI.pushforward),
7+
f::F,
8+
backend::AutoForwardDiff,
9+
x,
10+
tx::NTuple{B},
11+
contexts::Vararg{DI.Context,C},
12+
) where {F,B,C}
13+
T = tag_type(f, backend, x)
14+
xdual = make_dual(T, x, tx)
15+
return xdual
16+
end
17+
18+
function DI.overloaded_input(
19+
::typeof(DI.pushforward),
20+
f!::F,
21+
y,
22+
backend::AutoForwardDiff,
23+
x,
24+
tx::NTuple{B},
25+
contexts::Vararg{DI.Context,C},
26+
) where {F,B,C}
27+
T = tag_type(f!, backend, x)
28+
xdual = make_dual(T, x, tx)
29+
return xdual
30+
end
31+
532
## Derivative
633
function DI.overloaded_input_type(prep::ForwardDiffOneArgDerivativePrep)
734
return DI.overloaded_input_type(prep.pushforward_prep)

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ function DI.value_and_gradient!(
272272
if (
273273
isnothing(chunksize) &&
274274
T === Nothing &&
275-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
275+
contexts isa NTuple{C,DI.GeneralizedConstant}
276276
)
277277
fc = DI.with_contexts(f, contexts...)
278278
result = DiffResult(zero(eltype(x)), (grad,))
@@ -292,7 +292,7 @@ function DI.value_and_gradient(
292292
if (
293293
isnothing(chunksize) &&
294294
T === Nothing &&
295-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
295+
contexts isa NTuple{C,DI.GeneralizedConstant}
296296
)
297297
fc = DI.with_contexts(f, contexts...)
298298
result = GradientResult(x)
@@ -310,7 +310,7 @@ function DI.gradient!(
310310
if (
311311
isnothing(chunksize) &&
312312
T === Nothing &&
313-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
313+
contexts isa NTuple{C,DI.GeneralizedConstant}
314314
)
315315
fc = DI.with_contexts(f, contexts...)
316316
return gradient!(grad, fc, x)
@@ -326,7 +326,7 @@ function DI.gradient(
326326
if (
327327
isnothing(chunksize) &&
328328
T === Nothing &&
329-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
329+
contexts isa NTuple{C,DI.GeneralizedConstant}
330330
)
331331
fc = DI.with_contexts(f, contexts...)
332332
return gradient(fc, x)
@@ -435,7 +435,7 @@ function DI.value_and_jacobian!(
435435
if (
436436
isnothing(chunksize) &&
437437
T === Nothing &&
438-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
438+
contexts isa NTuple{C,DI.GeneralizedConstant}
439439
)
440440
fc = DI.with_contexts(f, contexts...)
441441
y = fc(x)
@@ -456,7 +456,7 @@ function DI.value_and_jacobian(
456456
if (
457457
isnothing(chunksize) &&
458458
T === Nothing &&
459-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
459+
contexts isa NTuple{C,DI.GeneralizedConstant}
460460
)
461461
fc = DI.with_contexts(f, contexts...)
462462
return fc(x), jacobian(fc, x)
@@ -472,7 +472,7 @@ function DI.jacobian!(
472472
if (
473473
isnothing(chunksize) &&
474474
T === Nothing &&
475-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
475+
contexts isa NTuple{C,DI.GeneralizedConstant}
476476
)
477477
fc = DI.with_contexts(f, contexts...)
478478
return jacobian!(jac, fc, x)
@@ -488,7 +488,7 @@ function DI.jacobian(
488488
if (
489489
isnothing(chunksize) &&
490490
T === Nothing &&
491-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
491+
contexts isa NTuple{C,DI.GeneralizedConstant}
492492
)
493493
fc = DI.with_contexts(f, contexts...)
494494
return jacobian(fc, x)
@@ -738,7 +738,7 @@ function DI.hessian!(
738738
if (
739739
isnothing(chunksize) &&
740740
T === Nothing &&
741-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
741+
contexts isa NTuple{C,DI.GeneralizedConstant}
742742
)
743743
fc = DI.with_contexts(f, contexts...)
744744
return hessian!(hess, fc, x)
@@ -754,7 +754,7 @@ function DI.hessian(
754754
if (
755755
isnothing(chunksize) &&
756756
T === Nothing &&
757-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
757+
contexts isa NTuple{C,DI.GeneralizedConstant}
758758
)
759759
fc = DI.with_contexts(f, contexts...)
760760
return hessian(fc, x)
@@ -775,7 +775,7 @@ function DI.value_gradient_and_hessian!(
775775
if (
776776
isnothing(chunksize) &&
777777
T === Nothing &&
778-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
778+
contexts isa NTuple{C,DI.GeneralizedConstant}
779779
)
780780
fc = DI.with_contexts(f, contexts...)
781781
result = DiffResult(one(eltype(x)), (grad, hess))
@@ -796,7 +796,7 @@ function DI.value_gradient_and_hessian(
796796
if (
797797
isnothing(chunksize) &&
798798
T === Nothing &&
799-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
799+
contexts isa NTuple{C,DI.GeneralizedConstant}
800800
)
801801
fc = DI.with_contexts(f, contexts...)
802802
result = HessianResult(x)

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ end
117117
function DI.value_and_derivative(
118118
f!::F, y, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C}
119119
) where {F,C,chunksize,T}
120-
if (T === Nothing && contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend})
120+
if (T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant})
121121
fc! = DI.with_contexts(f!, contexts...)
122122
result = MutableDiffResult(y, (similar(y),))
123123
result = derivative!(result, fc!, y, x)
@@ -131,7 +131,7 @@ end
131131
function DI.value_and_derivative!(
132132
f!::F, y, der, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C}
133133
) where {F,C,chunksize,T}
134-
if (T === Nothing && contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend})
134+
if (T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant})
135135
fc! = DI.with_contexts(f!, contexts...)
136136
result = MutableDiffResult(y, (der,))
137137
result = derivative!(result, fc!, y, x)
@@ -145,7 +145,7 @@ end
145145
function DI.derivative(
146146
f!::F, y, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C}
147147
) where {F,C,chunksize,T}
148-
if (T === Nothing && contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend})
148+
if (T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant})
149149
fc! = DI.with_contexts(f!, contexts...)
150150
return derivative(fc!, y, x)
151151
else
@@ -157,7 +157,7 @@ end
157157
function DI.derivative!(
158158
f!::F, y, der, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C}
159159
) where {F,C,chunksize,T}
160-
if (T === Nothing && contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend})
160+
if (T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant})
161161
fc! = DI.with_contexts(f!, contexts...)
162162
return derivative!(der, fc!, y, x)
163163
else
@@ -188,7 +188,7 @@ function DI.prepare!_derivative(
188188
old_prep::ForwardDiffTwoArgDerivativePrep,
189189
backend::AutoForwardDiff,
190190
x,
191-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
191+
contexts::Vararg{DI.GeneralizedConstant,C},
192192
) where {F,C}
193193
if y isa Vector
194194
(; config) = old_prep
@@ -283,7 +283,7 @@ function DI.value_and_jacobian(
283283
if (
284284
isnothing(chunksize) &&
285285
T === Nothing &&
286-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
286+
contexts isa NTuple{C,DI.GeneralizedConstant}
287287
)
288288
fc! = DI.with_contexts(f!, contexts...)
289289
jac = similar(y, length(y), length(x))
@@ -302,7 +302,7 @@ function DI.value_and_jacobian!(
302302
if (
303303
isnothing(chunksize) &&
304304
T === Nothing &&
305-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
305+
contexts isa NTuple{C,DI.GeneralizedConstant}
306306
)
307307
fc! = DI.with_contexts(f!, contexts...)
308308
result = MutableDiffResult(y, (jac,))
@@ -320,7 +320,7 @@ function DI.jacobian(
320320
if (
321321
isnothing(chunksize) &&
322322
T === Nothing &&
323-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
323+
contexts isa NTuple{C,DI.GeneralizedConstant}
324324
)
325325
fc! = DI.with_contexts(f!, contexts...)
326326
return jacobian(fc!, y, x)
@@ -336,7 +336,7 @@ function DI.jacobian!(
336336
if (
337337
isnothing(chunksize) &&
338338
T === Nothing &&
339-
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
339+
contexts isa NTuple{C,DI.GeneralizedConstant}
340340
)
341341
fc! = DI.with_contexts(f!, contexts...)
342342
return jacobian!(jac, fc!, y, x)
@@ -369,7 +369,7 @@ function DI.prepare!_jacobian(
369369
old_prep::ForwardDiffTwoArgJacobianPrep,
370370
backend::AutoForwardDiff,
371371
x,
372-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
372+
contexts::Vararg{DI.GeneralizedConstant,C},
373373
) where {F,C}
374374
if x isa Vector && y isa Vector
375375
(; config) = old_prep

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ struct PrepContext{T<:DI.Prep} <: DI.Context
8787
data::T
8888
end
8989

90-
NotCache = Union{DI.ConstantOrFunctionOrBackend,PrepContext}
90+
NotCache = Union{DI.GeneralizedConstant,PrepContext}
9191

9292
_translate(::Type{D}, c::NotCache) where {D<:Dual} = DI.unwrap(c)
9393
function _translate(::Type{D}, c::DI.Cache) where {D<:Dual}

DifferentiationInterface/ext/DifferentiationInterfaceTrackerExt/DifferentiationInterfaceTrackerExt.jl

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ struct TrackerPullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep
1515
end
1616

1717
function DI.prepare_pullback(
18-
f, ::AutoTracker, x, ty::NTuple, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}
18+
f, ::AutoTracker, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}
1919
) where {C}
2020
return DI.NoPullbackPrep()
2121
end
@@ -26,7 +26,7 @@ function DI.prepare_pullback_same_point(
2626
::AutoTracker,
2727
x,
2828
ty::NTuple,
29-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
29+
contexts::Vararg{DI.GeneralizedConstant,C},
3030
) where {C}
3131
y, pb = forward(f, x, map(DI.unwrap, contexts)...)
3232
return TrackerPullbackPrepSamePoint(y, pb)
@@ -38,7 +38,7 @@ function DI.value_and_pullback(
3838
::AutoTracker,
3939
x,
4040
ty::NTuple,
41-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
41+
contexts::Vararg{DI.GeneralizedConstant,C},
4242
) where {C}
4343
y, pb = forward(f, x, map(DI.unwrap, contexts)...)
4444
tx = map(ty) do dy
@@ -53,7 +53,7 @@ function DI.value_and_pullback(
5353
::AutoTracker,
5454
x,
5555
ty::NTuple,
56-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
56+
contexts::Vararg{DI.GeneralizedConstant,C},
5757
) where {C}
5858
(; y, pb) = prep
5959
tx = map(ty) do dy
@@ -68,7 +68,7 @@ function DI.pullback(
6868
::AutoTracker,
6969
x,
7070
ty::NTuple,
71-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
71+
contexts::Vararg{DI.GeneralizedConstant,C},
7272
) where {C}
7373
(; pb) = prep
7474
tx = map(ty) do dy
@@ -80,28 +80,20 @@ end
8080
## Gradient
8181

8282
function DI.prepare_gradient(
83-
f, ::AutoTracker, x, contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}
83+
f, ::AutoTracker, x, contexts::Vararg{DI.GeneralizedConstant,C}
8484
) where {C}
8585
return DI.NoGradientPrep()
8686
end
8787

8888
function DI.value_and_gradient(
89-
f,
90-
::DI.NoGradientPrep,
91-
::AutoTracker,
92-
x,
93-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
89+
f, ::DI.NoGradientPrep, ::AutoTracker, x, contexts::Vararg{DI.GeneralizedConstant,C}
9490
) where {C}
9591
(; val, grad) = withgradient(f, x, map(DI.unwrap, contexts)...)
9692
return val, data(first(grad))
9793
end
9894

9995
function DI.gradient(
100-
f,
101-
::DI.NoGradientPrep,
102-
::AutoTracker,
103-
x,
104-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
96+
f, ::DI.NoGradientPrep, ::AutoTracker, x, contexts::Vararg{DI.GeneralizedConstant,C}
10597
) where {C}
10698
(; grad) = withgradient(f, x, map(DI.unwrap, contexts)...)
10799
return data(first(grad))
@@ -113,7 +105,7 @@ function DI.value_and_gradient!(
113105
prep::DI.NoGradientPrep,
114106
backend::AutoTracker,
115107
x,
116-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
108+
contexts::Vararg{DI.GeneralizedConstant,C},
117109
) where {C}
118110
y, new_grad = DI.value_and_gradient(f, prep, backend, x, contexts...)
119111
return y, copyto!(grad, new_grad)
@@ -125,7 +117,7 @@ function DI.gradient!(
125117
prep::DI.NoGradientPrep,
126118
backend::AutoTracker,
127119
x,
128-
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
120+
contexts::Vararg{DI.GeneralizedConstant,C},
129121
) where {C}
130122
return copyto!(grad, DI.gradient(f, prep, backend, x, contexts...))
131123
end

0 commit comments

Comments
 (0)