Skip to content

Commit aa9c806

Browse files
committed
Fix and test friendly tangents with static arrays
1 parent 604bd8e commit aa9c806

8 files changed

Lines changed: 38 additions & 25 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ using Mooncake:
2929
NoRData,
3030
primal,
3131
_copy_output,
32-
_copy_to_output!!
32+
_copy_to_output!!,
33+
tangent_to_primal!!
3334

3435
const AnyAutoMooncake{C} = Union{AutoMooncake{C}, AutoMooncakeForward{C}}
3536

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,8 @@ function DI.prepare_pushforward_nokwarg(
1717
) where {F, C}
1818
_sig = DI.signature(f, backend, x, tx, contexts...; strict)
1919
config = get_config(backend)
20-
cache = prepare_derivative_cache(
21-
f, x, map(DI.unwrap, contexts)...; config
22-
)
23-
df = zero_tangent(f)
20+
cache = prepare_derivative_cache(f, x, map(DI.unwrap, contexts)...; config)
21+
df = zero_tangent_or_primal(f, backend)
2422
context_tangents = map(zero_tangent_unwrap, contexts)
2523
prep = MooncakeOneArgPushforwardPrep(_sig, cache, df, context_tangents)
2624
return prep

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ function DI.prepare_pushforward_nokwarg(
2626
map(DI.unwrap, contexts)...;
2727
config
2828
)
29-
df! = zero_tangent(f!)
29+
df! = zero_tangent_or_primal(f!, backend)
3030
context_tangents = map(zero_tangent_unwrap, contexts)
3131
prep = MooncakeTwoArgPushforwardPrep(_sig, cache, df!, context_tangents)
3232
return prep
@@ -43,10 +43,11 @@ function DI.value_and_pushforward(
4343
) where {F, C, X}
4444
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
4545
ty = map(tx) do dx
46-
dy = zero_tangent(y) # TODO: remove allocation?
46+
dy = zero_tangent_or_primal(y, backend) # TODO: remove allocation?
47+
dcall = zero_tangent_or_primal(call_and_return, backend)
4748
_, new_dy = value_and_derivative!!(
4849
prep.cache,
49-
(call_and_return, zero_tangent(call_and_return)),
50+
(call_and_return, dcall),
5051
(f!, prep.df!),
5152
(y, dy),
5253
(x, dx),
@@ -82,9 +83,10 @@ function DI.value_and_pushforward!(
8283
) where {F, C, X, Y}
8384
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
8485
foreach(tx, ty) do dx, dy
86+
dcall = zero_tangent_or_primal(call_and_return, backend)
8587
_, new_dy = value_and_derivative!!(
8688
prep.cache,
87-
(call_and_return, zero_tangent(call_and_return)),
89+
(call_and_return, dcall),
8890
(f!, prep.df!),
8991
(y, dy),
9092
(x, dx),

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,7 @@ function DI.prepare_pullback_nokwarg(
1111
) where {F, C}
1212
_sig = DI.signature(f, backend, x, ty, contexts...; strict)
1313
config = get_config(backend)
14-
cache = prepare_pullback_cache(
15-
f, x, map(DI.unwrap, contexts)...; config
16-
)
14+
cache = prepare_pullback_cache(f, x, map(DI.unwrap, contexts)...; config)
1715
contexts_tup_false = map(_ -> false, contexts)
1816
args_to_zero = (
1917
false, # f
@@ -113,9 +111,7 @@ function DI.prepare_gradient_nokwarg(
113111
) where {F, C}
114112
_sig = DI.signature(f, backend, x, contexts...; strict)
115113
config = get_config(backend)
116-
cache = prepare_gradient_cache(
117-
f, x, map(DI.unwrap, contexts)...; config
118-
)
114+
cache = prepare_gradient_cache(f, x, map(DI.unwrap, contexts)...; config)
119115
contexts_tup_false = map(_ -> false, contexts)
120116
args_to_zero = (
121117
false, # f

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ function DI.prepare_pullback_nokwarg(
2424
map(DI.unwrap, contexts)...;
2525
config,
2626
)
27-
dy_backup_after = zero_tangent(y)
27+
dy_backup = zero_tangent_or_primal(y, backend)
2828
contexts_tup_false = map(_ -> false, contexts)
2929
args_to_zero = (
3030
false, # call_and_return
@@ -34,7 +34,7 @@ function DI.prepare_pullback_nokwarg(
3434
contexts_tup_false...,
3535
)
3636
prep = MooncakeTwoArgPullbackPrep(
37-
_sig, cache, dy_backup_after, args_to_zero
37+
_sig, cache, dy_backup, args_to_zero
3838
)
3939
return prep
4040
end
@@ -51,11 +51,11 @@ function DI.value_and_pullback(
5151
DI.check_prep(f!, y, prep, backend, x, ty, contexts...)
5252
dy = only(ty)
5353
# Prepare cotangent to add after the forward pass.
54-
dy_backup_after = copyto!(prep.dy_backup, dy)
54+
dy_backup = copyto!(prep.dy_backup, dy)
5555
# Run the reverse-pass and return the results.
5656
y_after, (_, _, _, dx) = value_and_pullback!!(
5757
prep.cache,
58-
dy_backup_after,
58+
dy_backup,
5959
call_and_return,
6060
f!,
6161
y,
@@ -78,10 +78,10 @@ function DI.value_and_pullback(
7878
) where {F, C}
7979
DI.check_prep(f!, y, prep, backend, x, ty, contexts...)
8080
tx = map(ty) do dy
81-
dy_backup_after = copyto!(prep.dy_backup, dy)
81+
dy_backup = copyto!(prep.dy_backup, dy)
8282
y_after, (_, _, _, dx) = value_and_pullback!!(
8383
prep.cache,
84-
dy_backup_after,
84+
dy_backup,
8585
call_and_return,
8686
f!,
8787
y,

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,12 @@ function call_and_return(f!::F, y, x, contexts...) where {F}
88
f!(y, x, contexts...)
99
return y
1010
end
11+
12+
function zero_tangent_or_primal(x, backend::AnyAutoMooncake)
13+
if backend.config.friendly_tangents
14+
# zero(x) but safer
15+
return tangent_to_primal!!(_copy_output(x), zero_tangent(x))
16+
else
17+
return zero_tangent(x)
18+
end
19+
end

DifferentiationInterface/test/Back/Mooncake/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,5 @@ ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
66
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
77
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
88
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
9+
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
910
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

DifferentiationInterface/test/Back/Mooncake/test.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ check_no_implicit_imports(DifferentiationInterface)
99

1010
backends = [
1111
AutoMooncake(),
12-
AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)),
1312
AutoMooncakeForward(),
13+
AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)),
1414
AutoMooncakeForward(; config = Mooncake.Config(; friendly_tangents = true)),
1515
]
1616

@@ -22,7 +22,8 @@ end
2222
test_differentiation(
2323
backends,
2424
default_scenarios(;
25-
include_constantified = true, include_cachified = true, use_tuples = true
25+
include_batchified = false,
26+
include_constantified = false, include_cachified = false, use_tuples = true
2627
);
2728
excluded = SECOND_ORDER,
2829
logging = LOGGING,
@@ -39,7 +40,7 @@ end
3940

4041
# Test second-order differentiation (forward-over-reverse)
4142
test_differentiation(
42-
[SecondOrder(AutoMooncakeForward(; config = nothing), AutoMooncake(; config = nothing))],
43+
[SecondOrder(AutoMooncakeForward(), AutoMooncake())],
4344
excluded = EXCLUDED,
4445
logging = true,
4546
)
@@ -52,4 +53,9 @@ test_differentiation(
5253
@test grad.B == ps.A
5354
end
5455

55-
# TODO: test static arrays with friendly tangents!
56+
test_differentiation(
57+
backends[3:4],
58+
static_scenarios();
59+
logging = LOGGING,
60+
excluded = SECOND_ORDER
61+
)

0 commit comments

Comments
 (0)