Skip to content

Commit 8dc755b

Browse files
authored
Better seed handling in Jacobian and Hessian (#334)
1 parent 15e089d commit 8dc755b

7 files changed

Lines changed: 329 additions & 318 deletions

File tree

DifferentiationInterface/src/first_order/jacobian.jl

Lines changed: 36 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -56,13 +56,13 @@ abstract type JacobianExtras <: Extras end
5656
struct NoJacobianExtras <: JacobianExtras end
5757

5858
struct PushforwardJacobianExtras{B,D,E<:PushforwardExtras,Y} <: JacobianExtras
59-
seeds::D
59+
batched_seeds::Vector{Batch{B,D}}
6060
pushforward_batched_extras::E
6161
y_example::Y
6262
end
6363

6464
struct PullbackJacobianExtras{B,D,E<:PullbackExtras,Y} <: JacobianExtras
65-
seeds::D
65+
batched_seeds::Vector{Batch{B,D}}
6666
pullback_batched_extras::E
6767
y_example::Y
6868
end
@@ -80,26 +80,38 @@ function prepare_jacobian_aux(f_or_f!y::FY, backend, x, y, ::PushforwardFast) wh
8080
N = length(x)
8181
B = pick_batchsize(backend, N)
8282
seeds = [basis(backend, x, ind) for ind in CartesianIndices(x)]
83+
batched_seeds =
84+
Batch.([
85+
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B)) for
86+
a in 1:div(N, B, RoundUp)
87+
])
8388
pushforward_batched_extras = prepare_pushforward_batched(
84-
f_or_f!y..., backend, x, Batch(ntuple(Returns(seeds[1]), Val(B)))
89+
f_or_f!y..., backend, x, batched_seeds[1]
8590
)
86-
D = typeof(seeds)
91+
D = eltype(seeds)
8792
E = typeof(pushforward_batched_extras)
8893
Y = typeof(y)
89-
return PushforwardJacobianExtras{B,D,E,Y}(seeds, pushforward_batched_extras, copy(y))
94+
return PushforwardJacobianExtras{B,D,E,Y}(
95+
batched_seeds, pushforward_batched_extras, copy(y)
96+
)
9097
end
9198

9299
function prepare_jacobian_aux(f_or_f!y::FY, backend, x, y, ::PushforwardSlow) where {FY}
93100
M = length(y)
94101
B = pick_batchsize(backend, M)
95102
seeds = [basis(backend, y, ind) for ind in CartesianIndices(y)]
103+
batched_seeds =
104+
Batch.([
105+
ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % M], Val(B)) for
106+
a in 1:div(M, B, RoundUp)
107+
])
96108
pullback_batched_extras = prepare_pullback_batched(
97-
f_or_f!y..., backend, x, Batch(ntuple(Returns(seeds[1]), Val(B)))
109+
f_or_f!y..., backend, x, batched_seeds[1]
98110
)
99-
D = typeof(seeds)
111+
D = eltype(seeds)
100112
E = typeof(pullback_batched_extras)
101113
Y = typeof(y)
102-
return PullbackJacobianExtras{B,D,E,Y}(seeds, pullback_batched_extras, copy(y))
114+
return PullbackJacobianExtras{B,D,E,Y}(batched_seeds, pullback_batched_extras, copy(y))
103115
end
104116

105117
## One argument
@@ -197,27 +209,16 @@ end
197209
function jacobian_aux(
198210
f_or_f!y::FY, backend, x::AbstractArray, extras::PushforwardJacobianExtras{B}
199211
) where {FY,B}
200-
@compat (; seeds, pushforward_batched_extras, y_example) = extras
212+
@compat (; batched_seeds, pushforward_batched_extras, y_example) = extras
201213
N = length(x)
202214

203215
pushforward_batched_extras_same = prepare_pushforward_batched_same_point(
204-
f_or_f!y...,
205-
backend,
206-
x,
207-
Batch(ntuple(Returns(seeds[1]), Val(B))),
208-
pushforward_batched_extras,
216+
f_or_f!y..., backend, x, batched_seeds[1], pushforward_batched_extras
209217
)
210218

211-
jac_blocks = map(1:div(N, B, RoundUp)) do a
212-
dx_batch_elements = ntuple(Val(B)) do b
213-
seeds[1 + ((a - 1) * B + (b - 1)) % N]
214-
end
219+
jac_blocks = map(eachindex(batched_seeds)) do a
215220
dy_batch = pushforward_batched(
216-
f_or_f!y...,
217-
backend,
218-
x,
219-
Batch(dx_batch_elements),
220-
pushforward_batched_extras_same,
221+
f_or_f!y..., backend, x, batched_seeds[a], pushforward_batched_extras_same
221222
)
222223
stack(vec, dy_batch.elements; dims=2)
223224
end
@@ -232,27 +233,16 @@ end
232233
function jacobian_aux(
233234
f_or_f!y::FY, backend, x::AbstractArray, extras::PullbackJacobianExtras{B}
234235
) where {FY,B}
235-
@compat (; seeds, pullback_batched_extras, y_example) = extras
236+
@compat (; batched_seeds, pullback_batched_extras, y_example) = extras
236237
M = length(y_example)
237238

238239
pullback_batched_extras_same = prepare_pullback_batched_same_point(
239-
f_or_f!y...,
240-
backend,
241-
x,
242-
Batch(ntuple(Returns(seeds[1]), Val(B))),
243-
extras.pullback_batched_extras,
240+
f_or_f!y..., backend, x, batched_seeds[1], extras.pullback_batched_extras
244241
)
245242

246-
jac_blocks = map(1:div(M, B, RoundUp)) do a
247-
dy_batch_elements = ntuple(Val(B)) do b
248-
seeds[1 + ((a - 1) * B + (b - 1)) % M]
249-
end
243+
jac_blocks = map(eachindex(batched_seeds)) do a
250244
dx_batch = pullback_batched(
251-
f_or_f!y...,
252-
backend,
253-
x,
254-
Batch(dy_batch_elements),
255-
pullback_batched_extras_same,
245+
f_or_f!y..., backend, x, batched_seeds[a], pullback_batched_extras_same
256246
)
257247
stack(vec, dx_batch.elements; dims=1)
258248
end
@@ -271,21 +261,14 @@ function jacobian_aux!(
271261
x::AbstractArray,
272262
extras::PushforwardJacobianExtras{B},
273263
) where {FY,B}
274-
@compat (; seeds, pushforward_batched_extras, y_example) = extras
264+
@compat (; batched_seeds, pushforward_batched_extras, y_example) = extras
275265
N = length(x)
276266

277267
pushforward_batched_extras_same = prepare_pushforward_batched_same_point(
278-
f_or_f!y...,
279-
backend,
280-
x,
281-
Batch(ntuple(Returns(seeds[1]), Val(B))),
282-
pushforward_batched_extras,
268+
f_or_f!y..., backend, x, batched_seeds[1], pushforward_batched_extras
283269
)
284270

285-
for a in 1:div(N, B, RoundUp)
286-
dx_batch_elements = ntuple(Val(B)) do b
287-
seeds[1 + ((a - 1) * B + (b - 1)) % N]
288-
end
271+
for a in eachindex(batched_seeds)
289272
dy_batch_elements = ntuple(Val(B)) do b
290273
reshape(view(jac, :, 1 + ((a - 1) * B + (b - 1)) % N), size(y_example))
291274
end
@@ -294,7 +277,7 @@ function jacobian_aux!(
294277
Batch(dy_batch_elements),
295278
backend,
296279
x,
297-
Batch(dx_batch_elements),
280+
batched_seeds[a],
298281
pushforward_batched_extras_same,
299282
)
300283
end
@@ -309,21 +292,14 @@ function jacobian_aux!(
309292
x::AbstractArray,
310293
extras::PullbackJacobianExtras{B},
311294
) where {FY,B}
312-
@compat (; seeds, pullback_batched_extras, y_example) = extras
295+
@compat (; batched_seeds, pullback_batched_extras, y_example) = extras
313296
M = length(y_example)
314297

315298
pullback_batched_extras_same = prepare_pullback_batched_same_point(
316-
f_or_f!y...,
317-
backend,
318-
x,
319-
Batch(ntuple(Returns(seeds[1]), Val(B))),
320-
extras.pullback_batched_extras,
299+
f_or_f!y..., backend, x, batched_seeds[1], extras.pullback_batched_extras
321300
)
322301

323-
for a in 1:div(M, B, RoundUp)
324-
dy_batch_elements = ntuple(Val(B)) do b
325-
seeds[1 + ((a - 1) * B + (b - 1)) % M]
326-
end
302+
for a in eachindex(batched_seeds)
327303
dx_batch_elements = ntuple(Val(B)) do b
328304
reshape(view(jac, 1 + ((a - 1) * B + (b - 1)) % M, :), size(x))
329305
end
@@ -332,7 +308,7 @@ function jacobian_aux!(
332308
Batch(dx_batch_elements),
333309
backend,
334310
x,
335-
Batch(dy_batch_elements),
311+
batched_seeds[a],
336312
pullback_batched_extras_same,
337313
)
338314
end

0 commit comments

Comments
 (0)