@@ -58,123 +58,126 @@ struct ReverseOverReverseHVPExtras{G<:Gradient,E<:PullbackExtras} <: HVPExtras
5858end
5959
6060function prepare_hvp (f:: F , backend:: AbstractADType , x, tx:: Tangents ) where {F}
61- return prepare_hvp (f, SecondOrder (backend, backend), x, tx)
62- end
63-
64- function prepare_hvp (f:: F , backend:: SecondOrder , x, tx:: Tangents ) where {F}
6561 return _prepare_hvp_aux (f, backend, x, tx, hvp_mode (backend))
6662end
6763
6864function _prepare_hvp_aux (
69- f:: F , backend:: SecondOrder , x, tx:: Tangents , :: ForwardOverForward
65+ f:: F , backend:: AbstractADType , x, tx:: Tangents , :: ForwardOverForward
7066) where {F}
7167 # pushforward of many pushforwards in theory, but pushforward of gradient in practice
72- inner_gradient = Gradient (f, nested (inner (backend)))
73- outer_pushforward_extras = prepare_pushforward (inner_gradient, outer (backend), x, tx)
68+ inner_gradient = Gradient (f, nested (maybe_inner (backend)))
69+ outer_pushforward_extras = prepare_pushforward (
70+ inner_gradient, maybe_outer (backend), x, tx
71+ )
7472 return ForwardOverForwardHVPExtras (inner_gradient, outer_pushforward_extras)
7573end
7674
7775function _prepare_hvp_aux (
78- f:: F , backend:: SecondOrder , x, tx:: Tangents , :: ForwardOverReverse
76+ f:: F , backend:: AbstractADType , x, tx:: Tangents , :: ForwardOverReverse
7977) where {F}
8078 # pushforward of gradient
81- inner_gradient = Gradient (f, nested (inner (backend)))
82- outer_pushforward_extras = prepare_pushforward (inner_gradient, outer (backend), x, tx)
79+ inner_gradient = Gradient (f, nested (maybe_inner (backend)))
80+ outer_pushforward_extras = prepare_pushforward (
81+ inner_gradient, maybe_outer (backend), x, tx
82+ )
8383 return ForwardOverReverseHVPExtras (inner_gradient, outer_pushforward_extras)
8484end
8585
8686function _prepare_hvp_aux (
87- f:: F , backend:: SecondOrder , x, tx:: Tangents , :: ReverseOverForward
87+ f:: F , backend:: AbstractADType , x, tx:: Tangents , :: ReverseOverForward
8888) where {F}
8989 # gradient of pushforward
9090 # uses dx in the closure so it can't be prepared
9191 return ReverseOverForwardHVPExtras ()
9292end
9393
9494function _prepare_hvp_aux (
95- f:: F , backend:: SecondOrder , x, tx:: Tangents , :: ReverseOverReverse
95+ f:: F , backend:: AbstractADType , x, tx:: Tangents , :: ReverseOverReverse
9696) where {F}
9797 # pullback of gradient
98- inner_gradient = Gradient (f, nested (inner (backend)))
99- outer_pullback_extras = prepare_pullback (inner_gradient, outer (backend), x, tx)
98+ inner_gradient = Gradient (f, nested (maybe_inner (backend)))
99+ outer_pullback_extras = prepare_pullback (inner_gradient, maybe_outer (backend), x, tx)
100100 return ReverseOverReverseHVPExtras (inner_gradient, outer_pullback_extras)
101101end
102102
103103# # One argument
104104
105- function hvp (f:: F , extras:: HVPExtras , backend:: AbstractADType , x, tx:: Tangents ) where {F}
106- return hvp (f, extras, SecondOrder (backend, backend), x, tx)
107- end
108-
109105function hvp (
110- f:: F , extras:: ForwardOverForwardHVPExtras , backend:: SecondOrder , x, tx:: Tangents
106+ f:: F , extras:: ForwardOverForwardHVPExtras , backend:: AbstractADType , x, tx:: Tangents
111107) where {F}
112108 @compat (; inner_gradient, outer_pushforward_extras) = extras
113- return pushforward (inner_gradient, outer_pushforward_extras, outer (backend), x, tx)
109+ return pushforward (
110+ inner_gradient, outer_pushforward_extras, maybe_outer (backend), x, tx
111+ )
114112end
115113
116114function hvp (
117- f:: F , extras:: ForwardOverReverseHVPExtras , backend:: SecondOrder , x, tx:: Tangents
115+ f:: F , extras:: ForwardOverReverseHVPExtras , backend:: AbstractADType , x, tx:: Tangents
118116) where {F}
119117 @compat (; inner_gradient, outer_pushforward_extras) = extras
120- return pushforward (inner_gradient, outer_pushforward_extras, outer (backend), x, tx)
118+ return pushforward (
119+ inner_gradient, outer_pushforward_extras, maybe_outer (backend), x, tx
120+ )
121121end
122122
123123function hvp (
124- f:: F , :: ReverseOverForwardHVPExtras , backend:: SecondOrder , x, tx:: Tangents
124+ f:: F , :: ReverseOverForwardHVPExtras , backend:: AbstractADType , x, tx:: Tangents
125125) where {F}
126126 dgs = map (tx. d) do dx
127- inner_pushforward = PushforwardFixedSeed (f, nested (inner (backend)), Tangents (dx))
128- gradient (only ∘ inner_pushforward, outer (backend), x)
127+ inner_pushforward = PushforwardFixedSeed (f, nested (maybe_inner (backend)), Tangents (dx))
128+ gradient (only ∘ inner_pushforward, maybe_outer (backend), x)
129129 end
130130 return Tangents (dgs... )
131131end
132132
133133function hvp (
134- f:: F , extras:: ReverseOverReverseHVPExtras , backend:: SecondOrder , x, tx:: Tangents
134+ f:: F , extras:: ReverseOverReverseHVPExtras , backend:: AbstractADType , x, tx:: Tangents
135135) where {F}
136136 @compat (; inner_gradient, outer_pullback_extras) = extras
137- return pullback (inner_gradient, outer_pullback_extras, outer (backend), x, tx)
138- end
139-
140- function hvp! (
141- f:: F , tg:: Tangents , extras:: HVPExtras , backend:: AbstractADType , x, tx:: Tangents
142- ) where {F}
143- return hvp! (f, tg, extras, SecondOrder (backend, backend), x, tx)
137+ return pullback (inner_gradient, outer_pullback_extras, maybe_outer (backend), x, tx)
144138end
145139
146140function hvp! (
147141 f:: F ,
148142 tg:: Tangents ,
149143 extras:: ForwardOverForwardHVPExtras ,
150- backend:: SecondOrder ,
144+ backend:: AbstractADType ,
151145 x,
152146 tx:: Tangents ,
153147) where {F}
154148 @compat (; inner_gradient, outer_pushforward_extras) = extras
155- return pushforward! (inner_gradient, tg, outer_pushforward_extras, outer (backend), x, tx)
149+ return pushforward! (
150+ inner_gradient, tg, outer_pushforward_extras, maybe_outer (backend), x, tx
151+ )
156152end
157153
158154function hvp! (
159155 f:: F ,
160156 tg:: Tangents ,
161157 extras:: ForwardOverReverseHVPExtras ,
162- backend:: SecondOrder ,
158+ backend:: AbstractADType ,
163159 x,
164160 tx:: Tangents ,
165161) where {F}
166162 @compat (; inner_gradient, outer_pushforward_extras) = extras
167- return pushforward! (inner_gradient, tg, outer_pushforward_extras, outer (backend), x, tx)
163+ return pushforward! (
164+ inner_gradient, tg, outer_pushforward_extras, maybe_outer (backend), x, tx
165+ )
168166end
169167
170168function hvp! (
171- f:: F , tg:: Tangents , :: ReverseOverForwardHVPExtras , backend:: SecondOrder , x, tx:: Tangents
169+ f:: F ,
170+ tg:: Tangents ,
171+ :: ReverseOverForwardHVPExtras ,
172+ backend:: AbstractADType ,
173+ x,
174+ tx:: Tangents ,
172175) where {F}
173176 for b in eachindex (tx. d, tg. d)
174177 inner_pushforward = PushforwardFixedSeed (
175- f, nested (inner (backend)), Tangents (tx. d[b])
178+ f, nested (maybe_inner (backend)), Tangents (tx. d[b])
176179 )
177- gradient! (only ∘ inner_pushforward, tg. d[b], outer (backend), x)
180+ gradient! (only ∘ inner_pushforward, tg. d[b], maybe_outer (backend), x)
178181 end
179182 return tg
180183end
@@ -183,10 +186,10 @@ function hvp!(
183186 f:: F ,
184187 tg:: Tangents ,
185188 extras:: ReverseOverReverseHVPExtras ,
186- backend:: SecondOrder ,
189+ backend:: AbstractADType ,
187190 x,
188191 tx:: Tangents ,
189192) where {F}
190193 @compat (; inner_gradient, outer_pullback_extras) = extras
191- return pullback! (inner_gradient, tg, outer_pullback_extras, outer (backend), x, tx)
194+ return pullback! (inner_gradient, tg, outer_pullback_extras, maybe_outer (backend), x, tx)
192195end
0 commit comments