diff --git a/DifferentiationInterface/CHANGELOG.md b/DifferentiationInterface/CHANGELOG.md index 22d83c275..af8d11e0c 100644 --- a/DifferentiationInterface/CHANGELOG.md +++ b/DifferentiationInterface/CHANGELOG.md @@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [0.7.1] +### Feat + +- Use Mooncake's internal copy utilities ([#809]) + ### Fixed - Take `absstep` into account for FiniteDiff ([#812]) @@ -42,6 +46,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [#812]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/812 [#810]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/810 +[#809]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/809 [#799]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/799 [#795]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/795 [#790]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/790 diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 5896be5b2..11be6b77d 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -74,7 +74,7 @@ JET = "0.9" JLArrays = "0.2.0" JuliaFormatter = "1,2" LinearAlgebra = "1" -Mooncake = "0.4.88" +Mooncake = "0.4.122" Pkg = "1" PolyesterForwardDiff = "0.1.2" Random = "1" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl index 52e742b05..6253ea229 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl @@ -10,13 +10,12 @@ using Mooncake: tangent_type, value_and_gradient!!, value_and_pullback!!, - zero_tangent + zero_tangent, + _copy_output, + _copy_to_output!! DI.check_available(::AutoMooncake) = true -copyto!!(dst::Number, src::Number) = convert(typeof(dst), src) -copyto!!(dst, src) = DI.ismutable_array(dst) ? copyto!(dst, src) : convert(typeof(dst), src) - get_config(::AutoMooncake{Nothing}) = Config() get_config(backend::AutoMooncake{<:Config}) = backend.config diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl index d637a01bd..dac2039ce 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl @@ -30,11 +30,11 @@ function DI.value_and_pullback( ) where {F,Y,C} DI.check_prep(f, prep, backend, x, ty, contexts...) dy = only(ty) - dy_righttype = dy isa tangent_type(Y) ? dy : copyto!!(prep.dy_righttype, dy) + dy_righttype = dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy) new_y, (_, new_dx) = value_and_pullback!!( prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)... ) - return new_y, (mycopy(new_dx),) + return new_y, (_copy_output(new_dx),) end function DI.value_and_pullback( @@ -47,11 +47,12 @@ function DI.value_and_pullback( ) where {F,Y,C} DI.check_prep(f, prep, backend, x, ty, contexts...) ys_and_tx = map(ty) do dy - dy_righttype = dy isa tangent_type(Y) ? dy : copyto!!(prep.dy_righttype, dy) + dy_righttype = + dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy) y, (_, new_dx) = value_and_pullback!!( prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)... ) - y, mycopy(new_dx) + y, _copy_output(new_dx) end y = first(ys_and_tx[1]) tx = last.(ys_and_tx) @@ -126,7 +127,7 @@ function DI.value_and_gradient( ) where {F,C} DI.check_prep(f, prep, backend, x, contexts...) y, (_, new_grad) = value_and_gradient!!(prep.cache, f, x, map(DI.unwrap, contexts)...) - return y, mycopy(new_grad) + return y, _copy_output(new_grad) end function DI.value_and_gradient!( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl index 24cb39e41..d0bbf282d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl @@ -58,7 +58,7 @@ function DI.value_and_pullback( map(DI.unwrap, contexts)..., ) copyto!(y, y_after) - return y, (mycopy(dx),) + return y, (_copy_output(dx),) end function DI.value_and_pullback( @@ -83,7 +83,7 @@ function DI.value_and_pullback( map(DI.unwrap, contexts)..., ) copyto!(y, y_after) - mycopy(dx) + _copy_output(dx) end return y, tx end