Skip to content

Commit 592e299

Browse files
committed
Code coverage
1 parent a7e4de6 commit 592e299

6 files changed

Lines changed: 31 additions & 89 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ function DI.prepare_pushforward_nokwarg(
1717
_sig = DI.signature(f, backend, x, tx, contexts...; strict)
1818
df = function_shadow(f, backend, Val(B))
1919
mode = forward_withprimal(backend)
20-
context_shadows = shadows(backend, mode, Val(B), contexts...)
20+
context_shadows = make_context_shadows(backend, mode, Val(B), contexts...)
2121
return EnzymeOneArgPushforwardPrep(_sig, df, context_shadows)
2222
end
2323

@@ -148,7 +148,7 @@ function DI.prepare_gradient_nokwarg(
148148
valB = to_val(DI.pick_batchsize(backend, x))
149149
df = function_shadow(f, backend, valB)
150150
mode = forward_withprimal(backend)
151-
context_shadows = shadows(backend, mode, valB, contexts...)
151+
context_shadows = make_context_shadows(backend, mode, valB, contexts...)
152152
basis_shadows = create_shadows(valB, x)
153153
return EnzymeForwardGradientPrep(_sig, valB, df, context_shadows, basis_shadows)
154154
end
@@ -237,7 +237,7 @@ function DI.prepare_jacobian_nokwarg(
237237
valB = to_val(DI.pick_batchsize(backend, x))
238238
mode = forward_withprimal(backend)
239239
df = function_shadow(f, backend, valB)
240-
context_shadows = shadows(backend, mode, valB, contexts...)
240+
context_shadows = make_context_shadows(backend, mode, valB, contexts...)
241241
basis_shadows = create_shadows(valB, x)
242242
return EnzymeForwardOneArgJacobianPrep(
243243
_sig, valB, df, context_shadows, basis_shadows, length(y)

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_twoarg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ function DI.prepare_pushforward_nokwarg(
1818
_sig = DI.signature(f!, y, backend, x, tx, contexts...; strict)
1919
df! = function_shadow(f!, backend, Val(B))
2020
mode = forward_noprimal(backend)
21-
context_shadows = shadows(backend, mode, Val(B), contexts...)
21+
context_shadows = make_context_shadows(backend, mode, Val(B), contexts...)
2222
return EnzymeTwoArgPushforwardPrep(_sig, df!, context_shadows)
2323
end
2424

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ function DI.prepare_pullback_nokwarg(
6565
_sig = DI.signature(f, backend, x, ty, contexts...; strict)
6666
df = function_shadow(f, backend, Val(B))
6767
mode = reverse_split_withprimal(backend)
68-
context_shadows = shadows(backend, mode, Val(B), contexts...)
68+
context_shadows = make_context_shadows(backend, mode, Val(B), contexts...)
6969
y = f(x, map(DI.unwrap, contexts)...)
7070
return EnzymeReverseOneArgPullbackPrep(_sig, df, context_shadows, y)
7171
end
@@ -216,7 +216,7 @@ function DI.prepare_gradient_nokwarg(
216216
_sig = DI.signature(f, backend, x, contexts...; strict)
217217
df = function_shadow(f, backend, Val(1))
218218
mode = reverse_withprimal(backend)
219-
context_shadows = shadows(backend, mode, Val(1), contexts...)
219+
context_shadows = make_context_shadows(backend, mode, Val(1), contexts...)
220220
return EnzymeGradientPrep(_sig, df, context_shadows)
221221
end
222222

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_twoarg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ function DI.prepare_pullback_nokwarg(
1919
_sig = DI.signature(f!, y, backend, x, ty, contexts...; strict)
2020
df! = function_shadow(f!, backend, Val(B))
2121
mode = reverse_noprimal(backend)
22-
context_shadows = shadows(backend, mode, Val(B), contexts...)
22+
context_shadows = make_context_shadows(backend, mode, Val(B), contexts...)
2323
ty_copy = map(copy, ty)
2424
return EnzymeReverseTwoArgPullbackPrep(_sig, df!, context_shadows, ty_copy)
2525
end

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl

Lines changed: 22 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
const AnyDuplicated = Union{
2+
Duplicated,
3+
MixedDuplicated,
4+
BatchDuplicated,
5+
BatchMixedDuplicated,
6+
DuplicatedNoNeed,
7+
BatchDuplicatedNoNeed,
8+
}
9+
110
# until https://github.com/EnzymeAD/Enzyme.jl/pull/1545 is merged
211
function DI.pick_batchsize(::AutoEnzyme, N::Integer)
312
B = DI.reasonable_batchsize(N, 16)
@@ -8,26 +17,17 @@ to_val(::DI.BatchSizeSettings{B}) where {B} = Val(B)
817

918
## Annotations
1019

11-
const AnyDuplicated = Union{
12-
Duplicated,
13-
MixedDuplicated,
14-
BatchDuplicated,
15-
BatchMixedDuplicated,
16-
DuplicatedNoNeed,
17-
BatchDuplicatedNoNeed,
18-
}
19-
20-
function get_f_and_df(f::F, ::AutoEnzyme{M,Nothing}, ::Val{B}) where {F,M,B}
20+
function get_f_and_df_prepared!(_df, f::F, ::AutoEnzyme{M,Nothing}, ::Val{B}) where {F,M,B}
2121
return f
2222
end
2323

24-
function get_f_and_df(f::F, ::AutoEnzyme{M,<:Const}, ::Val{B}) where {F,M,B}
24+
function get_f_and_df_prepared!(_df, f::F, ::AutoEnzyme{M,<:Const}, ::Val{B}) where {F,M,B}
2525
return Const(f)
2626
end
2727

28-
function get_f_and_df(f::F, backend::AutoEnzyme{M,<:AnyDuplicated}, ::Val{B}) where {F,M,B}
29-
# TODO: needs more sophistication for mixed activities
30-
df = function_shadow(f, backend, Val(B))
28+
function get_f_and_df_prepared!(
29+
df, f::F, ::AutoEnzyme{M,<:AnyDuplicated}, ::Val{B}
30+
) where {F,M,B}
3131
if B == 1
3232
return Duplicated(f, df)
3333
else
@@ -49,71 +49,9 @@ function function_shadow(f::F, ::AutoEnzyme{M,<:AnyDuplicated}, ::Val{B}) where
4949
end
5050
end
5151

52-
function get_f_and_df_prepared!(_df, f::F, ::AutoEnzyme{M,Nothing}, ::Val{B}) where {F,M,B}
53-
return f
54-
end
55-
56-
function get_f_and_df_prepared!(_df, f::F, ::AutoEnzyme{M,<:Const}, ::Val{B}) where {F,M,B}
57-
return Const(f)
58-
end
59-
60-
function get_f_and_df_prepared!(
61-
df, f::F, ::AutoEnzyme{M,<:AnyDuplicated}, ::Val{B}
62-
) where {F,M,B}
63-
if B == 1
64-
return Duplicated(f, df)
65-
else
66-
return BatchDuplicated(f, df)
67-
end
68-
end
69-
7052
force_annotation(f::F) where {F<:Annotation} = f
7153
force_annotation(f::F) where {F} = Const(f)
7254

73-
function _translate(::AutoEnzyme, ::Mode, ::Val{B}, c_wrapped::DI.Constant) where {B}
74-
c = DI.unwrap(c_wrapped)
75-
return Const(c)
76-
end
77-
78-
function _translate(::AutoEnzyme, ::Mode, ::Val{B}, c_wrapped::DI.Cache) where {B}
79-
c = DI.unwrap(c_wrapped)
80-
if B == 1
81-
dc = make_zero(c)
82-
return Duplicated(c, dc)
83-
else
84-
dc = ntuple(_ -> make_zero(c), Val(B))
85-
return BatchDuplicated(c, dc)
86-
end
87-
end
88-
89-
function _translate(
90-
backend::AutoEnzyme, mode::Mode, ::Val{B}, c_wrapped::DI.ConstantOrCache
91-
) where {B}
92-
c = DI.unwrap(c_wrapped)
93-
IA = guess_activity(typeof(c), mode)
94-
if IA <: Const
95-
return _translate(backend, mode, Val(B), DI.Constant(c))
96-
else
97-
return _translate(backend, mode, Val(B), DI.Cache(c))
98-
end
99-
end
100-
101-
function _translate(
102-
backend::AutoEnzyme, ::Mode, ::Val{B}, c_wrapped::DI.FunctionContext
103-
) where {B}
104-
f = DI.unwrap(c_wrapped)
105-
return force_annotation(get_f_and_df(f, backend, Val(B)))
106-
end
107-
108-
function translate(
109-
backend::AutoEnzyme, mode::Mode, ::Val{B}, contexts::Vararg{DI.Context,C}
110-
) where {B,C}
111-
new_contexts = map(contexts) do c_wrapped
112-
_translate(backend, mode, Val(B), c_wrapped)
113-
end
114-
return new_contexts
115-
end
116-
11755
function _shadow(::AutoEnzyme, ::Mode, ::Val{B}, c_wrapped::DI.Constant) where {B}
11856
return nothing
11957
end
@@ -128,14 +66,18 @@ function _shadow(::AutoEnzyme, ::Mode, ::Val{B}, c_wrapped::DI.Cache) where {B}
12866
end
12967

13068
function _shadow(
131-
backend::AutoEnzyme, mode::Mode, valB::Val{B}, c_wrapped::DI.ConstantOrCache
69+
::AutoEnzyme, mode::Mode, valB::Val{B}, c_wrapped::DI.ConstantOrCache
13270
) where {B}
13371
c = DI.unwrap(c_wrapped)
13472
IA = guess_activity(typeof(c), mode)
13573
if IA <: Const
136-
return _shadow(backend, mode, valB, DI.Constant(c))
74+
nothing
13775
else
138-
return _shadow(backend, mode, valB, DI.Cache(c))
76+
if B == 1
77+
return make_zero(c)
78+
else
79+
return ntuple(_ -> make_zero(c), Val(B))
80+
end
13981
end
14082
end
14183

@@ -149,7 +91,7 @@ function _shadow(
14991
return function_shadow(f, backend, Val(B))
15092
end
15193

152-
function shadows(
94+
function make_context_shadows(
15395
backend::AutoEnzyme, mode::Mode, ::Val{B}, contexts::Vararg{DI.Context,C}
15496
) where {B,C}
15597
context_shadows = map(contexts) do c_wrapped

DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ import DifferentiationInterface as DI
55
using SparseConnectivityTracer:
66
TracerSparsityDetector, TracerLocalSparsityDetector, jacobian_buffer, hessian_buffer
77

8-
function _translate(::Type, c::Union{DI.GeneralizedConstant,DI.ConstantOrCache})
8+
@inline function _translate(::Type, c::Union{DI.GeneralizedConstant,DI.ConstantOrCache})
99
return DI.unwrap(c)
1010
end
11-
function _translate(::Type{T}, c::DI.Cache) where {T}
11+
@inline function _translate(::Type{T}, c::DI.Cache) where {T}
1212
return DI.recursive_similar(DI.unwrap(c), T)
1313
end
1414

0 commit comments

Comments
 (0)