Skip to content

Commit 9ee81a4

Browse files
authored
mapreduce instead of map + reduce for Jacobians & Hessians (#565)
* Improve type stability tests and benchmarking * Remove `first_order` and `second_order` * Docs * Zero allocs * Fixes * Call count * Fix * Fix * Add count calls * Default count calls * Fix * Custom stacking for StaticArrays * Bump * Clearer modulo * Woops * Undo mo1 * Mapreduce * Add function filter to type stability checks
1 parent 45fbdd6 commit 9ee81a4

4 files changed

Lines changed: 210 additions & 170 deletions

File tree

DifferentiationInterface/src/first_order/jacobian.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ function _jacobian_aux(
232232
f_or_f!y..., pushforward_prep, backend, x, batched_seeds[1], contexts...
233233
)
234234

235-
jac_blocks = map(eachindex(batched_seeds)) do a
235+
jac = mapreduce(hcat, eachindex(batched_seeds)) do a
236236
dy_batch = pushforward(
237237
f_or_f!y...,
238238
pushforward_prep_same,
@@ -247,8 +247,6 @@ function _jacobian_aux(
247247
end
248248
block
249249
end
250-
251-
jac = reduce(hcat, jac_blocks)
252250
return jac
253251
end
254252

@@ -265,7 +263,7 @@ function _jacobian_aux(
265263
f_or_f!y..., prep.pullback_prep, backend, x, batched_seeds[1], contexts...
266264
)
267265

268-
jac_blocks = map(eachindex(batched_seeds)) do a
266+
jac = mapreduce(vcat, eachindex(batched_seeds)) do a
269267
dx_batch = pullback(
270268
f_or_f!y..., pullback_prep_same, backend, x, batched_seeds[a], contexts...
271269
)
@@ -275,8 +273,6 @@ function _jacobian_aux(
275273
end
276274
block
277275
end
278-
279-
jac = reduce(vcat, jac_blocks)
280276
return jac
281277
end
282278

DifferentiationInterface/src/second_order/hessian.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,16 +111,14 @@ function hessian(
111111
f, hvp_prep, backend, x, batched_seeds[1], contexts...
112112
)
113113

114-
hess_blocks = map(eachindex(batched_seeds)) do a
114+
hess = mapreduce(hcat, eachindex(batched_seeds)) do a
115115
dg_batch = hvp(f, hvp_prep_same, backend, x, batched_seeds[a], contexts...)
116116
block = stack_vec_col(dg_batch)
117117
if N % B != 0 && a == lastindex(batched_seeds)
118118
block = block[:, 1:(N - (a - 1) * B)]
119119
end
120120
block
121121
end
122-
123-
hess = reduce(hcat, hess_blocks)
124122
return hess
125123
end
126124

DifferentiationInterfaceTest/src/test_differentiation.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ For `type_stability` and `benchmark`, the possible values are `:none`, `:prepare
4848
**Type stability options:**
4949
5050
- `ignored_modules=nothing`: list of modules that JET.jl should ignore
51+
- `function_filter`: filter for functions that JET.jl should ignore (with a reasonable default)
5152
5253
**Benchmark options:**
5354
@@ -72,6 +73,11 @@ function test_differentiation(
7273
sparsity::Bool=false,
7374
# type stability options
7475
ignored_modules=nothing,
76+
function_filter=if VERSION >= v"1.11"
77+
@nospecialize(f) -> true
78+
else
79+
@nospecialize(f) -> f != Base.mapreduce_empty # fix for `mapreduce` in jacobian and hessian
80+
end,
7581
# benchmark options
7682
count_calls::Bool=true,
7783
)
@@ -136,7 +142,8 @@ function test_differentiation(
136142
adapted_backend,
137143
scen;
138144
subset=type_stability,
139-
ignored_modules=ignored_modules,
145+
ignored_modules,
146+
function_filter,
140147
)
141148
end
142149
yield()

0 commit comments

Comments
 (0)