@@ -40,19 +40,19 @@ function hvp! end
4040
4141# # Preparation
4242
43- struct ForwardOverForwardHVPExtras{G<: Gradient ,E<: PushforwardExtras } <: HVPExtras
43+ struct ForwardOverForwardHVPExtras{G,E<: PushforwardExtras } <: HVPExtras
4444 inner_gradient:: G
4545 outer_pushforward_extras:: E
4646end
4747
48- struct ForwardOverReverseHVPExtras{G<: Gradient ,E<: PushforwardExtras } <: HVPExtras
48+ struct ForwardOverReverseHVPExtras{G,E<: PushforwardExtras } <: HVPExtras
4949 inner_gradient:: G
5050 outer_pushforward_extras:: E
5151end
5252
5353struct ReverseOverForwardHVPExtras <: HVPExtras end
5454
55- struct ReverseOverReverseHVPExtras{G<: Gradient ,E<: PullbackExtras } <: HVPExtras
55+ struct ReverseOverReverseHVPExtras{G,E<: PullbackExtras } <: HVPExtras
5656 inner_gradient:: G
5757 outer_pullback_extras:: E
5858end
@@ -65,7 +65,7 @@ function _prepare_hvp_aux(
6565 f:: F , backend:: AbstractADType , x, tx:: Tangents , :: ForwardOverForward
6666) where {F}
6767 # pushforward of many pushforwards in theory, but pushforward of gradient in practice
68- inner_gradient = Gradient (f, nested (maybe_inner (backend)))
68+ inner_gradient (x) = gradient (f, nested (maybe_inner (backend)), x )
6969 outer_pushforward_extras = prepare_pushforward (
7070 inner_gradient, maybe_outer (backend), x, tx
7171 )
@@ -76,7 +76,7 @@ function _prepare_hvp_aux(
7676 f:: F , backend:: AbstractADType , x, tx:: Tangents , :: ForwardOverReverse
7777) where {F}
7878 # pushforward of gradient
79- inner_gradient = Gradient (f, nested (maybe_inner (backend)))
79+ inner_gradient (x) = gradient (f, nested (maybe_inner (backend)), x )
8080 outer_pushforward_extras = prepare_pushforward (
8181 inner_gradient, maybe_outer (backend), x, tx
8282 )
@@ -95,7 +95,7 @@ function _prepare_hvp_aux(
9595 f:: F , backend:: AbstractADType , x, tx:: Tangents , :: ReverseOverReverse
9696) where {F}
9797 # pullback of gradient
98- inner_gradient = Gradient (f, nested (maybe_inner (backend)))
98+ inner_gradient (x) = gradient (f, nested (maybe_inner (backend)), x )
9999 outer_pullback_extras = prepare_pullback (inner_gradient, maybe_outer (backend), x, tx)
100100 return ReverseOverReverseHVPExtras (inner_gradient, outer_pullback_extras)
101101end
@@ -123,11 +123,13 @@ end
123123function hvp (
124124 f:: F , :: ReverseOverForwardHVPExtras , backend:: AbstractADType , x, tx:: Tangents
125125) where {F}
126- dgs = map (tx. d) do dx
127- inner_pushforward = PushforwardFixedSeed (f, nested (maybe_inner (backend)), Tangents (dx))
126+ tg = map (tx) do dx
127+ function inner_pushforward (x)
128+ return only (pushforward (f, nested (maybe_inner (backend)), x, Tangents (dx)))
129+ end
128130 gradient (only ∘ inner_pushforward, maybe_outer (backend), x)
129131 end
130- return Tangents (dgs ... )
132+ return tg
131133end
132134
133135function hvp (
@@ -174,9 +176,9 @@ function hvp!(
174176 tx:: Tangents ,
175177) where {F}
176178 for b in eachindex (tx. d, tg. d)
177- inner_pushforward = PushforwardFixedSeed (
178- f, nested (maybe_inner (backend)), Tangents (tx. d[b])
179- )
179+ function inner_pushforward (x)
180+ return only ( pushforward ( f, nested (maybe_inner (backend)), x, Tangents (tx. d[b])) )
181+ end
180182 gradient! (only ∘ inner_pushforward, tg. d[b], maybe_outer (backend), x)
181183 end
182184 return tg
0 commit comments