@@ -107,45 +107,61 @@ end
107107
108108# # Gradient
109109
110- struct PolyesterForwardDiffGradientPrep{chunksize} <: DI.GradientPrep
110+ struct PolyesterForwardDiffGradientPrep{chunksize,P } <: DI.GradientPrep
111111 chunk:: Chunk{chunksize}
112+ single_threaded_prep:: P
112113end
113114
114115function DI. prepare_gradient (
115- f, :: AutoPolyesterForwardDiff{chunksize} , x, contexts:: Vararg{DI.Context,C}
116+ f, backend :: AutoPolyesterForwardDiff{chunksize} , x, contexts:: Vararg{DI.Context,C}
116117) where {chunksize,C}
117118 if isnothing (chunksize)
118119 chunk = Chunk (x)
119120 else
120121 chunk = Chunk {chunksize} ()
121122 end
122- return PolyesterForwardDiffGradientPrep (chunk)
123+ single_threaded_prep = DI. prepare_gradient (f, single_threaded (backend), x, contexts... )
124+ return PolyesterForwardDiffGradientPrep (chunk, single_threaded_prep)
123125end
124126
125127function DI. value_and_gradient! (
126128 f,
127129 grad,
128130 prep:: PolyesterForwardDiffGradientPrep ,
129- :: AutoPolyesterForwardDiff ,
131+ backend :: AutoPolyesterForwardDiff ,
130132 x,
131133 contexts:: Vararg{DI.Context,C} ,
132134) where {C}
133- fc = DI. with_contexts (f, contexts... )
134- threaded_gradient! (fc, grad, x, prep. chunk)
135- return fc (x), grad
135+ if contexts isa NTuple{C,DI. GeneralizedConstant}
136+ fc = DI. with_contexts (f, contexts... )
137+ threaded_gradient! (fc, grad, x, prep. chunk)
138+ return fc (x), grad
139+ else
140+ # TODO : optimize
141+ return DI. value_and_gradient! (
142+ f, grad, prep. single_threaded_prep, single_threaded (backend), x, contexts...
143+ )
144+ end
136145end
137146
138147function DI. gradient! (
139148 f,
140149 grad,
141150 prep:: PolyesterForwardDiffGradientPrep ,
142- :: AutoPolyesterForwardDiff ,
151+ backend :: AutoPolyesterForwardDiff ,
143152 x,
144153 contexts:: Vararg{DI.Context,C} ,
145154) where {C}
146- fc = DI. with_contexts (f, contexts... )
147- threaded_gradient! (fc, grad, x, prep. chunk)
148- return grad
155+ if contexts isa NTuple{C,DI. GeneralizedConstant}
156+ fc = DI. with_contexts (f, contexts... )
157+ threaded_gradient! (fc, grad, x, prep. chunk)
158+ return grad
159+ else
160+ # TODO : optimize
161+ return DI. gradient! (
162+ f, grad, prep. single_threaded_prep, single_threaded (backend), x, contexts...
163+ )
164+ end
149165end
150166
151167function DI. value_and_gradient (
@@ -170,43 +186,57 @@ end
170186
171187# # Jacobian
172188
173- struct PolyesterForwardDiffOneArgJacobianPrep{chunksize} <: DI.JacobianPrep
189+ struct PolyesterForwardDiffOneArgJacobianPrep{chunksize,P } <: DI.JacobianPrep
174190 chunk:: Chunk{chunksize}
191+ single_threaded_prep:: P
175192end
176193
177194function DI. prepare_jacobian (
178- f, :: AutoPolyesterForwardDiff{chunksize} , x, contexts:: Vararg{DI.Context,C}
195+ f, backend :: AutoPolyesterForwardDiff{chunksize} , x, contexts:: Vararg{DI.Context,C}
179196) where {chunksize,C}
180197 if isnothing (chunksize)
181198 chunk = Chunk (x)
182199 else
183200 chunk = Chunk {chunksize} ()
184201 end
185- return PolyesterForwardDiffOneArgJacobianPrep (chunk)
202+ single_threaded_prep = DI. prepare_jacobian (f, single_threaded (backend), x, contexts... )
203+ return PolyesterForwardDiffOneArgJacobianPrep (chunk, single_threaded_prep)
186204end
187205
188206function DI. value_and_jacobian! (
189207 f,
190208 jac,
191209 prep:: PolyesterForwardDiffOneArgJacobianPrep ,
192- :: AutoPolyesterForwardDiff ,
210+ backend :: AutoPolyesterForwardDiff ,
193211 x,
194212 contexts:: Vararg{DI.Context,C} ,
195213) where {C}
196- fc = DI. with_contexts (f, contexts... )
197- return fc (x), threaded_jacobian! (fc, jac, x, prep. chunk)
214+ if contexts isa NTuple{C,DI. GeneralizedConstant}
215+ fc = DI. with_contexts (f, contexts... )
216+ return fc (x), threaded_jacobian! (fc, jac, x, prep. chunk)
217+ else
218+ return DI. value_and_jacobian! (
219+ f, jac, prep. single_threaded_prep, single_threaded (backend), x, contexts...
220+ )
221+ end
198222end
199223
200224function DI. jacobian! (
201225 f,
202226 jac,
203227 prep:: PolyesterForwardDiffOneArgJacobianPrep ,
204- :: AutoPolyesterForwardDiff ,
228+ backend :: AutoPolyesterForwardDiff ,
205229 x,
206230 contexts:: Vararg{DI.Context,C} ,
207231) where {C}
208- fc = DI. with_contexts (f, contexts... )
209- return threaded_jacobian! (fc, jac, x, prep. chunk)
232+ if contexts isa NTuple{C,DI. GeneralizedConstant}
233+ fc = DI. with_contexts (f, contexts... )
234+ return threaded_jacobian! (fc, jac, x, prep. chunk)
235+ else
236+ return DI. jacobian! (
237+ f, jac, prep. single_threaded_prep, single_threaded (backend), x, contexts...
238+ )
239+ end
210240end
211241
212242function DI. value_and_jacobian (
@@ -217,9 +247,8 @@ function DI.value_and_jacobian(
217247 contexts:: Vararg{DI.Context,C} ,
218248) where {C}
219249 y = f (x, map (DI. unwrap, contexts)... )
220- return DI. value_and_jacobian! (
221- f, similar (y, length (y), length (x)), prep, backend, x, contexts...
222- )
250+ jac = similar (y, length (y), length (x))
251+ return DI. value_and_jacobian! (f, jac, prep, backend, x, contexts... )
223252end
224253
225254function DI. jacobian (
@@ -230,7 +259,8 @@ function DI.jacobian(
230259 contexts:: Vararg{DI.Context,C} ,
231260) where {C}
232261 y = f (x, map (DI. unwrap, contexts)... )
233- return DI. jacobian! (f, similar (y, length (y), length (x)), prep, backend, x, contexts... )
262+ jac = similar (y, length (y), length (x))
263+ return DI. jacobian! (f, jac, prep, backend, x, contexts... )
234264end
235265
236266# # Hessian
299329
300330function DI.hvp(
301331 f,
302- prep::DI.HVPPrep ,
332+ prep::DI.ForwardOverAnythingHVPPrep ,
303333 backend::AutoPolyesterForwardDiff,
304334 x,
305335 tx::NTuple,
313343function DI.hvp!(
314344 f,
315345 tg::NTuple,
316- prep::DI.HVPPrep ,
346+ prep::DI.ForwardOverAnythingHVPPrep ,
317347 backend::AutoPolyesterForwardDiff,
318348 x,
319349 tx::NTuple,
326356
327357function DI.gradient_and_hvp(
328358 f,
329- prep::DI.HVPPrep ,
359+ prep::DI.ForwardOverAnythingHVPPrep ,
330360 backend::AutoPolyesterForwardDiff,
331361 x,
332362 tx::NTuple,
@@ -341,7 +371,7 @@ function DI.gradient_and_hvp!(
341371 f,
342372 grad,
343373 tg::NTuple,
344- prep::DI.HVPPrep ,
374+ prep::DI.ForwardOverAnythingHVPPrep ,
345375 backend::AutoPolyesterForwardDiff,
346376 x,
347377 tx::NTuple,
0 commit comments