You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The wrapper [`DifferentiateWith`](@ref) allows you to take a function and specify that it should be differentiated with the backend of your choice.
129
+
In other words, when you try to differentiate `dw = DifferentiateWith(f, backend1)` with `backend2`, then `backend1` steps in and `backend2` does nothing.
130
+
At the moment it only works when `backend2` supports [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl).
Copy file name to clipboardExpand all lines: DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl
Callable function wrapper that enforces differentiation with a specified (inner) backend.
5
+
6
+
This works by defining new rules overriding the behavior of the outer backend that would normally be used.
7
+
8
+
!!! warning
9
+
This is an experimental functionality, whose API cannot yet be considered stable.
10
+
At the moment, it only supports one-argument functions, and rules are only defined for [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible outer backends.
11
+
12
+
# Fields
13
+
14
+
- `f`: the function in question
15
+
- `backend::AbstractADType`: the inner backend to use for differentiation
16
+
17
+
# Constructor
18
+
19
+
DifferentiateWith(f, backend)
20
+
21
+
# Example
22
+
23
+
```@repl
24
+
using DifferentiationInterface
25
+
import ForwardDiff, Zygote
26
+
27
+
function f(x)
28
+
a = Vector{eltype(x)}(undef, 1)
29
+
a[1] = sum(x) # mutation that breaks Zygote
30
+
return a[1]
31
+
end
32
+
33
+
dw = DifferentiateWith(f, AutoForwardDiff());
34
+
35
+
gradient(dw, AutoZygote(), [1.0, 2.0]) # works because it calls ForwardDiff instead
36
+
gradient(f, AutoZygote(), [1.0, 2.0]) # fails
37
+
```
38
+
"""
39
+
struct DifferentiateWith{F,B<:AbstractADType}
40
+
f::F
41
+
backend::B
42
+
end
43
+
44
+
"""
45
+
(dw::DifferentiateWith)(x)
46
+
47
+
Call the underlying function `dw.f` of a [`DifferentiateWith`](@ref) wrapper.
48
+
"""
49
+
(dw::DifferentiateWith)(x) = dw.f(x)
50
+
51
+
function Base.show(io::IO, dw::DifferentiateWith)
52
+
(; f, backend) = dw
53
+
returnprint(io, "$f differentiated with $(backend_str(backend))")
0 commit comments