A very common use case is that one wants to not only differentiate an objective, but also get some auxiliary output (intermediate results, the predictions of an ML model, data structures of a PDE solver, etc.)
For example, in JAX there is the has_aux keyword option in jax.value_and_grad, which is actually the most common usage pattern of AD in JAX I have seen. The pattern looks like this (See e.g. the flax docs for a full example in context)
def loss_fn(params):
...
return loss, extra_data
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, extra_data), grads = grad_fn(params)
I typically use some hacky workarounds to achieve similar behavior in Julia, but maybe it is common enough to solve it at the interface level?
A very common use case is that one wants to not only differentiate an objective, but also get some auxiliary output (intermediate results, the predictions of an ML model, data structures of a PDE solver, etc.)
For example, in JAX there is the
has_auxkeyword option in jax.value_and_grad, which is actually the most common usage pattern of AD in JAX I have seen. The pattern looks like this (See e.g. the flax docs for a full example in context)I typically use some hacky workarounds to achieve similar behavior in Julia, but maybe it is common enough to solve it at the interface level?