diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index d50bdcf11..b2fa8cf7f 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -69,7 +69,7 @@ ForwardDiff = "0.10.36,1" GPUArraysCore = "0.2" GTPSA = "1.4.0" LinearAlgebra = "1" -Mooncake = "0.4.175" +Mooncake = "0.5.0" PolyesterForwardDiff = "0.1.2" ReverseDiff = "1.15.1" SparseArrays = "1" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl index 561c06cc8..b27dc0472 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl @@ -29,7 +29,9 @@ using Mooncake: NoRData, primal, _copy_output, - _copy_to_output!! + _copy_to_output!!, + primal_to_tangent!!, + tangent_to_primal!! const AnyAutoMooncake{C} = Union{AutoMooncake{C}, AutoMooncakeForward{C}} diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl index 61946a0d3..298fbc546 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl @@ -19,10 +19,14 @@ function DI.prepare_pushforward_nokwarg( _sig = DI.signature(f, backend, x, tx, contexts...; strict) config = get_config(backend) cache = prepare_derivative_cache( - f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages + f, x, map(DI.unwrap, contexts)...; config ) - dx_righttype = zero_tangent(x) df = zero_tangent(f) + if config.friendly_tangents + dx_righttype = zero_tangent(x) + else + dx_righttype = nothing + end context_tangents = map(zero_tangent_unwrap, contexts) prep = MooncakeOneArgPushforwardPrep(_sig, cache, dx_righttype, df, context_tangents) return prep @@ -38,8 +42,7 @@ function DI.value_and_pushforward( ) where {F, C, X} DI.check_prep(f, prep, backend, x, tx, contexts...) ys_and_ty = map(tx) do dx - dx_righttype = - dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx) + dx_righttype = isnothing(prep.dx_righttype) ? dx : primal_to_tangent!!(prep.dx_righttype, dx) y_dual = value_and_derivative!!( prep.cache, Dual(f, prep.df), @@ -47,7 +50,11 @@ function DI.value_and_pushforward( map(Dual_unwrap, contexts, prep.context_tangents)..., ) y = primal(y_dual) - dy = _copy_output(tangent(y_dual)) + if isnothing(prep.dx_righttype) + dy = _copy_output(tangent(y_dual)) + else + dy = tangent_to_primal!!(_copy_output(y), tangent(y_dual)) + end return y, dy end y = first(ys_and_ty[1]) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl index dc8f8c1f0..0930c90dd 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl @@ -25,11 +25,15 @@ function DI.prepare_pushforward_nokwarg( y, x, map(DI.unwrap, contexts)...; - config.debug_mode, - config.silence_debug_messages, + config, ) - dx_righttype = zero_tangent(x) - dy_righttype = zero_tangent(y) + if config.friendly_tangents + dx_righttype = zero_tangent(x) + dy_righttype = zero_tangent(y) + else + dx_righttype = nothing + dy_righttype = nothing + end df! = zero_tangent(f!) context_tangents = map(zero_tangent_unwrap, contexts) prep = MooncakeTwoArgPushforwardPrep(_sig, cache, dx_righttype, dy_righttype, df!, context_tangents) @@ -48,7 +52,7 @@ function DI.value_and_pushforward( DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ty = map(tx) do dx dx_righttype = - dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx) + isnothing(prep.dx_righttype) ? dx : primal_to_tangent!!(prep.dx_righttype, dx) y_dual = zero_dual(y) value_and_derivative!!( prep.cache, @@ -57,7 +61,11 @@ function DI.value_and_pushforward( Dual(x, dx_righttype), map(Dual_unwrap, contexts, prep.context_tangents)..., ) - dy = _copy_output(tangent(y_dual)) + if isnothing(prep.dx_righttype) + dy = _copy_output(tangent(y_dual)) + else + dy = tangent_to_primal!!(_copy_output(y), tangent(y_dual)) + end return dy end return y, ty @@ -89,9 +97,9 @@ function DI.value_and_pushforward!( DI.check_prep(f!, y, prep, backend, x, tx, contexts...) foreach(tx, ty) do dx, dy dx_righttype = - dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx) + isnothing(prep.dx_righttype) ? dx : primal_to_tangent!!(prep.dx_righttype, dx) dy_righttype = - dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy) + isnothing(prep.dy_righttype) ? dy : primal_to_tangent!!(prep.dy_righttype, dy) value_and_derivative!!( prep.cache, Dual(f!, prep.df!), @@ -99,7 +107,7 @@ function DI.value_and_pushforward!( Dual(x, dx_righttype), map(Dual_unwrap, contexts, prep.context_tangents)..., ) - dy === dy_righttype || copyto!(dy, dy_righttype) + isnothing(prep.dy_righttype) || tangent_to_primal!!(dy, dy_righttype) end return y, ty end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl index ab9818735..49ac0b656 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl @@ -1,9 +1,8 @@ ## Pullback -struct MooncakeOneArgPullbackPrep{SIG, Tcache, DY, N} <: DI.PullbackPrep{SIG} +struct MooncakeOneArgPullbackPrep{SIG, Tcache, N} <: DI.PullbackPrep{SIG} _sig::Val{SIG} cache::Tcache - dy_righttype::DY args_to_zero::NTuple{N, Bool} end @@ -12,18 +11,14 @@ function DI.prepare_pullback_nokwarg( ) where {F, C} _sig = DI.signature(f, backend, x, ty, contexts...; strict) config = get_config(backend) - cache = prepare_pullback_cache( - f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages - ) - y = f(x, map(DI.unwrap, contexts)...) - dy_righttype = zero_tangent(y) + cache = prepare_pullback_cache(f, x, map(DI.unwrap, contexts)...; config) contexts_tup_false = map(_ -> false, contexts) args_to_zero = ( false, # f true, # x contexts_tup_false..., ) - prep = MooncakeOneArgPullbackPrep(_sig, cache, dy_righttype, args_to_zero) + prep = MooncakeOneArgPullbackPrep(_sig, cache, args_to_zero) return prep end @@ -37,9 +32,8 @@ function DI.value_and_pullback( ) where {F, Y, C} DI.check_prep(f, prep, backend, x, ty, contexts...) dy = only(ty) - dy_righttype = dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy) new_y, (_, new_dx) = value_and_pullback!!( - prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...; + prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero ) return new_y, (_copy_output(new_dx),) @@ -55,10 +49,8 @@ function DI.value_and_pullback( ) where {F, Y, C} DI.check_prep(f, prep, backend, x, ty, contexts...) ys_and_tx = map(ty) do dy - dy_righttype = - dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy) y, (_, new_dx) = value_and_pullback!!( - prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...; + prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero ) y, _copy_output(new_dx) @@ -121,9 +113,7 @@ function DI.prepare_gradient_nokwarg( ) where {F, C} _sig = DI.signature(f, backend, x, contexts...; strict) config = get_config(backend) - cache = prepare_gradient_cache( - f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages - ) + cache = prepare_gradient_cache(f, x, map(DI.unwrap, contexts)...; config) contexts_tup_false = map(_ -> false, contexts) args_to_zero = ( false, # f diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl index 2ee11b5ae..ff41be77a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl @@ -1,7 +1,6 @@ -struct MooncakeTwoArgPullbackPrep{SIG, Tcache, DY, F, N} <: DI.PullbackPrep{SIG} +struct MooncakeTwoArgPullbackPrep{SIG, Tcache, F, N} <: DI.PullbackPrep{SIG} _sig::Val{SIG} cache::Tcache - dy_righttype::DY target_function::F args_to_zero::NTuple{N, Bool} end @@ -27,10 +26,8 @@ function DI.prepare_pullback_nokwarg( y, x, map(DI.unwrap, contexts)...; - debug_mode = config.debug_mode, - silence_debug_messages = config.silence_debug_messages, + config, ) - dy_righttype_after = zero_tangent(y) contexts_tup_false = map(_ -> false, contexts) args_to_zero = ( false, # target_function @@ -39,9 +36,7 @@ function DI.prepare_pullback_nokwarg( true, # x contexts_tup_false..., ) - prep = MooncakeTwoArgPullbackPrep( - _sig, cache, dy_righttype_after, target_function, args_to_zero - ) + prep = MooncakeTwoArgPullbackPrep(_sig, cache, target_function, args_to_zero) return prep end @@ -56,12 +51,10 @@ function DI.value_and_pullback( ) where {F, C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) dy = only(ty) - # Prepare cotangent to add after the forward pass. - dy_righttype_after = copyto!(prep.dy_righttype, dy) # Run the reverse-pass and return the results. y_after, (_, _, _, dx) = value_and_pullback!!( prep.cache, - dy_righttype_after, + dy, prep.target_function, f!, y,