From a81cbefd1992ba4f0421d5dc5c75d8e742d40efd Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 13 Nov 2025 20:34:50 +0100 Subject: [PATCH 1/3] fix: speed up Mooncake reverse mode with selective zeroing --- .../onearg.jl | 36 ++++++++++++++----- .../twoarg.jl | 20 ++++++++--- 2 files changed, 44 insertions(+), 12 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl index 209367d5a..6b32142e3 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl @@ -1,9 +1,10 @@ ## Pullback -struct MooncakeOneArgPullbackPrep{SIG, Tcache, DY} <: DI.PullbackPrep{SIG} +struct MooncakeOneArgPullbackPrep{SIG, Tcache, DY, N} <: DI.PullbackPrep{SIG} _sig::Val{SIG} cache::Tcache dy_righttype::DY + args_to_zero::NTuple{N, Bool} end function DI.prepare_pullback_nokwarg( @@ -16,7 +17,12 @@ function DI.prepare_pullback_nokwarg( ) y = f(x, map(DI.unwrap, contexts)...) dy_righttype = zero_tangent(y) - prep = MooncakeOneArgPullbackPrep(_sig, cache, dy_righttype) + args_to_zero = ( + false, # f + true, # x + map(_ -> false, contexts)..., + ) + prep = MooncakeOneArgPullbackPrep(_sig, cache, dy_righttype, args_to_zero) return prep end @@ -32,7 +38,8 @@ function DI.value_and_pullback( dy = only(ty) dy_righttype = dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy) new_y, (_, new_dx) = value_and_pullback!!( - prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)... + prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...; + prep.args_to_zero ) return new_y, (_copy_output(new_dx),) end @@ -50,7 +57,8 @@ function DI.value_and_pullback( dy_righttype = dy isa tangent_type(Y) ? dy : _copy_to_output!!(prep.dy_righttype, dy) y, (_, new_dx) = value_and_pullback!!( - prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)... + prep.cache, dy_righttype, f, x, map(DI.unwrap, contexts)...; + prep.args_to_zero ) y, _copy_output(new_dx) end @@ -101,9 +109,10 @@ end ## Gradient -struct MooncakeGradientPrep{SIG, Tcache} <: DI.GradientPrep{SIG} +struct MooncakeGradientPrep{SIG, Tcache, N} <: DI.GradientPrep{SIG} _sig::Val{SIG} cache::Tcache + args_to_zero::NTuple{N, Bool} end function DI.prepare_gradient_nokwarg( @@ -114,7 +123,12 @@ function DI.prepare_gradient_nokwarg( cache = prepare_gradient_cache( f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages ) - prep = MooncakeGradientPrep(_sig, cache) + args_to_zero = ( + false, # f + true, # x + map(_ -> false, contexts)..., + ) + prep = MooncakeGradientPrep(_sig, cache, args_to_zero) return prep end @@ -126,7 +140,10 @@ function DI.value_and_gradient( contexts::Vararg{DI.Context, C}, ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) - y, (_, new_grad) = value_and_gradient!!(prep.cache, f, x, map(DI.unwrap, contexts)...) + y, (_, new_grad) = value_and_gradient!!( + prep.cache, f, x, map(DI.unwrap, contexts)...; + prep.args_to_zero + ) return y, _copy_output(new_grad) end @@ -139,7 +156,10 @@ function DI.value_and_gradient!( contexts::Vararg{DI.Context, C}, ) where {F, C} DI.check_prep(f, prep, backend, x, contexts...) - y, (_, new_grad) = value_and_gradient!!(prep.cache, f, x, map(DI.unwrap, contexts)...) + y, (_, new_grad) = value_and_gradient!!( + prep.cache, f, x, map(DI.unwrap, contexts)...; + prep.args_to_zero + ) copyto!(grad, new_grad) return y, grad end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl index da3e5b217..837c7f7a9 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl @@ -1,8 +1,9 @@ -struct MooncakeTwoArgPullbackPrep{SIG, Tcache, DY, F} <: DI.PullbackPrep{SIG} +struct MooncakeTwoArgPullbackPrep{SIG, Tcache, DY, F, N} <: DI.PullbackPrep{SIG} _sig::Val{SIG} cache::Tcache dy_righttype::DY target_function::F + args_to_zero::NTuple{N, Bool} end function DI.prepare_pullback_nokwarg( @@ -30,7 +31,16 @@ function DI.prepare_pullback_nokwarg( silence_debug_messages = config.silence_debug_messages, ) dy_righttype_after = zero_tangent(y) - prep = MooncakeTwoArgPullbackPrep(_sig, cache, dy_righttype_after, target_function) + args_to_zero = ( + false, # target_function + false, # f! + false, # y + true, # x + map(_ -> false, contexts)..., + ) + prep = MooncakeTwoArgPullbackPrep( + _sig, cache, dy_righttype_after, target_function, args_to_zero + ) return prep end @@ -55,7 +65,8 @@ function DI.value_and_pullback( f!, y, x, - map(DI.unwrap, contexts)..., + map(DI.unwrap, contexts)...; + prep.args_to_zero ) copyto!(y, y_after) return y, (_copy_output(dx),) @@ -80,7 +91,8 @@ function DI.value_and_pullback( f!, y, x, - map(DI.unwrap, contexts)..., + map(DI.unwrap, contexts)...; + prep.args_to_zero ) copyto!(y, y_after) _copy_output(dx) From b84e56c30ace330e74edcc9d56527c54485f8a60 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 13 Nov 2025 20:37:15 +0100 Subject: [PATCH 2/3] Selective tests --- .github/workflows/Test.yml | 150 +++++++++--------- .../test/Back/Mooncake/test.jl | 2 +- 2 files changed, 76 insertions(+), 76 deletions(-) diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml index 89223f597..64bc3a3a4 100644 --- a/.github/workflows/Test.yml +++ b/.github/workflows/Test.yml @@ -28,29 +28,29 @@ jobs: fail-fast: true # TODO: toggle matrix: version: - - '1.10' + # - '1.10' - '1.11' - '1.12' group: - Core/Internals - - Back/DifferentiateWith - - Core/SimpleFiniteDiff - - Back/SparsityDetector - - Core/ZeroBackends - - Back/ChainRules + # - Back/DifferentiateWith + # - Core/SimpleFiniteDiff + # - Back/SparsityDetector + # - Core/ZeroBackends + # - Back/ChainRules # - Back/Diffractor - - Back/Enzyme - - Back/FastDifferentiation - - Back/FiniteDiff - - Back/FiniteDifferences - - Back/ForwardDiff - - Back/GTPSA + # - Back/Enzyme + # - Back/FastDifferentiation + # - Back/FiniteDiff + # - Back/FiniteDifferences + # - Back/ForwardDiff + # - Back/GTPSA - Back/Mooncake - - Back/PolyesterForwardDiff - - Back/ReverseDiff - - Back/Symbolics - - Back/Tracker - - Back/Zygote + # - Back/PolyesterForwardDiff + # - Back/ReverseDiff + # - Back/Symbolics + # - Back/Tracker + # - Back/Zygote skip_lts: - ${{ github.event.pull_request.draft }} skip_pre: @@ -104,61 +104,61 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} fail_ci_if_error: false - test-DIT: - name: ${{ matrix.version }} - DIT (${{ matrix.group }}) - runs-on: ubuntu-latest - if: ${{ !contains(github.event.pull_request.labels.*.name, 'skipci') }} - timeout-minutes: 60 - permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created - actions: write - contents: read - strategy: - fail-fast: true - matrix: - version: - - '1.10' - - '1.11' - - '1.12' - group: - - Formalities - - Zero - - Standard - - Weird - skip_lts: - - ${{ github.event.pull_request.draft }} - skip_pre: - - ${{ github.event.pull_request.draft }} - exclude: - - skip_lts: true - version: '1.10' - - skip_pre: true - version: '1.12' - env: - JULIA_DIT_TEST_GROUP: ${{ matrix.group }} - JULIA_DI_PR_DRAFT: ${{ github.event.pull_request.draft }} - steps: - - uses: actions/checkout@v5 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.version }} - arch: x64 - - uses: julia-actions/cache@v2 - - name: Install dependencies & run tests - run: julia --project=./DifferentiationInterfaceTest --color=yes -e ' - using Pkg; - Pkg.Registry.update(); - Pkg.develop(path="./DifferentiationInterface"); - if ENV["JULIA_DI_PR_DRAFT"] == "true"; - Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true, julia_args=["-O1"]); - else; - Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true); - end;' - - uses: julia-actions/julia-processcoverage@v1 - with: - directories: ./DifferentiationInterfaceTest/src,./DifferentiationInterfaceTest/ext,./DifferentiationInterfaceTest/test - - uses: codecov/codecov-action@v5 - with: - files: lcov.info - flags: DIT - token: ${{ secrets.CODECOV_TOKEN }} - fail_ci_if_error: false + # test-DIT: + # name: ${{ matrix.version }} - DIT (${{ matrix.group }}) + # runs-on: ubuntu-latest + # if: ${{ !contains(github.event.pull_request.labels.*.name, 'skipci') }} + # timeout-minutes: 60 + # permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created + # actions: write + # contents: read + # strategy: + # fail-fast: true + # matrix: + # version: + # - '1.10' + # - '1.11' + # - '1.12' + # group: + # - Formalities + # - Zero + # - Standard + # - Weird + # skip_lts: + # - ${{ github.event.pull_request.draft }} + # skip_pre: + # - ${{ github.event.pull_request.draft }} + # exclude: + # - skip_lts: true + # version: '1.10' + # - skip_pre: true + # version: '1.12' + # env: + # JULIA_DIT_TEST_GROUP: ${{ matrix.group }} + # JULIA_DI_PR_DRAFT: ${{ github.event.pull_request.draft }} + # steps: + # - uses: actions/checkout@v5 + # - uses: julia-actions/setup-julia@v2 + # with: + # version: ${{ matrix.version }} + # arch: x64 + # - uses: julia-actions/cache@v2 + # - name: Install dependencies & run tests + # run: julia --project=./DifferentiationInterfaceTest --color=yes -e ' + # using Pkg; + # Pkg.Registry.update(); + # Pkg.develop(path="./DifferentiationInterface"); + # if ENV["JULIA_DI_PR_DRAFT"] == "true"; + # Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true, julia_args=["-O1"]); + # else; + # Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true); + # end;' + # - uses: julia-actions/julia-processcoverage@v1 + # with: + # directories: ./DifferentiationInterfaceTest/src,./DifferentiationInterfaceTest/ext,./DifferentiationInterfaceTest/test + # - uses: codecov/codecov-action@v5 + # with: + # files: lcov.info + # flags: DIT + # token: ${{ secrets.CODECOV_TOKEN }} + # fail_ci_if_error: false diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index 3b61e3547..163d94483 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -1,5 +1,5 @@ using Pkg -Pkg.add("Mooncake") +Pkg.add(url = "https://github.com/gdalle/Mooncake.jl", rev = "selective_zeroing") using DifferentiationInterface, DifferentiationInterfaceTest using Mooncake: Mooncake From 01bdd5501fccb3892a40577fc71bc0c95447dc70 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 14 Nov 2025 12:07:46 +0100 Subject: [PATCH 3/3] Mooncake feature has released --- .github/workflows/Test.yml | 150 +++++++++--------- DifferentiationInterface/Project.toml | 8 +- .../onearg.jl | 6 +- .../twoarg.jl | 3 +- .../test/Back/Mooncake/test.jl | 2 +- 5 files changed, 88 insertions(+), 81 deletions(-) diff --git a/.github/workflows/Test.yml b/.github/workflows/Test.yml index 64bc3a3a4..89223f597 100644 --- a/.github/workflows/Test.yml +++ b/.github/workflows/Test.yml @@ -28,29 +28,29 @@ jobs: fail-fast: true # TODO: toggle matrix: version: - # - '1.10' + - '1.10' - '1.11' - '1.12' group: - Core/Internals - # - Back/DifferentiateWith - # - Core/SimpleFiniteDiff - # - Back/SparsityDetector - # - Core/ZeroBackends - # - Back/ChainRules + - Back/DifferentiateWith + - Core/SimpleFiniteDiff + - Back/SparsityDetector + - Core/ZeroBackends + - Back/ChainRules # - Back/Diffractor - # - Back/Enzyme - # - Back/FastDifferentiation - # - Back/FiniteDiff - # - Back/FiniteDifferences - # - Back/ForwardDiff - # - Back/GTPSA + - Back/Enzyme + - Back/FastDifferentiation + - Back/FiniteDiff + - Back/FiniteDifferences + - Back/ForwardDiff + - Back/GTPSA - Back/Mooncake - # - Back/PolyesterForwardDiff - # - Back/ReverseDiff - # - Back/Symbolics - # - Back/Tracker - # - Back/Zygote + - Back/PolyesterForwardDiff + - Back/ReverseDiff + - Back/Symbolics + - Back/Tracker + - Back/Zygote skip_lts: - ${{ github.event.pull_request.draft }} skip_pre: @@ -104,61 +104,61 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} fail_ci_if_error: false - # test-DIT: - # name: ${{ matrix.version }} - DIT (${{ matrix.group }}) - # runs-on: ubuntu-latest - # if: ${{ !contains(github.event.pull_request.labels.*.name, 'skipci') }} - # timeout-minutes: 60 - # permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created - # actions: write - # contents: read - # strategy: - # fail-fast: true - # matrix: - # version: - # - '1.10' - # - '1.11' - # - '1.12' - # group: - # - Formalities - # - Zero - # - Standard - # - Weird - # skip_lts: - # - ${{ github.event.pull_request.draft }} - # skip_pre: - # - ${{ github.event.pull_request.draft }} - # exclude: - # - skip_lts: true - # version: '1.10' - # - skip_pre: true - # version: '1.12' - # env: - # JULIA_DIT_TEST_GROUP: ${{ matrix.group }} - # JULIA_DI_PR_DRAFT: ${{ github.event.pull_request.draft }} - # steps: - # - uses: actions/checkout@v5 - # - uses: julia-actions/setup-julia@v2 - # with: - # version: ${{ matrix.version }} - # arch: x64 - # - uses: julia-actions/cache@v2 - # - name: Install dependencies & run tests - # run: julia --project=./DifferentiationInterfaceTest --color=yes -e ' - # using Pkg; - # Pkg.Registry.update(); - # Pkg.develop(path="./DifferentiationInterface"); - # if ENV["JULIA_DI_PR_DRAFT"] == "true"; - # Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true, julia_args=["-O1"]); - # else; - # Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true); - # end;' - # - uses: julia-actions/julia-processcoverage@v1 - # with: - # directories: ./DifferentiationInterfaceTest/src,./DifferentiationInterfaceTest/ext,./DifferentiationInterfaceTest/test - # - uses: codecov/codecov-action@v5 - # with: - # files: lcov.info - # flags: DIT - # token: ${{ secrets.CODECOV_TOKEN }} - # fail_ci_if_error: false + test-DIT: + name: ${{ matrix.version }} - DIT (${{ matrix.group }}) + runs-on: ubuntu-latest + if: ${{ !contains(github.event.pull_request.labels.*.name, 'skipci') }} + timeout-minutes: 60 + permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created + actions: write + contents: read + strategy: + fail-fast: true + matrix: + version: + - '1.10' + - '1.11' + - '1.12' + group: + - Formalities + - Zero + - Standard + - Weird + skip_lts: + - ${{ github.event.pull_request.draft }} + skip_pre: + - ${{ github.event.pull_request.draft }} + exclude: + - skip_lts: true + version: '1.10' + - skip_pre: true + version: '1.12' + env: + JULIA_DIT_TEST_GROUP: ${{ matrix.group }} + JULIA_DI_PR_DRAFT: ${{ github.event.pull_request.draft }} + steps: + - uses: actions/checkout@v5 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + arch: x64 + - uses: julia-actions/cache@v2 + - name: Install dependencies & run tests + run: julia --project=./DifferentiationInterfaceTest --color=yes -e ' + using Pkg; + Pkg.Registry.update(); + Pkg.develop(path="./DifferentiationInterface"); + if ENV["JULIA_DI_PR_DRAFT"] == "true"; + Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true, julia_args=["-O1"]); + else; + Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true); + end;' + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: ./DifferentiationInterfaceTest/src,./DifferentiationInterfaceTest/ext,./DifferentiationInterfaceTest/test + - uses: codecov/codecov-action@v5 + with: + files: lcov.info + flags: DIT + token: ${{ secrets.CODECOV_TOKEN }} + fail_ci_if_error: false diff --git a/DifferentiationInterface/Project.toml b/DifferentiationInterface/Project.toml index 799808b1b..e014b3a67 100644 --- a/DifferentiationInterface/Project.toml +++ b/DifferentiationInterface/Project.toml @@ -41,7 +41,11 @@ DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"] DifferentiationInterfaceGPUArraysCoreExt = "GPUArraysCore" DifferentiationInterfaceGTPSAExt = "GTPSA" DifferentiationInterfaceMooncakeExt = "Mooncake" -DifferentiationInterfacePolyesterForwardDiffExt = ["PolyesterForwardDiff", "ForwardDiff", "DiffResults"] +DifferentiationInterfacePolyesterForwardDiffExt = [ + "PolyesterForwardDiff", + "ForwardDiff", + "DiffResults", +] DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"] DifferentiationInterfaceSparseArraysExt = "SparseArrays" DifferentiationInterfaceSparseConnectivityTracerExt = "SparseConnectivityTracer" @@ -65,7 +69,7 @@ ForwardDiff = "0.10.36,1" GPUArraysCore = "0.2" GTPSA = "1.4.0" LinearAlgebra = "1" -Mooncake = "0.4.147" +Mooncake = "0.4.175" PolyesterForwardDiff = "0.1.2" ReverseDiff = "1.15.1" SparseArrays = "1" diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl index 6b32142e3..ab9818735 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/onearg.jl @@ -17,10 +17,11 @@ function DI.prepare_pullback_nokwarg( ) y = f(x, map(DI.unwrap, contexts)...) dy_righttype = zero_tangent(y) + contexts_tup_false = map(_ -> false, contexts) args_to_zero = ( false, # f true, # x - map(_ -> false, contexts)..., + contexts_tup_false..., ) prep = MooncakeOneArgPullbackPrep(_sig, cache, dy_righttype, args_to_zero) return prep @@ -123,10 +124,11 @@ function DI.prepare_gradient_nokwarg( cache = prepare_gradient_cache( f, x, map(DI.unwrap, contexts)...; config.debug_mode, config.silence_debug_messages ) + contexts_tup_false = map(_ -> false, contexts) args_to_zero = ( false, # f true, # x - map(_ -> false, contexts)..., + contexts_tup_false..., ) prep = MooncakeGradientPrep(_sig, cache, args_to_zero) return prep diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl index 837c7f7a9..2ee11b5ae 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceMooncakeExt/twoarg.jl @@ -31,12 +31,13 @@ function DI.prepare_pullback_nokwarg( silence_debug_messages = config.silence_debug_messages, ) dy_righttype_after = zero_tangent(y) + contexts_tup_false = map(_ -> false, contexts) args_to_zero = ( false, # target_function false, # f! false, # y true, # x - map(_ -> false, contexts)..., + contexts_tup_false..., ) prep = MooncakeTwoArgPullbackPrep( _sig, cache, dy_righttype_after, target_function, args_to_zero diff --git a/DifferentiationInterface/test/Back/Mooncake/test.jl b/DifferentiationInterface/test/Back/Mooncake/test.jl index 163d94483..3b61e3547 100644 --- a/DifferentiationInterface/test/Back/Mooncake/test.jl +++ b/DifferentiationInterface/test/Back/Mooncake/test.jl @@ -1,5 +1,5 @@ using Pkg -Pkg.add(url = "https://github.com/gdalle/Mooncake.jl", rev = "selective_zeroing") +Pkg.add("Mooncake") using DifferentiationInterface, DifferentiationInterfaceTest using Mooncake: Mooncake