@@ -6,7 +6,7 @@ struct FastDifferentiationOneArgPushforwardExtras{Y,E1,E2} <: PushforwardExtras
66 jvp_exe!:: E2
77end
88
9- function DI. prepare_pushforward (f, :: AnyAutoFastDifferentiation , x, dx)
9+ function DI. prepare_pushforward (f, :: AutoFastDifferentiation , x, dx)
1010 y_prototype = f (x)
1111 x_var = if x isa Number
1212 only (make_variables (:x ))
@@ -24,11 +24,7 @@ function DI.prepare_pushforward(f, ::AnyAutoFastDifferentiation, x, dx)
2424end
2525
2626function DI. pushforward (
27- f,
28- :: AnyAutoFastDifferentiation ,
29- x,
30- dx,
31- extras:: FastDifferentiationOneArgPushforwardExtras ,
27+ f, :: AutoFastDifferentiation , x, dx, extras:: FastDifferentiationOneArgPushforwardExtras
3228)
3329 v_vec = vcat (myvec (x), myvec (dx))
3430 if extras. y_prototype isa Number
4137function DI. pushforward! (
4238 f,
4339 dy,
44- :: AnyAutoFastDifferentiation ,
40+ :: AutoFastDifferentiation ,
4541 x,
4642 dx,
4743 extras:: FastDifferentiationOneArgPushforwardExtras ,
5349
5450function DI. value_and_pushforward (
5551 f,
56- backend:: AnyAutoFastDifferentiation ,
52+ backend:: AutoFastDifferentiation ,
5753 x,
5854 dx,
5955 extras:: FastDifferentiationOneArgPushforwardExtras ,
6460function DI. value_and_pushforward! (
6561 f,
6662 dy,
67- backend:: AnyAutoFastDifferentiation ,
63+ backend:: AutoFastDifferentiation ,
6864 x,
6965 dx,
7066 extras:: FastDifferentiationOneArgPushforwardExtras ,
@@ -84,7 +80,7 @@ struct FastDifferentiationOneArgDerivativeExtras{Y,E1,E2} <: DerivativeExtras
8480 der_exe!:: E2
8581end
8682
87- function DI. prepare_derivative (f, :: AnyAutoFastDifferentiation , x)
83+ function DI. prepare_derivative (f, :: AutoFastDifferentiation , x)
8884 y_prototype = f (x)
8985 x_var = only (make_variables (:x ))
9086 y_var = f (x_var)
@@ -98,7 +94,7 @@ function DI.prepare_derivative(f, ::AnyAutoFastDifferentiation, x)
9894end
9995
10096function DI. derivative (
101- f, :: AnyAutoFastDifferentiation , x, extras:: FastDifferentiationOneArgDerivativeExtras
97+ f, :: AutoFastDifferentiation , x, extras:: FastDifferentiationOneArgDerivativeExtras
10298)
10399 if extras. y_prototype isa Number
104100 return only (extras. der_exe (monovec (x)))
@@ -108,19 +104,15 @@ function DI.derivative(
108104end
109105
110106function DI. derivative! (
111- f,
112- der,
113- :: AnyAutoFastDifferentiation ,
114- x,
115- extras:: FastDifferentiationOneArgDerivativeExtras ,
107+ f, der, :: AutoFastDifferentiation , x, extras:: FastDifferentiationOneArgDerivativeExtras
116108)
117109 extras. der_exe! (vec (der), monovec (x))
118110 return der
119111end
120112
121113function DI. value_and_derivative (
122114 f,
123- backend:: AnyAutoFastDifferentiation ,
115+ backend:: AutoFastDifferentiation ,
124116 x,
125117 extras:: FastDifferentiationOneArgDerivativeExtras ,
126118)
130122function DI. value_and_derivative! (
131123 f,
132124 der,
133- backend:: AnyAutoFastDifferentiation ,
125+ backend:: AutoFastDifferentiation ,
134126 x,
135127 extras:: FastDifferentiationOneArgDerivativeExtras ,
136128)
@@ -144,7 +136,7 @@ struct FastDifferentiationOneArgGradientExtras{E1,E2} <: GradientExtras
144136 jac_exe!:: E2
145137end
146138
147- function DI. prepare_gradient (f, backend:: AnyAutoFastDifferentiation , x)
139+ function DI. prepare_gradient (f, backend:: AutoFastDifferentiation , x)
148140 y_prototype = f (x)
149141 x_var = make_variables (:x , size (x)... )
150142 y_var = f (x_var)
@@ -158,37 +150,30 @@ function DI.prepare_gradient(f, backend::AnyAutoFastDifferentiation, x)
158150end
159151
160152function DI. gradient (
161- f, :: AnyAutoFastDifferentiation , x, extras:: FastDifferentiationOneArgGradientExtras
153+ f, :: AutoFastDifferentiation , x, extras:: FastDifferentiationOneArgGradientExtras
162154)
163155 jac = extras. jac_exe (vec (x))
164156 grad_vec = @view jac[1 , :]
165157 return reshape (grad_vec, size (x))
166158end
167159
168160function DI. gradient! (
169- f,
170- grad,
171- :: AnyAutoFastDifferentiation ,
172- x,
173- extras:: FastDifferentiationOneArgGradientExtras ,
161+ f, grad, :: AutoFastDifferentiation , x, extras:: FastDifferentiationOneArgGradientExtras
174162)
175163 extras. jac_exe! (reshape (grad, 1 , length (grad)), vec (x))
176164 return grad
177165end
178166
179167function DI. value_and_gradient (
180- f,
181- backend:: AnyAutoFastDifferentiation ,
182- x,
183- extras:: FastDifferentiationOneArgGradientExtras ,
168+ f, backend:: AutoFastDifferentiation , x, extras:: FastDifferentiationOneArgGradientExtras
184169)
185170 return f (x), DI. gradient (f, backend, x, extras)
186171end
187172
188173function DI. value_and_gradient! (
189174 f,
190175 grad,
191- backend:: AnyAutoFastDifferentiation ,
176+ backend:: AutoFastDifferentiation ,
192177 x,
193178 extras:: FastDifferentiationOneArgGradientExtras ,
194179)
@@ -261,7 +246,7 @@ struct FastDifferentiationAllocatingSecondDerivativeExtras{Y,E1,E2} <:
261246 der2_exe!:: E2
262247end
263248
264- function DI. prepare_second_derivative (f, :: AnyAutoFastDifferentiation , x)
249+ function DI. prepare_second_derivative (f, :: AutoFastDifferentiation , x)
265250 y_prototype = f (x)
266251 x_var = only (make_variables (:x ))
267252 y_var = f (x_var)
278263
279264function DI. second_derivative (
280265 f,
281- :: AnyAutoFastDifferentiation ,
266+ :: AutoFastDifferentiation ,
282267 x,
283268 extras:: FastDifferentiationAllocatingSecondDerivativeExtras ,
284269)
292277function DI. second_derivative! (
293278 f,
294279 der2,
295- backend:: AnyAutoFastDifferentiation ,
280+ backend:: AutoFastDifferentiation ,
296281 x,
297282 extras:: FastDifferentiationAllocatingSecondDerivativeExtras ,
298283)
@@ -307,7 +292,7 @@ struct FastDifferentiationHVPExtras{E1,E2} <: HVPExtras
307292 hvp_exe!:: E2
308293end
309294
310- function DI. prepare_hvp (f, :: AnyAutoFastDifferentiation , x, v)
295+ function DI. prepare_hvp (f, :: AutoFastDifferentiation , x, v)
311296 x_var = make_variables (:x , size (x)... )
312297 y_var = f (x_var)
313298
@@ -318,14 +303,14 @@ function DI.prepare_hvp(f, ::AnyAutoFastDifferentiation, x, v)
318303 return FastDifferentiationHVPExtras (hvp_exe, hvp_exe!)
319304end
320305
321- function DI. hvp (f, :: AnyAutoFastDifferentiation , x, v, extras:: FastDifferentiationHVPExtras )
306+ function DI. hvp (f, :: AutoFastDifferentiation , x, v, extras:: FastDifferentiationHVPExtras )
322307 v_vec = vcat (vec (x), vec (v))
323308 hv_vec = extras. hvp_exe (v_vec)
324309 return reshape (hv_vec, size (x))
325310end
326311
327312function DI. hvp! (
328- f, p, :: AnyAutoFastDifferentiation , x, v, extras:: FastDifferentiationHVPExtras
313+ f, p, :: AutoFastDifferentiation , x, v, extras:: FastDifferentiationHVPExtras
329314)
330315 v_vec = vcat (vec (x), vec (v))
331316 extras. hvp_exe! (p, v_vec)
0 commit comments