Skip to content
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.6.45"
version = "0.6.46"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ force_annotation(f::F) where {F} = Const(f)
end

@inline function _translate(
backend::AutoEnzyme, ::Mode, ::Val{B}, c::DI.GeneralizedCache
backend::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Cache,DI.PrepContext}
) where {B}
if B == 1
return Duplicated(DI.unwrap(c), make_zero(DI.unwrap(c)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ function _translate(
end
function _translate(::Type{D}, c::DI.Cache) where {D<:Dual}
c0 = DI.unwrap(c)
return similar(c0, D)
return DI.recursive_similar(c0, D)
end

function translate(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:Dual,C}
Expand All @@ -106,7 +106,7 @@ function _translate_toprep(
end
function _translate_toprep(::Type{D}, c::DI.Cache) where {D<:Dual}
c0 = DI.unwrap(c)
return similar(c0, D)
return DI.recursive_similar(c0, D)
end

function translate_toprep(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:Dual,C}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,23 @@ import DifferentiationInterface as DI
using SparseConnectivityTracer:
TracerSparsityDetector, TracerLocalSparsityDetector, jacobian_buffer, hessian_buffer

@inline _jacobian_translate(detector, c::DI.Constant) = DI.unwrap(c)
@inline function _jacobian_translate(detector, c::DI.Cache{<:AbstractArray})
return jacobian_buffer(DI.unwrap(c), detector)
@inline _translate(::Type, c::DI.Constant) = DI.unwrap(c)
@inline function _translate(::Type{T}, c::DI.Cache) where {T}
return DI.recursive_similar(DI.unwrap(c), T)
end

function jacobian_translate(detector, contexts::Vararg{DI.Context,C}) where {C}
function jacobian_translate(detector, x, contexts::Vararg{DI.Context,C}) where {C}
T = eltype(jacobian_buffer(x, detector))
new_contexts = map(contexts) do c
_jacobian_translate(detector, c)
_translate(T, c)
end
return new_contexts
end

@inline _hessian_translate(detector, c::DI.Constant) = DI.unwrap(c)
@inline function _hessian_translate(detector, c::DI.Cache{<:AbstractArray})
return hessian_buffer(DI.unwrap(c), detector)
end

function hessian_translate(detector, contexts::Vararg{DI.Context,C}) where {C}
function hessian_translate(detector, x, contexts::Vararg{DI.Context,C}) where {C}
T = eltype(hessian_buffer(x, detector))
new_contexts = map(contexts) do c
_hessian_translate(detector, c)
_translate(T, c)
end
return new_contexts
end
Expand All @@ -35,7 +32,7 @@ function DI.jacobian_sparsity_with_contexts(
x,
contexts::Vararg{DI.Context,C},
) where {F,C}
contexts_tracer = jacobian_translate(detector, contexts...)
contexts_tracer = jacobian_translate(detector, x, contexts...)
fc = DI.FixTail(f, contexts_tracer...)
return jacobian_sparsity(fc, x, detector)
end
Expand All @@ -47,7 +44,7 @@ function DI.jacobian_sparsity_with_contexts(
x,
contexts::Vararg{DI.Context,C},
) where {F,C}
contexts_tracer = jacobian_translate(detector, contexts...)
contexts_tracer = jacobian_translate(detector, x, contexts...)
fc! = DI.FixTail(f!, contexts_tracer...)
return jacobian_sparsity(fc!, y, x, detector)
end
Expand All @@ -58,7 +55,7 @@ function DI.hessian_sparsity_with_contexts(
x,
contexts::Vararg{DI.Context,C},
) where {F,C}
contexts_tracer = hessian_translate(detector, contexts...)
contexts_tracer = hessian_translate(detector, x, contexts...)
fc = DI.FixTail(f, contexts_tracer...)
return hessian_sparsity(fc, x, detector)
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ DI.check_available(::AutoZygote) = true
DI.inplace_support(::AutoZygote) = DI.InPlaceNotSupported()

translate(c::DI.Context) = DI.unwrap(c)
translate(c::DI.Cache) = Buffer(DI.unwrap(c))
translate(c::DI.Cache{<:AbstractArray}) = Buffer(DI.unwrap(c))
function translate(c::DI.Cache{<:Union{Tuple,NamedTuple}})
return map(translate, map(DI.Cache, DI.unwrap(c)))
end

## Pullback

Expand Down
11 changes: 4 additions & 7 deletions DifferentiationInterface/src/utils/context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ Abstract supertype for additional context arguments, which can be passed to diff
abstract type Context end

abstract type GeneralizedConstant <: Context end
abstract type GeneralizedCache <: Context end

unwrap(c::Context) = c.data
Base.:(==)(c1::Context, c2::Context) = unwrap(c1) == unwrap(c2)
Expand Down Expand Up @@ -78,7 +77,7 @@ The initial values present inside the cache do not matter.
For some backends, preparation allocates the required memory for `Cache` contexts with the right element type, similar to [PreallocationTools.jl](https://github.com/SciML/PreallocationTools.jl).

!!! warning
Most backends require any `Cache` context to be an `AbstractArray`.
Some backends require any `Cache` context to be an `AbstractArray`, others accept nested (named) tuples of `AbstractArray`s.

# Example

Expand All @@ -97,7 +96,7 @@ julia> gradient(f, prep, AutoForwardDiff(), [3.0, 4.0], Cache(zeros(2)))
1.0
````
"""
struct Cache{T} <: GeneralizedCache
struct Cache{T} <: Context
data::T
end

Expand All @@ -114,12 +113,10 @@ struct BackendContext{T} <: GeneralizedConstant
data::T
end

struct PrepContext{T} <: GeneralizedCache
struct PrepContext{T} <: Context
data::T
end

struct UnknownContext <: Context end

## Context manipulation

struct Rewrap{C,T}
Expand All @@ -146,4 +143,4 @@ function with_contexts(f::F, contexts::Vararg{Context,N}) where {F,N}
end

adapt_eltype(c::Constant, ::Type) = c
adapt_eltype(c::Cache, ::Type{T}) where {T} = Cache(similar(unwrap(c), T))
adapt_eltype(c::Cache, ::Type{T}) where {T} = Cache(recursive_similar(unwrap(c), T))
12 changes: 12 additions & 0 deletions DifferentiationInterface/src/utils/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,15 @@ At the moment, this only returns `false` for `StaticArrays.SArray`.
"""
ismutable_array(::Type) = true
ismutable_array(x) = ismutable_array(typeof(x))

"""
recursive_similar(x, T)

Apply `similar(_, T)` recursively to `x` or its components.

Works if `x` is an `AbstractArray` or a (nested) `NTuple` / `NamedTuple` of `AbstractArray`s.
"""
recursive_similar(x::AbstractArray, ::Type{T}) where {T} = similar(x, T)
function recursive_similar(x::Union{Tuple,NamedTuple}, ::Type{T}) where {T}
return map(xi -> recursive_similar(xi, T), x)
end
2 changes: 1 addition & 1 deletion DifferentiationInterface/test/Back/Enzyme/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ end;

test_differentiation(
backends[2],
default_scenarios(; include_normal=false, include_cachified=true);
default_scenarios(; include_normal=false, include_cachified=true, use_tuples=true);
excluded=SECOND_ORDER,
logging=LOGGING,
)
Expand Down
4 changes: 3 additions & 1 deletion DifferentiationInterface/test/Back/FiniteDiff/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ end
@testset "Dense" begin
test_differentiation(
AutoFiniteDiff(),
default_scenarios(; include_constantified=true, include_cachified=true);
default_scenarios(;
include_constantified=true, include_cachified=true, use_tuples=true
);
excluded=[:second_derivative, :hvp],
logging=LOGGING,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ end

test_differentiation(
AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(3, 1)),
default_scenarios(; include_constantified=true, include_cachified=true);
default_scenarios(;
include_constantified=true, include_cachified=true, use_tuples=true
);
excluded=SECOND_ORDER,
logging=LOGGING,
);
5 changes: 4 additions & 1 deletion DifferentiationInterface/test/Back/ForwardDiff/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ end
test_differentiation(
AutoForwardDiff(),
default_scenarios(;
include_normal=false, include_batchified=false, include_cachified=true
include_normal=false,
include_batchified=false,
include_cachified=true,
use_tuples=true,
);
logging=LOGGING,
)
Expand Down
4 changes: 3 additions & 1 deletion DifferentiationInterface/test/Back/Mooncake/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ end

test_differentiation(
backends,
default_scenarios(; include_constantified=true, include_cachified=true);
default_scenarios(;
include_constantified=true, include_cachified=true, use_tuples=true
);
excluded=SECOND_ORDER,
logging=LOGGING,
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ end

test_differentiation(
backends,
default_scenarios(; include_constantified=true, include_cachified=true);
default_scenarios(;
include_constantified=true, include_cachified=true, use_tuples=true
);
logging=LOGGING,
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ end

test_differentiation(
AutoFastDifferentiation(),
default_scenarios(; include_constantified=true, include_cachified=true);
default_scenarios(;
include_constantified=true, include_cachified=true, use_tuples=false
);
logging=LOGGING,
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ test_differentiation(

test_differentiation(
AutoSymbolics(),
default_scenarios(; include_normal=false, include_cachified=true);
default_scenarios(; include_normal=false, include_cachified=true, use_tuples=false);
excluded=[:jacobian], # TODO: figure out why this fails
logging=LOGGING,
);
Expand Down
4 changes: 3 additions & 1 deletion DifferentiationInterface/test/Back/Zygote/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ end
@testset "Dense" begin
test_differentiation(
backends,
default_scenarios(; include_constantified=true, include_cachified=true);
default_scenarios(;
include_constantified=true, include_cachified=true, use_tuples=true
);
excluded=[:second_derivative],
logging=LOGGING,
)
Expand Down
9 changes: 9 additions & 0 deletions DifferentiationInterface/test/Core/Internals/linalg.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
using DifferentiationInterface: recursive_similar
using Test

@test recursive_similar(ones(Int, 2), Float32) isa Vector{Float32}
@test recursive_similar((ones(Int, 2), ones(Bool, 3, 4)), Float32) isa
Tuple{Vector{Float32},Matrix{Float32}}
@test recursive_similar((a=ones(Int, 2), b=(ones(Bool, 3, 4),)), Float32) isa
@NamedTuple{a::Vector{Float32}, b::Tuple{Matrix{Float32}}}
@test_throws MethodError recursive_similar(1, Float32)
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ end
MyAutoSparse.(
vcat(adaptive_backends, MixedMode(adaptive_backends[1], adaptive_backends[2]))
),
sparse_scenarios(; include_constantified=true, include_cachified=true);
sparse_scenarios(;
include_constantified=true, include_cachified=true, use_tuples=true
);
sparsity=true,
logging=LOGGING,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ myjl(x::Number) = x
myjl(x::AbstractArray) = jl(x)
myjl(x::Tuple) = map(myjl, x)
myjl(x::DI.Constant) = DI.Constant(myjl(DI.unwrap(x)))
myjl(x::DI.Cache) = DI.Cache(myjl(DI.unwrap(x)))
myjl(x::DI.Cache{<:AbstractArray}) = DI.Cache(myjl(DI.unwrap(x)))
myjl(x::DI.Cache{<:Union{Tuple,NamedTuple}}) = map(myjl, map(DI.Cache, DI.unwrap(x)))
myjl(::Nothing) = nothing

function myjl(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ end

mystatic(x::Tuple) = map(mystatic, x)
mystatic(x::DI.Constant) = DI.Constant(mystatic(DI.unwrap(x)))
mystatic(x::DI.Cache) = DI.Cache(mymutablestatic(DI.unwrap(x)))
mystatic(x::DI.Cache{<:AbstractArray}) = DI.Cache(mymutablestatic(DI.unwrap(x)))
function mystatic(x::DI.Cache{<:Union{Tuple,NamedTuple}})
return map(mystatic, map(DI.Cache, DI.unwrap(x)))
end
mystatic(::Nothing) = nothing

function mystatic(scen::DIT.Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
Expand Down
3 changes: 2 additions & 1 deletion DifferentiationInterfaceTest/src/scenarios/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,7 @@ function default_scenarios(;
include_closurified=false,
include_constantified=false,
include_cachified=false,
use_tuples=false,
)
x_ = 0.42
dx_ = 3.14
Expand Down Expand Up @@ -635,7 +636,7 @@ function default_scenarios(;
include_normal && append!(final_scens, scens)
include_closurified && append!(final_scens, closurify(scens))
include_constantified && append!(final_scens, constantify(scens))
include_cachified && append!(final_scens, cachify(scens))
include_cachified && append!(final_scens, cachify(scens; use_tuples=use_tuples))

return final_scens
end
31 changes: 23 additions & 8 deletions DifferentiationInterfaceTest/src/scenarios/modify.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ end
"""
constantify(scen::Scenario)

Return a new `Scenario` identical to `scen` except for the function `f`, which is made to accept an additional constant argument `a` by which the output is multiplied.
Return a new `Scenario` identical to `scen` except for the function `f`, which is made to accept an additional constant argument by which the output is multiplied.
The output and result fields are updated accordingly.
"""
function constantify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
Expand Down Expand Up @@ -178,6 +178,11 @@ end

Base.show(io::IO, f::StoreInCache) = print(io, "StoreInCache($(f.f))")

(sc::StoreInCache{:out})(x, y_cache::NamedTuple) = sc(x, y_cache.useful_cache)
(sc::StoreInCache{:in})(y, x, y_cache::NamedTuple) = sc(y, x, y_cache.useful_cache)
(sc::StoreInCache{:out})(x, y_cache::Tuple) = sc(x, first(y_cache))
(sc::StoreInCache{:in})(y, x, y_cache::Tuple) = sc(y, x, first(y_cache))

function (sc::StoreInCache{:out})(x, y_cache)
y = sc.f(x)
if y isa Number
Expand All @@ -198,16 +203,26 @@ end
"""
cachify(scen::Scenario)

Return a new `Scenario` identical to `scen` except for the function `f`, which is made to accept an additional cache argument `a` to store the result before it is returned.
Return a new `Scenario` identical to `scen` except for the function `f`, which is made to accept an additional cache argument to store the result before it is returned.

If `tup=true` the cache is a tuple of arrays, otherwise just an array.
"""
function cachify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
function cachify(scen::Scenario{op,pl_op,pl_fun}; use_tuples) where {op,pl_op,pl_fun}
(; f,) = scen
@assert isempty(scen.contexts)
cache_f = StoreInCache{pl_fun}(f)
y_cache = if scen.y isa Number
[myzero(scen.y)]
if use_tuples
y_cache = if scen.y isa Number
(; useful_cache=([myzero(scen.y)],), useless_cache=[myzero(scen.y)])
else
(; useful_cache=(mysimilar(scen.y),), useless_cache=mysimilar(scen.y))
end
else
mysimilar(scen.y)
y_cache = if scen.y isa Number
[myzero(scen.y)]
else
mysimilar(scen.y)
end
end
return Scenario{op,pl_op,pl_fun}(
cache_f;
Expand All @@ -217,7 +232,7 @@ function cachify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun}
contexts=(Cache(y_cache),),
res1=scen.res1,
res2=scen.res2,
smaller=isnothing(scen.smaller) ? nothing : cachify(scen.smaller),
smaller=isnothing(scen.smaller) ? nothing : cachify(scen.smaller; use_tuples),
name=isnothing(scen.name) ? nothing : scen.name * " [cachified]",
)
end
Expand All @@ -229,7 +244,7 @@ end

closurify(scens::AbstractVector{<:Scenario}) = closurify.(scens)
constantify(scens::AbstractVector{<:Scenario}) = constantify.(scens)
cachify(scens::AbstractVector{<:Scenario}) = cachify.(scens)
cachify(scens::AbstractVector{<:Scenario}; use_tuples) = cachify.(scens; use_tuples)

function set_smaller(
scen::Scenario{op,pl_op,pl_fun}, smaller::Scenario
Expand Down
Loading