Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion DifferentiationInterface/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.15...main)
## [Unreleased](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.16...main)

## [0.7.16](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.15...DifferentiationInterface-v0.7.16)

### Fixed

- Upgrade Mooncake compat to v0.5 ([#961](https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/961))

## [0.7.15](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.14...DifferentiationInterface-v0.7.15)

Expand Down
4 changes: 2 additions & 2 deletions DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.7.15"
version = "0.7.16"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -69,7 +69,7 @@ ForwardDiff = "0.10.36,1"
GPUArraysCore = "0.2"
GTPSA = "1.4.0"
LinearAlgebra = "1"
Mooncake = "0.4.175"
Mooncake = "0.5.1"
PolyesterForwardDiff = "0.1.2"
ReverseDiff = "1.15.1"
SparseArrays = "1"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ using Mooncake:
NoRData,
primal,
_copy_output,
_copy_to_output!!
_copy_to_output!!,
tangent_to_primal!!

const AnyAutoMooncake{C} = Union{AutoMooncake{C}, AutoMooncakeForward{C}}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
## Pushforward

struct MooncakeOneArgPushforwardPrep{SIG, Tcache, DX, FT, CT} <: DI.PushforwardPrep{SIG}
struct MooncakeOneArgPushforwardPrep{SIG, Tcache, FT, CT} <: DI.PushforwardPrep{SIG}
_sig::Val{SIG}
cache::Tcache
dx_righttype::DX
df::FT
context_tangents::CT
end
Expand All @@ -18,13 +17,10 @@ function DI.prepare_pushforward_nokwarg(
) where {F, C}
_sig = DI.signature(f, backend, x, tx, contexts...; strict)
config = get_config(backend)
cache = prepare_derivative_cache(
f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages
)
dx_righttype = zero_tangent(x)
df = zero_tangent(f)
cache = prepare_derivative_cache(f, x, map(DI.unwrap, contexts)...; config)
df = zero_tangent_or_primal(f, backend)
context_tangents = map(zero_tangent_unwrap, contexts)
prep = MooncakeOneArgPushforwardPrep(_sig, cache, dx_righttype, df, context_tangents)
prep = MooncakeOneArgPushforwardPrep(_sig, cache, df, context_tangents)
return prep
end

Expand All @@ -38,19 +34,17 @@ function DI.value_and_pushforward(
) where {F, C, X}
DI.check_prep(f, prep, backend, x, tx, contexts...)
ys_and_ty = map(tx) do dx
dx_righttype =
dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx)
y_dual = value_and_derivative!!(
y_and_dy = value_and_derivative!!(
prep.cache,
Dual(f, prep.df),
Dual(x, dx_righttype),
map(Dual_unwrap, contexts, prep.context_tangents)...,
(f, prep.df),
(x, dx),
map(first_unwrap, contexts, prep.context_tangents)...,
)
y = primal(y_dual)
dy = _copy_output(tangent(y_dual))
y = first(y_and_dy)
dy = _copy_output(last(y_and_dy))
return y, dy
end
y = first(ys_and_ty[1])
y = _copy_output(first(ys_and_ty[1]))
ty = map(last, ys_and_ty)
return y, ty
end
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
## Pushforward

struct MooncakeTwoArgPushforwardPrep{SIG, Tcache, DX, DY, FT, CT} <: DI.PushforwardPrep{SIG}
struct MooncakeTwoArgPushforwardPrep{SIG, Tcache, FT, CT} <: DI.PushforwardPrep{SIG}
_sig::Val{SIG}
cache::Tcache
dx_righttype::DX
dy_righttype::DY
df!::FT
context_tangents::CT
end
Expand All @@ -21,18 +19,16 @@ function DI.prepare_pushforward_nokwarg(
_sig = DI.signature(f!, y, backend, x, tx, contexts...; strict)
config = get_config(backend)
cache = prepare_derivative_cache(
call_and_return,
f!,
y,
x,
map(DI.unwrap, contexts)...;
config.debug_mode,
config.silence_debug_messages,
config
)
dx_righttype = zero_tangent(x)
dy_righttype = zero_tangent(y)
df! = zero_tangent(f!)
df! = zero_tangent_or_primal(f!, backend)
context_tangents = map(zero_tangent_unwrap, contexts)
prep = MooncakeTwoArgPushforwardPrep(_sig, cache, dx_righttype, dy_righttype, df!, context_tangents)
prep = MooncakeTwoArgPushforwardPrep(_sig, cache, df!, context_tangents)
return prep
end

Expand All @@ -47,18 +43,17 @@ function DI.value_and_pushforward(
) where {F, C, X}
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
ty = map(tx) do dx
dx_righttype =
dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx)
y_dual = zero_dual(y)
value_and_derivative!!(
dy = zero_tangent_or_primal(y, backend) # TODO: remove allocation?
dcall = zero_tangent_or_primal(call_and_return, backend)
_, new_dy = value_and_derivative!!(
prep.cache,
Dual(f!, prep.df!),
y_dual,
Dual(x, dx_righttype),
map(Dual_unwrap, contexts, prep.context_tangents)...,
(call_and_return, dcall),
(f!, prep.df!),
(y, dy),
(x, dx),
map(first_unwrap, contexts, prep.context_tangents)...,
)
dy = _copy_output(tangent(y_dual))
return dy
return _copy_output(new_dy)
end
return y, ty
end
Expand Down Expand Up @@ -88,18 +83,16 @@ function DI.value_and_pushforward!(
) where {F, C, X, Y}
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
foreach(tx, ty) do dx, dy
dx_righttype =
dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx)
dy_righttype =
dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
value_and_derivative!!(
dcall = zero_tangent_or_primal(call_and_return, backend)
_, new_dy = value_and_derivative!!(
prep.cache,
Dual(f!, prep.df!),
Dual(y, dy_righttype),
Dual(x, dx_righttype),
map(Dual_unwrap, contexts, prep.context_tangents)...,
(call_and_return, dcall),
(f!, prep.df!),
(y, dy),
(x, dx),
map(first_unwrap, contexts, prep.context_tangents)...,
)
dy === dy_righttype || copyto!(dy, dy_righttype)
copyto!(dy, new_dy)
end
return y, ty
end
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
## Pullback

struct MooncakeOneArgPullbackPrep{SIG, Tcache, DY, N} <: DI.PullbackPrep{SIG}
struct MooncakeOneArgPullbackPrep{SIG, Tcache, N} <: DI.PullbackPrep{SIG}
_sig::Val{SIG}
cache::Tcache
dy_righttype::DY
args_to_zero::NTuple{N, Bool}
end

Expand All @@ -12,18 +11,14 @@ function DI.prepare_pullback_nokwarg(
) where {F, C}
_sig = DI.signature(f, backend, x, ty, contexts...; strict)
config = get_config(backend)
cache = prepare_pullback_cache(
f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages
)
y = f(x, map(DI.unwrap, contexts)...)
dy_righttype = zero_tangent(y)
cache = prepare_pullback_cache(f, x, map(DI.unwrap, contexts)...; config)
contexts_tup_false = map(_ -> false, contexts)
args_to_zero = (
false, # f
true, # x
contexts_tup_false...,
)
prep = MooncakeOneArgPullbackPrep(_sig, cache, dy_righttype, args_to_zero)
prep = MooncakeOneArgPullbackPrep(_sig, cache, args_to_zero)
return prep
end

Expand All @@ -37,10 +32,8 @@ function DI.value_and_pullback(
) where {F, Y, C}
DI.check_prep(f, prep, backend, x, ty, contexts...)
dy = only(ty)
dy_righttype = dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
new_y, (_, new_dx) = value_and_pullback!!(
prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...;
prep.args_to_zero
prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero
)
return new_y, (_copy_output(new_dx),)
end
Expand All @@ -55,11 +48,8 @@ function DI.value_and_pullback(
) where {F, Y, C}
DI.check_prep(f, prep, backend, x, ty, contexts...)
ys_and_tx = map(ty) do dy
dy_righttype =
dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy)
y, (_, new_dx) = value_and_pullback!!(
prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...;
prep.args_to_zero
prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero
)
y, _copy_output(new_dx)
end
Expand Down Expand Up @@ -121,9 +111,7 @@ function DI.prepare_gradient_nokwarg(
) where {F, C}
_sig = DI.signature(f, backend, x, contexts...; strict)
config = get_config(backend)
cache = prepare_gradient_cache(
f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages
)
cache = prepare_gradient_cache(f, x, map(DI.unwrap, contexts)...; config)
contexts_tup_false = map(_ -> false, contexts)
args_to_zero = (
false, # f
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
struct MooncakeTwoArgPullbackPrep{SIG, Tcache, DY, F, N} <: DI.PullbackPrep{SIG}
struct MooncakeTwoArgPullbackPrep{SIG, Tcache, DY, N} <: DI.PullbackPrep{SIG}
_sig::Val{SIG}
cache::Tcache
dy_righttype::DY
target_function::F
dy_backup::DY
args_to_zero::NTuple{N, Bool}
end

Expand All @@ -16,31 +15,26 @@ function DI.prepare_pullback_nokwarg(
contexts::Vararg{DI.Context, C}
) where {F, C}
_sig = DI.signature(f!, y, backend, x, ty, contexts...; strict)
target_function = function (f!, y, x, contexts...)
f!(y, x, contexts...)
return y
end
config = get_config(backend)
cache = prepare_pullback_cache(
target_function,
call_and_return,
f!,
y,
x,
map(DI.unwrap, contexts)...;
debug_mode = config.debug_mode,
silence_debug_messages = config.silence_debug_messages,
config,
)
dy_righttype_after = zero_tangent(y)
dy_backup = zero_tangent_or_primal(y, backend)
contexts_tup_false = map(_ -> false, contexts)
args_to_zero = (
false, # target_function
false, # call_and_return
false, # f!
false, # y
true, # x
contexts_tup_false...,
)
prep = MooncakeTwoArgPullbackPrep(
_sig, cache, dy_righttype_after, target_function, args_to_zero
_sig, cache, dy_backup, args_to_zero
)
return prep
end
Expand All @@ -57,12 +51,12 @@ function DI.value_and_pullback(
DI.check_prep(f!, y, prep, backend, x, ty, contexts...)
dy = only(ty)
# Prepare cotangent to add after the forward pass.
dy_righttype_after = copyto!(prep.dy_righttype, dy)
dy_backup = copyto!(prep.dy_backup, dy)
# Run the reverse-pass and return the results.
y_after, (_, _, _, dx) = value_and_pullback!!(
prep.cache,
dy_righttype_after,
prep.target_function,
dy_backup,
call_and_return,
f!,
y,
x,
Expand All @@ -84,11 +78,11 @@ function DI.value_and_pullback(
) where {F, C}
DI.check_prep(f!, y, prep, backend, x, ty, contexts...)
tx = map(ty) do dy
dy_righttype_after = copyto!(prep.dy_righttype, dy)
dy_backup = copyto!(prep.dy_backup, dy)
y_after, (_, _, _, dx) = value_and_pullback!!(
prep.cache,
dy_righttype_after,
prep.target_function,
dy_backup,
call_and_return,
f!,
y,
x,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,18 @@ get_config(::AnyAutoMooncake{Nothing}) = Config()
get_config(backend::AnyAutoMooncake{<:Config}) = backend.config

@inline zero_tangent_unwrap(c::DI.Context) = zero_tangent(DI.unwrap(c))
@inline Dual_unwrap(c, dc) = Dual(DI.unwrap(c), dc)
@inline first_unwrap(c, dc) = (DI.unwrap(c), dc)

function call_and_return(f!::F, y, x, contexts...) where {F}
f!(y, x, contexts...)
return y
end

function zero_tangent_or_primal(x, backend::AnyAutoMooncake)
if get_config(backend).friendly_tangents
# zero(x) but safer
return tangent_to_primal!!(_copy_output(x), zero_tangent(x))
else
return zero_tangent(x)
end
end
1 change: 1 addition & 0 deletions DifferentiationInterface/test/Back/Mooncake/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Loading
Loading