@@ -53,34 +53,34 @@ struct PullbackJacobianExtras{E<:PullbackExtras} <: JacobianExtras
5353 pullback_extras:: E
5454end
5555
56- function prepare_jacobian (f, backend:: AbstractADType , x)
56+ function prepare_jacobian (f:: F , backend:: AbstractADType , x) where {F}
5757 return prepare_jacobian_aux (f, backend, x, pushforward_performance (backend))
5858end
5959
60- function prepare_jacobian (f!, y, backend:: AbstractADType , x)
60+ function prepare_jacobian (f!:: F , y, backend:: AbstractADType , x) where {F}
6161 return prepare_jacobian_aux (f!, y, backend, x, pushforward_performance (backend))
6262end
6363
64- function prepare_jacobian_aux (f, backend, x, :: PushforwardFast )
64+ function prepare_jacobian_aux (f:: F , backend, x, :: PushforwardFast ) where {F}
6565 dx = basis (backend, x, first (CartesianIndices (x)))
6666 pushforward_extras = prepare_pushforward (f, backend, x, dx)
6767 return PushforwardJacobianExtras (pushforward_extras)
6868end
6969
70- function prepare_jacobian_aux (f!, y, backend, x, :: PushforwardFast )
70+ function prepare_jacobian_aux (f!:: F , y, backend, x, :: PushforwardFast ) where {F}
7171 dx = basis (backend, x, first (CartesianIndices (x)))
7272 pushforward_extras = prepare_pushforward (f!, y, backend, x, dx)
7373 return PushforwardJacobianExtras (pushforward_extras)
7474end
7575
76- function prepare_jacobian_aux (f, backend, x, :: PushforwardSlow )
76+ function prepare_jacobian_aux (f:: F , backend, x, :: PushforwardSlow ) where {F}
7777 y = f (x)
7878 dy = basis (backend, y, first (CartesianIndices (y)))
7979 pullback_extras = prepare_pullback (f, backend, x, dy)
8080 return PullbackJacobianExtras (pullback_extras)
8181end
8282
83- function prepare_jacobian_aux (f!, y, backend, x, :: PushforwardSlow )
83+ function prepare_jacobian_aux (f!:: F , y, backend, x, :: PushforwardSlow ) where {F}
8484 dy = basis (backend, y, first (CartesianIndices (y)))
8585 pullback_extras = prepare_pullback (f!, y, backend, x, dy)
8686 return PullbackJacobianExtras (pullback_extras)
8989# # One argument
9090
9191function value_and_jacobian (
92- f, backend:: AbstractADType , x, extras:: JacobianExtras = prepare_jacobian (f, backend, x)
93- )
92+ f:: F , backend:: AbstractADType , x, extras:: JacobianExtras = prepare_jacobian (f, backend, x)
93+ ) where {F}
9494 return value_and_jacobian_onearg_aux (f, backend, x, extras)
9595end
9696
9797function value_and_jacobian_onearg_aux (
98- f, backend, x:: AbstractArray , extras:: PushforwardJacobianExtras
99- )
98+ f:: F , backend, x:: AbstractArray , extras:: PushforwardJacobianExtras
99+ ) where {F}
100100 y = f (x)
101101 jac = stack (CartesianIndices (x); dims= 2 ) do j
102102 dx_j = basis (backend, x, j)
@@ -107,8 +107,8 @@ function value_and_jacobian_onearg_aux(
107107end
108108
109109function value_and_jacobian_onearg_aux (
110- f, backend, x:: AbstractArray , extras:: PullbackJacobianExtras
111- )
110+ f:: F , backend, x:: AbstractArray , extras:: PullbackJacobianExtras
111+ ) where {F}
112112 y, pullbackfunc = value_and_pullback_split (f, backend, x, extras. pullback_extras)
113113 jac = stack (CartesianIndices (y); dims= 1 ) do i
114114 dy_i = basis (backend, y, i)
@@ -119,18 +119,18 @@ function value_and_jacobian_onearg_aux(
119119end
120120
121121function value_and_jacobian! (
122- f,
122+ f:: F ,
123123 jac,
124124 backend:: AbstractADType ,
125125 x,
126126 extras:: JacobianExtras = prepare_jacobian (f, backend, x),
127- )
127+ ) where {F}
128128 return value_and_jacobian_onearg_aux! (f, jac, backend, x, extras)
129129end
130130
131131function value_and_jacobian_onearg_aux! (
132- f, jac:: AbstractMatrix , backend, x:: AbstractArray , extras:: PushforwardJacobianExtras
133- )
132+ f:: F , jac:: AbstractMatrix , backend, x:: AbstractArray , extras:: PushforwardJacobianExtras
133+ ) where {F}
134134 y = f (x)
135135 for (k, j) in enumerate (CartesianIndices (x))
136136 dx_j = basis (backend, x, j)
@@ -141,8 +141,8 @@ function value_and_jacobian_onearg_aux!(
141141end
142142
143143function value_and_jacobian_onearg_aux! (
144- f, jac:: AbstractMatrix , backend, x:: AbstractArray , extras:: PullbackJacobianExtras
145- )
144+ f:: F , jac:: AbstractMatrix , backend, x:: AbstractArray , extras:: PullbackJacobianExtras
145+ ) where {F}
146146 y, pullbackfunc! = value_and_pullback!_split (f, backend, x, extras. pullback_extras)
147147 for (k, i) in enumerate (CartesianIndices (y))
148148 dy_i = basis (backend, y, i)
@@ -153,36 +153,36 @@ function value_and_jacobian_onearg_aux!(
153153end
154154
155155function jacobian (
156- f, backend:: AbstractADType , x, extras:: JacobianExtras = prepare_jacobian (f, backend, x)
157- )
156+ f:: F , backend:: AbstractADType , x, extras:: JacobianExtras = prepare_jacobian (f, backend, x)
157+ ) where {F}
158158 return value_and_jacobian (f, backend, x, extras)[2 ]
159159end
160160
161161function jacobian! (
162- f,
162+ f:: F ,
163163 jac,
164164 backend:: AbstractADType ,
165165 x,
166166 extras:: JacobianExtras = prepare_jacobian (f, backend, x),
167- )
167+ ) where {F}
168168 return value_and_jacobian! (f, jac, backend, x, extras)[2 ]
169169end
170170
171171# # Two arguments
172172
173173function value_and_jacobian (
174- f!,
174+ f!:: F ,
175175 y,
176176 backend:: AbstractADType ,
177177 x,
178178 extras:: JacobianExtras = prepare_jacobian (f!, y, backend, x),
179- )
179+ ) where {F}
180180 return value_and_jacobian_twoarg_aux (f!, y, backend, x, extras)
181181end
182182
183183function value_and_jacobian_twoarg_aux (
184- f!, y, backend, x:: AbstractArray , extras:: PushforwardJacobianExtras
185- )
184+ f!:: F , y, backend, x:: AbstractArray , extras:: PushforwardJacobianExtras
185+ ) where {F}
186186 jac = stack (CartesianIndices (x); dims= 2 ) do j
187187 dx_j = basis (backend, x, j)
188188 jac_col_j = pushforward (f!, y, backend, x, dx_j, extras. pushforward_extras)
@@ -193,8 +193,8 @@ function value_and_jacobian_twoarg_aux(
193193end
194194
195195function value_and_jacobian_twoarg_aux (
196- f!, y, backend, x:: AbstractArray , extras:: PullbackJacobianExtras
197- )
196+ f!:: F , y, backend, x:: AbstractArray , extras:: PullbackJacobianExtras
197+ ) where {F}
198198 y, pullbackfunc = value_and_pullback_split (f!, y, backend, x, extras. pullback_extras)
199199 jac = stack (CartesianIndices (y); dims= 1 ) do i
200200 dy_i = basis (backend, y, i)
@@ -206,19 +206,24 @@ function value_and_jacobian_twoarg_aux(
206206end
207207
208208function value_and_jacobian! (
209- f!,
209+ f!:: F ,
210210 y,
211211 jac,
212212 backend:: AbstractADType ,
213213 x,
214214 extras:: JacobianExtras = prepare_jacobian (f!, y, backend, x),
215- )
215+ ) where {F}
216216 return value_and_jacobian_twoarg_aux! (f!, y, jac, backend, x, extras)
217217end
218218
219219function value_and_jacobian_twoarg_aux! (
220- f!, y, jac:: AbstractMatrix , backend, x:: AbstractArray , extras:: PushforwardJacobianExtras
221- )
220+ f!:: F ,
221+ y,
222+ jac:: AbstractMatrix ,
223+ backend,
224+ x:: AbstractArray ,
225+ extras:: PushforwardJacobianExtras ,
226+ ) where {F}
222227 for (k, j) in enumerate (CartesianIndices (x))
223228 dx_j = basis (backend, x, j)
224229 jac_col_j = reshape (view (jac, :, k), size (y))
@@ -229,8 +234,8 @@ function value_and_jacobian_twoarg_aux!(
229234end
230235
231236function value_and_jacobian_twoarg_aux! (
232- f!, y, jac:: AbstractMatrix , backend, x:: AbstractArray , extras:: PullbackJacobianExtras
233- )
237+ f!:: F , y, jac:: AbstractMatrix , backend, x:: AbstractArray , extras:: PullbackJacobianExtras
238+ ) where {F}
234239 y, pullbackfunc! = value_and_pullback!_split (f!, y, backend, x, extras. pullback_extras)
235240 for (k, i) in enumerate (CartesianIndices (y))
236241 dy_i = basis (backend, y, i)
@@ -242,22 +247,22 @@ function value_and_jacobian_twoarg_aux!(
242247end
243248
244249function jacobian (
245- f!,
250+ f!:: F ,
246251 y,
247252 backend:: AbstractADType ,
248253 x,
249254 extras:: JacobianExtras = prepare_jacobian (f!, y, backend, x),
250- )
255+ ) where {F}
251256 return value_and_jacobian (f!, y, backend, x, extras)[2 ]
252257end
253258
254259function jacobian! (
255- f!,
260+ f!:: F ,
256261 y,
257262 jac,
258263 backend:: AbstractADType ,
259264 x,
260265 extras:: JacobianExtras = prepare_jacobian (f!, y, backend, x),
261- )
266+ ) where {F}
262267 return value_and_jacobian! (f!, y, jac, backend, x, extras)[2 ]
263268end
0 commit comments