11# # Pushforward
22
3+ struct EnzymeOneArgPushforwardPrep{SIG,DF,DC} <: DI.PushforwardPrep{SIG}
4+ _sig:: Val{SIG}
5+ df:: DF
6+ context_shadows:: DC
7+ end
8+
39function DI. prepare_pushforward_nokwarg (
410 strict:: Val ,
511 f:: F ,
612 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
713 x,
8- tx:: NTuple ,
14+ tx:: NTuple{B} ,
915 contexts:: Vararg{DI.Context,C} ;
10- ) where {F,C}
16+ ) where {F,C,B }
1117 _sig = DI. signature (f, backend, x, tx, contexts... ; strict)
12- return DI. NoPushforwardPrep (_sig)
18+ df = function_shadow (f, backend, Val (B))
19+ mode = forward_withprimal (backend)
20+ context_shadows = make_context_shadows (backend, mode, Val (B), contexts... )
21+ return EnzymeOneArgPushforwardPrep (_sig, df, context_shadows)
1322end
1423
1524function DI. value_and_pushforward (
1625 f:: F ,
17- prep:: DI.NoPushforwardPrep ,
26+ prep:: EnzymeOneArgPushforwardPrep ,
1827 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
1928 x,
2029 tx:: NTuple{1} ,
2130 contexts:: Vararg{DI.Context,C} ,
2231) where {F,C}
2332 DI. check_prep (f, prep, backend, x, tx, contexts... )
33+ (; df, context_shadows) = prep
2434 mode = forward_withprimal (backend)
25- f_and_df = get_f_and_df ( f, backend, mode )
35+ f_and_df = get_f_and_df_prepared! (df, f, backend, Val ( 1 ) )
2636 dx = only (tx)
2737 x_and_dx = Duplicated (x, dx)
28- annotated_contexts = translate (backend, mode , Val (1 ), contexts ... )
38+ annotated_contexts = translate_prepared! (context_shadows, contexts , Val (1 ))
2939 dy, y = autodiff (mode, f_and_df, x_and_dx, annotated_contexts... )
3040 return y, (dy,)
3141end
3242
3343function DI. value_and_pushforward (
3444 f:: F ,
35- prep:: DI.NoPushforwardPrep ,
45+ prep:: EnzymeOneArgPushforwardPrep ,
3646 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
3747 x,
3848 tx:: NTuple{B} ,
3949 contexts:: Vararg{DI.Context,C} ,
4050) where {F,B,C}
4151 DI. check_prep (f, prep, backend, x, tx, contexts... )
52+ (; df, context_shadows) = prep
4253 mode = forward_withprimal (backend)
43- f_and_df = get_f_and_df (f, backend, mode , Val (B))
54+ f_and_df = get_f_and_df_prepared! (df, f, backend , Val (B))
4455 x_and_tx = BatchDuplicated (x, tx)
45- annotated_contexts = translate (backend, mode , Val (B), contexts ... )
56+ annotated_contexts = translate_prepared! (context_shadows, contexts , Val (B))
4657 ty, y = autodiff (mode, f_and_df, x_and_tx, annotated_contexts... )
4758 return y, values (ty)
4859end
4960
5061function DI. pushforward (
5162 f:: F ,
52- prep:: DI.NoPushforwardPrep ,
63+ prep:: EnzymeOneArgPushforwardPrep ,
5364 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
5465 x,
5566 tx:: NTuple{1} ,
5667 contexts:: Vararg{DI.Context,C} ,
5768) where {F,C}
5869 DI. check_prep (f, prep, backend, x, tx, contexts... )
70+ (; df, context_shadows) = prep
5971 mode = forward_noprimal (backend)
60- f_and_df = get_f_and_df ( f, backend, mode )
72+ f_and_df = get_f_and_df_prepared! (df, f, backend, Val ( 1 ) )
6173 dx = only (tx)
6274 x_and_dx = Duplicated (x, dx)
63- annotated_contexts = translate (backend, mode , Val (1 ), contexts ... )
75+ annotated_contexts = translate_prepared! (context_shadows, contexts , Val (1 ))
6476 dy = only (autodiff (mode, f_and_df, x_and_dx, annotated_contexts... ))
6577 return (dy,)
6678end
6779
6880function DI. pushforward (
6981 f:: F ,
70- prep:: DI.NoPushforwardPrep ,
82+ prep:: EnzymeOneArgPushforwardPrep ,
7183 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
7284 x,
7385 tx:: NTuple{B} ,
7486 contexts:: Vararg{DI.Context,C} ,
7587) where {F,B,C}
7688 DI. check_prep (f, prep, backend, x, tx, contexts... )
89+ (; df, context_shadows) = prep
7790 mode = forward_noprimal (backend)
78- f_and_df = get_f_and_df (f, backend, mode , Val (B))
91+ f_and_df = get_f_and_df_prepared! (df, f, backend , Val (B))
7992 x_and_tx = BatchDuplicated (x, tx)
80- annotated_contexts = translate (backend, mode , Val (B), contexts ... )
93+ annotated_contexts = translate_prepared! (context_shadows, contexts , Val (B))
8194 ty = only (autodiff (mode, f_and_df, x_and_tx, annotated_contexts... ))
8295 return values (ty)
8396end
8497
8598function DI. value_and_pushforward! (
8699 f:: F ,
87100 ty:: NTuple ,
88- prep:: DI.NoPushforwardPrep ,
101+ prep:: EnzymeOneArgPushforwardPrep ,
89102 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
90103 x,
91104 tx:: NTuple ,
101114function DI. pushforward! (
102115 f:: F ,
103116 ty:: NTuple ,
104- prep:: DI.NoPushforwardPrep ,
117+ prep:: EnzymeOneArgPushforwardPrep ,
105118 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
106119 x,
107120 tx:: NTuple ,
@@ -116,10 +129,12 @@ end
116129
117130# # Gradient
118131
119- struct EnzymeForwardGradientPrep{SIG,B,O} <: DI.GradientPrep{SIG}
132+ struct EnzymeForwardGradientPrep{SIG,B,DF,DC, O} <: DI.GradientPrep{SIG}
120133 _sig:: Val{SIG}
121134 _valB:: Val{B}
122- shadows:: O
135+ df:: DF
136+ context_shadows:: DC
137+ basis_shadows:: O
123138end
124139
125140function DI. prepare_gradient_nokwarg (
@@ -131,8 +146,11 @@ function DI.prepare_gradient_nokwarg(
131146) where {F,C}
132147 _sig = DI. signature (f, backend, x, contexts... ; strict)
133148 valB = to_val (DI. pick_batchsize (backend, x))
134- shadows = create_shadows (valB, x)
135- return EnzymeForwardGradientPrep (_sig, valB, shadows)
149+ df = function_shadow (f, backend, valB)
150+ mode = forward_withprimal (backend)
151+ context_shadows = make_context_shadows (backend, mode, valB, contexts... )
152+ basis_shadows = create_shadows (valB, x)
153+ return EnzymeForwardGradientPrep (_sig, valB, df, context_shadows, basis_shadows)
136154end
137155
138156function DI. gradient (
@@ -143,11 +161,12 @@ function DI.gradient(
143161 contexts:: Vararg{DI.Constant,C} ,
144162) where {F,SIG,B,C}
145163 DI. check_prep (f, prep, backend, x, contexts... )
164+ (; df, context_shadows, basis_shadows) = prep
146165 mode = forward_noprimal (backend)
147- f_and_df = get_f_and_df ( f, backend, mode )
148- annotated_contexts = translate (backend, mode , Val (B), contexts ... )
166+ f_and_df = get_f_and_df_prepared! (df, f, backend, Val (B) )
167+ annotated_contexts = translate_prepared! (context_shadows, contexts , Val (B))
149168 derivs = gradient (
150- mode, f_and_df, x, annotated_contexts... ; chunk= Val (B), shadows= prep . shadows
169+ mode, f_and_df, x, annotated_contexts... ; chunk= Val (B), shadows= basis_shadows
151170 )
152171 return first (derivs)
153172end
@@ -160,11 +179,12 @@ function DI.value_and_gradient(
160179 contexts:: Vararg{DI.Constant,C} ,
161180) where {F,SIG,B,C}
162181 DI. check_prep (f, prep, backend, x, contexts... )
182+ (; df, context_shadows, basis_shadows) = prep
163183 mode = forward_withprimal (backend)
164- f_and_df = get_f_and_df ( f, backend, mode )
165- annotated_contexts = translate (backend, mode , Val (B), contexts ... )
184+ f_and_df = get_f_and_df_prepared! (df, f, backend, Val (B) )
185+ annotated_contexts = translate_prepared! (context_shadows, contexts , Val (B))
166186 (; derivs, val) = gradient (
167- mode, f_and_df, x, annotated_contexts... ; chunk= Val (B), shadows= prep . shadows
187+ mode, f_and_df, x, annotated_contexts... ; chunk= Val (B), shadows= basis_shadows
168188 )
169189 return val, first (derivs)
170190end
@@ -196,10 +216,12 @@ end
196216
197217# # Jacobian
198218
199- struct EnzymeForwardOneArgJacobianPrep{SIG,B,O} <: DI.JacobianPrep{SIG}
219+ struct EnzymeForwardOneArgJacobianPrep{SIG,B,DF,DC, O} <: DI.JacobianPrep{SIG}
200220 _sig:: Val{SIG}
201221 _valB:: Val{B}
202- shadows:: O
222+ df:: DF
223+ context_shadows:: DC
224+ basis_shadows:: O
203225 output_length:: Int
204226end
205227
@@ -213,8 +235,13 @@ function DI.prepare_jacobian_nokwarg(
213235 _sig = DI. signature (f, backend, x, contexts... ; strict)
214236 y = f (x, map (DI. unwrap, contexts)... )
215237 valB = to_val (DI. pick_batchsize (backend, x))
216- shadows = create_shadows (valB, x)
217- return EnzymeForwardOneArgJacobianPrep (_sig, valB, shadows, length (y))
238+ mode = forward_withprimal (backend)
239+ df = function_shadow (f, backend, valB)
240+ context_shadows = make_context_shadows (backend, mode, valB, contexts... )
241+ basis_shadows = create_shadows (valB, x)
242+ return EnzymeForwardOneArgJacobianPrep (
243+ _sig, valB, df, context_shadows, basis_shadows, length (y)
244+ )
218245end
219246
220247function DI. jacobian (
@@ -225,14 +252,15 @@ function DI.jacobian(
225252 contexts:: Vararg{DI.Constant,C} ,
226253) where {F,SIG,B,C}
227254 DI. check_prep (f, prep, backend, x, contexts... )
255+ (; df, context_shadows, basis_shadows, output_length) = prep
228256 mode = forward_noprimal (backend)
229- f_and_df = get_f_and_df ( f, backend, mode )
230- annotated_contexts = translate (backend, mode , Val (B), contexts ... )
257+ f_and_df = get_f_and_df_prepared! (df, f, backend, Val (B) )
258+ annotated_contexts = translate_prepared! (context_shadows, contexts , Val (B))
231259 derivs = jacobian (
232- mode, f_and_df, x, annotated_contexts... ; chunk= Val (B), shadows= prep . shadows
260+ mode, f_and_df, x, annotated_contexts... ; chunk= Val (B), shadows= basis_shadows
233261 )
234262 jac_tensor = first (derivs)
235- return maybe_reshape (jac_tensor, prep . output_length, length (x))
263+ return maybe_reshape (jac_tensor, output_length, length (x))
236264end
237265
238266function DI. value_and_jacobian (
@@ -243,14 +271,15 @@ function DI.value_and_jacobian(
243271 contexts:: Vararg{DI.Constant,C} ,
244272) where {F,SIG,B,C}
245273 DI. check_prep (f, prep, backend, x, contexts... )
274+ (; df, context_shadows, basis_shadows, output_length) = prep
246275 mode = forward_withprimal (backend)
247- f_and_df = get_f_and_df ( f, backend, mode )
248- annotated_contexts = translate (backend, mode , Val (B), contexts ... )
276+ f_and_df = get_f_and_df_prepared! (df, f, backend, Val (B) )
277+ annotated_contexts = translate_prepared! (context_shadows, contexts , Val (B))
249278 (; derivs, val) = jacobian (
250- mode, f_and_df, x, annotated_contexts... ; chunk= Val (B), shadows= prep . shadows
279+ mode, f_and_df, x, annotated_contexts... ; chunk= Val (B), shadows= basis_shadows
251280 )
252281 jac_tensor = first (derivs)
253- return val, maybe_reshape (jac_tensor, prep . output_length, length (x))
282+ return val, maybe_reshape (jac_tensor, output_length, length (x))
254283end
255284
256285function DI. jacobian! (
0 commit comments