@@ -8,23 +8,18 @@ struct ReactantGradientPrep{SIG, XR, GR, CG, CG!, CVG, CVG!} <: DI.GradientPrep{
88 compiled_value_and_gradient!:: CVG!
99end
1010
11- function DI. prepare_gradient_nokwarg (strict:: Val , f:: F , rebackend:: AutoReactant , x) where {F}
11+ function DI. prepare_gradient_nokwarg (
12+ strict:: Val , f:: F , rebackend:: AutoReactant , x, contexts:: Vararg{DI.Context, C}
13+ ) where {F, C}
1214 _sig = DI. signature (f, rebackend, x; strict)
1315 backend = rebackend. mode
14- xr = to_rarray (x)
15- gr = to_rarray (similar (x))
16- _gradient (_xr) = DI. gradient (f, backend, _xr)
17- _gradient! (_gr, _xr) = copy! (_gr, DI. gradient (f, backend, _xr))
18- _value_and_gradient (_xr) = DI. value_and_gradient (f, backend, _xr)
19- function _value_and_gradient! (_gr, _xr)
20- y, __gr = DI. value_and_gradient (f, backend, _xr)
21- copy! (_gr, __gr)
22- return y, _gr
23- end
24- compiled_gradient = @compile _gradient (xr)
25- compiled_gradient! = @compile _gradient! (gr, xr)
26- compiled_value_and_gradient = @compile _value_and_gradient (xr)
27- compiled_value_and_gradient! = @compile _value_and_gradient! (gr, xr)
16+ xr = to_reac (x)
17+ gr = to_reac (similar (x))
18+ contextsr = map (to_reac, contexts)
19+ compiled_gradient = @compile DI. gradient (f, backend, xr, contextsr... )
20+ compiled_gradient! = @compile DI. gradient! (f, gr, backend, xr, contextsr... )
21+ compiled_value_and_gradient = @compile DI. value_and_gradient (f, backend, xr, contextsr... )
22+ compiled_value_and_gradient! = @compile DI. value_and_gradient! (f, gr, backend, xr, contextsr... )
2823 return ReactantGradientPrep (
2924 _sig,
3025 xr,
@@ -37,45 +32,49 @@ function DI.prepare_gradient_nokwarg(strict::Val, f::F, rebackend::AutoReactant,
3732end
3833
3934function DI. gradient (
40- f:: F , prep:: ReactantGradientPrep , rebackend:: AutoReactant , x
41- ) where {F}
35+ f:: F , prep:: ReactantGradientPrep , rebackend:: AutoReactant , x, contexts :: Vararg{DI.Context, C}
36+ ) where {F, C }
4237 DI. check_prep (f, prep, rebackend, x)
38+ backend = rebackend. mode
4339 (; xr, compiled_gradient) = prep
44- copy ! (xr, x)
45- gr = compiled_gradient (xr )
46- g = convert ( typeof (x), gr )
47- return g
40+ copyto ! (xr, x)
41+ contextsr = map (to_reac, contexts )
42+ gr = compiled_gradient (f, backend, xr, contextsr ... )
43+ return gr
4844end
4945
5046function DI. value_and_gradient (
51- f:: F , prep:: ReactantGradientPrep , rebackend:: AutoReactant , x
52- ) where {F}
47+ f:: F , prep:: ReactantGradientPrep , rebackend:: AutoReactant , x, contexts :: Vararg{DI.Context, C}
48+ ) where {F, C }
5349 DI. check_prep (f, prep, rebackend, x)
50+ backend = rebackend. mode
5451 (; xr, compiled_value_and_gradient) = prep
55- copy! (xr, x)
56- yr, gr = compiled_value_and_gradient (xr)
57- y = convert (eltype (x), yr)
58- g = convert (typeof (x), gr)
59- return y, g
52+ copyto! (xr, x)
53+ contextsr = map (to_reac, contexts)
54+ yr, gr = compiled_value_and_gradient (f, backend, xr, contextsr... )
55+ return yr, gr
6056end
6157
6258function DI. gradient! (
63- f:: F , grad, prep:: ReactantGradientPrep , rebackend:: AutoReactant , x
64- ) where {F}
59+ f:: F , grad, prep:: ReactantGradientPrep , rebackend:: AutoReactant , x, contexts :: Vararg{DI.Context, C}
60+ ) where {F, C }
6561 DI. check_prep (f, prep, rebackend, x)
62+ backend = rebackend. mode
6663 (; xr, gr, compiled_gradient!) = prep
67- copy! (xr, x)
68- compiled_gradient! (gr, xr)
69- return copy! (grad, gr)
64+ copyto! (xr, x)
65+ contextsr = map (to_reac, contexts)
66+ compiled_gradient! (f, gr, backend, xr, contextsr... )
67+ return copyto! (grad, gr)
7068end
7169
7270function DI. value_and_gradient! (
73- f:: F , grad, prep:: ReactantGradientPrep , rebackend:: AutoReactant , x
74- ) where {F}
71+ f:: F , grad, prep:: ReactantGradientPrep , rebackend:: AutoReactant , x, contexts :: Vararg{DI.Context, C}
72+ ) where {F, C }
7573 DI. check_prep (f, prep, rebackend, x)
74+ backend = rebackend. mode
7675 (; xr, gr, compiled_value_and_gradient!) = prep
77- copy ! (xr, x)
78- yr, gr = compiled_value_and_gradient! (gr, xr )
79- y = convert ( eltype (x), yr )
80- return y, copy ! (grad, gr)
76+ copyto ! (xr, x)
77+ contextsr = map (to_reac, contexts )
78+ yr, gr = compiled_value_and_gradient! (f, gr, backend, xr, contextsr ... )
79+ return yr, copyto ! (grad, gr)
8180end
0 commit comments