11# # Pushforward
22
33function DI. prepare_pushforward (
4- f, :: AnyAutoEnzyme {<:Union{ForwardMode,Nothing}} , x, tx:: Tangents
4+ f, :: AutoEnzyme {<:Union{ForwardMode,Nothing}} , x, tx:: Tangents
55)
66 return NoPushforwardExtras ()
77end
88
99function DI. value_and_pushforward (
1010 f,
11- extras :: NoPushforwardExtras ,
12- backend:: AnyAutoEnzyme {<:Union{ForwardMode,Nothing}} ,
11+ :: NoPushforwardExtras ,
12+ backend:: AutoEnzyme {<:Union{ForwardMode,Nothing}} ,
1313 x,
14- tx:: Tangents ,
14+ tx:: Tangents{1} ,
1515)
16- ty = map (tx) do dx
17- only (DI . pushforward (f, extras, backend, x, Tangents (dx) ))
18- end
19- y = f (x )
20- return y, ty
16+ f_and_df = get_f_and_df (f, backend)
17+ dx_sametype = convert ( typeof (x), only (tx ))
18+ x_and_dx = Duplicated (x, dx_sametype)
19+ dy, y = autodiff ( forward_mode_withprimal (backend), f_and_df, x_and_dx )
20+ return y, Tangents (dy)
2121end
2222
2323function DI. value_and_pushforward (
2424 f,
2525 :: NoPushforwardExtras ,
26- backend:: AnyAutoEnzyme {<:Union{ForwardMode,Nothing}} ,
26+ backend:: AutoEnzyme {<:Union{ForwardMode,Nothing}} ,
2727 x,
28- tx:: Tangents{1} ,
29- )
30- dx = only (tx)
31- f_and_df = get_f_and_df (f, backend)
32- dx_sametype = convert (typeof (x), dx)
33- x_and_dx = Duplicated (x, dx_sametype)
34- y, new_dy = if backend isa AutoDeferredEnzyme
35- autodiff_deferred (forward_mode (backend), f_and_df, Duplicated, x_and_dx)
36- else
37- autodiff (forward_mode (backend), f_and_df, Duplicated, x_and_dx)
38- end
39- return y, Tangents (new_dy)
28+ tx:: Tangents{B} ,
29+ ) where {B}
30+ f_and_df = get_f_and_df (f, backend, Val (B))
31+ dxs_sametype = map (Fix1 (convert, typeof (x)), tx. d)
32+ x_and_dxs = BatchDuplicated (x, dxs_sametype)
33+ dys, y = autodiff (forward_mode_withprimal (backend), f_and_df, x_and_dxs)
34+ return y, Tangents (dys... )
4035end
4136
4237function DI. pushforward (
4338 f,
4439 :: NoPushforwardExtras ,
45- backend:: AnyAutoEnzyme {<:Union{ForwardMode,Nothing}} ,
40+ backend:: AutoEnzyme {<:Union{ForwardMode,Nothing}} ,
4641 x,
4742 tx:: Tangents{1} ,
4843)
49- dx = only (tx)
5044 f_and_df = get_f_and_df (f, backend)
51- dx_sametype = convert (typeof (x), dx )
45+ dx_sametype = convert (typeof (x), only (tx) )
5246 x_and_dx = Duplicated (x, dx_sametype)
53- new_dy = if backend isa AutoDeferredEnzyme
54- only (autodiff_deferred (forward_mode (backend), f_and_df, DuplicatedNoNeed, x_and_dx))
55- else
56- only (autodiff (forward_mode (backend), f_and_df, DuplicatedNoNeed, x_and_dx))
57- end
58- return Tangents (new_dy)
47+ dy = only (autodiff (forward_mode_noprimal (backend), f_and_df, x_and_dx))
48+ return Tangents (dy)
49+ end
50+
51+ function DI. pushforward (
52+ f,
53+ :: NoPushforwardExtras ,
54+ backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
55+ x,
56+ tx:: Tangents{B} ,
57+ ) where {B}
58+ f_and_df = get_f_and_df (f, backend, Val (B))
59+ dxs_sametype = map (Fix1 (convert, typeof (x)), tx. d)
60+ x_and_dxs = BatchDuplicated (x, dxs_sametype)
61+ dys = only (autodiff (forward_mode_noprimal (backend), f_and_df, x_and_dxs))
62+ return Tangents (dys... )
5963end
6064
6165function DI. value_and_pushforward! (
6266 f,
6367 ty:: Tangents ,
6468 extras:: NoPushforwardExtras ,
65- backend:: AnyAutoEnzyme {<:Union{ForwardMode,Nothing}} ,
69+ backend:: AutoEnzyme {<:Union{ForwardMode,Nothing}} ,
6670 x,
6771 tx:: Tangents ,
6872)
@@ -75,7 +79,7 @@ function DI.pushforward!(
7579 f,
7680 ty:: Tangents ,
7781 extras:: NoPushforwardExtras ,
78- backend:: AnyAutoEnzyme {<:Union{ForwardMode,Nothing}} ,
82+ backend:: AutoEnzyme {<:Union{ForwardMode,Nothing}} ,
7983 x,
8084 tx:: Tangents ,
8185)
8690# # Gradient
8791
8892struct EnzymeForwardGradientExtras{B,O} <: GradientExtras
89- shadow :: O
93+ shadows :: O
9094end
9195
9296function DI. prepare_gradient (
9397 f, backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} , x
9498)
9599 B = pick_batchsize (backend, length (x))
96- shadow = chunkedonehot (x, Val (B))
97- return EnzymeForwardGradientExtras {B,typeof(shadow )} (shadow )
100+ shadows = create_shadows ( Val (B), x )
101+ return EnzymeForwardGradientExtras {B,typeof(shadows )} (shadows )
98102end
99103
100104function DI. gradient (
@@ -104,17 +108,23 @@ function DI.gradient(
104108 x,
105109) where {B}
106110 f_and_df = get_f_and_df (f, backend)
107- grad_tup = gradient (forward_mode (backend), f_and_df, x, Val (B); shadow= extras. shadow)
108- return reshape (collect (grad_tup), size (x))
111+ derivs = gradient (
112+ forward_mode_noprimal (backend), f_and_df, x; chunk= Val (B), shadows= extras. shadows
113+ )
114+ return only (derivs)
109115end
110116
111117function DI. value_and_gradient (
112118 f,
113- extras:: EnzymeForwardGradientExtras ,
119+ extras:: EnzymeForwardGradientExtras{B} ,
114120 backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} ,
115121 x,
116- )
117- return f (x), DI. gradient (f, extras, backend, x)
122+ ) where {B}
123+ f_and_df = get_f_and_df (f, backend)
124+ (; derivs, val) = gradient (
125+ forward_mode_withprimal (backend), f_and_df, x; chunk= Val (B), shadows= extras. shadows
126+ )
127+ return val, only (derivs)
118128end
119129
120130function DI. gradient! (
@@ -124,9 +134,7 @@ function DI.gradient!(
124134 backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} ,
125135 x,
126136) where {B}
127- f_and_df = get_f_and_df (f, backend)
128- grad_tup = gradient (forward_mode (backend), f_and_df, x, Val (B); shadow= extras. shadow)
129- return copyto! (grad, grad_tup)
137+ return copyto! (grad, DI. gradient (f, extras, backend, x))
130138end
131139
132140function DI. value_and_gradient! (
@@ -136,27 +144,24 @@ function DI.value_and_gradient!(
136144 backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} ,
137145 x,
138146) where {B}
139- f_and_df = get_f_and_df (f, backend)
140- grad_tup = gradient (forward_mode (backend), f_and_df, x, Val (B); shadow= extras. shadow)
141- return f (x), copyto! (grad, grad_tup)
147+ y, new_grad = DI. value_and_gradient (f, extras, backend, x)
148+ return y, copyto! (grad, new_grad)
142149end
143150
144151# # Jacobian
145152
146153struct EnzymeForwardOneArgJacobianExtras{B,O} <: JacobianExtras
147- shadow:: O
154+ shadows:: O
155+ output_length:: Int
148156end
149157
150158function DI. prepare_jacobian (
151159 f, backend:: AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}} , x
152160)
161+ y = f (x)
153162 B = pick_batchsize (backend, length (x))
154- if B == 1
155- shadow = onehot (x)
156- else
157- shadow = chunkedonehot (x, Val (B))
158- end
159- return EnzymeForwardOneArgJacobianExtras {B,typeof(shadow)} (shadow)
163+ shadows = create_shadows (Val (B), x)
164+ return EnzymeForwardOneArgJacobianExtras {B,typeof(shadows)} (shadows, length (y))
160165end
161166
162167function DI. jacobian (
@@ -166,21 +171,25 @@ function DI.jacobian(
166171 x,
167172) where {B}
168173 f_and_df = get_f_and_df (f, backend)
169- jac_wrongshape = jacobian (
170- forward_mode (backend), f_and_df, x, Val (B); shadow = extras. shadow
174+ derivs = jacobian (
175+ forward_mode_noprimal (backend), f_and_df, x; chunk = Val (B), shadows = extras. shadows
171176 )
172- nx = length (x)
173- ny = length (jac_wrongshape) ÷ length (x)
174- return reshape (jac_wrongshape, ny, nx)
177+ jac_tensor = only (derivs)
178+ return maybe_reshape (jac_tensor, extras. output_length, length (x))
175179end
176180
177181function DI. value_and_jacobian (
178182 f,
179- extras:: EnzymeForwardOneArgJacobianExtras ,
183+ extras:: EnzymeForwardOneArgJacobianExtras{B} ,
180184 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}} ,
181185 x,
182- )
183- return f (x), DI. jacobian (f, extras, backend, x)
186+ ) where {B}
187+ f_and_df = get_f_and_df (f, backend)
188+ (; derivs, val) = jacobian (
189+ forward_mode_withprimal (backend), f_and_df, x; chunk= Val (B), shadows= extras. shadows
190+ )
191+ jac_tensor = only (derivs)
192+ return val, maybe_reshape (jac_tensor, extras. output_length, length (x))
184193end
185194
186195function DI. jacobian! (
0 commit comments