Skip to content

Commit 43bb14a

Browse files
authored
Add more sparse AD test scenarios (#578)
* Add more sparse AD test scenarios * No test with fixed matrices on Zygote * Increase max bands beyond ForwardDiff max chunk * More bandwidths, fix Enzyme * Speed up symbolic
1 parent aa1b63e commit 43bb14a

7 files changed

Lines changed: 124 additions & 8 deletions

File tree

DifferentiationInterface/test/Back/Enzyme/test.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,5 +90,8 @@ test_differentiation(
9090
## Sparse
9191

9292
test_differentiation(
93-
MyAutoSparse.(AutoEnzyme()), sparse_scenarios(); sparsity=true, logging=LOGGING
93+
MyAutoSparse.(AutoEnzyme(; function_annotation=Enzyme.Const)),
94+
sparse_scenarios();
95+
sparsity=true,
96+
logging=LOGGING,
9497
);

DifferentiationInterface/test/Back/SymbolicBackends/fastdifferentiation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ test_differentiation(AutoFastDifferentiation(); logging=LOGGING);
1616

1717
test_differentiation(
1818
AutoSparse(AutoFastDifferentiation()),
19-
sparse_scenarios();
19+
sparse_scenarios(; band_sizes=0:-1);
2020
sparsity=true,
2121
logging=LOGGING,
2222
);

DifferentiationInterface/test/Back/SymbolicBackends/symbolics.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,8 @@ end
1515
test_differentiation(AutoSymbolics(); logging=LOGGING);
1616

1717
test_differentiation(
18-
AutoSparse(AutoSymbolics()), sparse_scenarios(); sparsity=true, logging=LOGGING
18+
AutoSparse(AutoSymbolics()),
19+
sparse_scenarios(; band_sizes=0:-1);
20+
sparsity=true,
21+
logging=LOGGING,
1922
);

DifferentiationInterface/test/Back/Zygote/test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ test_differentiation(
4141

4242
test_differentiation(
4343
MyAutoSparse.(vcat(backends, second_order_backends)),
44-
sparse_scenarios();
44+
sparse_scenarios(; band_sizes=0:-1);
4545
sparsity=true,
4646
logging=LOGGING,
4747
)

DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ using JET: @test_opt
5656
using LinearAlgebra: Adjoint, Diagonal, Transpose, dot, parent
5757
using ProgressMeter: ProgressUnknown, next!
5858
using Random: AbstractRNG, default_rng, rand!
59-
using SparseArrays: SparseArrays, AbstractSparseMatrix, SparseMatrixCSC, nnz, spdiagm
59+
using SparseArrays:
60+
SparseArrays, AbstractSparseMatrix, SparseMatrixCSC, nnz, sparse, spdiagm
6061
using Test: @testset, @test
6162

6263
"""

DifferentiationInterfaceTest/src/scenarios/sparse.jl

Lines changed: 111 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,6 @@ end
178178

179179
function sparse_vec_to_num_scenarios(x::AbstractVector)
180180
f = sumdiffcube
181-
y = f(x)
182181
grad = sumdiffcube_gradient(x)
183182
hess = sumdiffcube_hessian(x)
184183

@@ -203,7 +202,6 @@ end
203202

204203
function sparse_mat_to_num_scenarios(x::AbstractMatrix)
205204
f = sumdiffcube_mat
206-
y = f(x)
207205
grad = sumdiffcube_mat_gradient(x)
208206
hess = sumdiffcube_mat_hessian(x)
209207

@@ -214,14 +212,120 @@ function sparse_mat_to_num_scenarios(x::AbstractMatrix)
214212
return scens
215213
end
216214

215+
## Various matrices
216+
217+
function banded_matrix(m, n, b)
218+
pairs = [k => rand(min(m, n) - k) for k in 0:b]
219+
return spdiagm(m, n, pairs...)
220+
end
221+
222+
### Linear map
223+
224+
struct SquareLinearMap{M<:AbstractMatrix}
225+
A::M
226+
end
227+
228+
function Base.show(io::IO, s::SquareLinearMap{M}) where {M}
229+
return print(io, "SquareLinearMap{$M - $(size(s.A)) with $(mynnz(s.A)) nonzeros}")
230+
end
231+
232+
function (s::SquareLinearMap)(x::AbstractArray)
233+
return s.A * abs2.(vec(x))
234+
end
235+
236+
function (s::SquareLinearMap)(y::AbstractArray, x::AbstractArray)
237+
vec(y) .= s.A * abs2.(vec(x))
238+
return nothing
239+
end
240+
241+
function squarelinearmap_jacobian(x::AbstractArray, A::AbstractMatrix)
242+
return 2 .* A .* transpose(vec(x))
243+
end
244+
245+
function squarelinearmap_scenarios(x::AbstractVector, band_sizes)
246+
n = length(x)
247+
scens = Scenario[]
248+
for A in vcat(banded_matrix.(2n, n, band_sizes), banded_matrix.(n ÷ 2, n, band_sizes))
249+
f = SquareLinearMap(A)
250+
f! = f
251+
y = f(x)
252+
jac = sparse(squarelinearmap_jacobian(x, A))
253+
for pl_op in (:out, :in)
254+
append!(
255+
scens,
256+
[
257+
Scenario{:jacobian,pl_op}(f, x; res1=jac),
258+
Scenario{:jacobian,pl_op}(f!, y, x; res1=jac),
259+
],
260+
)
261+
end
262+
end
263+
return scens
264+
end
265+
266+
### Quadratic form
267+
268+
struct SquareQuadraticForm{M<:AbstractMatrix}
269+
A::M
270+
end
271+
272+
function Base.show(io::IO, s::SquareQuadraticForm{M}) where {M}
273+
return print(io, "SquareQuadraticForm{$M - $(size(s.A)) with $(mynnz(s.A)) nonzeros}")
274+
end
275+
276+
function (s::SquareQuadraticForm)(x::AbstractArray)
277+
v = abs2.(vec(x))
278+
return dot(v, s.A, v)
279+
end
280+
281+
function squarequadraticform_gradient(x::AbstractArray, A::AbstractMatrix)
282+
g = similar(x)
283+
for i in eachindex(g)
284+
g[i] =
285+
4 * A[i, i] * x[i]^3 +
286+
2 * sum((A[i, j] + A[j, i]) * x[i] * x[j]^2 for j in eachindex(g) if j != i)
287+
end
288+
return g
289+
end
290+
291+
function squarequadraticform_hessian(x::AbstractArray, A::AbstractMatrix)
292+
H = similar(x, length(x), length(x))
293+
for i in axes(H, 1), j in axes(H, 2)
294+
if i == j
295+
H[i, i] =
296+
12 * A[i, i] * x[i]^2 +
297+
2 * sum((A[i, j2] + A[j2, i]) * x[j2]^2 for j2 in axes(H, 2) if j2 != i)
298+
else
299+
H[i, j] = 4 * (A[i, j] + A[j, i]) * x[i] * x[j]
300+
end
301+
end
302+
return H
303+
end
304+
305+
function squarequadraticform_scenarios(x::AbstractVector, band_sizes)
306+
n = length(x)
307+
scens = Scenario[]
308+
for A in banded_matrix.(n, n, band_sizes)
309+
f = SquareQuadraticForm(A)
310+
grad = squarequadraticform_gradient(x, A)
311+
hess = sparse(squarequadraticform_hessian(x, A))
312+
for pl_op in (:out, :in)
313+
push!(scens, Scenario{:hessian,pl_op}(f, x; res1=grad, res2=hess))
314+
end
315+
end
316+
return scens
317+
end
318+
217319
## Gather
218320

219321
"""
220322
sparse_scenarios(rng=Random.default_rng())
221323
222324
Create a vector of [`Scenario`](@ref)s with sparse array types, focused on sparse Jacobians and Hessians.
223325
"""
224-
function sparse_scenarios(rng::AbstractRNG=default_rng(); include_constantified=false)
326+
function sparse_scenarios(
327+
rng::AbstractRNG=default_rng(); band_sizes=0:4:36, include_constantified=false
328+
)
225329
scens = vcat(
226330
sparse_vec_to_vec_scenarios(rand(rng, 6)),
227331
sparse_vec_to_mat_scenarios(rand(rng, 6)),
@@ -230,6 +334,10 @@ function sparse_scenarios(rng::AbstractRNG=default_rng(); include_constantified=
230334
sparse_vec_to_num_scenarios(rand(rng, 6)),
231335
sparse_mat_to_num_scenarios(rand(rng, 2, 3)),
232336
)
337+
if !isempty(band_sizes)
338+
append!(scens, squarelinearmap_scenarios(rand(rng, 100), band_sizes))
339+
append!(scens, squarequadraticform_scenarios(rand(rng, 100), band_sizes))
340+
end
233341
include_constantified && append!(scens, constantify(scens))
234342
return scens
235343
end

DifferentiationInterfaceTest/test/standard.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ sparse_backend = AutoSparse(
2323
sparsity_detector=TracerSparsityDetector(),
2424
coloring_algorithm=GreedyColoringAlgorithm(),
2525
)
26+
2627
test_differentiation(
2728
sparse_backend,
2829
sparse_scenarios(; include_constantified=true);

0 commit comments

Comments
 (0)