Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
1a389a6
Handles backend switching for Mooncake using ChainRules
AstitvaAggarwal Apr 1, 2025
08b176a
Mooncake Wrapper for substitute backends
AstitvaAggarwal Apr 2, 2025
ba0c9e6
Merge branch 'JuliaDiff:main' into develop
AstitvaAggarwal Apr 10, 2025
1340d92
added rules
AstitvaAggarwal Apr 10, 2025
2ce1ee2
Merge branch 'develop' of https://github.com/AstitvaAggarwal/Differen…
AstitvaAggarwal Apr 10, 2025
08de6df
config
AstitvaAggarwal Apr 10, 2025
84f27c9
splatting for dy
AstitvaAggarwal Apr 10, 2025
2e95299
brackets
AstitvaAggarwal Apr 10, 2025
13233e5
too easy
AstitvaAggarwal Apr 11, 2025
1e8df98
changes from reviews, Docs
AstitvaAggarwal Apr 12, 2025
afdddd4
changes from reviews - 2
AstitvaAggarwal Apr 18, 2025
233c312
Merge branch 'JuliaDiff:main' into develop
AstitvaAggarwal Apr 18, 2025
7a07127
changes from reviews-1
AstitvaAggarwal May 16, 2025
f3e436d
conflicts
AstitvaAggarwal May 16, 2025
6a0d937
conflicts-2
AstitvaAggarwal May 16, 2025
e543958
Update differentiate_with.jl
AstitvaAggarwal May 16, 2025
2472ecc
Merge branch 'JuliaDiff:main' into develop
AstitvaAggarwal May 16, 2025
c63c956
typecheck for array rule.
AstitvaAggarwal May 18, 2025
36da036
assertion for array inputs
AstitvaAggarwal May 18, 2025
d2b5a8c
Merge branch 'JuliaDiff:main' into develop
AstitvaAggarwal May 29, 2025
c389a80
extensive tests, diffwith for tuples
AstitvaAggarwal May 29, 2025
b4fe0f8
tests.
AstitvaAggarwal May 29, 2025
ec4b75d
tests, inc primal handling
AstitvaAggarwal May 31, 2025
0f0b9fc
changes from reviews
AstitvaAggarwal Jun 6, 2025
3c5f99e
Merge branch 'main' into develop
yebai Jun 13, 2025
d94f146
Apply suggestions from code review
gdalle Jun 13, 2025
c982f46
Simplify Mooncake rule tests, add ChainRules rule tests
gdalle Jun 13, 2025
749fea5
Format
gdalle Jun 13, 2025
9e5ecfd
Update differentiate_with.jl
gdalle Jun 14, 2025
1e85f17
Restrict to array of numbers
gdalle Jun 14, 2025
ff5c4e2
Update DifferentiationInterface/ext/DifferentiationInterfaceMooncakeE…
gdalle Jun 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ JET = "0.9"
JLArrays = "0.2.0"
JuliaFormatter = "1,2"
LinearAlgebra = "1"
Mooncake = "0.4.88"
Mooncake = "0.4.121"
Pkg = "1"
PolyesterForwardDiff = "0.1.2"
Random = "1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@ using Mooncake:
value_and_pullback!!,
zero_tangent,
rdata_type,
fdata,
rdata,
tangent_type,
NoTangent,
@is_primitive,
zero_fcodual,
MinimalCtx,
NoRData,
fdata,
primal

DI.check_available(::AutoMooncake) = true
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Union{Number,AbstractArray,Tuple}}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it a bit weird to have this union of Number, AbstractArray (two types which are theoretically supported for x inputs in DI) and then just Tuple (which is not officially part of the supported inputs). Why not also NamedTuple for instance? Is it better if we just say Any? Or restrict to Number and AbstractArray for the time being?


# nested vectors, similar are not supported
Comment thread
yebai marked this conversation as resolved.
Outdated
function Mooncake.rrule!!(
dw::CoDual{<:DI.DifferentiateWith}, x::Union{CoDual{<:Number},CoDual{<:Tuple}}
)
Expand All @@ -10,31 +11,41 @@

# output is a vector, so we need to use the vector pullback
Comment thread
AstitvaAggarwal marked this conversation as resolved.
function pullback_array!!(dy::NoRData)
tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),))
@assert only(tx) isa rdata_type(typeof(primal_x))
return NoRData(), only(tx)
tx = DI.pullback(f, backend, primal_x, (y.dx,))
@assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x)))
return NoRData(), rdata(only(tx))
end

# output is a scalar, so we can use the scalar pullback
function pullback_scalar!!(dy::Number)
tx = DI.pullback(f, backend, primal_x, (dy,))
@assert only(tx) isa rdata_type(typeof(primal_x))
return NoRData(), only(tx)
@assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x)))
return NoRData(), rdata(only(tx))
end

# output is a Tuple, NTuple
function pullback_tuple!!(dy::Tuple)
tx = DI.pullback(f, backend, primal_x, (dy,))
@assert only(tx) isa rdata_type(typeof(primal_x))
return NoRData(), only(tx)
@assert rdata(only(tx)) isa rdata_type(tangent_type(typeof(primal_x)))
return NoRData(), rdata(only(tx))
end

pullback = if typeof(primal(y)) <: Number
# inputs are non Differentiable
function pullback_nodiff!!(dy::NoRData)
@assert tangent_type(typeof(primal(x))) <: NoTangent
return NoRData(), dy
end

pullback = if tangent_type(typeof(primal(x))) <: NoTangent
pullback_nodiff!!
elseif typeof(primal(y)) <: Number
Comment thread
gdalle marked this conversation as resolved.
Outdated
pullback_scalar!!
elseif typeof(primal(y)) <: Array
Comment thread
gdalle marked this conversation as resolved.
Outdated
pullback_array!!
else
elseif typeof(primal(y)) <: Tuple
Comment thread
gdalle marked this conversation as resolved.
Outdated
pullback_tuple!!
else
error("$(typeof(primal(y))) primal type currently not supported.")

Check warning on line 48 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L48

Added line #L48 was not covered by tests
Comment thread
AstitvaAggarwal marked this conversation as resolved.
Outdated
end

return y, pullback
Expand All @@ -43,40 +54,50 @@
function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray})
primal_func = primal(dw)
primal_x = primal(x)
fdata_arg = fdata(x.dx)
fdata_arg = x.dx
(; f, backend) = primal_func
y = zero_fcodual(f(primal_x))

# output is a vector, so we need to use the vector pullback
function pullback_array!!(dy::NoRData)
tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),))
@assert first(only(tx)) isa rdata_type(typeof(first(primal_x)))
tx = DI.pullback(f, backend, primal_x, (y.dx,))
@assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x))))
fdata_arg .+= only(tx)
return NoRData(), dy
end

# output is a scalar, so we can use the scalar pullback
function pullback_scalar!!(dy::Number)
tx = DI.pullback(f, backend, primal_x, (dy,))
@assert first(only(tx)) isa rdata_type(typeof(first(primal_x)))
@assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x))))
fdata_arg .+= only(tx)
return NoRData(), NoRData()
end

# output is a Tuple, NTuple
function pullback_tuple!!(dy::Tuple)
tx = DI.pullback(f, backend, primal_x, (dy,))
@assert first(only(tx)) isa rdata_type(typeof(first(primal_x)))
@assert rdata(first(only(tx))) isa rdata_type(tangent_type(typeof(first(primal_x))))
fdata_arg .+= only(tx)
return NoRData(), NoRData()
end

pullback = if typeof(primal(y)) <: Number
# inputs are non Differentiable
function pullback_nodiff!!(dy::NoRData)
@assert tangent_type(typeof(primal(x))) <: Vector{NoTangent}
return NoRData(), dy
end

pullback = if tangent_type(typeof(primal(x))) <: Vector{NoTangent}
pullback_nodiff!!
elseif typeof(primal(y)) <: Number
Comment thread
gdalle marked this conversation as resolved.
Outdated
pullback_scalar!!
elseif typeof(primal(y)) <: Array
elseif typeof(primal(y)) <: AbstractArray
Comment thread
gdalle marked this conversation as resolved.
Outdated
pullback_array!!
else
elseif typeof(primal(y)) <: Tuple
Comment thread
gdalle marked this conversation as resolved.
Outdated
pullback_tuple!!
else
error("$(typeof(primal(y))) primal type currently not supported.")

Check warning on line 100 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L100

Added line #L100 was not covered by tests
Comment thread
AstitvaAggarwal marked this conversation as resolved.
Outdated
end

return y, pullback
Expand All @@ -89,90 +110,93 @@
function Mooncake.generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:diffwith})
test_cases = reduce(
vcat,
map([Float64, Float32]) do P
return Any[
(false, :stability_and_allocs, nothing, cosh, P(0.3)),
(false, :stability_and_allocs, nothing, sinh, P(0.3)),
(false, :stability_and_allocs, nothing, Base.FastMath.exp10_fast, P(0.5)),
(false, :stability_and_allocs, nothing, Base.FastMath.exp2_fast, P(0.5)),
(false, :stability_and_allocs, nothing, Base.FastMath.exp_fast, P(5.0)),
(false, :stability_and_allocs, nothing, Base.FastMath.sincos, P(3.0)),
]
end,
map([(x) -> DI.DifferentiateWith(x, DI.AutoFiniteDiff())]) do F
map([Float64, Float32]) do P
return Any[
(false, :stability, nothing, F(cosh), P(0.3)),
(false, :stability, nothing, F(sinh), P(0.3)),
(false, :stability, nothing, F(Base.FastMath.exp10_fast), P(0.5)),
(false, :stability, nothing, F(Base.FastMath.exp2_fast), P(0.5)),
(false, :stability, nothing, F(Base.FastMath.exp_fast), P(5.0)),
(false, :none, nothing, F(copy), rand(Int32, 5)),
]
end
end...,
)
push!(test_cases, (false, :stability, nothing, copy, randn(5, 4)))
push!(test_cases, (
# Check that Core._apply_iterate gets lifted to _apply_iterate_equivalent.
false,
:none,
nothing,
x -> +(x...),
randn(33),
))
push!(
test_cases,
(
false,
:none,
nothing,
(
function (x)
rx = Ref(x)
return Base.pointerref(
Base.bitcast(Ptr{Float64}, pointer_from_objref(rx)), 1, 1
)
end
),
5.0,
),
)
push!(
test_cases,
(
false,
:none,
nothing,
x -> (pointerset(pointer(x), UInt8(3), 2, 1); x),
rand(UInt8, 5),
),
)
push!(test_cases, (false, :none, nothing, Mooncake.__vec_to_tuple, [1.0]))
push!(test_cases, (false, :none, nothing, Mooncake.__vec_to_tuple, Any[1.0]))
push!(test_cases, (false, :none, nothing, Mooncake.__vec_to_tuple, Any[[1.0]]))
push!(test_cases, (false, :stability, nothing, Mooncake.IntrinsicsWrappers.ctlz_int, 5))
push!(
test_cases, (false, :stability, nothing, Mooncake.IntrinsicsWrappers.ctpop_int, 5)
)
push!(test_cases, (false, :stability, nothing, Mooncake.IntrinsicsWrappers.cttz_int, 5))
push!(
test_cases, (false, :stability, nothing, Mooncake.IntrinsicsWrappers.abs_float, 5.0)
)
push!(
test_cases,
(false, :stability, nothing, Mooncake.IntrinsicsWrappers.abs_float, 5.0f0),
)
push!(test_cases, (false, :stability, nothing, deepcopy, 5.0))
push!(test_cases, (false, :stability, nothing, deepcopy, randn(5)))
push!(test_cases, (false, :stability_and_allocs, nothing, sin, 1.1))
push!(test_cases, (false, :stability_and_allocs, nothing, sin, 1.0f1))
push!(test_cases, (false, :stability_and_allocs, nothing, cos, 1.1))
push!(test_cases, (false, :stability_and_allocs, nothing, cos, 1.0f1))
push!(test_cases, (false, :stability_and_allocs, nothing, exp, 1.1))
push!(test_cases, (false, :stability_and_allocs, nothing, exp, 1.0f1))

# additional_test_set = Mooncake.tangent_test_cases()
# function is_valid(f)
# try
# isa(f([1.0, 2.0]), Union{<:Number,<:AbstractArray})
# catch
# false
# end
# end
# for test in additional_test_set
# if is_valid(test[2])
# push!(test_cases, test)
# end
# end

map([(x) -> DI.DifferentiateWith(x, DI.AutoZygote())]) do F
Comment thread
AstitvaAggarwal marked this conversation as resolved.
Outdated
map([Float64, Float32]) do P
push!(
test_cases,
Any[
(false, :stability, nothing, F(Base.FastMath.sincos), P(3.0)),
(false, :none, nothing, F(Mooncake.__vec_to_tuple), Any[P(1.0)]),
]...,
)
end
end

map([(x) -> DI.DifferentiateWith(x, DI.AutoZygote())]) do F
Comment thread
AstitvaAggarwal marked this conversation as resolved.
Outdated
push!(
test_cases,
Any[
(false, :stability, nothing, F(Mooncake.IntrinsicsWrappers.ctlz_int), 5),
(false, :stability, nothing, F(Mooncake.IntrinsicsWrappers.ctpop_int), 5),
(false, :stability, nothing, F(Mooncake.IntrinsicsWrappers.cttz_int), 5),
]...,
)
end

map([(x) -> DI.DifferentiateWith(x, DI.AutoFiniteDiff())]) do F
push!(
test_cases,
Any[
(false, :stability, nothing, copy, randn(5, 4)),
(
# Check that Core._apply_iterate gets lifted to _apply_iterate_equivalent.
false,
:none,
nothing,
F(x -> +(x...)),
randn(33),
),
(
false,
:none,
nothing,
(F(
function (x)
rx = Ref(x)
return Base.pointerref(
Base.bitcast(Ptr{Float64}, pointer_from_objref(rx)), 1, 1
)
end,
)),
5.0,
),
(false, :none, nothing, F(Mooncake.__vec_to_tuple), [1.0]),
# (false, :none, nothing, F(Mooncake.__vec_to_tuple), Any[(1.0,)]), DI.basis fails for this, correct it!
(false, :stability, nothing, F(Mooncake.IntrinsicsWrappers.ctlz_int), 5),
(false, :stability, nothing, F(Mooncake.IntrinsicsWrappers.ctpop_int), 5),
(false, :stability, nothing, F(Mooncake.IntrinsicsWrappers.cttz_int), 5),
(
false,
:stability,
nothing,
F(Mooncake.IntrinsicsWrappers.abs_float),
5.0f0,
),
(false, :stability, nothing, F(deepcopy), 5.0),
(false, :stability, nothing, F(deepcopy), randn(5)),
(false, :stability_and_allocs, nothing, F(sin), 1.1),
(false, :stability_and_allocs, nothing, F(sin), 1.0f1),
(false, :stability_and_allocs, nothing, F(cos), 1.1),
(false, :stability_and_allocs, nothing, F(cos), 1.0f1),
(false, :stability_and_allocs, nothing, F(exp), 1.1),
(false, :stability_and_allocs, nothing, F(exp), 1.0f1),
]...,
)
end

memory = Any[]
return test_cases, memory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@ test_differentiation(
logging=LOGGING,
)

@testset "new" begin
@testset "Mooncake tests" begin
Mooncake.TestUtils.run_rrule!!_test_cases(StableRNG, Val(:diffwith))
Comment thread
AstitvaAggarwal marked this conversation as resolved.
Outdated
end
Loading