Skip to content

Commit 0a0a943

Browse files
authored
Remove use of custom isequal to compare scenarios (#486)
* Remove use of custom `isequal` to compare scenarios * Fix * Scenario intact toggle * No testing equality of f
1 parent 1369010 commit 0a0a943

8 files changed

Lines changed: 47 additions & 46 deletions

File tree

DifferentiationInterface/test/Down/Flux/test.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ test_differentiation(
2020
# AutoEnzyme() # TODO: fix
2121
],
2222
DIT.flux_scenarios();
23-
isequal=DIT.flux_isequal,
2423
isapprox=DIT.flux_isapprox,
2524
rtol=1e-2,
2625
atol=1e-6,
26+
scenario_intact=false, # TODO: why?
27+
logging=LOGGING,
2728
)

DifferentiationInterface/test/Down/Lux/test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,9 @@ Random.seed!(0)
1616
test_differentiation(
1717
AutoZygote(),
1818
DIT.lux_scenarios(Random.Xoshiro(63));
19-
isequal=DIT.lux_isequal,
2019
isapprox=DIT.lux_isapprox,
2120
rtol=1.0f-2,
2221
atol=1.0f-3,
22+
scenario_intact=false, # TODO: why?
2323
logging=LOGGING,
2424
)

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestFluxExt/DifferentiationInterfaceTestFluxExt.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,6 @@ function gradient_finite_differences(loss, model)
2323
return re(only(gs))
2424
end
2525

26-
function DIT.flux_isequal(a, b)
27-
return all(isequal.(fleaves(a), fleaves(b)))
28-
end
29-
3026
function DIT.flux_isapprox(a, b; atol, rtol)
3127
isapprox_results = fmapstructure_with_path(a, b) do kp, x, y
3228
if :state in kp # ignore RNN and LSTM state

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestLuxExt/DifferentiationInterfaceTestLuxExt.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,6 @@ Relevant discussions:
1717
- https://github.com/LuxDL/Lux.jl/issues/769
1818
=#
1919

20-
function DIT.lux_isequal(a, b)
21-
return check_approx(a, b; atol=0, rtol=0)
22-
end
23-
2420
function DIT.lux_isapprox(a, b; atol, rtol)
2521
return check_approx(a, b; atol, rtol)
2622
end

DifferentiationInterfaceTest/src/scenarios/scenario.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,21 @@ function Scenario{op,pl_op}(
7272
return Scenario{op,pl_op,:in}(f!; x, y, tang, contexts, res1, res2)
7373
end
7474

75+
Base.:(==)(scen1::Scenario, scen2::Scenario) = false
76+
77+
function Base.:(==)(
78+
scen1::Scenario{op,pl_op,pl_fun}, scen2::Scenario{op,pl_op,pl_fun}
79+
) where {op,pl_op,pl_fun}
80+
eq_f = scen1.f == scen2.f
81+
eq_x = scen1.x == scen2.x
82+
eq_y = scen1.y == scen2.y
83+
eq_tang = scen1.tang == scen2.tang
84+
eq_contexts = scen1.contexts == scen2.contexts
85+
eq_res1 = scen1.res1 == scen2.res1
86+
eq_res2 = scen1.res2 == scen2.res2
87+
return (eq_x && eq_y && eq_tang && eq_contexts && eq_res1 && eq_res2)
88+
end
89+
7590
operator(::Scenario{op}) where {op} = op
7691
operator_place(::Scenario{op,pl_op}) where {op,pl_op} = pl_op
7792
function_place(::Scenario{op,pl_op,pl_fun}) where {op,pl_op,pl_fun} = pl_fun

DifferentiationInterfaceTest/src/test_differentiation.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,10 @@ Filtering:
4141
Options:
4242
4343
- `logging=false`: whether to log progress
44-
- `isequal=isequal`: function used to compare objects exactly, with the standard signature `isequal(x, y)`
4544
- `isapprox=isapprox`: function used to compare objects approximately, with the standard signature `isapprox(x, y; atol, rtol)`
4645
- `atol=0`: absolute precision for correctness testing (when comparing to the reference outputs)
4746
- `rtol=1e-3`: relative precision for correctness testing (when comparing to the reference outputs)
47+
- `scenario_intact=true`: whether to check that the scenario remains unchanged after the operators are applied
4848
"""
4949
function test_differentiation(
5050
backends::Vector{<:AbstractADType},
@@ -63,10 +63,10 @@ function test_differentiation(
6363
excluded::Vector{Symbol}=Symbol[],
6464
# options
6565
logging::Bool=false,
66-
isequal=isequal,
6766
isapprox=isapprox,
6867
atol::Real=0,
6968
rtol::Real=1e-3,
69+
scenario_intact::Bool=true,
7070
)
7171
scenarios = filter_scenarios(
7272
scenarios; first_order, second_order, input_type, output_type, excluded
@@ -109,7 +109,7 @@ function test_differentiation(
109109
],
110110
)
111111
correctness && @testset "Correctness" begin
112-
test_correctness(backend, scen; isequal, isapprox, atol, rtol)
112+
test_correctness(backend, scen; isapprox, atol, rtol, scenario_intact)
113113
end
114114
type_stability && @testset "Type stability" begin
115115
@static if VERSION >= v"1.7"

DifferentiationInterfaceTest/src/tests/correctness_eval.jl

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,3 @@
1-
function test_scen_intact(new_scen, scen; isequal)
2-
for n in fieldnames(typeof(scen))
3-
n == :f && continue
4-
@test isequal(getfield(new_scen, n), getfield(scen, n))
5-
end
6-
end
7-
81
for op in [
92
:derivative,
103
:gradient,
@@ -55,10 +48,10 @@ for op in [
5548
@eval function test_correctness(
5649
ba::AbstractADType,
5750
scen::$S1out;
58-
isequal::Function,
5951
isapprox::Function,
6052
atol::Real,
6153
rtol::Real,
54+
scenario_intact::Bool,
6255
)
6356
@compat (; f, x, y, res1, contexts) = new_scen = deepcopy(scen)
6457
xrand = myrandom(x)
@@ -80,17 +73,17 @@ for op in [
8073
@test res1_out2_noval scen.res1
8174
end
8275
end
83-
test_scen_intact(new_scen, scen; isequal)
76+
scenario_intact && @test new_scen == scen
8477
return nothing
8578
end
8679

8780
@eval function test_correctness(
8881
ba::AbstractADType,
8982
scen::$S1in;
90-
isequal::Function,
9183
isapprox::Function,
9284
atol::Real,
9385
rtol::Real,
86+
scenario_intact::Bool,
9487
)
9588
@compat (; f, x, y, res1, contexts) = new_scen = deepcopy(scen)
9689
xrand = myrandom(x)
@@ -128,7 +121,7 @@ for op in [
128121
@test res1_out2_noval scen.res1
129122
end
130123
end
131-
test_scen_intact(new_scen, scen; isequal)
124+
scenario_intact && @test new_scen == scen
132125
return nothing
133126
end
134127

@@ -137,10 +130,10 @@ for op in [
137130
@eval function test_correctness(
138131
ba::AbstractADType,
139132
scen::$S2out;
140-
isequal::Function,
141133
isapprox::Function,
142134
atol::Real,
143135
rtol::Real,
136+
scenario_intact::Bool,
144137
)
145138
@compat (; f, x, y, res1, contexts) = new_scen = deepcopy(scen)
146139
xrand, yrand = myrandom(x), myrandom(y)
@@ -172,17 +165,17 @@ for op in [
172165
@test res1_out2_noval scen.res1
173166
end
174167
end
175-
test_scen_intact(new_scen, scen; isequal)
168+
scenario_intact && @test new_scen == scen
176169
return nothing
177170
end
178171

179172
@eval function test_correctness(
180173
ba::AbstractADType,
181174
scen::$S2in;
182-
isequal::Function,
183175
isapprox::Function,
184176
atol::Real,
185177
rtol::Real,
178+
scenario_intact::Bool,
186179
)
187180
@compat (; f, x, y, res1, contexts) = new_scen = deepcopy(scen)
188181
xrand, yrand = myrandom(x), myrandom(y)
@@ -222,18 +215,18 @@ for op in [
222215
@test res1_out2_noval scen.res1
223216
end
224217
end
225-
test_scen_intact(new_scen, scen; isequal)
218+
scenario_intact && @test new_scen == scen
226219
return nothing
227220
end
228221

229222
elseif op in [:second_derivative, :hessian]
230223
@eval function test_correctness(
231224
ba::AbstractADType,
232225
scen::$S1out;
233-
isequal::Function,
234226
isapprox::Function,
235227
atol::Real,
236228
rtol::Real,
229+
scenario_intact::Bool,
237230
)
238231
@compat (; f, x, y, res1, res2, contexts) = new_scen = deepcopy(scen)
239232
xrand = myrandom(x)
@@ -261,17 +254,17 @@ for op in [
261254
@test res2_out2_noval scen.res2
262255
end
263256
end
264-
test_scen_intact(new_scen, scen; isequal)
257+
scenario_intact && @test new_scen == scen
265258
return nothing
266259
end
267260

268261
@eval function test_correctness(
269262
ba::AbstractADType,
270263
scen::$S1in;
271-
isequal::Function,
272264
isapprox::Function,
273265
atol::Real,
274266
rtol::Real,
267+
scenario_intact::Bool,
275268
)
276269
@compat (; f, x, y, res1, res2, contexts) = new_scen = deepcopy(scen)
277270
xrand = myrandom(x)
@@ -313,18 +306,18 @@ for op in [
313306
@test res2_out2_noval scen.res2
314307
end
315308
end
316-
test_scen_intact(new_scen, scen; isequal)
309+
scenario_intact && @test new_scen == scen
317310
return nothing
318311
end
319312

320313
elseif op in [:pushforward, :pullback]
321314
@eval function test_correctness(
322315
ba::AbstractADType,
323316
scen::$S1out;
324-
isequal::Function,
325317
isapprox::Function,
326318
atol::Real,
327319
rtol::Real,
320+
scenario_intact::Bool,
328321
)
329322
@compat (; f, x, y, tang, res1, contexts) = new_scen = deepcopy(scen)
330323
xrand, tangrand = myrandom(x), myrandom(tang)
@@ -354,17 +347,17 @@ for op in [
354347
@test res1_out2_noval scen.res1
355348
end
356349
end
357-
test_scen_intact(new_scen, scen; isequal)
350+
scenario_intact && @test new_scen == scen
358351
return nothing
359352
end
360353

361354
@eval function test_correctness(
362355
ba::AbstractADType,
363356
scen::$S1in;
364-
isequal::Function,
365357
isapprox::Function,
366358
atol::Real,
367359
rtol::Real,
360+
scenario_intact::Bool,
368361
)
369362
@compat (; f, x, y, tang, res1, contexts) = new_scen = deepcopy(scen)
370363
xrand, tangrand = myrandom(x), myrandom(tang)
@@ -406,17 +399,17 @@ for op in [
406399
@test res1_out2_noval scen.res1
407400
end
408401
end
409-
test_scen_intact(new_scen, scen; isequal)
402+
scenario_intact && @test new_scen == scen
410403
return nothing
411404
end
412405

413406
@eval function test_correctness(
414407
ba::AbstractADType,
415408
scen::$S2out;
416-
isequal::Function,
417409
isapprox::Function,
418410
atol::Real,
419411
rtol::Real,
412+
scenario_intact::Bool,
420413
)
421414
@compat (; f, x, y, tang, res1, contexts) = new_scen = deepcopy(scen)
422415
xrand, yrand, tangrand = myrandom(x), myrandom(y), myrandom(tang)
@@ -456,17 +449,17 @@ for op in [
456449
@test res1_out2_noval scen.res1
457450
end
458451
end
459-
test_scen_intact(new_scen, scen; isequal)
452+
scenario_intact && @test new_scen == scen
460453
return nothing
461454
end
462455

463456
@eval function test_correctness(
464457
ba::AbstractADType,
465458
scen::$S2in;
466-
isequal::Function,
467459
isapprox::Function,
468460
atol::Real,
469461
rtol::Real,
462+
scenario_intact::Bool,
470463
)
471464
@compat (; f, x, y, tang, res1, contexts) = new_scen = deepcopy(scen)
472465
xrand, yrand, tangrand = myrandom(x), myrandom(y), myrandom(tang)
@@ -510,18 +503,18 @@ for op in [
510503
@test res1_out2_noval scen.res1
511504
end
512505
end
513-
test_scen_intact(new_scen, scen; isequal)
506+
scenario_intact && @test new_scen == scen
514507
return nothing
515508
end
516509

517510
elseif op in [:hvp]
518511
@eval function test_correctness(
519512
ba::AbstractADType,
520513
scen::$S1out;
521-
isequal::Function,
522514
isapprox::Function,
523515
atol::Real,
524516
rtol::Real,
517+
scenario_intact::Bool,
525518
)
526519
@compat (; f, x, y, tang, res2, contexts) = new_scen = deepcopy(scen)
527520
xrand, tangrand = myrandom(x), myrandom(tang)
@@ -539,17 +532,17 @@ for op in [
539532
@test res2_out2_noval scen.res2
540533
end
541534
end
542-
test_scen_intact(new_scen, scen; isequal)
535+
scenario_intact && @test new_scen == scen
543536
return nothing
544537
end
545538

546539
@eval function test_correctness(
547540
ba::AbstractADType,
548541
scen::$S1in;
549-
isequal::Function,
550542
isapprox::Function,
551543
atol::Real,
552544
rtol::Real,
545+
scenario_intact::Bool,
553546
)
554547
@compat (; f, x, y, tang, res2, contexts) = new_scen = deepcopy(scen)
555548
xrand, tangrand = myrandom(x), myrandom(tang)
@@ -575,7 +568,7 @@ for op in [
575568
@test res2_out2_noval scen.res2
576569
end
577570
end
578-
test_scen_intact(new_scen, scen; isequal)
571+
scenario_intact && @test new_scen == scen
579572
return nothing
580573
end
581574
end

DifferentiationInterfaceTest/test/weird.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,21 +46,21 @@ Random.seed!(0)
4646
test_differentiation(
4747
AutoZygote(),
4848
DIT.flux_scenarios();
49-
isequal=DIT.flux_isequal,
5049
isapprox=DIT.flux_isapprox,
5150
rtol=1e-2,
5251
atol=1e-6,
52+
scenario_intact=false,
5353
logging=LOGGING,
5454
)
5555

5656
#=
5757
test_differentiation(
5858
AutoZygote(),
5959
DIT.lux_scenarios(Random.Xoshiro(63));
60-
isequal=DIT.lux_isequal,
6160
isapprox=DIT.lux_isapprox,
6261
rtol=1.0f-2,
6362
atol=1.0f-3,
63+
scenario_intact=false,
6464
logging=LOGGING,
6565
)
6666
=#

0 commit comments

Comments
 (0)