Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
1a389a6
Handles backend switching for Mooncake using ChainRules
AstitvaAggarwal Apr 1, 2025
08b176a
Mooncake Wrapper for substitute backends
AstitvaAggarwal Apr 2, 2025
ba0c9e6
Merge branch 'JuliaDiff:main' into develop
AstitvaAggarwal Apr 10, 2025
1340d92
added rules
AstitvaAggarwal Apr 10, 2025
2ce1ee2
Merge branch 'develop' of https://github.com/AstitvaAggarwal/Differen…
AstitvaAggarwal Apr 10, 2025
08de6df
config
AstitvaAggarwal Apr 10, 2025
84f27c9
splatting for dy
AstitvaAggarwal Apr 10, 2025
2e95299
brackets
AstitvaAggarwal Apr 10, 2025
13233e5
too easy
AstitvaAggarwal Apr 11, 2025
1e8df98
changes from reviews, Docs
AstitvaAggarwal Apr 12, 2025
afdddd4
changes from reviews - 2
AstitvaAggarwal Apr 18, 2025
233c312
Merge branch 'JuliaDiff:main' into develop
AstitvaAggarwal Apr 18, 2025
7a07127
changes from reviews-1
AstitvaAggarwal May 16, 2025
f3e436d
conflicts
AstitvaAggarwal May 16, 2025
6a0d937
conflicts-2
AstitvaAggarwal May 16, 2025
e543958
Update differentiate_with.jl
AstitvaAggarwal May 16, 2025
2472ecc
Merge branch 'JuliaDiff:main' into develop
AstitvaAggarwal May 16, 2025
c63c956
typecheck for array rule.
AstitvaAggarwal May 18, 2025
36da036
assertion for array inputs
AstitvaAggarwal May 18, 2025
d2b5a8c
Merge branch 'JuliaDiff:main' into develop
AstitvaAggarwal May 29, 2025
c389a80
extensive tests, diffwith for tuples
AstitvaAggarwal May 29, 2025
b4fe0f8
tests.
AstitvaAggarwal May 29, 2025
ec4b75d
tests, inc primal handling
AstitvaAggarwal May 31, 2025
0f0b9fc
changes from reviews
AstitvaAggarwal Jun 6, 2025
3c5f99e
Merge branch 'main' into develop
yebai Jun 13, 2025
d94f146
Apply suggestions from code review
gdalle Jun 13, 2025
c982f46
Simplify Mooncake rule tests, add ChainRules rule tests
gdalle Jun 13, 2025
749fea5
Format
gdalle Jun 13, 2025
9e5ecfd
Update differentiate_with.jl
gdalle Jun 14, 2025
1e85f17
Restrict to array of numbers
gdalle Jun 14, 2025
ff5c4e2
Update DifferentiationInterface/ext/DifferentiationInterfaceMooncakeE…
gdalle Jun 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,21 @@ module DifferentiationInterfaceMooncakeExt
using ADTypes: ADTypes, AutoMooncake
import DifferentiationInterface as DI
using Mooncake:
Mooncake,
CoDual,
Config,
prepare_gradient_cache,
prepare_pullback_cache,
tangent_type,
value_and_gradient!!,
value_and_pullback!!,
zero_tangent
zero_tangent,
@is_primitive,
zero_fcodual,
MinimalCtx,
NoRData,
fdata,
primal

DI.check_available(::AutoMooncake) = true

Expand All @@ -26,5 +33,6 @@ mycopy(x) = deepcopy(x)

include("onearg.jl")
include("twoarg.jl")
include("differentiate_with.jl")

end
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:AbstractArray}
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Number}

function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number})
primal_func = primal(dw)
Comment thread
AstitvaAggarwal marked this conversation as resolved.
primal_x = primal(x)
(; f, backend) = primal_func
y = zero_fcodual(f(primal_x))

Check warning on line 8 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L4-L8

Added lines #L4 - L8 were not covered by tests
Comment thread
gdalle marked this conversation as resolved.

# output is a vector, so we need to use the vector pullback
Comment thread
AstitvaAggarwal marked this conversation as resolved.
function pullback!!(dy::NoRData)
tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),))
return NoRData(), only(tx)

Check warning on line 13 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L11-L13

Added lines #L11 - L13 were not covered by tests
Comment thread
AstitvaAggarwal marked this conversation as resolved.
Outdated
Comment thread
AstitvaAggarwal marked this conversation as resolved.
Outdated
end

# output is a scalar, so we can use the scalar pullback
function pullback!!(dy)
Comment thread
AstitvaAggarwal marked this conversation as resolved.
Outdated
tx = DI.pullback(f, backend, primal_x, (dy,))
return NoRData(), only(tx)

Check warning on line 19 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L17-L19

Added lines #L17 - L19 were not covered by tests
end

return y, pullback!!

Check warning on line 22 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L22

Added line #L22 was not covered by tests
end

function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray})
primal_func = primal(dw)
primal_x = primal(x)
fdata_arg = fdata(x.dx)
(; f, backend) = primal_func
y = zero_fcodual(f(primal_x))

Check warning on line 30 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L25-L30

Added lines #L25 - L30 were not covered by tests

# output is a vector, so we need to use the vector pullback
function pullback!!(dy::NoRData)
tx = DI.pullback(f, backend, primal_x, (fdata(y.dx),))
fdata_arg .+= only(tx)
return NoRData(), dy

Check warning on line 36 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L33-L36

Added lines #L33 - L36 were not covered by tests
end

# output is a scalar, so we can use the scalar pullback
function pullback!!(dy)
tx = DI.pullback(f, backend, primal_x, (dy,))
fdata_arg .+= only(tx)
return NoRData(), NoRData()

Check warning on line 43 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L40-L43

Added lines #L40 - L43 were not covered by tests
end

# in case x is mutated when passed into f
x = CoDual(primal_x, x.dx)
return y, pullback!!

Check warning on line 48 in DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl#L47-L48

Added lines #L47 - L48 were not covered by tests
end
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
using Pkg
Pkg.add(["FiniteDiff", "ForwardDiff", "Zygote"])
Pkg.add(["FiniteDiff", "ForwardDiff", "Zygote", "Mooncake"])

using DifferentiationInterface, DifferentiationInterfaceTest
import DifferentiationInterfaceTest as DIT
using FiniteDiff: FiniteDiff
using ForwardDiff: ForwardDiff
using Zygote: Zygote
using Mooncake: Mooncake
Comment thread
AstitvaAggarwal marked this conversation as resolved.
using Test

LOGGING = get(ENV, "CI", "false") == "false"
Expand All @@ -24,7 +25,7 @@ function differentiatewith_scenarios()
end

test_differentiation(
[AutoForwardDiff(), AutoZygote()],
[AutoForwardDiff(), AutoZygote(), AutoMooncake(; config=nothing)],
differentiatewith_scenarios();
excluded=SECOND_ORDER,
logging=LOGGING,
Expand Down