Skip to content

Commit fd3dc14

Browse files
authored
Mooncake Upgrades (#645)
1 parent cc8818a commit fd3dc14

4 files changed

Lines changed: 41 additions & 92 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,13 @@ FiniteDifferences = "0.12.31"
5555
ForwardDiff = "0.10.36"
5656
JuliaFormatter = "1"
5757
LinearAlgebra = "<0.0.1,1"
58-
Mooncake = "0.4.11"
58+
Mooncake = "0.4.52"
5959
PolyesterForwardDiff = "0.1.2"
6060
ReverseDiff = "1.15.1"
6161
SparseArrays = "<0.0.1,1"
6262
SparseConnectivityTracer = "0.5.0,0.6"
63-
StaticArrays = "1.9.7"
6463
SparseMatrixColorings = "0.4.9"
64+
StaticArrays = "1.9.7"
6565
Symbolics = "5.27.1, 6"
6666
Tracker = "0.2.33"
6767
Zygote = "0.6.69"

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,13 @@ using DifferentiationInterface: Context, PullbackPrep, unwrap
66
using Mooncake:
77
CoDual,
88
Config,
9-
NoRData,
10-
NoTangent,
11-
build_rrule,
12-
fdata,
13-
get_interpreter,
14-
increment!!,
159
primal,
16-
rdata,
17-
set_to_zero!!,
1810
tangent,
1911
tangent_type,
2012
value_and_pullback!!,
21-
zero_codual,
22-
zero_fcodual,
2313
zero_tangent,
24-
__value_and_pullback!!
14+
prepare_pullback_cache,
15+
Mooncake
2516

2617
DI.check_available(::AutoMooncake) = true
2718

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,65 +1,49 @@
1-
struct MooncakeOneArgPullbackPrep{Y,R,DX,DY} <: PullbackPrep
2-
y_prototype::Y
3-
rrule::R
4-
dx_righttype::DX
1+
struct MooncakeOneArgPullbackPrep{Tcache,DY} <: PullbackPrep
2+
cache::Tcache
53
dy_righttype::DY
64
end
75

86
function DI.prepare_pullback(
97
f, backend::AutoMooncake, x, ty::NTuple, contexts::Vararg{Context,C}
108
) where {C}
11-
y = f(x, map(unwrap, contexts)...)
129
config = get_config(backend)
13-
rrule = build_rrule(
14-
get_interpreter(),
15-
Tuple{typeof(f),typeof(x),typeof.(map(unwrap, contexts))...};
16-
debug_mode=config.debug_mode,
17-
silence_debug_messages=config.silence_debug_messages,
10+
cache = prepare_pullback_cache(
11+
f, x, map(unwrap, contexts)...; config.debug_mode, config.silence_debug_messages
1812
)
19-
dx_righttype = zero_tangent(x)
13+
y = f(x, map(unwrap, contexts)...)
2014
dy_righttype = zero_tangent(y)
21-
prep = MooncakeOneArgPullbackPrep(y, rrule, dx_righttype, dy_righttype)
22-
DI.value_and_pullback(f, prep, backend, x, ty, contexts...) # warm up
15+
prep = MooncakeOneArgPullbackPrep(cache, dy_righttype)
16+
DI.value_and_pullback(f, prep, backend, x, ty, contexts...)
2317
return prep
2418
end
2519

2620
function DI.value_and_pullback(
27-
f,
21+
f::F,
2822
prep::MooncakeOneArgPullbackPrep{Y},
2923
::AutoMooncake,
3024
x,
3125
ty::NTuple{1},
3226
contexts::Vararg{Context,C},
33-
) where {Y,C}
27+
) where {F,Y,C}
3428
dy = only(ty)
3529
dy_righttype = dy isa tangent_type(Y) ? dy : copyto!!(prep.dy_righttype, dy)
36-
new_y, (_, new_dx) = value_and_pullback!!(
37-
prep.rrule, dy_righttype, f, x, map(unwrap, contexts)...
30+
new_y, (_, new_dx) = Mooncake.value_and_pullback!!(
31+
prep.cache, dy_righttype, f, x, map(unwrap, contexts)...
3832
)
39-
return new_y, (new_dx,)
33+
return new_y, (copy(new_dx),)
4034
end
4135

4236
function DI.value_and_pullback!(
4337
f,
4438
tx::NTuple{1},
4539
prep::MooncakeOneArgPullbackPrep{Y},
46-
::AutoMooncake,
40+
backend::AutoMooncake,
4741
x,
4842
ty::NTuple{1},
4943
contexts::Vararg{Context,C},
5044
) where {Y,C}
51-
dx, dy = only(tx), only(ty)
52-
dy_righttype = dy isa tangent_type(Y) ? dy : copyto!!(prep.dy_righttype, dy)
53-
dx_righttype = set_to_zero!!(prep.dx_righttype)
54-
contexts_coduals = map(zero_codual unwrap, contexts)
55-
y, (_, new_dx) = __value_and_pullback!!(
56-
prep.rrule,
57-
dy_righttype,
58-
zero_codual(f),
59-
CoDual(x, dx_righttype),
60-
contexts_coduals...,
61-
)
62-
copyto!(dx, new_dx)
45+
y, (new_dx,) = DI.value_and_pullback(f, prep, backend, x, ty, contexts...)
46+
copyto!(only(tx), new_dx)
6347
return y, tx
6448
end
6549

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl

Lines changed: 22 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,28 @@
1-
struct MooncakeTwoArgPullbackPrep{R,F,Y,DX,DY} <: PullbackPrep
2-
rrule::R
3-
df!::F
4-
y_copy::Y
5-
dx_righttype::DX
1+
struct MooncakeTwoArgPullbackPrep{Tcache,DY,F} <: PullbackPrep
2+
cache::Tcache
63
dy_righttype::DY
7-
dy_righttype_after::DY
4+
target_function::F
85
end
96

107
function DI.prepare_pullback(
118
f!, y, backend::AutoMooncake, x, ty::NTuple, contexts::Vararg{Context,C}
129
) where {C}
10+
target_function = function (f!, y, x, contexts...)
11+
f!(y, x, contexts...)
12+
return y
13+
end
1314
config = get_config(backend)
14-
rrule = build_rrule(
15-
get_interpreter(),
16-
Tuple{typeof(f!),typeof(y),typeof(x),typeof.(map(unwrap, contexts))...};
15+
cache = prepare_pullback_cache(
16+
target_function,
17+
f!,
18+
y,
19+
x,
20+
map(unwrap, contexts)...;
1721
debug_mode=config.debug_mode,
1822
silence_debug_messages=config.silence_debug_messages,
1923
)
20-
df! = zero_tangent(f!)
21-
y_copy = copy(y)
22-
dx_righttype = zero_tangent(x)
23-
dy_righttype = zero_tangent(y)
2424
dy_righttype_after = zero_tangent(y)
25-
prep = MooncakeTwoArgPullbackPrep(
26-
rrule, df!, y_copy, dx_righttype, dy_righttype, dy_righttype_after
27-
)
28-
DI.value_and_pullback(f!, y, prep, backend, x, ty, contexts...) # warm up
29-
return prep
25+
return MooncakeTwoArgPullbackPrep(cache, dy_righttype_after, target_function)
3026
end
3127

3228
function DI.value_and_pullback(
@@ -38,39 +34,17 @@ function DI.value_and_pullback(
3834
ty::NTuple{1},
3935
contexts::Vararg{Context,C},
4036
) where {C}
41-
dy = only(ty)
42-
43-
# Set all tangent storage to zero.
44-
df! = set_to_zero!!(prep.df!)
45-
dx_righttype = set_to_zero!!(prep.dx_righttype)
46-
dy_righttype = set_to_zero!!(prep.dy_righttype)
47-
4837
# Prepare cotangent to add after the forward pass.
49-
dy_righttype_after = copyto!(prep.dy_righttype_after, dy)
50-
51-
contexts_coduals = map(zero_fcodual unwrap, contexts)
38+
dy = only(ty)
39+
dy_righttype_after = copyto!(prep.dy_righttype, dy)
5240

53-
# Run the forward pass
54-
out, pb!! = prep.rrule(
55-
CoDual(f!, fdata(df!)),
56-
CoDual(prep.y_copy, fdata(dy_righttype)),
57-
CoDual(x, fdata(dx_righttype)),
58-
contexts_coduals...,
41+
# Run the reverse-pass and return the results.
42+
contexts = map(unwrap, contexts)
43+
y_after, (_, _, _, dx) = Mooncake.value_and_pullback!!(
44+
prep.cache, dy_righttype_after, prep.target_function, f!, y, x, contexts...
5945
)
60-
61-
# Verify that the output is non-differentiable.
62-
@assert primal(out) === nothing
63-
64-
# Increment the desired cotangent dy.
65-
dy_righttype = increment!!(dy_righttype, dy_righttype_after)
66-
67-
# Record the state of y before running the reverse pass.
68-
y = copyto!(y, prep.y_copy)
69-
70-
# Run the reverse pass.
71-
_, _, new_dx = pb!!(NoRData())
72-
73-
return y, (tangent(copy(fdata(dx_righttype)), new_dx),) # TODO: remove this allocation in `value_and_pullback!`
46+
copyto!(y, y_after)
47+
return y, (copy(dx),) # TODO: remove this allocation in `value_and_pullback!`
7448
end
7549

7650
function DI.value_and_pullback(

0 commit comments

Comments
 (0)