diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl index 6cf329ce9..87b3779d4 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl @@ -22,11 +22,15 @@ end function DI.overloaded_input_type(prep::ForwardDiffOneArgDerivativePrep) return DI.overloaded_input_type(prep.pushforward_prep) end -DI.overloaded_input_type(prep::ForwardDiffTwoArgDerivativePrep) = typeof(prep.config.duals) +function DI.overloaded_input_type( + prep::ForwardDiffTwoArgDerivativePrep{SIG, X, <:DerivativeConfig{T}} + ) where {SIG, X, T} + return typeof(Dual{T}(one(X), one(X))) +end ## Gradient DI.overloaded_input_type(prep::ForwardDiffGradientPrep) = typeof(prep.config.duals) ## Jacobian -DI.overloaded_input_type(prep::ForwardDiffOneArgJacobianPrep) = typeof(prep.config.duals[2]) +DI.overloaded_input_type(prep::ForwardDiffOneArgJacobianPrep) = typeof(prep.config.duals) DI.overloaded_input_type(prep::ForwardDiffTwoArgJacobianPrep) = typeof(prep.config.duals[2]) diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl index 6e7d53117..3930f768d 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl @@ -182,20 +182,21 @@ end ### Prepared -struct ForwardDiffTwoArgDerivativePrep{SIG, C, CD} <: DI.DerivativePrep{SIG} +struct ForwardDiffTwoArgDerivativePrep{SIG, X, C, CD} <: DI.DerivativePrep{SIG} _sig::Val{SIG} + x::X config::C contexts_dual::CD end function DI.prepare_derivative_nokwarg( - strict::Val, f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context, C} - ) where {F, C} + strict::Val, f!::F, y, backend::AutoForwardDiff, x::X, contexts::Vararg{DI.Context, C} + ) where {F, C, X} _sig = DI.signature(f!, y, backend, x, contexts...; strict) tag = get_tag(f!, backend, x) config = DerivativeConfig(nothing, y, x, tag) contexts_dual = translate_toprep(dual_type(config), contexts) - return ForwardDiffTwoArgDerivativePrep(_sig, config, contexts_dual) + return ForwardDiffTwoArgDerivativePrep(_sig, copy(x), config, contexts_dual) end function DI.prepare!_derivative( diff --git a/DifferentiationInterface/src/misc/overloading.jl b/DifferentiationInterface/src/misc/overloading.jl index 605791271..de0c2fa19 100644 --- a/DifferentiationInterface/src/misc/overloading.jl +++ b/DifferentiationInterface/src/misc/overloading.jl @@ -1,7 +1,7 @@ """ overloaded_input_type(prep) -If it exists, return the overloaded input type which will be passed to the differentiated function when preparation result `prep` is reused. +If it exists, return the overloaded input type (for the differentiated argument `x`) which will be passed to the differentiated function when preparation result `prep` is reused. !!! danger diff --git a/DifferentiationInterface/test/Back/ForwardDiff/test.jl b/DifferentiationInterface/test/Back/ForwardDiff/test.jl index df6050045..5260c9cea 100644 --- a/DifferentiationInterface/test/Back/ForwardDiff/test.jl +++ b/DifferentiationInterface/test/Back/ForwardDiff/test.jl @@ -104,7 +104,7 @@ end @test DI.overloaded_input_type(prepare_derivative(copy, backend, x)) == ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy), Float64}, Float64, 1} @test DI.overloaded_input_type(prepare_derivative(copyto!, y, backend, x)) == - Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!), Float64}, Float64, 1}} + ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!), Float64}, Float64, 1} # Gradient x = [1.0, 1.0] @@ -114,12 +114,15 @@ end # Jacobian x = [1.0, 0.0, 0.0] @test DI.overloaded_input_type(prepare_jacobian(copy, backend, x)) == - ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy), Float64}, Float64, 3} + Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy), Float64}, Float64, 3}} @test DI.overloaded_input_type(prepare_jacobian(copyto!, similar(x), backend, x)) == Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!), Float64}, Float64, 3}} @test DI.overloaded_input_type( prepare_jacobian(copyto!, similar(x), sparse_backend, x) ) == Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!), Float64}, Float64, 1}} + # Jacobian with one-element input + @test DI.overloaded_input_type(prepare_jacobian(copy, backend, [1.0])) == + Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy), Float64}, Float64, 1}} end; include("benchmark.jl")