-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathDifferentiationInterfaceTestJLArraysExt.jl
More file actions
43 lines (36 loc) · 1.24 KB
/
DifferentiationInterfaceTestJLArraysExt.jl
File metadata and controls
43 lines (36 loc) · 1.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
module DifferentiationInterfaceTestJLArraysExt
import DifferentiationInterface as DI
import DifferentiationInterfaceTest as DIT
using JLArrays: JLArray, JLVector, JLMatrix, jl
jl_num_to_vec(x::Number) = sin.(jl([1, 2]) .* x)
jl_num_to_mat(x::Number) = hcat(jl_num_to_vec(x), jl_num_to_vec(3x))
const NTV = typeof(DIT.num_to_vec)
const NTM = typeof(DIT.num_to_mat)
myjl(f::Function) = f
myjl(::NTV) = jl_num_to_vec
myjl(::NTM) = jl_num_to_mat
myjl(f::DIT.FunctionModifier) = f
myjl(x::Number) = x
myjl(x::AbstractArray) = jl(x)
myjl(x::Tuple) = map(myjl, x)
myjl(x::DI.Constant) = DI.Constant(myjl(DI.unwrap(x)))
myjl(x::DI.Cache{<:AbstractArray}) = DI.Cache(myjl(DI.unwrap(x)))
myjl(x::DI.Cache{<:Union{Tuple,NamedTuple}}) = map(myjl, map(DI.Cache, DI.unwrap(x)))
myjl(::Nothing) = nothing
function myjl(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
(; f, x, y, tang, contexts, res1, res2) = scen
return DIT.Scenario{op,pl_op,pl_fun}(
myjl(f);
x=myjl(x),
y=myjl(y),
tang=myjl(tang),
contexts=myjl(contexts),
res1=myjl(res1),
res2=myjl(res2),
)
end
function DIT.gpu_scenarios(args...; kwargs...)
scens = DIT.default_scenarios(args...; kwargs...)
return myjl.(scens)
end
end