You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl
+2Lines changed: 2 additions & 0 deletions
Original file line number
Diff line number
Diff line change
@@ -7,6 +7,7 @@ import DifferentiationInterface as DI
Callable function wrapper that enforces differentiation with a specified (inner) backend.
4
+
Function wrapper that enforces differentiation with a "substitute" AD backend, possible different from the "true" AD backend that is called.
5
5
6
-
This works by defining new rules overriding the behavior of the outer backend that would normally be used.
6
+
For instance, suppose a function `f` is not differentiable with Zygote because it involves mutation, but you know that it is differentiable with Enzyme.
7
+
Then `f2 = DifferentiateWith(f, AutoEnzyme())` is a new function that behaves like `f`, except that `f2` is differentiable with Zygote (thanks to a chain rule which calls Enzyme under the hood).
8
+
Moreover, any larger algorithm `alg` that calls `f2` instead of `f` will also be differentiable with Zygote (as long as `f` was the only Zygote blocker).
9
+
10
+
!!! tip
11
+
This is mainly relevant for package developers who want to produce differentiable code at low cost, without writing the differentiation rules themselves.
12
+
If you sprinkle a few `DifferentiateWith` in places where some AD backends may struggle, end users can pick from a wider variety of packages to differentiate your algorithms.
7
13
8
14
!!! warning
9
-
This is an experimental functionality, whose API cannot yet be considered stable.
10
-
It only supports out-of-place functions, and rules are only defined for [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible outer backends.
15
+
`DifferentiateWith` only supports out-of-place functions `y = f(x)` without additional context arguments.
16
+
It only makes these functions differentiable if the true backend is either [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl) or compatible with [ChainRules](https://github.com/JuliaDiff/ChainRules.jl).
17
+
For any other true backend, the differentiation behavior is not altered by `DifferentiateWith` (it becomes a transparent wrapper).
11
18
12
19
# Fields
13
20
14
-
- `f`: the function in question
15
-
- `backend::AbstractADType`: the inner backend to use for differentiation
21
+
- `f`: the function in question, with signature `f(x)`
22
+
- `backend::AbstractADType`: the substitute backend to use for differentiation
23
+
24
+
!!! note
25
+
For the substitute AD backend to be called under the hood, its package needs to be loaded in addition to the package of the true AD backend.
16
26
17
27
# Constructor
18
28
@@ -21,40 +31,50 @@ This works by defining new rules overriding the behavior of the outer backend th
21
31
# Example
22
32
23
33
```jldoctest
24
-
using DifferentiationInterface
25
-
import ForwardDiff, Zygote
34
+
julia> using DifferentiationInterface
26
35
27
-
function f(x)
28
-
a = Vector{eltype(x)}(undef, 1)
29
-
a[1] = sum(x) # mutation that breaks Zygote
30
-
return a[1]
31
-
end
36
+
julia> import Enzyme, ForwardDiff, Zygote
37
+
38
+
julia> function f(x::Vector{Float64})
39
+
a = Vector{Float64}(undef, 1) # type constraint breaks ForwardDiff
0 commit comments