forked from JuliaDiff/DifferentiationInterface.jl
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.jl
More file actions
36 lines (31 loc) · 1.06 KB
/
test.jl
File metadata and controls
36 lines (31 loc) · 1.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
using Pkg
Pkg.add(["FiniteDiff", "ForwardDiff", "Zygote", "Mooncake"])
using DifferentiationInterface, DifferentiationInterfaceTest
import DifferentiationInterfaceTest as DIT
using FiniteDiff: FiniteDiff
using ForwardDiff: ForwardDiff
using Zygote: Zygote
using Mooncake: Mooncake
using StableRNGs, Test
LOGGING = get(ENV, "CI", "false") == "false"
function differentiatewith_scenarios()
bad_scens = # these closurified scenarios have mutation and type constraints
filter(
DIT.default_scenarios(; include_normal=false, include_closurified=true)
) do scen
DIT.function_place(scen) == :out
end
good_scens = map(bad_scens) do scen
DIT.change_function(scen, DifferentiateWith(scen.f, AutoFiniteDiff()))
end
return good_scens
end
test_differentiation(
[AutoForwardDiff(), AutoZygote(), AutoMooncake(; config=nothing)],
differentiatewith_scenarios();
excluded=SECOND_ORDER,
logging=LOGGING,
)
@testset "Mooncake tests" begin
Mooncake.TestUtils.run_rrule!!_test_cases(StableRNG, Val(:diffwith))
end