Skip to content

Commit c389a80

Browse files
extensive tests, diffwith for tuples
1 parent d2b5a8c commit c389a80

2 files changed

Lines changed: 138 additions & 5 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

Lines changed: 133 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Union{Number,AbstractArray}}
1+
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Union{Number,AbstractArray,Tuple}}
22

3-
function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number})
3+
function Mooncake.rrule!!(
4+
dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{Union{<:Number,<:Tuple}}
5+
)
46
primal_func = primal(dw)
57
primal_x = primal(x)
68
(; f, backend) = primal_func
@@ -20,7 +22,22 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number
2022
return NoRData(), only(tx)
2123
end
2224

23-
return y, typeof(primal(y)) <: Number ? pullback_scalar!! : pullback_array!!
25+
# output is a Tuple, NTuple
26+
function pullback_tuple!!(dy::Tuple)
27+
tx = DI.pullback(f, backend, primal_x, (dy,))
28+
@assert only(tx) isa rdata_type(typeof(primal_x))
29+
return NoRData(), only(tx)
30+
end
31+
32+
pullback = if typeof(primal(y)) <: Number
33+
pullback_scalar!!
34+
elseif typeof(primal(y)) <: Array
35+
pullback_array!!
36+
else
37+
pullback_tuple!!
38+
end
39+
40+
return y, pullback
2441
end
2542

2643
function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray})
@@ -46,5 +63,117 @@ function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Abstra
4663
return NoRData(), NoRData()
4764
end
4865

49-
return y, typeof(primal(y)) <: Number ? pullback_scalar!! : pullback_array!!
66+
# output is a Tuple, NTuple
67+
function pullback_tuple!!(dy::Tuple)
68+
tx = DI.pullback(f, backend, primal_x, (dy,))
69+
@assert first(only(tx)) isa rdata_type(typeof(first(primal_x)))
70+
fdata_arg .+= only(tx)
71+
return NoRData(), NoRData()
72+
end
73+
74+
pullback = if typeof(primal(y)) <: Number
75+
pullback_scalar!!
76+
elseif typeof(primal(y)) <: Array
77+
pullback_array!!
78+
else
79+
pullback_tuple!!
80+
end
81+
82+
return y, pullback
83+
end
84+
85+
function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:diffwith})
86+
return Any[], Any[]
87+
end
88+
89+
function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:diffwith})
90+
test_cases = reduce(
91+
vcat,
92+
map([Float64, Float32]) do P
93+
return Any[
94+
(false, :stability_and_allocs, nothing, cosh, P(0.3)),
95+
(false, :stability_and_allocs, nothing, sinh, P(0.3)),
96+
(false, :stability_and_allocs, nothing, Base.FastMath.exp10_fast, P(0.5)),
97+
(false, :stability_and_allocs, nothing, Base.FastMath.exp2_fast, P(0.5)),
98+
(false, :stability_and_allocs, nothing, Base.FastMath.exp_fast, P(5.0)),
99+
(false, :stability_and_allocs, nothing, Base.FastMath.sincos, P(3.0)),
100+
]
101+
end,
102+
)
103+
push!(test_cases, (false, :stability, nothing, copy, randn(5, 4)))
104+
push!(test_cases, (
105+
# Check that Core._apply_iterate gets lifted to _apply_iterate_equivalent.
106+
false,
107+
:none,
108+
nothing,
109+
x -> +(x...),
110+
randn(33),
111+
))
112+
push!(
113+
test_cases,
114+
(
115+
false,
116+
:none,
117+
nothing,
118+
(
119+
function (x)
120+
rx = Ref(x)
121+
return Base.pointerref(
122+
Base.bitcast(Ptr{Float64}, pointer_from_objref(rx)), 1, 1
123+
)
124+
end
125+
),
126+
5.0,
127+
),
128+
)
129+
push!(
130+
test_cases,
131+
(
132+
false,
133+
:none,
134+
nothing,
135+
x -> (pointerset(pointer(x), UInt8(3), 2, 1); x),
136+
rand(UInt8, 5),
137+
),
138+
)
139+
push!(test_cases, (false, :none, nothing, Mooncake.__vec_to_tuple, [1.0]))
140+
push!(test_cases, (false, :none, nothing, Mooncake.__vec_to_tuple, Any[1.0]))
141+
push!(test_cases, (false, :none, nothing, Mooncake.__vec_to_tuple, Any[[1.0]]))
142+
push!(test_cases, (false, :stability, nothing, Mooncake.IntrinsicsWrappers.ctlz_int, 5))
143+
push!(
144+
test_cases, (false, :stability, nothing, Mooncake.IntrinsicsWrappers.ctpop_int, 5)
145+
)
146+
push!(test_cases, (false, :stability, nothing, Mooncake.IntrinsicsWrappers.cttz_int, 5))
147+
push!(
148+
test_cases, (false, :stability, nothing, Mooncake.IntrinsicsWrappers.abs_float, 5.0)
149+
)
150+
push!(
151+
test_cases,
152+
(false, :stability, nothing, Mooncake.IntrinsicsWrappers.abs_float, 5.0f0),
153+
)
154+
push!(test_cases, (false, :stability, nothing, deepcopy, 5.0))
155+
push!(test_cases, (false, :stability, nothing, deepcopy, randn(5)))
156+
push!(test_cases, (false, :stability_and_allocs, nothing, sin, 1.1))
157+
push!(test_cases, (false, :stability_and_allocs, nothing, sin, 1.0f1))
158+
push!(test_cases, (false, :stability_and_allocs, nothing, cos, 1.1))
159+
push!(test_cases, (false, :stability_and_allocs, nothing, cos, 1.0f1))
160+
push!(test_cases, (false, :stability_and_allocs, nothing, exp, 1.1))
161+
push!(test_cases, (false, :stability_and_allocs, nothing, exp, 1.0f1))
162+
163+
# additional_test_set = Mooncake.tangent_test_cases()
164+
# function is_valid(f)
165+
# try
166+
# isa(f([1.0, 2.0]), Union{<:Number,<:AbstractArray})
167+
# catch
168+
# false
169+
# end
170+
# end
171+
# for test in additional_test_set
172+
# if is_valid(test[2])
173+
# push!(test_cases, test)
174+
# end
175+
# end
176+
177+
memory = Any[]
178+
return test_cases, memory
50179
end

DifferentiationInterface/test/Back/DifferentiateWith/test.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using FiniteDiff: FiniteDiff
77
using ForwardDiff: ForwardDiff
88
using Zygote: Zygote
99
using Mooncake: Mooncake
10-
using Test
10+
using StableRNGs, Test
1111

1212
LOGGING = get(ENV, "CI", "false") == "false"
1313

@@ -30,3 +30,7 @@ test_differentiation(
3030
excluded=SECOND_ORDER,
3131
logging=LOGGING,
3232
)
33+
34+
@testset "new" begin
35+
Mooncake.TestUtils.run_rrule!!_test_cases(StableRNG, Val(:diffwith))
36+
end

0 commit comments

Comments
 (0)