diff --git a/DifferentiationInterface/CHANGELOG.md b/DifferentiationInterface/CHANGELOG.md index f01ae3b2a..d60b733ec 100644 --- a/DifferentiationInterface/CHANGELOG.md +++ b/DifferentiationInterface/CHANGELOG.md @@ -5,7 +5,13 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [Unreleased](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.15...main) +## [Unreleased](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.16...main) + +## [0.7.16](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.15...DifferentiationInterface-v0.7.16) + +### Fixed + +- Upgrade Mooncake compat to v0.5 ([#961](https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/961)) ## [0.7.15](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.14...DifferentiationInterface-v0.7.15) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index d50bdcf11..af16a6ae2 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.7.15" +version = "0.7.16" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -69,7 +69,7 @@ ForwardDiff = "0.10.36,1" GPUArraysCore = "0.2" GTPSA = "1.4.0" LinearAlgebra = "1" -Mooncake = "0.4.175" +Mooncake = "0.5.1" PolyesterForwardDiff = "0.1.2" ReverseDiff = "1.15.1" SparseArrays = "1" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl index 561c06cc8..3513d548c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl @@ -29,7 +29,8 @@ using Mooncake: NoRData, primal, _copy_output, - _copy_to_output!! + _copy_to_output!!, + tangent_to_primal!! const AnyAutoMooncake{C} = Union{AutoMooncake{C}, AutoMooncakeForward{C}} diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl index 61946a0d3..c470b6473 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl @@ -1,9 +1,8 @@ ## Pushforward -struct MooncakeOneArgPushforwardPrep{SIG, Tcache, DX, FT, CT} <: DI.PushforwardPrep{SIG} +struct MooncakeOneArgPushforwardPrep{SIG, Tcache, FT, CT} <: DI.PushforwardPrep{SIG} _sig::Val{SIG} cache::Tcache - dx_righttype::DX df::FT context_tangents::CT end @@ -18,13 +17,10 @@ function DI.prepare_pushforward_nokwarg( ) where {F, C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) config = get_config(backend) - cache = prepare_derivative_cache( - f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages - ) - dx_righttype = zero_tangent(x) - df = zero_tangent(f) + cache = prepare_derivative_cache(f, x, map(DI.unwrap, contexts)...; config) + df = zero_tangent_or_primal(f, backend) context_tangents = map(zero_tangent_unwrap, contexts) - prep = MooncakeOneArgPushforwardPrep(_sig, cache, dx_righttype, df, context_tangents) + prep = MooncakeOneArgPushforwardPrep(_sig, cache, df, context_tangents) return prep end @@ -38,19 +34,17 @@ function DI.value_and_pushforward( ) where {F, C, X} DI.check_prep(f, prep, backend, x, tx, contexts...) ys_and_ty = map(tx) do dx - dx_righttype = - dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx) - y_dual = value_and_derivative!!( + y_and_dy = value_and_derivative!!( prep.cache, - Dual(f, prep.df), - Dual(x, dx_righttype), - map(Dual_unwrap, contexts, prep.context_tangents)..., + (f, prep.df), + (x, dx), + map(first_unwrap, contexts, prep.context_tangents)..., ) - y = primal(y_dual) - dy = _copy_output(tangent(y_dual)) + y = first(y_and_dy) + dy = _copy_output(last(y_and_dy)) return y, dy end - y = first(ys_and_ty[1]) + y = _copy_output(first(ys_and_ty[1])) ty = map(last, ys_and_ty) return y, ty end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl index dc8f8c1f0..3c75f530b 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl @@ -1,11 +1,11 @@ ## Pushforward -struct MooncakeTwoArgPushforwardPrep{SIG, Tcache, DX, DY, FT, CT} <: DI.PushforwardPrep{SIG} +struct MooncakeTwoArgPushforwardPrep{SIG, Tcache, FT0, FT, YT, CT} <: DI.PushforwardPrep{SIG} _sig::Val{SIG} cache::Tcache - dx_righttype::DX - dy_righttype::DY + dcall::FT0 df!::FT + dy::YT context_tangents::CT end @@ -21,18 +21,18 @@ function DI.prepare_pushforward_nokwarg( _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) config = get_config(backend) cache = prepare_derivative_cache( + call_and_return, f!, y, x, map(DI.unwrap, contexts)...; - config.debug_mode, - config.silence_debug_messages, + config ) - dx_righttype = zero_tangent(x) - dy_righttype = zero_tangent(y) - df! = zero_tangent(f!) + dcall = zero_tangent_or_primal(call_and_return, backend) + df! = zero_tangent_or_primal(f!, backend) + dy = zero_tangent_or_primal(y, backend) context_tangents = map(zero_tangent_unwrap, contexts) - prep = MooncakeTwoArgPushforwardPrep(_sig, cache, dx_righttype, dy_righttype, df!, context_tangents) + prep = MooncakeTwoArgPushforwardPrep(_sig, cache, dcall, df!, dy, context_tangents) return prep end @@ -47,18 +47,15 @@ function DI.value_and_pushforward( ) where {F, C, X} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ty = map(tx) do dx - dx_righttype = - dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx) - y_dual = zero_dual(y) - value_and_derivative!!( + _, new_dy = value_and_derivative!!( prep.cache, - Dual(f!, prep.df!), - y_dual, - Dual(x, dx_righttype), - map(Dual_unwrap, contexts, prep.context_tangents)..., + (call_and_return, prep.dcall), + (f!, prep.df!), + (y, prep.dy), + (x, dx), + map(first_unwrap, contexts, prep.context_tangents)..., ) - dy = _copy_output(tangent(y_dual)) - return dy + return _copy_output(new_dy) end return y, ty end @@ -88,18 +85,15 @@ function DI.value_and_pushforward!( ) where {F, C, X, Y} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) foreach(tx, ty) do dx, dy - dx_righttype = - dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx) - dy_righttype = - dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy) - value_and_derivative!!( + _, new_dy = value_and_derivative!!( prep.cache, - Dual(f!, prep.df!), - Dual(y, dy_righttype), - Dual(x, dx_righttype), - map(Dual_unwrap, contexts, prep.context_tangents)..., + (call_and_return, prep.dcall), + (f!, prep.df!), + (y, dy), + (x, dx), + map(first_unwrap, contexts, prep.context_tangents)..., ) - dy === dy_righttype || copyto!(dy, dy_righttype) + copyto!(dy, new_dy) end return y, ty end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl index ab9818735..2514cdc40 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl @@ -1,9 +1,8 @@ ## Pullback -struct MooncakeOneArgPullbackPrep{SIG, Tcache, DY, N} <: DI.PullbackPrep{SIG} +struct MooncakeOneArgPullbackPrep{SIG, Tcache, N} <: DI.PullbackPrep{SIG} _sig::Val{SIG} cache::Tcache - dy_righttype::DY args_to_zero::NTuple{N, Bool} end @@ -12,18 +11,14 @@ function DI.prepare_pullback_nokwarg( ) where {F, C} _sig = DI.signature(f, backend, x, ty, contexts...; strict) config = get_config(backend) - cache = prepare_pullback_cache( - f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages - ) - y = f(x, map(DI.unwrap, contexts)...) - dy_righttype = zero_tangent(y) + cache = prepare_pullback_cache(f, x, map(DI.unwrap, contexts)...; config) contexts_tup_false = map(_ -> false, contexts) args_to_zero = ( false, # f true, # x contexts_tup_false..., ) - prep = MooncakeOneArgPullbackPrep(_sig, cache, dy_righttype, args_to_zero) + prep = MooncakeOneArgPullbackPrep(_sig, cache, args_to_zero) return prep end @@ -37,10 +32,8 @@ function DI.value_and_pullback( ) where {F, Y, C} DI.check_prep(f, prep, backend, x, ty, contexts...) dy = only(ty) - dy_righttype = dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy) new_y, (_, new_dx) = value_and_pullback!!( - prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...; - prep.args_to_zero + prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero ) return new_y, (_copy_output(new_dx),) end @@ -55,11 +48,8 @@ function DI.value_and_pullback( ) where {F, Y, C} DI.check_prep(f, prep, backend, x, ty, contexts...) ys_and_tx = map(ty) do dy - dy_righttype = - dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy) y, (_, new_dx) = value_and_pullback!!( - prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...; - prep.args_to_zero + prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero ) y, _copy_output(new_dx) end @@ -121,9 +111,7 @@ function DI.prepare_gradient_nokwarg( ) where {F, C} _sig = DI.signature(f, backend, x, contexts...; strict) config = get_config(backend) - cache = prepare_gradient_cache( - f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages - ) + cache = prepare_gradient_cache(f, x, map(DI.unwrap, contexts)...; config) contexts_tup_false = map(_ -> false, contexts) args_to_zero = ( false, # f diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl index 2ee11b5ae..2b55131b9 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl @@ -1,8 +1,7 @@ -struct MooncakeTwoArgPullbackPrep{SIG, Tcache, DY, F, N} <: DI.PullbackPrep{SIG} +struct MooncakeTwoArgPullbackPrep{SIG, Tcache, DY, N} <: DI.PullbackPrep{SIG} _sig::Val{SIG} cache::Tcache - dy_righttype::DY - target_function::F + dy_backup::DY args_to_zero::NTuple{N, Bool} end @@ -16,31 +15,26 @@ function DI.prepare_pullback_nokwarg( contexts::Vararg{DI.Context, C} ) where {F, C} _sig = DI.signature(f!, y, backend, x, ty, contexts...; strict) - target_function = function (f!, y, x, contexts...) - f!(y, x, contexts...) - return y - end config = get_config(backend) cache = prepare_pullback_cache( - target_function, + call_and_return, f!, y, x, map(DI.unwrap, contexts)...; - debug_mode = config.debug_mode, - silence_debug_messages = config.silence_debug_messages, + config, ) - dy_righttype_after = zero_tangent(y) + dy_backup = zero_tangent_or_primal(y, backend) contexts_tup_false = map(_ -> false, contexts) args_to_zero = ( - false, # target_function + false, # call_and_return false, # f! false, # y true, # x contexts_tup_false..., ) prep = MooncakeTwoArgPullbackPrep( - _sig, cache, dy_righttype_after, target_function, args_to_zero + _sig, cache, dy_backup, args_to_zero ) return prep end @@ -57,12 +51,12 @@ function DI.value_and_pullback( DI.check_prep(f!, y, prep, backend, x, ty, contexts...) dy = only(ty) # Prepare cotangent to add after the forward pass. - dy_righttype_after = copyto!(prep.dy_righttype, dy) + dy_backup = copyto!(prep.dy_backup, dy) # Run the reverse-pass and return the results. y_after, (_, _, _, dx) = value_and_pullback!!( prep.cache, - dy_righttype_after, - prep.target_function, + dy_backup, + call_and_return, f!, y, x, @@ -84,11 +78,11 @@ function DI.value_and_pullback( ) where {F, C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) tx = map(ty) do dy - dy_righttype_after = copyto!(prep.dy_righttype, dy) + dy_backup = copyto!(prep.dy_backup, dy) y_after, (_, _, _, dx) = value_and_pullback!!( prep.cache, - dy_righttype_after, - prep.target_function, + dy_backup, + call_and_return, f!, y, x, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl index 56e4b966a..b22d8d49b 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl @@ -2,4 +2,18 @@ get_config(::AnyAutoMooncake{Nothing}) = Config() get_config(backend::AnyAutoMooncake{<:Config}) = backend.config @inline zero_tangent_unwrap(c::DI.Context) = zero_tangent(DI.unwrap(c)) -@inline Dual_unwrap(c, dc) = Dual(DI.unwrap(c), dc) +@inline first_unwrap(c, dc) = (DI.unwrap(c), dc) + +function call_and_return(f!::F, y, x, contexts...) where {F} + f!(y, x, contexts...) + return y +end + +function zero_tangent_or_primal(x, backend::AnyAutoMooncake) + if get_config(backend).friendly_tangents + # zero(x) but safer + return tangent_to_primal!!(_copy_output(x), zero_tangent(x)) + else + return zero_tangent(x) + end +end diff --git a/DifferentiationInterface/test/Back/Mooncake/Project.toml b/DifferentiationInterface/test/Back/Mooncake/Project.toml index 39a878d7e..f13f37c5b 100644 --- a/DifferentiationInterface/test/Back/Mooncake/Project.toml +++ b/DifferentiationInterface/test/Back/Mooncake/Project.toml @@ -6,4 +6,5 @@ ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index f2d6add4c..8c67aa1bb 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -7,11 +7,11 @@ using Test using ExplicitImports check_no_implicit_imports(DifferentiationInterface) - backends = [ - AutoMooncake(; config = nothing), - AutoMooncake(; config = Mooncake.Config()), - AutoMooncakeForward(; config = nothing), + AutoMooncake(), + AutoMooncakeForward(), + AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)), + AutoMooncakeForward(; config = Mooncake.Config(; friendly_tangents = true)), ] for backend in backends @@ -39,9 +39,9 @@ end # Test second-order differentiation (forward-over-reverse) test_differentiation( - [SecondOrder(AutoMooncakeForward(; config = nothing), AutoMooncake(; config = nothing))], + [SecondOrder(AutoMooncakeForward(), AutoMooncake())], excluded = EXCLUDED, - logging = true, + logging = LOGGING, ) @testset "NamedTuples" begin @@ -51,3 +51,10 @@ test_differentiation( @test grad.A == ps.B @test grad.B == ps.A end + +test_differentiation( + backends[3:4], + static_scenarios(); + logging = LOGGING, + excluded = SECOND_ORDER +)