@@ -3,13 +3,15 @@ module DifferentiationInterfaceZygoteExt
33using ADTypes: AutoForwardDiff, AutoZygote
44import DifferentiationInterface as DI
55using DifferentiationInterface:
6+ Context,
67 HVPExtras,
78 NoGradientExtras,
89 NoHessianExtras,
910 NoJacobianExtras,
1011 NoPullbackExtras,
1112 PullbackExtras,
12- Tangents
13+ Tangents,
14+ unwrap
1315using ForwardDiff: ForwardDiff
1416using Zygote:
1517 ZygoteRuleConfig, gradient, hessian, jacobian, pullback, withgradient, withjacobian
@@ -25,63 +27,83 @@ struct ZygotePullbackExtrasSamePoint{Y,PB} <: PullbackExtras
2527 pb:: PB
2628end
2729
28- DI. prepare_pullback (f, :: AutoZygote , x, ty:: Tangents ) = NoPullbackExtras ()
30+ function DI. prepare_pullback (f, :: AutoZygote , x, ty:: Tangents , contexts:: Vararg{Context} )
31+ return NoPullbackExtras ()
32+ end
2933
3034function DI. prepare_pullback_same_point (
31- f, :: NoPullbackExtras , :: AutoZygote , x, ty:: Tangents
35+ f, :: NoPullbackExtras , :: AutoZygote , x, ty:: Tangents , contexts :: Vararg{Context}
3236)
33- y, pb = pullback (f, x)
37+ y, pb = pullback (f, x, map (unwrap, contexts) ... )
3438 return ZygotePullbackExtrasSamePoint (y, pb)
3539end
3640
37- function DI. value_and_pullback (f, :: NoPullbackExtras , :: AutoZygote , x, ty:: Tangents )
38- y, pb = pullback (f, x)
41+ function DI. value_and_pullback (
42+ f, :: NoPullbackExtras , :: AutoZygote , x, ty:: Tangents , contexts:: Vararg{Context}
43+ )
44+ y, pb = pullback (f, x, map (unwrap, contexts)... )
3945 tx = map (ty) do dy
40- only (pb (dy))
46+ first (pb (dy))
4147 end
4248 return y, tx
4349end
4450
4551function DI. value_and_pullback (
46- f, extras:: ZygotePullbackExtrasSamePoint , :: AutoZygote , x, ty:: Tangents
52+ f,
53+ extras:: ZygotePullbackExtrasSamePoint ,
54+ :: AutoZygote ,
55+ x,
56+ ty:: Tangents ,
57+ contexts:: Vararg{Context} ,
4758)
4859 @compat (; y, pb) = extras
4960 tx = map (ty) do dy
50- only (pb (dy))
61+ first (pb (dy))
5162 end
5263 return copy (y), tx
5364end
5465
5566function DI. pullback (
56- f, extras:: ZygotePullbackExtrasSamePoint , :: AutoZygote , x, ty:: Tangents
67+ f,
68+ extras:: ZygotePullbackExtrasSamePoint ,
69+ :: AutoZygote ,
70+ x,
71+ ty:: Tangents ,
72+ contexts:: Vararg{Context} ,
5773)
5874 @compat (; pb) = extras
5975 tx = map (ty) do dy
60- only (pb (dy))
76+ first (pb (dy))
6177 end
6278 return tx
6379end
6480
6581# # Gradient
6682
67- DI. prepare_gradient (f, :: AutoZygote , x) = NoGradientExtras ()
83+ DI. prepare_gradient (f, :: AutoZygote , x, contexts :: Vararg{Context} ) = NoGradientExtras ()
6884
69- function DI. value_and_gradient (f, :: NoGradientExtras , :: AutoZygote , x)
70- @compat (; val, grad) = withgradient (f, x)
71- return val, only (grad)
85+ function DI. value_and_gradient (
86+ f, :: NoGradientExtras , :: AutoZygote , x, contexts:: Vararg{Context}
87+ )
88+ @compat (; val, grad) = withgradient (f, x, map (unwrap, contexts)... )
89+ return val, first (grad)
7290end
7391
74- function DI. gradient (f, :: NoGradientExtras , :: AutoZygote , x)
75- return only (gradient (f, x))
92+ function DI. gradient (f, :: NoGradientExtras , :: AutoZygote , x, contexts :: Vararg{Context} )
93+ return first (gradient (f, x, map (unwrap, contexts) ... ))
7694end
7795
78- function DI. value_and_gradient! (f, grad, extras:: NoGradientExtras , backend:: AutoZygote , x)
79- y, new_grad = DI. value_and_gradient (f, extras, backend, x)
96+ function DI. value_and_gradient! (
97+ f, grad, extras:: NoGradientExtras , backend:: AutoZygote , x, contexts:: Vararg{Context}
98+ )
99+ y, new_grad = DI. value_and_gradient (f, extras, backend, x, contexts... )
80100 return y, copyto! (grad, new_grad)
81101end
82102
83- function DI. gradient! (f, grad, extras:: NoGradientExtras , backend:: AutoZygote , x)
84- return copyto! (grad, DI. gradient (f, extras, backend, x))
103+ function DI. gradient! (
104+ f, grad, extras:: NoGradientExtras , backend:: AutoZygote , x, contexts:: Vararg{Context}
105+ )
106+ return copyto! (grad, DI. gradient (f, extras, backend, x, contexts... ))
85107end
86108
87109# # Jacobian
0 commit comments