11# # Pushforward
22
3+ # ## Unprepared (avoid working on `similar(x)`)
4+
5+ function DI. value_and_pushforward (
6+ f:: F , backend:: AutoForwardDiff , x, tx:: NTuple{B} , contexts:: Vararg{Context,C}
7+ ) where {F,B,C}
8+ T = tag_type (f, backend, x)
9+ xdual_tmp = make_dual (T, x, tx)
10+ ydual = f (xdual_tmp, map (unwrap, contexts)... )
11+ y = myvalue (T, ydual)
12+ ty = mypartials (T, Val (B), ydual)
13+ return y, ty
14+ end
15+
16+ function DI. value_and_pushforward! (
17+ f:: F , ty:: NTuple , backend:: AutoForwardDiff , x, tx:: NTuple , contexts:: Vararg{Context,C}
18+ ) where {F,C}
19+ T = tag_type (f, backend, x)
20+ xdual_tmp = make_dual (T, x, tx)
21+ ydual = f (xdual_tmp, map (unwrap, contexts)... )
22+ y = myvalue (T, ydual)
23+ mypartials! (T, ty, ydual)
24+ return y, ty
25+ end
26+
27+ function DI. pushforward (
28+ f:: F , backend:: AutoForwardDiff , x, tx:: NTuple{B} , contexts:: Vararg{Context,C}
29+ ) where {F,B,C}
30+ T = tag_type (f, backend, x)
31+ xdual_tmp = make_dual (T, x, tx)
32+ ydual = f (xdual_tmp, map (unwrap, contexts)... )
33+ ty = mypartials (T, Val (B), ydual)
34+ return ty
35+ end
36+
37+ function DI. pushforward! (
38+ f:: F , ty:: NTuple , backend:: AutoForwardDiff , x, tx:: NTuple , contexts:: Vararg{Context,C}
39+ ) where {F,C}
40+ T = tag_type (f, backend, x)
41+ xdual_tmp = make_dual (T, x, tx)
42+ ydual = f (xdual_tmp, map (unwrap, contexts)... )
43+ mypartials! (T, ty, ydual)
44+ return ty
45+ end
46+
47+ # ## Prepared
48+
349struct ForwardDiffOneArgPushforwardPrep{T,X} <: PushforwardPrep
450 xdual_tmp:: X
551end
@@ -159,12 +205,12 @@ end
159205
160206# # Gradient
161207
162- # ## Unprepared, only when chunk size not specified
208+ # ## Unprepared, only when chunk size and tag are not specified
163209
164210function DI. value_and_gradient! (
165- f:: F , grad, backend:: AutoForwardDiff{chunksize} , x, contexts:: Vararg{Context,C}
166- ) where {F,C,chunksize}
167- if isnothing (chunksize)
211+ f:: F , grad, backend:: AutoForwardDiff{chunksize,T } , x, contexts:: Vararg{Context,C}
212+ ) where {F,C,chunksize,T }
213+ if isnothing (chunksize) && T === Nothing
168214 fc = with_contexts (f, contexts... )
169215 result = DiffResult (zero (eltype (x)), (grad,))
170216 result = gradient! (result, fc, x)
@@ -178,9 +224,9 @@ function DI.value_and_gradient!(
178224end
179225
180226function DI. value_and_gradient (
181- f:: F , backend:: AutoForwardDiff{chunksize} , x, contexts:: Vararg{Context,C}
182- ) where {F,C,chunksize}
183- if isnothing (chunksize)
227+ f:: F , backend:: AutoForwardDiff{chunksize,T } , x, contexts:: Vararg{Context,C}
228+ ) where {F,C,chunksize,T }
229+ if isnothing (chunksize) && T === Nothing
184230 fc = with_contexts (f, contexts... )
185231 result = GradientResult (x)
186232 result = gradient! (result, fc, x)
@@ -192,9 +238,9 @@ function DI.value_and_gradient(
192238end
193239
194240function DI. gradient! (
195- f:: F , grad, backend:: AutoForwardDiff{chunksize} , x, contexts:: Vararg{Context,C}
196- ) where {F,C,chunksize}
197- if isnothing (chunksize)
241+ f:: F , grad, backend:: AutoForwardDiff{chunksize,T } , x, contexts:: Vararg{Context,C}
242+ ) where {F,C,chunksize,T }
243+ if isnothing (chunksize) && T === Nothing
198244 fc = with_contexts (f, contexts... )
199245 return gradient! (grad, fc, x)
200246 else
@@ -204,9 +250,9 @@ function DI.gradient!(
204250end
205251
206252function DI. gradient (
207- f:: F , backend:: AutoForwardDiff{chunksize} , x, contexts:: Vararg{Context,C}
208- ) where {F,C,chunksize}
209- if isnothing (chunksize)
253+ f:: F , backend:: AutoForwardDiff{chunksize,T } , x, contexts:: Vararg{Context,C}
254+ ) where {F,C,chunksize,T }
255+ if isnothing (chunksize) && T === Nothing
210256 fc = with_contexts (f, contexts... )
211257 return gradient (fc, x)
212258 else
@@ -225,7 +271,10 @@ function DI.prepare_gradient(
225271 f:: F , backend:: AutoForwardDiff , x:: AbstractArray , contexts:: Vararg{Context,C}
226272) where {F,C}
227273 fc = with_contexts (f, contexts... )
228- return ForwardDiffGradientPrep (GradientConfig (fc, x, choose_chunk (backend, x)))
274+ chunk = choose_chunk (backend, x)
275+ tag = get_tag (fc, backend, x)
276+ config = GradientConfig (fc, x, chunk, tag)
277+ return ForwardDiffGradientPrep (config)
229278end
230279
231280function DI. value_and_gradient! (
@@ -274,12 +323,12 @@ end
274323
275324# # Jacobian
276325
277- # ## Unprepared, only when chunk size not specified
326+ # ## Unprepared, only when chunk size and tag are not specified
278327
279328function DI. value_and_jacobian! (
280- f:: F , jac, backend:: AutoForwardDiff{chunksize} , x, contexts:: Vararg{Context,C}
281- ) where {F,C,chunksize}
282- if isnothing (chunksize)
329+ f:: F , jac, backend:: AutoForwardDiff{chunksize,T } , x, contexts:: Vararg{Context,C}
330+ ) where {F,C,chunksize,T }
331+ if isnothing (chunksize) && T === Nothing
283332 fc = with_contexts (f, contexts... )
284333 y = fc (x)
285334 result = DiffResult (y, (jac,))
@@ -294,9 +343,9 @@ function DI.value_and_jacobian!(
294343end
295344
296345function DI. value_and_jacobian (
297- f:: F , backend:: AutoForwardDiff{chunksize} , x, contexts:: Vararg{Context,C}
298- ) where {F,C,chunksize}
299- if isnothing (chunksize)
346+ f:: F , backend:: AutoForwardDiff{chunksize,T } , x, contexts:: Vararg{Context,C}
347+ ) where {F,C,chunksize,T }
348+ if isnothing (chunksize) && T === Nothing
300349 fc = with_contexts (f, contexts... )
301350 return fc (x), jacobian (fc, x)
302351 else
@@ -306,9 +355,9 @@ function DI.value_and_jacobian(
306355end
307356
308357function DI. jacobian! (
309- f:: F , jac, backend:: AutoForwardDiff{chunksize} , x, contexts:: Vararg{Context,C}
310- ) where {F,C,chunksize}
311- if isnothing (chunksize)
358+ f:: F , jac, backend:: AutoForwardDiff{chunksize,T } , x, contexts:: Vararg{Context,C}
359+ ) where {F,C,chunksize,T }
360+ if isnothing (chunksize) && T === Nothing
312361 fc = with_contexts (f, contexts... )
313362 return jacobian! (jac, fc, x)
314363 else
@@ -318,9 +367,9 @@ function DI.jacobian!(
318367end
319368
320369function DI. jacobian (
321- f:: F , backend:: AutoForwardDiff{chunksize} , x, contexts:: Vararg{Context,C}
322- ) where {F,C,chunksize}
323- if isnothing (chunksize)
370+ f:: F , backend:: AutoForwardDiff{chunksize,T } , x, contexts:: Vararg{Context,C}
371+ ) where {F,C,chunksize,T }
372+ if isnothing (chunksize) && T === Nothing
324373 fc = with_contexts (f, contexts... )
325374 return jacobian (fc, x)
326375 else
@@ -339,7 +388,10 @@ function DI.prepare_jacobian(
339388 f:: F , backend:: AutoForwardDiff , x, contexts:: Vararg{Context,C}
340389) where {F,C}
341390 fc = with_contexts (f, contexts... )
342- return ForwardDiffOneArgJacobianPrep (JacobianConfig (fc, x, choose_chunk (backend, x)))
391+ chunk = choose_chunk (backend, x)
392+ tag = get_tag (fc, backend, x)
393+ config = JacobianConfig (fc, x, chunk, tag)
394+ return ForwardDiffOneArgJacobianPrep (config)
343395end
344396
345397function DI. value_and_jacobian! (
@@ -491,62 +543,80 @@ end
491543
492544# # Hessian
493545
494- # ## Unprepared
546+ # ## Unprepared, only when chunk size and tag are not specified
495547
496548function DI. hessian! (
497- f:: F , hess, :: AutoForwardDiff , x, contexts:: Vararg{Context,C}
498- ) where {F,C}
499- fc = with_contexts (f, contexts... )
500- return hessian! (hess, fc, x)
549+ f:: F , hess, backend:: AutoForwardDiff{chunksize,T} , x, contexts:: Vararg{Context,C}
550+ ) where {F,C,chunksize,T}
551+ if isnothing (chunksize) && T === Nothing
552+ fc = with_contexts (f, contexts... )
553+ return hessian! (hess, fc, x)
554+ else
555+ prep = DI. prepare_hessian (f, backend, x, contexts... )
556+ return DI. hessian! (f, hess, prep, backend, x, contexts... )
557+ end
501558end
502559
503- function DI. hessian (f:: F , :: AutoForwardDiff , x, contexts:: Vararg{Context,C} ) where {F,C}
504- fc = with_contexts (f, contexts... )
505- return hessian (fc, x)
560+ function DI. hessian (
561+ f:: F , backend:: AutoForwardDiff{chunksize,T} , x, contexts:: Vararg{Context,C}
562+ ) where {F,C,chunksize,T}
563+ if isnothing (chunksize) && T === Nothing
564+ fc = with_contexts (f, contexts... )
565+ return hessian (fc, x)
566+ else
567+ prep = DI. prepare_hessian (f, backend, x, contexts... )
568+ return DI. hessian (f, prep, backend, x, contexts... )
569+ end
506570end
507571
508572function DI. value_gradient_and_hessian! (
509- f:: F , grad, hess, :: AutoForwardDiff , x, contexts:: Vararg{Context,C}
510- ) where {F,C}
511- fc = with_contexts (f, contexts... )
512- result = DiffResult (one (eltype (x)), (grad, hess))
513- result = hessian! (result, fc, x)
514- y = DR. value (result)
515- grad === DR. gradient (result) || copyto! (grad, DR. gradient (result))
516- hess === DR. hessian (result) || copyto! (hess, DR. hessian (result))
517- return (y, grad, hess)
573+ f:: F , grad, hess, backend:: AutoForwardDiff{chunksize,T} , x, contexts:: Vararg{Context,C}
574+ ) where {F,C,chunksize,T}
575+ if isnothing (chunksize) && T === Nothing
576+ fc = with_contexts (f, contexts... )
577+ result = DiffResult (one (eltype (x)), (grad, hess))
578+ result = hessian! (result, fc, x)
579+ y = DR. value (result)
580+ grad === DR. gradient (result) || copyto! (grad, DR. gradient (result))
581+ hess === DR. hessian (result) || copyto! (hess, DR. hessian (result))
582+ return (y, grad, hess)
583+ else
584+ prep = DI. prepare_hessian (f, backend, x, contexts... )
585+ return DI. value_gradient_and_hessian! (f, grad, hess, prep, backend, x, contexts... )
586+ end
518587end
519588
520589function DI. value_gradient_and_hessian (
521- f:: F , :: AutoForwardDiff , x, contexts:: Vararg{Context,C}
522- ) where {F,C}
523- fc = with_contexts (f, contexts... )
524- result = HessianResult (x)
525- result = hessian! (result, fc, x)
526- return (DR. value (result), DR. gradient (result), DR. hessian (result))
590+ f:: F , backend:: AutoForwardDiff{chunksize,T} , x, contexts:: Vararg{Context,C}
591+ ) where {F,C,chunksize,T}
592+ if isnothing (chunksize) && T === Nothing
593+ fc = with_contexts (f, contexts... )
594+ result = HessianResult (x)
595+ result = hessian! (result, fc, x)
596+ return (DR. value (result), DR. gradient (result), DR. hessian (result))
597+ else
598+ prep = DI. prepare_hessian (f, backend, x, contexts... )
599+ return DI. value_gradient_and_hessian (f, prep, backend, x, contexts... )
600+ end
527601end
528602
529603# ## Prepared
530604
531- struct ForwardDiffHessianPrep{C1,C2,C3 } <: HessianPrep
605+ struct ForwardDiffHessianPrep{C1,C2} <: HessianPrep
532606 array_config:: C1
533- manual_result_config:: C2
534- auto_result_config:: C3
607+ result_config:: C2
535608end
536609
537610function DI. prepare_hessian (
538611 f:: F , backend:: AutoForwardDiff , x, contexts:: Vararg{Context,C}
539612) where {F,C}
540613 fc = with_contexts (f, contexts... )
541- manual_result = MutableDiffResult (
542- one (eltype (x)), (similar (x), similar (x, length (x), length (x)))
543- )
544- auto_result = HessianResult (x)
545614 chunk = choose_chunk (backend, x)
546- array_config = HessianConfig (fc, x, chunk)
547- manual_result_config = HessianConfig (fc, manual_result, x, chunk)
548- auto_result_config = HessianConfig (fc, auto_result, x, chunk)
549- return ForwardDiffHessianPrep (array_config, manual_result_config, auto_result_config)
615+ tag = get_tag (fc, backend, x)
616+ result = HessianResult (x)
617+ array_config = HessianConfig (fc, x, chunk, tag)
618+ result_config = HessianConfig (fc, result, x, chunk, tag)
619+ return ForwardDiffHessianPrep (array_config, result_config)
550620end
551621
552622function DI. hessian! (
@@ -579,7 +649,7 @@ function DI.value_gradient_and_hessian!(
579649) where {F,C}
580650 fc = with_contexts (f, contexts... )
581651 result = DiffResult (one (eltype (x)), (grad, hess))
582- result = hessian! (result, fc, x, prep. manual_result_config )
652+ result = hessian! (result, fc, x, prep. result_config )
583653 y = DR. value (result)
584654 grad === DR. gradient (result) || copyto! (grad, DR. gradient (result))
585655 hess === DR. hessian (result) || copyto! (hess, DR. hessian (result))
@@ -591,6 +661,6 @@ function DI.value_gradient_and_hessian(
591661) where {F,C}
592662 fc = with_contexts (f, contexts... )
593663 result = HessianResult (x)
594- result = hessian! (result, fc, x, prep. auto_result_config )
664+ result = hessian! (result, fc, x, prep. result_config )
595665 return (DR. value (result), DR. gradient (result), DR. hessian (result))
596666end
0 commit comments