From 71f38d0c2f0ccb872b7c70a002afcfcf37d2b5be Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 20 Mar 2025 19:08:27 +0100 Subject: [PATCH 1/2] fix: handle constant `ConstantOrCache` with Enzyme --- DifferentiationInterface/Project.toml | 2 +- .../ext/DifferentiationInterfaceEnzymeExt/utils.jl | 14 +++++++++++++- DifferentiationInterface/test/Back/Enzyme/test.jl | 11 +++++++++++ 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index c6c880c80..289cfee51 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.6.47" +version = "0.6.48" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl index eb71d0aec..9c5a3c669 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl @@ -54,8 +54,9 @@ force_annotation(f::F) where {F} = Const(f) end @inline function _translate( - backend::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Cache,DI.GeneralizedConstantOrCache} + backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.Cache ) where {B} + # important to keep make_zero here for ConstantOrCache instead of similar if B == 1 return Duplicated(DI.unwrap(c), make_zero(DI.unwrap(c))) else @@ -63,6 +64,17 @@ end end end +@inline function _translate( + backend::AutoEnzyme, mode::Mode, valB::Val{B}, c::DI.GeneralizedConstantOrCache +) where {B} + IA = guess_activity(typeof(DI.unwrap(c)), mode) + if IA <: Const + return _translate(backend, mode, valB, DI.Constant(DI.unwrap(c))) + else + return _translate(backend, mode, valB, DI.Cache(DI.unwrap(c))) + end +end + @inline function _translate( backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.FunctionContext ) where {B} diff --git a/DifferentiationInterface/test/Back/Enzyme/test.jl b/DifferentiationInterface/test/Back/Enzyme/test.jl index 5100e7b43..fbe68ed83 100644 --- a/DifferentiationInterface/test/Back/Enzyme/test.jl +++ b/DifferentiationInterface/test/Back/Enzyme/test.jl @@ -136,3 +136,14 @@ end logging=LOGGING, ) end + +@testset "Coverage" begin + # ConstantOrCache without cache + f_nocontext(x, p) = x + @test I == DifferentiationInterface.jacobian( + f_nocontext, AutoEnzyme(; mode=Enzyme.Forward), rand(10), ConstantOrCache(nothing) + ) + @test I == DifferentiationInterface.jacobian( + f_nocontext, AutoEnzyme(; mode=Enzyme.Reverse), rand(10), ConstantOrCache(nothing) + ) +end From 6de1319c038b218cc8a3f4387d9f308e4e4073d5 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 20 Mar 2025 20:20:34 +0100 Subject: [PATCH 2/2] Fixes --- .../utils.jl | 2 +- .../utils.jl | 8 +-- ...ionInterfaceSparseConnectivityTracerExt.jl | 4 +- .../src/second_order/hvp.jl | 62 +++++++++---------- .../src/second_order/second_derivative.jl | 10 +-- DifferentiationInterface/src/utils/context.jl | 23 +------ .../test/Back/Enzyme/test.jl | 1 + .../test/Core/SimpleFiniteDiff/test.jl | 5 +- 8 files changed, 49 insertions(+), 66 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl index 9c5a3c669..64a27e2e5 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl @@ -65,7 +65,7 @@ end end @inline function _translate( - backend::AutoEnzyme, mode::Mode, valB::Val{B}, c::DI.GeneralizedConstantOrCache + backend::AutoEnzyme, mode::Mode, valB::Val{B}, c::DI.ConstantOrCache ) where {B} IA = guess_activity(typeof(DI.unwrap(c)), mode) if IA <: Const diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl index dd78760a3..14b17f2dd 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl @@ -83,7 +83,7 @@ function mypartials!(::Type{T}, ty::NTuple{B}, ydual) where {T,B} end function _translate( - ::Type{D}, c::Union{DI.GeneralizedConstant,DI.GeneralizedConstantOrCache} + ::Type{D}, c::Union{DI.GeneralizedConstant,DI.ConstantOrCache} ) where {D<:Dual} return DI.unwrap(c) end @@ -100,7 +100,7 @@ function translate(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:Dual,C} end function _translate_toprep( - ::Type{D}, c::Union{DI.GeneralizedConstant,DI.GeneralizedConstantOrCache} + ::Type{D}, c::Union{DI.GeneralizedConstant,DI.ConstantOrCache} ) where {D<:Dual} return nothing end @@ -116,9 +116,7 @@ function translate_toprep(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:D return new_contexts end -function _translate_prepared( - c::Union{DI.GeneralizedConstant,DI.GeneralizedConstantOrCache}, _pc -) +function _translate_prepared(c::Union{DI.GeneralizedConstant,DI.ConstantOrCache}, _pc) return DI.unwrap(c) end _translate_prepared(_c::DI.Cache, pc) = pc diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl index 796af1bb0..2bb96701a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl @@ -5,7 +5,9 @@ import DifferentiationInterface as DI using SparseConnectivityTracer: TracerSparsityDetector, TracerLocalSparsityDetector, jacobian_buffer, hessian_buffer -@inline _translate(::Type, c::DI.Constant) = DI.unwrap(c) +@inline function _translate(::Type, c::Union{DI.GeneralizedConstant,DI.ConstantOrCache}) + return DI.unwrap(c) +end @inline function _translate(::Type{T}, c::DI.Cache) where {T} return DI.recursive_similar(DI.unwrap(c), T) end diff --git a/DifferentiationInterface/src/second_order/hvp.jl b/DifferentiationInterface/src/second_order/hvp.jl index 11faa2548..a8d31f8dd 100644 --- a/DifferentiationInterface/src/second_order/hvp.jl +++ b/DifferentiationInterface/src/second_order/hvp.jl @@ -117,7 +117,7 @@ function _prepare_hvp_aux( rewrap = Rewrap(contexts...) # Outer pushforward new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts... ) outer_pushforward_prep = prepare_pushforward_nokwarg( strict, shuffled_gradient, outer(backend), x, tx, new_contexts... @@ -161,15 +161,15 @@ function _prepare_hvp_aux( # Outer pushforward new_contexts = ( FunctionContext(f), - PrepContext(inner_gradient_prep), - BackendContext(inner(backend)), + ConstantOrCache(inner_gradient_prep), + Constant(inner(backend)), Constant(rewrap), contexts..., ) new_contexts_in = ( FunctionContext(f), - PrepContext(inner_gradient_in_prep), - BackendContext(inner(backend)), + ConstantOrCache(inner_gradient_in_prep), + Constant(inner(backend)), Constant(rewrap), contexts..., ) @@ -228,15 +228,15 @@ function _prepare_hvp_aux( # Outer pushforward new_contexts = ( FunctionContext(f), - PrepContext(inner_gradient_prep), - BackendContext(inner(backend)), + ConstantOrCache(inner_gradient_prep), + Constant(inner(backend)), Constant(rewrap), contexts..., ) new_contexts_in = ( FunctionContext(f), - PrepContext(inner_gradient_in_prep), - BackendContext(inner(backend)), + ConstantOrCache(inner_gradient_in_prep), + Constant(inner(backend)), Constant(rewrap), contexts..., ) @@ -279,8 +279,8 @@ function hvp( rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), - map(PrepContext, maybe_inner_gradient_prep)..., - BackendContext(inner(backend)), + map(ConstantOrCache, maybe_inner_gradient_prep)..., + Constant(inner(backend)), Constant(rewrap), contexts..., ) @@ -318,8 +318,8 @@ function _hvp_aux!( rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), - map(PrepContext, maybe_inner_gradient_in_prep)..., - BackendContext(inner(backend)), + map(ConstantOrCache, maybe_inner_gradient_in_prep)..., + Constant(inner(backend)), Constant(rewrap), contexts..., ) @@ -349,8 +349,8 @@ function _hvp_aux!( rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), - map(PrepContext, maybe_inner_gradient_prep)..., - BackendContext(inner(backend)), + map(ConstantOrCache, maybe_inner_gradient_prep)..., + Constant(inner(backend)), Constant(rewrap), contexts..., ) @@ -378,8 +378,8 @@ function gradient_and_hvp( rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), - map(PrepContext, maybe_inner_gradient_prep)..., - BackendContext(inner(backend)), + map(ConstantOrCache, maybe_inner_gradient_prep)..., + Constant(inner(backend)), Constant(rewrap), contexts..., ) @@ -419,8 +419,8 @@ function _gradient_and_hvp_aux!( rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), - map(PrepContext, maybe_inner_gradient_in_prep)..., - BackendContext(inner(backend)), + map(ConstantOrCache, maybe_inner_gradient_in_prep)..., + Constant(inner(backend)), Constant(rewrap), contexts..., ) @@ -452,8 +452,8 @@ function _gradient_and_hvp_aux!( rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), - map(PrepContext, maybe_inner_gradient_prep)..., - BackendContext(inner(backend)), + map(ConstantOrCache, maybe_inner_gradient_prep)..., + Constant(inner(backend)), Constant(rewrap), contexts..., ) @@ -492,7 +492,7 @@ function _prepare_hvp_aux( rewrap = Rewrap(contexts...) new_contexts = ( FunctionContext(f), - BackendContext(inner(backend)), + Constant(inner(backend)), Constant(first(tx)), Constant(rewrap), contexts..., @@ -522,7 +522,7 @@ function hvp( outer(backend), x, FunctionContext(f), - BackendContext(inner(backend)), + Constant(inner(backend)), Constant(dx), Constant(rewrap), contexts..., @@ -551,7 +551,7 @@ function hvp!( outer(backend), x, FunctionContext(f), - BackendContext(inner(backend)), + Constant(inner(backend)), Constant(tx[b]), Constant(rewrap), contexts..., @@ -613,7 +613,7 @@ function _prepare_hvp_aux( _sig = signature(f, backend, x, tx, contexts...; strict) rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts... ) grad_buffer = similar(x) outer_pullback_prep = prepare_pullback_nokwarg( @@ -649,7 +649,7 @@ function hvp( (; outer_pullback_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts... ) return pullback( shuffled_gradient, outer_pullback_prep, outer(backend), x, tx, new_contexts... @@ -684,7 +684,7 @@ function _hvp_aux!( (; grad_buffer, outer_pullback_in_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts... ) return pullback!( shuffled_gradient!, @@ -711,7 +711,7 @@ function _hvp_aux!( (; outer_pullback_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts... ) return pullback!( shuffled_gradient, tg, outer_pullback_prep, outer(backend), x, tx, new_contexts... @@ -730,7 +730,7 @@ function gradient_and_hvp( (; outer_pullback_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts... ) return value_and_pullback( shuffled_gradient, outer_pullback_prep, outer(backend), x, tx, new_contexts... @@ -767,7 +767,7 @@ function _gradient_and_hvp_aux!( (; outer_pullback_in_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts... ) new_grad, _ = value_and_pullback!( shuffled_gradient!, @@ -796,7 +796,7 @@ function _gradient_and_hvp_aux!( (; outer_pullback_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts... ) new_grad, _ = value_and_pullback!( shuffled_gradient, tg, outer_pullback_prep, outer(backend), x, tx, new_contexts... diff --git a/DifferentiationInterface/src/second_order/second_derivative.jl b/DifferentiationInterface/src/second_order/second_derivative.jl index 09e81fbac..6a9814b40 100644 --- a/DifferentiationInterface/src/second_order/second_derivative.jl +++ b/DifferentiationInterface/src/second_order/second_derivative.jl @@ -67,7 +67,7 @@ function prepare_second_derivative_nokwarg( _sig = signature(f, backend, x, contexts...; strict) rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts... ) outer_derivative_prep = prepare_derivative_nokwarg( strict, shuffled_derivative, outer(backend), x, new_contexts... @@ -88,7 +88,7 @@ function second_derivative( (; outer_derivative_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts... ) return derivative( shuffled_derivative, outer_derivative_prep, outer(backend), x, new_contexts... @@ -106,7 +106,7 @@ function value_derivative_and_second_derivative( (; outer_derivative_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts... ) y = f(x, map(unwrap, contexts)...) der, der2 = value_and_derivative( @@ -127,7 +127,7 @@ function second_derivative!( (; outer_derivative_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts... ) return derivative!( shuffled_derivative, der2, outer_derivative_prep, outer(backend), x, new_contexts... @@ -147,7 +147,7 @@ function value_derivative_and_second_derivative!( (; outer_derivative_prep) = prep rewrap = Rewrap(contexts...) new_contexts = ( - FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts... + FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts... ) y = f(x, map(unwrap, contexts)...) new_der, _ = value_and_derivative!( diff --git a/DifferentiationInterface/src/utils/context.jl b/DifferentiationInterface/src/utils/context.jl index a9ab2ea47..74c59efd7 100644 --- a/DifferentiationInterface/src/utils/context.jl +++ b/DifferentiationInterface/src/utils/context.jl @@ -12,7 +12,6 @@ Abstract supertype for additional context arguments, which can be passed to diff abstract type Context end abstract type GeneralizedConstant <: Context end -abstract type GeneralizedConstantOrCache <: Context end unwrap(c::Context) = c.data Base.:(==)(c1::Context, c2::Context) = unwrap(c1) == unwrap(c2) @@ -102,7 +101,7 @@ Concrete type of [`Context`](@ref) argument which can contain a mixture of const Unlike for [`Cache`](@ref), it is up to the user to ensure that the internal storage can adapt to the required element types, for instance by using [PreallocationTools.jl](https://github.com/SciML/PreallocationTools.jl) directly. """ -struct ConstantOrCache{T} <: GeneralizedConstantOrCache +struct ConstantOrCache{T} <: Context data::T end @@ -123,26 +122,6 @@ struct FunctionContext{T} <: GeneralizedConstant data::T end -""" - BackendContext - -Private type of [`Context`](@ref) argument used for passing backends inside second-order differentiation. -""" -struct BackendContext{T} <: GeneralizedConstant - data::T -end - -""" - PrepContext - -Private type of [`Context`](@ref) argument used for passing preparation results inside second-order differentiation. - -Conceptually similar to [`ConstantOrCache`](@ref) because we assume that preparation was performed with the right types so we don't change anything. -""" -struct PrepContext{T} <: GeneralizedConstantOrCache - data::T -end - ## Context manipulation """ diff --git a/DifferentiationInterface/test/Back/Enzyme/test.jl b/DifferentiationInterface/test/Back/Enzyme/test.jl index fbe68ed83..a772e48ba 100644 --- a/DifferentiationInterface/test/Back/Enzyme/test.jl +++ b/DifferentiationInterface/test/Back/Enzyme/test.jl @@ -5,6 +5,7 @@ using ADTypes: ADTypes using DifferentiationInterface, DifferentiationInterfaceTest import DifferentiationInterfaceTest as DIT using Enzyme: Enzyme +using LinearAlgebra using StaticArrays using Test diff --git a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl index 48d84e022..15cf4d723 100644 --- a/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl +++ b/DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl @@ -88,7 +88,10 @@ end vcat(adaptive_backends, MixedMode(adaptive_backends[1], adaptive_backends[2])) ), sparse_scenarios(; - include_constantified=true, include_cachified=true, use_tuples=true + include_constantified=true, + include_cachified=true, + include_constantorcachified=true, + use_tuples=true, ); sparsity=true, logging=LOGGING,