Skip to content

Commit 4d39e77

Browse files
authored
Better explanation of preparation (#536)
* Clarify documentation on preparation * Handle Val in Enzyme
1 parent 973676a commit 4d39e77

4 files changed

Lines changed: 24 additions & 14 deletions

File tree

DifferentiationInterface/docs/src/explanation/backends.md

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,12 @@ For every operator, preparation generates an [executable function](https://brian
137137
### FiniteDiff
138138

139139
Whenever possible, preparation creates a cache object.
140+
Pushforward is implemented rather slowly using a closure.
140141

141142
### FiniteDifferences
142143

144+
Nothing specific to mention.
145+
143146
### ForwardDiff
144147

145148
We implement [`pushforward`](@ref) directly using [`Dual` numbers](https://juliadiff.org/ForwardDiff.jl/stable/dev/how_it_works/), and preparation allocates the necessary space.
@@ -152,9 +155,12 @@ Most operators fall back on `AutoForwardDiff`.
152155
### ReverseDiff
153156

154157
Wherever possible, preparation records a [tape](https://juliadiff.org/ReverseDiff.jl/dev/api/#The-AbstractTape-API) of the function's execution.
158+
This tape is computed from the arguments `x` and `contexts...` provided at preparation time.
159+
It is control-flow dependent, so only one branch is recorded at each `if` statement.
155160

156-
!!! warning
157-
This tape is specific to the control flow inside the function, and cannot be reused if the control flow is value-dependent (like `if x[1] > 0`).
161+
!!! danger
162+
If your function has value-specific control flow (like `if x[1] > 0` or `if c == 1`), you may get silently wrong results whenever it takes new branches that were not taken during preparation.
163+
You must make sure to run preparation with an input and contexts whose values trigger the correct control flow for future executions.
158164

159165
### Symbolics
160166

@@ -176,4 +182,3 @@ Same-point preparation runs the forward sweep and returns the pullback closure a
176182

177183
We implement `pullback` based on `Zygote.pullback`.
178184
Same-point preparation runs the forward sweep and returns the pullback closure at `x`.
179-

DifferentiationInterface/docs/src/explanation/operators.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -125,13 +125,14 @@ Here are the general rules that we strive to implement:
125125

126126
For different-point preparation, the output `prep` of `prepare_op(f, b, x, [t])` can be reused in `op(f, prep, b, other_x, [other_t])`, provided that:
127127

128-
- the inputs `x` and `other_x` have similar types and equal shapes
129-
- the tangents in `t` and `other_t` have similar types and equal shapes
128+
- the inputs `x` and `other_x` have the same types and sizes
129+
- the tangents in `t` and `other_t` have the same types and sizes
130130

131131
For same-point preparation, the output `prep` of `prepare_op_same_point(f, b, x, [t])` can be reused in `op(f, prep, b, x, other_t)`, provided that:
132132

133-
- the input `x` remains the same (as well as the [`Context`](@ref) constants)
134-
- the tangents in `t` and `other_t` have similar types and equal shapes
133+
- the input `x` remains exactly the same (as well as any [`Constant`](@ref) context)
134+
- the tangents in `t` and `other_t` have the same types and sizes
135135

136136
!!! warning
137-
These rules hold for the majority of backends, but there are some exceptions.
137+
These rules hold for the majority of backends, but there are some exceptions.
138+
The most important exception is [ReverseDiff](@ref) and its taping mechanism, which is sensitive to control flow inside the function.

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/forward_onearg.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ end
117117
function DI.prepare_gradient(
118118
f::F, backend::AutoEnzyme{<:ForwardMode,<:Union{Nothing,Const}}, x
119119
) where {F}
120-
B = pick_batchsize(backend, length(x))
121-
shadows = create_shadows(Val(B), x)
120+
valB = pick_batchsize(backend, length(x))
121+
shadows = create_shadows(valB, x)
122122
return EnzymeForwardGradientPrep{B,typeof(shadows)}(shadows)
123123
end
124124

@@ -180,8 +180,8 @@ function DI.prepare_jacobian(
180180
f::F, backend::AutoEnzyme{<:Union{ForwardMode,Nothing},<:Union{Nothing,Const}}, x
181181
) where {F}
182182
y = f(x)
183-
B = pick_batchsize(backend, length(x))
184-
shadows = create_shadows(Val(B), x)
183+
valB = pick_batchsize(backend, length(x))
184+
shadows = create_shadows(valB, x)
185185
return EnzymeForwardOneArgJacobianPrep{B,typeof(shadows)}(shadows, length(y))
186186
end
187187

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/reverse_onearg.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -349,11 +349,15 @@ end
349349

350350
struct EnzymeReverseOneArgJacobianPrep{Sy,B} <: JacobianPrep end
351351

352+
function EnzymeReverseOneArgJacobianPrep(::Val{Sy}, ::Val{B}) where {Sy,B}
353+
return EnzymeReverseOneArgJacobianPrep{Sy,B}()
354+
end
355+
352356
function DI.prepare_jacobian(f::F, backend::AutoEnzyme{<:ReverseMode,Nothing}, x) where {F}
353357
y = f(x)
354358
Sy = size(y)
355-
B = pick_batchsize(backend, prod(Sy))
356-
return EnzymeReverseOneArgJacobianPrep{Sy,B}()
359+
valB = pick_batchsize(backend, prod(Sy))
360+
return EnzymeReverseOneArgJacobianPrep(Val(Sy), valB)
357361
end
358362

359363
function DI.jacobian(

0 commit comments

Comments
 (0)