diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index e29a979c0..1a1f819dd 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.45" +version = "0.6.46" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl index 8b1550532..5575cbb3e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl @@ -54,7 +54,7 @@ force_annotation(f::F) where {F} = Const(f) end @inline function _translate( - backend::AutoEnzyme, ::Mode, ::Val{B}, c::DI.GeneralizedCache + backend::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Cache,DI.PrepContext} ) where {B} if B == 1 return Duplicated(DI.unwrap(c), make_zero(DI.unwrap(c))) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl index 2dd4409cc..44afdd9ce 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl @@ -89,7 +89,7 @@ function _translate( end function _translate(::Type{D}, c::DI.Cache) where {D<:Dual} c0 = DI.unwrap(c) - return similar(c0, D) + return DI.recursive_similar(c0, D) end function translate(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:Dual,C} @@ -106,7 +106,7 @@ function _translate_toprep( end function _translate_toprep(::Type{D}, c::DI.Cache) where {D<:Dual} c0 = DI.unwrap(c) - return similar(c0, D) + return DI.recursive_similar(c0, D) end function translate_toprep(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:Dual,C} diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl index f5a315c3a..a01a804ef 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl @@ -5,26 +5,23 @@ import DifferentiationInterface as DI using SparseConnectivityTracer: TracerSparsityDetector, TracerLocalSparsityDetector, jacobian_buffer, hessian_buffer -@inline _jacobian_translate(detector, c::DI.Constant) = DI.unwrap(c) -@inline function _jacobian_translate(detector, c::DI.Cache{<:AbstractArray}) - return jacobian_buffer(DI.unwrap(c), detector) +@inline _translate(::Type, c::DI.Constant) = DI.unwrap(c) +@inline function _translate(::Type{T}, c::DI.Cache) where {T} + return DI.recursive_similar(DI.unwrap(c), T) end -function jacobian_translate(detector, contexts::Vararg{DI.Context,C}) where {C} +function jacobian_translate(detector, x, contexts::Vararg{DI.Context,C}) where {C} + T = eltype(jacobian_buffer(x, detector)) new_contexts = map(contexts) do c - _jacobian_translate(detector, c) + _translate(T, c) end return new_contexts end -@inline _hessian_translate(detector, c::DI.Constant) = DI.unwrap(c) -@inline function _hessian_translate(detector, c::DI.Cache{<:AbstractArray}) - return hessian_buffer(DI.unwrap(c), detector) -end - -function hessian_translate(detector, contexts::Vararg{DI.Context,C}) where {C} +function hessian_translate(detector, x, contexts::Vararg{DI.Context,C}) where {C} + T = eltype(hessian_buffer(x, detector)) new_contexts = map(contexts) do c - _hessian_translate(detector, c) + _translate(T, c) end return new_contexts end @@ -35,7 +32,7 @@ function DI.jacobian_sparsity_with_contexts( x, contexts::Vararg{DI.Context,C}, ) where {F,C} - contexts_tracer = jacobian_translate(detector, contexts...) + contexts_tracer = jacobian_translate(detector, x, contexts...) fc = DI.FixTail(f, contexts_tracer...) return jacobian_sparsity(fc, x, detector) end @@ -47,7 +44,7 @@ function DI.jacobian_sparsity_with_contexts( x, contexts::Vararg{DI.Context,C}, ) where {F,C} - contexts_tracer = jacobian_translate(detector, contexts...) + contexts_tracer = jacobian_translate(detector, x, contexts...) fc! = DI.FixTail(f!, contexts_tracer...) return jacobian_sparsity(fc!, y, x, detector) end @@ -58,7 +55,7 @@ function DI.hessian_sparsity_with_contexts( x, contexts::Vararg{DI.Context,C}, ) where {F,C} - contexts_tracer = hessian_translate(detector, contexts...) + contexts_tracer = hessian_translate(detector, x, contexts...) fc = DI.FixTail(f, contexts_tracer...) return hessian_sparsity(fc, x, detector) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index dc400b03e..72763eb6a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -17,7 +17,10 @@ DI.check_available(::AutoZygote) = true DI.inplace_support(::AutoZygote) = DI.InPlaceNotSupported() translate(c::DI.Context) = DI.unwrap(c) -translate(c::DI.Cache) = Buffer(DI.unwrap(c)) +translate(c::DI.Cache{<:AbstractArray}) = Buffer(DI.unwrap(c)) +function translate(c::DI.Cache{<:Union{Tuple,NamedTuple}}) + return map(translate, map(DI.Cache, DI.unwrap(c))) +end ## Pullback diff --git a/DifferentiationInterface/src/utils/context.jl b/DifferentiationInterface/src/utils/context.jl index 65edbde0f..201c0ae2e 100644 --- a/DifferentiationInterface/src/utils/context.jl +++ b/DifferentiationInterface/src/utils/context.jl @@ -23,7 +23,6 @@ Abstract supertype for additional context arguments, which can be passed to diff abstract type Context end abstract type GeneralizedConstant <: Context end -abstract type GeneralizedCache <: Context end unwrap(c::Context) = c.data Base.:(==)(c1::Context, c2::Context) = unwrap(c1) == unwrap(c2) @@ -78,7 +77,7 @@ The initial values present inside the cache do not matter. For some backends, preparation allocates the required memory for `Cache` contexts with the right element type, similar to [PreallocationTools.jl](https://github.com/SciML/PreallocationTools.jl). !!! warning - Most backends require any `Cache` context to be an `AbstractArray`. + Some backends require any `Cache` context to be an `AbstractArray`, others accept nested (named) tuples of `AbstractArray`s. # Example @@ -97,7 +96,7 @@ julia> gradient(f, prep, AutoForwardDiff(), [3.0, 4.0], Cache(zeros(2))) 1.0 ```` """ -struct Cache{T} <: GeneralizedCache +struct Cache{T} <: Context data::T end @@ -114,12 +113,10 @@ struct BackendContext{T} <: GeneralizedConstant data::T end -struct PrepContext{T} <: GeneralizedCache +struct PrepContext{T} <: Context data::T end -struct UnknownContext <: Context end - ## Context manipulation struct Rewrap{C,T} @@ -146,4 +143,4 @@ function with_contexts(f::F, contexts::Vararg{Context,N}) where {F,N} end adapt_eltype(c::Constant, ::Type) = c -adapt_eltype(c::Cache, ::Type{T}) where {T} = Cache(similar(unwrap(c), T)) +adapt_eltype(c::Cache, ::Type{T}) where {T} = Cache(recursive_similar(unwrap(c), T)) diff --git a/DifferentiationInterface/src/utils/linalg.jl b/DifferentiationInterface/src/utils/linalg.jl index fcabdb143..b7dd4a42a 100644 --- a/DifferentiationInterface/src/utils/linalg.jl +++ b/DifferentiationInterface/src/utils/linalg.jl @@ -10,3 +10,15 @@ At the moment, this only returns `false` for `StaticArrays.SArray`. """ ismutable_array(::Type) = true ismutable_array(x) = ismutable_array(typeof(x)) + +""" + recursive_similar(x, T) + +Apply `similar(_, T)` recursively to `x` or its components. + +Works if `x` is an `AbstractArray` or a (nested) `NTuple` / `NamedTuple` of `AbstractArray`s. +""" +recursive_similar(x::AbstractArray, ::Type{T}) where {T} = similar(x, T) +function recursive_similar(x::Union{Tuple,NamedTuple}, ::Type{T}) where {T} + return map(xi -> recursive_similar(xi, T), x) +end diff --git a/DifferentiationInterface/test/Back/Enzyme/test.jl b/DifferentiationInterface/test/Back/Enzyme/test.jl index 4fbb40fa0..2aa2c3268 100644 --- a/DifferentiationInterface/test/Back/Enzyme/test.jl +++ b/DifferentiationInterface/test/Back/Enzyme/test.jl @@ -55,7 +55,7 @@ end; test_differentiation( backends[2], - default_scenarios(; include_normal=false, include_cachified=true); + default_scenarios(; include_normal=false, include_cachified=true, use_tuples=true); excluded=SECOND_ORDER, logging=LOGGING, ) diff --git a/DifferentiationInterface/test/Back/FiniteDiff/test.jl b/DifferentiationInterface/test/Back/FiniteDiff/test.jl index aa92743b6..dc111f45f 100644 --- a/DifferentiationInterface/test/Back/FiniteDiff/test.jl +++ b/DifferentiationInterface/test/Back/FiniteDiff/test.jl @@ -22,7 +22,9 @@ end @testset "Dense" begin test_differentiation( AutoFiniteDiff(), - default_scenarios(; include_constantified=true, include_cachified=true); + default_scenarios(; + include_constantified=true, include_cachified=true, use_tuples=true + ); excluded=[:second_derivative, :hvp], logging=LOGGING, ) diff --git a/DifferentiationInterface/test/Back/FiniteDifferences/test.jl b/DifferentiationInterface/test/Back/FiniteDifferences/test.jl index d512ee8d4..c55ee3bd2 100644 --- a/DifferentiationInterface/test/Back/FiniteDifferences/test.jl +++ b/DifferentiationInterface/test/Back/FiniteDifferences/test.jl @@ -19,7 +19,9 @@ end test_differentiation( AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(3, 1)), - default_scenarios(; include_constantified=true, include_cachified=true); + default_scenarios(; + include_constantified=true, include_cachified=true, use_tuples=true + ); excluded=SECOND_ORDER, logging=LOGGING, ); diff --git a/DifferentiationInterface/test/Back/ForwardDiff/test.jl b/DifferentiationInterface/test/Back/ForwardDiff/test.jl index f4024afbd..0b9ff0d50 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/test.jl @@ -36,7 +36,10 @@ end test_differentiation( AutoForwardDiff(), default_scenarios(; - include_normal=false, include_batchified=false, include_cachified=true + include_normal=false, + include_batchified=false, + include_cachified=true, + use_tuples=true, ); logging=LOGGING, ) diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index 1bc485cdf..8c9ab839a 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -19,7 +19,9 @@ end test_differentiation( backends, - default_scenarios(; include_constantified=true, include_cachified=true); + default_scenarios(; + include_constantified=true, include_cachified=true, use_tuples=true + ); excluded=SECOND_ORDER, logging=LOGGING, ); diff --git a/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl b/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl index 4f38af5b1..34b59d46a 100644 --- a/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/PolyesterForwardDiff/test.jl @@ -28,7 +28,9 @@ end test_differentiation( backends, - default_scenarios(; include_constantified=true, include_cachified=true); + default_scenarios(; + include_constantified=true, include_cachified=true, use_tuples=true + ); logging=LOGGING, ); diff --git a/DifferentiationInterface/test/Back/SymbolicBackends/fastdifferentiation.jl b/DifferentiationInterface/test/Back/SymbolicBackends/fastdifferentiation.jl index c8efbdaa8..db6d2215b 100644 --- a/DifferentiationInterface/test/Back/SymbolicBackends/fastdifferentiation.jl +++ b/DifferentiationInterface/test/Back/SymbolicBackends/fastdifferentiation.jl @@ -17,7 +17,9 @@ end test_differentiation( AutoFastDifferentiation(), - default_scenarios(; include_constantified=true, include_cachified=true); + default_scenarios(; + include_constantified=true, include_cachified=true, use_tuples=false + ); logging=LOGGING, ); diff --git a/DifferentiationInterface/test/Back/SymbolicBackends/symbolics.jl b/DifferentiationInterface/test/Back/SymbolicBackends/symbolics.jl index 31f8316a0..91625b700 100644 --- a/DifferentiationInterface/test/Back/SymbolicBackends/symbolics.jl +++ b/DifferentiationInterface/test/Back/SymbolicBackends/symbolics.jl @@ -21,7 +21,7 @@ test_differentiation( test_differentiation( AutoSymbolics(), - default_scenarios(; include_normal=false, include_cachified=true); + default_scenarios(; include_normal=false, include_cachified=true, use_tuples=false); excluded=[:jacobian], # TODO: figure out why this fails logging=LOGGING, ); diff --git a/DifferentiationInterface/test/Back/Zygote/test.jl b/DifferentiationInterface/test/Back/Zygote/test.jl index 6b30e924b..882777e20 100644 --- a/DifferentiationInterface/test/Back/Zygote/test.jl +++ b/DifferentiationInterface/test/Back/Zygote/test.jl @@ -27,7 +27,9 @@ end @testset "Dense" begin test_differentiation( backends, - default_scenarios(; include_constantified=true, include_cachified=true); + default_scenarios(; + include_constantified=true, include_cachified=true, use_tuples=true + ); excluded=[:second_derivative], logging=LOGGING, ) diff --git a/DifferentiationInterface/test/Core/Internals/linalg.jl b/DifferentiationInterface/test/Core/Internals/linalg.jl new file mode 100644 index 000000000..03798da87 --- /dev/null +++ b/DifferentiationInterface/test/Core/Internals/linalg.jl @@ -0,0 +1,9 @@ +using DifferentiationInterface: recursive_similar +using Test + +@test recursive_similar(ones(Int, 2), Float32) isa Vector{Float32} +@test recursive_similar((ones(Int, 2), ones(Bool, 3, 4)), Float32) isa + Tuple{Vector{Float32},Matrix{Float32}} +@test recursive_similar((a=ones(Int, 2), b=(ones(Bool, 3, 4),)), Float32) isa + @NamedTuple{a::Vector{Float32}, b::Tuple{Matrix{Float32}}} +@test_throws MethodError recursive_similar(1, Float32) diff --git a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl index d716f136b..34d93c16a 100644 --- a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl +++ b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl @@ -86,7 +86,9 @@ end MyAutoSparse.( vcat(adaptive_backends, MixedMode(adaptive_backends[1], adaptive_backends[2])) ), - sparse_scenarios(; include_constantified=true, include_cachified=true); + sparse_scenarios(; + include_constantified=true, include_cachified=true, use_tuples=true + ); sparsity=true, logging=LOGGING, ) diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl index 305bee48b..11dca2cc5 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestJLArraysExt/DifferentiationInterfaceTestJLArraysExt.jl @@ -18,7 +18,8 @@ myjl(x::Number) = x myjl(x::AbstractArray) = jl(x) myjl(x::Tuple) = map(myjl, x) myjl(x::DI.Constant) = DI.Constant(myjl(DI.unwrap(x))) -myjl(x::DI.Cache) = DI.Cache(myjl(DI.unwrap(x))) +myjl(x::DI.Cache{<:AbstractArray}) = DI.Cache(myjl(DI.unwrap(x))) +myjl(x::DI.Cache{<:Union{Tuple,NamedTuple}}) = map(myjl, map(DI.Cache, DI.unwrap(x))) myjl(::Nothing) = nothing function myjl(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} diff --git a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl index c52849ad7..a7620b516 100644 --- a/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl +++ b/DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestStaticArraysExt/DifferentiationInterfaceTestStaticArraysExt.jl @@ -29,7 +29,10 @@ end mystatic(x::Tuple) = map(mystatic, x) mystatic(x::DI.Constant) = DI.Constant(mystatic(DI.unwrap(x))) -mystatic(x::DI.Cache) = DI.Cache(mymutablestatic(DI.unwrap(x))) +mystatic(x::DI.Cache{<:AbstractArray}) = DI.Cache(mymutablestatic(DI.unwrap(x))) +function mystatic(x::DI.Cache{<:Union{Tuple,NamedTuple}}) + return map(mystatic, map(DI.Cache, DI.unwrap(x))) +end mystatic(::Nothing) = nothing function mystatic(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} diff --git a/DifferentiationInterfaceTest/src/scenarios/default.jl b/DifferentiationInterfaceTest/src/scenarios/default.jl index ef9767df4..7a44f7390 100644 --- a/DifferentiationInterfaceTest/src/scenarios/default.jl +++ b/DifferentiationInterfaceTest/src/scenarios/default.jl @@ -559,6 +559,7 @@ function default_scenarios(; include_closurified=false, include_constantified=false, include_cachified=false, + use_tuples=false, ) x_ = 0.42 dx_ = 3.14 @@ -635,7 +636,7 @@ function default_scenarios(; include_normal && append!(final_scens, scens) include_closurified && append!(final_scens, closurify(scens)) include_constantified && append!(final_scens, constantify(scens)) - include_cachified && append!(final_scens, cachify(scens)) + include_cachified && append!(final_scens, cachify(scens; use_tuples=use_tuples)) return final_scens end diff --git a/DifferentiationInterfaceTest/src/scenarios/modify.jl b/DifferentiationInterfaceTest/src/scenarios/modify.jl index 6f306dc45..e991885f8 100644 --- a/DifferentiationInterfaceTest/src/scenarios/modify.jl +++ b/DifferentiationInterfaceTest/src/scenarios/modify.jl @@ -147,7 +147,7 @@ end """ constantify(scen::Scenario) -Return a new `Scenario` identical to `scen` except for the function `f`, which is made to accept an additional constant argument `a` by which the output is multiplied. +Return a new `Scenario` identical to `scen` except for the function `f`, which is made to accept an additional constant argument by which the output is multiplied. The output and result fields are updated accordingly. """ function constantify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} @@ -178,6 +178,11 @@ end Base.show(io::IO, f::StoreInCache) = print(io, "StoreInCache($(f.f))") +(sc::StoreInCache{:out})(x, y_cache::NamedTuple) = sc(x, y_cache.useful_cache) +(sc::StoreInCache{:in})(y, x, y_cache::NamedTuple) = sc(y, x, y_cache.useful_cache) +(sc::StoreInCache{:out})(x, y_cache::Tuple) = sc(x, first(y_cache)) +(sc::StoreInCache{:in})(y, x, y_cache::Tuple) = sc(y, x, first(y_cache)) + function (sc::StoreInCache{:out})(x, y_cache) y = sc.f(x) if y isa Number @@ -198,16 +203,26 @@ end """ cachify(scen::Scenario) -Return a new `Scenario` identical to `scen` except for the function `f`, which is made to accept an additional cache argument `a` to store the result before it is returned. +Return a new `Scenario` identical to `scen` except for the function `f`, which is made to accept an additional cache argument to store the result before it is returned. + +If `tup=true` the cache is a tuple of arrays, otherwise just an array. """ -function cachify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} +function cachify(scen::Scenario{op,pl_op,pl_fun}; use_tuples) where {op,pl_op,pl_fun} (; f,) = scen @assert isempty(scen.contexts) cache_f = StoreInCache{pl_fun}(f) - y_cache = if scen.y isa Number - [myzero(scen.y)] + if use_tuples + y_cache = if scen.y isa Number + (; useful_cache=([myzero(scen.y)],), useless_cache=[myzero(scen.y)]) + else + (; useful_cache=(mysimilar(scen.y),), useless_cache=mysimilar(scen.y)) + end else - mysimilar(scen.y) + y_cache = if scen.y isa Number + [myzero(scen.y)] + else + mysimilar(scen.y) + end end return Scenario{op,pl_op,pl_fun}( cache_f; @@ -217,7 +232,7 @@ function cachify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} contexts=(Cache(y_cache),), res1=scen.res1, res2=scen.res2, - smaller=isnothing(scen.smaller) ? nothing : cachify(scen.smaller), + smaller=isnothing(scen.smaller) ? nothing : cachify(scen.smaller; use_tuples), name=isnothing(scen.name) ? nothing : scen.name * " [cachified]", ) end @@ -229,7 +244,7 @@ end closurify(scens::AbstractVector{<:Scenario}) = closurify.(scens) constantify(scens::AbstractVector{<:Scenario}) = constantify.(scens) -cachify(scens::AbstractVector{<:Scenario}) = cachify.(scens) +cachify(scens::AbstractVector{<:Scenario}; use_tuples) = cachify.(scens; use_tuples) function set_smaller( scen::Scenario{op,pl_op,pl_fun}, smaller::Scenario diff --git a/DifferentiationInterfaceTest/src/scenarios/sparse.jl b/DifferentiationInterfaceTest/src/scenarios/sparse.jl index 5f1cad7fa..99ec02140 100644 --- a/DifferentiationInterfaceTest/src/scenarios/sparse.jl +++ b/DifferentiationInterfaceTest/src/scenarios/sparse.jl @@ -325,7 +325,10 @@ end Create a vector of [`Scenario`](@ref)s with sparse array types, focused on sparse Jacobians and Hessians. """ function sparse_scenarios(; - band_sizes=[5, 10, 20], include_constantified=false, include_cachified=false + band_sizes=[5, 10, 20], + include_constantified=false, + include_cachified=false, + use_tuples=false, ) x_6 = float.(1:6) x_2_3 = float.(reshape(1:6, 2, 3)) @@ -347,6 +350,6 @@ function sparse_scenarios(; final_scens = Scenario[] append!(final_scens, scens) include_constantified && append!(final_scens, constantify(scens)) - include_cachified && append!(final_scens, cachify(scens)) + include_cachified && append!(final_scens, cachify(scens; use_tuples)) return final_scens end diff --git a/DifferentiationInterfaceTest/test/standard.jl b/DifferentiationInterfaceTest/test/standard.jl index f06f3cd75..5324dd580 100644 --- a/DifferentiationInterfaceTest/test/standard.jl +++ b/DifferentiationInterfaceTest/test/standard.jl @@ -33,7 +33,7 @@ sparse_backend = AutoSparse( test_differentiation( sparse_backend, - sparse_scenarios(; include_constantified=true); + sparse_scenarios(; include_cachified=true, use_tuples=true); sparsity=true, logging=LOGGING, ) diff --git a/DifferentiationInterfaceTest/test/weird.jl b/DifferentiationInterfaceTest/test/weird.jl index b2b27a8df..f5e88e1ff 100644 --- a/DifferentiationInterfaceTest/test/weird.jl +++ b/DifferentiationInterfaceTest/test/weird.jl @@ -28,12 +28,14 @@ gpu_scenarios(; include_closurified=true, include_batchified=true, include_cachified=true, + use_tuples=true, ) static_scenarios(; include_constantified=true, include_closurified=true, include_batchified=true, include_cachified=true, + use_tuples=true, ) ## Weird arrays @@ -54,7 +56,10 @@ test_differentiation(AutoZygote(), gpu_scenarios(); excluded=SECOND_ORDER, loggi test_differentiation( AutoFiniteDiff(), default_scenarios(; - include_normal=false, include_closurified=true, include_cachified=true + include_normal=false, + include_closurified=true, + include_cachified=true, + use_tuples=true, ); excluded=SECOND_ORDER, logging=LOGGING,