Skip to content

Commit c982f46

Browse files
committed
Simplify Mooncake rule tests, add ChainRules rule tests
1 parent d94f146 commit c982f46

3 files changed

Lines changed: 106 additions & 219 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
function ChainRulesCore.rrule(dw::DI.DifferentiateWith, x)
22
(; f, backend) = dw
33
y = f(x)
4-
prep_same = DI.prepare_pullback_same_point_nokwarg(Val(true), f, backend, x, (y,))
4+
prep_same = DI.prepare_pullback_same_point_nokwarg(Val(false), f, backend, x, (y,))
55
function pullbackfunc(dy)
66
tx = DI.pullback(f, prep_same, backend, x, (dy,))
77
return (NoTangent(), only(tx))
Lines changed: 24 additions & 209 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,23 @@
1-
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Union{Number,AbstractArray,Tuple}}
1+
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Any}
2+
3+
struct MooncakeDifferentiateWithError <: Exception
4+
F::Type
5+
X::Type
6+
Y::Type
7+
function MooncakeDifferentiateWithError(::F, ::X, ::Y) where {F,X,Y}
8+
return new(F, X, Y)
9+
end
10+
end
11+
12+
function Base.showerror(io::IO, e::MooncakeDifferentiateWithError)
13+
return print(
14+
io,
15+
"MooncakeDifferentiateWithError: For the function type $(e.F) and input type $(e.X), the output type $(e.Y) is currently not supported.",
16+
)
17+
end
218

3-
# nested vectors (eg. [[1.0]]), Tuples (eg. ((1.0,),)) or similar (eg. [(1.0,)]) primal types are not supported by DI yet !
4-
# This is because basis construction (DI.basis) does not have overloads for these types.
519
# For details, refer commented out test cases to see where the pullback creation fails.
6-
function Mooncake.rrule!!(
7-
dw::CoDual{<:DI.DifferentiateWith}, x::Union{CoDual{<:Number},CoDual{<:Tuple}}
8-
)
20+
function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number})
921
primal_func = primal(dw)
1022
primal_x = primal(x)
1123
(; f, backend) = primal_func
@@ -25,31 +37,12 @@ function Mooncake.rrule!!(
2537
return NoRData(), rdata(only(tx))
2638
end
2739

28-
# output is a Tuple, NTuple
29-
function pullback_tuple!!(dy::Tuple)
30-
tx = DI.pullback(f, backend, primal_x, (dy,))
31-
@assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x)))
32-
return NoRData(), rdata(only(tx))
33-
end
34-
35-
# inputs are non Differentiable
36-
function pullback_nodiff!!(dy::NoRData)
37-
@assert tangent_type(typeof(primal(x))) <: NoTangent
38-
return NoRData(), dy
39-
end
40-
41-
pullback = if tangent_type(typeof(primal(x))) <: NoTangent
42-
pullback_nodiff!!
43-
elseif primal(y) isa Number
40+
pullback = if primal(y) isa Number
4441
pullback_scalar!!
45-
elseif primal(y) <: AbstractArray
42+
elseif primal(y) isa AbstractArray
4643
pullback_array!!
47-
elseif primal(y) <: Tuple
48-
pullback_tuple!!
4944
else
50-
error(
51-
"For the function type $(typeof(primal_func)) and input type $(typeof(primal_x)), the primal type $(typeof(primal(y))) is currently not supported.",
52-
)
45+
throw(MooncakeDifferentiateWithError(primal_func, primal_x, primal(y)))
5346
end
5447

5548
return y, pullback
@@ -78,191 +71,13 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Abstra
7871
return NoRData(), NoRData()
7972
end
8073

81-
# output is a Tuple, NTuple
82-
function pullback_tuple!!(dy::Tuple)
83-
tx = DI.pullback(f, backend, primal_x, (dy,))
84-
@assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x))))
85-
fdata_arg .+= only(tx)
86-
return NoRData(), NoRData()
87-
end
88-
89-
# inputs are non Differentiable
90-
function pullback_nodiff!!(dy::NoRData)
91-
@assert tangent_type(typeof(primal(x))) <: Vector{NoTangent}
92-
return NoRData(), dy
93-
end
94-
95-
pullback = if tangent_type(typeof(primal(x))) <: Vector{NoTangent}
96-
pullback_nodiff!!
97-
elseif primal(y) isa Number
74+
pullback = if primal(y) isa Number
9875
pullback_scalar!!
99-
elseif primal(y) <: AbstractArray
76+
elseif primal(y) isa AbstractArray
10077
pullback_array!!
101-
elseif primal(y) <: Tuple
102-
pullback_tuple!!
10378
else
104-
error(
105-
"For the function type $(typeof(primal_func)) and input type $(typeof(primal_x)), the primal type $(typeof(primal(y))) is currently not supported.",
106-
)
79+
throw(MooncakeDifferentiateWithError(primal_func, primal_x, primal(y)))
10780
end
10881

10982
return y, pullback
11083
end
111-
112-
function Mooncake.generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:diffwith})
113-
return Any[], Any[]
114-
end
115-
116-
function Mooncake.generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:diffwith})
117-
test_cases = reduce(
118-
vcat,
119-
map([(x) -> DI.DifferentiateWith(x, DI.AutoFiniteDiff())]) do F
120-
map([Float64, Float32]) do P
121-
return Any[
122-
# (false, :none, nothing, F(identity), ((1.0,),)), # (DI.basis fails for this, correct it!)
123-
# (false, :none, nothing, F(identity), [[1.0]]), # (DI.basis fails for this, correct it!)
124-
(false, :stability_and_allocs, nothing, F(cosh), P(0.3)),
125-
(false, :stability_and_allocs, nothing, F(sinh), P(0.3)),
126-
(
127-
false,
128-
:stability_and_allocs,
129-
nothing,
130-
F(Base.FastMath.exp10_fast),
131-
P(0.5),
132-
),
133-
(
134-
false,
135-
:stability_and_allocs,
136-
nothing,
137-
F(Base.FastMath.exp2_fast),
138-
P(0.5),
139-
),
140-
(
141-
false,
142-
:stability_and_allocs,
143-
nothing,
144-
F(Base.FastMath.exp_fast),
145-
P(5.0),
146-
),
147-
(false, :stability, nothing, F(copy), rand(Int32, 5)),
148-
]
149-
end
150-
end...,
151-
)
152-
153-
map([(x) -> DI.DifferentiateWith(x, DI.AutoFiniteDiff())]) do F
154-
push!(
155-
test_cases,
156-
Any[
157-
(false, :stability, nothing, copy, randn(5, 4)),
158-
(
159-
# Check that Core._apply_iterate gets lifted to _apply_iterate_equivalent.
160-
false,
161-
:stability,
162-
nothing,
163-
F(x -> +(x...)),
164-
randn(33),
165-
),
166-
(
167-
false,
168-
:stability,
169-
nothing,
170-
(F(
171-
function (x)
172-
rx = Ref(x)
173-
return Base.pointerref(
174-
Base.bitcast(Ptr{Float64}, pointer_from_objref(rx)), 1, 1
175-
)
176-
end,
177-
)),
178-
5.0,
179-
),
180-
# (false, :none, nothing, F(Mooncake.__vec_to_tuple), Any[(1.0,)]), # (DI.basis fails for this, correct it!)
181-
(
182-
false,
183-
:stability_and_allocs,
184-
nothing,
185-
F(Mooncake.IntrinsicsWrappers.ctlz_int),
186-
5,
187-
),
188-
(
189-
false,
190-
:stability_and_allocs,
191-
nothing,
192-
F(Mooncake.IntrinsicsWrappers.ctpop_int),
193-
5,
194-
),
195-
(
196-
false,
197-
:stability_and_allocs,
198-
nothing,
199-
F(Mooncake.IntrinsicsWrappers.cttz_int),
200-
5,
201-
),
202-
(
203-
false,
204-
:stability_and_allocs,
205-
nothing,
206-
F(Mooncake.IntrinsicsWrappers.abs_float),
207-
5.0f0,
208-
),
209-
(false, :stability_and_allocs, nothing, F(deepcopy), 5.0),
210-
(false, :stability, nothing, F(deepcopy), randn(5)),
211-
(false, :stability_and_allocs, nothing, F(sin), 1.1),
212-
(false, :stability_and_allocs, nothing, F(sin), 1.0f1),
213-
(false, :stability_and_allocs, nothing, F(cos), 1.1),
214-
(false, :stability_and_allocs, nothing, F(cos), 1.0f1),
215-
(false, :stability_and_allocs, nothing, F(exp), 1.1),
216-
(false, :stability_and_allocs, nothing, F(exp), 1.0f1),
217-
]...,
218-
)
219-
end
220-
221-
map([(x) -> DI.DifferentiateWith(x, DI.AutoForwardDiff())]) do F
222-
map([Float64, Float32]) do P
223-
push!(
224-
test_cases,
225-
Any[
226-
(
227-
false,
228-
:stability_and_allocs,
229-
nothing,
230-
F(Base.FastMath.sincos),
231-
P(3.0),
232-
),
233-
(false, :none, nothing, F(Mooncake.__vec_to_tuple), [P(1.0)]),
234-
]...,
235-
)
236-
end
237-
238-
push!(
239-
test_cases,
240-
Any[
241-
(
242-
false,
243-
:stability_and_allocs,
244-
nothing,
245-
F(Mooncake.IntrinsicsWrappers.ctlz_int),
246-
5,
247-
),
248-
(
249-
false,
250-
:stability_and_allocs,
251-
nothing,
252-
F(Mooncake.IntrinsicsWrappers.ctpop_int),
253-
5,
254-
),
255-
(
256-
false,
257-
:stability_and_allocs,
258-
nothing,
259-
F(Mooncake.IntrinsicsWrappers.cttz_int),
260-
5,
261-
),
262-
]...,
263-
)
264-
end
265-
266-
memory = Any[]
267-
return test_cases, memory
268-
end
Lines changed: 81 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,41 @@
11
using Pkg
2-
Pkg.add(["FiniteDiff", "ForwardDiff", "Zygote", "Mooncake"])
2+
Pkg.add(["ChainRulesTestUtils", "FiniteDiff", "ForwardDiff", "Zygote", "Mooncake"])
33

4+
using ChainRulesTestUtils: ChainRulesTestUtils
45
using DifferentiationInterface, DifferentiationInterfaceTest
56
import DifferentiationInterfaceTest as DIT
67
using FiniteDiff: FiniteDiff
78
using ForwardDiff: ForwardDiff
89
using Zygote: Zygote
910
using Mooncake: Mooncake
10-
using StableRNGs, Test
11+
using StableRNGs
12+
using Test
1113

1214
LOGGING = get(ENV, "CI", "false") == "false"
1315

16+
struct ADBreaker{F}
17+
f::F
18+
end
19+
20+
function (adb::ADBreaker)(x::Number)
21+
copyto!(Float64[0], x) # break ForwardDiff and Zygote
22+
return adb.f(x)
23+
end
24+
25+
function (adb::ADBreaker)(x::AbstractArray)
26+
copyto!(similar(x, Float64), x) # break ForwardDiff and Zygote
27+
return adb.f(x)
28+
end
29+
1430
function differentiatewith_scenarios()
15-
bad_scens = # these closurified scenarios have mutation and type constraints
16-
filter(
17-
DIT.default_scenarios(; include_normal=false, include_closurified=true)
18-
) do scen
19-
DIT.function_place(scen) == :out
20-
end
31+
outofplace_scens = filter(DIT.default_scenarios()) do scen
32+
DIT.function_place(scen) == :out
33+
end
34+
# with bad_scens, everything would break
35+
bad_scens = map(outofplace_scens) do scen
36+
DIT.change_function(scen, ADBreaker(scen.f))
37+
end
38+
# with good_scens, everything is fixed
2139
good_scens = map(bad_scens) do scen
2240
DIT.change_function(scen, DifferentiateWith(scen.f, AutoFiniteDiff()))
2341
end
@@ -29,8 +47,62 @@ test_differentiation(
2947
differentiatewith_scenarios();
3048
excluded=SECOND_ORDER,
3149
logging=LOGGING,
50+
testset_name="DI tests",
3251
)
3352

53+
@testset "ChainRules tests" begin
54+
@testset for scen in filter(differentiatewith_scenarios()) do scen
55+
DIT.operator(scen) == :pullback
56+
end
57+
ChainRulesTestUtils.test_rrule(scen.f, scen.x; rtol=1e-4)
58+
end
59+
end;
60+
3461
@testset "Mooncake tests" begin
35-
Mooncake.TestUtils.run_rrule!!_test_cases(StableRNG, Val(:diffwith))
62+
@testset for scen in filter(differentiatewith_scenarios()) do scen
63+
DIT.operator(scen) == :pullback
64+
end
65+
Mooncake.TestUtils.test_rule(StableRNG(0), scen.f, scen.x; is_primitive=true)
66+
end
67+
end;
68+
69+
@testset "Mooncake errors" begin
70+
MooncakeDifferentiateWithError =
71+
Base.get_extension(
72+
DifferentiationInterface, :DifferentiationInterfaceMooncakeExt
73+
).MooncakeDifferentiateWithError
74+
75+
e = MooncakeDifferentiateWithError(identity, 1.0, 2.0)
76+
@test sprint(showerror, e) ==
77+
"MooncakeDifferentiateWithError: For the function type typeof(identity) and input type Float64, the output type Float64 is currently not supported."
78+
79+
f_num2tup(x::Number) = (x,)
80+
f_vec2tup(x::Vector) = (first(x),)
81+
f_tup2num(x::Tuple{<:Number}) = only(x)
82+
f_tup2vec(x::Tuple{<:Number}) = [only(x)]
83+
84+
@test_throws MooncakeDifferentiateWithError pullback(
85+
DifferentiateWith(f_num2tup, AutoFiniteDiff()),
86+
AutoMooncake(; config=nothing),
87+
1.0,
88+
((2.0,),),
89+
)
90+
@test_throws MooncakeDifferentiateWithError pullback(
91+
DifferentiateWith(f_vec2tup, AutoFiniteDiff()),
92+
AutoMooncake(; config=nothing),
93+
[1.0],
94+
((2.0,),),
95+
)
96+
@test_throws MethodError pullback(
97+
DifferentiateWith(f_tup2num, AutoFiniteDiff()),
98+
AutoMooncake(; config=nothing),
99+
(1.0,),
100+
(2.0,),
101+
)
102+
@test_throws MethodError pullback(
103+
DifferentiateWith(f_tup2vec, AutoFiniteDiff()),
104+
AutoMooncake(; config=nothing),
105+
(1.0,),
106+
([2.0],),
107+
)
36108
end

0 commit comments

Comments
 (0)