Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@ end
function DI.overloaded_input_type(prep::ForwardDiffOneArgDerivativePrep)
return DI.overloaded_input_type(prep.pushforward_prep)
end
DI.overloaded_input_type(prep::ForwardDiffTwoArgDerivativePrep) = typeof(prep.config.duals)
function DI.overloaded_input_type(
prep::ForwardDiffTwoArgDerivativePrep{SIG, X, <:DerivativeConfig{T}}
) where {SIG, X, T}
return typeof(Dual{T}(one(X), one(X)))
end

## Gradient
DI.overloaded_input_type(prep::ForwardDiffGradientPrep) = typeof(prep.config.duals)

## Jacobian
DI.overloaded_input_type(prep::ForwardDiffOneArgJacobianPrep) = typeof(prep.config.duals[2])
DI.overloaded_input_type(prep::ForwardDiffOneArgJacobianPrep) = typeof(prep.config.duals)
DI.overloaded_input_type(prep::ForwardDiffTwoArgJacobianPrep) = typeof(prep.config.duals[2])
Original file line number Diff line number Diff line change
Expand Up @@ -182,20 +182,21 @@ end

### Prepared

struct ForwardDiffTwoArgDerivativePrep{SIG, C, CD} <: DI.DerivativePrep{SIG}
struct ForwardDiffTwoArgDerivativePrep{SIG, X, C, CD} <: DI.DerivativePrep{SIG}
_sig::Val{SIG}
x::X
config::C
contexts_dual::CD
end

function DI.prepare_derivative_nokwarg(
strict::Val, f!::F, y, backend::AutoForwardDiff, x, contexts::Vararg{DI.Context, C}
) where {F, C}
strict::Val, f!::F, y, backend::AutoForwardDiff, x::X, contexts::Vararg{DI.Context, C}
) where {F, C, X}
_sig = DI.signature(f!, y, backend, x, contexts...; strict)
tag = get_tag(f!, backend, x)
config = DerivativeConfig(nothing, y, x, tag)
contexts_dual = translate_toprep(dual_type(config), contexts)
return ForwardDiffTwoArgDerivativePrep(_sig, config, contexts_dual)
return ForwardDiffTwoArgDerivativePrep(_sig, copy(x), config, contexts_dual)
end

function DI.prepare!_derivative(
Expand Down
2 changes: 1 addition & 1 deletion DifferentiationInterface/src/misc/overloading.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
overloaded_input_type(prep)

If it exists, return the overloaded input type which will be passed to the differentiated function when preparation result `prep` is reused.
If it exists, return the overloaded input type (for the differentiated argument `x`) which will be passed to the differentiated function when preparation result `prep` is reused.

!!! danger

Expand Down
7 changes: 5 additions & 2 deletions DifferentiationInterface/test/Back/ForwardDiff/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ end
@test DI.overloaded_input_type(prepare_derivative(copy, backend, x)) ==
ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy), Float64}, Float64, 1}
@test DI.overloaded_input_type(prepare_derivative(copyto!, y, backend, x)) ==
Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!), Float64}, Float64, 1}}
ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!), Float64}, Float64, 1}

# Gradient
x = [1.0, 1.0]
Expand All @@ -114,12 +114,15 @@ end
# Jacobian
x = [1.0, 0.0, 0.0]
@test DI.overloaded_input_type(prepare_jacobian(copy, backend, x)) ==
ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy), Float64}, Float64, 3}
Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy), Float64}, Float64, 3}}
@test DI.overloaded_input_type(prepare_jacobian(copyto!, similar(x), backend, x)) ==
Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!), Float64}, Float64, 3}}
@test DI.overloaded_input_type(
prepare_jacobian(copyto!, similar(x), sparse_backend, x)
) == Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copyto!), Float64}, Float64, 1}}
# Jacobian with one-element input
@test DI.overloaded_input_type(prepare_jacobian(copy, backend, [1.0])) ==
Vector{ForwardDiff.Dual{ForwardDiff.Tag{typeof(copy), Float64}, Float64, 1}}
end;

include("benchmark.jl")
Loading