Skip to content

Commit 7043da2

Browse files
committed
fix: convert raw Mooncake.Tangent in pullback/gradient results
On Julia 1.11, Mooncake may return raw Tangent objects instead of friendly arrays for StaticArrays even with friendly_tangents=true. Add _maybe_to_primal dispatch as a safety net that converts leaked Tangent objects to primal-shaped values, no-op otherwise.
1 parent 106f50f commit 7043da2

4 files changed

Lines changed: 17 additions & 15 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ function DI.value_and_pullback(
3535
new_y, (_, new_dx) = value_and_pullback!!(
3636
prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero
3737
)
38-
return new_y, (_copy_output(new_dx),)
38+
return new_y, (_maybe_to_primal(new_dx, x),)
3939
end
4040

4141
function DI.value_and_pullback(
@@ -51,7 +51,7 @@ function DI.value_and_pullback(
5151
y, (_, new_dx) = value_and_pullback!!(
5252
prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero
5353
)
54-
y, _copy_output(new_dx)
54+
y, _maybe_to_primal(new_dx, x)
5555
end
5656
y = first(ys_and_tx[1])
5757
tx = map(last, ys_and_tx)
@@ -134,7 +134,7 @@ function DI.value_and_gradient(
134134
prep.cache, f, x, map(DI.unwrap, contexts)...;
135135
prep.args_to_zero
136136
)
137-
return y, _copy_output(new_grad)
137+
return y, _maybe_to_primal(new_grad, x)
138138
end
139139

140140
function DI.value_and_gradient!(
@@ -150,7 +150,7 @@ function DI.value_and_gradient!(
150150
prep.cache, f, x, map(DI.unwrap, contexts)...;
151151
prep.args_to_zero
152152
)
153-
copyto!(grad, new_grad)
153+
copyto!(grad, _maybe_to_primal(new_grad, x))
154154
return y, grad
155155
end
156156

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ function DI.value_and_pullback(
6464
prep.args_to_zero
6565
)
6666
copyto!(y, y_after)
67-
return y, (_copy_output(dx),)
67+
return y, (_maybe_to_primal(dx, x),)
6868
end
6969

7070
function DI.value_and_pullback(
@@ -90,7 +90,7 @@ function DI.value_and_pullback(
9090
prep.args_to_zero
9191
)
9292
copyto!(y, y_after)
93-
_copy_output(dx)
93+
_maybe_to_primal(dx, x)
9494
end
9595
return y, tx
9696
end

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,11 @@ function zero_tangent_or_primal(x, backend::AnyAutoMooncake)
2121
end
2222
end
2323

24+
# Safety net: if Mooncake returns a raw Tangent (e.g. Julia 1.11 + StaticArrays),
25+
# convert it to a primal-shaped value. No-op for already-converted results.
26+
_maybe_to_primal(tx, x) = _copy_output(tx)
27+
_maybe_to_primal(tx::Mooncake.Tangent, x) = tangent_to_user_primal(tx, x)
28+
2429
@inline maybe_getfield(mod, name::Symbol) =
2530
isdefined(mod, name) ? getfield(mod, name) : nothing
2631

DifferentiationInterface/test/Back/Mooncake/test.jl

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,12 @@ test_differentiation(
7575
@test grad.B == ps.A
7676
end
7777

78-
# friendly_tangents + StaticArrays broken on Julia 1.11 (upstream Mooncake bug)
79-
@static if !(VERSION v"1.11-" && VERSION < v"1.12-")
80-
test_differentiation(
81-
backends[3:4],
82-
nomatrix(static_scenarios());
83-
logging = LOGGING,
84-
excluded = SECOND_ORDER,
85-
)
86-
end
78+
test_differentiation(
79+
backends[3:4],
80+
nomatrix(static_scenarios());
81+
logging = LOGGING,
82+
excluded = SECOND_ORDER,
83+
)
8784

8885
@testset "Friendly tangents structured matrices" begin
8986
backend = AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true))

0 commit comments

Comments
 (0)