Skip to content

Commit 944355a

Browse files
committed
Fix static arrays
1 parent 27fbc23 commit 944355a

1 file changed

Lines changed: 8 additions & 2 deletions

File tree

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module DifferentiationInterfaceTestStaticArraysExt
33
import DifferentiationInterface as DI
44
import DifferentiationInterfaceTest as DIT
55
using SparseArrays: SparseArrays, SparseMatrixCSC, nnz, spdiagm
6-
using StaticArrays: MArray, MMatrix, MVector, SArray, SMatrix, SVector
6+
using StaticArrays: StaticArray, MArray, MMatrix, MVector, SArray, SMatrix, SVector
77

88
static_num_to_vec(x::Number) = sin.(SVector(1, 2) .* x)
99
static_num_to_mat(x::Number) = hcat(static_num_to_vec(x), static_num_to_vec(3x))
@@ -37,13 +37,19 @@ mystatic(::Nothing) = nothing
3737

3838
function mystatic(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
3939
(; f, x, y, t, contexts, prep_args, res1, res2, name) = scen
40+
new_prep_args = (;
41+
x=mystatic(prep_args.x), contexts=map(mystatic, prep_args.contexts), t=mystatic(t)
42+
)
43+
if pl_fun == :in
44+
new_prep_args = (; new_prep_args..., y=mymutablestatic(prep_args.y))
45+
end
4046
return DIT.Scenario{op,pl_op,pl_fun}(;
4147
f=mystatic(f),
4248
x=mystatic(x),
4349
y=pl_fun == :in ? mymutablestatic(y) : mystatic(y),
4450
t=mystatic(t),
4551
contexts=mystatic(contexts),
46-
prep_args=map(mystatic, prep_args),
52+
prep_args=new_prep_args,
4753
res1=mystatic(res1),
4854
res2=mystatic(res2),
4955
name=name,

0 commit comments

Comments
 (0)