Skip to content
21 changes: 19 additions & 2 deletions DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"]
DifferentiationInterfaceGTPSAExt = "GTPSA"
DifferentiationInterfaceMooncakeExt = "Mooncake"
DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff"
DifferentiationInterfacePolyesterForwardDiffExt = ["PolyesterForwardDiff", "ForwardDiff", "DiffResults"]
DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"]
DifferentiationInterfaceSparseArraysExt = "SparseArrays"
DifferentiationInterfaceSparseConnectivityTracerExt = "SparseConnectivityTracer"
Expand Down Expand Up @@ -109,4 +109,21 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "ExplicitImports", "JET", "JLArrays", "JuliaFormatter", "Pkg", "Random", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test"]
test = [
"ADTypes",
"Aqua",
"ComponentArrays",
"DataFrames",
"ExplicitImports",
"JET",
"JLArrays",
"JuliaFormatter",
"Pkg",
"Random",
"SparseArrays",
"SparseConnectivityTracer",
"SparseMatrixColorings",
"StableRNGs",
"StaticArrays",
"Test",
]
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,7 @@ struct ChainRulesPullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep
end

function DI.prepare_pullback(
f,
::AutoReverseChainRules,
x,
ty::NTuple,
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
f, ::AutoReverseChainRules, x, ty::NTuple, contexts::Vararg{DI.GeneralizedConstant,C}
) where {C}
return DI.NoPullbackPrep()
end
Expand All @@ -21,7 +17,7 @@ function DI.prepare_pullback_same_point(
backend::AutoReverseChainRules,
x,
ty::NTuple,
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
contexts::Vararg{DI.GeneralizedConstant,C},
) where {C}
rc = ruleconfig(backend)
y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...)
Expand All @@ -34,7 +30,7 @@ function DI.value_and_pullback(
backend::AutoReverseChainRules,
x,
ty::NTuple,
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
contexts::Vararg{DI.GeneralizedConstant,C},
) where {C}
rc = ruleconfig(backend)
y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...)
Expand All @@ -50,7 +46,7 @@ function DI.value_and_pullback(
::AutoReverseChainRules,
x,
ty::NTuple,
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
contexts::Vararg{DI.GeneralizedConstant,C},
) where {C}
(; y, pb) = prep
tx = map(ty) do dy
Expand All @@ -65,7 +61,7 @@ function DI.pullback(
::AutoReverseChainRules,
x,
ty::NTuple,
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
contexts::Vararg{DI.GeneralizedConstant,C},
) where {C}
(; pb) = prep
tx = map(ty) do dy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ force_annotation(f::F) where {F<:Annotation} = f
force_annotation(f::F) where {F} = Const(f)

@inline function _translate(
::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Constant,DI.BackendContext}
::AutoEnzyme, ::Mode, ::Val{B}, c::DI.GeneralizedConstant
) where {B}
return Const(DI.unwrap(c))
end

@inline function _translate(
backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.Cache
::AutoEnzyme, ::Mode, ::Val{B}, c::DI.GeneralizedCache
) where {B}
if B == 1
return Duplicated(DI.unwrap(c), make_zero(DI.unwrap(c)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ using FiniteDiff:
using LinearAlgebra: dot, mul!

DI.check_available(::AutoFiniteDiff) = true
DI.inner_preparation_behavior(::AutoFiniteDiff) = DI.PrepareInnerSimple()

# see https://github.com/SciML/ADTypes.jl/issues/33

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using LinearAlgebra: dot

DI.check_available(::AutoFiniteDifferences) = true
DI.inplace_support(::AutoFiniteDifferences) = DI.InPlaceNotSupported()
DI.inner_preparation_behavior(::AutoFiniteDifferences) = DI.PrepareInnerSimple()

## Pushforward

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ using ForwardDiff:
value

DI.check_available(::AutoForwardDiff) = true
DI.inner_preparation_behavior(::AutoForwardDiff) = DI.PrepareInnerOverload()

include("utils.jl")
include("onearg.jl")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,22 @@
DI.overloaded_input_type(prep::ForwardDiffOneArgPushforwardPrep) = typeof(prep.xdual_tmp)
DI.overloaded_input_type(prep::ForwardDiffTwoArgPushforwardPrep) = typeof(prep.xdual_tmp)

function DI.overloaded_input(
::typeof(DI.pushforward), f::F, backend::AutoForwardDiff, x, tx::NTuple{B}
) where {F,B}
T = tag_type(f, backend, x)
xdual = make_dual(T, x, tx)
return xdual
end

function DI.overloaded_input(
::typeof(DI.pushforward), f!::F, y, backend::AutoForwardDiff, x, tx::NTuple{B}
) where {F,B}
T = tag_type(f!, backend, x)
xdual = make_dual(T, x, tx)
return xdual
end

## Derivative
function DI.overloaded_input_type(prep::ForwardDiffOneArgDerivativePrep)
return DI.overloaded_input_type(prep.pushforward_prep)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ function DI.value_and_gradient!(
if (
isnothing(chunksize) &&
T === Nothing &&
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
contexts isa NTuple{C,DI.GeneralizedConstant}
)
fc = DI.with_contexts(f, contexts...)
result = DiffResult(zero(eltype(x)), (grad,))
Expand All @@ -292,7 +292,7 @@ function DI.value_and_gradient(
if (
isnothing(chunksize) &&
T === Nothing &&
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
contexts isa NTuple{C,DI.GeneralizedConstant}
)
fc = DI.with_contexts(f, contexts...)
result = GradientResult(x)
Expand All @@ -310,7 +310,7 @@ function DI.gradient!(
if (
isnothing(chunksize) &&
T === Nothing &&
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
contexts isa NTuple{C,DI.GeneralizedConstant}
)
fc = DI.with_contexts(f, contexts...)
return gradient!(grad, fc, x)
Expand All @@ -326,7 +326,7 @@ function DI.gradient(
if (
isnothing(chunksize) &&
T === Nothing &&
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
contexts isa NTuple{C,DI.GeneralizedConstant}
)
fc = DI.with_contexts(f, contexts...)
return gradient(fc, x)
Expand Down Expand Up @@ -435,7 +435,7 @@ function DI.value_and_jacobian!(
if (
isnothing(chunksize) &&
T === Nothing &&
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
contexts isa NTuple{C,DI.GeneralizedConstant}
)
fc = DI.with_contexts(f, contexts...)
y = fc(x)
Expand All @@ -456,7 +456,7 @@ function DI.value_and_jacobian(
if (
isnothing(chunksize) &&
T === Nothing &&
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
contexts isa NTuple{C,DI.GeneralizedConstant}
)
fc = DI.with_contexts(f, contexts...)
return fc(x), jacobian(fc, x)
Expand All @@ -472,7 +472,7 @@ function DI.jacobian!(
if (
isnothing(chunksize) &&
T === Nothing &&
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
contexts isa NTuple{C,DI.GeneralizedConstant}
)
fc = DI.with_contexts(f, contexts...)
return jacobian!(jac, fc, x)
Expand All @@ -488,7 +488,7 @@ function DI.jacobian(
if (
isnothing(chunksize) &&
T === Nothing &&
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
contexts isa NTuple{C,DI.GeneralizedConstant}
)
fc = DI.with_contexts(f, contexts...)
return jacobian(fc, x)
Expand Down Expand Up @@ -738,7 +738,7 @@ function DI.hessian!(
if (
isnothing(chunksize) &&
T === Nothing &&
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
contexts isa NTuple{C,DI.GeneralizedConstant}
)
fc = DI.with_contexts(f, contexts...)
return hessian!(hess, fc, x)
Expand All @@ -754,7 +754,7 @@ function DI.hessian(
if (
isnothing(chunksize) &&
T === Nothing &&
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
contexts isa NTuple{C,DI.GeneralizedConstant}
)
fc = DI.with_contexts(f, contexts...)
return hessian(fc, x)
Expand All @@ -775,7 +775,7 @@ function DI.value_gradient_and_hessian!(
if (
isnothing(chunksize) &&
T === Nothing &&
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
contexts isa NTuple{C,DI.GeneralizedConstant}
)
fc = DI.with_contexts(f, contexts...)
result = DiffResult(one(eltype(x)), (grad, hess))
Expand All @@ -796,7 +796,7 @@ function DI.value_gradient_and_hessian(
if (
isnothing(chunksize) &&
T === Nothing &&
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
contexts isa NTuple{C,DI.GeneralizedConstant}
)
fc = DI.with_contexts(f, contexts...)
result = HessianResult(x)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ end
function DI.value_and_derivative(
f!::F, y, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C}
) where {F,C,chunksize,T}
if (T === Nothing && contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend})
if (T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant})
fc! = DI.with_contexts(f!, contexts...)
result = MutableDiffResult(y, (similar(y),))
result = derivative!(result, fc!, y, x)
Expand All @@ -131,7 +131,7 @@ end
function DI.value_and_derivative!(
f!::F, y, der, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C}
) where {F,C,chunksize,T}
if (T === Nothing && contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend})
if (T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant})
fc! = DI.with_contexts(f!, contexts...)
result = MutableDiffResult(y, (der,))
result = derivative!(result, fc!, y, x)
Expand All @@ -145,7 +145,7 @@ end
function DI.derivative(
f!::F, y, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C}
) where {F,C,chunksize,T}
if (T === Nothing && contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend})
if (T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant})
fc! = DI.with_contexts(f!, contexts...)
return derivative(fc!, y, x)
else
Expand All @@ -157,7 +157,7 @@ end
function DI.derivative!(
f!::F, y, der, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C}
) where {F,C,chunksize,T}
if (T === Nothing && contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend})
if (T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant})
fc! = DI.with_contexts(f!, contexts...)
return derivative!(der, fc!, y, x)
else
Expand Down Expand Up @@ -188,7 +188,7 @@ function DI.prepare!_derivative(
old_prep::ForwardDiffTwoArgDerivativePrep,
backend::AutoForwardDiff,
x,
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
contexts::Vararg{DI.GeneralizedConstant,C},
) where {F,C}
if y isa Vector
(; config) = old_prep
Expand Down Expand Up @@ -283,7 +283,7 @@ function DI.value_and_jacobian(
if (
isnothing(chunksize) &&
T === Nothing &&
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
contexts isa NTuple{C,DI.GeneralizedConstant}
)
fc! = DI.with_contexts(f!, contexts...)
jac = similar(y, length(y), length(x))
Expand All @@ -302,7 +302,7 @@ function DI.value_and_jacobian!(
if (
isnothing(chunksize) &&
T === Nothing &&
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
contexts isa NTuple{C,DI.GeneralizedConstant}
)
fc! = DI.with_contexts(f!, contexts...)
result = MutableDiffResult(y, (jac,))
Expand All @@ -320,7 +320,7 @@ function DI.jacobian(
if (
isnothing(chunksize) &&
T === Nothing &&
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
contexts isa NTuple{C,DI.GeneralizedConstant}
)
fc! = DI.with_contexts(f!, contexts...)
return jacobian(fc!, y, x)
Expand All @@ -336,7 +336,7 @@ function DI.jacobian!(
if (
isnothing(chunksize) &&
T === Nothing &&
contexts isa NTuple{C,DI.ConstantOrFunctionOrBackend}
contexts isa NTuple{C,DI.GeneralizedConstant}
)
fc! = DI.with_contexts(f!, contexts...)
return jacobian!(jac, fc!, y, x)
Expand Down Expand Up @@ -369,7 +369,7 @@ function DI.prepare!_jacobian(
old_prep::ForwardDiffTwoArgJacobianPrep,
backend::AutoForwardDiff,
x,
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
contexts::Vararg{DI.GeneralizedConstant,C},
) where {F,C}
if x isa Vector && y isa Vector
(; config) = old_prep
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,11 @@ function mypartials!(::Type{T}, ty::NTuple{B}, ydual) where {T,B}
return ty
end

# store preparation result with the right input eltype
struct PrepContext{T<:DI.Prep} <: DI.Context
data::T
function _translate(
::Type{D}, c::Union{DI.GeneralizedConstant,DI.PrepContext}
) where {D<:Dual}
return DI.unwrap(c)
end

NotCache = Union{DI.ConstantOrFunctionOrBackend,PrepContext}

_translate(::Type{D}, c::NotCache) where {D<:Dual} = DI.unwrap(c)
function _translate(::Type{D}, c::DI.Cache) where {D<:Dual}
c0 = DI.unwrap(c)
return similar(c0, D)
Expand All @@ -102,7 +99,11 @@ function translate(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:Dual,C}
return new_contexts
end

_translate_toprep(::Type{D}, c::NotCache) where {D<:Dual} = nothing
function _translate_toprep(
::Type{D}, c::Union{DI.GeneralizedConstant,DI.PrepContext}
) where {D<:Dual}
return nothing
end
function _translate_toprep(::Type{D}, c::DI.Cache) where {D<:Dual}
c0 = DI.unwrap(c)
return similar(c0, D)
Expand All @@ -115,7 +116,7 @@ function translate_toprep(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:D
return new_contexts
end

_translate_prepared(c::NotCache, _pc) = DI.unwrap(c)
_translate_prepared(c::Union{DI.GeneralizedConstant,DI.PrepContext}, _pc) = DI.unwrap(c)
_translate_prepared(_c::DI.Cache, pc) = pc

function translate_prepared(
Expand Down
Loading