1- struct ForwardDiffOverSomethingHVPWrapper{F}
2- f:: F
3- end
4-
5- """
6- tag_backend_hvp(f, ::AutoForwardDiff, x)
7-
8- Return a new `AutoForwardDiff` backend with a fixed tag linked to `f`, so that we know how to prepare the inner gradient of the HVP without depending on what that gradient closure looks like.
9- """
10- tag_backend_hvp (f, backend:: AutoForwardDiff , x) = backend
11-
12- function tag_backend_hvp (f:: F , :: AutoForwardDiff{chunksize,Nothing} , x) where {F,chunksize}
13- tag = ForwardDiff. Tag (ForwardDiffOverSomethingHVPWrapper (f), eltype (x))
14- return AutoForwardDiff {chunksize,typeof(tag)} (tag)
15- end
16-
17- struct ForwardDiffOverSomethingHVPPrep{B<: AutoForwardDiff ,G,E<: PushforwardPrep } <: HVPPrep
18- tagged_outer_backend:: B
19- inner_gradient:: G
20- outer_pushforward_prep:: E
1+ struct ForwardDiffOverSomethingHVPPrep{E1<: GradientPrep ,E2<: PushforwardPrep } <: HVPPrep
2+ inner_gradient_prep:: E1
3+ outer_pushforward_prep:: E2
214end
225
236function DI. prepare_hvp (
@@ -27,65 +10,94 @@ function DI.prepare_hvp(
2710 tx:: NTuple ,
2811 contexts:: Vararg{Context,C} ,
2912) where {F,C}
30- rewrap = Rewrap (contexts... )
31- tagged_outer_backend = tag_backend_hvp (f, outer (backend), x)
32- T = tag_type (f, tagged_outer_backend, x)
13+ T = tag_type (shuffled_gradient, outer (backend), x)
3314 xdual = make_dual (T, x, tx)
34- gradient_prep = DI. prepare_gradient (f, inner (backend), xdual, contexts... )
35- # TODO : get rid of closure?
36- function inner_gradient (x, unannotated_contexts ... )
37- annotated_contexts = rewrap (unannotated_contexts ... )
38- return DI . gradient (f, gradient_prep, inner (backend), x, annotated_contexts ... )
39- end
40- outer_pushforward_prep = DI . prepare_pushforward (
41- inner_gradient, tagged_outer_backend, x, tx, contexts...
15+ inner_gradient_prep = DI. prepare_gradient (f, inner (backend), xdual, contexts... )
16+ rewrap = Rewrap (contexts ... )
17+ new_contexts = (
18+ Constant (f),
19+ PrepContext (inner_gradient_prep),
20+ Constant ( inner (backend)),
21+ Constant (rewrap),
22+ contexts... ,
4223 )
43- return ForwardDiffOverSomethingHVPPrep (
44- tagged_outer_backend, inner_gradient, outer_pushforward_prep
24+ outer_pushforward_prep = DI . prepare_pushforward (
25+ shuffled_gradient, outer (backend), x, tx, new_contexts ...
4526 )
27+ return ForwardDiffOverSomethingHVPPrep (inner_gradient_prep, outer_pushforward_prep)
4628end
4729
4830function DI. hvp (
4931 f:: F ,
5032 prep:: ForwardDiffOverSomethingHVPPrep ,
51- :: SecondOrder{<:AutoForwardDiff} ,
33+ backend :: SecondOrder{<:AutoForwardDiff} ,
5234 x,
5335 tx:: NTuple ,
5436 contexts:: Vararg{Context,C} ,
5537) where {F,C}
56- (; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep
38+ (; inner_gradient_prep, outer_pushforward_prep) = prep
39+ rewrap = Rewrap (contexts... )
40+ new_contexts = (
41+ Constant (f),
42+ PrepContext (inner_gradient_prep),
43+ Constant (inner (backend)),
44+ Constant (rewrap),
45+ contexts... ,
46+ )
5747 return DI. pushforward (
58- inner_gradient , outer_pushforward_prep, tagged_outer_backend , x, tx, contexts ...
48+ shuffled_gradient , outer_pushforward_prep, outer (backend) , x, tx, new_contexts ...
5949 )
6050end
6151
6252function DI. hvp! (
6353 f:: F ,
6454 tg:: NTuple ,
6555 prep:: ForwardDiffOverSomethingHVPPrep ,
66- :: SecondOrder{<:AutoForwardDiff} ,
56+ backend :: SecondOrder{<:AutoForwardDiff} ,
6757 x,
6858 tx:: NTuple ,
6959 contexts:: Vararg{Context,C} ,
7060) where {F,C}
71- (; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep
72- DI. pushforward! (
73- inner_gradient, tg, outer_pushforward_prep, tagged_outer_backend, x, tx, contexts...
61+ (; inner_gradient_prep, outer_pushforward_prep) = prep
62+ rewrap = Rewrap (contexts... )
63+ new_contexts = (
64+ Constant (f),
65+ PrepContext (inner_gradient_prep),
66+ Constant (inner (backend)),
67+ Constant (rewrap),
68+ contexts... ,
69+ )
70+ return DI. pushforward! (
71+ shuffled_gradient,
72+ tg,
73+ outer_pushforward_prep,
74+ outer (backend),
75+ x,
76+ tx,
77+ new_contexts... ,
7478 )
7579 return tg
7680end
7781
7882function DI. gradient_and_hvp (
7983 f:: F ,
8084 prep:: ForwardDiffOverSomethingHVPPrep ,
81- :: SecondOrder{<:AutoForwardDiff} ,
85+ backend :: SecondOrder{<:AutoForwardDiff} ,
8286 x,
8387 tx:: NTuple ,
8488 contexts:: Vararg{Context,C} ,
8589) where {F,C}
86- (; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep
90+ (; inner_gradient_prep, outer_pushforward_prep) = prep
91+ rewrap = Rewrap (contexts... )
92+ new_contexts = (
93+ Constant (f),
94+ PrepContext (inner_gradient_prep),
95+ Constant (inner (backend)),
96+ Constant (rewrap),
97+ contexts... ,
98+ )
8799 return DI. value_and_pushforward (
88- inner_gradient , outer_pushforward_prep, tagged_outer_backend , x, tx, contexts ...
100+ shuffled_gradient , outer_pushforward_prep, outer (backend) , x, tx, new_contexts ...
89101 )
90102end
91103
@@ -94,14 +106,28 @@ function DI.gradient_and_hvp!(
94106 grad,
95107 tg:: NTuple ,
96108 prep:: ForwardDiffOverSomethingHVPPrep ,
97- :: SecondOrder{<:AutoForwardDiff} ,
109+ backend :: SecondOrder{<:AutoForwardDiff} ,
98110 x,
99111 tx:: NTuple ,
100112 contexts:: Vararg{Context,C} ,
101113) where {F,C}
102- (; tagged_outer_backend, inner_gradient, outer_pushforward_prep) = prep
114+ (; inner_gradient_prep, outer_pushforward_prep) = prep
115+ rewrap = Rewrap (contexts... )
116+ new_contexts = (
117+ Constant (f),
118+ PrepContext (inner_gradient_prep),
119+ Constant (inner (backend)),
120+ Constant (rewrap),
121+ contexts... ,
122+ )
103123 new_grad, _ = DI. value_and_pushforward! (
104- inner_gradient, tg, outer_pushforward_prep, tagged_outer_backend, x, tx, contexts...
124+ shuffled_gradient,
125+ tg,
126+ outer_pushforward_prep,
127+ outer (backend),
128+ x,
129+ tx,
130+ new_contexts... ,
105131 )
106132 return copyto! (grad, new_grad), tg
107133end
0 commit comments