|
1 | 1 | using Pkg |
2 | | -Pkg.add(["FiniteDiff", "ForwardDiff", "Zygote"]) |
| 2 | +Pkg.add(["ChainRulesTestUtils", "FiniteDiff", "ForwardDiff", "Zygote", "Mooncake"]) |
3 | 3 |
|
| 4 | +using ChainRulesTestUtils: ChainRulesTestUtils |
4 | 5 | using DifferentiationInterface, DifferentiationInterfaceTest |
5 | 6 | import DifferentiationInterfaceTest as DIT |
6 | 7 | using FiniteDiff: FiniteDiff |
7 | 8 | using ForwardDiff: ForwardDiff |
8 | 9 | using Zygote: Zygote |
| 10 | +using Mooncake: Mooncake |
| 11 | +using StableRNGs |
9 | 12 | using Test |
10 | 13 |
|
11 | 14 | LOGGING = get(ENV, "CI", "false") == "false" |
12 | 15 |
|
| 16 | +struct ADBreaker{F} |
| 17 | + f::F |
| 18 | +end |
| 19 | + |
| 20 | +function (adb::ADBreaker)(x::Number) |
| 21 | + copyto!(Float64[0], x) # break ForwardDiff and Zygote |
| 22 | + return adb.f(x) |
| 23 | +end |
| 24 | + |
| 25 | +function (adb::ADBreaker)(x::AbstractArray) |
| 26 | + copyto!(similar(x, Float64), x) # break ForwardDiff and Zygote |
| 27 | + return adb.f(x) |
| 28 | +end |
| 29 | + |
13 | 30 | function differentiatewith_scenarios() |
14 | | - bad_scens = # these closurified scenarios have mutation and type constraints |
15 | | - filter(default_scenarios(; include_normal=false, include_closurified=true)) do scen |
16 | | - DIT.function_place(scen) == :out |
17 | | - end |
| 31 | + outofplace_scens = filter(DIT.default_scenarios()) do scen |
| 32 | + DIT.function_place(scen) == :out |
| 33 | + end |
| 34 | + # with bad_scens, everything would break |
| 35 | + bad_scens = map(outofplace_scens) do scen |
| 36 | + DIT.change_function(scen, ADBreaker(scen.f)) |
| 37 | + end |
| 38 | + # with good_scens, everything is fixed |
18 | 39 | good_scens = map(bad_scens) do scen |
19 | 40 | DIT.change_function(scen, DifferentiateWith(scen.f, AutoFiniteDiff())) |
20 | 41 | end |
21 | 42 | return good_scens |
22 | 43 | end |
23 | 44 |
|
24 | 45 | test_differentiation( |
25 | | - [AutoForwardDiff(), AutoZygote()], |
| 46 | + [AutoForwardDiff(), AutoZygote(), AutoMooncake(; config=nothing)], |
26 | 47 | differentiatewith_scenarios(); |
27 | 48 | excluded=SECOND_ORDER, |
28 | 49 | logging=LOGGING, |
| 50 | + testset_name="DI tests", |
29 | 51 | ) |
| 52 | + |
| 53 | +@testset "ChainRules tests" begin |
| 54 | + @testset for scen in filter(differentiatewith_scenarios()) do scen |
| 55 | + DIT.operator(scen) == :pullback |
| 56 | + end |
| 57 | + ChainRulesTestUtils.test_rrule(scen.f, scen.x; rtol=1e-4) |
| 58 | + end |
| 59 | +end; |
| 60 | + |
| 61 | +@testset "Mooncake tests" begin |
| 62 | + @testset for scen in filter(differentiatewith_scenarios()) do scen |
| 63 | + DIT.operator(scen) == :pullback |
| 64 | + end |
| 65 | + Mooncake.TestUtils.test_rule(StableRNG(0), scen.f, scen.x; is_primitive=true) |
| 66 | + end |
| 67 | +end; |
| 68 | + |
| 69 | +@testset "Mooncake errors" begin |
| 70 | + MooncakeDifferentiateWithError = |
| 71 | + Base.get_extension(DifferentiationInterface, :DifferentiationInterfaceMooncakeExt).MooncakeDifferentiateWithError |
| 72 | + |
| 73 | + e = MooncakeDifferentiateWithError(identity, 1.0, 2.0) |
| 74 | + @test sprint(showerror, e) == |
| 75 | + "MooncakeDifferentiateWithError: For the function type typeof(identity) and input type Float64, the output type Float64 is currently not supported." |
| 76 | + |
| 77 | + f_num2tup(x::Number) = (x,) |
| 78 | + f_vec2tup(x::Vector) = (first(x),) |
| 79 | + f_tup2num(x::Tuple{<:Number}) = only(x) |
| 80 | + f_tup2vec(x::Tuple{<:Number}) = [only(x)] |
| 81 | + |
| 82 | + @test_throws MooncakeDifferentiateWithError pullback( |
| 83 | + DifferentiateWith(f_num2tup, AutoFiniteDiff()), |
| 84 | + AutoMooncake(; config=nothing), |
| 85 | + 1.0, |
| 86 | + ((2.0,),), |
| 87 | + ) |
| 88 | + @test_throws MooncakeDifferentiateWithError pullback( |
| 89 | + DifferentiateWith(f_vec2tup, AutoFiniteDiff()), |
| 90 | + AutoMooncake(; config=nothing), |
| 91 | + [1.0], |
| 92 | + ((2.0,),), |
| 93 | + ) |
| 94 | + @test_throws MethodError pullback( |
| 95 | + DifferentiateWith(f_tup2num, AutoFiniteDiff()), |
| 96 | + AutoMooncake(; config=nothing), |
| 97 | + (1.0,), |
| 98 | + (2.0,), |
| 99 | + ) |
| 100 | + @test_throws MethodError pullback( |
| 101 | + DifferentiateWith(f_tup2vec, AutoFiniteDiff()), |
| 102 | + AutoMooncake(; config=nothing), |
| 103 | + (1.0,), |
| 104 | + ([2.0],), |
| 105 | + ) |
| 106 | +end |
0 commit comments