Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ ForwardDiff = "0.10.36,1"
GPUArraysCore = "0.2"
GTPSA = "1.4.0"
LinearAlgebra = "1"
Mooncake = "0.4.175"
Mooncake = "0.5.0"
PolyesterForwardDiff = "0.1.2"
ReverseDiff = "1.15.1"
SparseArrays = "1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ using Mooncake:
NoRData,
primal,
_copy_output,
_copy_to_output!!
_copy_to_output!!,
primal_to_tangent!!,
tangent_to_primal!!

const AnyAutoMooncake{C} = Union{AutoMooncake{C}, AutoMooncakeForward{C}}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@ function DI.prepare_pushforward_nokwarg(
_sig = DI.signature(f, backend, x, tx, contexts...; strict)
config = get_config(backend)
cache = prepare_derivative_cache(
f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages
f, x, map(DI.unwrap, contexts)...; config
)
dx_righttype = zero_tangent(x)
df = zero_tangent(f)
if config.friendly_tangents
dx_righttype = zero_tangent(x)
else
dx_righttype = nothing
end
context_tangents = map(zero_tangent_unwrap, contexts)
prep = MooncakeOneArgPushforwardPrep(_sig, cache, dx_righttype, df, context_tangents)
return prep
Expand All @@ -38,16 +42,19 @@ function DI.value_and_pushforward(
) where {F, C, X}
DI.check_prep(f, prep, backend, x, tx, contexts...)
ys_and_ty = map(tx) do dx
dx_righttype =
dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx)
dx_righttype = isnothing(prep.dx_righttype) ? dx : primal_to_tangent!!(prep.dx_righttype, dx)
y_dual = value_and_derivative!!(
prep.cache,
Dual(f, prep.df),
Dual(x, dx_righttype),
map(Dual_unwrap, contexts, prep.context_tangents)...,
)
y = primal(y_dual)
dy = _copy_output(tangent(y_dual))
if isnothing(prep.dx_righttype)
dy = _copy_output(tangent(y_dual))
else
dy = tangent_to_primal!!(_copy_output(y), tangent(y_dual))
end
return y, dy
end
y = first(ys_and_ty[1])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,15 @@ function DI.prepare_pushforward_nokwarg(
y,
x,
map(DI.unwrap, contexts)...;
config.debug_mode,
config.silence_debug_messages,
config,
)
dx_righttype = zero_tangent(x)
dy_righttype = zero_tangent(y)
if config.friendly_tangents
dx_righttype = zero_tangent(x)
dy_righttype = zero_tangent(y)
else
dx_righttype = nothing
dy_righttype = nothing
end
df! = zero_tangent(f!)
context_tangents = map(zero_tangent_unwrap, contexts)
prep = MooncakeTwoArgPushforwardPrep(_sig, cache, dx_righttype, dy_righttype, df!, context_tangents)
Expand All @@ -48,7 +52,7 @@ function DI.value_and_pushforward(
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
ty = map(tx) do dx
dx_righttype =
dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx)
isnothing(prep.dx_righttype) ? dx : primal_to_tangent!!(prep.dx_righttype, dx)
y_dual = zero_dual(y)
value_and_derivative!!(
prep.cache,
Expand All @@ -57,7 +61,11 @@ function DI.value_and_pushforward(
Dual(x, dx_righttype),
map(Dual_unwrap, contexts, prep.context_tangents)...,
)
dy = _copy_output(tangent(y_dual))
if isnothing(prep.dx_righttype)
dy = _copy_output(tangent(y_dual))
else
dy = tangent_to_primal!!(_copy_output(y), tangent(y_dual))
end
return dy
end
return y, ty
Expand Down Expand Up @@ -89,17 +97,17 @@ function DI.value_and_pushforward!(
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
foreach(tx, ty) do dx, dy
dx_righttype =
dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx)
isnothing(prep.dx_righttype) ? dx : primal_to_tangent!!(prep.dx_righttype, dx)
dy_righttype =
dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
isnothing(prep.dy_righttype) ? dy : primal_to_tangent!!(prep.dy_righttype, dy)
value_and_derivative!!(
prep.cache,
Dual(f!, prep.df!),
Dual(y, dy_righttype),
Dual(x, dx_righttype),
map(Dual_unwrap, contexts, prep.context_tangents)...,
)
dy === dy_righttype || copyto!(dy, dy_righttype)
isnothing(prep.dy_righttype) || tangent_to_primal!!(dy, dy_righttype)
end
return y, ty
end
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
## Pullback

struct MooncakeOneArgPullbackPrep{SIG, Tcache, DY, N} <: DI.PullbackPrep{SIG}
struct MooncakeOneArgPullbackPrep{SIG, Tcache, N} <: DI.PullbackPrep{SIG}
_sig::Val{SIG}
cache::Tcache
dy_righttype::DY
args_to_zero::NTuple{N, Bool}
end

Expand All @@ -12,18 +11,14 @@ function DI.prepare_pullback_nokwarg(
) where {F, C}
_sig = DI.signature(f, backend, x, ty, contexts...; strict)
config = get_config(backend)
cache = prepare_pullback_cache(
f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages
)
y = f(x, map(DI.unwrap, contexts)...)
dy_righttype = zero_tangent(y)
cache = prepare_pullback_cache(f, x, map(DI.unwrap, contexts)...; config)
contexts_tup_false = map(_ -> false, contexts)
args_to_zero = (
false, # f
true, # x
contexts_tup_false...,
)
prep = MooncakeOneArgPullbackPrep(_sig, cache, dy_righttype, args_to_zero)
prep = MooncakeOneArgPullbackPrep(_sig, cache, args_to_zero)
return prep
end

Expand All @@ -37,9 +32,8 @@ function DI.value_and_pullback(
) where {F, Y, C}
DI.check_prep(f, prep, backend, x, ty, contexts...)
dy = only(ty)
dy_righttype = dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
new_y, (_, new_dx) = value_and_pullback!!(
prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...;
prep.cache, dy, f, x, map(DI.unwrap, contexts)...;
prep.args_to_zero
)
return new_y, (_copy_output(new_dx),)
Expand All @@ -55,10 +49,8 @@ function DI.value_and_pullback(
) where {F, Y, C}
DI.check_prep(f, prep, backend, x, ty, contexts...)
ys_and_tx = map(ty) do dy
dy_righttype =
dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
y, (_, new_dx) = value_and_pullback!!(
prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...;
prep.cache, dy, f, x, map(DI.unwrap, contexts)...;
prep.args_to_zero
)
y, _copy_output(new_dx)
Expand Down Expand Up @@ -121,9 +113,7 @@ function DI.prepare_gradient_nokwarg(
) where {F, C}
_sig = DI.signature(f, backend, x, contexts...; strict)
config = get_config(backend)
cache = prepare_gradient_cache(
f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages
)
cache = prepare_gradient_cache(f, x, map(DI.unwrap, contexts)...; config)
contexts_tup_false = map(_ -> false, contexts)
args_to_zero = (
false, # f
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
struct MooncakeTwoArgPullbackPrep{SIG, Tcache, DY, F, N} <: DI.PullbackPrep{SIG}
struct MooncakeTwoArgPullbackPrep{SIG, Tcache, F, N} <: DI.PullbackPrep{SIG}
_sig::Val{SIG}
cache::Tcache
dy_righttype::DY
target_function::F
args_to_zero::NTuple{N, Bool}
end
Expand All @@ -27,10 +26,8 @@ function DI.prepare_pullback_nokwarg(
y,
x,
map(DI.unwrap, contexts)...;
debug_mode = config.debug_mode,
silence_debug_messages = config.silence_debug_messages,
config,
)
dy_righttype_after = zero_tangent(y)
contexts_tup_false = map(_ -> false, contexts)
args_to_zero = (
false, # target_function
Expand All @@ -39,9 +36,7 @@ function DI.prepare_pullback_nokwarg(
true, # x
contexts_tup_false...,
)
prep = MooncakeTwoArgPullbackPrep(
_sig, cache, dy_righttype_after, target_function, args_to_zero
)
prep = MooncakeTwoArgPullbackPrep(_sig, cache, target_function, args_to_zero)
return prep
end

Expand All @@ -56,12 +51,10 @@ function DI.value_and_pullback(
) where {F, C}
DI.check_prep(f!, y, prep, backend, x, ty, contexts...)
dy = only(ty)
# Prepare cotangent to add after the forward pass.
dy_righttype_after = copyto!(prep.dy_righttype, dy)
# Run the reverse-pass and return the results.
y_after, (_, _, _, dx) = value_and_pullback!!(
prep.cache,
dy_righttype_after,
dy,
prep.target_function,
f!,
y,
Expand Down
Loading