Skip to content

Commit 7979f75

Browse files
committed
Friendly tangents in forward mode
1 parent a93e0d1 commit 7979f75

3 files changed

Lines changed: 30 additions & 16 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ using Mooncake:
2929
NoRData,
3030
primal,
3131
_copy_output,
32-
_copy_to_output!!
32+
_copy_to_output!!,
33+
primal_to_tangent!!,
34+
tangent_to_primal!!
3335

3436
const AnyAutoMooncake{C} = Union{AutoMooncake{C}, AutoMooncakeForward{C}}
3537

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
## Pushforward
22

3-
# TODO: needs friendly tangents support
4-
53
struct MooncakeOneArgPushforwardPrep{SIG, Tcache, DX, FT, CT} <: DI.PushforwardPrep{SIG}
64
_sig::Val{SIG}
75
cache::Tcache
@@ -23,8 +21,12 @@ function DI.prepare_pushforward_nokwarg(
2321
cache = prepare_derivative_cache(
2422
f, x, map(DI.unwrap, contexts)...; config
2523
)
26-
dx_righttype = zero_tangent(x)
2724
df = zero_tangent(f)
25+
if config.friendly_tangents
26+
dx_righttype = zero_tangent(x)
27+
else
28+
dx_righttype = nothing
29+
end
2830
context_tangents = map(zero_tangent_unwrap, contexts)
2931
prep = MooncakeOneArgPushforwardPrep(_sig, cache, dx_righttype, df, context_tangents)
3032
return prep
@@ -40,16 +42,19 @@ function DI.value_and_pushforward(
4042
) where {F, C, X}
4143
DI.check_prep(f, prep, backend, x, tx, contexts...)
4244
ys_and_ty = map(tx) do dx
43-
dx_righttype =
44-
dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx)
45+
dx_righttype = isnothing(prep.dx_righttype) ? dx : primal_to_tangent!!(prep.dx_righttype, dx)
4546
y_dual = value_and_derivative!!(
4647
prep.cache,
4748
Dual(f, prep.df),
4849
Dual(x, dx_righttype),
4950
map(Dual_unwrap, contexts, prep.context_tangents)...,
5051
)
5152
y = primal(y_dual)
52-
dy = _copy_output(tangent(y_dual))
53+
if isnothing(prep.dx_righttype)
54+
dy = _copy_output(tangent(y_dual))
55+
else
56+
dy = tangent_to_primal!!(_copy_output(y), tangent(y_dual))
57+
end
5358
return y, dy
5459
end
5560
y = first(ys_and_ty[1])

DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
## Pushforward
22

3-
# TODO: needs friendly tangents support
4-
53
struct MooncakeTwoArgPushforwardPrep{SIG, Tcache, DX, DY, FT, CT} <: DI.PushforwardPrep{SIG}
64
_sig::Val{SIG}
75
cache::Tcache
@@ -29,8 +27,13 @@ function DI.prepare_pushforward_nokwarg(
2927
map(DI.unwrap, contexts)...;
3028
config,
3129
)
32-
dx_righttype = zero_tangent(x)
33-
dy_righttype = zero_tangent(y)
30+
if config.friendly_tangents
31+
dx_righttype = zero_tangent(x)
32+
dy_righttype = zero_tangent(y)
33+
else
34+
dx_righttype = nothing
35+
dy_righttype = nothing
36+
end
3437
df! = zero_tangent(f!)
3538
context_tangents = map(zero_tangent_unwrap, contexts)
3639
prep = MooncakeTwoArgPushforwardPrep(_sig, cache, dx_righttype, dy_righttype, df!, context_tangents)
@@ -49,7 +52,7 @@ function DI.value_and_pushforward(
4952
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
5053
ty = map(tx) do dx
5154
dx_righttype =
52-
dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx)
55+
isnothing(prep.dx_righttype) ? dx : primal_to_tangent!!(prep.dx_righttype, dx)
5356
y_dual = zero_dual(y)
5457
value_and_derivative!!(
5558
prep.cache,
@@ -58,7 +61,11 @@ function DI.value_and_pushforward(
5861
Dual(x, dx_righttype),
5962
map(Dual_unwrap, contexts, prep.context_tangents)...,
6063
)
61-
dy = _copy_output(tangent(y_dual))
64+
if isnothing(prep.dx_righttype)
65+
dy = _copy_output(tangent(y_dual))
66+
else
67+
dy = tangent_to_primal!!(_copy_output(y), tangent(y_dual))
68+
end
6269
return dy
6370
end
6471
return y, ty
@@ -90,17 +97,17 @@ function DI.value_and_pushforward!(
9097
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
9198
foreach(tx, ty) do dx, dy
9299
dx_righttype =
93-
dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx)
100+
isnothing(prep.dx_righttype) ? dx : primal_to_tangent!!(prep.dx_righttype, dx)
94101
dy_righttype =
95-
dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
102+
isnothing(prep.dy_righttype) ? dy : primal_to_tangent!!(prep.dy_righttype, dy)
96103
value_and_derivative!!(
97104
prep.cache,
98105
Dual(f!, prep.df!),
99106
Dual(y, dy_righttype),
100107
Dual(x, dx_righttype),
101108
map(Dual_unwrap, contexts, prep.context_tangents)...,
102109
)
103-
dy === dy_righttype || copyto!(dy, dy_righttype)
110+
isnothing(prep.dy_righttype) || tangent_to_primal!!(dy, dy_righttype)
104111
end
105112
return y, ty
106113
end

0 commit comments

Comments
 (0)