From 151ad470265c562555e65aae3ceec32beb1d9b54 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 29 Jan 2026 10:02:24 +0100 Subject: [PATCH 1/9] fix: upgrade Mooncake compat to v0.5 --- DifferentiationInterface/CHANGELOG.md | 8 +++++++- DifferentiationInterface/Project.toml | 4 ++-- .../DifferentiationInterfaceMooncakeExt/forward_onearg.jl | 2 +- .../DifferentiationInterfaceMooncakeExt/forward_twoarg.jl | 3 +-- .../ext/DifferentiationInterfaceMooncakeExt/onearg.jl | 4 ++-- .../ext/DifferentiationInterfaceMooncakeExt/twoarg.jl | 3 +-- 6 files changed, 14 insertions(+), 10 deletions(-) diff --git a/DifferentiationInterface/CHANGELOG.md b/DifferentiationInterface/CHANGELOG.md index f01ae3b2a..d60b733ec 100644 --- a/DifferentiationInterface/CHANGELOG.md +++ b/DifferentiationInterface/CHANGELOG.md @@ -5,7 +5,13 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [Unreleased](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.15...main) +## [Unreleased](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.16...main) + +## [0.7.16](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.15...DifferentiationInterface-v0.7.16) + +### Fixed + +- Upgrade Mooncake compat to v0.5 ([#961](https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/961)) ## [0.7.15](https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.7.14...DifferentiationInterface-v0.7.15) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index d50bdcf11..f54fd1f4f 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -1,7 +1,7 @@ name = "DifferentiationInterface" uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" authors = ["Guillaume Dalle", "Adrian Hill"] -version = "0.7.15" +version = "0.7.16" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -69,7 +69,7 @@ ForwardDiff = "0.10.36,1" GPUArraysCore = "0.2" GTPSA = "1.4.0" LinearAlgebra = "1" -Mooncake = "0.4.175" +Mooncake = "0.5.0" PolyesterForwardDiff = "0.1.2" ReverseDiff = "1.15.1" SparseArrays = "1" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl index 61946a0d3..b64549b89 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl @@ -19,7 +19,7 @@ function DI.prepare_pushforward_nokwarg( _sig = DI.signature(f, backend, x, tx, contexts...; strict) config = get_config(backend) cache = prepare_derivative_cache( - f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages + f, x, map(DI.unwrap, contexts)...; config ) dx_righttype = zero_tangent(x) df = zero_tangent(f) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl index dc8f8c1f0..dea6138bc 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl @@ -25,8 +25,7 @@ function DI.prepare_pushforward_nokwarg( y, x, map(DI.unwrap, contexts)...; - config.debug_mode, - config.silence_debug_messages, + config ) dx_righttype = zero_tangent(x) dy_righttype = zero_tangent(y) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl index ab9818735..e52a2a498 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl @@ -13,7 +13,7 @@ function DI.prepare_pullback_nokwarg( _sig = DI.signature(f, backend, x, ty, contexts...; strict) config = get_config(backend) cache = prepare_pullback_cache( - f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages + f, x, map(DI.unwrap, contexts)...; config ) y = f(x, map(DI.unwrap, contexts)...) dy_righttype = zero_tangent(y) @@ -122,7 +122,7 @@ function DI.prepare_gradient_nokwarg( _sig = DI.signature(f, backend, x, contexts...; strict) config = get_config(backend) cache = prepare_gradient_cache( - f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages + f, x, map(DI.unwrap, contexts)...; config ) contexts_tup_false = map(_ -> false, contexts) args_to_zero = ( diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl index 2ee11b5ae..14b5c2503 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl @@ -27,8 +27,7 @@ function DI.prepare_pullback_nokwarg( y, x, map(DI.unwrap, contexts)...; - debug_mode = config.debug_mode, - silence_debug_messages = config.silence_debug_messages, + config, ) dy_righttype_after = zero_tangent(y) contexts_tup_false = map(_ -> false, contexts) From a62c5189b5cd3219ac5f88668d41792b01496de5 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 30 Jan 2026 10:34:39 +0100 Subject: [PATCH 2/9] Remove useless _righttype caches --- DifferentiationInterface/Project.toml | 2 +- .../forward_onearg.jl | 14 ++++----- .../forward_twoarg.jl | 29 ++++++------------- .../onearg.jl | 16 +++------- .../twoarg.jl | 14 ++++----- .../utils.jl | 2 +- .../test/Back/Mooncake/Project.toml | 3 ++ .../test/Back/Mooncake/test.jl | 10 ++++--- 8 files changed, 36 insertions(+), 54 deletions(-) diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index f54fd1f4f..af16a6ae2 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -69,7 +69,7 @@ ForwardDiff = "0.10.36,1" GPUArraysCore = "0.2" GTPSA = "1.4.0" LinearAlgebra = "1" -Mooncake = "0.5.0" +Mooncake = "0.5.1" PolyesterForwardDiff = "0.1.2" ReverseDiff = "1.15.1" SparseArrays = "1" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl index b64549b89..5615cdbd7 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl @@ -1,9 +1,8 @@ ## Pushforward -struct MooncakeOneArgPushforwardPrep{SIG, Tcache, DX, FT, CT} <: DI.PushforwardPrep{SIG} +struct MooncakeOneArgPushforwardPrep{SIG, Tcache, FT, CT} <: DI.PushforwardPrep{SIG} _sig::Val{SIG} cache::Tcache - dx_righttype::DX df::FT context_tangents::CT end @@ -21,10 +20,9 @@ function DI.prepare_pushforward_nokwarg( cache = prepare_derivative_cache( f, x, map(DI.unwrap, contexts)...; config ) - dx_righttype = zero_tangent(x) df = zero_tangent(f) context_tangents = map(zero_tangent_unwrap, contexts) - prep = MooncakeOneArgPushforwardPrep(_sig, cache, dx_righttype, df, context_tangents) + prep = MooncakeOneArgPushforwardPrep(_sig, cache, df, context_tangents) return prep end @@ -38,13 +36,11 @@ function DI.value_and_pushforward( ) where {F, C, X} DI.check_prep(f, prep, backend, x, tx, contexts...) ys_and_ty = map(tx) do dx - dx_righttype = - dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx) y_dual = value_and_derivative!!( prep.cache, - Dual(f, prep.df), - Dual(x, dx_righttype), - map(Dual_unwrap, contexts, prep.context_tangents)..., + (f, prep.df), + (x, dx), + map(first_unwrap, contexts, prep.context_tangents)..., ) y = primal(y_dual) dy = _copy_output(tangent(y_dual)) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl index dea6138bc..929be485f 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl @@ -1,10 +1,8 @@ ## Pushforward -struct MooncakeTwoArgPushforwardPrep{SIG, Tcache, DX, DY, FT, CT} <: DI.PushforwardPrep{SIG} +struct MooncakeTwoArgPushforwardPrep{SIG, Tcache, FT, CT} <: DI.PushforwardPrep{SIG} _sig::Val{SIG} cache::Tcache - dx_righttype::DX - dy_righttype::DY df!::FT context_tangents::CT end @@ -27,11 +25,9 @@ function DI.prepare_pushforward_nokwarg( map(DI.unwrap, contexts)...; config ) - dx_righttype = zero_tangent(x) - dy_righttype = zero_tangent(y) df! = zero_tangent(f!) context_tangents = map(zero_tangent_unwrap, contexts) - prep = MooncakeTwoArgPushforwardPrep(_sig, cache, dx_righttype, dy_righttype, df!, context_tangents) + prep = MooncakeTwoArgPushforwardPrep(_sig, cache, df!, context_tangents) return prep end @@ -46,15 +42,13 @@ function DI.value_and_pushforward( ) where {F, C, X} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ty = map(tx) do dx - dx_righttype = - dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx) y_dual = zero_dual(y) value_and_derivative!!( prep.cache, - Dual(f!, prep.df!), + (f!, prep.df!), y_dual, - Dual(x, dx_righttype), - map(Dual_unwrap, contexts, prep.context_tangents)..., + (x, dx), + map(first_unwrap, contexts, prep.context_tangents)..., ) dy = _copy_output(tangent(y_dual)) return dy @@ -87,18 +81,13 @@ function DI.value_and_pushforward!( ) where {F, C, X, Y} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) foreach(tx, ty) do dx, dy - dx_righttype = - dx isa tangent_type(X) ? dx : _copy_to_output!!(prep.dx_righttype, dx) - dy_righttype = - dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy) value_and_derivative!!( prep.cache, - Dual(f!, prep.df!), - Dual(y, dy_righttype), - Dual(x, dx_righttype), - map(Dual_unwrap, contexts, prep.context_tangents)..., + (f!, prep.df!), + (y, dy), + (x, dx), + map(first_unwrap, contexts, prep.context_tangents)..., ) - dy === dy_righttype || copyto!(dy, dy_righttype) end return y, ty end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl index e52a2a498..6131d9fe4 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl @@ -1,9 +1,8 @@ ## Pullback -struct MooncakeOneArgPullbackPrep{SIG, Tcache, DY, N} <: DI.PullbackPrep{SIG} +struct MooncakeOneArgPullbackPrep{SIG, Tcache, N} <: DI.PullbackPrep{SIG} _sig::Val{SIG} cache::Tcache - dy_righttype::DY args_to_zero::NTuple{N, Bool} end @@ -15,15 +14,13 @@ function DI.prepare_pullback_nokwarg( cache = prepare_pullback_cache( f, x, map(DI.unwrap, contexts)...; config ) - y = f(x, map(DI.unwrap, contexts)...) - dy_righttype = zero_tangent(y) contexts_tup_false = map(_ -> false, contexts) args_to_zero = ( false, # f true, # x contexts_tup_false..., ) - prep = MooncakeOneArgPullbackPrep(_sig, cache, dy_righttype, args_to_zero) + prep = MooncakeOneArgPullbackPrep(_sig, cache, args_to_zero) return prep end @@ -37,10 +34,8 @@ function DI.value_and_pullback( ) where {F, Y, C} DI.check_prep(f, prep, backend, x, ty, contexts...) dy = only(ty) - dy_righttype = dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy) new_y, (_, new_dx) = value_and_pullback!!( - prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...; - prep.args_to_zero + prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero ) return new_y, (_copy_output(new_dx),) end @@ -55,11 +50,8 @@ function DI.value_and_pullback( ) where {F, Y, C} DI.check_prep(f, prep, backend, x, ty, contexts...) ys_and_tx = map(ty) do dy - dy_righttype = - dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy) y, (_, new_dx) = value_and_pullback!!( - prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...; - prep.args_to_zero + prep.cache, dy, f, x, map(DI.unwrap, contexts)...; prep.args_to_zero ) y, _copy_output(new_dx) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl index 14b5c2503..90ed26443 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl @@ -1,7 +1,7 @@ struct MooncakeTwoArgPullbackPrep{SIG, Tcache, DY, F, N} <: DI.PullbackPrep{SIG} _sig::Val{SIG} cache::Tcache - dy_righttype::DY + dy_backup::DY target_function::F args_to_zero::NTuple{N, Bool} end @@ -29,7 +29,7 @@ function DI.prepare_pullback_nokwarg( map(DI.unwrap, contexts)...; config, ) - dy_righttype_after = zero_tangent(y) + dy_backup_after = zero_tangent(y) contexts_tup_false = map(_ -> false, contexts) args_to_zero = ( false, # target_function @@ -39,7 +39,7 @@ function DI.prepare_pullback_nokwarg( contexts_tup_false..., ) prep = MooncakeTwoArgPullbackPrep( - _sig, cache, dy_righttype_after, target_function, args_to_zero + _sig, cache, dy_backup_after, target_function, args_to_zero ) return prep end @@ -56,11 +56,11 @@ function DI.value_and_pullback( DI.check_prep(f!, y, prep, backend, x, ty, contexts...) dy = only(ty) # Prepare cotangent to add after the forward pass. - dy_righttype_after = copyto!(prep.dy_righttype, dy) + dy_backup_after = copyto!(prep.dy_backup, dy) # Run the reverse-pass and return the results. y_after, (_, _, _, dx) = value_and_pullback!!( prep.cache, - dy_righttype_after, + dy_backup_after, prep.target_function, f!, y, @@ -83,10 +83,10 @@ function DI.value_and_pullback( ) where {F, C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) tx = map(ty) do dy - dy_righttype_after = copyto!(prep.dy_righttype, dy) + dy_backup_after = copyto!(prep.dy_backup, dy) y_after, (_, _, _, dx) = value_and_pullback!!( prep.cache, - dy_righttype_after, + dy_backup_after, prep.target_function, f!, y, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl index 56e4b966a..b56437d66 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl @@ -2,4 +2,4 @@ get_config(::AnyAutoMooncake{Nothing}) = Config() get_config(backend::AnyAutoMooncake{<:Config}) = backend.config @inline zero_tangent_unwrap(c::DI.Context) = zero_tangent(DI.unwrap(c)) -@inline Dual_unwrap(c, dc) = Dual(DI.unwrap(c), dc) +@inline first_unwrap(c, dc) = (DI.unwrap(c), dc) diff --git a/DifferentiationInterface/test/Back/Mooncake/Project.toml b/DifferentiationInterface/test/Back/Mooncake/Project.toml index 39a878d7e..3fb318369 100644 --- a/DifferentiationInterface/test/Back/Mooncake/Project.toml +++ b/DifferentiationInterface/test/Back/Mooncake/Project.toml @@ -7,3 +7,6 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + +[sources] +Mooncake = { url = "https://github.com/gdalle/Mooncake.jl", rev = "gd/forwardcache" } diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index f2d6add4c..55891e8a4 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -7,11 +7,11 @@ using Test using ExplicitImports check_no_implicit_imports(DifferentiationInterface) - backends = [ - AutoMooncake(; config = nothing), - AutoMooncake(; config = Mooncake.Config()), - AutoMooncakeForward(; config = nothing), + AutoMooncake(), + AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)), + AutoMooncakeForward(), + AutoMooncakeForward(; config = Mooncake.Config(; friendly_tangents = true)), ] for backend in backends @@ -51,3 +51,5 @@ test_differentiation( @test grad.A == ps.B @test grad.B == ps.A end + +# TODO: test static arrays with friendly tangents! From 4bf5dc674ce5f0eb92ddfe01e286c1ccaef74d48 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 30 Jan 2026 10:55:06 +0100 Subject: [PATCH 3/9] No dual output --- .../DifferentiationInterfaceMooncakeExt/forward_onearg.jl | 6 +++--- .../DifferentiationInterfaceMooncakeExt/forward_twoarg.jl | 5 ++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl index 5615cdbd7..42bc8fb0a 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl @@ -36,14 +36,14 @@ function DI.value_and_pushforward( ) where {F, C, X} DI.check_prep(f, prep, backend, x, tx, contexts...) ys_and_ty = map(tx) do dx - y_dual = value_and_derivative!!( + y_and_dy = value_and_derivative!!( prep.cache, (f, prep.df), (x, dx), map(first_unwrap, contexts, prep.context_tangents)..., ) - y = primal(y_dual) - dy = _copy_output(tangent(y_dual)) + y = first(y_and_dy) + dy = _copy_output(last(y_and_dy)) return y, dy end y = first(ys_and_ty[1]) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl index 929be485f..b3dff8380 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl @@ -42,15 +42,14 @@ function DI.value_and_pushforward( ) where {F, C, X} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ty = map(tx) do dx - y_dual = zero_dual(y) + dy = zero_tangent(y) # TODO: remove allocation? value_and_derivative!!( prep.cache, (f!, prep.df!), - y_dual, + (y, dy), (x, dx), map(first_unwrap, contexts, prep.context_tangents)..., ) - dy = _copy_output(tangent(y_dual)) return dy end return y, ty From 604bd8ec4d78f7269141c85f355d03ccf2a60b27 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 30 Jan 2026 16:30:23 +0100 Subject: [PATCH 4/9] Fixes --- .../forward_onearg.jl | 2 +- .../forward_twoarg.jl | 10 +++++++--- .../twoarg.jl | 17 ++++++----------- .../utils.jl | 5 +++++ .../test/Back/Mooncake/Project.toml | 3 --- 5 files changed, 19 insertions(+), 18 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl index 42bc8fb0a..dbcd8f7e0 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl @@ -46,7 +46,7 @@ function DI.value_and_pushforward( dy = _copy_output(last(y_and_dy)) return y, dy end - y = first(ys_and_ty[1]) + y = _copy_output(first(ys_and_ty[1])) ty = map(last, ys_and_ty) return y, ty end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl index b3dff8380..4375d74cd 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl @@ -19,6 +19,7 @@ function DI.prepare_pushforward_nokwarg( _sig = DI.signature(f!, y, backend, x, tx, contexts...; strict) config = get_config(backend) cache = prepare_derivative_cache( + call_and_return, f!, y, x, @@ -43,14 +44,15 @@ function DI.value_and_pushforward( DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ty = map(tx) do dx dy = zero_tangent(y) # TODO: remove allocation? - value_and_derivative!!( + _, new_dy = value_and_derivative!!( prep.cache, + (call_and_return, zero_tangent(call_and_return)), (f!, prep.df!), (y, dy), (x, dx), map(first_unwrap, contexts, prep.context_tangents)..., ) - return dy + return _copy_output(new_dy) end return y, ty end @@ -80,13 +82,15 @@ function DI.value_and_pushforward!( ) where {F, C, X, Y} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) foreach(tx, ty) do dx, dy - value_and_derivative!!( + _, new_dy = value_and_derivative!!( prep.cache, + (call_and_return, zero_tangent(call_and_return)), (f!, prep.df!), (y, dy), (x, dx), map(first_unwrap, contexts, prep.context_tangents)..., ) + copyto!(dy, new_dy) end return y, ty end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl index 90ed26443..e29fe92e9 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl @@ -1,8 +1,7 @@ -struct MooncakeTwoArgPullbackPrep{SIG, Tcache, DY, F, N} <: DI.PullbackPrep{SIG} +struct MooncakeTwoArgPullbackPrep{SIG, Tcache, DY, N} <: DI.PullbackPrep{SIG} _sig::Val{SIG} cache::Tcache dy_backup::DY - target_function::F args_to_zero::NTuple{N, Bool} end @@ -16,13 +15,9 @@ function DI.prepare_pullback_nokwarg( contexts::Vararg{DI.Context, C} ) where {F, C} _sig = DI.signature(f!, y, backend, x, ty, contexts...; strict) - target_function = function (f!, y, x, contexts...) - f!(y, x, contexts...) - return y - end config = get_config(backend) cache = prepare_pullback_cache( - target_function, + call_and_return, f!, y, x, @@ -32,14 +27,14 @@ function DI.prepare_pullback_nokwarg( dy_backup_after = zero_tangent(y) contexts_tup_false = map(_ -> false, contexts) args_to_zero = ( - false, # target_function + false, # call_and_return false, # f! false, # y true, # x contexts_tup_false..., ) prep = MooncakeTwoArgPullbackPrep( - _sig, cache, dy_backup_after, target_function, args_to_zero + _sig, cache, dy_backup_after, args_to_zero ) return prep end @@ -61,7 +56,7 @@ function DI.value_and_pullback( y_after, (_, _, _, dx) = value_and_pullback!!( prep.cache, dy_backup_after, - prep.target_function, + call_and_return, f!, y, x, @@ -87,7 +82,7 @@ function DI.value_and_pullback( y_after, (_, _, _, dx) = value_and_pullback!!( prep.cache, dy_backup_after, - prep.target_function, + call_and_return, f!, y, x, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl index b56437d66..5c4c0cf87 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl @@ -3,3 +3,8 @@ get_config(backend::AnyAutoMooncake{<:Config}) = backend.config @inline zero_tangent_unwrap(c::DI.Context) = zero_tangent(DI.unwrap(c)) @inline first_unwrap(c, dc) = (DI.unwrap(c), dc) + +function call_and_return(f!::F, y, x, contexts...) where {F} + f!(y, x, contexts...) + return y +end diff --git a/DifferentiationInterface/test/Back/Mooncake/Project.toml b/DifferentiationInterface/test/Back/Mooncake/Project.toml index 3fb318369..39a878d7e 100644 --- a/DifferentiationInterface/test/Back/Mooncake/Project.toml +++ b/DifferentiationInterface/test/Back/Mooncake/Project.toml @@ -7,6 +7,3 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" - -[sources] -Mooncake = { url = "https://github.com/gdalle/Mooncake.jl", rev = "gd/forwardcache" } From aa9c806440e55d19525462d6845339a84a16ef77 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 31 Jan 2026 10:40:29 +0100 Subject: [PATCH 5/9] Fix and test friendly tangents with static arrays --- .../DifferentiationInterfaceMooncakeExt.jl | 3 ++- .../forward_onearg.jl | 6 ++---- .../forward_twoarg.jl | 10 ++++++---- .../DifferentiationInterfaceMooncakeExt/onearg.jl | 8 ++------ .../DifferentiationInterfaceMooncakeExt/twoarg.jl | 12 ++++++------ .../DifferentiationInterfaceMooncakeExt/utils.jl | 9 +++++++++ .../test/Back/Mooncake/Project.toml | 1 + .../test/Back/Mooncake/test.jl | 14 ++++++++++---- 8 files changed, 38 insertions(+), 25 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl index 561c06cc8..3513d548c 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/DifferentiationInterfaceMooncakeExt.jl @@ -29,7 +29,8 @@ using Mooncake: NoRData, primal, _copy_output, - _copy_to_output!! + _copy_to_output!!, + tangent_to_primal!! const AnyAutoMooncake{C} = Union{AutoMooncake{C}, AutoMooncakeForward{C}} diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl index dbcd8f7e0..c470b6473 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_onearg.jl @@ -17,10 +17,8 @@ function DI.prepare_pushforward_nokwarg( ) where {F, C} _sig = DI.signature(f, backend, x, tx, contexts...; strict) config = get_config(backend) - cache = prepare_derivative_cache( - f, x, map(DI.unwrap, contexts)...; config - ) - df = zero_tangent(f) + cache = prepare_derivative_cache(f, x, map(DI.unwrap, contexts)...; config) + df = zero_tangent_or_primal(f, backend) context_tangents = map(zero_tangent_unwrap, contexts) prep = MooncakeOneArgPushforwardPrep(_sig, cache, df, context_tangents) return prep diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl index 4375d74cd..20033a659 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl @@ -26,7 +26,7 @@ function DI.prepare_pushforward_nokwarg( map(DI.unwrap, contexts)...; config ) - df! = zero_tangent(f!) + df! = zero_tangent_or_primal(f!, backend) context_tangents = map(zero_tangent_unwrap, contexts) prep = MooncakeTwoArgPushforwardPrep(_sig, cache, df!, context_tangents) return prep @@ -43,10 +43,11 @@ function DI.value_and_pushforward( ) where {F, C, X} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ty = map(tx) do dx - dy = zero_tangent(y) # TODO: remove allocation? + dy = zero_tangent_or_primal(y, backend) # TODO: remove allocation? + dcall = zero_tangent_or_primal(call_and_return, backend) _, new_dy = value_and_derivative!!( prep.cache, - (call_and_return, zero_tangent(call_and_return)), + (call_and_return, dcall), (f!, prep.df!), (y, dy), (x, dx), @@ -82,9 +83,10 @@ function DI.value_and_pushforward!( ) where {F, C, X, Y} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) foreach(tx, ty) do dx, dy + dcall = zero_tangent_or_primal(call_and_return, backend) _, new_dy = value_and_derivative!!( prep.cache, - (call_and_return, zero_tangent(call_and_return)), + (call_and_return, dcall), (f!, prep.df!), (y, dy), (x, dx), diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl index 6131d9fe4..2514cdc40 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl @@ -11,9 +11,7 @@ function DI.prepare_pullback_nokwarg( ) where {F, C} _sig = DI.signature(f, backend, x, ty, contexts...; strict) config = get_config(backend) - cache = prepare_pullback_cache( - f, x, map(DI.unwrap, contexts)...; config - ) + cache = prepare_pullback_cache(f, x, map(DI.unwrap, contexts)...; config) contexts_tup_false = map(_ -> false, contexts) args_to_zero = ( false, # f @@ -113,9 +111,7 @@ function DI.prepare_gradient_nokwarg( ) where {F, C} _sig = DI.signature(f, backend, x, contexts...; strict) config = get_config(backend) - cache = prepare_gradient_cache( - f, x, map(DI.unwrap, contexts)...; config - ) + cache = prepare_gradient_cache(f, x, map(DI.unwrap, contexts)...; config) contexts_tup_false = map(_ -> false, contexts) args_to_zero = ( false, # f diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl index e29fe92e9..2b55131b9 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl @@ -24,7 +24,7 @@ function DI.prepare_pullback_nokwarg( map(DI.unwrap, contexts)...; config, ) - dy_backup_after = zero_tangent(y) + dy_backup = zero_tangent_or_primal(y, backend) contexts_tup_false = map(_ -> false, contexts) args_to_zero = ( false, # call_and_return @@ -34,7 +34,7 @@ function DI.prepare_pullback_nokwarg( contexts_tup_false..., ) prep = MooncakeTwoArgPullbackPrep( - _sig, cache, dy_backup_after, args_to_zero + _sig, cache, dy_backup, args_to_zero ) return prep end @@ -51,11 +51,11 @@ function DI.value_and_pullback( DI.check_prep(f!, y, prep, backend, x, ty, contexts...) dy = only(ty) # Prepare cotangent to add after the forward pass. - dy_backup_after = copyto!(prep.dy_backup, dy) + dy_backup = copyto!(prep.dy_backup, dy) # Run the reverse-pass and return the results. y_after, (_, _, _, dx) = value_and_pullback!!( prep.cache, - dy_backup_after, + dy_backup, call_and_return, f!, y, @@ -78,10 +78,10 @@ function DI.value_and_pullback( ) where {F, C} DI.check_prep(f!, y, prep, backend, x, ty, contexts...) tx = map(ty) do dy - dy_backup_after = copyto!(prep.dy_backup, dy) + dy_backup = copyto!(prep.dy_backup, dy) y_after, (_, _, _, dx) = value_and_pullback!!( prep.cache, - dy_backup_after, + dy_backup, call_and_return, f!, y, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl index 5c4c0cf87..8f6437b49 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl @@ -8,3 +8,12 @@ function call_and_return(f!::F, y, x, contexts...) where {F} f!(y, x, contexts...) return y end + +function zero_tangent_or_primal(x, backend::AnyAutoMooncake) + if backend.config.friendly_tangents + # zero(x) but safer + return tangent_to_primal!!(_copy_output(x), zero_tangent(x)) + else + return zero_tangent(x) + end +end diff --git a/DifferentiationInterface/test/Back/Mooncake/Project.toml b/DifferentiationInterface/test/Back/Mooncake/Project.toml index 39a878d7e..f13f37c5b 100644 --- a/DifferentiationInterface/test/Back/Mooncake/Project.toml +++ b/DifferentiationInterface/test/Back/Mooncake/Project.toml @@ -6,4 +6,5 @@ ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5" SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index 55891e8a4..7748a316b 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -9,8 +9,8 @@ check_no_implicit_imports(DifferentiationInterface) backends = [ AutoMooncake(), - AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)), AutoMooncakeForward(), + AutoMooncake(; config = Mooncake.Config(; friendly_tangents = true)), AutoMooncakeForward(; config = Mooncake.Config(; friendly_tangents = true)), ] @@ -22,7 +22,8 @@ end test_differentiation( backends, default_scenarios(; - include_constantified = true, include_cachified = true, use_tuples = true + include_batchified = false, + include_constantified = false, include_cachified = false, use_tuples = true ); excluded = SECOND_ORDER, logging = LOGGING, @@ -39,7 +40,7 @@ end # Test second-order differentiation (forward-over-reverse) test_differentiation( - [SecondOrder(AutoMooncakeForward(; config = nothing), AutoMooncake(; config = nothing))], + [SecondOrder(AutoMooncakeForward(), AutoMooncake())], excluded = EXCLUDED, logging = true, ) @@ -52,4 +53,9 @@ test_differentiation( @test grad.B == ps.A end -# TODO: test static arrays with friendly tangents! +test_differentiation( + backends[3:4], + static_scenarios(); + logging = LOGGING, + excluded = SECOND_ORDER +) From 0b13cda1c41b11a82b87d6189c788ee17f0d3a4a Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 31 Jan 2026 11:01:35 +0100 Subject: [PATCH 6/9] Fix config --- .github/workflows/Test.yml | 276 +++++++++--------- .../utils.jl | 2 +- 2 files changed, 139 insertions(+), 139 deletions(-) diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml index 69ecf753e..52689f618 100644 --- a/.github/workflows/Test.yml +++ b/.github/workflows/Test.yml @@ -16,71 +16,71 @@ concurrency: cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} jobs: - test-DI-Core: - name: ${{ matrix.version }} - DI Core (${{ matrix.group }}) - runs-on: ubuntu-latest - if: ${{ !contains(github.event.pull_request.labels.*.name, 'skipci') }} - timeout-minutes: 120 - permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created - actions: write - contents: read - strategy: - fail-fast: false # TODO: toggle - matrix: - version: - - '1.10' - - '1.11' - - '1.12' - group: - - Internals - - SimpleFiniteDiff - - ZeroBackends - skip_lts: - - ${{ github.event.pull_request.draft }} - skip_pre: - - ${{ github.event.pull_request.draft }} - exclude: - - skip_lts: true - version: '1.10' - - skip_pre: true - version: '1.12' - env: - JULIA_DI_TEST_TYPE: 'Core' - JULIA_DI_TEST_GROUP: ${{ matrix.group }} - JULIA_DI_PR_DRAFT: ${{ github.event.pull_request.draft }} - steps: - - uses: actions/checkout@v6 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.version }} - arch: x64 - - uses: julia-actions/cache@v2 - - name: Install dependencies & run tests - run: julia --color=yes -e ' - using Pkg; - Pkg.activate("./DifferentiationInterface/test"); - if VERSION < v"1.11"; - Pkg.rm("DifferentiationInterfaceTest"); - Pkg.resolve(); - else; - Pkg.develop(; path="./DifferentiationInterfaceTest"); - end; - Pkg.activate("./DifferentiationInterface"); - test_kwargs = (; allow_reresolve=false, coverage=true); - if ENV["JULIA_DI_PR_DRAFT"] == "true"; - Pkg.test("DifferentiationInterface"; julia_args=["-O1"], test_kwargs...); - else; - Pkg.test("DifferentiationInterface"; test_kwargs...); - end;' - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: ./DifferentiationInterface/src,./DifferentiationInterface/ext,./DifferentiationInterface/test - - uses: codecov/codecov-action@v5 - with: - files: lcov.info - flags: DI - token: ${{ secrets.CODECOV_TOKEN }} - fail_ci_if_error: false + # test-DI-Core: + # name: ${{ matrix.version }} - DI Core (${{ matrix.group }}) + # runs-on: ubuntu-latest + # if: ${{ !contains(github.event.pull_request.labels.*.name, 'skipci') }} + # timeout-minutes: 120 + # permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created + # actions: write + # contents: read + # strategy: + # fail-fast: false # TODO: toggle + # matrix: + # version: + # - '1.10' + # - '1.11' + # - '1.12' + # group: + # - Internals + # - SimpleFiniteDiff + # - ZeroBackends + # skip_lts: + # - ${{ github.event.pull_request.draft }} + # skip_pre: + # - ${{ github.event.pull_request.draft }} + # exclude: + # - skip_lts: true + # version: '1.10' + # - skip_pre: true + # version: '1.12' + # env: + # JULIA_DI_TEST_TYPE: 'Core' + # JULIA_DI_TEST_GROUP: ${{ matrix.group }} + # JULIA_DI_PR_DRAFT: ${{ github.event.pull_request.draft }} + # steps: + # - uses: actions/checkout@v6 + # - uses: julia-actions/setup-julia@v2 + # with: + # version: ${{ matrix.version }} + # arch: x64 + # - uses: julia-actions/cache@v2 + # - name: Install dependencies & run tests + # run: julia --color=yes -e ' + # using Pkg; + # Pkg.activate("./DifferentiationInterface/test"); + # if VERSION < v"1.11"; + # Pkg.rm("DifferentiationInterfaceTest"); + # Pkg.resolve(); + # else; + # Pkg.develop(; path="./DifferentiationInterfaceTest"); + # end; + # Pkg.activate("./DifferentiationInterface"); + # test_kwargs = (; allow_reresolve=false, coverage=true); + # if ENV["JULIA_DI_PR_DRAFT"] == "true"; + # Pkg.test("DifferentiationInterface"; julia_args=["-O1"], test_kwargs...); + # else; + # Pkg.test("DifferentiationInterface"; test_kwargs...); + # end;' + # - uses: julia-actions/julia-processcoverage@v1 + # with: + # directories: ./DifferentiationInterface/src,./DifferentiationInterface/ext,./DifferentiationInterface/test + # - uses: codecov/codecov-action@v5 + # with: + # files: lcov.info + # flags: DI + # token: ${{ secrets.CODECOV_TOKEN }} + # fail_ci_if_error: false test-DI-Backend: name: ${{ matrix.version }} - DI Back (${{ matrix.group }}) @@ -98,22 +98,22 @@ jobs: - '1.11' - '1.12' group: - - ChainRules - - DifferentiateWith - # - Diffractor - - Enzyme - - FastDifferentiation - - FiniteDiff - - FiniteDifferences - - ForwardDiff - - GTPSA + # - ChainRules + # - DifferentiateWith + # # - Diffractor + # - Enzyme + # - FastDifferentiation + # - FiniteDiff + # - FiniteDifferences + # - ForwardDiff + # - GTPSA - Mooncake - - PolyesterForwardDiff - - ReverseDiff - - SparsityDetector - - Symbolics - - Tracker - - Zygote + # - PolyesterForwardDiff + # - ReverseDiff + # - SparsityDetector + # - Symbolics + # - Tracker + # - Zygote skip_lts: - ${{ github.event.pull_request.draft }} skip_pre: @@ -157,61 +157,61 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} fail_ci_if_error: false - test-DIT: - name: ${{ matrix.version }} - DIT (${{ matrix.group }}) - runs-on: ubuntu-latest - if: ${{ !contains(github.event.pull_request.labels.*.name, 'skipci') }} - timeout-minutes: 60 - permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created - actions: write - contents: read - strategy: - fail-fast: false # TODO: toggle - matrix: - version: - - '1.10' - - '1.11' - - '1.12' - group: - - Formalities - - Zero - - Standard - - Weird - skip_lts: - - ${{ github.event.pull_request.draft }} - skip_pre: - - ${{ github.event.pull_request.draft }} - exclude: - - skip_lts: true - version: '1.10' - - skip_pre: true - version: '1.12' - env: - JULIA_DIT_TEST_GROUP: ${{ matrix.group }} - JULIA_DI_PR_DRAFT: ${{ github.event.pull_request.draft }} - steps: - - uses: actions/checkout@v6 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.version }} - arch: x64 - - uses: julia-actions/cache@v2 - - name: Install dependencies & run tests - run: julia --project=./DifferentiationInterfaceTest --color=yes -e ' - using Pkg; - Pkg.Registry.update(); - Pkg.develop(path="./DifferentiationInterface"); - if ENV["JULIA_DI_PR_DRAFT"] == "true"; - Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true, julia_args=["-O1"]); - else; - Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true); - end;' - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: ./DifferentiationInterfaceTest/src,./DifferentiationInterfaceTest/ext,./DifferentiationInterfaceTest/test - - uses: codecov/codecov-action@v5 - with: - files: lcov.info - flags: DIT - token: ${{ secrets.CODECOV_TOKEN }} - fail_ci_if_error: false + # test-DIT: + # name: ${{ matrix.version }} - DIT (${{ matrix.group }}) + # runs-on: ubuntu-latest + # if: ${{ !contains(github.event.pull_request.labels.*.name, 'skipci') }} + # timeout-minutes: 60 + # permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created + # actions: write + # contents: read + # strategy: + # fail-fast: false # TODO: toggle + # matrix: + # version: + # - '1.10' + # - '1.11' + # - '1.12' + # group: + # - Formalities + # - Zero + # - Standard + # - Weird + # skip_lts: + # - ${{ github.event.pull_request.draft }} + # skip_pre: + # - ${{ github.event.pull_request.draft }} + # exclude: + # - skip_lts: true + # version: '1.10' + # - skip_pre: true + # version: '1.12' + # env: + # JULIA_DIT_TEST_GROUP: ${{ matrix.group }} + # JULIA_DI_PR_DRAFT: ${{ github.event.pull_request.draft }} + # steps: + # - uses: actions/checkout@v6 + # - uses: julia-actions/setup-julia@v2 + # with: + # version: ${{ matrix.version }} + # arch: x64 + # - uses: julia-actions/cache@v2 + # - name: Install dependencies & run tests + # run: julia --project=./DifferentiationInterfaceTest --color=yes -e ' + # using Pkg; + # Pkg.Registry.update(); + # Pkg.develop(path="./DifferentiationInterface"); + # if ENV["JULIA_DI_PR_DRAFT"] == "true"; + # Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true, julia_args=["-O1"]); + # else; + # Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true); + # end;' + # - uses: julia-actions/julia-processcoverage@v1 + # with: + # directories: ./DifferentiationInterfaceTest/src,./DifferentiationInterfaceTest/ext,./DifferentiationInterfaceTest/test + # - uses: codecov/codecov-action@v5 + # with: + # files: lcov.info + # flags: DIT + # token: ${{ secrets.CODECOV_TOKEN }} + # fail_ci_if_error: false diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl index 8f6437b49..b22d8d49b 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/utils.jl @@ -10,7 +10,7 @@ function call_and_return(f!::F, y, x, contexts...) where {F} end function zero_tangent_or_primal(x, backend::AnyAutoMooncake) - if backend.config.friendly_tangents + if get_config(backend).friendly_tangents # zero(x) but safer return tangent_to_primal!!(_copy_output(x), zero_tangent(x)) else From ed3de268c919a384fac371fb0731f54ac77064d3 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 31 Jan 2026 11:05:08 +0100 Subject: [PATCH 7/9] No logging --- DifferentiationInterface/test/Back/Mooncake/test.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index 7748a316b..6156b3aec 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -42,7 +42,7 @@ end test_differentiation( [SecondOrder(AutoMooncakeForward(), AutoMooncake())], excluded = EXCLUDED, - logging = true, + logging = LOGGING, ) @testset "NamedTuples" begin From 00b7f09ad7fed89502985e4fb4c63a399bcfb741 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 31 Jan 2026 11:54:45 +0100 Subject: [PATCH 8/9] Reactivate other tests --- .github/workflows/Test.yml | 276 ++++++++++++++++++------------------- 1 file changed, 138 insertions(+), 138 deletions(-) diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml index 52689f618..69ecf753e 100644 --- a/.github/workflows/Test.yml +++ b/.github/workflows/Test.yml @@ -16,71 +16,71 @@ concurrency: cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} jobs: - # test-DI-Core: - # name: ${{ matrix.version }} - DI Core (${{ matrix.group }}) - # runs-on: ubuntu-latest - # if: ${{ !contains(github.event.pull_request.labels.*.name, 'skipci') }} - # timeout-minutes: 120 - # permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created - # actions: write - # contents: read - # strategy: - # fail-fast: false # TODO: toggle - # matrix: - # version: - # - '1.10' - # - '1.11' - # - '1.12' - # group: - # - Internals - # - SimpleFiniteDiff - # - ZeroBackends - # skip_lts: - # - ${{ github.event.pull_request.draft }} - # skip_pre: - # - ${{ github.event.pull_request.draft }} - # exclude: - # - skip_lts: true - # version: '1.10' - # - skip_pre: true - # version: '1.12' - # env: - # JULIA_DI_TEST_TYPE: 'Core' - # JULIA_DI_TEST_GROUP: ${{ matrix.group }} - # JULIA_DI_PR_DRAFT: ${{ github.event.pull_request.draft }} - # steps: - # - uses: actions/checkout@v6 - # - uses: julia-actions/setup-julia@v2 - # with: - # version: ${{ matrix.version }} - # arch: x64 - # - uses: julia-actions/cache@v2 - # - name: Install dependencies & run tests - # run: julia --color=yes -e ' - # using Pkg; - # Pkg.activate("./DifferentiationInterface/test"); - # if VERSION < v"1.11"; - # Pkg.rm("DifferentiationInterfaceTest"); - # Pkg.resolve(); - # else; - # Pkg.develop(; path="./DifferentiationInterfaceTest"); - # end; - # Pkg.activate("./DifferentiationInterface"); - # test_kwargs = (; allow_reresolve=false, coverage=true); - # if ENV["JULIA_DI_PR_DRAFT"] == "true"; - # Pkg.test("DifferentiationInterface"; julia_args=["-O1"], test_kwargs...); - # else; - # Pkg.test("DifferentiationInterface"; test_kwargs...); - # end;' - # - uses: julia-actions/julia-processcoverage@v1 - # with: - # directories: ./DifferentiationInterface/src,./DifferentiationInterface/ext,./DifferentiationInterface/test - # - uses: codecov/codecov-action@v5 - # with: - # files: lcov.info - # flags: DI - # token: ${{ secrets.CODECOV_TOKEN }} - # fail_ci_if_error: false + test-DI-Core: + name: ${{ matrix.version }} - DI Core (${{ matrix.group }}) + runs-on: ubuntu-latest + if: ${{ !contains(github.event.pull_request.labels.*.name, 'skipci') }} + timeout-minutes: 120 + permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created + actions: write + contents: read + strategy: + fail-fast: false # TODO: toggle + matrix: + version: + - '1.10' + - '1.11' + - '1.12' + group: + - Internals + - SimpleFiniteDiff + - ZeroBackends + skip_lts: + - ${{ github.event.pull_request.draft }} + skip_pre: + - ${{ github.event.pull_request.draft }} + exclude: + - skip_lts: true + version: '1.10' + - skip_pre: true + version: '1.12' + env: + JULIA_DI_TEST_TYPE: 'Core' + JULIA_DI_TEST_GROUP: ${{ matrix.group }} + JULIA_DI_PR_DRAFT: ${{ github.event.pull_request.draft }} + steps: + - uses: actions/checkout@v6 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + arch: x64 + - uses: julia-actions/cache@v2 + - name: Install dependencies & run tests + run: julia --color=yes -e ' + using Pkg; + Pkg.activate("./DifferentiationInterface/test"); + if VERSION < v"1.11"; + Pkg.rm("DifferentiationInterfaceTest"); + Pkg.resolve(); + else; + Pkg.develop(; path="./DifferentiationInterfaceTest"); + end; + Pkg.activate("./DifferentiationInterface"); + test_kwargs = (; allow_reresolve=false, coverage=true); + if ENV["JULIA_DI_PR_DRAFT"] == "true"; + Pkg.test("DifferentiationInterface"; julia_args=["-O1"], test_kwargs...); + else; + Pkg.test("DifferentiationInterface"; test_kwargs...); + end;' + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: ./DifferentiationInterface/src,./DifferentiationInterface/ext,./DifferentiationInterface/test + - uses: codecov/codecov-action@v5 + with: + files: lcov.info + flags: DI + token: ${{ secrets.CODECOV_TOKEN }} + fail_ci_if_error: false test-DI-Backend: name: ${{ matrix.version }} - DI Back (${{ matrix.group }}) @@ -98,22 +98,22 @@ jobs: - '1.11' - '1.12' group: - # - ChainRules - # - DifferentiateWith - # # - Diffractor - # - Enzyme - # - FastDifferentiation - # - FiniteDiff - # - FiniteDifferences - # - ForwardDiff - # - GTPSA + - ChainRules + - DifferentiateWith + # - Diffractor + - Enzyme + - FastDifferentiation + - FiniteDiff + - FiniteDifferences + - ForwardDiff + - GTPSA - Mooncake - # - PolyesterForwardDiff - # - ReverseDiff - # - SparsityDetector - # - Symbolics - # - Tracker - # - Zygote + - PolyesterForwardDiff + - ReverseDiff + - SparsityDetector + - Symbolics + - Tracker + - Zygote skip_lts: - ${{ github.event.pull_request.draft }} skip_pre: @@ -157,61 +157,61 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} fail_ci_if_error: false - # test-DIT: - # name: ${{ matrix.version }} - DIT (${{ matrix.group }}) - # runs-on: ubuntu-latest - # if: ${{ !contains(github.event.pull_request.labels.*.name, 'skipci') }} - # timeout-minutes: 60 - # permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created - # actions: write - # contents: read - # strategy: - # fail-fast: false # TODO: toggle - # matrix: - # version: - # - '1.10' - # - '1.11' - # - '1.12' - # group: - # - Formalities - # - Zero - # - Standard - # - Weird - # skip_lts: - # - ${{ github.event.pull_request.draft }} - # skip_pre: - # - ${{ github.event.pull_request.draft }} - # exclude: - # - skip_lts: true - # version: '1.10' - # - skip_pre: true - # version: '1.12' - # env: - # JULIA_DIT_TEST_GROUP: ${{ matrix.group }} - # JULIA_DI_PR_DRAFT: ${{ github.event.pull_request.draft }} - # steps: - # - uses: actions/checkout@v6 - # - uses: julia-actions/setup-julia@v2 - # with: - # version: ${{ matrix.version }} - # arch: x64 - # - uses: julia-actions/cache@v2 - # - name: Install dependencies & run tests - # run: julia --project=./DifferentiationInterfaceTest --color=yes -e ' - # using Pkg; - # Pkg.Registry.update(); - # Pkg.develop(path="./DifferentiationInterface"); - # if ENV["JULIA_DI_PR_DRAFT"] == "true"; - # Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true, julia_args=["-O1"]); - # else; - # Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true); - # end;' - # - uses: julia-actions/julia-processcoverage@v1 - # with: - # directories: ./DifferentiationInterfaceTest/src,./DifferentiationInterfaceTest/ext,./DifferentiationInterfaceTest/test - # - uses: codecov/codecov-action@v5 - # with: - # files: lcov.info - # flags: DIT - # token: ${{ secrets.CODECOV_TOKEN }} - # fail_ci_if_error: false + test-DIT: + name: ${{ matrix.version }} - DIT (${{ matrix.group }}) + runs-on: ubuntu-latest + if: ${{ !contains(github.event.pull_request.labels.*.name, 'skipci') }} + timeout-minutes: 60 + permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created + actions: write + contents: read + strategy: + fail-fast: false # TODO: toggle + matrix: + version: + - '1.10' + - '1.11' + - '1.12' + group: + - Formalities + - Zero + - Standard + - Weird + skip_lts: + - ${{ github.event.pull_request.draft }} + skip_pre: + - ${{ github.event.pull_request.draft }} + exclude: + - skip_lts: true + version: '1.10' + - skip_pre: true + version: '1.12' + env: + JULIA_DIT_TEST_GROUP: ${{ matrix.group }} + JULIA_DI_PR_DRAFT: ${{ github.event.pull_request.draft }} + steps: + - uses: actions/checkout@v6 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + arch: x64 + - uses: julia-actions/cache@v2 + - name: Install dependencies & run tests + run: julia --project=./DifferentiationInterfaceTest --color=yes -e ' + using Pkg; + Pkg.Registry.update(); + Pkg.develop(path="./DifferentiationInterface"); + if ENV["JULIA_DI_PR_DRAFT"] == "true"; + Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true, julia_args=["-O1"]); + else; + Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true); + end;' + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: ./DifferentiationInterfaceTest/src,./DifferentiationInterfaceTest/ext,./DifferentiationInterfaceTest/test + - uses: codecov/codecov-action@v5 + with: + files: lcov.info + flags: DIT + token: ${{ secrets.CODECOV_TOKEN }} + fail_ci_if_error: false From 6639952594a70db85c708a580099ed727f4f8782 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sat, 31 Jan 2026 19:24:48 +0100 Subject: [PATCH 9/9] Retoggle tests --- .../forward_twoarg.jl | 17 +++++++++-------- .../test/Back/Mooncake/test.jl | 3 +-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl index 20033a659..3c75f530b 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/forward_twoarg.jl @@ -1,9 +1,11 @@ ## Pushforward -struct MooncakeTwoArgPushforwardPrep{SIG, Tcache, FT, CT} <: DI.PushforwardPrep{SIG} +struct MooncakeTwoArgPushforwardPrep{SIG, Tcache, FT0, FT, YT, CT} <: DI.PushforwardPrep{SIG} _sig::Val{SIG} cache::Tcache + dcall::FT0 df!::FT + dy::YT context_tangents::CT end @@ -26,9 +28,11 @@ function DI.prepare_pushforward_nokwarg( map(DI.unwrap, contexts)...; config ) + dcall = zero_tangent_or_primal(call_and_return, backend) df! = zero_tangent_or_primal(f!, backend) + dy = zero_tangent_or_primal(y, backend) context_tangents = map(zero_tangent_unwrap, contexts) - prep = MooncakeTwoArgPushforwardPrep(_sig, cache, df!, context_tangents) + prep = MooncakeTwoArgPushforwardPrep(_sig, cache, dcall, df!, dy, context_tangents) return prep end @@ -43,13 +47,11 @@ function DI.value_and_pushforward( ) where {F, C, X} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) ty = map(tx) do dx - dy = zero_tangent_or_primal(y, backend) # TODO: remove allocation? - dcall = zero_tangent_or_primal(call_and_return, backend) _, new_dy = value_and_derivative!!( prep.cache, - (call_and_return, dcall), + (call_and_return, prep.dcall), (f!, prep.df!), - (y, dy), + (y, prep.dy), (x, dx), map(first_unwrap, contexts, prep.context_tangents)..., ) @@ -83,10 +85,9 @@ function DI.value_and_pushforward!( ) where {F, C, X, Y} DI.check_prep(f!, y, prep, backend, x, tx, contexts...) foreach(tx, ty) do dx, dy - dcall = zero_tangent_or_primal(call_and_return, backend) _, new_dy = value_and_derivative!!( prep.cache, - (call_and_return, dcall), + (call_and_return, prep.dcall), (f!, prep.df!), (y, dy), (x, dx), diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index 6156b3aec..8c67aa1bb 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -22,8 +22,7 @@ end test_differentiation( backends, default_scenarios(; - include_batchified = false, - include_constantified = false, include_cachified = false, use_tuples = true + include_constantified = true, include_cachified = true, use_tuples = true ); excluded = SECOND_ORDER, logging = LOGGING,