Skip to content

Commit 3abcee1

Browse files
authored
fix: better handling of Enzyme split mode + run Enzyme tests on 1.11 (#654)
* Test Enzyme on 1.11 * Remove matrix inputs * Explicit EnzymeCore dep * Skip step 6 * Custom diff implem * Even more manual diff * Filter sparse scenarios
1 parent 3c6e5dc commit 3abcee1

5 files changed

Lines changed: 91 additions & 77 deletions

File tree

.github/workflows/Test.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
actions: write
2626
contents: read
2727
strategy:
28-
fail-fast: false # TODO: toggle
28+
fail-fast: true # TODO: toggle
2929
matrix:
3030
version:
3131
- "1.10"
@@ -57,8 +57,6 @@ jobs:
5757
# version: "1.10"
5858
- version: "1"
5959
group: Back/ChainRules
60-
- version: "1"
61-
group: Back/Enzyme
6260
env:
6361
JULIA_DI_TEST_GROUP: ${{ matrix.group }}
6462
steps:

DifferentiationInterface/Project.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.25"
4+
version = "0.6.26"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -12,6 +12,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1212
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
1313
Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c"
1414
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
15+
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
1516
FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be"
1617
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1718
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
@@ -29,7 +30,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2930
[extensions]
3031
DifferentiationInterfaceChainRulesCoreExt = "ChainRulesCore"
3132
DifferentiationInterfaceDiffractorExt = "Diffractor"
32-
DifferentiationInterfaceEnzymeExt = "Enzyme"
33+
DifferentiationInterfaceEnzymeExt = ["EnzymeCore", "Enzyme"]
3334
DifferentiationInterfaceFastDifferentiationExt = "FastDifferentiation"
3435
DifferentiationInterfaceFiniteDiffExt = "FiniteDiff"
3536
DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
@@ -49,7 +50,7 @@ ADTypes = "1.9.0"
4950
ChainRulesCore = "1.23.0"
5051
DiffResults = "1.1.0"
5152
Diffractor = "=0.2.6"
52-
Enzyme = "0.13.6"
53+
Enzyme = "0.13.17"
5354
ExplicitImports = "1.10.1"
5455
FastDifferentiation = "0.4.1"
5556
FiniteDiff = "2.23.1"

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@ module DifferentiationInterfaceEnzymeExt
33
using ADTypes: ADTypes, AutoEnzyme
44
using Base: Fix1
55
import DifferentiationInterface as DI
6-
using Enzyme:
6+
using EnzymeCore:
77
Active,
88
Annotation,
99
BatchDuplicated,
1010
BatchMixedDuplicated,
11+
Combined,
1112
Const,
1213
Duplicated,
1314
DuplicatedNoNeed,
@@ -25,7 +26,9 @@ using Enzyme:
2526
ReverseSplitWidth,
2627
ReverseSplitWithPrimal,
2728
ReverseWithPrimal,
28-
WithPrimal,
29+
Split,
30+
WithPrimal
31+
using Enzyme:
2932
autodiff,
3033
autodiff_thunk,
3134
create_shadows,

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,12 @@ reverse_noprimal(::AutoEnzyme{Nothing}) = Reverse
5858
reverse_withprimal(backend::AutoEnzyme{<:ReverseMode}) = WithPrimal(backend.mode)
5959
reverse_withprimal(::AutoEnzyme{Nothing}) = ReverseWithPrimal
6060

61-
function reverse_split_withprimal(backend::AutoEnzyme)
62-
mode = ReverseSplitWithPrimal
63-
return set_err(mode, backend)
61+
function reverse_split_withprimal(backend::AutoEnzyme{<:ReverseMode})
62+
return set_err(WithPrimal(Split(backend.mode)), backend)
63+
end
64+
65+
function reverse_split_withprimal(backend::AutoEnzyme{Nothing})
66+
return set_err(ReverseSplitWithPrimal, backend)
6467
end
6568

6669
set_err(mode::Mode, ::AutoEnzyme{<:Any,Nothing}) = EnzymeCore.set_err_if_func_written(mode)

DifferentiationInterface/test/Back/Enzyme/test.jl

Lines changed: 75 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,20 @@ check_no_implicit_imports(DifferentiationInterface)
1313

1414
LOGGING = get(ENV, "CI", "false") == "false"
1515

16+
function remove_matrix_inputs(scens::Vector{<:Scenario}) # TODO: remove
17+
if VERSION < v"1.11"
18+
return scens
19+
else
20+
# for https://github.com/EnzymeAD/Enzyme.jl/issues/2071
21+
return filter(s -> s.x isa Union{Number,AbstractVector}, scens)
22+
end
23+
end
24+
1625
backends = [
1726
AutoEnzyme(; mode=nothing),
1827
AutoEnzyme(; mode=Enzyme.Forward),
1928
AutoEnzyme(; mode=Enzyme.Reverse),
20-
AutoEnzyme(; mode=Enzyme.Forward, function_annotation=Enzyme.Const),
21-
AutoEnzyme(; mode=Enzyme.Reverse, function_annotation=Enzyme.Const),
29+
AutoEnzyme(; mode=nothing, function_annotation=Enzyme.Const),
2230
]
2331

2432
duplicated_backends = [
@@ -33,27 +41,25 @@ duplicated_backends = [
3341
end
3442
end;
3543

36-
## First order
37-
38-
test_differentiation(backends, default_scenarios(); excluded=SECOND_ORDER, logging=LOGGING);
39-
40-
test_differentiation(
41-
backends[1:3],
42-
default_scenarios(; include_normal=false, include_constantified=true);
43-
excluded=SECOND_ORDER,
44-
logging=LOGGING,
45-
);
46-
47-
#=
48-
# TODO: reactivate closurified tests once Enzyme#2056 is fixed
49-
50-
test_differentiation(
51-
duplicated_backends,
52-
default_scenarios(; include_normal=false, include_closurified=true);
53-
excluded=SECOND_ORDER,
54-
logging=LOGGING,
55-
);
56-
=#
44+
@testset "First order" begin
45+
test_differentiation(
46+
backends, default_scenarios(); excluded=SECOND_ORDER, logging=LOGGING
47+
)
48+
49+
test_differentiation(
50+
backends[1:3],
51+
default_scenarios(; include_normal=false, include_constantified=true);
52+
excluded=SECOND_ORDER,
53+
logging=LOGGING,
54+
)
55+
56+
test_differentiation(
57+
duplicated_backends,
58+
default_scenarios(; include_normal=false, include_closurified=true);
59+
excluded=SECOND_ORDER,
60+
logging=LOGGING,
61+
)
62+
end
5763

5864
#=
5965
# TODO: reactivate type stability tests
@@ -68,50 +74,53 @@ test_differentiation(
6874
);
6975
=#
7076

71-
## Second order
72-
73-
test_differentiation(
74-
AutoEnzyme(),
75-
default_scenarios(; include_constantified=true);
76-
excluded=FIRST_ORDER,
77-
logging=LOGGING,
78-
);
79-
80-
test_differentiation(
81-
AutoEnzyme(; mode=Enzyme.Forward);
82-
excluded=vcat(FIRST_ORDER, [:hessian, :hvp]),
83-
logging=LOGGING,
84-
);
85-
86-
test_differentiation(
87-
AutoEnzyme(; mode=Enzyme.Reverse);
88-
excluded=vcat(FIRST_ORDER, [:second_derivative]),
89-
logging=LOGGING,
90-
);
91-
92-
test_differentiation(
93-
SecondOrder(AutoEnzyme(; mode=Enzyme.Reverse), AutoEnzyme(; mode=Enzyme.Forward));
94-
logging=LOGGING,
95-
);
96-
97-
## Sparse
77+
@testset "Second order" begin
78+
test_differentiation(
79+
[
80+
AutoEnzyme(),
81+
SecondOrder(
82+
AutoEnzyme(; mode=Enzyme.Reverse), AutoEnzyme(; mode=Enzyme.Forward)
83+
),
84+
],
85+
remove_matrix_inputs(default_scenarios(; include_constantified=true));
86+
excluded=FIRST_ORDER,
87+
logging=LOGGING,
88+
)
89+
90+
test_differentiation(
91+
AutoEnzyme(; mode=Enzyme.Forward);
92+
excluded=vcat(FIRST_ORDER, [:hessian, :hvp]),
93+
logging=LOGGING,
94+
)
95+
end
9896

99-
test_differentiation(
100-
MyAutoSparse.(AutoEnzyme(; function_annotation=Enzyme.Const)),
101-
sparse_scenarios();
102-
sparsity=true,
103-
logging=LOGGING,
104-
);
97+
@testset "Sparse" begin
98+
test_differentiation(
99+
MyAutoSparse.(AutoEnzyme(; function_annotation=Enzyme.Const)),
100+
if VERSION < v"1.11"
101+
sparse_scenarios()
102+
else
103+
filter(sparse_scenarios()) do s
104+
# for https://github.com/EnzymeAD/Enzyme.jl/issues/2168
105+
(s.x isa AbstractVector) &&
106+
(s.f != DIT.sumdiffcube) &&
107+
(s.f != DIT.sumdiffcube_mat)
108+
end
109+
end;
110+
sparsity=true,
111+
logging=LOGGING,
112+
)
113+
end
105114

106-
##
115+
@testset "Static" begin
116+
filtered_static_scenarios = filter(static_scenarios()) do s
117+
DIT.operator_place(s) == :out && DIT.function_place(s) == :out
118+
end
107119

108-
filtered_static_scenarios = filter(static_scenarios()) do s
109-
DIT.operator_place(s) == :out && DIT.function_place(s) == :out
120+
test_differentiation(
121+
[AutoEnzyme(; mode=Enzyme.Forward), AutoEnzyme(; mode=Enzyme.Reverse)],
122+
filtered_static_scenarios;
123+
excluded=SECOND_ORDER,
124+
logging=LOGGING,
125+
)
110126
end
111-
112-
test_differentiation(
113-
[AutoEnzyme(; mode=Enzyme.Forward), AutoEnzyme(; mode=Enzyme.Reverse)],
114-
filtered_static_scenarios;
115-
excluded=SECOND_ORDER,
116-
logging=LOGGING,
117-
)

0 commit comments

Comments
 (0)