This page recaps the mathematical model of automatic differentiation used by DI, which justifies how preparation results are constructed.
It is inspired by
Consider a mathematical function $f(x, c, s) = y$ where
-
$x \in \mathcal{X}$ is the active argument (the one being differentiated)
-
$c \in \mathcal{C}$ is a constant argument (corresponds to
Constant contexts)
-
$s \in \mathcal{S}$ is a scratch argument (corresponds to
Cache contexts)
-
$y \in \mathcal{Y}$ is the output
In Julia code, some of the input arguments might be mutated, while the output may be written to as well.
Therefore, the proper model is a function $\phi(x_0, c_0, s_0, y_0) = (x_1, c_1, s_1, y_1)$ where $a_0$ is the state of argument $a$ before $f$ is run, while $a_1$ is its state after $a$ is run.
DI makes the following hypotheses on the implementation of $f$ (aka the behavior of $\phi$):
- The active argument $x$ is not mutated, so $x_1 = x_0$
- The constant argument $c$ is not mutated, so $c_1 = c_0$
- The initial value of the scratch argument $s_0$ does not matter
- The initial value of the output $y_0$ does not matter
We want to compute a Jacobian-Vector Product (JVP) $\dot{y} = \left(\frac{\partial f}{\partial x}\right) \dot{x}$ where $\dot{x} \in \mathcal{X}$ is an input tangent.
To do that, we run our AD backend on $\phi$ with input tangents $(\dot{x}_0, \dot{c}_0, \dot{s}_0, \dot{y}_0)$ and obtain $(\dot{x}_1, \dot{c}_1, \dot{s}_1, \dot{y}_1)$.
The interesting value is
$$\dot{y}_1 = \frac{\partial y_1}{\partial x_0} \dot{x}_0 + \frac{\partial y_1}{\partial c_0} \dot{c}_0 + \frac{\partial y_1}{\partial s_0} \dot{s}_0 + \frac{\partial y_1}{\partial y_0} \dot{y}_0$$
Thanks to our hypotheses 3 and 4 on the function's implementation, $\frac{\partial y_1}{\partial s_0} = 0$ and $\frac{\partial y_1}{\partial y_0} = 0$, so we are left with:
$$\dot{y}_1 = \frac{\partial y_1}{\partial x_0} \dot{x_0} + \frac{\partial y_1}{\partial c_0} \dot{c_0}$$
Thus, as long as $\dot{c}_0 = 0$, the output tangent $\dot{y}_1$ contains the correct JVP.
Let us now look at $\dot{s}_1$ with the help of hypothesis 2:
$$\dot{c}_1 = \frac{\partial c_1}{\partial x_0} \dot{x}_0 + \frac{\partial c_1}{\partial c_0} \dot{c}_0 + \frac{\partial c_1}{\partial s_0} \dot{s}_0 + \frac{\partial c_1}{\partial y_0} \dot{y}_0 = \dot{c}_0$$
The tangent of $c$ will always be preserved by differentiation.
We want to compute a Vector-Jacobian Product (VJP) $\bar{x} = \left(\frac{\partial f}{\partial x}\right)^* \bar{y}$ where $\bar{y} \in \mathcal{Y}$ is an output sensivity.
To do that, we run our AD backend on $\phi$ with output sensitivities $(\bar{x}_1, \bar{c}_1, \bar{s}_1, \bar{y}_1)$ and obtain $(\bar{x}_0, \bar{c}_0, \bar{s}_0, \bar{y}_0)$.
The interesting value is
$$\bar{x}_0 = \left(\frac{\partial x_1}{\partial x_0}\right)^* \bar{x}_1 + \left(\frac{\partial c_1}{\partial x_0}\right)^* \bar{c}_1 + \left(\frac{\partial s_1}{\partial x_0}\right)^* \bar{s}_1 + \left(\frac{\partial y_1}{\partial x_0}\right)^* \bar{y}_1$$
Thanks to our hypotheses 1 and 2 on the function's implementation, $\frac{\partial x_1}{\partial x_0} = I$ and $\frac{\partial c_1}{\partial x_0} = 0$, so we are left with:
$$\bar{x}_0 = \bar{x}_1 + \left(\frac{\partial s_1}{\partial x_0}\right)^* \bar{s}_1 + \left(\frac{\partial y_1}{\partial x_0}\right)^* \bar{y}_1$$
Thus, as long as $\bar{x}_1 = 0$ and $\bar{s}_1 = 0$, the input sensitivity $\bar{x}_0$ contains the correct VJP.
Let us now look at $\bar{s}_0$ with the help of hypothesis 3:
$$\bar{s}_0 = \left(\frac{\partial x_1}{\partial s_0}\right)^* \bar{x}_1 + \left(\frac{\partial c_1}{\partial s_0}\right)^* \bar{c}_1 + \left(\frac{\partial s_1}{\partial s_0}\right)^* \bar{s}_1 + \left(\frac{\partial y_1}{\partial s_0}\right)^* \bar{y}_1 = 0$$
The sensitivity of $s$ will always be set to $0$ by differentiation.
DI's preparation mechanism allows pre-allocating the memory for tangents and sensitivities, inside a prep object.
This object is then reused across several AD calls.
For mutable objects, each AD call performs the following transformations on the provided shadow/dual storage (Duplicated for Enzyme, Dual / CoDual for Mooncake):
- In forward mode, $\dot{a}$ is updated from $\dot{a}_0$ to $\dot{a}_1$
- In reverse mode, $\bar{a}$ is updated from $\bar{a}_1$ to $\bar{a}_0$
How to initialize shadow/dual memory inside prep?
- In forward mode, make sure that $\dot{c} = 0$.
- In reverse mode, make sure that $\bar{x} = 0$ and $\bar{s} = 0$.
Should the shadow/dual memory inside prep be reset before every AD call?
- In forward mode, no need ($\dot{c}$ will remain $0$ if it is initialized to $0$)
- In reverse mode, just set $\bar{x} = 0$ ($\bar{s}$ will be reset to $0$ at every AD call)