@@ -313,7 +313,7 @@ function DI.prepare_gradient(
313313) where {F,C}
314314 fc = DI. with_contexts (f, contexts... )
315315 chunk = choose_chunk (backend, x)
316- tag = get_tag (fc , backend, x)
316+ tag = get_tag (f , backend, x)
317317 config = GradientConfig (fc, x, chunk, tag)
318318 return ForwardDiffGradientPrep (config)
319319end
@@ -329,7 +329,10 @@ function DI.value_and_gradient!(
329329 fc = DI. with_contexts (f, contexts... )
330330 result = DiffResult (zero (eltype (x)), (grad,))
331331 CHK = tag_type (backend) === Nothing
332- result = gradient! (result, fc, x, prep. config, Val (CHK))
332+ if CHK
333+ checktag (prep. config, f, x)
334+ end
335+ result = gradient! (result, fc, x, prep. config, Val (false ))
333336 y = DR. value (result)
334337 grad === DR. gradient (result) || copyto! (grad, DR. gradient (result))
335338 return y, grad
@@ -345,7 +348,10 @@ function DI.value_and_gradient(
345348 fc = DI. with_contexts (f, contexts... )
346349 result = GradientResult (x)
347350 CHK = tag_type (backend) === Nothing
348- result = gradient! (result, fc, x, prep. config, Val (CHK))
351+ if CHK
352+ checktag (prep. config, f, x)
353+ end
354+ result = gradient! (result, fc, x, prep. config, Val (false ))
349355 return DR. value (result), DR. gradient (result)
350356end
351357
@@ -359,7 +365,10 @@ function DI.gradient!(
359365) where {F,C}
360366 fc = DI. with_contexts (f, contexts... )
361367 CHK = tag_type (backend) === Nothing
362- return gradient! (grad, fc, x, prep. config, Val (CHK))
368+ if CHK
369+ checktag (prep. config, f, x)
370+ end
371+ return gradient! (grad, fc, x, prep. config, Val (false ))
363372end
364373
365374function DI. gradient (
@@ -371,7 +380,10 @@ function DI.gradient(
371380) where {F,C}
372381 fc = DI. with_contexts (f, contexts... )
373382 CHK = tag_type (backend) === Nothing
374- return gradient (fc, x, prep. config, Val (CHK))
383+ if CHK
384+ checktag (prep. config, f, x)
385+ end
386+ return gradient (fc, x, prep. config, Val (false ))
375387end
376388
377389# # Jacobian
@@ -456,7 +468,7 @@ function DI.prepare_jacobian(
456468) where {F,C}
457469 fc = DI. with_contexts (f, contexts... )
458470 chunk = choose_chunk (backend, x)
459- tag = get_tag (fc , backend, x)
471+ tag = get_tag (f , backend, x)
460472 config = JacobianConfig (fc, x, chunk, tag)
461473 return ForwardDiffOneArgJacobianPrep (config)
462474end
@@ -473,7 +485,10 @@ function DI.value_and_jacobian!(
473485 y = fc (x)
474486 result = DiffResult (y, (jac,))
475487 CHK = tag_type (backend) === Nothing
476- result = jacobian! (result, fc, x, prep. config, Val (CHK))
488+ if CHK
489+ checktag (prep. config, f, x)
490+ end
491+ result = jacobian! (result, fc, x, prep. config, Val (false ))
477492 y = DR. value (result)
478493 jac === DR. jacobian (result) || copyto! (jac, DR. jacobian (result))
479494 return y, jac
@@ -488,7 +503,10 @@ function DI.value_and_jacobian(
488503) where {F,C}
489504 fc = DI. with_contexts (f, contexts... )
490505 CHK = tag_type (backend) === Nothing
491- return fc (x), jacobian (fc, x, prep. config, Val (CHK))
506+ if CHK
507+ checktag (prep. config, f, x)
508+ end
509+ return fc (x), jacobian (fc, x, prep. config, Val (false ))
492510end
493511
494512function DI. jacobian! (
@@ -501,7 +519,10 @@ function DI.jacobian!(
501519) where {F,C}
502520 fc = DI. with_contexts (f, contexts... )
503521 CHK = tag_type (backend) === Nothing
504- return jacobian! (jac, fc, x, prep. config, Val (CHK))
522+ if CHK
523+ checktag (prep. config, f, x)
524+ end
525+ return jacobian! (jac, fc, x, prep. config, Val (false ))
505526end
506527
507528function DI. jacobian (
@@ -513,7 +534,10 @@ function DI.jacobian(
513534) where {F,C}
514535 fc = DI. with_contexts (f, contexts... )
515536 CHK = tag_type (backend) === Nothing
516- return jacobian (fc, x, prep. config, Val (CHK))
537+ if CHK
538+ checktag (prep. config, f, x)
539+ end
540+ return jacobian (fc, x, prep. config, Val (false ))
517541end
518542
519543# # Second derivative
@@ -738,7 +762,7 @@ function DI.prepare_hessian(
738762) where {F,C}
739763 fc = DI. with_contexts (f, contexts... )
740764 chunk = choose_chunk (backend, x)
741- tag = get_tag (fc , backend, x)
765+ tag = get_tag (f , backend, x)
742766 result = HessianResult (x)
743767 array_config = HessianConfig (fc, x, chunk, tag)
744768 result_config = HessianConfig (fc, result, x, chunk, tag)
@@ -755,7 +779,10 @@ function DI.hessian!(
755779) where {F,C}
756780 fc = DI. with_contexts (f, contexts... )
757781 CHK = tag_type (backend) === Nothing
758- return hessian! (hess, fc, x, prep. array_config, Val (CHK))
782+ if CHK
783+ checktag (prep. array_config, f, x)
784+ end
785+ return hessian! (hess, fc, x, prep. array_config, Val (false ))
759786end
760787
761788function DI. hessian (
@@ -767,7 +794,10 @@ function DI.hessian(
767794) where {F,C}
768795 fc = DI. with_contexts (f, contexts... )
769796 CHK = tag_type (backend) === Nothing
770- return hessian (fc, x, prep. array_config, Val (CHK))
797+ if CHK
798+ checktag (prep. array_config, f, x)
799+ end
800+ return hessian (fc, x, prep. array_config, Val (false ))
771801end
772802
773803function DI. value_gradient_and_hessian! (
@@ -782,7 +812,10 @@ function DI.value_gradient_and_hessian!(
782812 fc = DI. with_contexts (f, contexts... )
783813 result = DiffResult (one (eltype (x)), (grad, hess))
784814 CHK = tag_type (backend) === Nothing
785- result = hessian! (result, fc, x, prep. result_config, Val (CHK))
815+ if CHK
816+ checktag (prep. result_config, f, x)
817+ end
818+ result = hessian! (result, fc, x, prep. result_config, Val (false ))
786819 y = DR. value (result)
787820 grad === DR. gradient (result) || copyto! (grad, DR. gradient (result))
788821 hess === DR. hessian (result) || copyto! (hess, DR. hessian (result))
@@ -799,6 +832,9 @@ function DI.value_gradient_and_hessian(
799832 fc = DI. with_contexts (f, contexts... )
800833 result = HessianResult (x)
801834 CHK = tag_type (backend) === Nothing
802- result = hessian! (result, fc, x, prep. result_config, Val (CHK))
835+ if CHK
836+ checktag (prep. result_config, f, x)
837+ end
838+ result = hessian! (result, fc, x, prep. result_config, Val (false ))
803839 return (DR. value (result), DR. gradient (result), DR. hessian (result))
804840end
0 commit comments