11# # Pushforward
22
3- DI. prepare_pushforward (f, :: AutoForwardOrNothingEnzyme , x, dx) = NoPushforwardExtras ()
3+ function DI. prepare_pushforward (f, :: AnyAutoEnzyme{<:Union{ForwardMode,Nothing}} , x, dx)
4+ return NoPushforwardExtras ()
5+ end
46
57function DI. value_and_pushforward (
6- f, backend:: AutoForwardOrNothingEnzyme , x, dx, :: NoPushforwardExtras
8+ f, backend:: AnyAutoEnzyme{<:Union{ForwardMode,Nothing}} , x, dx, :: NoPushforwardExtras
79)
810 dx_sametype = convert (typeof (x), dx)
9- y, new_dy = autodiff (
10- forward_mode (backend), Const (f), Duplicated, Duplicated (x, dx_sametype)
11- )
11+ x_and_dx = Duplicated (x, dx_sametype)
12+ y, new_dy = if backend isa AutoDeferredEnzyme
13+ autodiff_deferred (forward_mode (backend), f, Duplicated, x_and_dx)
14+ else
15+ autodiff (forward_mode (backend), Const (f), Duplicated, x_and_dx)
16+ end
1217 return y, new_dy
1318end
1419
1520function DI. pushforward (
16- f, backend:: AutoForwardOrNothingEnzyme , x, dx, :: NoPushforwardExtras
21+ f, backend:: AnyAutoEnzyme{<:Union{ForwardMode,Nothing}} , x, dx, :: NoPushforwardExtras
1722)
1823 dx_sametype = convert (typeof (x), dx)
19- new_dy = only (
20- autodiff (
21- forward_mode (backend), Const (f), DuplicatedNoNeed, Duplicated (x, dx_sametype)
22- ),
23- )
24+ x_and_dx = Duplicated (x, dx_sametype)
25+ new_dy = if backend isa AutoDeferredEnzyme
26+ only (autodiff_deferred (forward_mode (backend), f, DuplicatedNoNeed, x_and_dx))
27+ else
28+ only (autodiff (forward_mode (backend), Const (f), DuplicatedNoNeed, x_and_dx))
29+ end
2430 return new_dy
2531end
2632
2733function DI. value_and_pushforward! (
28- f, dy, backend:: AutoForwardOrNothingEnzyme , x, dx, extras:: NoPushforwardExtras
34+ f,
35+ dy,
36+ backend:: AnyAutoEnzyme{<:Union{ForwardMode,Nothing}} ,
37+ x,
38+ dx,
39+ extras:: NoPushforwardExtras ,
2940)
3041 # dy cannot be passed anyway
3142 y, new_dy = DI. value_and_pushforward (f, backend, x, dx, extras)
3243 return y, copyto! (dy, new_dy)
3344end
3445
3546function DI. pushforward! (
36- f, dy, backend:: AutoForwardOrNothingEnzyme , x, dx, extras:: NoPushforwardExtras
47+ f,
48+ dy,
49+ backend:: AnyAutoEnzyme{<:Union{ForwardMode,Nothing}} ,
50+ x,
51+ dx,
52+ extras:: NoPushforwardExtras ,
3753)
3854 # dy cannot be passed anyway
3955 return copyto! (dy, DI. pushforward (f, backend, x, dx, extras))
@@ -45,34 +61,34 @@ struct EnzymeForwardGradientExtras{C,O} <: GradientExtras
4561 shadow:: O
4662end
4763
48- function DI. prepare_gradient (f, :: AutoForwardEnzyme , x)
64+ function DI. prepare_gradient (f, :: AutoEnzyme{<:ForwardMode} , x)
4965 C = pick_chunksize (length (x))
5066 shadow = chunkedonehot (x, Val (C))
5167 return EnzymeForwardGradientExtras {C,typeof(shadow)} (shadow)
5268end
5369
5470function DI. gradient (
55- f, backend:: AutoForwardEnzyme , x, extras:: EnzymeForwardGradientExtras{C}
71+ f, backend:: AutoEnzyme{<:ForwardMode} , x, extras:: EnzymeForwardGradientExtras{C}
5672) where {C}
5773 grad_tup = gradient (forward_mode (backend), f, x, Val {C} (); shadow= extras. shadow)
5874 return reshape (collect (grad_tup), size (x))
5975end
6076
6177function DI. value_and_gradient (
62- f, backend:: AutoForwardEnzyme , x, extras:: EnzymeForwardGradientExtras
78+ f, backend:: AutoEnzyme{<:ForwardMode} , x, extras:: EnzymeForwardGradientExtras
6379)
6480 return f (x), DI. gradient (f, backend, x, extras)
6581end
6682
6783function DI. gradient! (
68- f, grad, backend:: AutoForwardEnzyme , x, extras:: EnzymeForwardGradientExtras{C}
84+ f, grad, backend:: AutoEnzyme{<:ForwardMode} , x, extras:: EnzymeForwardGradientExtras{C}
6985) where {C}
7086 grad_tup = gradient (forward_mode (backend), f, x, Val {C} (); shadow= extras. shadow)
7187 return copyto! (grad, grad_tup)
7288end
7389
7490function DI. value_and_gradient! (
75- f, grad, backend:: AutoForwardEnzyme , x, extras:: EnzymeForwardGradientExtras{C}
91+ f, grad, backend:: AutoEnzyme{<:ForwardMode} , x, extras:: EnzymeForwardGradientExtras{C}
7692) where {C}
7793 grad_tup = gradient (forward_mode (backend), f, x, Val {C} (); shadow= extras. shadow)
7894 return f (x), copyto! (grad, grad_tup)
@@ -84,14 +100,17 @@ struct EnzymeForwardOneArgJacobianExtras{C,O} <: JacobianExtras
84100 shadow:: O
85101end
86102
87- function DI. prepare_jacobian (f, :: AutoForwardOrNothingEnzyme , x)
103+ function DI. prepare_jacobian (f, :: AutoEnzyme{<:Union{ForwardMode,Nothing}} , x)
88104 C = pick_chunksize (length (x))
89105 shadow = chunkedonehot (x, Val (C))
90106 return EnzymeForwardOneArgJacobianExtras {C,typeof(shadow)} (shadow)
91107end
92108
93109function DI. jacobian (
94- f, backend:: AutoForwardOrNothingEnzyme , x, extras:: EnzymeForwardOneArgJacobianExtras{C}
110+ f,
111+ backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
112+ x,
113+ extras:: EnzymeForwardOneArgJacobianExtras{C} ,
95114) where {C}
96115 jac_wrongshape = jacobian (forward_mode (backend), f, x, Val {C} (); shadow= extras. shadow)
97116 nx = length (x)
@@ -100,15 +119,18 @@ function DI.jacobian(
100119end
101120
102121function DI. value_and_jacobian (
103- f, backend:: AutoForwardOrNothingEnzyme , x, extras:: EnzymeForwardOneArgJacobianExtras
122+ f,
123+ backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
124+ x,
125+ extras:: EnzymeForwardOneArgJacobianExtras ,
104126)
105127 return f (x), DI. jacobian (f, backend, x, extras)
106128end
107129
108130function DI. jacobian! (
109131 f,
110132 jac,
111- backend:: AutoForwardOrNothingEnzyme ,
133+ backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
112134 x,
113135 extras:: EnzymeForwardOneArgJacobianExtras ,
114136)
118140function DI. value_and_jacobian! (
119141 f,
120142 jac,
121- backend:: AutoForwardOrNothingEnzyme ,
143+ backend:: AnyAutoEnzyme{<:Union{ForwardMode,Nothing}} ,
122144 x,
123145 extras:: EnzymeForwardOneArgJacobianExtras ,
124146)
0 commit comments