Skip to content

Commit a81cbef

Browse files
committed
fix: speed up Mooncake reverse mode with selective zeroing
1 parent b7adfb6 commit a81cbef

2 files changed

Lines changed: 44 additions & 12 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
## Pullback
22

3-
struct MooncakeOneArgPullbackPrep{SIG, Tcache, DY} <: DI.PullbackPrep{SIG}
3+
struct MooncakeOneArgPullbackPrep{SIG, Tcache, DY, N} <: DI.PullbackPrep{SIG}
44
_sig::Val{SIG}
55
cache::Tcache
66
dy_righttype::DY
7+
args_to_zero::NTuple{N, Bool}
78
end
89

910
function DI.prepare_pullback_nokwarg(
@@ -16,7 +17,12 @@ function DI.prepare_pullback_nokwarg(
1617
)
1718
y = f(x, map(DI.unwrap, contexts)...)
1819
dy_righttype = zero_tangent(y)
19-
prep = MooncakeOneArgPullbackPrep(_sig, cache, dy_righttype)
20+
args_to_zero = (
21+
false, # f
22+
true, # x
23+
map(_ -> false, contexts)...,
24+
)
25+
prep = MooncakeOneArgPullbackPrep(_sig, cache, dy_righttype, args_to_zero)
2026
return prep
2127
end
2228

@@ -32,7 +38,8 @@ function DI.value_and_pullback(
3238
dy = only(ty)
3339
dy_righttype = dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
3440
new_y, (_, new_dx) = value_and_pullback!!(
35-
prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...
41+
prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...;
42+
prep.args_to_zero
3643
)
3744
return new_y, (_copy_output(new_dx),)
3845
end
@@ -50,7 +57,8 @@ function DI.value_and_pullback(
5057
dy_righttype =
5158
dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
5259
y, (_, new_dx) = value_and_pullback!!(
53-
prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...
60+
prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...;
61+
prep.args_to_zero
5462
)
5563
y, _copy_output(new_dx)
5664
end
@@ -101,9 +109,10 @@ end
101109

102110
## Gradient
103111

104-
struct MooncakeGradientPrep{SIG, Tcache} <: DI.GradientPrep{SIG}
112+
struct MooncakeGradientPrep{SIG, Tcache, N} <: DI.GradientPrep{SIG}
105113
_sig::Val{SIG}
106114
cache::Tcache
115+
args_to_zero::NTuple{N, Bool}
107116
end
108117

109118
function DI.prepare_gradient_nokwarg(
@@ -114,7 +123,12 @@ function DI.prepare_gradient_nokwarg(
114123
cache = prepare_gradient_cache(
115124
f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages
116125
)
117-
prep = MooncakeGradientPrep(_sig, cache)
126+
args_to_zero = (
127+
false, # f
128+
true, # x
129+
map(_ -> false, contexts)...,
130+
)
131+
prep = MooncakeGradientPrep(_sig, cache, args_to_zero)
118132
return prep
119133
end
120134

@@ -126,7 +140,10 @@ function DI.value_and_gradient(
126140
contexts::Vararg{DI.Context, C},
127141
) where {F, C}
128142
DI.check_prep(f, prep, backend, x, contexts...)
129-
y, (_, new_grad) = value_and_gradient!!(prep.cache, f, x, map(DI.unwrap, contexts)...)
143+
y, (_, new_grad) = value_and_gradient!!(
144+
prep.cache, f, x, map(DI.unwrap, contexts)...;
145+
prep.args_to_zero
146+
)
130147
return y, _copy_output(new_grad)
131148
end
132149

@@ -139,7 +156,10 @@ function DI.value_and_gradient!(
139156
contexts::Vararg{DI.Context, C},
140157
) where {F, C}
141158
DI.check_prep(f, prep, backend, x, contexts...)
142-
y, (_, new_grad) = value_and_gradient!!(prep.cache, f, x, map(DI.unwrap, contexts)...)
159+
y, (_, new_grad) = value_and_gradient!!(
160+
prep.cache, f, x, map(DI.unwrap, contexts)...;
161+
prep.args_to_zero
162+
)
143163
copyto!(grad, new_grad)
144164
return y, grad
145165
end

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
struct MooncakeTwoArgPullbackPrep{SIG, Tcache, DY, F} <: DI.PullbackPrep{SIG}
1+
struct MooncakeTwoArgPullbackPrep{SIG, Tcache, DY, F, N} <: DI.PullbackPrep{SIG}
22
_sig::Val{SIG}
33
cache::Tcache
44
dy_righttype::DY
55
target_function::F
6+
args_to_zero::NTuple{N, Bool}
67
end
78

89
function DI.prepare_pullback_nokwarg(
@@ -30,7 +31,16 @@ function DI.prepare_pullback_nokwarg(
3031
silence_debug_messages = config.silence_debug_messages,
3132
)
3233
dy_righttype_after = zero_tangent(y)
33-
prep = MooncakeTwoArgPullbackPrep(_sig, cache, dy_righttype_after, target_function)
34+
args_to_zero = (
35+
false, # target_function
36+
false, # f!
37+
false, # y
38+
true, # x
39+
map(_ -> false, contexts)...,
40+
)
41+
prep = MooncakeTwoArgPullbackPrep(
42+
_sig, cache, dy_righttype_after, target_function, args_to_zero
43+
)
3444
return prep
3545
end
3646

@@ -55,7 +65,8 @@ function DI.value_and_pullback(
5565
f!,
5666
y,
5767
x,
58-
map(DI.unwrap, contexts)...,
68+
map(DI.unwrap, contexts)...;
69+
prep.args_to_zero
5970
)
6071
copyto!(y, y_after)
6172
return y, (_copy_output(dx),)
@@ -80,7 +91,8 @@ function DI.value_and_pullback(
8091
f!,
8192
y,
8293
x,
83-
map(DI.unwrap, contexts)...,
94+
map(DI.unwrap, contexts)...;
95+
prep.args_to_zero
8496
)
8597
copyto!(y, y_after)
8698
_copy_output(dx)

0 commit comments

Comments
 (0)