@@ -26,7 +26,7 @@ function DI.value_and_pullback(
2626 f,
2727 backend:: AnyAutoEnzyme{<:Union{ReverseMode,Nothing}} ,
2828 x:: Number ,
29- dy:: AbstractArray ,
29+ dy,
3030 :: NoPullbackExtras ,
3131)
3232 tf, tx = typeof (f), typeof (x)
4040function DI. value_and_pullback (
4141 f,
4242 backend:: AnyAutoEnzyme{<:Union{ReverseMode,Nothing}} ,
43- x:: AbstractArray ,
44- dy,
45- extras :: NoPullbackExtras ,
43+ x,
44+ dy:: Number ,
45+ :: NoPullbackExtras ,
4646)
47- dx = similar (x)
47+ dx_sametype = make_zero (x)
48+ x_and_dx = Duplicated (x, dx_sametype)
49+ _, y = if backend isa AutoDeferredEnzyme
50+ autodiff_deferred (ReverseWithPrimal, Const (f), Active, x_and_dx)
51+ else
52+ autodiff (ReverseWithPrimal, Const (f), Active, x_and_dx)
53+ end
54+ if ! isone (dy)
55+ # TODO : generalize beyond Arrays?
56+ dx_sametype .*= dy
57+ end
58+ return y, dx_sametype
59+ end
60+
61+ function DI. value_and_pullback (
62+ f, backend:: AnyAutoEnzyme{<:Union{ReverseMode,Nothing}} , x, dy, extras:: NoPullbackExtras
63+ )
64+ dx = make_zero (x)
4865 return DI. value_and_pullback! (f, dx, backend, x, dy, extras)
4966end
5067
@@ -60,36 +77,34 @@ function DI.value_and_pullback!(
6077 f,
6178 dx,
6279 backend:: AnyAutoEnzyme{<:Union{ReverseMode,Nothing}} ,
63- x:: AbstractArray ,
80+ x,
6481 dy:: Number ,
6582 :: NoPullbackExtras ,
6683)
6784 dx_sametype = convert (typeof (x), dx)
68- dx_sametype . = zero ( eltype (x) )
85+ make_zero! (dx_sametype )
6986 x_and_dx = Duplicated (x, dx_sametype)
7087 _, y = if backend isa AutoDeferredEnzyme
7188 autodiff_deferred (ReverseWithPrimal, Const (f), Active, x_and_dx)
7289 else
7390 autodiff (ReverseWithPrimal, Const (f), Active, x_and_dx)
7491 end
75- dx_sametype .*= dy
92+ if ! isone (dy)
93+ # TODO : generalize beyond Arrays?
94+ dx_sametype .*= dy
95+ end
7696 return y, copyto! (dx, dx_sametype)
7797end
7898
7999function DI. value_and_pullback! (
80- f,
81- dx,
82- backend:: AnyAutoEnzyme{<:Union{ReverseMode,Nothing}} ,
83- x:: AbstractArray ,
84- dy:: AbstractArray ,
85- :: NoPullbackExtras ,
100+ f, dx, backend:: AnyAutoEnzyme{<:Union{ReverseMode,Nothing}} , x, dy, :: NoPullbackExtras
86101)
87102 tf, tx = typeof (f), typeof (x)
88103 forw, rev = autodiff_thunk (
89104 ReverseSplitWithPrimal, Const{tf}, Duplicated, Duplicated{tx}
90105 )
91106 dx_sametype = convert (typeof (x), dx)
92- dx_sametype . = zero ( eltype (x) )
107+ make_zero! (dx_sametype )
93108 tape, y, new_dy = forw (Const (f), Duplicated (x, dx_sametype))
94109 copyto! (new_dy, dy)
95110 rev (Const (f), Duplicated (x, dx_sametype), tape)
@@ -133,7 +148,7 @@ function DI.gradient!(
133148 extras:: NoGradientExtras ,
134149)
135150 grad_sametype = convert (typeof (x), grad)
136- grad_sametype . = zero ( eltype (x) )
151+ make_zero! (grad_sametype )
137152 if backend isa AutoDeferredEnzyme
138153 autodiff_deferred (reverse_mode (backend), f, Active, Duplicated (x, grad_sametype))
139154 else
@@ -145,13 +160,13 @@ end
145160function DI. value_and_gradient (
146161 f, backend:: AnyAutoEnzyme{<:Union{ReverseMode,Nothing}} , x, :: NoGradientExtras
147162)
148- return DI. value_and_pullback (f, backend, x, one ( eltype (x)) , NoPullbackExtras ())
163+ return DI. value_and_pullback (f, backend, x, true , NoPullbackExtras ())
149164end
150165
151166function DI. value_and_gradient! (
152167 f, grad, backend:: AnyAutoEnzyme{<:Union{ReverseMode,Nothing}} , x, :: NoGradientExtras
153168)
154- return DI. value_and_pullback! (f, grad, backend, x, one ( eltype (x)) , NoPullbackExtras ())
169+ return DI. value_and_pullback! (f, grad, backend, x, true , NoPullbackExtras ())
155170end
156171
157172# # Jacobian
0 commit comments