@@ -3,7 +3,7 @@ module DifferentiationInterfaceZygoteExt
33using ADTypes: AutoForwardDiff, AutoZygote
44import DifferentiationInterface as DI
55using DifferentiationInterface:
6- Context ,
6+ Constant ,
77 HVPExtras,
88 NoGradientExtras,
99 NoHessianExtras,
@@ -27,19 +27,19 @@ struct ZygotePullbackExtrasSamePoint{Y,PB} <: PullbackExtras
2727 pb:: PB
2828end
2929
30- function DI. prepare_pullback (f, :: AutoZygote , x, ty:: Tangents , contexts:: Vararg{Context } )
30+ function DI. prepare_pullback (f, :: AutoZygote , x, ty:: Tangents , contexts:: Vararg{Constant } )
3131 return NoPullbackExtras ()
3232end
3333
3434function DI. prepare_pullback_same_point (
35- f, :: NoPullbackExtras , :: AutoZygote , x, ty:: Tangents , contexts:: Vararg{Context }
35+ f, :: NoPullbackExtras , :: AutoZygote , x, ty:: Tangents , contexts:: Vararg{Constant }
3636)
3737 y, pb = pullback (f, x, map (unwrap, contexts)... )
3838 return ZygotePullbackExtrasSamePoint (y, pb)
3939end
4040
4141function DI. value_and_pullback (
42- f, :: NoPullbackExtras , :: AutoZygote , x, ty:: Tangents , contexts:: Vararg{Context }
42+ f, :: NoPullbackExtras , :: AutoZygote , x, ty:: Tangents , contexts:: Vararg{Constant }
4343)
4444 y, pb = pullback (f, x, map (unwrap, contexts)... )
4545 tx = map (ty) do dy
@@ -54,7 +54,7 @@ function DI.value_and_pullback(
5454 :: AutoZygote ,
5555 x,
5656 ty:: Tangents ,
57- contexts:: Vararg{Context } ,
57+ contexts:: Vararg{Constant } ,
5858)
5959 @compat (; y, pb) = extras
6060 tx = map (ty) do dy
@@ -69,7 +69,7 @@ function DI.pullback(
6969 :: AutoZygote ,
7070 x,
7171 ty:: Tangents ,
72- contexts:: Vararg{Context } ,
72+ contexts:: Vararg{Constant } ,
7373)
7474 @compat (; pb) = extras
7575 tx = map (ty) do dy
8080
8181# # Gradient
8282
83- DI. prepare_gradient (f, :: AutoZygote , x, contexts:: Vararg{Context } ) = NoGradientExtras ()
83+ DI. prepare_gradient (f, :: AutoZygote , x, contexts:: Vararg{Constant } ) = NoGradientExtras ()
8484
8585function DI. value_and_gradient (
86- f, :: NoGradientExtras , :: AutoZygote , x, contexts:: Vararg{Context }
86+ f, :: NoGradientExtras , :: AutoZygote , x, contexts:: Vararg{Constant }
8787)
8888 @compat (; val, grad) = withgradient (f, x, map (unwrap, contexts)... )
8989 return val, first (grad)
9090end
9191
92- function DI. gradient (f, :: NoGradientExtras , :: AutoZygote , x, contexts:: Vararg{Context } )
92+ function DI. gradient (f, :: NoGradientExtras , :: AutoZygote , x, contexts:: Vararg{Constant } )
9393 return first (gradient (f, x, map (unwrap, contexts)... ))
9494end
9595
9696function DI. value_and_gradient! (
97- f, grad, extras:: NoGradientExtras , backend:: AutoZygote , x, contexts:: Vararg{Context }
97+ f, grad, extras:: NoGradientExtras , backend:: AutoZygote , x, contexts:: Vararg{Constant }
9898)
9999 y, new_grad = DI. value_and_gradient (f, extras, backend, x, contexts... )
100100 return y, copyto! (grad, new_grad)
101101end
102102
103103function DI. gradient! (
104- f, grad, extras:: NoGradientExtras , backend:: AutoZygote , x, contexts:: Vararg{Context }
104+ f, grad, extras:: NoGradientExtras , backend:: AutoZygote , x, contexts:: Vararg{Constant }
105105)
106106 return copyto! (grad, DI. gradient (f, extras, backend, x, contexts... ))
107107end
0 commit comments