-
Notifications
You must be signed in to change notification settings - Fork 32
Expand file tree
/
Copy pathDifferentiationInterfaceTestJLArraysExt.jl
More file actions
50 lines (42 loc) · 1.45 KB
/
DifferentiationInterfaceTestJLArraysExt.jl
File metadata and controls
50 lines (42 loc) · 1.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
module DifferentiationInterfaceTestJLArraysExt
import DifferentiationInterface as DI
import DifferentiationInterfaceTest as DIT
using JLArrays: JLArray, JLVector, JLMatrix, jl
using PrecompileTools: @compile_workload
jl_num_to_vec(x::Number) = sin.(jl([1, 2]) .* x)
jl_num_to_mat(x::Number) = hcat(jl_num_to_vec(x), jl_num_to_vec(3x))
const NTV = typeof(DIT.num_to_vec)
const NTM = typeof(DIT.num_to_mat)
myjl(f::Function) = f
myjl(::NTV) = jl_num_to_vec
myjl(::NTM) = jl_num_to_mat
myjl(f::DIT.FunctionModifier) = f
myjl(x::Number) = x
myjl(x::AbstractArray) = jl(x)
myjl(x::Tuple) = map(myjl, x)
myjl(x::DI.Constant) = DI.Constant(myjl(DI.unwrap(x)))
myjl(x::DI.Cache{<:AbstractArray}) = DI.Cache(myjl(DI.unwrap(x)))
myjl(x::DI.Cache{<:Union{Tuple,NamedTuple}}) = map(myjl, map(DI.Cache, DI.unwrap(x)))
myjl(::Nothing) = nothing
function myjl(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
(; f, x, y, t, contexts, prep_args, res1, res2, name) = scen
return DIT.Scenario{op,pl_op,pl_fun}(;
f=myjl(f),
x=myjl(x),
y=myjl(y),
t=myjl(t),
contexts=myjl(contexts),
prep_args=map(myjl, prep_args),
res1=myjl(res1),
res2=myjl(res2),
name,
)
end
function DIT.gpu_scenarios(args...; kwargs...)
scens = DIT.default_scenarios(args...; kwargs...)
return myjl.(scens)
end
@compile_workload begin
DIT.gpu_scenarios(; include_constantified=true, include_cachified=true)
end
end