Skip to content

Commit 71f38d0

Browse files
committed
fix: handle constant ConstantOrCache with Enzyme
1 parent cfab84d commit 71f38d0

3 files changed

Lines changed: 25 additions & 2 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.47"
4+
version = "0.6.48"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,15 +54,27 @@ force_annotation(f::F) where {F} = Const(f)
5454
end
5555

5656
@inline function _translate(
57-
backend::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Cache,DI.GeneralizedConstantOrCache}
57+
backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.Cache
5858
) where {B}
59+
# important to keep make_zero here for ConstantOrCache instead of similar
5960
if B == 1
6061
return Duplicated(DI.unwrap(c), make_zero(DI.unwrap(c)))
6162
else
6263
return BatchDuplicated(DI.unwrap(c), ntuple(_ -> make_zero(DI.unwrap(c)), Val(B)))
6364
end
6465
end
6566

67+
@inline function _translate(
68+
backend::AutoEnzyme, mode::Mode, valB::Val{B}, c::DI.GeneralizedConstantOrCache
69+
) where {B}
70+
IA = guess_activity(typeof(DI.unwrap(c)), mode)
71+
if IA <: Const
72+
return _translate(backend, mode, valB, DI.Constant(DI.unwrap(c)))
73+
else
74+
return _translate(backend, mode, valB, DI.Cache(DI.unwrap(c)))
75+
end
76+
end
77+
6678
@inline function _translate(
6779
backend::AutoEnzyme, mode::Mode, ::Val{B}, c::DI.FunctionContext
6880
) where {B}

DifferentiationInterface/test/Back/Enzyme/test.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,14 @@ end
136136
logging=LOGGING,
137137
)
138138
end
139+
140+
@testset "Coverage" begin
141+
# ConstantOrCache without cache
142+
f_nocontext(x, p) = x
143+
@test I == DifferentiationInterface.jacobian(
144+
f_nocontext, AutoEnzyme(; mode=Enzyme.Forward), rand(10), ConstantOrCache(nothing)
145+
)
146+
@test I == DifferentiationInterface.jacobian(
147+
f_nocontext, AutoEnzyme(; mode=Enzyme.Reverse), rand(10), ConstantOrCache(nothing)
148+
)
149+
end

0 commit comments

Comments
 (0)