Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ using ForwardDiff:
HessianConfig,
JacobianConfig,
Tag,
checktag,
derivative,
derivative!,
extract_derivative,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ function DI.prepare_gradient(
) where {F,C}
fc = DI.with_contexts(f, contexts...)
chunk = choose_chunk(backend, x)
tag = get_tag(fc, backend, x)
tag = get_tag(f, backend, x)
config = GradientConfig(fc, x, chunk, tag)
return ForwardDiffGradientPrep(config)
end
Expand All @@ -329,7 +329,10 @@ function DI.value_and_gradient!(
fc = DI.with_contexts(f, contexts...)
result = DiffResult(zero(eltype(x)), (grad,))
CHK = tag_type(backend) === Nothing
result = gradient!(result, fc, x, prep.config, Val(CHK))
if CHK
checktag(prep.config, f, x)
end
result = gradient!(result, fc, x, prep.config, Val(false))
y = DR.value(result)
grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
return y, grad
Expand All @@ -345,7 +348,10 @@ function DI.value_and_gradient(
fc = DI.with_contexts(f, contexts...)
result = GradientResult(x)
CHK = tag_type(backend) === Nothing
result = gradient!(result, fc, x, prep.config, Val(CHK))
if CHK
checktag(prep.config, f, x)
end
result = gradient!(result, fc, x, prep.config, Val(false))
return DR.value(result), DR.gradient(result)
end

Expand All @@ -359,7 +365,10 @@ function DI.gradient!(
) where {F,C}
fc = DI.with_contexts(f, contexts...)
CHK = tag_type(backend) === Nothing
return gradient!(grad, fc, x, prep.config, Val(CHK))
if CHK
checktag(prep.config, f, x)
end
return gradient!(grad, fc, x, prep.config, Val(false))
end

function DI.gradient(
Expand All @@ -371,7 +380,10 @@ function DI.gradient(
) where {F,C}
fc = DI.with_contexts(f, contexts...)
CHK = tag_type(backend) === Nothing
return gradient(fc, x, prep.config, Val(CHK))
if CHK
checktag(prep.config, f, x)
end
return gradient(fc, x, prep.config, Val(false))
end

## Jacobian
Expand Down Expand Up @@ -456,7 +468,7 @@ function DI.prepare_jacobian(
) where {F,C}
fc = DI.with_contexts(f, contexts...)
chunk = choose_chunk(backend, x)
tag = get_tag(fc, backend, x)
tag = get_tag(f, backend, x)
config = JacobianConfig(fc, x, chunk, tag)
return ForwardDiffOneArgJacobianPrep(config)
end
Expand All @@ -473,7 +485,10 @@ function DI.value_and_jacobian!(
y = fc(x)
result = DiffResult(y, (jac,))
CHK = tag_type(backend) === Nothing
result = jacobian!(result, fc, x, prep.config, Val(CHK))
if CHK
checktag(prep.config, f, x)
end
result = jacobian!(result, fc, x, prep.config, Val(false))
y = DR.value(result)
jac === DR.jacobian(result) || copyto!(jac, DR.jacobian(result))
return y, jac
Expand All @@ -488,7 +503,10 @@ function DI.value_and_jacobian(
) where {F,C}
fc = DI.with_contexts(f, contexts...)
CHK = tag_type(backend) === Nothing
return fc(x), jacobian(fc, x, prep.config, Val(CHK))
if CHK
checktag(prep.config, f, x)
end
return fc(x), jacobian(fc, x, prep.config, Val(false))
end

function DI.jacobian!(
Expand All @@ -501,7 +519,10 @@ function DI.jacobian!(
) where {F,C}
fc = DI.with_contexts(f, contexts...)
CHK = tag_type(backend) === Nothing
return jacobian!(jac, fc, x, prep.config, Val(CHK))
if CHK
checktag(prep.config, f, x)
end
return jacobian!(jac, fc, x, prep.config, Val(false))
end

function DI.jacobian(
Expand All @@ -513,7 +534,10 @@ function DI.jacobian(
) where {F,C}
fc = DI.with_contexts(f, contexts...)
CHK = tag_type(backend) === Nothing
return jacobian(fc, x, prep.config, Val(CHK))
if CHK
checktag(prep.config, f, x)
end
return jacobian(fc, x, prep.config, Val(false))
end

## Second derivative
Expand Down Expand Up @@ -738,7 +762,7 @@ function DI.prepare_hessian(
) where {F,C}
fc = DI.with_contexts(f, contexts...)
chunk = choose_chunk(backend, x)
tag = get_tag(fc, backend, x)
tag = get_tag(f, backend, x)
result = HessianResult(x)
array_config = HessianConfig(fc, x, chunk, tag)
result_config = HessianConfig(fc, result, x, chunk, tag)
Expand All @@ -755,7 +779,10 @@ function DI.hessian!(
) where {F,C}
fc = DI.with_contexts(f, contexts...)
CHK = tag_type(backend) === Nothing
return hessian!(hess, fc, x, prep.array_config, Val(CHK))
if CHK
checktag(prep.array_config, f, x)
end
return hessian!(hess, fc, x, prep.array_config, Val(false))
end

function DI.hessian(
Expand All @@ -767,7 +794,10 @@ function DI.hessian(
) where {F,C}
fc = DI.with_contexts(f, contexts...)
CHK = tag_type(backend) === Nothing
return hessian(fc, x, prep.array_config, Val(CHK))
if CHK
checktag(prep.array_config, f, x)
end
return hessian(fc, x, prep.array_config, Val(false))
end

function DI.value_gradient_and_hessian!(
Expand All @@ -782,7 +812,10 @@ function DI.value_gradient_and_hessian!(
fc = DI.with_contexts(f, contexts...)
result = DiffResult(one(eltype(x)), (grad, hess))
CHK = tag_type(backend) === Nothing
result = hessian!(result, fc, x, prep.result_config, Val(CHK))
if CHK
checktag(prep.result_config, f, x)
end
result = hessian!(result, fc, x, prep.result_config, Val(false))
y = DR.value(result)
grad === DR.gradient(result) || copyto!(grad, DR.gradient(result))
hess === DR.hessian(result) || copyto!(hess, DR.hessian(result))
Expand All @@ -799,6 +832,9 @@ function DI.value_gradient_and_hessian(
fc = DI.with_contexts(f, contexts...)
result = HessianResult(x)
CHK = tag_type(backend) === Nothing
result = hessian!(result, fc, x, prep.result_config, Val(CHK))
if CHK
checktag(prep.result_config, f, x)
end
result = hessian!(result, fc, x, prep.result_config, Val(false))
return (DR.value(result), DR.gradient(result), DR.hessian(result))
end
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@
contexts::Vararg{DI.ConstantOrFunctionOrBackend,C},
) where {F,C}
fc! = DI.with_contexts(f!, contexts...)
tag = get_tag(fc!, backend, x)
tag = get_tag(f!, backend, x)
config = DerivativeConfig(fc!, y, x, tag)
return ForwardDiffTwoArgDerivativePrep(config)
end
Expand Down Expand Up @@ -227,7 +227,10 @@
fc! = DI.with_contexts(f!, contexts...)
result = MutableDiffResult(y, (similar(y),))
CHK = tag_type(backend) === Nothing
result = derivative!(result, fc!, y, x, prep.config, Val(CHK))
if CHK
checktag(prep.config, f!, x)
end
result = derivative!(result, fc!, y, x, prep.config, Val(false))
return DiffResults.value(result), DiffResults.derivative(result)
end

Expand All @@ -243,7 +246,10 @@
fc! = DI.with_contexts(f!, contexts...)
result = MutableDiffResult(y, (der,))
CHK = tag_type(backend) === Nothing
result = derivative!(result, fc!, y, x, prep.config, Val(CHK))
if CHK
checktag(prep.config, f!, x)
end
result = derivative!(result, fc!, y, x, prep.config, Val(false))
return DiffResults.value(result), DiffResults.derivative(result)
end

Expand All @@ -257,7 +263,10 @@
) where {F,C}
fc! = DI.with_contexts(f!, contexts...)
CHK = tag_type(backend) === Nothing
return derivative(fc!, y, x, prep.config, Val(CHK))
if CHK
checktag(prep.config, f!, x)
end
return derivative(fc!, y, x, prep.config, Val(false))
end

function DI.derivative!(
Expand All @@ -271,7 +280,10 @@
) where {F,C}
fc! = DI.with_contexts(f!, contexts...)
CHK = tag_type(backend) === Nothing
return derivative!(der, fc!, y, x, prep.config, Val(CHK))
if CHK
checktag(prep.config, f!, x)
end
return derivative!(der, fc!, y, x, prep.config, Val(false))
end

## Jacobian
Expand Down Expand Up @@ -364,7 +376,7 @@
) where {F,C}
fc! = DI.with_contexts(f!, contexts...)
chunk = choose_chunk(backend, x)
tag = get_tag(fc!, backend, x)
tag = get_tag(f!, backend, x)

Check warning on line 379 in DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl#L379

Added line #L379 was not covered by tests
config = JacobianConfig(fc!, y, x, chunk, tag)
return ForwardDiffTwoArgJacobianPrep(config)
end
Expand Down Expand Up @@ -400,7 +412,10 @@
jac = similar(y, length(y), length(x))
result = MutableDiffResult(y, (jac,))
CHK = tag_type(backend) === Nothing
result = jacobian!(result, fc!, y, x, prep.config, Val(CHK))
if CHK
checktag(prep.config, f!, x)

Check warning on line 416 in DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl#L415-L416

Added lines #L415 - L416 were not covered by tests
end
result = jacobian!(result, fc!, y, x, prep.config, Val(false))

Check warning on line 418 in DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl#L418

Added line #L418 was not covered by tests
return DiffResults.value(result), DiffResults.jacobian(result)
end

Expand All @@ -416,7 +431,10 @@
fc! = DI.with_contexts(f!, contexts...)
result = MutableDiffResult(y, (jac,))
CHK = tag_type(backend) === Nothing
result = jacobian!(result, fc!, y, x, prep.config, Val(CHK))
if CHK
checktag(prep.config, f!, x)

Check warning on line 435 in DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl#L434-L435

Added lines #L434 - L435 were not covered by tests
end
result = jacobian!(result, fc!, y, x, prep.config, Val(false))

Check warning on line 437 in DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl#L437

Added line #L437 was not covered by tests
return DiffResults.value(result), DiffResults.jacobian(result)
end

Expand All @@ -430,7 +448,10 @@
) where {F,C}
fc! = DI.with_contexts(f!, contexts...)
CHK = tag_type(backend) === Nothing
return jacobian(fc!, y, x, prep.config, Val(CHK))
if CHK
checktag(prep.config, f!, x)

Check warning on line 452 in DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl#L451-L452

Added lines #L451 - L452 were not covered by tests
end
return jacobian(fc!, y, x, prep.config, Val(false))

Check warning on line 454 in DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl#L454

Added line #L454 was not covered by tests
end

function DI.jacobian!(
Expand All @@ -444,5 +465,8 @@
) where {F,C}
fc! = DI.with_contexts(f!, contexts...)
CHK = tag_type(backend) === Nothing
return jacobian!(jac, fc!, y, x, prep.config, Val(CHK))
if CHK
checktag(prep.config, f!, x)

Check warning on line 469 in DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl#L468-L469

Added lines #L468 - L469 were not covered by tests
end
return jacobian!(jac, fc!, y, x, prep.config, Val(false))

Check warning on line 471 in DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl#L471

Added line #L471 was not covered by tests
end
Loading