@@ -4,58 +4,58 @@ struct ForwardDiffOneArgPushforwardExtras{T,X} <: PushforwardExtras
44 xdual_tmp:: X
55end
66
7- function DI. prepare_pushforward (f, backend:: AutoForwardDiff , x, dx)
7+ function DI. prepare_pushforward (f:: F , backend:: AutoForwardDiff , x, dx) where {F}
88 T = tag_type (f, backend, x)
99 xdual_tmp = make_dual (T, x, dx)
1010 return ForwardDiffOneArgPushforwardExtras {T,typeof(xdual_tmp)} (xdual_tmp)
1111end
1212
1313function compute_ydual_onearg (
14- f, x:: Number , dx, extras:: ForwardDiffOneArgPushforwardExtras{T}
15- ) where {T}
14+ f:: F , x:: Number , dx, extras:: ForwardDiffOneArgPushforwardExtras{T}
15+ ) where {F, T}
1616 xdual_tmp = make_dual (T, x, dx)
1717 ydual = f (xdual_tmp)
1818 return ydual
1919end
2020
2121function compute_ydual_onearg (
22- f, x, dx, extras:: ForwardDiffOneArgPushforwardExtras{T}
23- ) where {T}
22+ f:: F , x, dx, extras:: ForwardDiffOneArgPushforwardExtras{T}
23+ ) where {F, T}
2424 (; xdual_tmp) = extras
2525 make_dual! (T, xdual_tmp, x, dx)
2626 ydual = f (xdual_tmp)
2727 return ydual
2828end
2929
3030function DI. value_and_pushforward (
31- f, :: AutoForwardDiff , x, dx, extras:: ForwardDiffOneArgPushforwardExtras{T}
32- ) where {T}
31+ f:: F , :: AutoForwardDiff , x, dx, extras:: ForwardDiffOneArgPushforwardExtras{T}
32+ ) where {F, T}
3333 ydual = compute_ydual_onearg (f, x, dx, extras)
3434 y = myvalue (T, ydual)
3535 new_dy = myderivative (T, ydual)
3636 return y, new_dy
3737end
3838
3939function DI. pushforward (
40- f, :: AutoForwardDiff , x, dx, extras:: ForwardDiffOneArgPushforwardExtras{T}
41- ) where {T}
40+ f:: F , :: AutoForwardDiff , x, dx, extras:: ForwardDiffOneArgPushforwardExtras{T}
41+ ) where {F, T}
4242 ydual = compute_ydual_onearg (f, x, dx, extras)
4343 new_dy = myderivative (T, ydual)
4444 return new_dy
4545end
4646
4747function DI. value_and_pushforward! (
48- f, dy, :: AutoForwardDiff , x, dx, extras:: ForwardDiffOneArgPushforwardExtras{T}
49- ) where {T}
48+ f:: F , dy, :: AutoForwardDiff , x, dx, extras:: ForwardDiffOneArgPushforwardExtras{T}
49+ ) where {F, T}
5050 ydual = compute_ydual_onearg (f, x, dx, extras)
5151 y = myvalue (T, ydual)
5252 myderivative! (T, dy, ydual)
5353 return y, dy
5454end
5555
5656function DI. pushforward! (
57- f, dy, :: AutoForwardDiff , x, dx, extras:: ForwardDiffOneArgPushforwardExtras{T}
58- ) where {T}
57+ f:: F , dy, :: AutoForwardDiff , x, dx, extras:: ForwardDiffOneArgPushforwardExtras{T}
58+ ) where {F, T}
5959 ydual = compute_ydual_onearg (f, x, dx, extras)
6060 myderivative! (T, dy, ydual)
6161 return dy
@@ -67,42 +67,34 @@ struct ForwardDiffGradientExtras{C} <: GradientExtras
6767 config:: C
6868end
6969
70- function DI. prepare_gradient (f, backend:: AutoForwardDiff , x:: AbstractArray )
70+ function DI. prepare_gradient (f:: F , backend:: AutoForwardDiff , x:: AbstractArray ) where {F}
7171 return ForwardDiffGradientExtras (GradientConfig (f, x, choose_chunk (backend, x)))
7272end
7373
7474function DI. value_and_gradient! (
75- f,
76- grad:: AbstractArray ,
77- :: AutoForwardDiff ,
78- x:: AbstractArray ,
79- extras:: ForwardDiffGradientExtras ,
80- )
75+ f:: F , grad, :: AutoForwardDiff , x, extras:: ForwardDiffGradientExtras
76+ ) where {F}
8177 result = MutableDiffResult (zero (eltype (x)), (grad,))
8278 result = gradient! (result, f, x, extras. config)
8379 return DiffResults. value (result), DiffResults. gradient (result)
8480end
8581
8682function DI. value_and_gradient (
87- f, backend:: AutoForwardDiff , x:: AbstractArray , extras:: ForwardDiffGradientExtras
88- )
83+ f:: F , backend:: AutoForwardDiff , x, extras:: ForwardDiffGradientExtras
84+ ) where {F}
8985 grad = similar (x)
9086 return DI. value_and_gradient! (f, grad, backend, x, extras)
9187end
9288
9389function DI. gradient! (
94- f,
95- grad:: AbstractArray ,
96- :: AutoForwardDiff ,
97- x:: AbstractArray ,
98- extras:: ForwardDiffGradientExtras ,
99- )
90+ f:: F , grad, :: AutoForwardDiff , x, extras:: ForwardDiffGradientExtras
91+ ) where {F}
10092 return gradient! (grad, f, x, extras. config)
10193end
10294
10395function DI. gradient (
104- f, :: AutoForwardDiff , x:: AbstractArray , extras:: ForwardDiffGradientExtras
105- )
96+ f:: F , :: AutoForwardDiff , x, extras:: ForwardDiffGradientExtras
97+ ) where {F}
10698 return gradient (f, x, extras. config)
10799end
108100
@@ -112,42 +104,34 @@ struct ForwardDiffOneArgJacobianExtras{C} <: JacobianExtras
112104 config:: C
113105end
114106
115- function DI. prepare_jacobian (f, backend:: AutoForwardDiff , x:: AbstractArray )
107+ function DI. prepare_jacobian (f, backend:: AutoForwardDiff , x)
116108 return ForwardDiffOneArgJacobianExtras (JacobianConfig (f, x, choose_chunk (backend, x)))
117109end
118110
119111function DI. value_and_jacobian! (
120- f,
121- jac:: AbstractMatrix ,
122- :: AutoForwardDiff ,
123- x:: AbstractArray ,
124- extras:: ForwardDiffOneArgJacobianExtras ,
125- )
112+ f:: F , jac, :: AutoForwardDiff , x, extras:: ForwardDiffOneArgJacobianExtras
113+ ) where {F}
126114 y = f (x)
127115 result = MutableDiffResult (y, (jac,))
128116 result = jacobian! (result, f, x, extras. config)
129117 return DiffResults. value (result), DiffResults. jacobian (result)
130118end
131119
132120function DI. value_and_jacobian (
133- f, :: AutoForwardDiff , x:: AbstractArray , extras:: ForwardDiffOneArgJacobianExtras
134- )
121+ f:: F , :: AutoForwardDiff , x, extras:: ForwardDiffOneArgJacobianExtras
122+ ) where {F}
135123 return f (x), jacobian (f, x, extras. config)
136124end
137125
138126function DI. jacobian! (
139- f,
140- jac:: AbstractMatrix ,
141- :: AutoForwardDiff ,
142- x:: AbstractArray ,
143- extras:: ForwardDiffOneArgJacobianExtras ,
144- )
127+ f:: F , jac, :: AutoForwardDiff , x, extras:: ForwardDiffOneArgJacobianExtras
128+ ) where {F}
145129 return jacobian! (jac, f, x, extras. config)
146130end
147131
148132function DI. jacobian (
149- f, :: AutoForwardDiff , x:: AbstractArray , extras:: ForwardDiffOneArgJacobianExtras
150- )
133+ f:: F , :: AutoForwardDiff , x, extras:: ForwardDiffOneArgJacobianExtras
134+ ) where {F}
151135 return jacobian (f, x, extras. config)
152136end
153137
@@ -157,22 +141,16 @@ struct ForwardDiffHessianExtras{C} <: HessianExtras
157141 config:: C
158142end
159143
160- function DI. prepare_hessian (f, backend:: AutoForwardDiff , x:: AbstractArray )
144+ function DI. prepare_hessian (f, backend:: AutoForwardDiff , x)
161145 return ForwardDiffHessianExtras (HessianConfig (f, x, choose_chunk (backend, x)))
162146end
163147
164148function DI. hessian! (
165- f,
166- hess:: AbstractMatrix ,
167- :: AutoForwardDiff ,
168- x:: AbstractArray ,
169- extras:: ForwardDiffHessianExtras ,
170- )
149+ f:: F , hess, :: AutoForwardDiff , x, extras:: ForwardDiffHessianExtras
150+ ) where {F}
171151 return hessian! (hess, f, x, extras. config)
172152end
173153
174- function DI. hessian (
175- f, :: AutoForwardDiff , x:: AbstractArray , extras:: ForwardDiffHessianExtras
176- )
154+ function DI. hessian (f:: F , :: AutoForwardDiff , x, extras:: ForwardDiffHessianExtras ) where {F}
177155 return hessian (f, x, extras. config)
178156end
0 commit comments