Skip to content

Commit ce72baf

Browse files
committed
fix: handle MutableTangent and forward mode tangent leaks
Also convert leaked Mooncake.MutableTangent (e.g. MVector tangents) and apply _maybe_to_primal in forward mode (pushforward) paths.
1 parent 7043da2 commit ce72baf

3 files changed

Lines changed: 4 additions & 3 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ function DI.value_and_pushforward(
4141
map(first_unwrap, contexts, prep.context_tangents)...,
4242
)
4343
y = first(y_and_dy)
44-
dy = _copy_output(last(y_and_dy))
44+
dy = _maybe_to_primal(last(y_and_dy), y)
4545
return y, dy
4646
end
4747
y = _copy_output(first(ys_and_ty[1]))

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ function DI.value_and_pushforward(
5555
(x, dx),
5656
map(first_unwrap, contexts, prep.context_tangents)...,
5757
)
58-
return _copy_output(new_dy)
58+
return _maybe_to_primal(new_dy, y)
5959
end
6060
return y, ty
6161
end
@@ -93,7 +93,7 @@ function DI.value_and_pushforward!(
9393
(x, dx),
9494
map(first_unwrap, contexts, prep.context_tangents)...,
9595
)
96-
copyto!(dy, new_dy)
96+
copyto!(dy, _maybe_to_primal(new_dy, y))
9797
end
9898
return y, ty
9999
end

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ end
2525
# convert it to a primal-shaped value. No-op for already-converted results.
2626
_maybe_to_primal(tx, x) = _copy_output(tx)
2727
_maybe_to_primal(tx::Mooncake.Tangent, x) = tangent_to_user_primal(tx, x)
28+
_maybe_to_primal(tx::Mooncake.MutableTangent, x) = tangent_to_user_primal(tx, x)
2829

2930
@inline maybe_getfield(mod, name::Symbol) =
3031
isdefined(mod, name) ? getfield(mod, name) : nothing

0 commit comments

Comments
 (0)