Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion DifferentiationInterface/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.15...main)
## [Unreleased](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.16...main)

## [0.7.16](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.15...DifferentiationInterface-v0.7.16)

### Fixed

- Upgrade Mooncake compat to v0.5 ([#961](https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/961))

## [0.7.15](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.14...DifferentiationInterface-v0.7.15)

Expand Down
4 changes: 2 additions & 2 deletions DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.7.15"
version = "0.7.16"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -69,7 +69,7 @@ ForwardDiff = "0.10.36,1"
GPUArraysCore = "0.2"
GTPSA = "1.4.0"
LinearAlgebra = "1"
Mooncake = "0.4.175"
Mooncake = "0.5.1"
PolyesterForwardDiff = "0.1.2"
ReverseDiff = "1.15.1"
SparseArrays = "1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ using Mooncake:
NoRData,
primal,
_copy_output,
_copy_to_output!!
_copy_to_output!!,
tangent_to_primal!!

const AnyAutoMooncake{C} = Union{AutoMooncake{C}, AutoMooncakeForward{C}}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
## Pushforward

struct MooncakeOneArgPushforwardPrep{SIG, Tcache, DX, FT, CT} <: DI.PushforwardPrep{SIG}
struct MooncakeOneArgPushforwardPrep{SIG, Tcache, FT, CT} <: DI.PushforwardPrep{SIG}
_sig::Val{SIG}
cache::Tcache
dx_righttype::DX
df::FT
context_tangents::CT
end
Expand All @@ -18,13 +17,10 @@ function DI.prepare_pushforward_nokwarg(
) where {F, C}
_sig = DI.signature(f, backend, x, tx, contexts...; strict)
config = get_config(backend)
cache = prepare_derivative_cache(
f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages
)
dx_righttype = zero_tangent(x)
df = zero_tangent(f)
cache = prepare_derivative_cache(f, x, map(DI.unwrap, contexts)...; config)
df = zero_tangent_or_primal(f, backend)
context_tangents = map(zero_tangent_unwrap, contexts)
prep = MooncakeOneArgPushforwardPrep(_sig, cache, dx_righttype, df, context_tangents)
prep = MooncakeOneArgPushforwardPrep(_sig, cache, df, context_tangents)
return prep
end

Expand All @@ -38,19 +34,17 @@ function DI.value_and_pushforward(
) where {F, C, X}
DI.check_prep(f, prep, backend, x, tx, contexts...)
ys_and_ty = map(tx) do dx
dx_righttype =
dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx)
y_dual = value_and_derivative!!(
y_and_dy = value_and_derivative!!(
prep.cache,
Dual(f, prep.df),
Dual(x, dx_righttype),
map(Dual_unwrap, contexts, prep.context_tangents)...,
(f, prep.df),
(x, dx),
map(first_unwrap, contexts, prep.context_tangents)...,
)
y = primal(y_dual)
dy = _copy_output(tangent(y_dual))
y = first(y_and_dy)
dy = _copy_output(last(y_and_dy))
return y, dy
end
y = first(ys_and_ty[1])
y = _copy_output(first(ys_and_ty[1]))
ty = map(last, ys_and_ty)
return y, ty
end
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
## Pushforward

struct MooncakeTwoArgPushforwardPrep{SIG, Tcache, DX, DY, FT, CT} <: DI.PushforwardPrep{SIG}
struct MooncakeTwoArgPushforwardPrep{SIG, Tcache, FT0, FT, YT, CT} <: DI.PushforwardPrep{SIG}
_sig::Val{SIG}
cache::Tcache
dx_righttype::DX
dy_righttype::DY
dcall::FT0
df!::FT
dy::YT
context_tangents::CT
end

Expand All @@ -21,18 +21,18 @@ function DI.prepare_pushforward_nokwarg(
_sig = DI.signature(f!, y, backend, x, tx, contexts...; strict)
config = get_config(backend)
cache = prepare_derivative_cache(
call_and_return,
f!,
y,
x,
map(DI.unwrap, contexts)...;
config.debug_mode,
config.silence_debug_messages,
config
)
dx_righttype = zero_tangent(x)
dy_righttype = zero_tangent(y)
df! = zero_tangent(f!)
dcall = zero_tangent_or_primal(call_and_return, backend)
df! = zero_tangent_or_primal(f!, backend)
dy = zero_tangent_or_primal(y, backend)
context_tangents = map(zero_tangent_unwrap, contexts)
prep = MooncakeTwoArgPushforwardPrep(_sig, cache, dx_righttype, dy_righttype, df!, context_tangents)
prep = MooncakeTwoArgPushforwardPrep(_sig, cache, dcall, df!, dy, context_tangents)
return prep
end

Expand All @@ -47,18 +47,15 @@ function DI.value_and_pushforward(
) where {F, C, X}
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
ty = map(tx) do dx
dx_righttype =
dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx)
y_dual = zero_dual(y)
value_and_derivative!!(
_, new_dy = value_and_derivative!!(
prep.cache,
Dual(f!, prep.df!),
y_dual,
Dual(x, dx_righttype),
map(Dual_unwrap, contexts, prep.context_tangents)...,
(call_and_return, prep.dcall),
(f!, prep.df!),
(y, prep.dy),
(x, dx),
map(first_unwrap, contexts, prep.context_tangents)...,
)
dy = _copy_output(tangent(y_dual))
return dy
return _copy_output(new_dy)
end
return y, ty
end
Expand Down Expand Up @@ -88,18 +85,15 @@ function DI.value_and_pushforward!(
) where {F, C, X, Y}
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
foreach(tx, ty) do dx, dy
dx_righttype =
dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx)
dy_righttype =
dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
value_and_derivative!!(
_, new_dy = value_and_derivative!!(
prep.cache,
Dual(f!, prep.df!),
Dual(y, dy_righttype),
Dual(x, dx_righttype),
map(Dual_unwrap, contexts, prep.context_tangents)...,
(call_and_return, prep.dcall),
(f!, prep.df!),
(y, dy),
(x, dx),
map(first_unwrap, contexts, prep.context_tangents)...,
)
dy === dy_righttype || copyto!(dy, dy_righttype)
copyto!(dy, new_dy)
end
return y, ty
end
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
## Pullback

struct MooncakeOneArgPullbackPrep{SIG, Tcache, DY, N} <: DI.PullbackPrep{SIG}
struct MooncakeOneArgPullbackPrep{SIG, Tcache, N} <: DI.PullbackPrep{SIG}
_sig::Val{SIG}
cache::Tcache
dy_righttype::DY
args_to_zero::NTuple{N, Bool}
end

Expand All @@ -12,18 +11,14 @@ function DI.prepare_pullback_nokwarg(
) where {F, C}
_sig = DI.signature(f, backend, x, ty, contexts...; strict)
config = get_config(backend)
cache = prepare_pullback_cache(
f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages
)
y = f(x, map(DI.unwrap, contexts)...)
dy_righttype = zero_tangent(y)
cache = prepare_pullback_cache(f, x, map(DI.unwrap, contexts)...; config)
contexts_tup_false = map(_ -> false, contexts)
args_to_zero = (
false, # f
true, # x
contexts_tup_false...,
)
prep = MooncakeOneArgPullbackPrep(_sig, cache, dy_righttype, args_to_zero)
prep = MooncakeOneArgPullbackPrep(_sig, cache, args_to_zero)
return prep
end

Expand All @@ -37,10 +32,8 @@ function DI.value_and_pullback(
) where {F, Y, C}
DI.check_prep(f, prep, backend, x, ty, contexts...)
dy = only(ty)
dy_righttype = dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
new_y, (_, new_dx) = value_and_pullback!!(
prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...;
prep.args_to_zero
prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero
)
return new_y, (_copy_output(new_dx),)
end
Expand All @@ -55,11 +48,8 @@ function DI.value_and_pullback(
) where {F, Y, C}
DI.check_prep(f, prep, backend, x, ty, contexts...)
ys_and_tx = map(ty) do dy
dy_righttype =
dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
y, (_, new_dx) = value_and_pullback!!(
prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...;
prep.args_to_zero
prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero
)
y, _copy_output(new_dx)
end
Expand Down Expand Up @@ -121,9 +111,7 @@ function DI.prepare_gradient_nokwarg(
) where {F, C}
_sig = DI.signature(f, backend, x, contexts...; strict)
config = get_config(backend)
cache = prepare_gradient_cache(
f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages
)
cache = prepare_gradient_cache(f, x, map(DI.unwrap, contexts)...; config)
contexts_tup_false = map(_ -> false, contexts)
args_to_zero = (
false, # f
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
struct MooncakeTwoArgPullbackPrep{SIG, Tcache, DY, F, N} <: DI.PullbackPrep{SIG}
struct MooncakeTwoArgPullbackPrep{SIG, Tcache, DY, N} <: DI.PullbackPrep{SIG}
_sig::Val{SIG}
cache::Tcache
dy_righttype::DY
target_function::F
dy_backup::DY
args_to_zero::NTuple{N, Bool}
end

Expand All @@ -16,31 +15,26 @@ function DI.prepare_pullback_nokwarg(
contexts::Vararg{DI.Context, C}
) where {F, C}
_sig = DI.signature(f!, y, backend, x, ty, contexts...; strict)
target_function = function (f!, y, x, contexts...)
f!(y, x, contexts...)
return y
end
config = get_config(backend)
cache = prepare_pullback_cache(
target_function,
call_and_return,
f!,
y,
x,
map(DI.unwrap, contexts)...;
debug_mode = config.debug_mode,
silence_debug_messages = config.silence_debug_messages,
config,
)
dy_righttype_after = zero_tangent(y)
dy_backup = zero_tangent_or_primal(y, backend)
contexts_tup_false = map(_ -> false, contexts)
args_to_zero = (
false, # target_function
false, # call_and_return
false, # f!
false, # y
true, # x
contexts_tup_false...,
)
prep = MooncakeTwoArgPullbackPrep(
_sig, cache, dy_righttype_after, target_function, args_to_zero
_sig, cache, dy_backup, args_to_zero
)
return prep
end
Expand All @@ -57,12 +51,12 @@ function DI.value_and_pullback(
DI.check_prep(f!, y, prep, backend, x, ty, contexts...)
dy = only(ty)
# Prepare cotangent to add after the forward pass.
dy_righttype_after = copyto!(prep.dy_righttype, dy)
dy_backup = copyto!(prep.dy_backup, dy)
# Run the reverse-pass and return the results.
y_after, (_, _, _, dx) = value_and_pullback!!(
prep.cache,
dy_righttype_after,
prep.target_function,
dy_backup,
call_and_return,
f!,
y,
x,
Expand All @@ -84,11 +78,11 @@ function DI.value_and_pullback(
) where {F, C}
DI.check_prep(f!, y, prep, backend, x, ty, contexts...)
tx = map(ty) do dy
dy_righttype_after = copyto!(prep.dy_righttype, dy)
dy_backup = copyto!(prep.dy_backup, dy)
y_after, (_, _, _, dx) = value_and_pullback!!(
prep.cache,
dy_righttype_after,
prep.target_function,
dy_backup,
call_and_return,
f!,
y,
x,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,18 @@ get_config(::AnyAutoMooncake{Nothing}) = Config()
get_config(backend::AnyAutoMooncake{<:Config}) = backend.config

@inline zero_tangent_unwrap(c::DI.Context) = zero_tangent(DI.unwrap(c))
@inline Dual_unwrap(c, dc) = Dual(DI.unwrap(c), dc)
@inline first_unwrap(c, dc) = (DI.unwrap(c), dc)

function call_and_return(f!::F, y, x, contexts...) where {F}
f!(y, x, contexts...)
return y
end

function zero_tangent_or_primal(x, backend::AnyAutoMooncake)
if get_config(backend).friendly_tangents
# zero(x) but safer
return tangent_to_primal!!(_copy_output(x), zero_tangent(x))
else
return zero_tangent(x)
end
end
1 change: 1 addition & 0 deletions DifferentiationInterface/test/Back/Mooncake/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Loading
Loading