@@ -4,7 +4,14 @@ using ADTypes: AutoForwardDiff, AutoZygote
44import DifferentiationInterface as DI
55using ForwardDiff: ForwardDiff
66using Zygote:
7- ZygoteRuleConfig, gradient, hessian, jacobian, pullback, withgradient, withjacobian
7+ Buffer,
8+ ZygoteRuleConfig,
9+ gradient,
10+ hessian,
11+ jacobian,
12+ pullback,
13+ withgradient,
14+ withjacobian
815
916struct ZygoteNothingError <: Exception
1017 f
@@ -27,6 +34,9 @@ check_nothing(::Any, f, x, contexts) = nothing
2734DI. check_available (:: AutoZygote ) = true
2835DI. inplace_support (:: AutoZygote ) = DI. InPlaceNotSupported ()
2936
37+ translate (c:: DI.Context ) = DI. unwrap (c)
38+ translate (c:: DI.Cache ) = Buffer (DI. unwrap (c))
39+
3040# # Pullback
3141
3242struct ZygotePullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep
@@ -35,32 +45,22 @@ struct ZygotePullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep
3545end
3646
3747function DI. prepare_pullback (
38- f, :: AutoZygote , x, ty:: NTuple , contexts:: Vararg{DI.ConstantOrFunctionOrBackend ,C}
48+ f, :: AutoZygote , x, ty:: NTuple , contexts:: Vararg{DI.Context ,C}
3949) where {C}
4050 return DI. NoPullbackPrep ()
4151end
4252
4353function DI. prepare_pullback_same_point (
44- f,
45- :: DI.NoPullbackPrep ,
46- :: AutoZygote ,
47- x,
48- ty:: NTuple ,
49- contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C} ,
54+ f, :: DI.NoPullbackPrep , :: AutoZygote , x, ty:: NTuple , contexts:: Vararg{DI.Context,C}
5055) where {C}
51- y, pb = pullback (f, x, map (DI . unwrap , contexts)... )
56+ y, pb = pullback (f, x, map (translate , contexts)... )
5257 return ZygotePullbackPrepSamePoint (y, pb)
5358end
5459
5560function DI. value_and_pullback (
56- f,
57- :: DI.NoPullbackPrep ,
58- :: AutoZygote ,
59- x,
60- ty:: NTuple ,
61- contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C} ,
61+ f, :: DI.NoPullbackPrep , :: AutoZygote , x, ty:: NTuple , contexts:: Vararg{DI.Context,C}
6262) where {C}
63- y, pb = pullback (f, x, map (DI . unwrap , contexts)... )
63+ y, pb = pullback (f, x, map (translate , contexts)... )
6464 tx = map (ty) do dy
6565 first (pb (dy))
6666 end
@@ -74,7 +74,7 @@ function DI.value_and_pullback(
7474 :: AutoZygote ,
7575 x,
7676 ty:: NTuple ,
77- contexts:: Vararg{DI.ConstantOrFunctionOrBackend ,C} ,
77+ contexts:: Vararg{DI.Context ,C} ,
7878) where {C}
7979 (; y, pb) = prep
8080 tx = map (ty) do dy
@@ -90,7 +90,7 @@ function DI.pullback(
9090 :: AutoZygote ,
9191 x,
9292 ty:: NTuple ,
93- contexts:: Vararg{DI.ConstantOrFunctionOrBackend ,C} ,
93+ contexts:: Vararg{DI.Context ,C} ,
9494) where {C}
9595 (; pb) = prep
9696 tx = map (ty) do dy
@@ -102,112 +102,72 @@ end
102102
103103# # Gradient
104104
105- function DI. prepare_gradient (
106- f, :: AutoZygote , x, contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C}
107- ) where {C}
105+ function DI. prepare_gradient (f, :: AutoZygote , x, contexts:: Vararg{DI.Context,C} ) where {C}
108106 return DI. NoGradientPrep ()
109107end
110108
111109function DI. value_and_gradient (
112- f,
113- :: DI.NoGradientPrep ,
114- :: AutoZygote ,
115- x,
116- contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C} ,
110+ f, :: DI.NoGradientPrep , :: AutoZygote , x, contexts:: Vararg{DI.Context,C}
117111) where {C}
118- (; val, grad) = withgradient (f, x, map (DI . unwrap , contexts)... )
112+ (; val, grad) = withgradient (f, x, map (translate , contexts)... )
119113 check_nothing (first (grad), f, x, contexts)
120114 return val, first (grad)
121115end
122116
123117function DI. gradient (
124- f,
125- :: DI.NoGradientPrep ,
126- :: AutoZygote ,
127- x,
128- contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C} ,
118+ f, :: DI.NoGradientPrep , :: AutoZygote , x, contexts:: Vararg{DI.Context,C}
129119) where {C}
130- grad = gradient (f, x, map (DI . unwrap , contexts)... )
120+ grad = gradient (f, x, map (translate , contexts)... )
131121 check_nothing (first (grad), f, x, contexts)
132122 return first (grad)
133123end
134124
135125function DI. value_and_gradient! (
136- f,
137- grad,
138- prep:: DI.NoGradientPrep ,
139- backend:: AutoZygote ,
140- x,
141- contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C} ,
126+ f, grad, prep:: DI.NoGradientPrep , backend:: AutoZygote , x, contexts:: Vararg{DI.Context,C}
142127) where {C}
143128 y, new_grad = DI. value_and_gradient (f, prep, backend, x, contexts... )
144129 return y, copyto! (grad, new_grad)
145130end
146131
147132function DI. gradient! (
148- f,
149- grad,
150- prep:: DI.NoGradientPrep ,
151- backend:: AutoZygote ,
152- x,
153- contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C} ,
133+ f, grad, prep:: DI.NoGradientPrep , backend:: AutoZygote , x, contexts:: Vararg{DI.Context,C}
154134) where {C}
155135 return copyto! (grad, DI. gradient (f, prep, backend, x, contexts... ))
156136end
157137
158138# # Jacobian
159139
160- function DI. prepare_jacobian (
161- f, :: AutoZygote , x, contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C}
162- ) where {C}
140+ function DI. prepare_jacobian (f, :: AutoZygote , x, contexts:: Vararg{DI.Context,C} ) where {C}
163141 return DI. NoJacobianPrep ()
164142end
165143
166144function DI. value_and_jacobian (
167- f,
168- :: DI.NoJacobianPrep ,
169- :: AutoZygote ,
170- x,
171- contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C} ,
145+ f, :: DI.NoJacobianPrep , :: AutoZygote , x, contexts:: Vararg{DI.Context,C}
172146) where {C}
173- y = f (x, map (DI . unwrap , contexts)... )
147+ y = f (x, map (translate , contexts)... )
174148 # https://github.com/FluxML/Zygote.jl/issues/1506
175- jac = jacobian (f, x, map (DI . unwrap , contexts)... )
149+ jac = jacobian (f, x, map (translate , contexts)... )
176150 check_nothing (first (jac), f, x, contexts)
177151 return y, first (jac)
178152end
179153
180154function DI. jacobian (
181- f,
182- :: DI.NoJacobianPrep ,
183- :: AutoZygote ,
184- x,
185- contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C} ,
155+ f, :: DI.NoJacobianPrep , :: AutoZygote , x, contexts:: Vararg{DI.Context,C}
186156) where {C}
187- jac = jacobian (f, x, map (DI . unwrap , contexts)... )
157+ jac = jacobian (f, x, map (translate , contexts)... )
188158 check_nothing (first (jac), f, x, contexts)
189159 return first (jac)
190160end
191161
192162function DI. value_and_jacobian! (
193- f,
194- jac,
195- prep:: DI.NoJacobianPrep ,
196- backend:: AutoZygote ,
197- x,
198- contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C} ,
163+ f, jac, prep:: DI.NoJacobianPrep , backend:: AutoZygote , x, contexts:: Vararg{DI.Context,C}
199164) where {C}
200165 y, new_jac = DI. value_and_jacobian (f, prep, backend, x, contexts... )
201166 return y, copyto! (jac, new_jac)
202167end
203168
204169function DI. jacobian! (
205- f,
206- jac,
207- prep:: DI.NoJacobianPrep ,
208- backend:: AutoZygote ,
209- x,
210- contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C} ,
170+ f, jac, prep:: DI.NoJacobianPrep , backend:: AutoZygote , x, contexts:: Vararg{DI.Context,C}
211171) where {C}
212172 return copyto! (jac, DI. jacobian (f, prep, backend, x, contexts... ))
213173end
@@ -217,22 +177,13 @@ end
217177# Beware, this uses ForwardDiff for the inner differentiation
218178
219179function DI. prepare_hvp (
220- f,
221- backend:: AutoZygote ,
222- x,
223- tx:: NTuple ,
224- contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C} ,
180+ f, backend:: AutoZygote , x, tx:: NTuple , contexts:: Vararg{DI.Context,C}
225181) where {C}
226182 return DI. prepare_hvp (f, DI. SecondOrder (AutoForwardDiff (), backend), x, tx, contexts... )
227183end
228184
229185function DI. hvp (
230- f,
231- prep:: DI.HVPPrep ,
232- backend:: AutoZygote ,
233- x,
234- tx:: NTuple ,
235- contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C} ,
186+ f, prep:: DI.HVPPrep , backend:: AutoZygote , x, tx:: NTuple , contexts:: Vararg{DI.Context,C}
236187) where {C}
237188 return DI. hvp (f, prep, DI. SecondOrder (AutoForwardDiff (), backend), x, tx, contexts... )
238189end
@@ -244,20 +195,15 @@ function DI.hvp!(
244195 backend:: AutoZygote ,
245196 x,
246197 tx:: NTuple ,
247- contexts:: Vararg{DI.ConstantOrFunctionOrBackend ,C} ,
198+ contexts:: Vararg{DI.Context ,C} ,
248199) where {C}
249200 return DI. hvp! (
250201 f, tg, prep, DI. SecondOrder (AutoForwardDiff (), backend), x, tx, contexts...
251202 )
252203end
253204
254205function DI. gradient_and_hvp (
255- f,
256- prep:: DI.HVPPrep ,
257- backend:: AutoZygote ,
258- x,
259- tx:: NTuple ,
260- contexts:: Vararg{DI.ConstantOrFunctionOrBackend,C} ,
206+ f, prep:: DI.HVPPrep , backend:: AutoZygote , x, tx:: NTuple , contexts:: Vararg{DI.Context,C}
261207) where {C}
262208 return DI. gradient_and_hvp (
263209 f, prep, DI. SecondOrder (AutoForwardDiff (), backend), x, tx, contexts...
@@ -272,7 +218,7 @@ function DI.gradient_and_hvp!(
272218 backend:: AutoZygote ,
273219 x,
274220 tx:: NTuple ,
275- contexts:: Vararg{DI.ConstantOrFunctionOrBackend ,C} ,
221+ contexts:: Vararg{DI.Context ,C} ,
276222) where {C}
277223 return DI. gradient_and_hvp! (
278224 f, grad, tg, prep, DI. SecondOrder (AutoForwardDiff (), backend), x, tx, contexts...
0 commit comments