Skip to content

Commit 3d3f551

Browse files
committed
Nesting in test scens
1 parent 0e5b11e commit 3d3f551

3 files changed

Lines changed: 7 additions & 3 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ DI.inplace_support(::AutoZygote) = DI.InPlaceNotSupported()
1818

1919
translate(c::DI.Context) = DI.unwrap(c)
2020
translate(c::DI.Cache{<:AbstractArray}) = Buffer(DI.unwrap(c))
21-
function translate(c::DI.Cache{<:Union{NTuple,NamedTuple}})
21+
function translate(c::DI.Cache{<:Union{Tuple,NamedTuple}})
2222
return map(translate, map(DI.Cache, DI.unwrap(c)))
2323
end
2424

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ myjl(x::Number) = x
1818
myjl(x::AbstractArray) = jl(x)
1919
myjl(x::Tuple) = map(myjl, x)
2020
myjl(x::DI.Constant) = DI.Constant(myjl(DI.unwrap(x)))
21-
myjl(x::DI.Cache) = DI.Cache(myjl(DI.unwrap(x)))
21+
myjl(x::DI.Cache{<:AbstractArray}) = DI.Cache(myjl(DI.unwrap(x)))
22+
myjl(x::DI.Cache{<:Union{Tuple,NamedTuple}}) = map(myjl, map(DI.Cache, DI.unwrap(x)))
2223
myjl(::Nothing) = nothing
2324

2425
function myjl(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ end
2929

3030
mystatic(x::Tuple) = map(mystatic, x)
3131
mystatic(x::DI.Constant) = DI.Constant(mystatic(DI.unwrap(x)))
32-
mystatic(x::DI.Cache) = DI.Cache(mymutablestatic(DI.unwrap(x)))
32+
mystatic(x::DI.Cache{<:AbstractArray}) = DI.Cache(mymutablestatic(DI.unwrap(x)))
33+
function mystatic(x::DI.Cache{<:Union{Tuple,NamedTuple}})
34+
return map(mystatic, map(DI.Cache, DI.unwrap(x)))
35+
end
3336
mystatic(::Nothing) = nothing
3437

3538
function mystatic(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}

0 commit comments

Comments
 (0)