11# # Pushforward
22
3- struct FastDifferentiationAllocatingPushforwardExtras{E} <: PushforwardExtras
3+ struct FastDifferentiationAllocatingPushforwardExtras{Y,E} <: PushforwardExtras
4+ y_prototype:: Y
45 jvp_exe:: E
56end
67
78function DI. prepare_pushforward (f, :: AnyAutoFastDifferentiation , x)
9+ y_prototype = f (x)
810 x_var = if x isa Number
911 only (make_variables (:x ))
1012 else
@@ -16,7 +18,7 @@ function DI.prepare_pushforward(f, ::AnyAutoFastDifferentiation, x)
1618 y_vec_var = y_var isa Number ? [y_var] : vec (y_var)
1719 jv_vec_var, v_vec_var = jacobian_times_v (y_vec_var, x_vec_var)
1820 jvp_exe = make_function (jv_vec_var, [x_vec_var; v_vec_var]; in_place= false )
19- return FastDifferentiationAllocatingPushforwardExtras (jvp_exe)
21+ return FastDifferentiationAllocatingPushforwardExtras (y_prototype, jvp_exe)
2022end
2123
2224function DI. value_and_pushforward (
@@ -36,6 +38,22 @@ function DI.value_and_pushforward(
3638 end
3739end
3840
41+ function DI. pushforward (
42+ f,
43+ :: AnyAutoFastDifferentiation ,
44+ x,
45+ dx,
46+ extras:: FastDifferentiationAllocatingPushforwardExtras ,
47+ )
48+ v_vec = vcat (myvec (x), myvec (dx))
49+ jv_vec = extras. jvp_exe (v_vec)
50+ if extras. y_prototype isa Number
51+ return only (jv_vec)
52+ else
53+ return reshape (jv_vec, size (extras. y_prototype))
54+ end
55+ end
56+
3957# # Pullback
4058
4159# TODO : this only fails for scalar -> matrix, not sure why
82100
83101# # Derivative
84102
85- struct FastDifferentiationAllocatingDerivativeExtras{E} <: DerivativeExtras
103+ struct FastDifferentiationAllocatingDerivativeExtras{Y,E} <: DerivativeExtras
104+ y_prototype:: Y
86105 der_exe:: E
87106end
88107
89108function DI. prepare_derivative (f, :: AnyAutoFastDifferentiation , x)
109+ y_prototype = f (x)
90110 x_var = only (make_variables (:x ))
91111 y_var = f (x_var)
92112
93113 x_vec_var = [x_var]
94114 y_vec_var = y_var isa Number ? [y_var] : vec (y_var)
95115 der_vec_var = derivative (y_vec_var, x_var)
96116 der_exe = make_function (der_vec_var, x_vec_var; in_place= false )
97- return FastDifferentiationAllocatingDerivativeExtras (der_exe)
117+ return FastDifferentiationAllocatingDerivativeExtras (y_prototype, der_exe)
98118end
99119
100120function DI. value_and_derivative (
@@ -128,7 +148,12 @@ function DI.derivative(
128148 x,
129149 extras:: FastDifferentiationAllocatingDerivativeExtras ,
130150)
131- return DI. value_and_derivative (f, backend, x, extras)[2 ]
151+ der_vec = extras. der_exe ([x])
152+ if extras. y_prototype isa Number
153+ return only (der_vec)
154+ else
155+ return reshape (der_vec, size (extras. y_prototype))
156+ end
132157end
133158
134159function DI. derivative!! (
@@ -143,20 +168,25 @@ end
143168
144169# # Jacobian
145170
146- struct FastDifferentiationAllocatingJacobianExtras{E} <: JacobianExtras
171+ struct FastDifferentiationAllocatingJacobianExtras{Y,E} <: JacobianExtras
172+ y_prototype:: Y
147173 jac_exe:: E
148174end
149175
150176function DI. prepare_jacobian (f, backend:: AnyAutoFastDifferentiation , x)
151- x_vec_var = make_variables (:x , size (x)... )
152- y_vec_var = f (x_vec_var)
177+ y_prototype = f (x)
178+ x_var = make_variables (:x , size (x)... )
179+ y_var = f (x_var)
180+
181+ x_vec_var = vec (x_var)
182+ y_vec_var = vec (y_var)
153183 if issparse (backend)
154- jac_var = sparse_jacobian (vec ( y_vec_var), vec ( x_vec_var) )
184+ jac_var = sparse_jacobian (y_vec_var, x_vec_var)
155185 else
156- jac_var = jacobian (vec ( y_vec_var), vec ( x_vec_var) )
186+ jac_var = jacobian (y_vec_var, x_vec_var)
157187 end
158- jac_exe = make_function (jac_var, vec ( x_vec_var) ; in_place= false )
159- return FastDifferentiationAllocatingJacobianExtras (jac_exe)
188+ jac_exe = make_function (jac_var, x_vec_var; in_place= false )
189+ return FastDifferentiationAllocatingJacobianExtras (y_prototype, jac_exe)
160190end
161191
162192function DI. jacobian (
@@ -199,19 +229,21 @@ end
199229
200230# # Second derivative
201231
202- struct FastDifferentiationAllocatingSecondDerivativeExtras{E} <: DerivativeExtras
232+ struct FastDifferentiationAllocatingSecondDerivativeExtras{Y,E} <: SecondDerivativeExtras
233+ y_prototype:: Y
203234 der2_exe:: E
204235end
205236
206237function DI. prepare_second_derivative (f, :: AnyAutoFastDifferentiation , x)
238+ y_prototype = f (x)
207239 x_var = only (make_variables (:x ))
208240 y_var = f (x_var)
209241
210242 x_vec_var = [x_var]
211243 y_vec_var = y_var isa Number ? [y_var] : vec (y_var)
212244 der2_vec_var = derivative (y_vec_var, x_var, x_var)
213245 der2_exe = make_function (der2_vec_var, x_vec_var; in_place= false )
214- return FastDifferentiationAllocatingSecondDerivativeExtras (der2_exe)
246+ return FastDifferentiationAllocatingSecondDerivativeExtras (y_prototype, der2_exe)
215247end
216248
217249function DI. second_derivative (
@@ -220,12 +252,11 @@ function DI.second_derivative(
220252 x,
221253 extras:: FastDifferentiationAllocatingSecondDerivativeExtras ,
222254)
223- y = f (x)
224255 der2_vec = extras. der2_exe ([x])
225- if y isa Number
256+ if extras . y_prototype isa Number
226257 return only (der2_vec)
227258 else
228- return reshape (der2_vec, size (y ))
259+ return reshape (der2_vec, size (extras . y_prototype ))
229260 end
230261end
231262
@@ -253,7 +284,7 @@ function DI.prepare_hvp(f, ::AnyAutoFastDifferentiation, x, v)
253284 end
254285 y_var = f (x_var)
255286
256- x_vec_var = x_var isa Number ? [x_var] : vec (x_var)
287+ x_vec_var = vec (x_var)
257288 hv_vec_var, v_vec_var = hessian_times_v (y_var, x_vec_var)
258289 hvp_exe = make_function (hv_vec_var, [x_vec_var; v_vec_var]; in_place= false )
259290 return FastDifferentiationHVPExtras (hvp_exe)
262293function DI. hvp (f, :: AnyAutoFastDifferentiation , x, v, extras:: FastDifferentiationHVPExtras )
263294 v_vec = vcat (myvec (x), myvec (v))
264295 hv_vec = extras. hvp_exe (v_vec)
265- if x isa Number
266- return only (hv_vec)
267- else
268- return reshape (hv_vec, size (x))
269- end
296+ return reshape (hv_vec, size (x))
270297end
271298
272299function DI. hvp!! (
0 commit comments