diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 799808b1b..e014b3a67 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -41,7 +41,11 @@ DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"] DifferentiationInterfaceGPUArraysCoreExt = "GPUArraysCore" DifferentiationInterfaceGTPSAExt = "GTPSA" DifferentiationInterfaceMooncakeExt = "Mooncake" -DifferentiationInterfacePolyesterForwardDiffExt = ["PolyesterForwardDiff", "ForwardDiff", "DiffResults"] +DifferentiationInterfacePolyesterForwardDiffExt = [ + "PolyesterForwardDiff", + "ForwardDiff", + "DiffResults", +] DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"] DifferentiationInterfaceSparseArraysExt = "SparseArrays" DifferentiationInterfaceSparseConnectivityTracerExt = "SparseConnectivityTracer" @@ -65,7 +69,7 @@ ForwardDiff = "0.10.36,1" GPUArraysCore = "0.2" GTPSA = "1.4.0" LinearAlgebra = "1" -Mooncake = "0.4.147" +Mooncake = "0.4.175" PolyesterForwardDiff = "0.1.2" ReverseDiff = "1.15.1" SparseArrays = "1" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl index 209367d5a..ab9818735 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl @@ -1,9 +1,10 @@ ## Pullback -struct MooncakeOneArgPullbackPrep{SIG, Tcache, DY} <: DI.PullbackPrep{SIG} +struct MooncakeOneArgPullbackPrep{SIG, Tcache, DY, N} <: DI.PullbackPrep{SIG} _sig::Val{SIG} cache::Tcache dy_righttype::DY + args_to_zero::NTuple{N, Bool} end function DI.prepare_pullback_nokwarg( @@ -16,7 +17,13 @@ function DI.prepare_pullback_nokwarg( ) y = f(x, map(DI.unwrap, contexts)...) dy_righttype = zero_tangent(y) - prep = MooncakeOneArgPullbackPrep(_sig, cache, dy_righttype) + contexts_tup_false = map(_ -> false, contexts) + args_to_zero = ( + false, # f + true, # x + contexts_tup_false..., + ) + prep = MooncakeOneArgPullbackPrep(_sig, cache, dy_righttype, args_to_zero) return prep end @@ -32,7 +39,8 @@ function DI.value_and_pullback( dy = only(ty) dy_righttype = dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy) new_y, (_, new_dx) = value_and_pullback!!( - prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)... + prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...; + prep.args_to_zero ) return new_y, (_copy_output(new_dx),) end @@ -50,7 +58,8 @@ function DI.value_and_pullback( dy_righttype = dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy) y, (_, new_dx) = value_and_pullback!!( - prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)... + prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...; + prep.args_to_zero ) y, _copy_output(new_dx) end @@ -101,9 +110,10 @@ end ## Gradient -struct MooncakeGradientPrep{SIG, Tcache} <: DI.GradientPrep{SIG} +struct MooncakeGradientPrep{SIG, Tcache, N} <: DI.GradientPrep{SIG} _sig::Val{SIG} cache::Tcache + args_to_zero::NTuple{N, Bool} end function DI.prepare_gradient_nokwarg( @@ -114,7 +124,13 @@ function DI.prepare_gradient_nokwarg( cache = prepare_gradient_cache( f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages ) - prep = MooncakeGradientPrep(_sig, cache) + contexts_tup_false = map(_ -> false, contexts) + args_to_zero = ( + false, # f + true, # x + contexts_tup_false..., + ) + prep = MooncakeGradientPrep(_sig, cache, args_to_zero) return prep end @@ -126,7 +142,10 @@ function DI.value_and_gradient( contexts::Vararg{DI.Context, C}, ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) - y, (_, new_grad) = value_and_gradient!!(prep.cache, f, x, map(DI.unwrap, contexts)...) + y, (_, new_grad) = value_and_gradient!!( + prep.cache, f, x, map(DI.unwrap, contexts)...; + prep.args_to_zero + ) return y, _copy_output(new_grad) end @@ -139,7 +158,10 @@ function DI.value_and_gradient!( contexts::Vararg{DI.Context, C}, ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) - y, (_, new_grad) = value_and_gradient!!(prep.cache, f, x, map(DI.unwrap, contexts)...) + y, (_, new_grad) = value_and_gradient!!( + prep.cache, f, x, map(DI.unwrap, contexts)...; + prep.args_to_zero + ) copyto!(grad, new_grad) return y, grad end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl index da3e5b217..2ee11b5ae 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl @@ -1,8 +1,9 @@ -struct MooncakeTwoArgPullbackPrep{SIG, Tcache, DY, F} <: DI.PullbackPrep{SIG} +struct MooncakeTwoArgPullbackPrep{SIG, Tcache, DY, F, N} <: DI.PullbackPrep{SIG} _sig::Val{SIG} cache::Tcache dy_righttype::DY target_function::F + args_to_zero::NTuple{N, Bool} end function DI.prepare_pullback_nokwarg( @@ -30,7 +31,17 @@ function DI.prepare_pullback_nokwarg( silence_debug_messages = config.silence_debug_messages, ) dy_righttype_after = zero_tangent(y) - prep = MooncakeTwoArgPullbackPrep(_sig, cache, dy_righttype_after, target_function) + contexts_tup_false = map(_ -> false, contexts) + args_to_zero = ( + false, # target_function + false, # f! + false, # y + true, # x + contexts_tup_false..., + ) + prep = MooncakeTwoArgPullbackPrep( + _sig, cache, dy_righttype_after, target_function, args_to_zero + ) return prep end @@ -55,7 +66,8 @@ function DI.value_and_pullback( f!, y, x, - map(DI.unwrap, contexts)..., + map(DI.unwrap, contexts)...; + prep.args_to_zero ) copyto!(y, y_after) return y, (_copy_output(dx),) @@ -80,7 +92,8 @@ function DI.value_and_pullback( f!, y, x, - map(DI.unwrap, contexts)..., + map(DI.unwrap, contexts)...; + prep.args_to_zero ) copyto!(y, y_after) _copy_output(dx)