Skip to content

Commit a93e0d1

Browse files
committed
Update Mooncake support to 0.5
1 parent 8d33550 commit a93e0d1

5 files changed

Lines changed: 17 additions & 31 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ ForwardDiff = "0.10.36,1"
6969
GPUArraysCore = "0.2"
7070
GTPSA = "1.4.0"
7171
LinearAlgebra = "1"
72-
Mooncake = "0.4.175"
72+
Mooncake = "0.5.0"
7373
PolyesterForwardDiff = "0.1.2"
7474
ReverseDiff = "1.15.1"
7575
SparseArrays = "1"

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
## Pushforward
22

3+
# TODO: needs friendly tangents support
4+
35
struct MooncakeOneArgPushforwardPrep{SIG, Tcache, DX, FT, CT} <: DI.PushforwardPrep{SIG}
46
_sig::Val{SIG}
57
cache::Tcache
@@ -19,7 +21,7 @@ function DI.prepare_pushforward_nokwarg(
1921
_sig = DI.signature(f, backend, x, tx, contexts...; strict)
2022
config = get_config(backend)
2123
cache = prepare_derivative_cache(
22-
f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages
24+
f, x, map(DI.unwrap, contexts)...; config
2325
)
2426
dx_righttype = zero_tangent(x)
2527
df = zero_tangent(f)

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
## Pushforward
22

3+
# TODO: needs friendly tangents support
4+
35
struct MooncakeTwoArgPushforwardPrep{SIG, Tcache, DX, DY, FT, CT} <: DI.PushforwardPrep{SIG}
46
_sig::Val{SIG}
57
cache::Tcache
@@ -25,8 +27,7 @@ function DI.prepare_pushforward_nokwarg(
2527
y,
2628
x,
2729
map(DI.unwrap, contexts)...;
28-
config.debug_mode,
29-
config.silence_debug_messages,
30+
config,
3031
)
3132
dx_righttype = zero_tangent(x)
3233
dy_righttype = zero_tangent(y)

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
## Pullback
22

3-
struct MooncakeOneArgPullbackPrep{SIG, Tcache, DY, N} <: DI.PullbackPrep{SIG}
3+
struct MooncakeOneArgPullbackPrep{SIG, Tcache, N} <: DI.PullbackPrep{SIG}
44
_sig::Val{SIG}
55
cache::Tcache
6-
dy_righttype::DY
76
args_to_zero::NTuple{N, Bool}
87
end
98

@@ -12,18 +11,14 @@ function DI.prepare_pullback_nokwarg(
1211
) where {F, C}
1312
_sig = DI.signature(f, backend, x, ty, contexts...; strict)
1413
config = get_config(backend)
15-
cache = prepare_pullback_cache(
16-
f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages
17-
)
18-
y = f(x, map(DI.unwrap, contexts)...)
19-
dy_righttype = zero_tangent(y)
14+
cache = prepare_pullback_cache(f, x, map(DI.unwrap, contexts)...; config)
2015
contexts_tup_false = map(_ -> false, contexts)
2116
args_to_zero = (
2217
false, # f
2318
true, # x
2419
contexts_tup_false...,
2520
)
26-
prep = MooncakeOneArgPullbackPrep(_sig, cache, dy_righttype, args_to_zero)
21+
prep = MooncakeOneArgPullbackPrep(_sig, cache, args_to_zero)
2722
return prep
2823
end
2924

@@ -37,9 +32,8 @@ function DI.value_and_pullback(
3732
) where {F, Y, C}
3833
DI.check_prep(f, prep, backend, x, ty, contexts...)
3934
dy = only(ty)
40-
dy_righttype = dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
4135
new_y, (_, new_dx) = value_and_pullback!!(
42-
prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...;
36+
prep.cache, dy, f, x, map(DI.unwrap, contexts)...;
4337
prep.args_to_zero
4438
)
4539
return new_y, (_copy_output(new_dx),)
@@ -55,10 +49,8 @@ function DI.value_and_pullback(
5549
) where {F, Y, C}
5650
DI.check_prep(f, prep, backend, x, ty, contexts...)
5751
ys_and_tx = map(ty) do dy
58-
dy_righttype =
59-
dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
6052
y, (_, new_dx) = value_and_pullback!!(
61-
prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...;
53+
prep.cache, dy, f, x, map(DI.unwrap, contexts)...;
6254
prep.args_to_zero
6355
)
6456
y, _copy_output(new_dx)
@@ -121,9 +113,7 @@ function DI.prepare_gradient_nokwarg(
121113
) where {F, C}
122114
_sig = DI.signature(f, backend, x, contexts...; strict)
123115
config = get_config(backend)
124-
cache = prepare_gradient_cache(
125-
f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages
126-
)
116+
cache = prepare_gradient_cache(f, x, map(DI.unwrap, contexts)...; config)
127117
contexts_tup_false = map(_ -> false, contexts)
128118
args_to_zero = (
129119
false, # f

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
struct MooncakeTwoArgPullbackPrep{SIG, Tcache, DY, F, N} <: DI.PullbackPrep{SIG}
1+
struct MooncakeTwoArgPullbackPrep{SIG, Tcache, F, N} <: DI.PullbackPrep{SIG}
22
_sig::Val{SIG}
33
cache::Tcache
4-
dy_righttype::DY
54
target_function::F
65
args_to_zero::NTuple{N, Bool}
76
end
@@ -27,10 +26,8 @@ function DI.prepare_pullback_nokwarg(
2726
y,
2827
x,
2928
map(DI.unwrap, contexts)...;
30-
debug_mode = config.debug_mode,
31-
silence_debug_messages = config.silence_debug_messages,
29+
config,
3230
)
33-
dy_righttype_after = zero_tangent(y)
3431
contexts_tup_false = map(_ -> false, contexts)
3532
args_to_zero = (
3633
false, # target_function
@@ -39,9 +36,7 @@ function DI.prepare_pullback_nokwarg(
3936
true, # x
4037
contexts_tup_false...,
4138
)
42-
prep = MooncakeTwoArgPullbackPrep(
43-
_sig, cache, dy_righttype_after, target_function, args_to_zero
44-
)
39+
prep = MooncakeTwoArgPullbackPrep(_sig, cache, target_function, args_to_zero)
4540
return prep
4641
end
4742

@@ -56,12 +51,10 @@ function DI.value_and_pullback(
5651
) where {F, C}
5752
DI.check_prep(f!, y, prep, backend, x, ty, contexts...)
5853
dy = only(ty)
59-
# Prepare cotangent to add after the forward pass.
60-
dy_righttype_after = copyto!(prep.dy_righttype, dy)
6154
# Run the reverse-pass and return the results.
6255
y_after, (_, _, _, dx) = value_and_pullback!!(
6356
prep.cache,
64-
dy_righttype_after,
57+
dy,
6558
prep.target_function,
6659
f!,
6760
y,

0 commit comments

Comments
 (0)