Skip to content

Commit 4bf5dc6

Browse files
committed
No dual output
1 parent a62c518 commit 4bf5dc6

2 files changed

Lines changed: 5 additions & 6 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,14 @@ function DI.value_and_pushforward(
3636
) where {F, C, X}
3737
DI.check_prep(f, prep, backend, x, tx, contexts...)
3838
ys_and_ty = map(tx) do dx
39-
y_dual = value_and_derivative!!(
39+
y_and_dy = value_and_derivative!!(
4040
prep.cache,
4141
(f, prep.df),
4242
(x, dx),
4343
map(first_unwrap, contexts, prep.context_tangents)...,
4444
)
45-
y = primal(y_dual)
46-
dy = _copy_output(tangent(y_dual))
45+
y = first(y_and_dy)
46+
dy = _copy_output(last(y_and_dy))
4747
return y, dy
4848
end
4949
y = first(ys_and_ty[1])

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,14 @@ function DI.value_and_pushforward(
4242
) where {F, C, X}
4343
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
4444
ty = map(tx) do dx
45-
y_dual = zero_dual(y)
45+
dy = zero_tangent(y) # TODO: remove allocation?
4646
value_and_derivative!!(
4747
prep.cache,
4848
(f!, prep.df!),
49-
y_dual,
49+
(y, dy),
5050
(x, dx),
5151
map(first_unwrap, contexts, prep.context_tangents)...,
5252
)
53-
dy = _copy_output(tangent(y_dual))
5453
return dy
5554
end
5655
return y, ty

0 commit comments

Comments
 (0)