Skip to content

Commit 94f9bc5

Browse files
authored
Simplify static test scenarios (#581)
* Simplify static scenarios * No conversion * Exclude derivative for Zygote * Fix JLArrays * Unskip
1 parent a97b432 commit 94f9bc5

3 files changed

Lines changed: 14 additions & 31 deletions

File tree

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,24 @@ module DifferentiationInterfaceTestJLArraysExt
33
import DifferentiationInterface as DI
44
using DifferentiationInterfaceTest
55
import DifferentiationInterfaceTest as DIT
6-
using JLArrays: JLArray, jl
6+
using JLArrays: JLArray, JLVector, JLMatrix, jl
77
using Random: AbstractRNG, default_rng
88

99
myjl(f::Function) = f
1010
function myjl(::DIT.NumToArr{A}) where {T,N,A<:AbstractArray{T,N}}
1111
return DIT.NumToArr(JLArray{T,N})
1212
end
1313

14+
function (f::DIT.NumToArr{JLVector{T}})(x::Number) where {T}
15+
a = JLVector{T}(Vector(1:6)) # avoid mutation
16+
return sin.(x .* a)
17+
end
18+
19+
function (f::DIT.NumToArr{JLMatrix{T}})(x::Number) where {T}
20+
a = JLMatrix{T}(Matrix(reshape(1:6, 2, 3))) # avoid mutation
21+
return sin.(x .* a)
22+
end
23+
1424
myjl(f::DIT.MultiplyByConstant) = f
1525
myjl(f::DIT.WritableClosure) = f
1626

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,10 @@ using SparseArrays: SparseArrays, SparseMatrixCSC, nnz, spdiagm
88
using StaticArrays: MArray, MMatrix, MVector, SArray, SMatrix, SVector
99

1010
mySArray(f::Function) = f
11-
myMArray(f::Function) = f
12-
1311
mySArray(::DIT.NumToArr{A}) where {T,A<:AbstractVector{T}} = DIT.NumToArr(SVector{6,T})
14-
myMArray(::DIT.NumToArr{A}) where {T,A<:AbstractVector{T}} = DIT.NumToArr(MVector{6,T})
15-
1612
mySArray(::DIT.NumToArr{A}) where {T,A<:AbstractMatrix{T}} = DIT.NumToArr(SMatrix{2,3,T,6})
17-
myMArray(::DIT.NumToArr{A}) where {T,A<:AbstractMatrix{T}} = DIT.NumToArr(MMatrix{2,3,T,6})
18-
1913
mySArray(f::DIT.MultiplyByConstant) = f
20-
myMArray(f::DIT.MultiplyByConstant) = f
21-
2214
mySArray(f::DIT.WritableClosure) = f
23-
myMArray(f::DIT.WritableClosure) = f
2415

2516
mySArray(x::Number) = x
2617
myMArray(x::Number) = x
@@ -36,13 +27,8 @@ function myMArray(x::AbstractMatrix{T}) where {T}
3627
end
3728

3829
mySArray(x::Tuple) = map(mySArray, x)
39-
myMArray(x::Tuple) = map(myMArray, x)
40-
4130
mySArray(x::DI.Constant) = DI.Constant(mySArray(DI.unwrap(x)))
42-
myMArray(x::DI.Constant) = DI.Constant(myMArray(DI.unwrap(x)))
43-
4431
mySArray(::Nothing) = nothing
45-
myMArray(::Nothing) = nothing
4632

4733
function mySArray(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
4834
(; f, x, y, tang, contexts, res1, res2) = scen
@@ -57,22 +43,9 @@ function mySArray(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
5743
)
5844
end
5945

60-
function myMArray(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
61-
(; f, x, y, tang, contexts, res1, res2) = scen
62-
return Scenario{op,pl_op,pl_fun}(
63-
myMArray(f);
64-
x=myMArray(x),
65-
y=pl_fun == :in ? myMArray(y) : myMArray(y),
66-
tang=myMArray(tang),
67-
contexts=myMArray(contexts),
68-
res1=myMArray(res1),
69-
res2=myMArray(res2),
70-
)
71-
end
72-
7346
function DIT.static_scenarios(args...; kwargs...)
7447
scens = DIT.default_scenarios(args...; kwargs...)
75-
return vcat(mySArray.(scens), myMArray.(scens))
48+
return mySArray.(scens)
7649
end
7750

7851
end

DifferentiationInterfaceTest/src/scenarios/default.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ end
7171

7272
## Number to array
7373

74-
multiplicator(::Type{A}) where {A<:AbstractVector} = convert(A, float.(1:6))
75-
multiplicator(::Type{A}) where {A<:AbstractMatrix} = convert(A, reshape(float.(1:6), 2, 3))
74+
multiplicator(::Type{A}) where {A<:AbstractVector} = A(1:6)
75+
multiplicator(::Type{A}) where {A<:AbstractMatrix} = A(reshape(1:6, 2, 3))
7676

7777
struct NumToArr{A} end
7878
NumToArr(::Type{A}) where {A} = NumToArr{A}()

0 commit comments

Comments
 (0)