4949
5050# # Pullback
5151
52+ struct EnzymeReverseOneArgPullbackPrep{Y} <: DI.PullbackPrep
53+ y_example:: Y # useful to create return activity
54+ end
55+
5256function DI. prepare_pullback (
5357 f:: F ,
5458 :: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
5559 x,
5660 ty:: NTuple ,
5761 contexts:: Vararg{DI.Context,C} ,
5862) where {F,C}
59- return DI. NoPullbackPrep ()
63+ y = f (x, map (DI. unwrap, contexts)... )
64+ return EnzymeReverseOneArgPullbackPrep (y)
6065end
6166
6267# ## Out-of-place
6368
6469function DI. value_and_pullback (
6570 f:: F ,
66- :: DI.NoPullbackPrep ,
71+ prep :: EnzymeReverseOneArgPullbackPrep ,
6772 backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
6873 x,
6974 ty:: NTuple{1} ,
@@ -72,7 +77,7 @@ function DI.value_and_pullback(
7277 mode = reverse_split_withprimal (backend)
7378 f_and_df = force_annotation (get_f_and_df (f, backend, mode))
7479 IA = guess_activity (typeof (x), mode)
75- RA = guess_activity (eltype (ty ), mode)
80+ RA = guess_activity (typeof (prep . y_example ), mode)
7681 dx = make_zero (x)
7782 annotated_contexts = translate (backend, mode, Val (1 ), contexts... )
7883 dinputs, result = seeded_autodiff_thunk (
8893
8994function DI. value_and_pullback (
9095 f:: F ,
91- :: DI.NoPullbackPrep ,
96+ prep :: EnzymeReverseOneArgPullbackPrep ,
9297 backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
9398 x,
9499 ty:: NTuple{B} ,
@@ -97,7 +102,7 @@ function DI.value_and_pullback(
97102 mode = reverse_split_withprimal (backend)
98103 f_and_df = force_annotation (get_f_and_df (f, backend, mode, Val (B)))
99104 IA = batchify_activity (guess_activity (typeof (x), mode), Val (B))
100- RA = batchify_activity (guess_activity (eltype (ty ), mode), Val (B))
105+ RA = batchify_activity (guess_activity (typeof (prep . y_example ), mode), Val (B))
101106 tx = ntuple (_ -> make_zero (x), Val (B))
102107 annotated_contexts = translate (backend, mode, Val (B), contexts... )
103108 dinputs, result = batch_seeded_autodiff_thunk (
113118
114119function DI. pullback (
115120 f:: F ,
116- prep:: DI.NoPullbackPrep ,
121+ prep:: EnzymeReverseOneArgPullbackPrep ,
117122 backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
118123 x,
119124 ty:: NTuple ,
@@ -127,51 +132,51 @@ end
127132function DI. value_and_pullback! (
128133 f:: F ,
129134 tx:: NTuple{1} ,
130- :: DI.NoPullbackPrep ,
135+ prep :: EnzymeReverseOneArgPullbackPrep ,
131136 backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
132137 x,
133138 ty:: NTuple{1} ,
134139 contexts:: Vararg{DI.Context,C} ,
135140) where {F,C}
136141 mode = reverse_split_withprimal (backend)
137142 f_and_df = force_annotation (get_f_and_df (f, backend, mode))
138- RA = guess_activity (eltype (ty ), mode)
143+ RA = guess_activity (typeof (prep . y_example ), mode)
139144 dx_righttype = convert (typeof (x), only (tx))
140145 make_zero! (dx_righttype)
141146 annotated_contexts = translate (backend, mode, Val (1 ), contexts... )
142147 _, result = seeded_autodiff_thunk (
143148 mode, only (ty), f_and_df, RA, Duplicated (x, dx_righttype), annotated_contexts...
144149 )
145- only (tx) === dx_righttype || copyto ! (only (tx), dx_righttype)
150+ copyto_if_different_addresses ! (only (tx), dx_righttype)
146151 return result, tx
147152end
148153
149154function DI. value_and_pullback! (
150155 f:: F ,
151156 tx:: NTuple{B} ,
152- :: DI.NoPullbackPrep ,
157+ prep :: EnzymeReverseOneArgPullbackPrep ,
153158 backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
154159 x,
155160 ty:: NTuple{B} ,
156161 contexts:: Vararg{DI.Context,C} ,
157162) where {F,B,C}
158163 mode = reverse_split_withprimal (backend)
159164 f_and_df = force_annotation (get_f_and_df (f, backend, mode, Val (B)))
160- RA = batchify_activity (guess_activity (eltype (ty ), mode), Val (B))
165+ RA = batchify_activity (guess_activity (typeof (prep . y_example ), mode), Val (B))
161166 tx_righttype = map (Fix1 (convert, typeof (x)), tx)
162167 make_zero! (tx_righttype)
163168 annotated_contexts = translate (backend, mode, Val (B), contexts... )
164169 _, result = batch_seeded_autodiff_thunk (
165170 mode, ty, f_and_df, RA, BatchDuplicated (x, tx_righttype), annotated_contexts...
166171 )
167- foreach (copyto !, tx, tx_righttype)
172+ foreach (copyto_if_different_addresses !, tx, tx_righttype)
168173 return result, tx
169174end
170175
171176function DI. pullback! (
172177 f:: F ,
173178 tx:: NTuple ,
174- prep:: DI.NoPullbackPrep ,
179+ prep:: EnzymeReverseOneArgPullbackPrep ,
175180 backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
176181 x,
177182 ty:: NTuple ,
@@ -265,7 +270,7 @@ function DI.gradient!(
265270 make_zero! (grad_righttype)
266271 annotated_contexts = translate (backend, mode, Val (1 ), contexts... )
267272 autodiff (mode, f_and_df, Active, Duplicated (x, grad_righttype), annotated_contexts... )
268- grad === grad_righttype || copyto ! (grad, grad_righttype)
273+ copyto_if_different_addresses ! (grad, grad_righttype)
269274 return grad
270275end
271276
@@ -295,70 +300,6 @@ function DI.value_and_gradient!(
295300 _, y = autodiff (
296301 mode, f_and_df, Active, Duplicated (x, grad_righttype), annotated_contexts...
297302 )
298- grad === grad_righttype || copyto ! (grad, grad_righttype)
303+ copyto_if_different_addresses ! (grad, grad_righttype)
299304 return y, grad
300305end
301-
302- # # Jacobian
303-
304- # TODO : does not support static arrays
305-
306- #=
307- struct EnzymeReverseOneArgJacobianPrep{Sy,B} <:DI.JacobianPrep end
308-
309- function EnzymeReverseOneArgJacobianPrep(::Val{Sy}, ::Val{B}) where {Sy,B}
310- return EnzymeReverseOneArgJacobianPrep{Sy,B}()
311- end
312-
313- function DI.prepare_jacobian(f::F, backend::AutoEnzyme{<:ReverseMode,Nothing}, x) where {F}
314- y = f(x)
315- Sy = size(y)
316- valB = to_val(DI.pick_batchsize(backend, y))
317- return EnzymeReverseOneArgJacobianPrep(Val(Sy), valB)
318- end
319-
320- function DI.jacobian(
321- f::F,
322- ::EnzymeReverseOneArgJacobianPrep{Sy,B},
323- backend::AutoEnzyme{<:ReverseMode,Nothing},
324- x,
325- ) where {F,Sy,B}
326- derivs = jacobian(reverse_noprimal(backend), f, x; n_outs=Val(Sy), chunk=Val(B))
327- jac_tensor = only(derivs)
328- return maybe_reshape(jac_tensor, prod(Sy), length(x))
329- end
330-
331- function DI.value_and_jacobian(
332- f::F,
333- ::EnzymeReverseOneArgJacobianPrep{Sy,B},
334- backend::AutoEnzyme{<:ReverseMode,Nothing},
335- x,
336- ) where {F,Sy,B}
337- (; derivs, val) = jacobian(
338- reverse_withprimal(backend), f, x; n_outs=Val(Sy), chunk=Val(B)
339- )
340- jac_tensor = only(derivs)
341- return val, maybe_reshape(jac_tensor, prod(Sy), length(x))
342- end
343-
344- function DI.jacobian!(
345- f::F,
346- jac,
347- prep::EnzymeReverseOneArgJacobianPrep,
348- backend::AutoEnzyme{<:ReverseMode,Nothing},
349- x,
350- ) where {F}
351- return copyto!(jac, DI.jacobian(f, prep, backend, x))
352- end
353-
354- function DI.value_and_jacobian!(
355- f::F,
356- jac,
357- prep::EnzymeReverseOneArgJacobianPrep,
358- backend::AutoEnzyme{<:ReverseMode,Nothing},
359- x,
360- ) where {F}
361- y, new_jac = DI.value_and_jacobian(f, prep, backend, x)
362- return y, copyto!(jac, new_jac)
363- end
364- =#
0 commit comments