Skip to content

Commit 1340d92

Browse files
added rules
1 parent 08b176a commit 1340d92

4 files changed

Lines changed: 30 additions & 19 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ DifferentiationInterfaceFiniteDiffExt = "FiniteDiff"
3838
DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
3939
DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"]
4040
DifferentiationInterfaceGTPSAExt = "GTPSA"
41-
DifferentiationInterfaceMooncakeExt = ["ChainRulesCore", "Mooncake"]
41+
DifferentiationInterfaceMooncakeExt = "Mooncake"
4242
DifferentiationInterfacePolyesterForwardDiffExt = ["PolyesterForwardDiff", "ForwardDiff", "DiffResults"]
4343
DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"]
4444
DifferentiationInterfaceSparseArraysExt = "SparseArrays"

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ using Mooncake:
1515
@is_primitive,
1616
zero_fcodual,
1717
MinimalCtx,
18-
NoRData
19-
20-
using ChainRulesCore: ChainRulesCore, rrule
18+
NoRData,
19+
fdata,
20+
primal
2121

2222
DI.check_available(::AutoMooncake) = true
2323

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/differentiate_with.jl

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,31 @@
1-
@is_primitive MinimalCtx Tuple{CoDual{<:DI.DifferentiateWith},CoDual{<:AbstractArray}}
2-
@is_primitive MinimalCtx Tuple{CoDual{<:DI.DifferentiateWith},CoDual{<:Number}}
3-
4-
function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, args::CoDual...)
5-
primal_func = Mooncake.primal(dw)
6-
primal_args = map(arg -> Mooncake.primal(arg), args)
1+
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:AbstractArray}
2+
@is_primitive MinimalCtx Tuple{DI.DifferentiateWith,<:Number}
73

4+
function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:Number})
5+
primal_func = primal(dw)
6+
primal_x = primal(x)
87
(; f, backend) = primal_func
9-
y = f(primal_args...)
8+
y = f(primal_x)
9+
10+
function pullback!!(dy)
11+
tx = DI.pullback(f, backend, primal_x, (dy,))
12+
return NoRData(), only(tx)
13+
end
1014

11-
prep_same = DI.prepare_pullback_same_point_nokwarg(
12-
Val(true), f, backend, primal_args..., (y,)
13-
)
15+
return zero_fcodual(y), pullback!!
16+
end
17+
18+
function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, x::CoDual{<:AbstractArray})
19+
primal_func = primal(dw)
20+
primal_x = primal(x)
21+
fdata_arg = fdata(x.dx)
22+
(; f, backend) = primal_func
23+
y = f(primal_x)
1424

1525
function pullback!!(dy)
16-
tx = DI.pullback(f, prep_same, backend, primal_args, (dy,))
17-
args_rdata = map((x) -> (x, Mooncake.zero_rdata(x)), only(tx))
18-
return NoRData(), args_rdata...
26+
tx = DI.pullback(f, backend, primal_x, (dy,))
27+
fdata_arg .+= only(tx)
28+
return NoRData(), NoRData()
1929
end
2030

2131
return zero_fcodual(y), pullback!!

DifferentiationInterface/test/Back/DifferentiateWith/test.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
using Pkg
2-
Pkg.add(["FiniteDiff", "ForwardDiff", "Zygote"])
2+
Pkg.add(["FiniteDiff", "ForwardDiff", "Zygote", "Mooncake"])
33

44
using DifferentiationInterface, DifferentiationInterfaceTest
55
import DifferentiationInterfaceTest as DIT
66
using FiniteDiff: FiniteDiff
77
using ForwardDiff: ForwardDiff
88
using Zygote: Zygote
9+
using Mooncake: Mooncake
910
using Test
1011

1112
LOGGING = get(ENV, "CI", "false") == "false"
@@ -24,7 +25,7 @@ function differentiatewith_scenarios()
2425
end
2526

2627
test_differentiation(
27-
[AutoForwardDiff(), AutoZygote()],
28+
[AutoForwardDiff(), AutoZygote(), AutoMooncake()],
2829
differentiatewith_scenarios();
2930
excluded=SECOND_ORDER,
3031
logging=LOGGING,

0 commit comments

Comments
 (0)