Skip to content

Commit 7ee5859

Browse files
authored
Improve Mooncake caching (#513)
* Improve Mooncake caching * Fix
1 parent 3f29b61 commit 7ee5859

3 files changed

Lines changed: 41 additions & 30 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ using Mooncake:
2525

2626
DI.check_available(::AutoMooncake) = true
2727

28+
copyto!!(dst::Number, src::Number) = convert(typeof(dst), src)
29+
copyto!!(dst, src) = copyto!(dst, src)
30+
2831
get_config(::AutoMooncake{Nothing}) = Config()
2932
get_config(backend::AutoMooncake{<:Config}) = backend.config
3033

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
struct MooncakeOneArgPullbackPrep{Y,R} <: PullbackPrep
1+
struct MooncakeOneArgPullbackPrep{Y,R,DX,DY} <: PullbackPrep
22
y_prototype::Y
33
rrule::R
4+
dx_righttype::DX
5+
dy_righttype::DY
46
end
57

68
function DI.prepare_pullback(
@@ -14,7 +16,9 @@ function DI.prepare_pullback(
1416
debug_mode=config.debug_mode,
1517
silence_debug_messages=config.silence_debug_messages,
1618
)
17-
prep = MooncakeOneArgPullbackPrep(y, rrule)
19+
dx_righttype = zero_tangent(x)
20+
dy_righttype = zero_tangent(y)
21+
prep = MooncakeOneArgPullbackPrep(y, rrule, dx_righttype, dy_righttype)
1822
DI.value_and_pullback(f, prep, backend, x, ty, contexts...) # warm up
1923
return prep
2024
end
@@ -28,7 +32,7 @@ function DI.value_and_pullback(
2832
contexts::Vararg{Context,C},
2933
) where {Y,C}
3034
dy = only(ty)
31-
dy_righttype = convert(tangent_type(Y), dy)
35+
dy_righttype = dy isa tangent_type(Y) ? dy : copyto!!(prep.dy_righttype, dy)
3236
new_y, (_, new_dx) = value_and_pullback!!(
3337
prep.rrule, dy_righttype, f, x, map(unwrap, contexts)...
3438
)
@@ -45,8 +49,8 @@ function DI.value_and_pullback!(
4549
contexts::Vararg{Context,C},
4650
) where {Y,C}
4751
dx, dy = only(tx), only(ty)
48-
dy_righttype = convert(tangent_type(Y), dy)
49-
dx_righttype = set_to_zero!!(convert(tangent_type(typeof(x)), dx))
52+
dy_righttype = dy isa tangent_type(Y) ? dy : copyto!!(prep.dy_righttype, dy)
53+
dx_righttype = set_to_zero!!(prep.dx_righttype)
5054
contexts_coduals = map(zero_codual unwrap, contexts)
5155
y, (_, new_dx) = __value_and_pullback!!(
5256
prep.rrule,

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1-
struct MooncakeTwoArgPullbackPrep{R} <: PullbackPrep
1+
struct MooncakeTwoArgPullbackPrep{R,F,Y,DX,DY} <: PullbackPrep
22
rrule::R
3+
df!::F
4+
y_copy::Y
5+
dx_righttype::DX
6+
dy_righttype::DY
7+
dy_righttype_after::DY
38
end
49

510
function DI.prepare_pullback(
@@ -12,7 +17,14 @@ function DI.prepare_pullback(
1217
debug_mode=config.debug_mode,
1318
silence_debug_messages=config.silence_debug_messages,
1419
)
15-
prep = MooncakeTwoArgPullbackPrep(rrule)
20+
df! = zero_tangent(f!)
21+
y_copy = copy(y)
22+
dx_righttype = zero_tangent(x)
23+
dy_righttype = zero_tangent(y)
24+
dy_righttype_after = zero_tangent(y)
25+
prep = MooncakeTwoArgPullbackPrep(
26+
rrule, df!, y_copy, dx_righttype, dy_righttype, dy_righttype_after
27+
)
1628
DI.value_and_pullback(f!, y, prep, backend, x, ty, contexts...) # warm up
1729
return prep
1830
end
@@ -27,46 +39,38 @@ function DI.value_and_pullback(
2739
contexts::Vararg{Context,C},
2840
) where {C}
2941
dy = only(ty)
30-
dy_righttype = convert(tangent_type(typeof(y)), copy(dy))
31-
dx_righttype = zero_tangent(x)
3242

33-
# We want the VJP, not VJP + dx, so I'm going to zero-out `dx`. `set_to_zero!!` has the advantage
34-
# that it will also replace any immutable components of `dx` to zero.
35-
dx_righttype = set_to_zero!!(dx_righttype)
43+
# Set all tangent storage to zero.
44+
df! = set_to_zero!!(prep.df!)
45+
dx_righttype = set_to_zero!!(prep.dx_righttype)
46+
dy_righttype = set_to_zero!!(prep.dy_righttype)
3647

37-
# We want `dy` to correspond to the cotangent of `y` _after_
38-
# running the forwards-pass, so I'm going to take a copy, and zero-out the original.
39-
dy_righttype_backup = copy(dy_righttype)
40-
dy_righttype = set_to_zero!!(dy_righttype)
41-
contexts_coduals = map(zero_fcodual unwrap, contexts)
42-
43-
# Mutate a copy of `y`, so that we can run the reverse-pass later on.
44-
y_copy = copy(y)
48+
# Prepare cotangent to add after the forward pass.
49+
dy_righttype_after = copyto!(prep.dy_righttype_after, dy)
4550

46-
# In case `f!` is a closure
47-
df! = zero_tangent(f!)
51+
contexts_coduals = map(zero_fcodual unwrap, contexts)
4852

49-
# Run the forwards-pass.
53+
# Run the forward pass
5054
out, pb!! = prep.rrule(
5155
CoDual(f!, fdata(df!)),
52-
CoDual(y_copy, fdata(dy_righttype)),
56+
CoDual(prep.y_copy, fdata(dy_righttype)),
5357
CoDual(x, fdata(dx_righttype)),
5458
contexts_coduals...,
5559
)
5660

5761
# Verify that the output is non-differentiable.
5862
@assert primal(out) === nothing
5963

60-
# Set the cotangent of `y` to be equal to the requested value.
61-
dy_righttype = increment!!(dy_righttype, dy_righttype_backup)
64+
# Increment the desired cotangent dy.
65+
dy_righttype = increment!!(dy_righttype, dy_righttype_after)
6266

63-
# Record the state of `y` before running the reverse-pass.
64-
y = copyto!(y, y_copy)
67+
# Record the state of y before running the reverse pass.
68+
y = copyto!(y, prep.y_copy)
6569

66-
# Run the reverse-pass.
70+
# Run the reverse pass.
6771
_, _, new_dx = pb!!(NoRData())
6872

69-
return y, (tangent(fdata(dx_righttype), new_dx),)
73+
return y, (tangent(copy(fdata(dx_righttype)), new_dx),) # TODO: remove this allocation in `value_and_pullback!`
7074
end
7175

7276
function DI.value_and_pullback(

0 commit comments

Comments
 (0)