Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"]
DifferentiationInterfaceGPUArraysCoreExt = "GPUArraysCore"
DifferentiationInterfaceGTPSAExt = "GTPSA"
DifferentiationInterfaceMooncakeExt = "Mooncake"
DifferentiationInterfacePolyesterForwardDiffExt = ["PolyesterForwardDiff", "ForwardDiff", "DiffResults"]
DifferentiationInterfacePolyesterForwardDiffExt = [
"PolyesterForwardDiff",
"ForwardDiff",
"DiffResults",
]
DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"]
DifferentiationInterfaceSparseArraysExt = "SparseArrays"
DifferentiationInterfaceSparseConnectivityTracerExt = "SparseConnectivityTracer"
Expand All @@ -65,7 +69,7 @@ ForwardDiff = "0.10.36,1"
GPUArraysCore = "0.2"
GTPSA = "1.4.0"
LinearAlgebra = "1"
Mooncake = "0.4.147"
Mooncake = "0.4.175"
PolyesterForwardDiff = "0.1.2"
ReverseDiff = "1.15.1"
SparseArrays = "1"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
## Pullback

struct MooncakeOneArgPullbackPrep{SIG, Tcache, DY} <: DI.PullbackPrep{SIG}
struct MooncakeOneArgPullbackPrep{SIG, Tcache, DY, N} <: DI.PullbackPrep{SIG}
_sig::Val{SIG}
cache::Tcache
dy_righttype::DY
args_to_zero::NTuple{N, Bool}
end

function DI.prepare_pullback_nokwarg(
Expand All @@ -16,7 +17,13 @@ function DI.prepare_pullback_nokwarg(
)
y = f(x, map(DI.unwrap, contexts)...)
dy_righttype = zero_tangent(y)
prep = MooncakeOneArgPullbackPrep(_sig, cache, dy_righttype)
contexts_tup_false = map(_ -> false, contexts)
args_to_zero = (
false, # f
true, # x
contexts_tup_false...,
)
prep = MooncakeOneArgPullbackPrep(_sig, cache, dy_righttype, args_to_zero)
return prep
end

Expand All @@ -32,7 +39,8 @@ function DI.value_and_pullback(
dy = only(ty)
dy_righttype = dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
new_y, (_, new_dx) = value_and_pullback!!(
prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...
prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...;
prep.args_to_zero
)
return new_y, (_copy_output(new_dx),)
end
Expand All @@ -50,7 +58,8 @@ function DI.value_and_pullback(
dy_righttype =
dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
y, (_, new_dx) = value_and_pullback!!(
prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...
prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...;
prep.args_to_zero
)
y, _copy_output(new_dx)
end
Expand Down Expand Up @@ -101,9 +110,10 @@ end

## Gradient

struct MooncakeGradientPrep{SIG, Tcache} <: DI.GradientPrep{SIG}
struct MooncakeGradientPrep{SIG, Tcache, N} <: DI.GradientPrep{SIG}
_sig::Val{SIG}
cache::Tcache
args_to_zero::NTuple{N, Bool}
end

function DI.prepare_gradient_nokwarg(
Expand All @@ -114,7 +124,13 @@ function DI.prepare_gradient_nokwarg(
cache = prepare_gradient_cache(
f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages
)
prep = MooncakeGradientPrep(_sig, cache)
contexts_tup_false = map(_ -> false, contexts)
args_to_zero = (
false, # f
true, # x
contexts_tup_false...,
)
prep = MooncakeGradientPrep(_sig, cache, args_to_zero)
return prep
end

Expand All @@ -126,7 +142,10 @@ function DI.value_and_gradient(
contexts::Vararg{DI.Context, C},
) where {F, C}
DI.check_prep(f, prep, backend, x, contexts...)
y, (_, new_grad) = value_and_gradient!!(prep.cache, f, x, map(DI.unwrap, contexts)...)
y, (_, new_grad) = value_and_gradient!!(
prep.cache, f, x, map(DI.unwrap, contexts)...;
prep.args_to_zero
)
return y, _copy_output(new_grad)
end

Expand All @@ -139,7 +158,10 @@ function DI.value_and_gradient!(
contexts::Vararg{DI.Context, C},
) where {F, C}
DI.check_prep(f, prep, backend, x, contexts...)
y, (_, new_grad) = value_and_gradient!!(prep.cache, f, x, map(DI.unwrap, contexts)...)
y, (_, new_grad) = value_and_gradient!!(
prep.cache, f, x, map(DI.unwrap, contexts)...;
prep.args_to_zero
)
copyto!(grad, new_grad)
return y, grad
end
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
struct MooncakeTwoArgPullbackPrep{SIG, Tcache, DY, F} <: DI.PullbackPrep{SIG}
struct MooncakeTwoArgPullbackPrep{SIG, Tcache, DY, F, N} <: DI.PullbackPrep{SIG}
_sig::Val{SIG}
cache::Tcache
dy_righttype::DY
target_function::F
args_to_zero::NTuple{N, Bool}
end

function DI.prepare_pullback_nokwarg(
Expand Down Expand Up @@ -30,7 +31,17 @@ function DI.prepare_pullback_nokwarg(
silence_debug_messages = config.silence_debug_messages,
)
dy_righttype_after = zero_tangent(y)
prep = MooncakeTwoArgPullbackPrep(_sig, cache, dy_righttype_after, target_function)
contexts_tup_false = map(_ -> false, contexts)
args_to_zero = (
false, # target_function
false, # f!
false, # y
true, # x
contexts_tup_false...,
)
prep = MooncakeTwoArgPullbackPrep(
_sig, cache, dy_righttype_after, target_function, args_to_zero
)
return prep
end

Expand All @@ -55,7 +66,8 @@ function DI.value_and_pullback(
f!,
y,
x,
map(DI.unwrap, contexts)...,
map(DI.unwrap, contexts)...;
prep.args_to_zero
)
copyto!(y, y_after)
return y, (_copy_output(dx),)
Expand All @@ -80,7 +92,8 @@ function DI.value_and_pullback(
f!,
y,
x,
map(DI.unwrap, contexts)...,
map(DI.unwrap, contexts)...;
prep.args_to_zero
)
copyto!(y, y_after)
_copy_output(dx)
Expand Down
Loading