Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.6.47"
version = "0.6.48"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,27 @@
end

@inline function _translate(
backend::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Cache,DI.GeneralizedConstantOrCache}
backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.Cache
) where {B}
# important to keep make_zero here for ConstantOrCache instead of similar
if B == 1
return Duplicated(DI.unwrap(c), make_zero(DI.unwrap(c)))
else
return BatchDuplicated(DI.unwrap(c), ntuple(_ -> make_zero(DI.unwrap(c)), Val(B)))
end
end

@inline function _translate(

Check warning on line 67 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl#L67

Added line #L67 was not covered by tests
backend::AutoEnzyme, mode::Mode, valB::Val{B}, c::DI.ConstantOrCache
) where {B}
IA = guess_activity(typeof(DI.unwrap(c)), mode)
if IA <: Const
return _translate(backend, mode, valB, DI.Constant(DI.unwrap(c)))

Check warning on line 72 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl#L70-L72

Added lines #L70 - L72 were not covered by tests
else
return _translate(backend, mode, valB, DI.Cache(DI.unwrap(c)))

Check warning on line 74 in DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl#L74

Added line #L74 was not covered by tests
end
end

@inline function _translate(
backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.FunctionContext
) where {B}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ function mypartials!(::Type{T}, ty::NTuple{B}, ydual) where {T,B}
end

function _translate(
::Type{D}, c::Union{DI.GeneralizedConstant,DI.GeneralizedConstantOrCache}
::Type{D}, c::Union{DI.GeneralizedConstant,DI.ConstantOrCache}
) where {D<:Dual}
return DI.unwrap(c)
end
Expand All @@ -100,7 +100,7 @@ function translate(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:Dual,C}
end

function _translate_toprep(
::Type{D}, c::Union{DI.GeneralizedConstant,DI.GeneralizedConstantOrCache}
::Type{D}, c::Union{DI.GeneralizedConstant,DI.ConstantOrCache}
) where {D<:Dual}
return nothing
end
Expand All @@ -116,9 +116,7 @@ function translate_toprep(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:D
return new_contexts
end

function _translate_prepared(
c::Union{DI.GeneralizedConstant,DI.GeneralizedConstantOrCache}, _pc
)
function _translate_prepared(c::Union{DI.GeneralizedConstant,DI.ConstantOrCache}, _pc)
return DI.unwrap(c)
end
_translate_prepared(_c::DI.Cache, pc) = pc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ import DifferentiationInterface as DI
using SparseConnectivityTracer:
TracerSparsityDetector, TracerLocalSparsityDetector, jacobian_buffer, hessian_buffer

@inline _translate(::Type, c::DI.Constant) = DI.unwrap(c)
@inline function _translate(::Type, c::Union{DI.GeneralizedConstant,DI.ConstantOrCache})
return DI.unwrap(c)
end
@inline function _translate(::Type{T}, c::DI.Cache) where {T}
return DI.recursive_similar(DI.unwrap(c), T)
end
Expand Down
62 changes: 31 additions & 31 deletions DifferentiationInterface/src/second_order/hvp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ function _prepare_hvp_aux(
rewrap = Rewrap(contexts...)
# Outer pushforward
new_contexts = (
FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts...
FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts...
)
outer_pushforward_prep = prepare_pushforward_nokwarg(
strict, shuffled_gradient, outer(backend), x, tx, new_contexts...
Expand Down Expand Up @@ -161,15 +161,15 @@ function _prepare_hvp_aux(
# Outer pushforward
new_contexts = (
FunctionContext(f),
PrepContext(inner_gradient_prep),
BackendContext(inner(backend)),
ConstantOrCache(inner_gradient_prep),
Constant(inner(backend)),
Constant(rewrap),
contexts...,
)
new_contexts_in = (
FunctionContext(f),
PrepContext(inner_gradient_in_prep),
BackendContext(inner(backend)),
ConstantOrCache(inner_gradient_in_prep),
Constant(inner(backend)),
Constant(rewrap),
contexts...,
)
Expand Down Expand Up @@ -228,15 +228,15 @@ function _prepare_hvp_aux(
# Outer pushforward
new_contexts = (
FunctionContext(f),
PrepContext(inner_gradient_prep),
BackendContext(inner(backend)),
ConstantOrCache(inner_gradient_prep),
Constant(inner(backend)),
Constant(rewrap),
contexts...,
)
new_contexts_in = (
FunctionContext(f),
PrepContext(inner_gradient_in_prep),
BackendContext(inner(backend)),
ConstantOrCache(inner_gradient_in_prep),
Constant(inner(backend)),
Constant(rewrap),
contexts...,
)
Expand Down Expand Up @@ -279,8 +279,8 @@ function hvp(
rewrap = Rewrap(contexts...)
new_contexts = (
FunctionContext(f),
map(PrepContext, maybe_inner_gradient_prep)...,
BackendContext(inner(backend)),
map(ConstantOrCache, maybe_inner_gradient_prep)...,
Constant(inner(backend)),
Constant(rewrap),
contexts...,
)
Expand Down Expand Up @@ -318,8 +318,8 @@ function _hvp_aux!(
rewrap = Rewrap(contexts...)
new_contexts = (
FunctionContext(f),
map(PrepContext, maybe_inner_gradient_in_prep)...,
BackendContext(inner(backend)),
map(ConstantOrCache, maybe_inner_gradient_in_prep)...,
Constant(inner(backend)),
Constant(rewrap),
contexts...,
)
Expand Down Expand Up @@ -349,8 +349,8 @@ function _hvp_aux!(
rewrap = Rewrap(contexts...)
new_contexts = (
FunctionContext(f),
map(PrepContext, maybe_inner_gradient_prep)...,
BackendContext(inner(backend)),
map(ConstantOrCache, maybe_inner_gradient_prep)...,
Constant(inner(backend)),
Constant(rewrap),
contexts...,
)
Expand Down Expand Up @@ -378,8 +378,8 @@ function gradient_and_hvp(
rewrap = Rewrap(contexts...)
new_contexts = (
FunctionContext(f),
map(PrepContext, maybe_inner_gradient_prep)...,
BackendContext(inner(backend)),
map(ConstantOrCache, maybe_inner_gradient_prep)...,
Constant(inner(backend)),
Constant(rewrap),
contexts...,
)
Expand Down Expand Up @@ -419,8 +419,8 @@ function _gradient_and_hvp_aux!(
rewrap = Rewrap(contexts...)
new_contexts = (
FunctionContext(f),
map(PrepContext, maybe_inner_gradient_in_prep)...,
BackendContext(inner(backend)),
map(ConstantOrCache, maybe_inner_gradient_in_prep)...,
Constant(inner(backend)),
Constant(rewrap),
contexts...,
)
Expand Down Expand Up @@ -452,8 +452,8 @@ function _gradient_and_hvp_aux!(
rewrap = Rewrap(contexts...)
new_contexts = (
FunctionContext(f),
map(PrepContext, maybe_inner_gradient_prep)...,
BackendContext(inner(backend)),
map(ConstantOrCache, maybe_inner_gradient_prep)...,
Constant(inner(backend)),
Constant(rewrap),
contexts...,
)
Expand Down Expand Up @@ -492,7 +492,7 @@ function _prepare_hvp_aux(
rewrap = Rewrap(contexts...)
new_contexts = (
FunctionContext(f),
BackendContext(inner(backend)),
Constant(inner(backend)),
Constant(first(tx)),
Constant(rewrap),
contexts...,
Expand Down Expand Up @@ -522,7 +522,7 @@ function hvp(
outer(backend),
x,
FunctionContext(f),
BackendContext(inner(backend)),
Constant(inner(backend)),
Constant(dx),
Constant(rewrap),
contexts...,
Expand Down Expand Up @@ -551,7 +551,7 @@ function hvp!(
outer(backend),
x,
FunctionContext(f),
BackendContext(inner(backend)),
Constant(inner(backend)),
Constant(tx[b]),
Constant(rewrap),
contexts...,
Expand Down Expand Up @@ -613,7 +613,7 @@ function _prepare_hvp_aux(
_sig = signature(f, backend, x, tx, contexts...; strict)
rewrap = Rewrap(contexts...)
new_contexts = (
FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts...
FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts...
)
grad_buffer = similar(x)
outer_pullback_prep = prepare_pullback_nokwarg(
Expand Down Expand Up @@ -649,7 +649,7 @@ function hvp(
(; outer_pullback_prep) = prep
rewrap = Rewrap(contexts...)
new_contexts = (
FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts...
FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts...
)
return pullback(
shuffled_gradient, outer_pullback_prep, outer(backend), x, tx, new_contexts...
Expand Down Expand Up @@ -684,7 +684,7 @@ function _hvp_aux!(
(; grad_buffer, outer_pullback_in_prep) = prep
rewrap = Rewrap(contexts...)
new_contexts = (
FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts...
FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts...
)
return pullback!(
shuffled_gradient!,
Expand All @@ -711,7 +711,7 @@ function _hvp_aux!(
(; outer_pullback_prep) = prep
rewrap = Rewrap(contexts...)
new_contexts = (
FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts...
FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts...
)
return pullback!(
shuffled_gradient, tg, outer_pullback_prep, outer(backend), x, tx, new_contexts...
Expand All @@ -730,7 +730,7 @@ function gradient_and_hvp(
(; outer_pullback_prep) = prep
rewrap = Rewrap(contexts...)
new_contexts = (
FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts...
FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts...
)
return value_and_pullback(
shuffled_gradient, outer_pullback_prep, outer(backend), x, tx, new_contexts...
Expand Down Expand Up @@ -767,7 +767,7 @@ function _gradient_and_hvp_aux!(
(; outer_pullback_in_prep) = prep
rewrap = Rewrap(contexts...)
new_contexts = (
FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts...
FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts...
)
new_grad, _ = value_and_pullback!(
shuffled_gradient!,
Expand Down Expand Up @@ -796,7 +796,7 @@ function _gradient_and_hvp_aux!(
(; outer_pullback_prep) = prep
rewrap = Rewrap(contexts...)
new_contexts = (
FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts...
FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts...
)
new_grad, _ = value_and_pullback!(
shuffled_gradient, tg, outer_pullback_prep, outer(backend), x, tx, new_contexts...
Expand Down
10 changes: 5 additions & 5 deletions DifferentiationInterface/src/second_order/second_derivative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ function prepare_second_derivative_nokwarg(
_sig = signature(f, backend, x, contexts...; strict)
rewrap = Rewrap(contexts...)
new_contexts = (
FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts...
FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts...
)
outer_derivative_prep = prepare_derivative_nokwarg(
strict, shuffled_derivative, outer(backend), x, new_contexts...
Expand All @@ -88,7 +88,7 @@ function second_derivative(
(; outer_derivative_prep) = prep
rewrap = Rewrap(contexts...)
new_contexts = (
FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts...
FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts...
)
return derivative(
shuffled_derivative, outer_derivative_prep, outer(backend), x, new_contexts...
Expand All @@ -106,7 +106,7 @@ function value_derivative_and_second_derivative(
(; outer_derivative_prep) = prep
rewrap = Rewrap(contexts...)
new_contexts = (
FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts...
FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts...
)
y = f(x, map(unwrap, contexts)...)
der, der2 = value_and_derivative(
Expand All @@ -127,7 +127,7 @@ function second_derivative!(
(; outer_derivative_prep) = prep
rewrap = Rewrap(contexts...)
new_contexts = (
FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts...
FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts...
)
return derivative!(
shuffled_derivative, der2, outer_derivative_prep, outer(backend), x, new_contexts...
Expand All @@ -147,7 +147,7 @@ function value_derivative_and_second_derivative!(
(; outer_derivative_prep) = prep
rewrap = Rewrap(contexts...)
new_contexts = (
FunctionContext(f), BackendContext(inner(backend)), Constant(rewrap), contexts...
FunctionContext(f), Constant(inner(backend)), Constant(rewrap), contexts...
)
y = f(x, map(unwrap, contexts)...)
new_der, _ = value_and_derivative!(
Expand Down
23 changes: 1 addition & 22 deletions DifferentiationInterface/src/utils/context.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ Abstract supertype for additional context arguments, which can be passed to diff
abstract type Context end

abstract type GeneralizedConstant <: Context end
abstract type GeneralizedConstantOrCache <: Context end

unwrap(c::Context) = c.data
Base.:(==)(c1::Context, c2::Context) = unwrap(c1) == unwrap(c2)
Expand Down Expand Up @@ -102,7 +101,7 @@ Concrete type of [`Context`](@ref) argument which can contain a mixture of const

Unlike for [`Cache`](@ref), it is up to the user to ensure that the internal storage can adapt to the required element types, for instance by using [PreallocationTools.jl](https://github.com/SciML/PreallocationTools.jl) directly.
"""
struct ConstantOrCache{T} <: GeneralizedConstantOrCache
struct ConstantOrCache{T} <: Context
data::T
end

Expand All @@ -123,26 +122,6 @@ struct FunctionContext{T} <: GeneralizedConstant
data::T
end

"""
BackendContext

Private type of [`Context`](@ref) argument used for passing backends inside second-order differentiation.
"""
struct BackendContext{T} <: GeneralizedConstant
data::T
end

"""
PrepContext

Private type of [`Context`](@ref) argument used for passing preparation results inside second-order differentiation.

Conceptually similar to [`ConstantOrCache`](@ref) because we assume that preparation was performed with the right types so we don't change anything.
"""
struct PrepContext{T} <: GeneralizedConstantOrCache
data::T
end

## Context manipulation

"""
Expand Down
12 changes: 12 additions & 0 deletions DifferentiationInterface/test/Back/Enzyme/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using DifferentiationInterface, DifferentiationInterfaceTest
import DifferentiationInterfaceTest as DIT
using Enzyme: Enzyme
using LinearAlgebra
using StaticArrays
using Test

Expand Down Expand Up @@ -136,3 +137,14 @@
logging=LOGGING,
)
end

@testset "Coverage" begin
# ConstantOrCache without cache
f_nocontext(x, p) = x

Check warning on line 143 in DifferentiationInterface/test/Back/Enzyme/test.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/test/Back/Enzyme/test.jl#L143

Added line #L143 was not covered by tests
@test I == DifferentiationInterface.jacobian(
f_nocontext, AutoEnzyme(; mode=Enzyme.Forward), rand(10), ConstantOrCache(nothing)
)
@test I == DifferentiationInterface.jacobian(
f_nocontext, AutoEnzyme(; mode=Enzyme.Reverse), rand(10), ConstantOrCache(nothing)
)
end
Loading
Loading