diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index fc7347b53..486b91a5b 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.7.1" +version = "0.7.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl index f89da5598..991796bb1 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl @@ -1,6 +1,11 @@ -const AnyDuplicated = Union{Duplicated,MixedDuplicated,BatchDuplicated,BatchMixedDuplicated} - -const AnyDuplicatedNoNeed = Union{DuplicatedNoNeed,BatchDuplicatedNoNeed} +const AnyDuplicated = Union{ + Duplicated, + MixedDuplicated, + BatchDuplicated, + BatchMixedDuplicated, + DuplicatedNoNeed, + BatchDuplicatedNoNeed, +} # until https://github.com/EnzymeAD/Enzyme.jl/pull/1545 is merged function DI.pick_batchsize(::AutoEnzyme, N::Integer) @@ -35,25 +40,13 @@ function get_f_and_df_prepared!( end end -function get_f_and_df_prepared!( - df, f::F, ::AutoEnzyme{M,<:AnyDuplicatedNoNeed}, ::Val{B} -) where {F,M,B} - if B == 1 - return DuplicatedNoNeed(f, df) - else - return BatchDuplicatedNoNeed(f, df) - end -end - function function_shadow( ::F, ::AutoEnzyme{M,<:Union{Const,Nothing}}, ::Val{B} ) where {M,B,F} return nothing end -function function_shadow( - f::F, ::AutoEnzyme{M,<:Union{AnyDuplicated,AnyDuplicatedNoNeed}}, ::Val{B} -) where {F,M,B} +function function_shadow(f::F, ::AutoEnzyme{M,<:AnyDuplicated}, ::Val{B}) where {F,M,B} if B == 1 return make_zero(f) else diff --git a/DifferentiationInterface/test/Back/Enzyme/test.jl b/DifferentiationInterface/test/Back/Enzyme/test.jl index 561dc1966..7e1b02451 100644 --- a/DifferentiationInterface/test/Back/Enzyme/test.jl +++ b/DifferentiationInterface/test/Back/Enzyme/test.jl @@ -32,7 +32,7 @@ backends = [ duplicated_backends = [ AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Duplicated), - AutoEnzyme(; mode=Enzyme.Reverse, function_annotation=Enzyme.DuplicatedNoNeed), + AutoEnzyme(; mode=Enzyme.Reverse, function_annotation=Enzyme.Duplicated), ] @testset "Checks" begin