|
| 1 | +module DifferentiationInterfaceTestStaticArraysExt |
| 2 | + |
| 3 | +using DifferentiationInterfaceTest |
| 4 | +import DifferentiationInterfaceTest as DIT |
| 5 | +using Random: AbstractRNG, default_rng |
| 6 | +using SparseArrays: SparseArrays, SparseMatrixCSC, nnz, spdiagm |
| 7 | +using StaticArrays: MArray, MMatrix, MVector, SArray, SMatrix, SVector |
| 8 | + |
| 9 | +num_to_arr_svector(x) = DIT.num_to_arr(x, SVector{6,Float64}) |
| 10 | +num_to_arr_smatrix(x) = DIT.num_to_arr(x, SMatrix{2,3,Float64,6}) |
| 11 | + |
| 12 | +DIT.pick_num_to_arr(::Type{<:SVector}) = num_to_arr_svector |
| 13 | +DIT.pick_num_to_arr(::Type{<:SMatrix}) = num_to_arr_smatrix |
| 14 | + |
| 15 | +function DIT.static_scenarios(rng::AbstractRNG=default_rng(); linalg=true) |
| 16 | + x_ = rand(rng) |
| 17 | + dx_ = rand(rng) |
| 18 | + dy_ = rand(rng) |
| 19 | + |
| 20 | + x_6 = rand(rng, 6) |
| 21 | + dx_6 = rand(rng, 6) |
| 22 | + |
| 23 | + x_2_3 = rand(rng, 2, 3) |
| 24 | + dx_2_3 = rand(rng, 2, 3) |
| 25 | + |
| 26 | + dy_6 = rand(rng, 6) |
| 27 | + dy_12 = rand(rng, 12) |
| 28 | + dy_2_3 = rand(rng, 2, 3) |
| 29 | + dy_6_2 = rand(rng, 6, 2) |
| 30 | + |
| 31 | + SV_6 = SVector{6} |
| 32 | + MV_6 = MVector{6} |
| 33 | + SV_12 = SVector{12} |
| 34 | + MV_12 = MVector{12} |
| 35 | + |
| 36 | + SM_2_3 = SMatrix{2,3} |
| 37 | + MM_2_3 = MMatrix{2,3} |
| 38 | + SM_6_2 = SMatrix{6,2} |
| 39 | + MM_6_2 = MMatrix{6,2} |
| 40 | + |
| 41 | + scens = vcat( |
| 42 | + # one argument |
| 43 | + DIT.num_to_arr_scenarios_onearg(x_, SV_6; dx=dx_, dy=SV_6(dy_6)), |
| 44 | + DIT.num_to_arr_scenarios_onearg(x_, SM_2_3; dx=dx_, dy=SM_2_3(dy_2_3)), |
| 45 | + DIT.arr_to_num_scenarios_onearg(SV_6(x_6); dx=SV_6(dx_6), dy=dy_, linalg), |
| 46 | + DIT.arr_to_num_scenarios_onearg(SM_2_3(x_2_3); dx=SM_2_3(dx_2_3), dy=dy_, linalg), |
| 47 | + DIT.vec_to_vec_scenarios_onearg(SV_6(x_6); dx=SV_6(dx_6), dy=SV_12(dy_12)), |
| 48 | + DIT.vec_to_mat_scenarios_onearg(SV_6(x_6); dx=SV_6(dx_6), dy=SM_6_2(dy_6_2)), |
| 49 | + DIT.mat_to_vec_scenarios_onearg(SM_2_3(x_2_3); dx=SM_2_3(dx_2_3), dy=SV_12(dy_12)), |
| 50 | + DIT.mat_to_mat_scenarios_onearg( |
| 51 | + SM_2_3(x_2_3); dx=SM_2_3(dx_2_3), dy=SM_6_2(dy_6_2) |
| 52 | + ), |
| 53 | + # two arguments |
| 54 | + DIT.num_to_arr_scenarios_twoarg(x_, MV_6; dx=dx_, dy=MV_6(dy_6)), |
| 55 | + DIT.num_to_arr_scenarios_twoarg(x_, MM_2_3; dx=dx_, dy=MM_2_3(dy_2_3)), |
| 56 | + DIT.vec_to_vec_scenarios_twoarg(MV_6(x_6); dx=MV_6(dx_6), dy=MV_12(dy_12)), |
| 57 | + DIT.vec_to_mat_scenarios_twoarg(MV_6(x_6); dx=MV_6(dx_6), dy=MM_6_2(dy_6_2)), |
| 58 | + DIT.mat_to_vec_scenarios_twoarg(MM_2_3(x_2_3); dx=MM_2_3(dx_2_3), dy=MV_12(dy_12)), |
| 59 | + DIT.mat_to_mat_scenarios_twoarg( |
| 60 | + MM_2_3(x_2_3); dx=MM_2_3(dx_2_3), dy=MM_6_2(dy_6_2) |
| 61 | + ), |
| 62 | + ) |
| 63 | + scens = filter(scens) do s |
| 64 | + DIT.place(s) == :outofplace || s.x isa Union{Number,MArray} |
| 65 | + end |
| 66 | + return scens |
| 67 | +end |
| 68 | + |
| 69 | +end |
0 commit comments