diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 1a1f819dd..c6c880c80 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.46" +version = "0.6.47" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/DifferentiationInterface/docs/src/api.md b/DifferentiationInterface/docs/src/api.md index ace752251..486a5830f 100644 --- a/DifferentiationInterface/docs/src/api.md +++ b/DifferentiationInterface/docs/src/api.md @@ -14,6 +14,7 @@ DifferentiationInterface Context Constant Cache +ConstantOrCache ``` ## First order diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl index 5575cbb3e..eb71d0aec 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl @@ -54,7 +54,7 @@ force_annotation(f::F) where {F} = Const(f) end @inline function _translate( - backend::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Cache,DI.PrepContext} + backend::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Cache,DI.GeneralizedConstantOrCache} ) where {B} if B == 1 return Duplicated(DI.unwrap(c), make_zero(DI.unwrap(c))) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl index fa3d73aa2..cf298f263 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/onearg.jl @@ -12,7 +12,7 @@ function DI.prepare_pushforward_nokwarg( strict::Val, f, backend::AutoFiniteDiff, x, tx::NTuple, contexts::Vararg{DI.Context,C}; ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) y = fc(x) cache = if x isa Number || y isa Number nothing @@ -89,7 +89,7 @@ function DI.pushforward( ) where {SIG,C} DI.check_prep(f, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) ty = map(tx) do dx finite_difference_jvp(fc, x, dx, prep.cache; relstep, absstep, dir) end @@ -106,7 +106,7 @@ function DI.value_and_pushforward( ) where {SIG,C} DI.check_prep(f, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) y = fc(x) ty = map(tx) do dx finite_difference_jvp(fc, x, dx, prep.cache, y; relstep, absstep, dir) @@ -128,7 +128,7 @@ function DI.prepare_derivative_nokwarg( strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) y = fc(x) cache = if y isa Number nothing @@ -161,7 +161,7 @@ function DI.derivative( ) where {SIG,C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return finite_difference_derivative(fc, x, fdtype(backend); relstep, absstep, dir) end @@ -174,7 +174,7 @@ function DI.value_and_derivative( ) where {SIG,C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) y = fc(x) return ( y, @@ -195,7 +195,7 @@ function DI.derivative( ) where {SIG,C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return finite_difference_gradient(fc, x, prep.cache; relstep, absstep, dir) end @@ -209,7 +209,7 @@ function DI.derivative!( ) where {SIG,C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return finite_difference_gradient!(der, fc, x, prep.cache; relstep, absstep, dir) end @@ -221,7 +221,7 @@ function DI.value_and_derivative( contexts::Vararg{DI.Context,C}, ) where {SIG,C} DI.check_prep(f, prep, backend, x, contexts...) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) (; relstep, absstep, dir) = prep y = fc(x) return (y, finite_difference_gradient(fc, x, prep.cache; relstep, absstep, dir)) @@ -237,7 +237,7 @@ function DI.value_and_derivative!( ) where {SIG,C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return ( fc(x), finite_difference_gradient!(der, fc, x, prep.cache; relstep, absstep, dir) ) @@ -257,7 +257,7 @@ function DI.prepare_gradient_nokwarg( strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) y = fc(x) df = zero(y) .* x cache = GradientCache(df, x, fdtype(backend)) @@ -284,7 +284,7 @@ function DI.gradient( ) where {C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return finite_difference_gradient(fc, x, prep.cache; relstep, absstep, dir) end @@ -297,7 +297,7 @@ function DI.value_and_gradient( ) where {C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return fc(x), finite_difference_gradient(fc, x, prep.cache; relstep, absstep, dir) end @@ -311,7 +311,7 @@ function DI.gradient!( ) where {C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return finite_difference_gradient!(grad, fc, x, prep.cache; relstep, absstep, dir) end @@ -325,7 +325,7 @@ function DI.value_and_gradient!( ) where {C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return ( fc(x), finite_difference_gradient!(grad, fc, x, prep.cache; relstep, absstep, dir) ) @@ -345,7 +345,7 @@ function DI.prepare_jacobian_nokwarg( strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) y = fc(x) x1 = similar(x) fx = similar(y) @@ -374,7 +374,7 @@ function DI.jacobian( ) where {C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return finite_difference_jacobian(fc, x, prep.cache; relstep, absstep, dir) end @@ -386,7 +386,7 @@ function DI.value_and_jacobian( contexts::Vararg{DI.Context,C}, ) where {C} DI.check_prep(f, prep, backend, x, contexts...) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) (; relstep, absstep, dir) = prep y = fc(x) return (y, finite_difference_jacobian(fc, x, prep.cache, y; relstep, absstep, dir)) @@ -402,7 +402,7 @@ function DI.jacobian!( ) where {C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return copyto!( jac, finite_difference_jacobian( @@ -421,7 +421,7 @@ function DI.value_and_jacobian!( ) where {C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) y = fc(x) return ( y, @@ -450,7 +450,7 @@ function DI.prepare_hessian_nokwarg( strict::Val, f, backend::AutoFiniteDiff, x, contexts::Vararg{DI.Context,C} ) where {C} _sig = DI.signature(f, backend, x, contexts...; strict) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) y = fc(x) df = zero(y) .* x gradient_cache = GradientCache(df, x, fdtype(backend)) @@ -481,7 +481,7 @@ function DI.hessian( ) where {C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep_h, absstep_h) = prep - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return finite_difference_hessian( fc, x, prep.hessian_cache; relstep=relstep_h, absstep=absstep_h ) @@ -497,7 +497,7 @@ function DI.hessian!( ) where {C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep_h, absstep_h) = prep - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return finite_difference_hessian!( hess, fc, x, prep.hessian_cache; relstep=relstep_h, absstep=absstep_h ) @@ -512,7 +512,7 @@ function DI.value_gradient_and_hessian( ) where {C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep_g, absstep_g, relstep_h, absstep_h) = prep - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) grad = finite_difference_gradient( fc, x, prep.gradient_cache; relstep=relstep_g, absstep=absstep_g ) @@ -533,7 +533,7 @@ function DI.value_gradient_and_hessian!( ) where {C} DI.check_prep(f, prep, backend, x, contexts...) (; relstep_g, absstep_g, relstep_h, absstep_h) = prep - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) finite_difference_gradient!( grad, fc, x, prep.gradient_cache; relstep=relstep_g, absstep=absstep_g ) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl index 11bbcfbb9..35e2fb047 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDiffExt/twoarg.jl @@ -80,7 +80,7 @@ function DI.pushforward( ) where {SIG,C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) ty = map(tx) do dx dy = similar(y) finite_difference_jvp!(dy, fc!, x, dx, prep.cache; relstep, absstep, dir) @@ -100,7 +100,7 @@ function DI.value_and_pushforward( ) where {SIG,C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) ty = map(tx) do dx dy = similar(y) finite_difference_jvp!(dy, fc!, x, dx, prep.cache; relstep, absstep, dir) @@ -122,7 +122,7 @@ function DI.pushforward!( ) where {SIG,C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] finite_difference_jvp!(dy, fc!, x, dx, prep.cache; relstep, absstep, dir) @@ -142,7 +142,7 @@ function DI.value_and_pushforward!( ) where {SIG,C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) (; relstep, absstep, dir) = prep - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] finite_difference_jvp!(dy, fc!, x, dx, prep.cache; relstep, absstep, dir) @@ -214,7 +214,7 @@ function DI.value_and_derivative( ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) fc!(y, x) der = finite_difference_gradient(fc!, x, prep.cache; relstep, absstep, dir) return y, der @@ -231,7 +231,7 @@ function DI.value_and_derivative!( ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) fc!(y, x) finite_difference_gradient!(der, fc!, x, prep.cache; relstep, absstep, dir) return y, der @@ -247,7 +247,7 @@ function DI.derivative( ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) fc!(y, x) der = finite_difference_gradient(fc!, x, prep.cache; relstep, absstep, dir) return der @@ -264,7 +264,7 @@ function DI.derivative!( ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) finite_difference_gradient!(der, fc!, x, prep.cache; relstep, absstep, dir) return der end @@ -336,7 +336,7 @@ function DI.value_and_jacobian( ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) jac = similar(y, length(y), length(x)) finite_difference_jacobian!(jac, fc!, x, prep.cache; relstep, absstep, dir) fc!(y, x) @@ -354,7 +354,7 @@ function DI.value_and_jacobian!( ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) finite_difference_jacobian!(jac, fc!, x, prep.cache; relstep, absstep, dir) fc!(y, x) return y, jac @@ -370,7 +370,7 @@ function DI.jacobian( ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) jac = similar(y, length(y), length(x)) finite_difference_jacobian!(jac, fc!, x, prep.cache; relstep, absstep, dir) return jac @@ -387,7 +387,7 @@ function DI.jacobian!( ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) (; relstep, absstep, dir) = prep - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) finite_difference_jacobian!(jac, fc!, x, prep.cache; relstep, absstep, dir) return jac end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl index 68bd2918f..366c6ba31 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl @@ -32,7 +32,7 @@ function DI.pushforward( contexts::Vararg{DI.Context,C}, ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) ty = map(tx) do dx jvp(backend.fdm, fc, (x, dx)) end @@ -75,7 +75,7 @@ function DI.pullback( contexts::Vararg{DI.Context,C}, ) where {C} DI.check_prep(f, prep, backend, x, ty, contexts...) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) tx = map(ty) do dy only(j′vp(backend.fdm, fc, dy, x)) end @@ -112,7 +112,7 @@ function DI.gradient( contexts::Vararg{DI.Context,C}, ) where {C} DI.check_prep(f, prep, backend, x, contexts...) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return only(grad(backend.fdm, fc, x)) end @@ -169,7 +169,7 @@ function DI.jacobian( contexts::Vararg{DI.Context,C}, ) where {C} DI.check_prep(f, prep, backend, x, contexts...) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return only(jacobian(backend.fdm, fc, x)) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl index 30f769417..bdfcb54a9 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl @@ -292,7 +292,7 @@ function DI.value_and_gradient!( T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant} ) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) result = DiffResult(zero(eltype(x)), (grad,)) result = gradient!(result, fc, x) y = DR.value(result) @@ -312,7 +312,7 @@ function DI.value_and_gradient( T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant} ) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) result = GradientResult(x) result = gradient!(result, fc, x) return DR.value(result), DR.gradient(result) @@ -330,7 +330,7 @@ function DI.gradient!( T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant} ) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return gradient!(grad, fc, x) else prep = DI.prepare_gradient_nokwarg(Val(true), f, backend, x, contexts...) @@ -346,7 +346,7 @@ function DI.gradient( T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant} ) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return gradient(fc, x) else prep = DI.prepare_gradient_nokwarg(Val(true), f, backend, x, contexts...) @@ -387,7 +387,7 @@ function DI.value_and_gradient!( ) where {F,C} DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) - fc = DI.FixTail(f, contexts_dual...) + fc = DI.fix_tail(f, contexts_dual...) result = DiffResult(zero(eltype(x)), (grad,)) CHK = tag_type(backend) === Nothing if CHK @@ -408,7 +408,7 @@ function DI.value_and_gradient( ) where {F,C} DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) - fc = DI.FixTail(f, contexts_dual...) + fc = DI.fix_tail(f, contexts_dual...) result = GradientResult(x) CHK = tag_type(backend) === Nothing if CHK @@ -428,7 +428,7 @@ function DI.gradient!( ) where {F,C} DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) - fc = DI.FixTail(f, contexts_dual...) + fc = DI.fix_tail(f, contexts_dual...) CHK = tag_type(backend) === Nothing if CHK checktag(prep.config, f, x) @@ -445,7 +445,7 @@ function DI.gradient( ) where {F,C} DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) - fc = DI.FixTail(f, contexts_dual...) + fc = DI.fix_tail(f, contexts_dual...) CHK = tag_type(backend) === Nothing if CHK checktag(prep.config, f, x) @@ -465,7 +465,7 @@ function DI.value_and_jacobian!( T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant} ) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) y = fc(x) result = DiffResult(y, (jac,)) result = jacobian!(result, fc, x) @@ -486,7 +486,7 @@ function DI.value_and_jacobian( T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant} ) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return fc(x), jacobian(fc, x) else prep = DI.prepare_jacobian_nokwarg(Val(true), f, backend, x, contexts...) @@ -502,7 +502,7 @@ function DI.jacobian!( T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant} ) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return jacobian!(jac, fc, x) else prep = DI.prepare_jacobian_nokwarg(Val(true), f, backend, x, contexts...) @@ -518,7 +518,7 @@ function DI.jacobian( T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant} ) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return jacobian(fc, x) else prep = DI.prepare_jacobian_nokwarg(Val(true), f, backend, x, contexts...) @@ -555,7 +555,7 @@ function DI.value_and_jacobian!( ) where {F,C} DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) - fc = DI.FixTail(f, contexts_dual...) + fc = DI.fix_tail(f, contexts_dual...) y = fc(x) result = DiffResult(y, (jac,)) CHK = tag_type(backend) === Nothing @@ -577,7 +577,7 @@ function DI.value_and_jacobian( ) where {F,C} DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) - fc = DI.FixTail(f, contexts_dual...) + fc = DI.fix_tail(f, contexts_dual...) CHK = tag_type(backend) === Nothing if CHK checktag(prep.config, f, x) @@ -595,7 +595,7 @@ function DI.jacobian!( ) where {F,C} DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) - fc = DI.FixTail(f, contexts_dual...) + fc = DI.fix_tail(f, contexts_dual...) CHK = tag_type(backend) === Nothing if CHK checktag(prep.config, f, x) @@ -612,7 +612,7 @@ function DI.jacobian( ) where {F,C} DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) - fc = DI.FixTail(f, contexts_dual...) + fc = DI.fix_tail(f, contexts_dual...) CHK = tag_type(backend) === Nothing if CHK checktag(prep.config, f, x) @@ -718,7 +718,7 @@ function DI.hessian!( T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant} ) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return hessian!(hess, fc, x) else prep = DI.prepare_hessian_nokwarg(Val(true), f, backend, x, contexts...) @@ -734,7 +734,7 @@ function DI.hessian( T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant} ) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return hessian(fc, x) else prep = DI.prepare_hessian_nokwarg(Val(true), f, backend, x, contexts...) @@ -755,7 +755,7 @@ function DI.value_gradient_and_hessian!( T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant} ) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) result = DiffResult(one(eltype(x)), (grad, hess)) result = hessian!(result, fc, x) y = DR.value(result) @@ -776,7 +776,7 @@ function DI.value_gradient_and_hessian( T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant} ) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) result = HessianResult(x) result = hessian!(result, fc, x) return (DR.value(result), DR.gradient(result), DR.hessian(result)) @@ -818,7 +818,7 @@ function DI.hessian!( ) where {F,C} DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) - fc = DI.FixTail(f, contexts_dual...) + fc = DI.fix_tail(f, contexts_dual...) CHK = tag_type(backend) === Nothing if CHK checktag(prep.array_config, f, x) @@ -835,7 +835,7 @@ function DI.hessian( ) where {F,C} DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) - fc = DI.FixTail(f, contexts_dual...) + fc = DI.fix_tail(f, contexts_dual...) CHK = tag_type(backend) === Nothing if CHK checktag(prep.array_config, f, x) @@ -854,7 +854,7 @@ function DI.value_gradient_and_hessian!( ) where {F,C} DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) - fc = DI.FixTail(f, contexts_dual...) + fc = DI.fix_tail(f, contexts_dual...) result = DiffResult(one(eltype(x)), (grad, hess)) CHK = tag_type(backend) === Nothing if CHK @@ -876,7 +876,7 @@ function DI.value_gradient_and_hessian( ) where {F,C} DI.check_prep(f, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) - fc = DI.FixTail(f, contexts_dual...) + fc = DI.fix_tail(f, contexts_dual...) result = HessianResult(x) CHK = tag_type(backend) === Nothing if CHK diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl index 5073838b4..29d6cc1ab 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl @@ -132,7 +132,7 @@ function DI.value_and_derivative( f!::F, y, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C} ) where {F,C,chunksize,T} if (T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant}) - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) result = MutableDiffResult(y, (similar(y),)) result = derivative!(result, fc!, y, x) return DiffResults.value(result), DiffResults.derivative(result) @@ -146,7 +146,7 @@ function DI.value_and_derivative!( f!::F, y, der, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C} ) where {F,C,chunksize,T} if (T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant}) - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) result = MutableDiffResult(y, (der,)) result = derivative!(result, fc!, y, x) return DiffResults.value(result), DiffResults.derivative(result) @@ -160,7 +160,7 @@ function DI.derivative( f!::F, y, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C} ) where {F,C,chunksize,T} if (T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant}) - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) return derivative(fc!, y, x) else prep = DI.prepare_derivative_nokwarg(Val(true), f!, y, backend, x, contexts...) @@ -172,7 +172,7 @@ function DI.derivative!( f!::F, y, der, backend::AutoForwardDiff{chunksize,T}, x, contexts::Vararg{DI.Context,C} ) where {F,C,chunksize,T} if (T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant}) - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) return derivative!(der, fc!, y, x) else prep = DI.prepare_derivative_nokwarg(Val(true), f!, y, backend, x, contexts...) @@ -228,7 +228,7 @@ function DI.value_and_derivative( ) where {F,C} DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) - fc! = DI.FixTail(f!, contexts_dual...) + fc! = DI.fix_tail(f!, contexts_dual...) result = MutableDiffResult(y, (similar(y),)) CHK = tag_type(backend) === Nothing if CHK @@ -249,7 +249,7 @@ function DI.value_and_derivative!( ) where {F,C} DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) - fc! = DI.FixTail(f!, contexts_dual...) + fc! = DI.fix_tail(f!, contexts_dual...) result = MutableDiffResult(y, (der,)) CHK = tag_type(backend) === Nothing if CHK @@ -269,7 +269,7 @@ function DI.derivative( ) where {F,C} DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) - fc! = DI.FixTail(f!, contexts_dual...) + fc! = DI.fix_tail(f!, contexts_dual...) CHK = tag_type(backend) === Nothing if CHK checktag(prep.config, f!, x) @@ -288,7 +288,7 @@ function DI.derivative!( ) where {F,C} DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) - fc! = DI.FixTail(f!, contexts_dual...) + fc! = DI.fix_tail(f!, contexts_dual...) CHK = tag_type(backend) === Nothing if CHK checktag(prep.config, f!, x) @@ -308,7 +308,7 @@ function DI.value_and_jacobian( T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant} ) - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) jac = similar(y, length(y), length(x)) result = MutableDiffResult(y, (jac,)) result = jacobian!(result, fc!, y, x) @@ -327,7 +327,7 @@ function DI.value_and_jacobian!( T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant} ) - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) result = MutableDiffResult(y, (jac,)) result = jacobian!(result, fc!, y, x) return DiffResults.value(result), DiffResults.jacobian(result) @@ -345,7 +345,7 @@ function DI.jacobian( T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant} ) - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) return jacobian(fc!, y, x) else prep = DI.prepare_jacobian_nokwarg(Val(true), f!, y, backend, x, contexts...) @@ -361,7 +361,7 @@ function DI.jacobian!( T === Nothing && contexts isa NTuple{C,DI.GeneralizedConstant} ) - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) return jacobian!(jac, fc!, y, x) else prep = DI.prepare_jacobian_nokwarg(Val(true), f!, y, backend, x, contexts...) @@ -420,7 +420,7 @@ function DI.value_and_jacobian( ) where {F,C} DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) - fc! = DI.FixTail(f!, contexts_dual...) + fc! = DI.fix_tail(f!, contexts_dual...) jac = similar(y, length(y), length(x)) result = MutableDiffResult(y, (jac,)) CHK = tag_type(backend) === Nothing @@ -442,7 +442,7 @@ function DI.value_and_jacobian!( ) where {F,C} DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) - fc! = DI.FixTail(f!, contexts_dual...) + fc! = DI.fix_tail(f!, contexts_dual...) result = MutableDiffResult(y, (jac,)) CHK = tag_type(backend) === Nothing if CHK @@ -462,7 +462,7 @@ function DI.jacobian( ) where {F,C} DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) - fc! = DI.FixTail(f!, contexts_dual...) + fc! = DI.fix_tail(f!, contexts_dual...) CHK = tag_type(backend) === Nothing if CHK checktag(prep.config, f!, x) @@ -481,7 +481,7 @@ function DI.jacobian!( ) where {F,C} DI.check_prep(f!, y, prep, backend, x, contexts...) contexts_dual = translate_prepared(contexts, prep.contexts_dual) - fc! = DI.FixTail(f!, contexts_dual...) + fc! = DI.fix_tail(f!, contexts_dual...) CHK = tag_type(backend) === Nothing if CHK checktag(prep.config, f!, x) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl index 44afdd9ce..dd78760a3 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl @@ -83,7 +83,7 @@ function mypartials!(::Type{T}, ty::NTuple{B}, ydual) where {T,B} end function _translate( - ::Type{D}, c::Union{DI.GeneralizedConstant,DI.PrepContext} + ::Type{D}, c::Union{DI.GeneralizedConstant,DI.GeneralizedConstantOrCache} ) where {D<:Dual} return DI.unwrap(c) end @@ -100,7 +100,7 @@ function translate(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:Dual,C} end function _translate_toprep( - ::Type{D}, c::Union{DI.GeneralizedConstant,DI.PrepContext} + ::Type{D}, c::Union{DI.GeneralizedConstant,DI.GeneralizedConstantOrCache} ) where {D<:Dual} return nothing end @@ -116,7 +116,11 @@ function translate_toprep(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:D return new_contexts end -_translate_prepared(c::Union{DI.GeneralizedConstant,DI.PrepContext}, _pc) = DI.unwrap(c) +function _translate_prepared( + c::Union{DI.GeneralizedConstant,DI.GeneralizedConstantOrCache}, _pc +) + return DI.unwrap(c) +end _translate_prepared(_c::DI.Cache, pc) = pc function translate_prepared( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl index 514550d26..d20b31f4b 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl @@ -44,7 +44,7 @@ function DI.pushforward( contexts::Vararg{DI.Constant,C}, ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) ty = map(tx) do dx foreach((t, xi, dxi) -> (t[0] = xi; t[1] = dxi), prep.xt, x, dx) yt = fc(prep.xt) @@ -68,7 +68,7 @@ function DI.pushforward!( contexts::Vararg{DI.Constant,C}, ) where {C} DI.check_prep(f, prep, backend, x, tx, contexts...) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] foreach((t, xi, dxi) -> (t[0] = xi; t[1] = dxi), prep.xt, x, dx) @@ -139,7 +139,7 @@ function DI.gradient( ) where {C} DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part (slopes set in prepare) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) yt = fc(prep.xt) grad = similar(x, GTPSA.numtype(yt)) GTPSA.gradient!(grad, yt; include_params=true, unsafe_inbounds=true) @@ -156,7 +156,7 @@ function DI.gradient!( ) where {C} DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) yt = fc(prep.xt) GTPSA.gradient!(grad, yt; include_params=true, unsafe_inbounds=true) return grad @@ -167,7 +167,7 @@ function DI.value_and_gradient( ) where {C} DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part (slopes set in prepare) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) yt = fc(prep.xt) grad = similar(x, GTPSA.numtype(yt)) GTPSA.gradient!(grad, yt; include_params=true, unsafe_inbounds=true) @@ -184,7 +184,7 @@ function DI.value_and_gradient!( ) where {C} DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part (slopes set in prepare) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) yt = fc(prep.xt) GTPSA.gradient!(grad, yt; include_params=true, unsafe_inbounds=true) return yt[0], grad @@ -224,7 +224,7 @@ function DI.jacobian( ) where {C} DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) yt = fc(prep.xt) jac = similar(x, GTPSA.numtype(eltype(yt)), (length(yt), length(x))) GTPSA.jacobian!(jac, yt; include_params=true, unsafe_inbounds=true) @@ -241,7 +241,7 @@ function DI.jacobian!( ) where {C} DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) yt = fc(prep.xt) GTPSA.jacobian!(jac, yt; include_params=true, unsafe_inbounds=true) return jac @@ -252,7 +252,7 @@ function DI.value_and_jacobian( ) where {C} DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) yt = fc(prep.xt) jac = similar(x, GTPSA.numtype(eltype(yt)), (length(yt), length(x))) GTPSA.jacobian!(jac, yt; include_params=true, unsafe_inbounds=true) @@ -270,7 +270,7 @@ function DI.value_and_jacobian!( ) where {C} DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) yt = fc(prep.xt) GTPSA.jacobian!(jac, yt; include_params=true, unsafe_inbounds=true) y = map(t -> t[0], yt) @@ -307,7 +307,7 @@ function DI.second_derivative( ) where {D,C} DI.check_prep(f, prep, backend, x, contexts...) prep.xt[0] = x - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) yt = fc(prep.xt) if D == Nothing idx2 = 2 @@ -336,7 +336,7 @@ function DI.second_derivative!( ) where {D,C} DI.check_prep(f, prep, backend, x, contexts...) prep.xt[0] = x - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) yt = fc(prep.xt) if D == Nothing idx2 = 2 @@ -358,7 +358,7 @@ function DI.value_derivative_and_second_derivative( ) where {D,C} DI.check_prep(f, prep, backend, x, contexts...) prep.xt[0] = x - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) yt = fc(prep.xt) if D == Nothing idx2 = 2 @@ -390,7 +390,7 @@ function DI.value_derivative_and_second_derivative!( ) where {D,C} DI.check_prep(f, prep, backend, x, contexts...) prep.xt[0] = x - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) yt = fc(prep.xt) y = map(t -> t[0], yt) if D == Nothing @@ -452,7 +452,7 @@ function DI.hessian( ) where {D,C} DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) yt = fc(prep.xt) hess = similar(x, GTPSA.numtype(yt), (length(x), length(x))) unsafe_fast = D == Nothing ? true : false @@ -477,7 +477,7 @@ function DI.hessian!( ) where {D,C} DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) yt = fc(prep.xt) unsafe_fast = D == Nothing ? true : false GTPSA.hessian!( @@ -500,7 +500,7 @@ function DI.value_gradient_and_hessian( ) where {D,C} DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) yt = fc(prep.xt) grad = similar(x, GTPSA.numtype(yt)) GTPSA.gradient!(grad, yt; include_params=true, unsafe_inbounds=true) @@ -528,7 +528,7 @@ function DI.value_gradient_and_hessian!( ) where {D,C} DI.check_prep(f, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) yt = fc(prep.xt) GTPSA.gradient!(grad, yt; include_params=true, unsafe_inbounds=true) unsafe_fast = D == Nothing ? true : false @@ -554,7 +554,7 @@ function DI.prepare_hvp_nokwarg( ) where {C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) hessprep = DI.prepare_hessian_nokwarg(strict, f, backend, x, contexts...) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) hess = similar(x, typeof(fc(x)), (length(x), length(x))) return GTPSAOneArgHVPPrep(_sig, hessprep, hess) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl index b106923b5..9f8cc3cf3 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl @@ -54,7 +54,7 @@ function DI.pushforward( contexts::Vararg{DI.Constant,C}, ) where {C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) ty = map(tx) do dx foreach((t, xi, dxi) -> (t[0] = xi; t[1] = dxi), prep.xt, x, dx) fc!(prep.yt, prep.xt) @@ -76,7 +76,7 @@ function DI.pushforward!( contexts::Vararg{DI.Constant,C}, ) where {C} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) for b in eachindex(tx, ty) dx, dy = tx[b], ty[b] foreach((t, xi, dxi) -> (t[0] = xi; t[1] = dxi), prep.xt, x, dx) @@ -163,7 +163,7 @@ function DI.jacobian( ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) fc!(prep.yt, prep.xt) jac = similar(x, GTPSA.numtype(eltype(prep.yt)), (length(prep.yt), length(x))) GTPSA.jacobian!(jac, prep.yt; include_params=true, unsafe_inbounds=true) @@ -182,7 +182,7 @@ function DI.jacobian!( ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) foreach((t, xi) -> t[0] = xi, prep.xt, x) # Set the scalar part - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) fc!(prep.yt, prep.xt) GTPSA.jacobian!(jac, prep.yt; include_params=true, unsafe_inbounds=true) map!(t -> t[0], y, prep.yt) diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl index a646488c5..44af4237c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/onearg.jl @@ -187,7 +187,7 @@ function DI.value_and_gradient!( ) where {C} DI.check_prep(f, prep, backend, x, contexts...) if contexts isa NTuple{C,DI.GeneralizedConstant} - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) threaded_gradient!(fc, grad, x, prep.chunk) return fc(x), grad else @@ -208,7 +208,7 @@ function DI.gradient!( ) where {C} DI.check_prep(f, prep, backend, x, contexts...) if contexts isa NTuple{C,DI.GeneralizedConstant} - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) threaded_gradient!(fc, grad, x, prep.chunk) return grad else @@ -278,7 +278,7 @@ function DI.value_and_jacobian!( ) where {C} DI.check_prep(f, prep, backend, x, contexts...) if contexts isa NTuple{C,DI.GeneralizedConstant} - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return fc(x), threaded_jacobian!(fc, jac, x, prep.chunk) else return DI.value_and_jacobian!( @@ -297,7 +297,7 @@ function DI.jacobian!( ) where {C} DI.check_prep(f, prep, backend, x, contexts...) if contexts isa NTuple{C,DI.GeneralizedConstant} - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return threaded_jacobian!(fc, jac, x, prep.chunk) else return DI.jacobian!( diff --git a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl index 78609ea0e..2b270e4cd 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfacePolyesterForwardDiffExt/twoarg.jl @@ -196,7 +196,7 @@ function DI.value_and_jacobian( ) where {K,C} DI.check_prep(f!, y, prep, backend, x, contexts...) if contexts isa NTuple{C,DI.GeneralizedConstant} - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) jac = similar(y, length(y), length(x)) threaded_jacobian!(fc!, y, jac, x, prep.chunk) fc!(y, x) @@ -219,7 +219,7 @@ function DI.value_and_jacobian!( ) where {K,C} DI.check_prep(f!, y, prep, backend, x, contexts...) if contexts isa NTuple{C,DI.GeneralizedConstant} - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) threaded_jacobian!(fc!, y, jac, x, prep.chunk) fc!(y, x) return y, jac @@ -240,7 +240,7 @@ function DI.jacobian( ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) if contexts isa NTuple{C,DI.GeneralizedConstant} - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) jac = similar(y, length(y), length(x)) threaded_jacobian!(fc!, y, jac, x, prep.chunk) return jac @@ -262,7 +262,7 @@ function DI.jacobian!( ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) if contexts isa NTuple{C,DI.GeneralizedConstant} - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) threaded_jacobian!(fc!, y, jac, x, prep.chunk) return jac else diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl index c9d10490a..753e67b20 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/onearg.jl @@ -16,7 +16,7 @@ function DI.value_and_pullback( contexts::Vararg{DI.Context,C}, ) where {C} DI.check_prep(f, prep, backend, x, ty, contexts...) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) y = fc(x) dotclosure(z, dy) = dot(fc(z), dy) tx = map(ty) do dy @@ -39,7 +39,7 @@ function DI.value_and_pullback!( contexts::Vararg{DI.Context,C}, ) where {C} DI.check_prep(f, prep, backend, x, ty, contexts...) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) y = fc(x) dotclosure(z, dy) = dot(fc(z), dy) for b in eachindex(tx, ty) @@ -160,7 +160,7 @@ function DI.value_and_gradient!( contexts::Vararg{DI.Context,C}, ) where {C} DI.check_prep(f, prep, backend, x, contexts...) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) result = MutableDiffResult(zero(eltype(x)), (grad,)) # ReverseDiff#251 result = gradient!(result, fc, x, prep.config) return DR.value(result), grad # ReverseDiff#269 @@ -174,7 +174,7 @@ function DI.value_and_gradient( contexts::Vararg{DI.Context,C}, ) where {C} DI.check_prep(f, prep, backend, x, contexts...) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) # GradientResult tries to mutate an SArray result = MutableDiffResult(zero(eltype(x)), (similar(x),)) result = gradient!(result, fc, x, prep.config) @@ -190,7 +190,7 @@ function DI.gradient!( contexts::Vararg{DI.Context,C}, ) where {C} DI.check_prep(f, prep, backend, x, contexts...) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return gradient!(grad, fc, x, prep.config) end @@ -202,7 +202,7 @@ function DI.gradient( contexts::Vararg{DI.Context,C}, ) where {C} DI.check_prep(f, prep, backend, x, contexts...) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return gradient(fc, x, prep.config) end @@ -297,7 +297,7 @@ function DI.value_and_jacobian!( contexts::Vararg{DI.Context,C}, ) where {C} DI.check_prep(f, prep, backend, x, contexts...) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) y = fc(x) result = DiffResult(y, (jac,)) result = jacobian!(result, fc, x, prep.config) @@ -314,7 +314,7 @@ function DI.value_and_jacobian( contexts::Vararg{DI.Context,C}, ) where {C} DI.check_prep(f, prep, backend, x, contexts...) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return fc(x), jacobian(fc, x, prep.config) end @@ -327,7 +327,7 @@ function DI.jacobian!( contexts::Vararg{DI.Context,C}, ) where {C} DI.check_prep(f, prep, backend, x, contexts...) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return jacobian!(jac, fc, x, prep.config) end @@ -339,7 +339,7 @@ function DI.jacobian( contexts::Vararg{DI.Context,C}, ) where {C} DI.check_prep(f, prep, backend, x, contexts...) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return jacobian(fc, x, prep.config) end @@ -430,7 +430,7 @@ function DI.hessian!( contexts::Vararg{DI.Context,C}, ) where {C} DI.check_prep(f, prep, backend, x, contexts...) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return hessian!(hess, fc, x, prep.hessian_config) end @@ -442,7 +442,7 @@ function DI.hessian( contexts::Vararg{DI.Context,C}, ) where {C} DI.check_prep(f, prep, backend, x, contexts...) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) return hessian(fc, x, prep.hessian_config) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl index 40553d456..b58ba7dd4 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceReverseDiffExt/twoarg.jl @@ -25,7 +25,7 @@ function DI.value_and_pullback( contexts::Vararg{DI.Context,C}, ) where {C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) function dotclosure(x, dy) y_copy = similar(y, eltype(x)) fc!(y_copy, x) @@ -49,7 +49,7 @@ function DI.value_and_pullback!( contexts::Vararg{DI.Context,C}, ) where {C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) function dotclosure(x, dy) y_copy = similar(y, eltype(x)) fc!(y_copy, x) @@ -73,7 +73,7 @@ function DI.pullback( contexts::Vararg{DI.Context,C}, ) where {C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) function dotclosure(x, dy) y_copy = similar(y, eltype(x)) fc!(y_copy, x) @@ -96,7 +96,7 @@ function DI.pullback!( contexts::Vararg{DI.Context,C}, ) where {C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) function dotclosure(x, dy) y_copy = similar(y, eltype(x)) fc!(y_copy, x) @@ -222,7 +222,7 @@ function DI.value_and_jacobian( contexts::Vararg{DI.Context,C}, ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) jac = similar(y, length(y), length(x)) result = MutableDiffResult(y, (jac,)) result = jacobian!(result, fc!, y, x, prep.config) @@ -239,7 +239,7 @@ function DI.value_and_jacobian!( contexts::Vararg{DI.Context,C}, ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) result = MutableDiffResult(y, (jac,)) result = jacobian!(result, fc!, y, x, prep.config) return DiffResults.value(result), DiffResults.derivative(result) @@ -254,7 +254,7 @@ function DI.jacobian( contexts::Vararg{DI.Context,C}, ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) jac = jacobian(fc!, y, x, prep.config) return jac end @@ -269,7 +269,7 @@ function DI.jacobian!( contexts::Vararg{DI.Context,C}, ) where {C} DI.check_prep(f!, y, prep, backend, x, contexts...) - fc! = DI.with_contexts(f!, contexts...) + fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...) jac = jacobian!(jac, fc!, y, x, prep.config) return jac end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl index a01a804ef..796af1bb0 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl @@ -33,7 +33,7 @@ function DI.jacobian_sparsity_with_contexts( contexts::Vararg{DI.Context,C}, ) where {F,C} contexts_tracer = jacobian_translate(detector, x, contexts...) - fc = DI.FixTail(f, contexts_tracer...) + fc = DI.fix_tail(f, contexts_tracer...) return jacobian_sparsity(fc, x, detector) end @@ -45,7 +45,7 @@ function DI.jacobian_sparsity_with_contexts( contexts::Vararg{DI.Context,C}, ) where {F,C} contexts_tracer = jacobian_translate(detector, x, contexts...) - fc! = DI.FixTail(f!, contexts_tracer...) + fc! = DI.fix_tail(f!, contexts_tracer...) return jacobian_sparsity(fc!, y, x, detector) end @@ -56,7 +56,7 @@ function DI.hessian_sparsity_with_contexts( contexts::Vararg{DI.Context,C}, ) where {F,C} contexts_tracer = hessian_translate(detector, x, contexts...) - fc = DI.FixTail(f, contexts_tracer...) + fc = DI.fix_tail(f, contexts_tracer...) return hessian_sparsity(fc, x, detector) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl index 72763eb6a..0843c5c05 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl @@ -283,7 +283,7 @@ function DI.hessian( contexts::Vararg{DI.GeneralizedConstant,C}, ) where {C} DI.check_prep(f, prep, backend, x, contexts...) - fc = DI.with_contexts(f, contexts...) + fc = DI.fix_tail(f, map(DI.unwrap, contexts)...) hess = hessian(fc, x) return hess end diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index 984a194f7..ab7ff9aea 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -73,7 +73,7 @@ include("misc/overloading.jl") ## Exported -export Context, Constant, Cache +export Context, Constant, Cache, ConstantOrCache export MixedMode, SecondOrder export value_and_pushforward!, value_and_pushforward diff --git a/DifferentiationInterface/src/utils/context.jl b/DifferentiationInterface/src/utils/context.jl index 201c0ae2e..4bca5f1e2 100644 --- a/DifferentiationInterface/src/utils/context.jl +++ b/DifferentiationInterface/src/utils/context.jl @@ -1,15 +1,3 @@ -struct FixTail{F,A<:Tuple} - f::F - tail_args::A - function FixTail(f::F, tail_args::Vararg{Any,N}) where {F,N} - return new{F,typeof(tail_args)}(f, tail_args) - end -end - -function (ft::FixTail)(args::Vararg{Any,N}) where {N} - return ft.f(args..., ft.tail_args...) -end - """ Context @@ -19,10 +7,12 @@ Abstract supertype for additional context arguments, which can be passed to diff - [`Constant`](@ref) - [`Cache`](@ref) +- [`ConstantOrCache`](@ref) """ abstract type Context end abstract type GeneralizedConstant <: Context end +abstract type GeneralizedConstantOrCache <: Context end unwrap(c::Context) = c.data Base.:(==)(c1::Context, c2::Context) = unwrap(c1) == unwrap(c2) @@ -66,6 +56,7 @@ end constant_maker(c) = Constant(c) maker(::Constant) = constant_maker +adapt_eltype(c::Constant, ::Type) = c """ Cache @@ -102,23 +93,65 @@ end cache_maker(c) = Cache(c) maker(::Cache) = cache_maker +adapt_eltype(c::Cache, ::Type{T}) where {T} = Cache(recursive_similar(unwrap(c), T)) + +""" + ConstantOrCache + +Concrete type of [`Context`](@ref) argument which can contain a mixture of constants and caches, passed along to the backend without modification. + +Unlike for [`Cache`](@ref), it is up to the user to ensure that the internal storage can adapt to the required element types, for instance by using [PreallocationTools.jl](https://github.com/SciML/PreallocationTools.jl) directly. +""" +struct ConstantOrCache{T} <: GeneralizedConstantOrCache + data::T +end + +constantorcache_maker(c) = ConstantOrCache(c) +maker(::ConstantOrCache) = constantorcache_maker +adapt_eltype(c::ConstantOrCache, ::Type) = c ## Internal contexts for passing stuff around +""" + FunctionContext + +Private type of [`Context`](@ref) argument used for passing functions inside second-order differentiation. + +Behaves differently for Enzyme only, where the function can be annotated. +""" struct FunctionContext{T} <: GeneralizedConstant data::T end +""" + BackendContext + +Private type of [`Context`](@ref) argument used for passing backends inside second-order differentiation. +""" struct BackendContext{T} <: GeneralizedConstant data::T end -struct PrepContext{T} <: Context +""" + PrepContext + +Private type of [`Context`](@ref) argument used for passing preparation results inside second-order differentiation. + +Conceptually similar to [`ConstantOrCache`](@ref) because we assume that preparation was performed with the right types so we don't change anything. +""" +struct PrepContext{T} <: GeneralizedConstantOrCache data::T end ## Context manipulation +""" + Rewrap + +Utility for recording context types of additional arguments (e.g. `Constant` or `Cache`) and re-wrapping them into their types after they have been unwrapped. + +Useful for second-order differentiation. +""" struct Rewrap{C,T} context_makers::T function Rewrap(contexts::Vararg{Context,C}) where {C} @@ -135,12 +168,32 @@ function (r::Rewrap{C,T})(unannotated_contexts::Vararg{Any,C}) where {C,T} end end -with_contexts(f) = f +## Closures + +""" + FixTail + +Closure around a function `f` and a set of tail argument `tail_args` such that +``` +(ft::FixTail)(args...) = ft.f(args..., ft.tail_args...) +``` +""" +struct FixTail{F,A<:Tuple} + f::F + tail_args::A + function FixTail(f::F, tail_args::Vararg{Any,N}) where {F,N} + return new{F,typeof(tail_args)}(f, tail_args) + end +end -function with_contexts(f::F, contexts::Vararg{Context,N}) where {F,N} - tail_args = map(unwrap, contexts) - return FixTail(f, tail_args...) +function (ft::FixTail)(args::Vararg{Any,N}) where {N} + return ft.f(args..., ft.tail_args...) end -adapt_eltype(c::Constant, ::Type) = c -adapt_eltype(c::Cache, ::Type{T}) where {T} = Cache(recursive_similar(unwrap(c), T)) +""" + fix_tail(f, tail_args...) + +Convenience for constructing a [`FixTail`](@ref), with a shortcut when there are no tail arguments. +""" +@inline fix_tail(f::F) where {F} = f +fix_tail(f::F, args::Vararg{Any,N}) where {F,N} = FixTail(f, args...) diff --git a/DifferentiationInterface/src/utils/sparse.jl b/DifferentiationInterface/src/utils/sparse.jl index 68fe591e1..f79650503 100644 --- a/DifferentiationInterface/src/utils/sparse.jl +++ b/DifferentiationInterface/src/utils/sparse.jl @@ -7,13 +7,13 @@ Wrapper around [`ADTypes.jacobian_sparsity`](@extref ADTypes.jacobian_sparsity) function jacobian_sparsity_with_contexts( f::F, detector::AbstractSparsityDetector, x, contexts::Vararg{Context,C} ) where {F,C} - return jacobian_sparsity(with_contexts(f, contexts...), x, detector) + return jacobian_sparsity(fix_tail(f, map(unwrap, contexts)...), x, detector) end function jacobian_sparsity_with_contexts( f!::F, y, detector::AbstractSparsityDetector, x, contexts::Vararg{Context,C} ) where {F,C} - return jacobian_sparsity(with_contexts(f!, contexts...), y, x, detector) + return jacobian_sparsity(fix_tail(f!, map(unwrap, contexts)...), y, x, detector) end """ @@ -24,5 +24,5 @@ Wrapper around [`ADTypes.hessian_sparsity`](@extref ADTypes.hessian_sparsity) en function hessian_sparsity_with_contexts( f::F, detector::AbstractSparsityDetector, x, contexts::Vararg{Context,C} ) where {F,C} - return hessian_sparsity(with_contexts(f, contexts...), x, detector) + return hessian_sparsity(fix_tail(f, map(unwrap, contexts)...), x, detector) end diff --git a/DifferentiationInterface/test/Back/Enzyme/test.jl b/DifferentiationInterface/test/Back/Enzyme/test.jl index 2aa2c3268..5100e7b43 100644 --- a/DifferentiationInterface/test/Back/Enzyme/test.jl +++ b/DifferentiationInterface/test/Back/Enzyme/test.jl @@ -55,7 +55,12 @@ end; test_differentiation( backends[2], - default_scenarios(; include_normal=false, include_cachified=true, use_tuples=true); + default_scenarios(; + include_normal=false, + include_cachified=true, + include_constantorcachified=true, + use_tuples=true, + ); excluded=SECOND_ORDER, logging=LOGGING, ) diff --git a/DifferentiationInterface/test/Back/FiniteDiff/test.jl b/DifferentiationInterface/test/Back/FiniteDiff/test.jl index dc111f45f..e4be24ebc 100644 --- a/DifferentiationInterface/test/Back/FiniteDiff/test.jl +++ b/DifferentiationInterface/test/Back/FiniteDiff/test.jl @@ -23,7 +23,10 @@ end test_differentiation( AutoFiniteDiff(), default_scenarios(; - include_constantified=true, include_cachified=true, use_tuples=true + include_constantified=true, + include_cachified=true, + include_constantorcachified=true, + use_tuples=true, ); excluded=[:second_derivative, :hvp], logging=LOGGING, diff --git a/DifferentiationInterface/test/Back/ForwardDiff/test.jl b/DifferentiationInterface/test/Back/ForwardDiff/test.jl index 0b9ff0d50..ada83edf2 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/test.jl @@ -39,6 +39,7 @@ end include_normal=false, include_batchified=false, include_cachified=true, + include_constantorcachified=true, use_tuples=true, ); logging=LOGGING, diff --git a/DifferentiationInterface/test/Core/Internals/context.jl b/DifferentiationInterface/test/Core/Internals/context.jl index fa898b6fe..6ac8be0f6 100644 --- a/DifferentiationInterface/test/Core/Internals/context.jl +++ b/DifferentiationInterface/test/Core/Internals/context.jl @@ -1,13 +1,13 @@ using DifferentiationInterface -using DifferentiationInterface: Rewrap, with_contexts +using DifferentiationInterface: Rewrap, fix_tail using Test f1(x) = x -g1 = @inferred with_contexts(f1) +g1 = @inferred fix_tail(f1) @test @inferred g1(4) == 4 f2(x, a, b) = a * x + b -g2 = @inferred with_contexts(f2, Constant(2), Constant(3)) +g2 = @inferred fix_tail(f2, 2, 3) @test @inferred g2(4) == 2 * 4 + 3 contexts = () diff --git a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl index 34d93c16a..48d84e022 100644 --- a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl +++ b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl @@ -67,7 +67,8 @@ end ) test_differentiation( - second_order_hvp_backends; + second_order_hvp_backends, + default_scenarios(; include_constantorcachified=true); excluded=vcat(FIRST_ORDER, :hessian, :second_derivative), logging=LOGGING, ) diff --git a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl index a40645fbf..6a804f33f 100644 --- a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl +++ b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl @@ -89,7 +89,7 @@ using DifferentiationInterface: inplace_support, pushforward_performance, pullback_performance -using DifferentiationInterface: Rewrap, Context, Constant, Cache, unwrap +using DifferentiationInterface: Rewrap, Context, Constant, Cache, ConstantOrCache, unwrap using DifferentiationInterface: PreparationMismatchError using DocStringExtensions: TYPEDFIELDS, TYPEDSIGNATURES using JET: @test_opt diff --git a/DifferentiationInterfaceTest/src/scenarios/default.jl b/DifferentiationInterfaceTest/src/scenarios/default.jl index 7a44f7390..7523365c8 100644 --- a/DifferentiationInterfaceTest/src/scenarios/default.jl +++ b/DifferentiationInterfaceTest/src/scenarios/default.jl @@ -559,6 +559,7 @@ function default_scenarios(; include_closurified=false, include_constantified=false, include_cachified=false, + include_constantorcachified=false, use_tuples=false, ) x_ = 0.42 @@ -637,6 +638,7 @@ function default_scenarios(; include_closurified && append!(final_scens, closurify(scens)) include_constantified && append!(final_scens, constantify(scens)) include_cachified && append!(final_scens, cachify(scens; use_tuples=use_tuples)) + include_constantorcachified && append!(final_scens, constantorcachify(scens)) return final_scens end diff --git a/DifferentiationInterfaceTest/src/scenarios/modify.jl b/DifferentiationInterfaceTest/src/scenarios/modify.jl index e991885f8..cd1fb521e 100644 --- a/DifferentiationInterfaceTest/src/scenarios/modify.jl +++ b/DifferentiationInterfaceTest/src/scenarios/modify.jl @@ -44,6 +44,22 @@ function change_function( ) end +function set_smaller( + scen::Scenario{op,pl_op,pl_fun}, smaller::Scenario +) where {op,pl_op,pl_fun} + @assert scen.f == smaller.f + return Scenario{op,pl_op,pl_fun}( + scen.f; + x=scen.x, + y=scen.y, + tang=scen.tang, + contexts=scen.contexts, + res1=scen.res1, + res2=scen.res2, + smaller=smaller, + ) +end + """ batchify(scen::Scenario) @@ -183,7 +199,7 @@ Base.show(io::IO, f::StoreInCache) = print(io, "StoreInCache($(f.f))") (sc::StoreInCache{:out})(x, y_cache::Tuple) = sc(x, first(y_cache)) (sc::StoreInCache{:in})(y, x, y_cache::Tuple) = sc(y, x, first(y_cache)) -function (sc::StoreInCache{:out})(x, y_cache) +function (sc::StoreInCache{:out})(x, y_cache) # no annotation otherwise Zygote.Buffer cries y = sc.f(x) if y isa Number y_cache[1] = y @@ -237,27 +253,86 @@ function cachify(scen::Scenario{op,pl_op,pl_fun}; use_tuples) where {op,pl_op,pl ) end -function batchify(scens::AbstractVector{<:Scenario}) - batchifiable_scens = filter(s -> operator(s) in (:pushforward, :pullback, :hvp), scens) - return batchify.(batchifiable_scens) +struct MultiplyByConstantAndStoreInCache{pl_fun,F} <: FunctionModifier + f::F end -closurify(scens::AbstractVector{<:Scenario}) = closurify.(scens) -constantify(scens::AbstractVector{<:Scenario}) = constantify.(scens) -cachify(scens::AbstractVector{<:Scenario}; use_tuples) = cachify.(scens; use_tuples) +function MultiplyByConstantAndStoreInCache{pl_fun}(f::F) where {pl_fun,F} + return MultiplyByConstantAndStoreInCache{pl_fun,F}(f) +end -function set_smaller( - scen::Scenario{op,pl_op,pl_fun}, smaller::Scenario -) where {op,pl_op,pl_fun} - @assert scen.f == smaller.f +function Base.show(io::IO, f::MultiplyByConstantAndStoreInCache) + return print(io, "MultiplyByConstantAndStoreInCache($(f.f))") +end + +function (sc::MultiplyByConstantAndStoreInCache{:out})(x, constantorcache) + (; constant, cache) = constantorcache + y = constant * sc.f(x) + if eltype(y) == eltype(cache) + newcache = cache + else + # poor man's PreallocationTools + newcache = similar(cache, eltype(y)) + end + if y isa Number + newcache[1] = y + return newcache[1] + else + copyto!(newcache, y) + return copy(newcache) + end +end + +function (sc::MultiplyByConstantAndStoreInCache{:in})(y, x, constantorcache) + (; constant, cache) = constantorcache + if eltype(y) == eltype(cache) + newcache = cache + else + # poor man's PreallocationTools + newcache = similar(cache, eltype(y)) + end + sc.f(newcache, x) + newcache .*= constant + copyto!(y, newcache) + return nothing +end + +""" + constantorcachify(scen::Scenario) + +Return a new `Scenario` identical to `scen` except for the function `f`, which is made to accept an additional "constant or cache" argument. +""" +function constantorcachify(scen::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} + (; f,) = scen + @assert isempty(scen.contexts) + constantorcache_f = MultiplyByConstantAndStoreInCache{pl_fun}(f) + a = 3.0 + constantorcache = if scen.y isa Number + (; cache=[myzero(scen.y)], constant=a) + else + (; cache=mysimilar(scen.y), constant=a) + end return Scenario{op,pl_op,pl_fun}( - scen.f; + constantorcache_f; x=scen.x, - y=scen.y, + y=mymultiply(scen.y, a), tang=scen.tang, - contexts=scen.contexts, - res1=scen.res1, - res2=scen.res2, - smaller=smaller, + contexts=(ConstantOrCache(constantorcache),), + res1=mymultiply(scen.res1, a), + res2=mymultiply(scen.res2, a), + smaller=isnothing(scen.smaller) ? nothing : constantorcachify(scen.smaller), + name=isnothing(scen.name) ? nothing : scen.name * " [constantorcachified]", ) end + +## Group functions + +function batchify(scens::AbstractVector{<:Scenario}) + batchifiable_scens = filter(s -> operator(s) in (:pushforward, :pullback, :hvp), scens) + return batchify.(batchifiable_scens) +end + +closurify(scens::AbstractVector{<:Scenario}) = closurify.(scens) +constantify(scens::AbstractVector{<:Scenario}) = constantify.(scens) +cachify(scens::AbstractVector{<:Scenario}; use_tuples) = cachify.(scens; use_tuples) +constantorcachify(scens::AbstractVector{<:Scenario}) = constantorcachify.(scens) diff --git a/DifferentiationInterfaceTest/src/scenarios/scenario.jl b/DifferentiationInterfaceTest/src/scenarios/scenario.jl index 7b92848d1..9d1c6a69f 100644 --- a/DifferentiationInterfaceTest/src/scenarios/scenario.jl +++ b/DifferentiationInterfaceTest/src/scenarios/scenario.jl @@ -101,7 +101,7 @@ function Base.:(==)( eq_tang = scen1.tang == scen2.tang eq_contexts = all( map(scen1.contexts, scen2.contexts) do c1, c2 - if c1 isa Cache || c2 isa Cache + if c1 isa Union{Cache,ConstantOrCache} || c2 isa Union{Cache,ConstantOrCache} return true else return c1 == c2 diff --git a/DifferentiationInterfaceTest/src/scenarios/sparse.jl b/DifferentiationInterfaceTest/src/scenarios/sparse.jl index 99ec02140..d3821c7e4 100644 --- a/DifferentiationInterfaceTest/src/scenarios/sparse.jl +++ b/DifferentiationInterfaceTest/src/scenarios/sparse.jl @@ -328,6 +328,7 @@ function sparse_scenarios(; band_sizes=[5, 10, 20], include_constantified=false, include_cachified=false, + include_constantorcachified=false, use_tuples=false, ) x_6 = float.(1:6) @@ -351,5 +352,6 @@ function sparse_scenarios(; append!(final_scens, scens) include_constantified && append!(final_scens, constantify(scens)) include_cachified && append!(final_scens, cachify(scens; use_tuples)) + include_constantorcachified && append!(final_scens, constantorcachify(scens)) return final_scens end diff --git a/DifferentiationInterfaceTest/src/utils.jl b/DifferentiationInterfaceTest/src/utils.jl index ada6fa2d6..81369e1ac 100644 --- a/DifferentiationInterfaceTest/src/utils.jl +++ b/DifferentiationInterfaceTest/src/utils.jl @@ -1,16 +1,16 @@ myzero(x::Number) = zero(x) myzero(x::AbstractArray) = zero(x) -myzero(x::NTuple) = map(myzero, x) +myzero(x::Union{Tuple,NamedTuple}) = map(myzero, x) myzero(::Nothing) = nothing mysimilar(x::Number) = one(x) mysimilar(x::AbstractArray) = similar(x) -mysimilar(x::NTuple) = map(mysimilar, x) +mysimilar(x::Union{Tuple,NamedTuple}) = map(mysimilar, x) mysimilar(x) = deepcopy(x) myrandom(rng::AbstractRNG, x::Number) = randn(rng, typeof(x)) myrandom(rng::AbstractRNG, x::AbstractArray) = map(Base.Fix1(myrandom, rng), x) -myrandom(rng::AbstractRNG, x::NTuple) = map(Base.Fix1(myrandom, rng), x) +myrandom(rng::AbstractRNG, x::Union{Tuple,NamedTuple}) = map(Base.Fix1(myrandom, rng), x) myrandom(rng::AbstractRNG, x) = deepcopy(x) myrandom(x) = myrandom(default_rng(), x) @@ -21,7 +21,7 @@ mysize(x) = missing mymultiply(x::Number, a::Number) = a * x mymultiply(x::AbstractArray, a::Number) = a .* x -mymultiply(x::NTuple, a::Number) = map(Base.Fix2(mymultiply, a), x) +mymultiply(x::Union{Tuple,NamedTuple}, a::Number) = map(Base.Fix2(mymultiply, a), x) mymultiply(::Nothing, a::Number) = nothing mynnz(A::AbstractMatrix) = count(!iszero, A) diff --git a/DifferentiationInterfaceTest/test/standard.jl b/DifferentiationInterfaceTest/test/standard.jl index 5324dd580..d85ab0603 100644 --- a/DifferentiationInterfaceTest/test/standard.jl +++ b/DifferentiationInterfaceTest/test/standard.jl @@ -17,10 +17,12 @@ test_differentiation( logging=LOGGING, ) -## Complex - test_differentiation( - AutoFiniteDiff(), vcat(complex_scenarios(), complex_sparse_scenarios()); logging=LOGGING + [AutoForwardDiff(), AutoFiniteDiff(; relstep=1e-5)], + default_scenarios(; + include_batchified=false, include_normal=false, include_constantorcachified=true + ); + logging=LOGGING, ) ## Sparse @@ -37,3 +39,9 @@ test_differentiation( sparsity=true, logging=LOGGING, ) + +## Complex + +test_differentiation( + AutoFiniteDiff(), vcat(complex_scenarios(), complex_sparse_scenarios()); logging=LOGGING +)