22
33function DI. prepare_pushforward (
44 f:: F ,
5- :: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
5+ backend :: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
66 x,
77 tx:: NTuple ,
8- contexts:: Vararg{DI.Context,C} ,
8+ contexts:: Vararg{DI.Context,C} ;
9+ strict:: Bool = false ,
910) where {F,C}
10- return DI. NoPushforwardPrep ()
11+ SIG = DI. signature (f, backend, x, tx, contexts... ; strict)
12+ return DI. NoPushforwardPrep {SIG} ()
1113end
1214
1315function DI. value_and_pushforward (
1416 f:: F ,
15- :: DI.NoPushforwardPrep ,
17+ prep :: DI.NoPushforwardPrep ,
1618 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
1719 x,
1820 tx:: NTuple{1} ,
1921 contexts:: Vararg{DI.Context,C} ,
2022) where {F,C}
23+ DI. check_prep (f, prep, backend, x, tx, contexts... )
2124 mode = forward_withprimal (backend)
2225 f_and_df = get_f_and_df (f, backend, mode)
2326 dx = only (tx)
2932
3033function DI. value_and_pushforward (
3134 f:: F ,
32- :: DI.NoPushforwardPrep ,
35+ prep :: DI.NoPushforwardPrep ,
3336 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
3437 x,
3538 tx:: NTuple{B} ,
3639 contexts:: Vararg{DI.Context,C} ,
3740) where {F,B,C}
41+ DI. check_prep (f, prep, backend, x, tx, contexts... )
3842 mode = forward_withprimal (backend)
3943 f_and_df = get_f_and_df (f, backend, mode, Val (B))
4044 x_and_tx = BatchDuplicated (x, tx)
4549
4650function DI. pushforward (
4751 f:: F ,
48- :: DI.NoPushforwardPrep ,
52+ prep :: DI.NoPushforwardPrep ,
4953 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
5054 x,
5155 tx:: NTuple{1} ,
5256 contexts:: Vararg{DI.Context,C} ,
5357) where {F,C}
58+ DI. check_prep (f, prep, backend, x, tx, contexts... )
5459 mode = forward_noprimal (backend)
5560 f_and_df = get_f_and_df (f, backend, mode)
5661 dx = only (tx)
6267
6368function DI. pushforward (
6469 f:: F ,
65- :: DI.NoPushforwardPrep ,
70+ prep :: DI.NoPushforwardPrep ,
6671 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing}} ,
6772 x,
6873 tx:: NTuple{B} ,
6974 contexts:: Vararg{DI.Context,C} ,
7075) where {F,B,C}
76+ DI. check_prep (f, prep, backend, x, tx, contexts... )
7177 mode = forward_noprimal (backend)
7278 f_and_df = get_f_and_df (f, backend, mode, Val (B))
7379 x_and_tx = BatchDuplicated (x, tx)
@@ -85,6 +91,7 @@ function DI.value_and_pushforward!(
8591 tx:: NTuple ,
8692 contexts:: Vararg{DI.Context,C} ,
8793) where {F,C}
94+ DI. check_prep (f, prep, backend, x, tx, contexts... )
8895 # dy cannot be passed anyway
8996 y, new_ty = DI. value_and_pushforward (f, prep, backend, x, tx, contexts... )
9097 foreach (copyto!, ty, new_ty)
@@ -100,6 +107,7 @@ function DI.pushforward!(
100107 tx:: NTuple ,
101108 contexts:: Vararg{DI.Context,C} ,
102109) where {F,C}
110+ DI. check_prep (f, prep, backend, x, tx, contexts... )
103111 # dy cannot be passed anyway
104112 new_ty = DI. pushforward (f, prep, backend, x, tx, contexts... )
105113 foreach (copyto!, ty, new_ty)
@@ -108,23 +116,25 @@ end
108116
109117# # Gradient
110118
111- struct EnzymeForwardGradientPrep{B,O} <: DI.GradientPrep
119+ struct EnzymeForwardGradientPrep{SIG, B,O} <: DI.GradientPrep{SIG}
112120 shadows:: O
113121end
114122
115- function EnzymeForwardGradientPrep (:: Val{B} , shadows:: O ) where {B,O}
116- return EnzymeForwardGradientPrep {B,O} (shadows)
123+ function EnzymeForwardGradientPrep (:: Type{SIG} , :: Val{B} , shadows:: O ) where {SIG, B,O}
124+ return EnzymeForwardGradientPrep {SIG, B,O} (shadows)
117125end
118126
119127function DI. prepare_gradient (
120128 f:: F ,
121129 backend:: AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}} ,
122130 x,
123- contexts:: Vararg{DI.Constant,C} ,
131+ contexts:: Vararg{DI.Constant,C} ;
132+ strict:: Bool = false ,
124133) where {F,C}
134+ SIG = DI. signature (f, backend, x, contexts... ; strict)
125135 valB = to_val (DI. pick_batchsize (backend, x))
126136 shadows = create_shadows (valB, x)
127- return EnzymeForwardGradientPrep (valB, shadows)
137+ return EnzymeForwardGradientPrep (SIG, valB, shadows)
128138end
129139
130140function DI. gradient (
@@ -134,6 +144,7 @@ function DI.gradient(
134144 x,
135145 contexts:: Vararg{DI.Constant,C} ,
136146) where {F,B,C}
147+ DI. check_prep (f, prep, backend, x, contexts... )
137148 mode = forward_noprimal (backend)
138149 f_and_df = get_f_and_df (f, backend, mode)
139150 annotated_contexts = translate (backend, mode, Val (B), contexts... )
@@ -150,6 +161,7 @@ function DI.value_and_gradient(
150161 x,
151162 contexts:: Vararg{DI.Constant,C} ,
152163) where {F,B,C}
164+ DI. check_prep (f, prep, backend, x, contexts... )
153165 mode = forward_withprimal (backend)
154166 f_and_df = get_f_and_df (f, backend, mode)
155167 annotated_contexts = translate (backend, mode, Val (B), contexts... )
@@ -167,6 +179,7 @@ function DI.gradient!(
167179 x,
168180 contexts:: Vararg{DI.Constant,C} ,
169181) where {F,B,C}
182+ DI. check_prep (f, prep, backend, x, contexts... )
170183 return copyto! (grad, DI. gradient (f, prep, backend, x, contexts... ))
171184end
172185
@@ -178,33 +191,36 @@ function DI.value_and_gradient!(
178191 x,
179192 contexts:: Vararg{DI.Constant,C} ,
180193) where {F,B,C}
194+ DI. check_prep (f, prep, backend, x, contexts... )
181195 y, new_grad = DI. value_and_gradient (f, prep, backend, x, contexts... )
182196 return y, copyto! (grad, new_grad)
183197end
184198
185199# # Jacobian
186200
187- struct EnzymeForwardOneArgJacobianPrep{B,O} <: DI.JacobianPrep
201+ struct EnzymeForwardOneArgJacobianPrep{SIG, B,O} <: DI.JacobianPrep{SIG}
188202 shadows:: O
189203 output_length:: Int
190204end
191205
192206function EnzymeForwardOneArgJacobianPrep (
193- :: Val{B} , shadows:: O , output_length:: Integer
194- ) where {B,O}
195- return EnzymeForwardOneArgJacobianPrep {B,O} (shadows, output_length)
207+ :: Type{SIG} , :: Val{B} , shadows:: O , output_length:: Integer
208+ ) where {SIG, B,O}
209+ return EnzymeForwardOneArgJacobianPrep {SIG, B,O} (shadows, output_length)
196210end
197211
198212function DI. prepare_jacobian (
199213 f:: F ,
200214 backend:: AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}} ,
201215 x,
202- contexts:: Vararg{DI.Constant,C} ,
216+ contexts:: Vararg{DI.Constant,C} ;
217+ strict:: Bool = false ,
203218) where {F,C}
219+ SIG = DI. signature (f, backend, x, contexts... ; strict)
204220 y = f (x, map (DI. unwrap, contexts)... )
205221 valB = to_val (DI. pick_batchsize (backend, x))
206222 shadows = create_shadows (valB, x)
207- return EnzymeForwardOneArgJacobianPrep (valB, shadows, length (y))
223+ return EnzymeForwardOneArgJacobianPrep (SIG, valB, shadows, length (y))
208224end
209225
210226function DI. jacobian (
@@ -214,6 +230,7 @@ function DI.jacobian(
214230 x,
215231 contexts:: Vararg{DI.Constant,C} ,
216232) where {F,B,C}
233+ DI. check_prep (f, prep, backend, contexts... )
217234 mode = forward_noprimal (backend)
218235 f_and_df = get_f_and_df (f, backend, mode)
219236 annotated_contexts = translate (backend, mode, Val (B), contexts... )
@@ -231,6 +248,7 @@ function DI.value_and_jacobian(
231248 x,
232249 contexts:: Vararg{DI.Constant,C} ,
233250) where {F,B,C}
251+ DI. check_prep (f, prep, backend, contexts... )
234252 mode = forward_withprimal (backend)
235253 f_and_df = get_f_and_df (f, backend, mode)
236254 annotated_contexts = translate (backend, mode, Val (B), contexts... )
@@ -249,6 +267,7 @@ function DI.jacobian!(
249267 x,
250268 contexts:: Vararg{DI.Constant,C} ,
251269) where {F,C}
270+ DI. check_prep (f, prep, backend, contexts... )
252271 return copyto! (jac, DI. jacobian (f, prep, backend, x, contexts... ))
253272end
254273
@@ -260,6 +279,7 @@ function DI.value_and_jacobian!(
260279 x,
261280 contexts:: Vararg{DI.Constant,C} ,
262281) where {F,C}
282+ DI. check_prep (f, prep, backend, contexts... )
263283 y, new_jac = DI. value_and_jacobian (f, prep, backend, x, contexts... )
264284 return y, copyto! (jac, new_jac)
265285end
0 commit comments