From 04b4097f0ea33b1281a23172c233b87601ca61e9 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 5 Jun 2025 20:28:33 +0200 Subject: [PATCH 1/3] feat: use Mooncake's copy utilities --- DifferentiationInterface/Project.toml | 29 +++++++++++++++---- .../DifferentiationInterfaceMooncakeExt.jl | 7 ++--- .../onearg.jl | 10 +++---- .../twoarg.jl | 4 +-- 4 files changed, 34 insertions(+), 16 deletions(-) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 486b91a5b..30043c9ac 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.7.0" +version = "0.7.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -41,7 +41,9 @@ DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"] DifferentiationInterfaceGPUArraysCoreExt = "GPUArraysCore" DifferentiationInterfaceGTPSAExt = "GTPSA" DifferentiationInterfaceMooncakeExt = "Mooncake" -DifferentiationInterfacePolyesterForwardDiffExt = ["PolyesterForwardDiff", "ForwardDiff", "DiffResults"] +DifferentiationInterfacePolyesterForwardDiffExt = [ + "PolyesterForwardDiff", "ForwardDiff", "DiffResults" +] DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"] DifferentiationInterfaceSparseArraysExt = "SparseArrays" DifferentiationInterfaceSparseConnectivityTracerExt = "SparseConnectivityTracer" @@ -53,7 +55,7 @@ DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"] [compat] Aqua = "0.8.12" -ADTypes = "1.13.0" +ADTypes = "1.15.0" ChainRulesCore = "1.23.0" ComponentArrays = "0.15.27" DataFrames = "1.7.0" @@ -72,7 +74,7 @@ JET = "0.9" JLArrays = "0.2.0" JuliaFormatter = "1,2" LinearAlgebra = "1" -Mooncake = "0.4.88" +Mooncake = "0.4.121" Pkg = "1" PolyesterForwardDiff = "0.1.2" Random = "1" @@ -121,4 +123,21 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "ExplicitImports", "JET", "JLArrays", "JuliaFormatter", "Pkg", "Random", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test"] +test = [ + "ADTypes", + "Aqua", + "ComponentArrays", + "DataFrames", + "ExplicitImports", + "JET", + "JLArrays", + "JuliaFormatter", + "Pkg", + "Random", + "SparseArrays", + "SparseConnectivityTracer", + "SparseMatrixColorings", + "StableRNGs", + "StaticArrays", + "Test", +] diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl index 52e742b05..fa5c6d4ec 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl @@ -10,13 +10,12 @@ using Mooncake: tangent_type, value_and_gradient!!, value_and_pullback!!, - zero_tangent + zero_tangent, + _copy_output, + _copy_to_output! DI.check_available(::AutoMooncake) = true -copyto!!(dst::Number, src::Number) = convert(typeof(dst), src) -copyto!!(dst, src) = DI.ismutable_array(dst) ? copyto!(dst, src) : convert(typeof(dst), src) - get_config(::AutoMooncake{Nothing}) = Config() get_config(backend::AutoMooncake{<:Config}) = backend.config diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl index d637a01bd..7a56f577e 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl @@ -30,11 +30,11 @@ function DI.value_and_pullback( ) where {F,Y,C} DI.check_prep(f, prep, backend, x, ty, contexts...) dy = only(ty) - dy_righttype = dy isa tangent_type(Y) ? dy : copyto!!(prep.dy_righttype, dy) + dy_righttype = dy isa tangent_type(Y) ? dy : _copy_to_output!(prep.dy_righttype, dy) new_y, (_, new_dx) = value_and_pullback!!( prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)... ) - return new_y, (mycopy(new_dx),) + return new_y, (_copy_output(new_dx),) end function DI.value_and_pullback( @@ -47,11 +47,11 @@ function DI.value_and_pullback( ) where {F,Y,C} DI.check_prep(f, prep, backend, x, ty, contexts...) ys_and_tx = map(ty) do dy - dy_righttype = dy isa tangent_type(Y) ? dy : copyto!!(prep.dy_righttype, dy) + dy_righttype = dy isa tangent_type(Y) ? dy : _copy_to_output!(prep.dy_righttype, dy) y, (_, new_dx) = value_and_pullback!!( prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)... ) - y, mycopy(new_dx) + y, _copy_output(new_dx) end y = first(ys_and_tx[1]) tx = last.(ys_and_tx) @@ -126,7 +126,7 @@ function DI.value_and_gradient( ) where {F,C} DI.check_prep(f, prep, backend, x, contexts...) y, (_, new_grad) = value_and_gradient!!(prep.cache, f, x, map(DI.unwrap, contexts)...) - return y, mycopy(new_grad) + return y, _copy_output(new_grad) end function DI.value_and_gradient!( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl index 24cb39e41..d0bbf282d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl @@ -58,7 +58,7 @@ function DI.value_and_pullback( map(DI.unwrap, contexts)..., ) copyto!(y, y_after) - return y, (mycopy(dx),) + return y, (_copy_output(dx),) end function DI.value_and_pullback( @@ -83,7 +83,7 @@ function DI.value_and_pullback( map(DI.unwrap, contexts)..., ) copyto!(y, y_after) - mycopy(dx) + _copy_output(dx) end return y, tx end From 9b03f2cf18f4db2ca6b13ae76b9f06e7b975fe8f Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 10 Jun 2025 18:43:02 +0200 Subject: [PATCH 2/3] Fix compat --- DifferentiationInterface/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 30043c9ac..5ffc3e4b6 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -55,7 +55,7 @@ DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"] [compat] Aqua = "0.8.12" -ADTypes = "1.15.0" +ADTypes = "1.13.0" ChainRulesCore = "1.23.0" ComponentArrays = "0.15.27" DataFrames = "1.7.0" From c05dfb63f0c33acf434988aa3d76a3722f81b2d0 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 10 Jun 2025 20:03:47 +0200 Subject: [PATCH 3/3] Adapt to latest Mooncake --- DifferentiationInterface/CHANGELOG.md | 5 +++++ DifferentiationInterface/Project.toml | 2 +- .../DifferentiationInterfaceMooncakeExt.jl | 2 +- .../ext/DifferentiationInterfaceMooncakeExt/onearg.jl | 5 +++-- 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/DifferentiationInterface/CHANGELOG.md b/DifferentiationInterface/CHANGELOG.md index 7dee7f9ba..39a8aa928 100644 --- a/DifferentiationInterface/CHANGELOG.md +++ b/DifferentiationInterface/CHANGELOG.md @@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [0.7.1] +### Feat + +- Use Mooncake's internal copy utilities ([#809]) + ### Fixed - Make basis work for `CuArray` ([#810]) @@ -40,6 +44,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [0.6.53]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.52...DifferentiationInterface-v0.6.53 [#810]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/810 +[#809]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/809 [#799]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/799 [#795]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/795 [#790]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/790 diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 5ffc3e4b6..11be6b77d 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -74,7 +74,7 @@ JET = "0.9" JLArrays = "0.2.0" JuliaFormatter = "1,2" LinearAlgebra = "1" -Mooncake = "0.4.121" +Mooncake = "0.4.122" Pkg = "1" PolyesterForwardDiff = "0.1.2" Random = "1" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl index fa5c6d4ec..6253ea229 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl @@ -12,7 +12,7 @@ using Mooncake: value_and_pullback!!, zero_tangent, _copy_output, - _copy_to_output! + _copy_to_output!! DI.check_available(::AutoMooncake) = true diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl index 7a56f577e..dac2039ce 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl @@ -30,7 +30,7 @@ function DI.value_and_pullback( ) where {F,Y,C} DI.check_prep(f, prep, backend, x, ty, contexts...) dy = only(ty) - dy_righttype = dy isa tangent_type(Y) ? dy : _copy_to_output!(prep.dy_righttype, dy) + dy_righttype = dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy) new_y, (_, new_dx) = value_and_pullback!!( prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)... ) @@ -47,7 +47,8 @@ function DI.value_and_pullback( ) where {F,Y,C} DI.check_prep(f, prep, backend, x, ty, contexts...) ys_and_tx = map(ty) do dy - dy_righttype = dy isa tangent_type(Y) ? dy : _copy_to_output!(prep.dy_righttype, dy) + dy_righttype = + dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy) y, (_, new_dx) = value_and_pullback!!( prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)... )