diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl index d4f8570b0..47165a5d7 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl @@ -13,6 +13,7 @@ using ForwardDiff: HessianConfig, JacobianConfig, Tag, + checktag, derivative, derivative!, extract_derivative, diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl index f50f030a7..292b9b787 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl @@ -313,7 +313,7 @@ function DI.prepare_gradient( ) where {F,C} fc = DI.with_contexts(f, contexts...) chunk = choose_chunk(backend, x) - tag = get_tag(fc, backend, x) + tag = get_tag(f, backend, x) config = GradientConfig(fc, x, chunk, tag) return ForwardDiffGradientPrep(config) end @@ -329,7 +329,10 @@ function DI.value_and_gradient!( fc = DI.with_contexts(f, contexts...) result = DiffResult(zero(eltype(x)), (grad,)) CHK = tag_type(backend) === Nothing - result = gradient!(result, fc, x, prep.config, Val(CHK)) + if CHK + checktag(prep.config, f, x) + end + result = gradient!(result, fc, x, prep.config, Val(false)) y = DR.value(result) grad === DR.gradient(result) || copyto!(grad, DR.gradient(result)) return y, grad @@ -345,7 +348,10 @@ function DI.value_and_gradient( fc = DI.with_contexts(f, contexts...) result = GradientResult(x) CHK = tag_type(backend) === Nothing - result = gradient!(result, fc, x, prep.config, Val(CHK)) + if CHK + checktag(prep.config, f, x) + end + result = gradient!(result, fc, x, prep.config, Val(false)) return DR.value(result), DR.gradient(result) end @@ -359,7 +365,10 @@ function DI.gradient!( ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing - return gradient!(grad, fc, x, prep.config, Val(CHK)) + if CHK + checktag(prep.config, f, x) + end + return gradient!(grad, fc, x, prep.config, Val(false)) end function DI.gradient( @@ -371,7 +380,10 @@ function DI.gradient( ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing - return gradient(fc, x, prep.config, Val(CHK)) + if CHK + checktag(prep.config, f, x) + end + return gradient(fc, x, prep.config, Val(false)) end ## Jacobian @@ -456,7 +468,7 @@ function DI.prepare_jacobian( ) where {F,C} fc = DI.with_contexts(f, contexts...) chunk = choose_chunk(backend, x) - tag = get_tag(fc, backend, x) + tag = get_tag(f, backend, x) config = JacobianConfig(fc, x, chunk, tag) return ForwardDiffOneArgJacobianPrep(config) end @@ -473,7 +485,10 @@ function DI.value_and_jacobian!( y = fc(x) result = DiffResult(y, (jac,)) CHK = tag_type(backend) === Nothing - result = jacobian!(result, fc, x, prep.config, Val(CHK)) + if CHK + checktag(prep.config, f, x) + end + result = jacobian!(result, fc, x, prep.config, Val(false)) y = DR.value(result) jac === DR.jacobian(result) || copyto!(jac, DR.jacobian(result)) return y, jac @@ -488,7 +503,10 @@ function DI.value_and_jacobian( ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing - return fc(x), jacobian(fc, x, prep.config, Val(CHK)) + if CHK + checktag(prep.config, f, x) + end + return fc(x), jacobian(fc, x, prep.config, Val(false)) end function DI.jacobian!( @@ -501,7 +519,10 @@ function DI.jacobian!( ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing - return jacobian!(jac, fc, x, prep.config, Val(CHK)) + if CHK + checktag(prep.config, f, x) + end + return jacobian!(jac, fc, x, prep.config, Val(false)) end function DI.jacobian( @@ -513,7 +534,10 @@ function DI.jacobian( ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing - return jacobian(fc, x, prep.config, Val(CHK)) + if CHK + checktag(prep.config, f, x) + end + return jacobian(fc, x, prep.config, Val(false)) end ## Second derivative @@ -738,7 +762,7 @@ function DI.prepare_hessian( ) where {F,C} fc = DI.with_contexts(f, contexts...) chunk = choose_chunk(backend, x) - tag = get_tag(fc, backend, x) + tag = get_tag(f, backend, x) result = HessianResult(x) array_config = HessianConfig(fc, x, chunk, tag) result_config = HessianConfig(fc, result, x, chunk, tag) @@ -755,7 +779,10 @@ function DI.hessian!( ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing - return hessian!(hess, fc, x, prep.array_config, Val(CHK)) + if CHK + checktag(prep.array_config, f, x) + end + return hessian!(hess, fc, x, prep.array_config, Val(false)) end function DI.hessian( @@ -767,7 +794,10 @@ function DI.hessian( ) where {F,C} fc = DI.with_contexts(f, contexts...) CHK = tag_type(backend) === Nothing - return hessian(fc, x, prep.array_config, Val(CHK)) + if CHK + checktag(prep.array_config, f, x) + end + return hessian(fc, x, prep.array_config, Val(false)) end function DI.value_gradient_and_hessian!( @@ -782,7 +812,10 @@ function DI.value_gradient_and_hessian!( fc = DI.with_contexts(f, contexts...) result = DiffResult(one(eltype(x)), (grad, hess)) CHK = tag_type(backend) === Nothing - result = hessian!(result, fc, x, prep.result_config, Val(CHK)) + if CHK + checktag(prep.result_config, f, x) + end + result = hessian!(result, fc, x, prep.result_config, Val(false)) y = DR.value(result) grad === DR.gradient(result) || copyto!(grad, DR.gradient(result)) hess === DR.hessian(result) || copyto!(hess, DR.hessian(result)) @@ -799,6 +832,9 @@ function DI.value_gradient_and_hessian( fc = DI.with_contexts(f, contexts...) result = HessianResult(x) CHK = tag_type(backend) === Nothing - result = hessian!(result, fc, x, prep.result_config, Val(CHK)) + if CHK + checktag(prep.result_config, f, x) + end + result = hessian!(result, fc, x, prep.result_config, Val(false)) return (DR.value(result), DR.gradient(result), DR.hessian(result)) end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl index 7289c5379..291f96887 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/twoarg.jl @@ -194,7 +194,7 @@ function DI.prepare_derivative( contexts::Vararg{DI.ConstantOrFunctionOrBackend,C}, ) where {F,C} fc! = DI.with_contexts(f!, contexts...) - tag = get_tag(fc!, backend, x) + tag = get_tag(f!, backend, x) config = DerivativeConfig(fc!, y, x, tag) return ForwardDiffTwoArgDerivativePrep(config) end @@ -227,7 +227,10 @@ function DI.value_and_derivative( fc! = DI.with_contexts(f!, contexts...) result = MutableDiffResult(y, (similar(y),)) CHK = tag_type(backend) === Nothing - result = derivative!(result, fc!, y, x, prep.config, Val(CHK)) + if CHK + checktag(prep.config, f!, x) + end + result = derivative!(result, fc!, y, x, prep.config, Val(false)) return DiffResults.value(result), DiffResults.derivative(result) end @@ -243,7 +246,10 @@ function DI.value_and_derivative!( fc! = DI.with_contexts(f!, contexts...) result = MutableDiffResult(y, (der,)) CHK = tag_type(backend) === Nothing - result = derivative!(result, fc!, y, x, prep.config, Val(CHK)) + if CHK + checktag(prep.config, f!, x) + end + result = derivative!(result, fc!, y, x, prep.config, Val(false)) return DiffResults.value(result), DiffResults.derivative(result) end @@ -257,7 +263,10 @@ function DI.derivative( ) where {F,C} fc! = DI.with_contexts(f!, contexts...) CHK = tag_type(backend) === Nothing - return derivative(fc!, y, x, prep.config, Val(CHK)) + if CHK + checktag(prep.config, f!, x) + end + return derivative(fc!, y, x, prep.config, Val(false)) end function DI.derivative!( @@ -271,7 +280,10 @@ function DI.derivative!( ) where {F,C} fc! = DI.with_contexts(f!, contexts...) CHK = tag_type(backend) === Nothing - return derivative!(der, fc!, y, x, prep.config, Val(CHK)) + if CHK + checktag(prep.config, f!, x) + end + return derivative!(der, fc!, y, x, prep.config, Val(false)) end ## Jacobian @@ -364,7 +376,7 @@ function DI.prepare_jacobian( ) where {F,C} fc! = DI.with_contexts(f!, contexts...) chunk = choose_chunk(backend, x) - tag = get_tag(fc!, backend, x) + tag = get_tag(f!, backend, x) config = JacobianConfig(fc!, y, x, chunk, tag) return ForwardDiffTwoArgJacobianPrep(config) end @@ -400,7 +412,10 @@ function DI.value_and_jacobian( jac = similar(y, length(y), length(x)) result = MutableDiffResult(y, (jac,)) CHK = tag_type(backend) === Nothing - result = jacobian!(result, fc!, y, x, prep.config, Val(CHK)) + if CHK + checktag(prep.config, f!, x) + end + result = jacobian!(result, fc!, y, x, prep.config, Val(false)) return DiffResults.value(result), DiffResults.jacobian(result) end @@ -416,7 +431,10 @@ function DI.value_and_jacobian!( fc! = DI.with_contexts(f!, contexts...) result = MutableDiffResult(y, (jac,)) CHK = tag_type(backend) === Nothing - result = jacobian!(result, fc!, y, x, prep.config, Val(CHK)) + if CHK + checktag(prep.config, f!, x) + end + result = jacobian!(result, fc!, y, x, prep.config, Val(false)) return DiffResults.value(result), DiffResults.jacobian(result) end @@ -430,7 +448,10 @@ function DI.jacobian( ) where {F,C} fc! = DI.with_contexts(f!, contexts...) CHK = tag_type(backend) === Nothing - return jacobian(fc!, y, x, prep.config, Val(CHK)) + if CHK + checktag(prep.config, f!, x) + end + return jacobian(fc!, y, x, prep.config, Val(false)) end function DI.jacobian!( @@ -444,5 +465,8 @@ function DI.jacobian!( ) where {F,C} fc! = DI.with_contexts(f!, contexts...) CHK = tag_type(backend) === Nothing - return jacobian!(jac, fc!, y, x, prep.config, Val(CHK)) + if CHK + checktag(prep.config, f!, x) + end + return jacobian!(jac, fc!, y, x, prep.config, Val(false)) end