11# # Pullback
22
3- DI. prepare_pullback (f, :: AutoReverseDiff , x, ty:: NTuple ) = NoPullbackPrep ()
3+ function DI. prepare_pullback (
4+ f, :: AutoReverseDiff , x, ty:: NTuple , contexts:: Vararg{Context,C}
5+ ) where {C}
6+ return NoPullbackPrep ()
7+ end
48
59function DI. value_and_pullback (
6- f, :: NoPullbackPrep , :: AutoReverseDiff , x:: AbstractArray , ty:: NTuple
7- )
8- y = f (x)
10+ f,
11+ :: NoPullbackPrep ,
12+ :: AutoReverseDiff ,
13+ x:: AbstractArray ,
14+ ty:: NTuple ,
15+ contexts:: Vararg{Context,C} ,
16+ ) where {C}
17+ fc = with_contexts (f, contexts... )
18+ y = fc (x)
19+ dotclosure (z, dy) = dot (fc (z), dy)
920 tx = map (ty) do dy
1021 if y isa Number
11- dy .* gradient (f , x)
22+ dy .* gradient (fc , x)
1223 elseif y isa AbstractArray
13- gradient (z -> dot ( f (z) , dy), x)
24+ gradient (Fix2 (dotclosure , dy), x)
1425 end
1526 end
1627 return y, tx
1728end
1829
1930function DI. value_and_pullback! (
20- f, :: NoPullbackPrep , tx:: NTuple , :: AutoReverseDiff , x:: AbstractArray , ty:: NTuple
21- )
22- y = f (x)
31+ f,
32+ :: NoPullbackPrep ,
33+ tx:: NTuple ,
34+ :: AutoReverseDiff ,
35+ x:: AbstractArray ,
36+ ty:: NTuple ,
37+ contexts:: Vararg{Context,C} ,
38+ ) where {C}
39+ fc = with_contexts (f, contexts... )
40+ y = fc (x)
41+ dotclosure (z, dy) = dot (fc (z), dy)
2342 for b in eachindex (tx, ty)
2443 dx, dy = tx[b], ty[b]
2544 if y isa Number
26- dx = gradient! (dx, f , x)
45+ dx = gradient! (dx, fc , x)
2746 dx .*= dy
2847 elseif y isa AbstractArray
29- gradient! (dx, z -> dot ( f (z) , dy), x)
48+ gradient! (dx, Fix2 (dotclosure , dy), x)
3049 end
3150 end
3251 return y, tx
3352end
3453
3554function DI. value_and_pullback (
36- f, :: NoPullbackPrep , backend:: AutoReverseDiff , x:: Number , ty:: NTuple
37- )
55+ f,
56+ :: NoPullbackPrep ,
57+ backend:: AutoReverseDiff ,
58+ x:: Number ,
59+ ty:: NTuple ,
60+ contexts:: Vararg{Context,C} ,
61+ ) where {C}
3862 x_array = [x]
39- f_array = f ∘ only
40- y, tx_array = DI. value_and_pullback (f_array, backend, x_array, ty)
63+ f_array (x_array, args ... ) = f ( only (x_array), args ... )
64+ y, tx_array = DI. value_and_pullback (f_array, backend, x_array, ty, contexts ... )
4165 return y, only .(tx_array)
4266end
4367
4468# # Gradient
4569
70+ # ## Without contexts
71+
4672struct ReverseDiffGradientPrep{T} <: GradientPrep
4773 tape:: T
4874end
@@ -56,7 +82,7 @@ function DI.prepare_gradient(f, ::AutoReverseDiff{Compile}, x) where {Compile}
5682end
5783
5884function DI. value_and_gradient! (
59- f, grad:: AbstractArray , prep:: ReverseDiffGradientPrep , :: AutoReverseDiff , x
85+ f, grad, prep:: ReverseDiffGradientPrep , :: AutoReverseDiff , x
6086)
6187 y = f (x) # TODO : remove once ReverseDiff#251 is fixed
6288 result = MutableDiffResult (y, (grad,))
@@ -71,18 +97,55 @@ function DI.value_and_gradient(
7197 return DI. value_and_gradient! (f, grad, prep, backend, x)
7298end
7399
74- function DI. gradient! (
75- _f, grad, prep:: ReverseDiffGradientPrep , :: AutoReverseDiff , x:: AbstractArray
76- )
100+ function DI. gradient! (_f, grad, prep:: ReverseDiffGradientPrep , :: AutoReverseDiff , x)
77101 return gradient! (grad, prep. tape, x)
78102end
79103
80104function DI. gradient (_f, prep:: ReverseDiffGradientPrep , :: AutoReverseDiff , x)
81105 return gradient! (prep. tape, x)
82106end
83107
108+ # ## With contexts
109+
110+ function DI. prepare_gradient (f, :: AutoReverseDiff , x, contexts:: Vararg{Context,C} ) where {C}
111+ return NoGradientPrep ()
112+ end
113+
114+ function DI. value_and_gradient! (
115+ f, grad, :: NoGradientPrep , :: AutoReverseDiff , x, contexts:: Vararg{Context,C}
116+ ) where {C}
117+ fc = with_contexts (f, contexts... )
118+ y = fc (x) # TODO : remove once ReverseDiff#251 is fixed
119+ result = MutableDiffResult (y, (grad,))
120+ result = gradient! (result, fc, x)
121+ return DiffResults. value (result), DiffResults. derivative (result)
122+ end
123+
124+ function DI. value_and_gradient (
125+ f, prep:: NoGradientPrep , backend:: AutoReverseDiff , x, contexts:: Vararg{Context,C}
126+ ) where {C}
127+ grad = similar (x)
128+ return DI. value_and_gradient! (f, grad, prep, backend, x, contexts... )
129+ end
130+
131+ function DI. gradient! (
132+ f, grad, :: NoGradientPrep , :: AutoReverseDiff , x, contexts:: Vararg{Context,C}
133+ ) where {C}
134+ fc = with_contexts (f, contexts... )
135+ return gradient! (grad, fc, x)
136+ end
137+
138+ function DI. gradient (
139+ f, :: NoGradientPrep , :: AutoReverseDiff , x, contexts:: Vararg{Context,C}
140+ ) where {C}
141+ fc = with_contexts (f, contexts... )
142+ return gradient (fc, x)
143+ end
144+
84145# # Jacobian
85146
147+ # ## Without contexts
148+
86149struct ReverseDiffOneArgJacobianPrep{T} <: JacobianPrep
87150 tape:: T
88151end
@@ -116,8 +179,47 @@ function DI.jacobian(f, prep::ReverseDiffOneArgJacobianPrep, ::AutoReverseDiff,
116179 return jacobian! (prep. tape, x)
117180end
118181
182+ # ## With contexts
183+
184+ function DI. prepare_jacobian (f, :: AutoReverseDiff , x, contexts:: Vararg{Context,C} ) where {C}
185+ return NoJacobianPrep ()
186+ end
187+
188+ function DI. value_and_jacobian! (
189+ f, jac, :: NoJacobianPrep , :: AutoReverseDiff , x, contexts:: Vararg{Context,C}
190+ ) where {C}
191+ fc = with_contexts (f, contexts... )
192+ y = fc (x)
193+ result = MutableDiffResult (y, (jac,))
194+ result = jacobian! (result, fc, x)
195+ return DiffResults. value (result), DiffResults. derivative (result)
196+ end
197+
198+ function DI. value_and_jacobian (
199+ f, :: NoJacobianPrep , :: AutoReverseDiff , x, contexts:: Vararg{Context,C}
200+ ) where {C}
201+ fc = with_contexts (f, contexts... )
202+ return fc (x), jacobian (fc, x)
203+ end
204+
205+ function DI. jacobian! (
206+ f, jac, :: NoJacobianPrep , :: AutoReverseDiff , x, contexts:: Vararg{Context,C}
207+ ) where {C}
208+ fc = with_contexts (f, contexts... )
209+ return jacobian! (jac, fc, x)
210+ end
211+
212+ function DI. jacobian (
213+ f, :: NoJacobianPrep , :: AutoReverseDiff , x, contexts:: Vararg{Context,C}
214+ ) where {C}
215+ fc = with_contexts (f, contexts... )
216+ return jacobian (fc, x)
217+ end
218+
119219# # Hessian
120220
221+ # ## Without contexts
222+
121223struct ReverseDiffHessianPrep{T} <: HessianPrep
122224 tape:: T
123225end
@@ -152,11 +254,54 @@ end
152254function DI. value_gradient_and_hessian (
153255 f, prep:: ReverseDiffHessianPrep , :: AutoReverseDiff , x
154256)
155- result = MutableDiffResult (
156- one (eltype (x)), (similar (x), similar (x, length (x), length (x)))
157- )
257+ y = f (x) # TODO : remove once ReverseDiff#251 is fixed
258+ result = MutableDiffResult (y, (similar (x), similar (x, length (x), length (x))))
158259 result = hessian! (result, prep. tape, x)
159260 return (
160261 DiffResults. value (result), DiffResults. gradient (result), DiffResults. hessian (result)
161262 )
162263end
264+
265+ # ## With contexts
266+
267+ function DI. prepare_hessian (f, :: AutoReverseDiff , x, contexts:: Vararg{Context,C} ) where {C}
268+ return NoHessianPrep ()
269+ end
270+
271+ function DI. hessian! (
272+ f, hess, :: NoHessianPrep , :: AutoReverseDiff , x, contexts:: Vararg{Context,C}
273+ ) where {C}
274+ fc = with_contexts (f, contexts... )
275+ return hessian! (hess, fc, x)
276+ end
277+
278+ function DI. hessian (
279+ f, :: NoHessianPrep , :: AutoReverseDiff , x, contexts:: Vararg{Context,C}
280+ ) where {C}
281+ fc = with_contexts (f, contexts... )
282+ return hessian (fc, x)
283+ end
284+
285+ function DI. value_gradient_and_hessian! (
286+ f, grad, hess, :: NoHessianPrep , :: AutoReverseDiff , x, contexts:: Vararg{Context,C}
287+ ) where {C}
288+ fc = with_contexts (f, contexts... )
289+ y = fc (x) # TODO : remove once ReverseDiff#251 is fixed
290+ result = MutableDiffResult (y, (grad, hess))
291+ result = hessian! (result, fc, x)
292+ return (
293+ DiffResults. value (result), DiffResults. gradient (result), DiffResults. hessian (result)
294+ )
295+ end
296+
297+ function DI. value_gradient_and_hessian (
298+ f, :: NoHessianPrep , :: AutoReverseDiff , x, contexts:: Vararg{Context,C}
299+ ) where {C}
300+ fc = with_contexts (f, contexts... )
301+ y = fc (x) # TODO : remove once ReverseDiff#251 is fixed
302+ result = MutableDiffResult (y, (similar (x), similar (x, length (x), length (x))))
303+ result = hessian! (result, fc, x)
304+ return (
305+ DiffResults. value (result), DiffResults. gradient (result), DiffResults. hessian (result)
306+ )
307+ end
0 commit comments