Skip to content

Commit 1a389a6

Browse files
Handles backend switching for Mooncake using ChainRules
1 parent 51c56e8 commit 1a389a6

2 files changed

Lines changed: 22 additions & 1 deletion

File tree

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,19 @@ module DifferentiationInterfaceMooncakeExt
33
using ADTypes: ADTypes, AutoMooncake
44
import DifferentiationInterface as DI
55
using Mooncake:
6+
Mooncake,
67
CoDual,
78
Config,
89
prepare_gradient_cache,
910
prepare_pullback_cache,
1011
tangent_type,
1112
value_and_gradient!!,
1213
value_and_pullback!!,
13-
zero_tangent
14+
@from_rrule,
15+
MinimalCtx,
16+
NoFData
17+
18+
using ChainRulesCore: ChainRulesCore, rrule
1419

1520
DI.check_available(::AutoMooncake) = true
1621

@@ -26,5 +31,6 @@ mycopy(x) = deepcopy(x)
2631

2732
include("onearg.jl")
2833
include("twoarg.jl")
34+
include("differentiate_with.jl")
2935

3036
end
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
function define_rule!(primal_func, primal_args)
2+
return eval(:(@from_rrule MinimalCtx Tuple{$primal_func,$primal_args...}))
3+
end
4+
5+
function Mooncake.rrule!!(dw::CoDual{<:DI.DifferentiateWith}, args::CoDual...)
6+
primal_func = typeof(Mooncake.primal(dw))
7+
primal_args = typeof.(map(arg -> Mooncake.primal(arg), args))
8+
# use the DI.chainrule wrapper inside @from_rrule to create a custom rrule!!
9+
10+
# macro evaluation in global scope with more specialized types (@fromrrule requires non generic types)
11+
define_rule!(primal_func, primal_args)
12+
13+
# Use the ChainRuleCore rrule mapping with backends, calling Mooncake rule!! that now wraps around that ChainRulesCore rrule.
14+
return Base.invokelatest(Mooncake.rrule!!, CoDual(primal_func, dw.dx), args...)
15+
end

0 commit comments

Comments
 (0)