Skip to content

Commit b8f82b0

Browse files
authored
Batched HVP (#330)
* Batched hvp * Coverage
1 parent 86f1e02 commit b8f82b0

3 files changed

Lines changed: 160 additions & 68 deletions

File tree

DifferentiationInterface/src/first_order/pullback.jl

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -140,29 +140,25 @@ end
140140

141141
### Batched
142142

143-
function prepare_pullback_batched(
144-
f::F, backend::AbstractADType, x, dy::Batch{B}
145-
) where {F,B}
143+
function prepare_pullback_batched(f::F, backend::AbstractADType, x, dy::Batch) where {F}
146144
return prepare_pullback(f, backend, x, first(dy.elements))
147145
end
148146

149-
function prepare_pullback_batched(
150-
f!::F, y, backend::AbstractADType, x, dy::Batch{B}
151-
) where {F,B}
147+
function prepare_pullback_batched(f!::F, y, backend::AbstractADType, x, dy::Batch) where {F}
152148
return prepare_pullback(f!, y, backend, x, first(dy.elements))
153149
end
154150

155151
### Batched, same point
156152

157153
function prepare_pullback_batched_same_point(
158-
f::F, backend::AbstractADType, x, dy::Batch{B}, extras::PullbackExtras
159-
) where {F,B}
154+
f::F, backend::AbstractADType, x, dy::Batch, extras::PullbackExtras
155+
) where {F}
160156
return prepare_pullback_same_point(f, backend, x, first(dy.elements), extras)
161157
end
162158

163159
function prepare_pullback_batched_same_point(
164-
f!::F, y, backend::AbstractADType, x, dy::Batch{B}, extras::PullbackExtras
165-
) where {F,B}
160+
f!::F, y, backend::AbstractADType, x, dy::Batch, extras::PullbackExtras
161+
) where {F}
166162
return prepare_pullback_same_point(f!, y, backend, x, first(dy.elements), extras)
167163
end
168164

@@ -229,17 +225,17 @@ end
229225
function pullback_batched(
230226
f::F, backend::AbstractADType, x, dy::Batch{B}, extras::PullbackExtras
231227
) where {F,B}
232-
dx_elements = ntuple(Val(B)) do l
233-
pullback(f, backend, x, dy.elements[l], extras)
228+
dx_elements = ntuple(Val(B)) do b
229+
pullback(f, backend, x, dy.elements[b], extras)
234230
end
235231
return Batch(dx_elements)
236232
end
237233

238234
function pullback_batched!(
239-
f::F, dx::Batch{B}, backend::AbstractADType, x, dy::Batch{B}, extras::PullbackExtras
240-
) where {F,B}
241-
for l in 1:B
242-
pullback!(f, dx.elements[l], backend, x, dy.elements[l], extras)
235+
f::F, dx::Batch, backend::AbstractADType, x, dy::Batch, extras::PullbackExtras
236+
) where {F}
237+
for b in eachindex(dx.elements, dy.elements)
238+
pullback!(f, dx.elements[b], backend, x, dy.elements[b], extras)
243239
end
244240
return dx
245241
end
@@ -307,17 +303,17 @@ end
307303
function pullback_batched(
308304
f!::F, y, backend::AbstractADType, x, dy::Batch{B}, extras::PullbackExtras
309305
) where {F,B}
310-
dx_elements = ntuple(Val(B)) do l
311-
pullback(f!, y, backend, x, dy.elements[l], extras)
306+
dx_elements = ntuple(Val(B)) do b
307+
pullback(f!, y, backend, x, dy.elements[b], extras)
312308
end
313309
return Batch(dx_elements)
314310
end
315311

316312
function pullback_batched!(
317-
f!::F, y, dx::Batch{B}, backend::AbstractADType, x, dy::Batch{B}, extras::PullbackExtras
318-
) where {F,B}
319-
for l in 1:B
320-
pullback!(f!, y, dx.elements[l], backend, x, dy.elements[l], extras)
313+
f!::F, y, dx::Batch, backend::AbstractADType, x, dy::Batch, extras::PullbackExtras
314+
) where {F}
315+
for b in eachindex(dx.elements, dy.elements)
316+
pullback!(f!, y, dx.elements[b], backend, x, dy.elements[b], extras)
321317
end
322318
return dx
323319
end

DifferentiationInterface/src/first_order/pushforward.jl

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -141,29 +141,27 @@ end
141141

142142
### Batched
143143

144-
function prepare_pushforward_batched(
145-
f::F, backend::AbstractADType, x, dx::Batch{B}
146-
) where {F,B}
144+
function prepare_pushforward_batched(f::F, backend::AbstractADType, x, dx::Batch) where {F}
147145
return prepare_pushforward(f, backend, x, first(dx.elements))
148146
end
149147

150148
function prepare_pushforward_batched(
151-
f!::F, y, backend::AbstractADType, x, dx::Batch{B}
152-
) where {F,B}
149+
f!::F, y, backend::AbstractADType, x, dx::Batch
150+
) where {F}
153151
return prepare_pushforward(f!, y, backend, x, first(dx.elements))
154152
end
155153

156154
### Batched, same point
157155

158156
function prepare_pushforward_batched_same_point(
159-
f::F, backend::AbstractADType, x, dx::Batch{B}, extras::PushforwardExtras
160-
) where {F,B}
157+
f::F, backend::AbstractADType, x, dx::Batch, extras::PushforwardExtras
158+
) where {F}
161159
return prepare_pushforward_same_point(f, backend, x, first(dx.elements), extras)
162160
end
163161

164162
function prepare_pushforward_batched_same_point(
165-
f!::F, y, backend::AbstractADType, x, dx::Batch{B}, extras::PushforwardExtras
166-
) where {F,B}
163+
f!::F, y, backend::AbstractADType, x, dx::Batch, extras::PushforwardExtras
164+
) where {F}
167165
return prepare_pushforward_same_point(f!, y, backend, x, first(dx.elements), extras)
168166
end
169167

@@ -234,17 +232,17 @@ end
234232
function pushforward_batched(
235233
f::F, backend::AbstractADType, x, dx::Batch{B}, extras::PushforwardExtras
236234
) where {F,B}
237-
dy_elements = ntuple(Val(B)) do l
238-
pushforward(f, backend, x, dx.elements[l], extras)
235+
dy_elements = ntuple(Val(B)) do b
236+
pushforward(f, backend, x, dx.elements[b], extras)
239237
end
240238
return Batch(dy_elements)
241239
end
242240

243241
function pushforward_batched!(
244-
f::F, dy::Batch{B}, backend::AbstractADType, x, dx::Batch{B}, extras::PushforwardExtras
245-
) where {F,B}
246-
for l in 1:B
247-
pushforward!(f, dy.elements[l], backend, x, dx.elements[l], extras)
242+
f::F, dy::Batch, backend::AbstractADType, x, dx::Batch, extras::PushforwardExtras
243+
) where {F}
244+
for b in eachindex(dy.elements, dx.elements)
245+
pushforward!(f, dy.elements[b], backend, x, dx.elements[b], extras)
248246
end
249247
return dy
250248
end
@@ -316,23 +314,17 @@ end
316314
function pushforward_batched(
317315
f!::F, y, backend::AbstractADType, x, dx::Batch{B}, extras::PushforwardExtras
318316
) where {F,B}
319-
dy_elements = ntuple(Val(B)) do l
320-
pushforward(f!, y, backend, x, dx.elements[l], extras)
317+
dy_elements = ntuple(Val(B)) do b
318+
pushforward(f!, y, backend, x, dx.elements[b], extras)
321319
end
322320
return Batch(dy_elements)
323321
end
324322

325323
function pushforward_batched!(
326-
f!::F,
327-
y,
328-
dy::Batch{B},
329-
backend::AbstractADType,
330-
x,
331-
dx::Batch{B},
332-
extras::PushforwardExtras,
333-
) where {F,B}
334-
for l in 1:B
335-
pushforward!(f!, y, dy.elements[l], backend, x, dx.elements[l], extras)
324+
f!::F, y, dy::Batch, backend::AbstractADType, x, dx::Batch, extras::PushforwardExtras
325+
) where {F}
326+
for b in eachindex(dy.elements, dx.elements)
327+
pushforward!(f!, y, dy.elements[b], backend, x, dx.elements[b], extras)
336328
end
337329
return dy
338330
end

DifferentiationInterface/src/second_order/hvp.jl

Lines changed: 123 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -102,33 +102,33 @@ function prepare_hvp(f::F, backend::AbstractADType, x, dx) where {F}
102102
end
103103

104104
function prepare_hvp(f::F, backend::SecondOrder, x, dx) where {F}
105-
return prepare_hvp(f, backend, x, dx, hvp_mode(backend))
105+
return prepare_hvp_aux(f, backend, x, dx, hvp_mode(backend))
106106
end
107107

108-
function prepare_hvp(f::F, backend::SecondOrder, x, dx, ::ForwardOverForward) where {F}
108+
function prepare_hvp_aux(f::F, backend::SecondOrder, x, dx, ::ForwardOverForward) where {F}
109109
# pushforward of many pushforwards in theory, but pushforward of gradient in practice
110110
inner_gradient = InnerGradient(f, nested(inner(backend)))
111111
outer_pushforward_extras = prepare_pushforward(inner_gradient, outer(backend), x, dx)
112112
return ForwardOverForwardHVPExtras(inner_gradient, outer_pushforward_extras)
113113
end
114114

115-
function prepare_hvp(f::F, backend::SecondOrder, x, dx, ::ForwardOverReverse) where {F}
115+
function prepare_hvp_aux(f::F, backend::SecondOrder, x, dx, ::ForwardOverReverse) where {F}
116116
# pushforward of gradient
117117
inner_gradient = InnerGradient(f, nested(inner(backend)))
118118
outer_pushforward_extras = prepare_pushforward(inner_gradient, outer(backend), x, dx)
119119
return ForwardOverReverseHVPExtras(inner_gradient, outer_pushforward_extras)
120120
end
121121

122-
function prepare_hvp(f::F, backend::SecondOrder, x, dx, ::ReverseOverForward) where {F}
122+
function prepare_hvp_aux(f::F, backend::SecondOrder, x, dx, ::ReverseOverForward) where {F}
123123
# gradient of pushforward
124124
# uses dx in the closure so it can't be stored
125125
inner_pushforward = InnerPushforwardFixedSeed(f, nested(inner(backend)), dx)
126126
outer_gradient_extras = prepare_gradient(inner_pushforward, outer(backend), x)
127127
return ReverseOverForwardHVPExtras(outer_gradient_extras)
128128
end
129129

130-
function prepare_hvp(f::F, backend::SecondOrder, x, dx, ::ReverseOverReverse) where {F}
131-
# pullback of the gradient
130+
function prepare_hvp_aux(f::F, backend::SecondOrder, x, dx, ::ReverseOverReverse) where {F}
131+
# pullback of gradient
132132
inner_gradient = InnerGradient(f, nested(inner(backend)))
133133
outer_pullback_extras = prepare_pullback(inner_gradient, outer(backend), x, dx)
134134
return ReverseOverReverseHVPExtras(inner_gradient, outer_pullback_extras)
@@ -149,16 +149,58 @@ end
149149

150150
### Batched
151151

152-
function prepare_hvp_batched(f::F, backend::AbstractADType, x, dx::Batch{B}) where {F,B}
153-
return prepare_hvp(f, backend, x, first(dx.elements))
152+
function prepare_hvp_batched(f::F, backend::AbstractADType, x, dx::Batch) where {F}
153+
return prepare_hvp_batched(f, SecondOrder(backend, backend), x, dx)
154+
end
155+
156+
function prepare_hvp_batched(f::F, backend::SecondOrder, x, dx::Batch) where {F}
157+
return prepare_hvp_batched_aux(f, backend, x, dx, hvp_mode(backend))
158+
end
159+
160+
function prepare_hvp_batched_aux(
161+
f::F, backend::SecondOrder, x, dx::Batch, ::ForwardOverForward
162+
) where {F}
163+
# batched pushforward of gradient
164+
inner_gradient = InnerGradient(f, nested(inner(backend)))
165+
outer_pushforward_extras = prepare_pushforward_batched(
166+
inner_gradient, outer(backend), x, dx
167+
)
168+
return ForwardOverForwardHVPExtras(inner_gradient, outer_pushforward_extras)
169+
end
170+
171+
function prepare_hvp_batched_aux(
172+
f::F, backend::SecondOrder, x, dx::Batch, ::ForwardOverReverse
173+
) where {F}
174+
# batched pushforward of gradient
175+
inner_gradient = InnerGradient(f, nested(inner(backend)))
176+
outer_pushforward_extras = prepare_pushforward_batched(
177+
inner_gradient, outer(backend), x, dx
178+
)
179+
return ForwardOverReverseHVPExtras(inner_gradient, outer_pushforward_extras)
180+
end
181+
182+
function prepare_hvp_batched_aux(
183+
f::F, backend::SecondOrder, x, dx::Batch, ::ReverseOverForward
184+
) where {F}
185+
# TODO: batched version replacing the outer gradient with a pullback
186+
return prepare_hvp_aux(f, backend, x, first(dx.elements), ReverseOverForward())
187+
end
188+
189+
function prepare_hvp_batched_aux(
190+
f::F, backend::SecondOrder, x, dx::Batch, ::ReverseOverReverse
191+
) where {F}
192+
# batched pullback of gradient
193+
inner_gradient = InnerGradient(f, nested(inner(backend)))
194+
outer_pullback_extras = prepare_pullback_batched(inner_gradient, outer(backend), x, dx)
195+
return ReverseOverReverseHVPExtras(inner_gradient, outer_pullback_extras)
154196
end
155197

156198
### Batched, same point
157199

158200
function prepare_hvp_batched_same_point(
159-
f::F, backend::AbstractADType, x, dx::Batch{B}, extras::HVPExtras
160-
) where {F,B}
161-
return prepare_hvp_same_point(f, backend, x, first(dx.elements), extras)
201+
f::F, backend::AbstractADType, x, dx::Batch, extras::HVPExtras
202+
) where {F}
203+
return extras
162204
end
163205

164206
## One argument
@@ -241,27 +283,89 @@ end
241283

242284
### Batched
243285

244-
function hvp_batched(f::F, backend::AbstractADType, x, dx, extras::HVPExtras) where {F}
286+
function hvp_batched(
287+
f::F, backend::AbstractADType, x, dx::Batch, extras::HVPExtras
288+
) where {F}
245289
return hvp_batched(f, SecondOrder(backend, backend), x, dx, extras)
246290
end
247291

248292
function hvp_batched(
249-
f::F, backend::SecondOrder, x, dx::Batch{B}, extras::HVPExtras
293+
f::F, backend::SecondOrder, x, dx::Batch, extras::ForwardOverForwardHVPExtras
294+
) where {F}
295+
@compat (; inner_gradient, outer_pushforward_extras) = extras
296+
return pushforward_batched(
297+
inner_gradient, outer(backend), x, dx, outer_pushforward_extras
298+
)
299+
end
300+
301+
function hvp_batched(
302+
f::F, backend::SecondOrder, x, dx::Batch, extras::ForwardOverReverseHVPExtras
303+
) where {F}
304+
@compat (; inner_gradient, outer_pushforward_extras) = extras
305+
return pushforward_batched(
306+
inner_gradient, outer(backend), x, dx, outer_pushforward_extras
307+
)
308+
end
309+
310+
function hvp_batched(
311+
f::F, backend::SecondOrder, x, dx::Batch{B}, extras::ReverseOverForwardHVPExtras
250312
) where {F,B}
251-
dg_elements = ntuple(Val(B)) do l
252-
hvp(f, backend, x, dx.elements[l], extras)
313+
dg_elements = ntuple(Val(B)) do b
314+
hvp(f, backend, x, dx.elements[b], extras)
253315
end
254316
return Batch(dg_elements)
255317
end
256318

257-
function hvp_batched!(f::F, dg, backend::AbstractADType, x, dx, extras::HVPExtras) where {F}
319+
function hvp_batched(
320+
f::F, backend::SecondOrder, x, dx::Batch, extras::ReverseOverReverseHVPExtras
321+
) where {F}
322+
@compat (; inner_gradient, outer_pullback_extras) = extras
323+
return pullback_batched(inner_gradient, outer(backend), x, dx, outer_pullback_extras)
324+
end
325+
326+
function hvp_batched!(
327+
f::F, dg::Batch, backend::AbstractADType, x, dx::Batch, extras::HVPExtras
328+
) where {F}
258329
return hvp_batched!(f, dg, SecondOrder(backend, backend), x, dx, extras)
259330
end
260331

261332
function hvp_batched!(
262-
f::F, dg::Batch{B}, backend::SecondOrder, x, dx::Batch{B}, extras::HVPExtras
333+
f::F, dg::Batch, backend::SecondOrder, x, dx::Batch, extras::ForwardOverForwardHVPExtras
334+
) where {F}
335+
@compat (; inner_gradient, outer_pushforward_extras) = extras
336+
return pushforward_batched!(
337+
inner_gradient, dg, outer(backend), x, dx, outer_pushforward_extras
338+
)
339+
end
340+
341+
function hvp_batched!(
342+
f::F, dg::Batch, backend::SecondOrder, x, dx::Batch, extras::ForwardOverReverseHVPExtras
343+
) where {F}
344+
@compat (; inner_gradient, outer_pushforward_extras) = extras
345+
return pushforward_batched!(
346+
inner_gradient, dg, outer(backend), x, dx, outer_pushforward_extras
347+
)
348+
end
349+
350+
function hvp_batched!(
351+
f::F,
352+
dg::Batch{B},
353+
backend::SecondOrder,
354+
x,
355+
dx::Batch{B},
356+
extras::ReverseOverForwardHVPExtras,
263357
) where {F,B}
264-
for l in 1:B
265-
hvp!(f, dg.elements[l], backend, x, dx.elements[l], extras)
358+
for b in eachindex(dg.elements, dx.elements)
359+
hvp!(f, dg.elements[b], backend, x, dx.elements[b], extras)
266360
end
361+
return dg
362+
end
363+
364+
function hvp_batched!(
365+
f::F, dg::Batch, backend::SecondOrder, x, dx::Batch, extras::ReverseOverReverseHVPExtras
366+
) where {F}
367+
@compat (; inner_gradient, outer_pullback_extras) = extras
368+
return pullback_batched!(
369+
inner_gradient, dg, outer(backend), x, dx, outer_pullback_extras
370+
)
267371
end

0 commit comments

Comments
 (0)