@@ -21,6 +21,31 @@ function seeded_autodiff_thunk(
2121 end
2222end
2323
24+ function batch_seeded_autodiff_thunk (
25+ rmode:: ReverseModeSplit{ReturnPrimal} ,
26+ dresults:: NTuple ,
27+ f:: FA ,
28+ :: Type{RA} ,
29+ args:: Vararg{Annotation,N} ,
30+ ) where {ReturnPrimal,FA<: Annotation ,RA<: Annotation ,N}
31+ forward, reverse = autodiff_thunk (rmode, FA, RA, typeof .(args)... )
32+ tape, result, shadow_results = forward (f, args... )
33+ if RA <: Active
34+ dresults_righttype = map (Fix1 (convert, typeof (result)), dresults)
35+ dinputs = only (reverse (f, args... , dresults_righttype, tape))
36+ else
37+ foreach (shadow_results, dresults) do d0, d
38+ d0 .+ = d # use recursive_add here?
39+ end
40+ dinputs = only (reverse (f, args... , tape))
41+ end
42+ if ReturnPrimal
43+ return (dinputs, result)
44+ else
45+ return (dinputs,)
46+ end
47+ end
48+
2449# # Pullback
2550
2651function DI. prepare_pullback (
3560
3661# ## Out-of-place
3762
38- function DI. value_and_pullback (
39- f:: F ,
40- prep:: NoPullbackPrep ,
41- backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
42- x,
43- ty:: Tangents ,
44- contexts:: Vararg{Context,C} ,
45- ) where {F,C}
46- ys_and_dxs = map (ty. d) do dy
47- y, tx = DI. value_and_pullback (f, prep, backend, x, Tangents (dy), contexts... )
48- y, only (tx)
49- end
50- y = first (ys_and_dxs[1 ])
51- dxs = last .(ys_and_dxs)
52- tx = Tangents (dxs... )
53- return y, tx
54- end
55-
5663function DI. value_and_pullback (
5764 f:: F ,
5865 :: NoPullbackPrep ,
@@ -70,6 +77,24 @@ function DI.value_and_pullback(
7077 return result, Tangents (first (dinputs))
7178end
7279
80+ function DI. value_and_pullback (
81+ f:: F ,
82+ prep:: NoPullbackPrep ,
83+ backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
84+ x:: Number ,
85+ ty:: Tangents{B} ,
86+ contexts:: Vararg{Context,C} ,
87+ ) where {F,B,C}
88+ # TODO : improve
89+ ys_and_dxs = map (ty. d) do dy
90+ y, tx = DI. value_and_pullback (f, prep, backend, x, Tangents (dy), contexts... )
91+ y, only (tx)
92+ end
93+ y = first (ys_and_dxs[1 ])
94+ dxs = last .(ys_and_dxs)
95+ return y, Tangents (dxs... )
96+ end
97+
7398function DI. value_and_pullback (
7499 f:: F ,
75100 :: NoPullbackPrep ,
@@ -88,53 +113,37 @@ function DI.value_and_pullback(
88113 return result, Tangents (dx)
89114end
90115
91- function DI. pullback (
92- f:: F ,
93- prep:: NoPullbackPrep ,
94- backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
95- x:: Number ,
96- ty:: Tangents{1} ,
97- contexts:: Vararg{Context,C} ,
98- ) where {F,C}
99- return last (DI. value_and_pullback (f, prep, backend, x, ty, contexts... ))
100- end
101-
102- # ## In-place
103-
104- function DI. value_and_pullback! (
116+ function DI. value_and_pullback (
105117 f:: F ,
106- tx:: Tangents ,
107- prep:: NoPullbackPrep ,
118+ :: NoPullbackPrep ,
108119 backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
109120 x,
110- ty:: Tangents ,
121+ ty:: Tangents{B} ,
111122 contexts:: Vararg{Context,C} ,
112- ) where {F,C}
113- ys = map (tx . d, ty . d) do dx, dy
114- y, _ = DI . value_and_pullback! (
115- f, Tangents (dx), prep, backend, x, Tangents (dy), contexts ...
116- )
117- y
118- end
119- y = first (ys )
120- return y, tx
123+ ) where {F,B, C}
124+ f_and_df = force_annotation ( get_f_and_df (f, backend, Val (B)))
125+ mode = reverse_mode_split_withprimal (backend)
126+ RA = eltype (ty) <: Number ? Active : BatchDuplicated
127+ dxs = ntuple (_ -> make_zero (x), Val (B) )
128+ _, result = batch_seeded_autodiff_thunk (
129+ mode, NTuple (ty), f_and_df, RA, BatchDuplicated (x, dxs), map (translate, contexts) ...
130+ )
131+ return result, Tangents (dxs ... )
121132end
122133
123- function DI. pullback! (
134+ function DI. pullback (
124135 f:: F ,
125- tx:: Tangents ,
126136 prep:: NoPullbackPrep ,
127137 backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
128138 x,
129139 ty:: Tangents ,
130140 contexts:: Vararg{Context,C} ,
131141) where {F,C}
132- for b in eachindex (tx. d, ty. d)
133- DI. pullback! (f, Tangents (tx. d[b]), prep, backend, x, Tangents (ty. d[b]), contexts... )
134- end
135- return tx
142+ return last (DI. value_and_pullback (f, prep, backend, x, ty, contexts... ))
136143end
137144
145+ # ## In-place
146+
138147function DI. value_and_pullback! (
139148 f:: F ,
140149 tx:: Tangents{1} ,
@@ -161,13 +170,39 @@ function DI.value_and_pullback!(
161170 return result, tx
162171end
163172
173+ function DI. value_and_pullback! (
174+ f:: F ,
175+ tx:: Tangents{B} ,
176+ :: NoPullbackPrep ,
177+ backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
178+ x,
179+ ty:: Tangents{B} ,
180+ contexts:: Vararg{Context,C} ,
181+ ) where {F,B,C}
182+ f_and_df = force_annotation (get_f_and_df (f, backend, Val (B)))
183+ mode = reverse_mode_split_withprimal (backend)
184+ RA = eltype (ty) <: Number ? Active : BatchDuplicated
185+ dxs_righttype = map (Fix1 (convert, typeof (x)), NTuple (tx))
186+ make_zero! (dxs_righttype)
187+ _, result = batch_seeded_autodiff_thunk (
188+ mode,
189+ NTuple (ty),
190+ f_and_df,
191+ RA,
192+ BatchDuplicated (x, dxs_righttype),
193+ map (translate, contexts)... ,
194+ )
195+ foreach (copyto!, NTuple (tx), dxs_righttype)
196+ return result, tx
197+ end
198+
164199function DI. pullback! (
165200 f:: F ,
166- tx:: Tangents{1} ,
201+ tx:: Tangents ,
167202 prep:: NoPullbackPrep ,
168203 backend:: AutoEnzyme{<:Union{ReverseMode,Nothing}} ,
169204 x,
170- ty:: Tangents{1} ,
205+ ty:: Tangents ,
171206 contexts:: Vararg{Context,C} ,
172207) where {F,C}
173208 return last (DI. value_and_pullback! (f, tx, prep, backend, x, ty, contexts... ))
0 commit comments