Skip to content

Commit a62c518

Browse files
committed
Remove useless _righttype caches
1 parent 151ad47 commit a62c518

8 files changed

Lines changed: 36 additions & 54 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ ForwardDiff = "0.10.36,1"
6969
GPUArraysCore = "0.2"
7070
GTPSA = "1.4.0"
7171
LinearAlgebra = "1"
72-
Mooncake = "0.5.0"
72+
Mooncake = "0.5.1"
7373
PolyesterForwardDiff = "0.1.2"
7474
ReverseDiff = "1.15.1"
7575
SparseArrays = "1"

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
## Pushforward
22

3-
struct MooncakeOneArgPushforwardPrep{SIG, Tcache, DX, FT, CT} <: DI.PushforwardPrep{SIG}
3+
struct MooncakeOneArgPushforwardPrep{SIG, Tcache, FT, CT} <: DI.PushforwardPrep{SIG}
44
_sig::Val{SIG}
55
cache::Tcache
6-
dx_righttype::DX
76
df::FT
87
context_tangents::CT
98
end
@@ -21,10 +20,9 @@ function DI.prepare_pushforward_nokwarg(
2120
cache = prepare_derivative_cache(
2221
f, x, map(DI.unwrap, contexts)...; config
2322
)
24-
dx_righttype = zero_tangent(x)
2523
df = zero_tangent(f)
2624
context_tangents = map(zero_tangent_unwrap, contexts)
27-
prep = MooncakeOneArgPushforwardPrep(_sig, cache, dx_righttype, df, context_tangents)
25+
prep = MooncakeOneArgPushforwardPrep(_sig, cache, df, context_tangents)
2826
return prep
2927
end
3028

@@ -38,13 +36,11 @@ function DI.value_and_pushforward(
3836
) where {F, C, X}
3937
DI.check_prep(f, prep, backend, x, tx, contexts...)
4038
ys_and_ty = map(tx) do dx
41-
dx_righttype =
42-
dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx)
4339
y_dual = value_and_derivative!!(
4440
prep.cache,
45-
Dual(f, prep.df),
46-
Dual(x, dx_righttype),
47-
map(Dual_unwrap, contexts, prep.context_tangents)...,
41+
(f, prep.df),
42+
(x, dx),
43+
map(first_unwrap, contexts, prep.context_tangents)...,
4844
)
4945
y = primal(y_dual)
5046
dy = _copy_output(tangent(y_dual))

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
## Pushforward
22

3-
struct MooncakeTwoArgPushforwardPrep{SIG, Tcache, DX, DY, FT, CT} <: DI.PushforwardPrep{SIG}
3+
struct MooncakeTwoArgPushforwardPrep{SIG, Tcache, FT, CT} <: DI.PushforwardPrep{SIG}
44
_sig::Val{SIG}
55
cache::Tcache
6-
dx_righttype::DX
7-
dy_righttype::DY
86
df!::FT
97
context_tangents::CT
108
end
@@ -27,11 +25,9 @@ function DI.prepare_pushforward_nokwarg(
2725
map(DI.unwrap, contexts)...;
2826
config
2927
)
30-
dx_righttype = zero_tangent(x)
31-
dy_righttype = zero_tangent(y)
3228
df! = zero_tangent(f!)
3329
context_tangents = map(zero_tangent_unwrap, contexts)
34-
prep = MooncakeTwoArgPushforwardPrep(_sig, cache, dx_righttype, dy_righttype, df!, context_tangents)
30+
prep = MooncakeTwoArgPushforwardPrep(_sig, cache, df!, context_tangents)
3531
return prep
3632
end
3733

@@ -46,15 +42,13 @@ function DI.value_and_pushforward(
4642
) where {F, C, X}
4743
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
4844
ty = map(tx) do dx
49-
dx_righttype =
50-
dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx)
5145
y_dual = zero_dual(y)
5246
value_and_derivative!!(
5347
prep.cache,
54-
Dual(f!, prep.df!),
48+
(f!, prep.df!),
5549
y_dual,
56-
Dual(x, dx_righttype),
57-
map(Dual_unwrap, contexts, prep.context_tangents)...,
50+
(x, dx),
51+
map(first_unwrap, contexts, prep.context_tangents)...,
5852
)
5953
dy = _copy_output(tangent(y_dual))
6054
return dy
@@ -87,18 +81,13 @@ function DI.value_and_pushforward!(
8781
) where {F, C, X, Y}
8882
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
8983
foreach(tx, ty) do dx, dy
90-
dx_righttype =
91-
dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx)
92-
dy_righttype =
93-
dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
9484
value_and_derivative!!(
9585
prep.cache,
96-
Dual(f!, prep.df!),
97-
Dual(y, dy_righttype),
98-
Dual(x, dx_righttype),
99-
map(Dual_unwrap, contexts, prep.context_tangents)...,
86+
(f!, prep.df!),
87+
(y, dy),
88+
(x, dx),
89+
map(first_unwrap, contexts, prep.context_tangents)...,
10090
)
101-
dy === dy_righttype || copyto!(dy, dy_righttype)
10291
end
10392
return y, ty
10493
end

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
## Pullback
22

3-
struct MooncakeOneArgPullbackPrep{SIG, Tcache, DY, N} <: DI.PullbackPrep{SIG}
3+
struct MooncakeOneArgPullbackPrep{SIG, Tcache, N} <: DI.PullbackPrep{SIG}
44
_sig::Val{SIG}
55
cache::Tcache
6-
dy_righttype::DY
76
args_to_zero::NTuple{N, Bool}
87
end
98

@@ -15,15 +14,13 @@ function DI.prepare_pullback_nokwarg(
1514
cache = prepare_pullback_cache(
1615
f, x, map(DI.unwrap, contexts)...; config
1716
)
18-
y = f(x, map(DI.unwrap, contexts)...)
19-
dy_righttype = zero_tangent(y)
2017
contexts_tup_false = map(_ -> false, contexts)
2118
args_to_zero = (
2219
false, # f
2320
true, # x
2421
contexts_tup_false...,
2522
)
26-
prep = MooncakeOneArgPullbackPrep(_sig, cache, dy_righttype, args_to_zero)
23+
prep = MooncakeOneArgPullbackPrep(_sig, cache, args_to_zero)
2724
return prep
2825
end
2926

@@ -37,10 +34,8 @@ function DI.value_and_pullback(
3734
) where {F, Y, C}
3835
DI.check_prep(f, prep, backend, x, ty, contexts...)
3936
dy = only(ty)
40-
dy_righttype = dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
4137
new_y, (_, new_dx) = value_and_pullback!!(
42-
prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...;
43-
prep.args_to_zero
38+
prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero
4439
)
4540
return new_y, (_copy_output(new_dx),)
4641
end
@@ -55,11 +50,8 @@ function DI.value_and_pullback(
5550
) where {F, Y, C}
5651
DI.check_prep(f, prep, backend, x, ty, contexts...)
5752
ys_and_tx = map(ty) do dy
58-
dy_righttype =
59-
dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
6053
y, (_, new_dx) = value_and_pullback!!(
61-
prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...;
62-
prep.args_to_zero
54+
prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero
6355
)
6456
y, _copy_output(new_dx)
6557
end

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
struct MooncakeTwoArgPullbackPrep{SIG, Tcache, DY, F, N} <: DI.PullbackPrep{SIG}
22
_sig::Val{SIG}
33
cache::Tcache
4-
dy_righttype::DY
4+
dy_backup::DY
55
target_function::F
66
args_to_zero::NTuple{N, Bool}
77
end
@@ -29,7 +29,7 @@ function DI.prepare_pullback_nokwarg(
2929
map(DI.unwrap, contexts)...;
3030
config,
3131
)
32-
dy_righttype_after = zero_tangent(y)
32+
dy_backup_after = zero_tangent(y)
3333
contexts_tup_false = map(_ -> false, contexts)
3434
args_to_zero = (
3535
false, # target_function
@@ -39,7 +39,7 @@ function DI.prepare_pullback_nokwarg(
3939
contexts_tup_false...,
4040
)
4141
prep = MooncakeTwoArgPullbackPrep(
42-
_sig, cache, dy_righttype_after, target_function, args_to_zero
42+
_sig, cache, dy_backup_after, target_function, args_to_zero
4343
)
4444
return prep
4545
end
@@ -56,11 +56,11 @@ function DI.value_and_pullback(
5656
DI.check_prep(f!, y, prep, backend, x, ty, contexts...)
5757
dy = only(ty)
5858
# Prepare cotangent to add after the forward pass.
59-
dy_righttype_after = copyto!(prep.dy_righttype, dy)
59+
dy_backup_after = copyto!(prep.dy_backup, dy)
6060
# Run the reverse-pass and return the results.
6161
y_after, (_, _, _, dx) = value_and_pullback!!(
6262
prep.cache,
63-
dy_righttype_after,
63+
dy_backup_after,
6464
prep.target_function,
6565
f!,
6666
y,
@@ -83,10 +83,10 @@ function DI.value_and_pullback(
8383
) where {F, C}
8484
DI.check_prep(f!, y, prep, backend, x, ty, contexts...)
8585
tx = map(ty) do dy
86-
dy_righttype_after = copyto!(prep.dy_righttype, dy)
86+
dy_backup_after = copyto!(prep.dy_backup, dy)
8787
y_after, (_, _, _, dx) = value_and_pullback!!(
8888
prep.cache,
89-
dy_righttype_after,
89+
dy_backup_after,
9090
prep.target_function,
9191
f!,
9292
y,

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@ get_config(::AnyAutoMooncake{Nothing}) = Config()
22
get_config(backend::AnyAutoMooncake{<:Config}) = backend.config
33

44
@inline zero_tangent_unwrap(c::DI.Context) = zero_tangent(DI.unwrap(c))
5-
@inline Dual_unwrap(c, dc) = Dual(DI.unwrap(c), dc)
5+
@inline first_unwrap(c, dc) = (DI.unwrap(c), dc)

DifferentiationInterface/test/Back/Mooncake/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
77
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
88
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
99
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
10+
11+
[sources]
12+
Mooncake = { url = "https://github.com/gdalle/Mooncake.jl", rev = "gd/forwardcache" }

DifferentiationInterface/test/Back/Mooncake/test.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ using Test
77
using ExplicitImports
88
check_no_implicit_imports(DifferentiationInterface)
99

10-
1110
backends = [
12-
AutoMooncake(; config = nothing),
13-
AutoMooncake(; config = Mooncake.Config()),
14-
AutoMooncakeForward(; config = nothing),
11+
AutoMooncake(),
12+
AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)),
13+
AutoMooncakeForward(),
14+
AutoMooncakeForward(; config = Mooncake.Config(; friendly_tangents = true)),
1515
]
1616

1717
for backend in backends
@@ -51,3 +51,5 @@ test_differentiation(
5151
@test grad.A == ps.B
5252
@test grad.B == ps.A
5353
end
54+
55+
# TODO: test static arrays with friendly tangents!

0 commit comments

Comments
 (0)