Skip to content

Commit 08b176a

Browse files
Mooncake Wrapper for substitute backends
1 parent 1a389a6 commit 08b176a

3 files changed

Lines changed: 22 additions & 13 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ DifferentiationInterfaceFiniteDiffExt = "FiniteDiff"
3838
DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
3939
DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"]
4040
DifferentiationInterfaceGTPSAExt = "GTPSA"
41-
DifferentiationInterfaceMooncakeExt = "Mooncake"
41+
DifferentiationInterfaceMooncakeExt = ["ChainRulesCore", "Mooncake"]
4242
DifferentiationInterfacePolyesterForwardDiffExt = ["PolyesterForwardDiff", "ForwardDiff", "DiffResults"]
4343
DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"]
4444
DifferentiationInterfaceSparseArraysExt = "SparseArrays"

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@ using Mooncake:
1111
tangent_type,
1212
value_and_gradient!!,
1313
value_and_pullback!!,
14-
@from_rrule,
14+
zero_tangent,
15+
@is_primitive,
16+
zero_fcodual,
1517
MinimalCtx,
16-
NoFData
18+
NoRData
1719

1820
using ChainRulesCore: ChainRulesCore, rrule
1921

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
1-
function define_rule!(primal_func, primal_args)
2-
return eval(:(@from_rrule MinimalCtx Tuple{$primal_func,$primal_args...}))
3-
end
1+
@is_primitive MinimalCtx Tuple{CoDual{<:DI.DifferentiateWith},CoDual{<:AbstractArray}}
2+
@is_primitive MinimalCtx Tuple{CoDual{<:DI.DifferentiateWith},CoDual{<:Number}}
43

54
function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, args::CoDual...)
6-
primal_func = typeof(Mooncake.primal(dw))
7-
primal_args = typeof.(map(arg -> Mooncake.primal(arg), args))
8-
# use the DI.chainrule wrapper inside @from_rrule to create a custom rrule!!
5+
primal_func = Mooncake.primal(dw)
6+
primal_args = map(arg -> Mooncake.primal(arg), args)
7+
8+
(; f, backend) = primal_func
9+
y = f(primal_args...)
10+
11+
prep_same = DI.prepare_pullback_same_point_nokwarg(
12+
Val(true), f, backend, primal_args..., (y,)
13+
)
914

10-
# macro evaluation in global scope with more specialized types (@fromrrule requires non generic types)
11-
define_rule!(primal_func, primal_args)
15+
function pullback!!(dy)
16+
tx = DI.pullback(f, prep_same, backend, primal_args, (dy,))
17+
args_rdata = map((x) -> (x, Mooncake.zero_rdata(x)), only(tx))
18+
return NoRData(), args_rdata...
19+
end
1220

13-
# Use the ChainRuleCore rrule mapping with backends, calling Mooncake rule!! that now wraps around that ChainRulesCore rrule.
14-
return Base.invokelatest(Mooncake.rrule!!, CoDual(primal_func, dw.dx), args...)
21+
return zero_fcodual(y), pullback!!
1522
end

0 commit comments

Comments
 (0)