forked from JuliaDiff/DifferentiationInterface.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.jl
More file actions
106 lines (93 loc) · 3.18 KB
/
test.jl
File metadata and controls
106 lines (93 loc) · 3.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
using Pkg
Pkg.add(["ChainRulesTestUtils", "FiniteDiff", "ForwardDiff", "Zygote", "Mooncake"])
using ChainRulesTestUtils: ChainRulesTestUtils
using DifferentiationInterface, DifferentiationInterfaceTest
import DifferentiationInterfaceTest as DIT
using FiniteDiff: FiniteDiff
using ForwardDiff: ForwardDiff
using Zygote: Zygote
using Mooncake: Mooncake
using StableRNGs
using Test
LOGGING = get(ENV, "CI", "false") == "false"
struct ADBreaker{F}
f::F
end
function (adb::ADBreaker)(x::Number)
copyto!(Float64[0], x) # break ForwardDiff and Zygote
return adb.f(x)
end
function (adb::ADBreaker)(x::AbstractArray)
copyto!(similar(x, Float64), x) # break ForwardDiff and Zygote
return adb.f(x)
end
function differentiatewith_scenarios()
outofplace_scens = filter(DIT.default_scenarios()) do scen
DIT.function_place(scen) == :out
end
# with bad_scens, everything would break
bad_scens = map(outofplace_scens) do scen
DIT.change_function(scen, ADBreaker(scen.f))
end
# with good_scens, everything is fixed
good_scens = map(bad_scens) do scen
DIT.change_function(scen, DifferentiateWith(scen.f, AutoFiniteDiff()))
end
return good_scens
end
test_differentiation(
[AutoForwardDiff(), AutoZygote(), AutoMooncake(; config=nothing)],
differentiatewith_scenarios();
excluded=SECOND_ORDER,
logging=LOGGING,
testset_name="DI tests",
)
@testset "ChainRules tests" begin
@testset for scen in filter(differentiatewith_scenarios()) do scen
DIT.operator(scen) == :pullback
end
ChainRulesTestUtils.test_rrule(scen.f, scen.x; rtol=1e-4)
end
end;
@testset "Mooncake tests" begin
@testset for scen in filter(differentiatewith_scenarios()) do scen
DIT.operator(scen) == :pullback
end
Mooncake.TestUtils.test_rule(StableRNG(0), scen.f, scen.x; is_primitive=true)
end
end;
@testset "Mooncake errors" begin
MooncakeDifferentiateWithError =
Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceMooncakeExt).MooncakeDifferentiateWithError
e = MooncakeDifferentiateWithError(identity, 1.0, 2.0)
@test sprint(showerror, e) ==
"MooncakeDifferentiateWithError: For the function type typeof(identity) and input type Float64, the output type Float64 is currently not supported."
f_num2tup(x::Number) = (x,)
f_vec2tup(x::Vector) = (first(x),)
f_tup2num(x::Tuple{<:Number}) = only(x)
f_tup2vec(x::Tuple{<:Number}) = [only(x)]
@test_throws MooncakeDifferentiateWithError pullback(
DifferentiateWith(f_num2tup, AutoFiniteDiff()),
AutoMooncake(; config=nothing),
1.0,
((2.0,),),
)
@test_throws MooncakeDifferentiateWithError pullback(
DifferentiateWith(f_vec2tup, AutoFiniteDiff()),
AutoMooncake(; config=nothing),
[1.0],
((2.0,),),
)
@test_throws MethodError pullback(
DifferentiateWith(f_tup2num, AutoFiniteDiff()),
AutoMooncake(; config=nothing),
(1.0,),
(2.0,),
)
@test_throws MethodError pullback(
DifferentiateWith(f_tup2vec, AutoFiniteDiff()),
AutoMooncake(; config=nothing),
(1.0,),
([2.0],),
)
end