Skip to content

Commit ec4b75d

Browse files
tests, inc primal handling
1 parent b4fe0f8 commit ec4b75d

4 files changed

Lines changed: 130 additions & 103 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ JET = "0.9"
7272
JLArrays = "0.2.0"
7373
JuliaFormatter = "1,2"
7474
LinearAlgebra = "1"
75-
Mooncake = "0.4.88"
75+
Mooncake = "0.4.121"
7676
Pkg = "1"
7777
PolyesterForwardDiff = "0.1.2"
7878
Random = "1"

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@ using Mooncake:
1313
value_and_pullback!!,
1414
zero_tangent,
1515
rdata_type,
16+
fdata,
17+
rdata,
18+
tangent_type,
19+
NoTangent,
1620
@is_primitive,
1721
zero_fcodual,
1822
MinimalCtx,
1923
NoRData,
20-
fdata,
2124
primal
2225

2326
DI.check_available(::AutoMooncake) = true

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

Lines changed: 124 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Union{Number,AbstractArray,Tuple}}
22

3+
# nested vectors, similar are not supported
34
function Mooncake.rrule!!(
45
dw::CoDual{<:DI.DifferentiateWith}, x::Union{CoDual{<:Number},CoDual{<:Tuple}}
56
)
@@ -10,31 +11,41 @@ function Mooncake.rrule!!(
1011

1112
# output is a vector, so we need to use the vector pullback
1213
function pullback_array!!(dy::NoRData)
13-
tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),))
14-
@assert only(tx) isa rdata_type(typeof(primal_x))
15-
return NoRData(), only(tx)
14+
tx = DI.pullback(f, backend, primal_x, (y.dx,))
15+
@assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x)))
16+
return NoRData(), rdata(only(tx))
1617
end
1718

1819
# output is a scalar, so we can use the scalar pullback
1920
function pullback_scalar!!(dy::Number)
2021
tx = DI.pullback(f, backend, primal_x, (dy,))
21-
@assert only(tx) isa rdata_type(typeof(primal_x))
22-
return NoRData(), only(tx)
22+
@assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x)))
23+
return NoRData(), rdata(only(tx))
2324
end
2425

2526
# output is a Tuple, NTuple
2627
function pullback_tuple!!(dy::Tuple)
2728
tx = DI.pullback(f, backend, primal_x, (dy,))
28-
@assert only(tx) isa rdata_type(typeof(primal_x))
29-
return NoRData(), only(tx)
29+
@assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x)))
30+
return NoRData(), rdata(only(tx))
3031
end
3132

32-
pullback = if typeof(primal(y)) <: Number
33+
# inputs are non Differentiable
34+
function pullback_nodiff!!(dy::NoRData)
35+
@assert tangent_type(typeof(primal(x))) <: NoTangent
36+
return NoRData(), dy
37+
end
38+
39+
pullback = if tangent_type(typeof(primal(x))) <: NoTangent
40+
pullback_nodiff!!
41+
elseif typeof(primal(y)) <: Number
3342
pullback_scalar!!
3443
elseif typeof(primal(y)) <: Array
3544
pullback_array!!
36-
else
45+
elseif typeof(primal(y)) <: Tuple
3746
pullback_tuple!!
47+
else
48+
error("$(typeof(primal(y))) primal type currently not supported.")
3849
end
3950

4051
return y, pullback
@@ -43,40 +54,50 @@ end
4354
function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray})
4455
primal_func = primal(dw)
4556
primal_x = primal(x)
46-
fdata_arg = fdata(x.dx)
57+
fdata_arg = x.dx
4758
(; f, backend) = primal_func
4859
y = zero_fcodual(f(primal_x))
4960

5061
# output is a vector, so we need to use the vector pullback
5162
function pullback_array!!(dy::NoRData)
52-
tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),))
53-
@assert first(only(tx)) isa rdata_type(typeof(first(primal_x)))
63+
tx = DI.pullback(f, backend, primal_x, (y.dx,))
64+
@assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x))))
5465
fdata_arg .+= only(tx)
5566
return NoRData(), dy
5667
end
5768

5869
# output is a scalar, so we can use the scalar pullback
5970
function pullback_scalar!!(dy::Number)
6071
tx = DI.pullback(f, backend, primal_x, (dy,))
61-
@assert first(only(tx)) isa rdata_type(typeof(first(primal_x)))
72+
@assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x))))
6273
fdata_arg .+= only(tx)
6374
return NoRData(), NoRData()
6475
end
6576

6677
# output is a Tuple, NTuple
6778
function pullback_tuple!!(dy::Tuple)
6879
tx = DI.pullback(f, backend, primal_x, (dy,))
69-
@assert first(only(tx)) isa rdata_type(typeof(first(primal_x)))
80+
@assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x))))
7081
fdata_arg .+= only(tx)
7182
return NoRData(), NoRData()
7283
end
7384

74-
pullback = if typeof(primal(y)) <: Number
85+
# inputs are non Differentiable
86+
function pullback_nodiff!!(dy::NoRData)
87+
@assert tangent_type(typeof(primal(x))) <: Vector{NoTangent}
88+
return NoRData(), dy
89+
end
90+
91+
pullback = if tangent_type(typeof(primal(x))) <: Vector{NoTangent}
92+
pullback_nodiff!!
93+
elseif typeof(primal(y)) <: Number
7594
pullback_scalar!!
76-
elseif typeof(primal(y)) <: Array
95+
elseif typeof(primal(y)) <: AbstractArray
7796
pullback_array!!
78-
else
97+
elseif typeof(primal(y)) <: Tuple
7998
pullback_tuple!!
99+
else
100+
error("$(typeof(primal(y))) primal type currently not supported.")
80101
end
81102

82103
return y, pullback
@@ -89,90 +110,93 @@ end
89110
function Mooncake.generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:diffwith})
90111
test_cases = reduce(
91112
vcat,
92-
map([Float64, Float32]) do P
93-
return Any[
94-
(false, :stability_and_allocs, nothing, cosh, P(0.3)),
95-
(false, :stability_and_allocs, nothing, sinh, P(0.3)),
96-
(false, :stability_and_allocs, nothing, Base.FastMath.exp10_fast, P(0.5)),
97-
(false, :stability_and_allocs, nothing, Base.FastMath.exp2_fast, P(0.5)),
98-
(false, :stability_and_allocs, nothing, Base.FastMath.exp_fast, P(5.0)),
99-
(false, :stability_and_allocs, nothing, Base.FastMath.sincos, P(3.0)),
100-
]
101-
end,
113+
map([(x) -> DI.DifferentiateWith(x, DI.AutoFiniteDiff())]) do F
114+
map([Float64, Float32]) do P
115+
return Any[
116+
(false, :stability, nothing, F(cosh), P(0.3)),
117+
(false, :stability, nothing, F(sinh), P(0.3)),
118+
(false, :stability, nothing, F(Base.FastMath.exp10_fast), P(0.5)),
119+
(false, :stability, nothing, F(Base.FastMath.exp2_fast), P(0.5)),
120+
(false, :stability, nothing, F(Base.FastMath.exp_fast), P(5.0)),
121+
(false, :none, nothing, F(copy), rand(Int32, 5)),
122+
]
123+
end
124+
end...,
102125
)
103-
push!(test_cases, (false, :stability, nothing, copy, randn(5, 4)))
104-
push!(test_cases, (
105-
# Check that Core._apply_iterate gets lifted to _apply_iterate_equivalent.
106-
false,
107-
:none,
108-
nothing,
109-
x -> +(x...),
110-
randn(33),
111-
))
112-
push!(
113-
test_cases,
114-
(
115-
false,
116-
:none,
117-
nothing,
118-
(
119-
function (x)
120-
rx = Ref(x)
121-
return Base.pointerref(
122-
Base.bitcast(Ptr{Float64}, pointer_from_objref(rx)), 1, 1
123-
)
124-
end
125-
),
126-
5.0,
127-
),
128-
)
129-
push!(
130-
test_cases,
131-
(
132-
false,
133-
:none,
134-
nothing,
135-
x -> (pointerset(pointer(x), UInt8(3), 2, 1); x),
136-
rand(UInt8, 5),
137-
),
138-
)
139-
push!(test_cases, (false, :none, nothing, Mooncake.__vec_to_tuple, [1.0]))
140-
push!(test_cases, (false, :none, nothing, Mooncake.__vec_to_tuple, Any[1.0]))
141-
push!(test_cases, (false, :none, nothing, Mooncake.__vec_to_tuple, Any[[1.0]]))
142-
push!(test_cases, (false, :stability, nothing, Mooncake.IntrinsicsWrappers.ctlz_int, 5))
143-
push!(
144-
test_cases, (false, :stability, nothing, Mooncake.IntrinsicsWrappers.ctpop_int, 5)
145-
)
146-
push!(test_cases, (false, :stability, nothing, Mooncake.IntrinsicsWrappers.cttz_int, 5))
147-
push!(
148-
test_cases, (false, :stability, nothing, Mooncake.IntrinsicsWrappers.abs_float, 5.0)
149-
)
150-
push!(
151-
test_cases,
152-
(false, :stability, nothing, Mooncake.IntrinsicsWrappers.abs_float, 5.0f0),
153-
)
154-
push!(test_cases, (false, :stability, nothing, deepcopy, 5.0))
155-
push!(test_cases, (false, :stability, nothing, deepcopy, randn(5)))
156-
push!(test_cases, (false, :stability_and_allocs, nothing, sin, 1.1))
157-
push!(test_cases, (false, :stability_and_allocs, nothing, sin, 1.0f1))
158-
push!(test_cases, (false, :stability_and_allocs, nothing, cos, 1.1))
159-
push!(test_cases, (false, :stability_and_allocs, nothing, cos, 1.0f1))
160-
push!(test_cases, (false, :stability_and_allocs, nothing, exp, 1.1))
161-
push!(test_cases, (false, :stability_and_allocs, nothing, exp, 1.0f1))
162-
163-
# additional_test_set = Mooncake.tangent_test_cases()
164-
# function is_valid(f)
165-
# try
166-
# isa(f([1.0, 2.0]), Union{<:Number,<:AbstractArray})
167-
# catch
168-
# false
169-
# end
170-
# end
171-
# for test in additional_test_set
172-
# if is_valid(test[2])
173-
# push!(test_cases, test)
174-
# end
175-
# end
126+
127+
map([(x) -> DI.DifferentiateWith(x, DI.AutoZygote())]) do F
128+
map([Float64, Float32]) do P
129+
push!(
130+
test_cases,
131+
Any[
132+
(false, :stability, nothing, F(Base.FastMath.sincos), P(3.0)),
133+
(false, :none, nothing, F(Mooncake.__vec_to_tuple), Any[P(1.0)]),
134+
]...,
135+
)
136+
end
137+
end
138+
139+
map([(x) -> DI.DifferentiateWith(x, DI.AutoZygote())]) do F
140+
push!(
141+
test_cases,
142+
Any[
143+
(false, :stability, nothing, F(Mooncake.IntrinsicsWrappers.ctlz_int), 5),
144+
(false, :stability, nothing, F(Mooncake.IntrinsicsWrappers.ctpop_int), 5),
145+
(false, :stability, nothing, F(Mooncake.IntrinsicsWrappers.cttz_int), 5),
146+
]...,
147+
)
148+
end
149+
150+
map([(x) -> DI.DifferentiateWith(x, DI.AutoFiniteDiff())]) do F
151+
push!(
152+
test_cases,
153+
Any[
154+
(false, :stability, nothing, copy, randn(5, 4)),
155+
(
156+
# Check that Core._apply_iterate gets lifted to _apply_iterate_equivalent.
157+
false,
158+
:none,
159+
nothing,
160+
F(x -> +(x...)),
161+
randn(33),
162+
),
163+
(
164+
false,
165+
:none,
166+
nothing,
167+
(F(
168+
function (x)
169+
rx = Ref(x)
170+
return Base.pointerref(
171+
Base.bitcast(Ptr{Float64}, pointer_from_objref(rx)), 1, 1
172+
)
173+
end,
174+
)),
175+
5.0,
176+
),
177+
(false, :none, nothing, F(Mooncake.__vec_to_tuple), [1.0]),
178+
# (false, :none, nothing, F(Mooncake.__vec_to_tuple), Any[(1.0,)]), DI.basis fails for this, correct it!
179+
(false, :stability, nothing, F(Mooncake.IntrinsicsWrappers.ctlz_int), 5),
180+
(false, :stability, nothing, F(Mooncake.IntrinsicsWrappers.ctpop_int), 5),
181+
(false, :stability, nothing, F(Mooncake.IntrinsicsWrappers.cttz_int), 5),
182+
(
183+
false,
184+
:stability,
185+
nothing,
186+
F(Mooncake.IntrinsicsWrappers.abs_float),
187+
5.0f0,
188+
),
189+
(false, :stability, nothing, F(deepcopy), 5.0),
190+
(false, :stability, nothing, F(deepcopy), randn(5)),
191+
(false, :stability_and_allocs, nothing, F(sin), 1.1),
192+
(false, :stability_and_allocs, nothing, F(sin), 1.0f1),
193+
(false, :stability_and_allocs, nothing, F(cos), 1.1),
194+
(false, :stability_and_allocs, nothing, F(cos), 1.0f1),
195+
(false, :stability_and_allocs, nothing, F(exp), 1.1),
196+
(false, :stability_and_allocs, nothing, F(exp), 1.0f1),
197+
]...,
198+
)
199+
end
176200

177201
memory = Any[]
178202
return test_cases, memory

DifferentiationInterface/test/Back/DifferentiateWith/test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,6 @@ test_differentiation(
3131
logging=LOGGING,
3232
)
3333

34-
@testset "new" begin
34+
@testset "Mooncake tests" begin
3535
Mooncake.TestUtils.run_rrule!!_test_cases(StableRNG, Val(:diffwith))
3636
end

0 commit comments

Comments
 (0)