Skip to content

Commit a1d9c26

Browse files
authored
[BREAKING] Clean up tests (#470)
* Clean up tests * Fixes * Fix sparse autodiff * Remove options * Typo * Remove exports * Fixes * Docs * Fixes * Fix * Fix order of scenario modificators
1 parent d51fc0a commit a1d9c26

36 files changed

Lines changed: 635 additions & 789 deletions

File tree

.github/workflows/Test.yml

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ on:
44
push:
55
branches:
66
- main
7-
tags: ['*']
7+
tags: ["*"]
88
pull_request:
99
types: [opened, reopened, synchronize, ready_for_review]
1010
workflow_dispatch:
@@ -16,7 +16,6 @@ concurrency:
1616
cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }}
1717

1818
jobs:
19-
2019
test-DI:
2120
name: ${{ matrix.version }} - DI (${{ matrix.group }})
2221
runs-on: ubuntu-latest
@@ -28,9 +27,9 @@ jobs:
2827
fail-fast: false
2928
matrix:
3029
version:
31-
- '1'
32-
- 'lts'
33-
- 'pre'
30+
- "1"
31+
- "lts"
32+
- "pre"
3433
group:
3534
- Formalities
3635
- Internals
@@ -48,51 +47,53 @@ jobs:
4847
- Back/Tapir
4948
- Back/Tracker
5049
- Back/Zygote
51-
- Down/Detector
52-
- Down/DifferentiateWith
50+
- Misc/DifferentiateWith
51+
- Misc/FromPrimitive
52+
- Misc/SparsityDetector
53+
- Misc/ZeroBackends
5354
- Down/Flux
5455
- Down/Lux
5556
exclude:
5657
# lts
57-
- version: 'lts'
58+
- version: "lts"
5859
group: Formalities
59-
- version: 'lts'
60+
- version: "lts"
6061
group: Back/ChainRulesCore
61-
- version: 'lts'
62+
- version: "lts"
6263
group: Back/Diffractor
63-
- version: 'lts'
64+
- version: "lts"
6465
group: Back/Enzyme
65-
- version: 'lts'
66+
- version: "lts"
6667
group: Back/FiniteDiff
67-
- version: 'lts'
68+
- version: "lts"
6869
group: Back/FastDifferentiation
69-
- version: 'lts'
70+
- version: "lts"
7071
group: Back/PolyesterForwardDiff
71-
- version: 'lts'
72+
- version: "lts"
7273
group: Back/SecondOrder
73-
- version: 'lts'
74+
- version: "lts"
7475
group: Back/Symbolics
75-
- version: 'lts'
76+
- version: "lts"
7677
group: Back/Tapir
77-
- version: 'lts'
78-
group: Down/Detector
79-
- version: 'lts'
78+
- version: "lts"
79+
group: Misc/SparsityDetector
80+
- version: "lts"
8081
group: Down/Flux
81-
- version: 'lts'
82+
- version: "lts"
8283
group: Down/Lux
8384
# pre-release
84-
- version: 'pre'
85+
- version: "pre"
8586
group: Formalities
86-
- version: 'pre'
87+
- version: "pre"
8788
group: Back/ChainRulesCore
88-
- version: 'pre'
89+
- version: "pre"
8990
group: Back/Enzyme
90-
- version: 'pre'
91+
- version: "pre"
9192
group: Back/Tapir
92-
- version: 'pre'
93+
- version: "pre"
9394
group: Back/SecondOrder
94-
- version: 'pre'
95-
group: Down/Detector
95+
- version: "pre"
96+
group: Misc/SparsityDetector
9697
env:
9798
SHOULDRUN: ${{ matrix.version == '1' || !github.event.pull_request.draft }}
9899
JULIA_DI_TEST_GROUP: ${{ matrix.group }}
@@ -138,18 +139,18 @@ jobs:
138139
fail-fast: false
139140
matrix:
140141
version:
141-
- '1'
142-
- 'lts'
143-
- 'pre'
142+
- "1"
143+
- "lts"
144+
- "pre"
144145
group:
145146
- Formalities
146147
- Zero
147-
- ForwardDiff
148+
- Standard
148149
- Weird
149150
exclude:
150-
- version: 'lts'
151+
- version: "lts"
151152
group: Formalities
152-
- version: 'lts'
153+
- version: "lts"
153154
group: Weird
154155
env:
155156
SHOULDRUN: ${{ matrix.version == '1' || !github.event.pull_request.draft }}
@@ -183,4 +184,4 @@ jobs:
183184
flags: DIT
184185
name: ${{ matrix.version }} - DIT (${{ matrix.group }})
185186
token: ${{ secrets.CODECOV_TOKEN }}
186-
fail_ci_if_error: true
187+
fail_ci_if_error: true

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/DifferentiationInterfaceSparseMatrixColoringsExt.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ using SparseMatrixColorings:
4141
decompress,
4242
decompress!
4343

44+
with_context(f, contexts::Vararg{Context,C}) where {C} = (DI.with_context(f, contexts...),)
45+
46+
function with_context(f!, y, contexts::Vararg{Context,C}) where {C}
47+
return (DI.with_context(f!, contexts...), y)
48+
end
49+
4450
include("jacobian.jl")
4551
include("hessian.jl")
4652

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/hessian.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ function DI.prepare_hessian(
3939
f::F, backend::AutoSparse, x, contexts::Vararg{Context,C}
4040
) where {F,C}
4141
dense_backend = dense_ad(backend)
42-
sparsity = hessian_sparsity(f, x, sparsity_detector(backend))
42+
sparsity = hessian_sparsity(
43+
with_context(f, contexts...)..., x, sparsity_detector(backend)
44+
)
4345
problem = ColoringProblem{:symmetric,:column}()
4446
coloring_result = coloring(
4547
sparsity, problem, coloring_algorithm(backend); decompression_eltype=eltype(x)

DifferentiationInterface/ext/DifferentiationInterfaceSparseMatrixColoringsExt/jacobian.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,10 @@ function _prepare_sparse_jacobian_aux(
8181
::PushforwardFast, y, f_or_f!y::FY, backend::AutoSparse, x, contexts::Vararg{Context,C}
8282
) where {FY,C}
8383
dense_backend = dense_ad(backend)
84-
sparsity = jacobian_sparsity(f_or_f!y..., x, sparsity_detector(backend))
84+
85+
sparsity = jacobian_sparsity(
86+
with_context(f_or_f!y..., contexts...)..., x, sparsity_detector(backend)
87+
)
8588
problem = ColoringProblem{:nonsymmetric,:column}()
8689
coloring_result = coloring(
8790
sparsity,
@@ -115,7 +118,9 @@ function _prepare_sparse_jacobian_aux(
115118
::PushforwardSlow, y, f_or_f!y::FY, backend::AutoSparse, x, contexts::Vararg{Context,C}
116119
) where {FY,C}
117120
dense_backend = dense_ad(backend)
118-
sparsity = jacobian_sparsity(f_or_f!y..., x, sparsity_detector(backend))
121+
sparsity = jacobian_sparsity(
122+
with_context(f_or_f!y..., contexts...)..., x, sparsity_detector(backend)
123+
)
119124
problem = ColoringProblem{:nonsymmetric,:row}()
120125
coloring_result = coloring(
121126
sparsity,
@@ -163,7 +168,7 @@ end
163168
function DI.value_and_jacobian(
164169
f::F, extras::SparseJacobianExtras, backend::AutoSparse, x, contexts::Vararg{Context,C}
165170
) where {F,C}
166-
return f(x), jacobian(f, extras, backend, x, contexts...)
171+
return f(x, map(unwrap, contexts)...), jacobian(f, extras, backend, x, contexts...)
167172
end
168173

169174
function DI.value_and_jacobian!(

DifferentiationInterface/test/Back/Enzyme/test.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ test_differentiation(
5555

5656
test_differentiation(
5757
duplicated_function_backends,
58-
DIT.make_closure.(default_scenarios());
58+
default_scenarios(; include_normal=false, include_closurified=true);
5959
second_order=false,
6060
logging=LOGGING,
6161
);
@@ -81,7 +81,7 @@ test_differentiation(
8181

8282
test_differentiation(
8383
AutoEnzyme(; mode=Enzyme.Forward), # TODO: add more
84-
DIT.remove_batched(default_scenarios());
84+
default_scenarios(; include_batchified=false);
8585
correctness=false,
8686
type_stability=true,
8787
second_order=false,

DifferentiationInterface/test/Back/ForwardDiff/test.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ test_differentiation(dense_backends, default_scenarios(); logging=LOGGING);
3131

3232
test_differentiation(
3333
dense_backends,
34-
default_scenarios();
34+
default_scenarios(; include_constantified=true);
3535
correctness=false,
3636
type_stability=true,
3737
second_order=false,
@@ -57,4 +57,9 @@ test_differentiation(
5757
logging=LOGGING,
5858
);
5959

60-
test_differentiation(sparse_backends, sparse_scenarios(); sparsity=true, logging=LOGGING);
60+
test_differentiation(
61+
sparse_backends,
62+
sparse_scenarios(; include_constantified=true);
63+
sparsity=true,
64+
logging=LOGGING,
65+
);

DifferentiationInterface/test/Down/DifferentiateWith/test.jl renamed to DifferentiationInterface/test/Misc/DifferentiateWith/test.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ using Test
1010
LOGGING = get(ENV, "CI", "false") == "false"
1111

1212
function zygote_breaking_scenarios()
13-
onearg_scens = filter(default_scenarios()) do scen
14-
DIT.nb_args(scen) == 1
13+
outofplace_scens = filter(default_scenarios()) do scen
14+
DIT.operator_place(scen) == :out
1515
end
16-
bad_onearg_scens = map(onearg_scens) do scen
16+
bad_outofplace_scens = map(outofplace_scens) do scen
1717
function bad_f(x)
1818
a = Vector{eltype(x)}(undef, 1)
1919
a[1] = sum(x)
@@ -23,7 +23,7 @@ function zygote_breaking_scenarios()
2323
bad_scen = DIT.change_function(scen, wrapped_bad_f)
2424
return bad_scen
2525
end
26-
return bad_onearg_scens
26+
return bad_outofplace_scens
2727
end
2828

2929
test_differentiation(

DifferentiationInterface/test/Internals/from_primitive.jl renamed to DifferentiationInterface/test/Misc/FromPrimitive/test.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
using DifferentiationInterface, DifferentiationInterfaceTest
22
using DifferentiationInterface: AutoForwardFromPrimitive, AutoReverseFromPrimitive
33
using DifferentiationInterfaceTest
4-
using DifferentiationInterfaceTest: insert_context
54
using ForwardDiff: ForwardDiff
65
using Test
76

@@ -17,8 +16,6 @@ for backend in vcat(fromprimitive_backends)
1716
@test check_inplace(backend)
1817
end
1918

20-
test_differentiation(fromprimitive_backends, default_scenarios(); logging=LOGGING);
21-
2219
test_differentiation(
23-
fromprimitive_backends, insert_context.(default_scenarios()); logging=LOGGING
20+
fromprimitive_backends, default_scenarios(; include_constantified=true); logging=LOGGING
2421
);

DifferentiationInterface/test/Down/Detector/test.jl renamed to DifferentiationInterface/test/Misc/SparsityDetector/test.jl

File renamed without changes.

DifferentiationInterface/test/Internals/zero_backends.jl renamed to DifferentiationInterface/test/Misc/ZeroBackends/test.jl

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
using DifferentiationInterface
22
using DifferentiationInterface: AutoZeroForward, AutoZeroReverse
33
using DifferentiationInterfaceTest
4-
using DifferentiationInterfaceTest: insert_context
54
using ComponentArrays: ComponentArrays
65
using JLArrays: JLArrays
76
using StaticArrays: StaticArrays
@@ -20,7 +19,7 @@ end
2019

2120
test_differentiation(
2221
zero_backends,
23-
zero.(default_scenarios());
22+
zero.(default_scenarios(; include_constantified=true));
2423
correctness=true,
2524
type_stability=true,
2625
excluded=[:second_derivative],
@@ -39,15 +38,6 @@ test_differentiation(
3938
logging=LOGGING,
4039
)
4140

42-
## Contexts
43-
44-
test_differentiation(
45-
zero_backends,
46-
insert_context.(zero.(default_scenarios()));
47-
correctness=true,
48-
logging=LOGGING,
49-
)
50-
5141
## Weird arrays
5242

5343
test_differentiation(

0 commit comments

Comments
 (0)