Skip to content

Commit 0db3be8

Browse files
authored
Weird array test scenarios in DIT extensions (#359)
1 parent 1299a5e commit 0db3be8

16 files changed

Lines changed: 228 additions & 141 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ julia = "1.6"
7272
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
7373
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
7474
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
75+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
7576
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
7677
# DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3"
7778
Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c"
@@ -81,6 +82,7 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
8182
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
8283
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
8384
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
85+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
8486
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
8587
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
8688
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
@@ -89,6 +91,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
8991
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
9092
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
9193
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
94+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
9295
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
9396
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
9497
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
@@ -99,14 +102,17 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
99102
test = [
100103
"ADTypes",
101104
"Aqua",
105+
"ComponentArrays",
102106
"DataFrames",
103107
# "DifferentiationInterfaceTest",
104108
"JET",
109+
"JLArrays",
105110
"JuliaFormatter",
106111
"Pkg",
107112
"SparseArrays",
108113
"SparseConnectivityTracer",
109114
"SparseMatrixColorings",
110115
"StableRNGs",
116+
"StaticArrays",
111117
"Test",
112118
]

DifferentiationInterface/test/Back/ForwardDiff/test.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
using ComponentArrays: ComponentArrays
12
using DifferentiationInterface, DifferentiationInterfaceTest
23
using DifferentiationInterfaceTest: add_batchified!
34
using ForwardDiff: ForwardDiff
45
using SparseConnectivityTracer, SparseMatrixColorings
6+
using StaticArrays: StaticArrays
57
using Test
68

79
dense_backends = [AutoForwardDiff(), AutoForwardDiff(; chunksize=5, tag=:hello)]

DifferentiationInterface/test/Back/Zygote/test.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
using ComponentArrays: ComponentArrays
12
using DifferentiationInterface, DifferentiationInterfaceTest
3+
using JLArrays: JLArrays
24
using SparseConnectivityTracer, SparseMatrixColorings
5+
using StaticArrays: StaticArrays
36
using Test
47
using Zygote: Zygote
58

DifferentiationInterfaceTest/Project.toml

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,33 @@
11
name = "DifferentiationInterfaceTest"
22
uuid = "a82114a7-5aa3-49a8-9643-716bb13727a3"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.5.0"
4+
version = "0.6.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de"
99
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
10-
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
1110
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
1211
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1312
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1413
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
15-
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
1614
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
15+
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
1716
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
1817
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1918
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
20-
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2119
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2220

21+
[weakdeps]
22+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
23+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
24+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
25+
26+
[extensions]
27+
DifferentiationInterfaceTestComponentArraysExt = "ComponentArrays"
28+
DifferentiationInterfaceTestJLArraysExt = "JLArrays"
29+
DifferentiationInterfaceTestStaticArraysExt = "StaticArrays"
30+
2331
[compat]
2432
ADTypes = "1.0.0"
2533
Chairmarks = "1.2.1"
@@ -42,17 +50,37 @@ julia = "1.6"
4250
[extras]
4351
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
4452
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
53+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
4554
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
4655
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
4756
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
4857
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
58+
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
4959
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
5060
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
5161
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
5262
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
5363
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
64+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
5465
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5566
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5667

5768
[targets]
58-
test = ["ADTypes", "Aqua", "DataFrames", "DifferentiationInterface", "ForwardDiff", "JET", "JuliaFormatter", "Pkg", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "Test", "Zygote"]
69+
test = [
70+
"ADTypes",
71+
"Aqua",
72+
"ComponentArrays",
73+
"DataFrames",
74+
"DifferentiationInterface",
75+
"ForwardDiff",
76+
"JET",
77+
"JLArrays",
78+
"JuliaFormatter",
79+
"Pkg",
80+
"SparseArrays",
81+
"SparseConnectivityTracer",
82+
"SparseMatrixColorings",
83+
"StaticArrays",
84+
"Test",
85+
"Zygote",
86+
]

DifferentiationInterfaceTest/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ Make it easy to know, for a given function:
2626
- Type stability tests
2727
- Count calls to the function
2828
- Benchmark runtime and allocations
29-
- Weird array types (GPU, static, components)
29+
- Scenarios with weird array types (GPU, static, components) in package extensions
3030

3131
## Installation
3232

DifferentiationInterfaceTest/src/scenarios/component.jl renamed to DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestComponentArraysExt/DifferentiationInterfaceTestComponentArraysExt.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
module DifferentiationInterfaceTestComponentArraysExt
2+
3+
using ComponentArrays: ComponentVector
4+
using DifferentiationInterfaceTest
5+
import DifferentiationInterfaceTest as DIT
6+
using LinearAlgebra: dot
7+
using Random: AbstractRNG, default_rng
8+
19
## Vector to scalar
210

311
function comp_to_num(x::ComponentVector)::Number
@@ -42,12 +50,7 @@ end
4250

4351
## Gather
4452

45-
"""
46-
component_scenarios(rng=Random.default_rng())
47-
48-
Create a vector of [`Scenario`](@ref)s with component array types from [ComponentArrays.jl](https://github.com/jonniedie/ComponentArrays.jl).
49-
"""
50-
function component_scenarios(rng::AbstractRNG=default_rng())
53+
function DIT.component_scenarios(rng::AbstractRNG=default_rng())
5154
dy_ = rand(rng)
5255

5356
x_comp = ComponentVector(; a=randn(rng, 4), b=randn(rng, 2))
@@ -60,3 +63,5 @@ function component_scenarios(rng::AbstractRNG=default_rng())
6063
)
6164
return scens
6265
end
66+
67+
end
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
module DifferentiationInterfaceTestJLArraysExt
2+
3+
using DifferentiationInterfaceTest
4+
import DifferentiationInterfaceTest as DIT
5+
using JLArrays: JLArray, jl
6+
using Random: AbstractRNG, default_rng
7+
8+
num_to_arr_jlvector(x) = DIT.num_to_arr(x, JLArray{Float64,1})
9+
num_to_arr_jlmatrix(x) = DIT.num_to_arr(x, JLArray{Float64,2})
10+
11+
DIT.pick_num_to_arr(::Type{<:JLArray{<:Real,1}}) = num_to_arr_jlvector
12+
DIT.pick_num_to_arr(::Type{<:JLArray{<:Real,2}}) = num_to_arr_jlmatrix
13+
14+
function DIT.gpu_scenarios(rng::AbstractRNG=default_rng(); linalg=true)
15+
x_ = rand(rng)
16+
dx_ = rand(rng)
17+
dy_ = rand(rng)
18+
19+
x_6 = jl(rand(rng, 6))
20+
dx_6 = jl(rand(rng, 6))
21+
22+
x_2_3 = jl(rand(rng, 2, 3))
23+
dx_2_3 = jl(rand(rng, 2, 3))
24+
25+
dy_12 = jl(rand(rng, 12))
26+
dy_6_2 = jl(rand(rng, 6, 2))
27+
dy_6 = jl(rand(rng, 6))
28+
dy_2_3 = jl(rand(rng, 2, 3))
29+
30+
V = typeof(dy_6)
31+
M = typeof(dy_2_3)
32+
33+
scens = vcat(
34+
# one argument
35+
DIT.num_to_num_scenarios_onearg(x_; dx=dx_, dy=dy_),
36+
DIT.num_to_arr_scenarios_onearg(x_, V; dx=dx_, dy=dy_6),
37+
DIT.num_to_arr_scenarios_onearg(x_, M; dx=dx_, dy=dy_2_3),
38+
DIT.arr_to_num_scenarios_onearg(x_6; dx=dx_6, dy=dy_, linalg),
39+
DIT.arr_to_num_scenarios_onearg(x_2_3; dx=dx_2_3, dy=dy_, linalg),
40+
DIT.vec_to_vec_scenarios_onearg(x_6; dx=dx_6, dy=dy_12),
41+
DIT.vec_to_mat_scenarios_onearg(x_6; dx=dx_6, dy=dy_6_2),
42+
DIT.mat_to_vec_scenarios_onearg(x_2_3; dx=dx_2_3, dy=dy_12),
43+
DIT.mat_to_mat_scenarios_onearg(x_2_3; dx=dx_2_3, dy=dy_6_2),
44+
# two arguments
45+
DIT.num_to_arr_scenarios_twoarg(x_, V; dx=dx_, dy=dy_6),
46+
DIT.num_to_arr_scenarios_twoarg(x_, M; dx=dx_, dy=dy_2_3),
47+
DIT.vec_to_vec_scenarios_twoarg(x_6; dx=dx_6, dy=dy_12),
48+
DIT.vec_to_mat_scenarios_twoarg(x_6; dx=dx_6, dy=dy_6_2),
49+
DIT.mat_to_vec_scenarios_twoarg(x_2_3; dx=dx_2_3, dy=dy_12),
50+
DIT.mat_to_mat_scenarios_twoarg(x_2_3; dx=dx_2_3, dy=dy_6_2),
51+
)
52+
return scens
53+
end
54+
55+
end
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
module DifferentiationInterfaceTestStaticArraysExt
2+
3+
using DifferentiationInterfaceTest
4+
import DifferentiationInterfaceTest as DIT
5+
using Random: AbstractRNG, default_rng
6+
using SparseArrays: SparseArrays, SparseMatrixCSC, nnz, spdiagm
7+
using StaticArrays: MArray, MMatrix, MVector, SArray, SMatrix, SVector
8+
9+
num_to_arr_svector(x) = DIT.num_to_arr(x, SVector{6,Float64})
10+
num_to_arr_smatrix(x) = DIT.num_to_arr(x, SMatrix{2,3,Float64,6})
11+
12+
DIT.pick_num_to_arr(::Type{<:SVector}) = num_to_arr_svector
13+
DIT.pick_num_to_arr(::Type{<:SMatrix}) = num_to_arr_smatrix
14+
15+
function DIT.static_scenarios(rng::AbstractRNG=default_rng(); linalg=true)
16+
x_ = rand(rng)
17+
dx_ = rand(rng)
18+
dy_ = rand(rng)
19+
20+
x_6 = rand(rng, 6)
21+
dx_6 = rand(rng, 6)
22+
23+
x_2_3 = rand(rng, 2, 3)
24+
dx_2_3 = rand(rng, 2, 3)
25+
26+
dy_6 = rand(rng, 6)
27+
dy_12 = rand(rng, 12)
28+
dy_2_3 = rand(rng, 2, 3)
29+
dy_6_2 = rand(rng, 6, 2)
30+
31+
SV_6 = SVector{6}
32+
MV_6 = MVector{6}
33+
SV_12 = SVector{12}
34+
MV_12 = MVector{12}
35+
36+
SM_2_3 = SMatrix{2,3}
37+
MM_2_3 = MMatrix{2,3}
38+
SM_6_2 = SMatrix{6,2}
39+
MM_6_2 = MMatrix{6,2}
40+
41+
scens = vcat(
42+
# one argument
43+
DIT.num_to_arr_scenarios_onearg(x_, SV_6; dx=dx_, dy=SV_6(dy_6)),
44+
DIT.num_to_arr_scenarios_onearg(x_, SM_2_3; dx=dx_, dy=SM_2_3(dy_2_3)),
45+
DIT.arr_to_num_scenarios_onearg(SV_6(x_6); dx=SV_6(dx_6), dy=dy_, linalg),
46+
DIT.arr_to_num_scenarios_onearg(SM_2_3(x_2_3); dx=SM_2_3(dx_2_3), dy=dy_, linalg),
47+
DIT.vec_to_vec_scenarios_onearg(SV_6(x_6); dx=SV_6(dx_6), dy=SV_12(dy_12)),
48+
DIT.vec_to_mat_scenarios_onearg(SV_6(x_6); dx=SV_6(dx_6), dy=SM_6_2(dy_6_2)),
49+
DIT.mat_to_vec_scenarios_onearg(SM_2_3(x_2_3); dx=SM_2_3(dx_2_3), dy=SV_12(dy_12)),
50+
DIT.mat_to_mat_scenarios_onearg(
51+
SM_2_3(x_2_3); dx=SM_2_3(dx_2_3), dy=SM_6_2(dy_6_2)
52+
),
53+
# two arguments
54+
DIT.num_to_arr_scenarios_twoarg(x_, MV_6; dx=dx_, dy=MV_6(dy_6)),
55+
DIT.num_to_arr_scenarios_twoarg(x_, MM_2_3; dx=dx_, dy=MM_2_3(dy_2_3)),
56+
DIT.vec_to_vec_scenarios_twoarg(MV_6(x_6); dx=MV_6(dx_6), dy=MV_12(dy_12)),
57+
DIT.vec_to_mat_scenarios_twoarg(MV_6(x_6); dx=MV_6(dx_6), dy=MM_6_2(dy_6_2)),
58+
DIT.mat_to_vec_scenarios_twoarg(MM_2_3(x_2_3); dx=MM_2_3(dx_2_3), dy=MV_12(dy_12)),
59+
DIT.mat_to_mat_scenarios_twoarg(
60+
MM_2_3(x_2_3); dx=MM_2_3(dx_2_3), dy=MM_6_2(dy_6_2)
61+
),
62+
)
63+
scens = filter(scens) do s
64+
DIT.place(s) == :outofplace || s.x isa Union{Number,MArray}
65+
end
66+
return scens
67+
end
68+
69+
end

DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ using ADTypes:
1919
SymbolicMode
2020
using Chairmarks: @be, Benchmark, Sample
2121
using Compat
22-
using ComponentArrays: ComponentVector
2322
using DataFrames: DataFrame
2423
using DifferentiationInterface
2524
using DifferentiationInterface:
@@ -63,22 +62,19 @@ using DifferentiationInterface:
6362
using DocStringExtensions
6463
import DifferentiationInterface as DI
6564
using JET: JET
66-
using JLArrays: JLArray, jl
6765
using LinearAlgebra: Adjoint, Diagonal, Transpose, dot, parent
66+
using PackageExtensionCompat: @require_extensions
6867
using ProgressMeter: ProgressUnknown, next!
6968
using Random: AbstractRNG, default_rng, rand!
7069
using SparseArrays: SparseArrays, SparseMatrixCSC, nnz, spdiagm
71-
using StaticArrays: MArray, MMatrix, MVector, SArray, SMatrix, SVector
7270
using Test: @testset, @test
7371

7472
include("scenarios/scenario.jl")
7573
include("scenarios/batchify.jl")
7674
include("scenarios/default.jl")
7775
include("scenarios/sparse.jl")
78-
include("scenarios/static.jl")
79-
include("scenarios/component.jl")
80-
include("scenarios/gpu.jl")
8176
include("scenarios/allocfree.jl")
77+
include("scenarios/extensions.jl")
8278

8379
include("utils/zero_backends.jl")
8480
include("utils/misc.jl")
@@ -92,6 +88,10 @@ include("tests/sparsity.jl")
9288
include("tests/benchmark.jl")
9389
include("test_differentiation.jl")
9490

91+
function __init__()
92+
@require_extensions
93+
end
94+
9595
export Scenario
9696
export PushforwardScenario,
9797
PullbackScenario,
@@ -102,8 +102,11 @@ export PushforwardScenario,
102102
HVPScenario,
103103
HessianScenario
104104
export default_scenarios, sparse_scenarios
105-
export static_scenarios, component_scenarios, gpu_scenarios
106105
export test_differentiation, benchmark_differentiation
107106
export DifferentiationBenchmarkDataRow
107+
# extensions
108+
export static_scenarios
109+
export component_scenarios
110+
export gpu_scenarios
108111

109112
end

DifferentiationInterfaceTest/src/scenarios/default.jl

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -47,30 +47,10 @@ function num_to_arr(x::Number, ::Type{A}) where {A<:AbstractArray}
4747
end
4848

4949
num_to_arr_vector(x) = num_to_arr(x, Vector{Float64})
50-
num_to_arr_svector(x) = num_to_arr(x, SVector{6,Float64})
51-
num_to_arr_jlvector(x) = num_to_arr(x, JLArray{Float64,1})
52-
5350
num_to_arr_matrix(x) = num_to_arr(x, Matrix{Float64})
54-
num_to_arr_smatrix(x) = num_to_arr(x, SMatrix{2,3,Float64,6})
55-
num_to_arr_jlmatrix(x) = num_to_arr(x, JLArray{Float64,2})
56-
57-
function pick_num_to_arr(::Type{A}) where {A<:AbstractArray}
58-
if A <: Vector
59-
return num_to_arr_vector
60-
elseif A <: SVector
61-
return num_to_arr_svector
62-
elseif A <: JLArray{<:Any,1}
63-
return num_to_arr_jlvector
64-
elseif A <: Matrix
65-
return num_to_arr_matrix
66-
elseif A <: SMatrix
67-
return num_to_arr_smatrix
68-
elseif A <: JLArray{<:Any,2}
69-
return num_to_arr_jlmatrix
70-
else
71-
throw(ArgumentError("Array type $A not supported"))
72-
end
73-
end
51+
52+
pick_num_to_arr(::Type{<:Vector}) = num_to_arr_vector
53+
pick_num_to_arr(::Type{<:Matrix}) = num_to_arr_matrix
7454

7555
function num_to_arr!(y::AbstractArray, x::Number)::Nothing
7656
a = multiplicator(typeof(y))

0 commit comments

Comments
 (0)