From 1d136715ba55e500d3bd97ac63615533452cf14b Mon Sep 17 00:00:00 2001 From: Qingyu Qu <2283984853@qq.com> Date: Thu, 22 Jan 2026 18:10:57 +0800 Subject: [PATCH 1/2] fix: overloaded_input_type for one-element vector input --- .../ext/DifferentiationInterfaceForwardDiffExt/misc.jl | 4 ++-- DifferentiationInterface/test/Back/ForwardDiff/test.jl | 3 +++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl index 6cf329ce9..250253002 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl @@ -28,5 +28,5 @@ DI.overloaded_input_type(prep::ForwardDiffTwoArgDerivativePrep) = typeof(prep.co DI.overloaded_input_type(prep::ForwardDiffGradientPrep) = typeof(prep.config.duals) ## Jacobian -DI.overloaded_input_type(prep::ForwardDiffOneArgJacobianPrep) = typeof(prep.config.duals[2]) -DI.overloaded_input_type(prep::ForwardDiffTwoArgJacobianPrep) = typeof(prep.config.duals[2]) +DI.overloaded_input_type(prep::ForwardDiffOneArgJacobianPrep) = typeof(last(prep.config.duals)) +DI.overloaded_input_type(prep::ForwardDiffTwoArgJacobianPrep) = typeof(last(prep.config.duals)) diff --git a/DifferentiationInterface/test/Back/ForwardDiff/test.jl b/DifferentiationInterface/test/Back/ForwardDiff/test.jl index df6050045..1cc7bac46 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/test.jl @@ -120,6 +120,9 @@ end @test DI.overloaded_input_type( prepare_jacobian(copyto!, similar(x), sparse_backend, x) ) == Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!), Float64}, Float64, 1}} + # Jacobian with one-element input + @test DI.overloaded_input_type(prepare_jacobian(copy, backend, [1.0])) == + ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy), Float64}, Float64, 1} end; include("benchmark.jl") From 44d8e954802196b9f2e4095fb91f666e0726e672 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 23 Jan 2026 08:55:59 +0100 Subject: [PATCH 2/2] Fix bugs in overloaded input type --- .../ext/DifferentiationInterfaceForwardDiffExt/misc.jl | 10 +++++++--- .../DifferentiationInterfaceForwardDiffExt/twoarg.jl | 9 +++++---- DifferentiationInterface/src/misc/overloading.jl | 2 +- DifferentiationInterface/test/Back/ForwardDiff/test.jl | 6 +++--- 4 files changed, 16 insertions(+), 11 deletions(-) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl index 250253002..87b3779d4 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl @@ -22,11 +22,15 @@ end function DI.overloaded_input_type(prep::ForwardDiffOneArgDerivativePrep) return DI.overloaded_input_type(prep.pushforward_prep) end -DI.overloaded_input_type(prep::ForwardDiffTwoArgDerivativePrep) = typeof(prep.config.duals) +function DI.overloaded_input_type( + prep::ForwardDiffTwoArgDerivativePrep{SIG, X, <:DerivativeConfig{T}} + ) where {SIG, X, T} + return typeof(Dual{T}(one(X), one(X))) +end ## Gradient DI.overloaded_input_type(prep::ForwardDiffGradientPrep) = typeof(prep.config.duals) ## Jacobian -DI.overloaded_input_type(prep::ForwardDiffOneArgJacobianPrep) = typeof(last(prep.config.duals)) -DI.overloaded_input_type(prep::ForwardDiffTwoArgJacobianPrep) = typeof(last(prep.config.duals)) +DI.overloaded_input_type(prep::ForwardDiffOneArgJacobianPrep) = typeof(prep.config.duals) +DI.overloaded_input_type(prep::ForwardDiffTwoArgJacobianPrep) = typeof(prep.config.duals[2]) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl index 6e7d53117..3930f768d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl @@ -182,20 +182,21 @@ end ### Prepared -struct ForwardDiffTwoArgDerivativePrep{SIG, C, CD} <: DI.DerivativePrep{SIG} +struct ForwardDiffTwoArgDerivativePrep{SIG, X, C, CD} <: DI.DerivativePrep{SIG} _sig::Val{SIG} + x::X config::C contexts_dual::CD end function DI.prepare_derivative_nokwarg( - strict::Val, f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context, C} - ) where {F, C} + strict::Val, f!::F, y, backend::AutoForwardDiff, x::X, contexts::Vararg{DI.Context, C} + ) where {F, C, X} _sig = DI.signature(f!, y, backend, x, contexts...; strict) tag = get_tag(f!, backend, x) config = DerivativeConfig(nothing, y, x, tag) contexts_dual = translate_toprep(dual_type(config), contexts) - return ForwardDiffTwoArgDerivativePrep(_sig, config, contexts_dual) + return ForwardDiffTwoArgDerivativePrep(_sig, copy(x), config, contexts_dual) end function DI.prepare!_derivative( diff --git a/DifferentiationInterface/src/misc/overloading.jl b/DifferentiationInterface/src/misc/overloading.jl index 605791271..de0c2fa19 100644 --- a/DifferentiationInterface/src/misc/overloading.jl +++ b/DifferentiationInterface/src/misc/overloading.jl @@ -1,7 +1,7 @@ """ overloaded_input_type(prep) -If it exists, return the overloaded input type which will be passed to the differentiated function when preparation result `prep` is reused. +If it exists, return the overloaded input type (for the differentiated argument `x`) which will be passed to the differentiated function when preparation result `prep` is reused. !!! danger diff --git a/DifferentiationInterface/test/Back/ForwardDiff/test.jl b/DifferentiationInterface/test/Back/ForwardDiff/test.jl index 1cc7bac46..5260c9cea 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/test.jl @@ -104,7 +104,7 @@ end @test DI.overloaded_input_type(prepare_derivative(copy, backend, x)) == ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy), Float64}, Float64, 1} @test DI.overloaded_input_type(prepare_derivative(copyto!, y, backend, x)) == - Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!), Float64}, Float64, 1}} + ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!), Float64}, Float64, 1} # Gradient x = [1.0, 1.0] @@ -114,7 +114,7 @@ end # Jacobian x = [1.0, 0.0, 0.0] @test DI.overloaded_input_type(prepare_jacobian(copy, backend, x)) == - ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy), Float64}, Float64, 3} + Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy), Float64}, Float64, 3}} @test DI.overloaded_input_type(prepare_jacobian(copyto!, similar(x), backend, x)) == Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!), Float64}, Float64, 3}} @test DI.overloaded_input_type( @@ -122,7 +122,7 @@ end ) == Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!), Float64}, Float64, 1}} # Jacobian with one-element input @test DI.overloaded_input_type(prepare_jacobian(copy, backend, [1.0])) == - ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy), Float64}, Float64, 1} + Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy), Float64}, Float64, 1}} end; include("benchmark.jl")