Skip to content

Commit 1d13671

Browse files
committed
fix: overloaded_input_type for one-element vector input
1 parent cd515b3 commit 1d13671

2 files changed

Lines changed: 5 additions & 2 deletions

File tree

  • DifferentiationInterface
    • ext/DifferentiationInterfaceForwardDiffExt
    • test/Back/ForwardDiff

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/misc.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,5 @@ DI.overloaded_input_type(prep::ForwardDiffTwoArgDerivativePrep) = typeof(prep.co
2828
DI.overloaded_input_type(prep::ForwardDiffGradientPrep) = typeof(prep.config.duals)
2929

3030
## Jacobian
31-
DI.overloaded_input_type(prep::ForwardDiffOneArgJacobianPrep) = typeof(prep.config.duals[2])
32-
DI.overloaded_input_type(prep::ForwardDiffTwoArgJacobianPrep) = typeof(prep.config.duals[2])
31+
DI.overloaded_input_type(prep::ForwardDiffOneArgJacobianPrep) = typeof(last(prep.config.duals))
32+
DI.overloaded_input_type(prep::ForwardDiffTwoArgJacobianPrep) = typeof(last(prep.config.duals))

DifferentiationInterface/test/Back/ForwardDiff/test.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,9 @@ end
120120
@test DI.overloaded_input_type(
121121
prepare_jacobian(copyto!, similar(x), sparse_backend, x)
122122
) == Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!), Float64}, Float64, 1}}
123+
# Jacobian with one-element input
124+
@test DI.overloaded_input_type(prepare_jacobian(copy, backend, [1.0])) ==
125+
ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy), Float64}, Float64, 1}
123126
end;
124127

125128
include("benchmark.jl")

0 commit comments

Comments
 (0)