Skip to content

Commit 66f1fdc

Browse files
authored
Implement DifferentiateWith to translate between backends (#218)
* Implement DifferentiateWith to translate between backends * Fix parsing
1 parent fac64b7 commit 66f1fdc

22 files changed

Lines changed: 276 additions & 121 deletions

DifferentiationInterface/docs/src/api.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,12 @@ check_twoarg
9191
check_hessian
9292
```
9393

94+
## Translation
95+
96+
```@docs
97+
DifferentiateWith
98+
```
99+
94100
## Internals
95101

96102
This is not part of the public API.

DifferentiationInterface/docs/src/backends.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ CollapsedDocStrings = true
55

66
```@setup backends
77
using DifferentiationInterface
8-
using DifferentiationInterface: backend_string
8+
using DifferentiationInterface: backend_str
99
import Markdown
1010
import Diffractor, Enzyme, FastDifferentiation, FiniteDiff, FiniteDifferences, ForwardDiff, PolyesterForwardDiff, ReverseDiff, Tapir, Tracker, Zygote
1111
@@ -37,7 +37,7 @@ println(io, "|:--------|:------------:|:----------------------:|:---------------
3737
3838
for example in backend_examples
3939
b = eval(Meta.parse(example)) # backend
40-
join(io, [backend_string(b), unicode_check_available(b), unicode_check_twoarg(b), unicode_check_hessian(b), "`$example`"], '|')
40+
join(io, [backend_str(b), unicode_check_available(b), unicode_check_twoarg(b), unicode_check_hessian(b), "`$example`"], '|')
4141
println(io, '|' )
4242
end
4343
backend_table = Markdown.parse(String(take!(io)))

DifferentiationInterface/docs/src/overloads.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Each cell can have three values:
2424
```@setup overloads
2525
using ADTypes: AbstractADType
2626
using DifferentiationInterface
27-
using DifferentiationInterface: backend_string, mutation_support, MutationSupported
27+
using DifferentiationInterface: backend_str, mutation_support, MutationSupported
2828
using Markdown: Markdown
2929
using Diffractor: Diffractor
3030
using Enzyme: Enzyme

DifferentiationInterface/docs/src/overview.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,12 @@ We make this available for all backends with the following operators:
123123
| :--------------------------------- | :---------------------------------- |
124124
| [`value_and_pullback_split`](@ref) | [`value_and_pullback!_split`](@ref) |
125125

126+
## Translation
127+
128+
The wrapper [`DifferentiateWith`](@ref) allows you to take a function and specify that it should be differentiated with the backend of your choice.
129+
In other words, when you try to differentiate `dw = DifferentiateWith(f, backend1)` with `backend2`, then `backend1` steps in and `backend2` does nothing.
130+
At the moment it only works when `backend2` supports [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl).
131+
126132
## Going further
127133

128134
### Non-standard types

DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,15 @@ module DifferentiationInterfaceChainRulesCoreExt
22

33
using ADTypes: ADTypes, AutoChainRules
44
using ChainRulesCore:
5-
HasForwardsMode, HasReverseMode, NoTangent, RuleConfig, frule_via_ad, rrule_via_ad
5+
ChainRulesCore,
6+
HasForwardsMode,
7+
HasReverseMode,
8+
NoTangent,
9+
RuleConfig,
10+
frule_via_ad,
11+
rrule_via_ad
612
import DifferentiationInterface as DI
7-
using DifferentiationInterface: NoPullbackExtras, NoPushforwardExtras
13+
using DifferentiationInterface: DifferentiateWith, NoPullbackExtras, NoPushforwardExtras
814

915
ruleconfig(backend::AutoChainRules) = backend.ruleconfig
1016

@@ -14,32 +20,7 @@ const AutoReverseChainRules = AutoChainRules{<:RuleConfig{>:HasReverseMode}}
1420
DI.check_available(::AutoChainRules) = true
1521
DI.mutation_support(::AutoChainRules) = DI.MutationNotSupported()
1622

17-
## Pullback
18-
19-
DI.prepare_pullback(f, ::AutoReverseChainRules, x, dy) = NoPullbackExtras()
20-
21-
function DI.value_and_pullback_split(
22-
f, backend::AutoReverseChainRules, x, ::NoPullbackExtras
23-
)
24-
rc = ruleconfig(backend)
25-
y, pullback = rrule_via_ad(rc, f, x)
26-
pullbackfunc(dy) = last(pullback(dy))
27-
return y, pullbackfunc
28-
end
29-
30-
function DI.value_and_pullback!_split(
31-
f, backend::AutoReverseChainRules, x, extras::NoPullbackExtras
32-
)
33-
y, pullbackfunc = DI.value_and_pullback_split(f, backend, x, extras)
34-
pullbackfunc!(dx, dy) = copyto!(dx, pullbackfunc(dy))
35-
return y, pullbackfunc!
36-
end
37-
38-
function DI.value_and_pullback(
39-
f, backend::AutoReverseChainRules, x, dy, extras::NoPullbackExtras
40-
)
41-
y, pullbackfunc = DI.value_and_pullback_split(f, backend, x, extras)
42-
return y, pullbackfunc(dy)
43-
end
23+
include("reverse_onearg.jl")
24+
include("differentiate_with.jl")
4425

4526
end
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
function ChainRulesCore.frule((_, dx), dw::DifferentiateWith, x)
2+
(; f, backend) = dw
3+
y, dy = DI.value_and_pushforward(f, backend, x, dx)
4+
return y, dy
5+
end
6+
7+
function ChainRulesCore.rrule(dw::DifferentiateWith, x)
8+
(; f, backend) = dw
9+
y, pullbackfunc = DI.value_and_pullback_split(f, backend, x)
10+
pullbackfunc_adjusted(dy) = (NoTangent(), pullbackfunc(dy))
11+
return y, pullbackfunc_adjusted
12+
end
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
## Pullback
2+
3+
DI.prepare_pullback(f, ::AutoReverseChainRules, x, dy) = NoPullbackExtras()
4+
5+
function DI.value_and_pullback_split(
6+
f, backend::AutoReverseChainRules, x, ::NoPullbackExtras
7+
)
8+
rc = ruleconfig(backend)
9+
y, pullback = rrule_via_ad(rc, f, x)
10+
pullbackfunc(dy) = last(pullback(dy))
11+
return y, pullbackfunc
12+
end
13+
14+
function DI.value_and_pullback!_split(
15+
f, backend::AutoReverseChainRules, x, extras::NoPullbackExtras
16+
)
17+
y, pullbackfunc = DI.value_and_pullback_split(f, backend, x, extras)
18+
pullbackfunc!(dx, dy) = copyto!(dx, pullbackfunc(dy))
19+
return y, pullbackfunc!
20+
end
21+
22+
function DI.value_and_pullback(
23+
f, backend::AutoReverseChainRules, x, dy, extras::NoPullbackExtras
24+
)
25+
y, pullbackfunc = DI.value_and_pullback_split(f, backend, x, extras)
26+
return y, pullbackfunc(dy)
27+
end

DifferentiationInterface/src/DifferentiationInterface.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ include("sparse/fallbacks.jl")
6060
include("sparse/jacobian.jl")
6161
include("sparse/hessian.jl")
6262

63+
include("translation/differentiate_with.jl")
64+
6365
export SecondOrder
6466

6567
export value_and_pushforward!, value_and_pushforward
@@ -87,6 +89,8 @@ export prepare_second_derivative, prepare_hvp, prepare_hessian
8789

8890
export check_available, check_twoarg, check_hessian
8991

92+
export DifferentiateWith
93+
9094
# Re-export backends from ADTypes
9195
export AutoChainRules
9296
export AutoDiffractor
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
"""
2+
DifferentiateWith
3+
4+
Callable function wrapper that enforces differentiation with a specified (inner) backend.
5+
6+
This works by defining new rules overriding the behavior of the outer backend that would normally be used.
7+
8+
!!! warning
9+
This is an experimental functionality, whose API cannot yet be considered stable.
10+
At the moment, it only supports one-argument functions, and rules are only defined for [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible outer backends.
11+
12+
# Fields
13+
14+
- `f`: the function in question
15+
- `backend::AbstractADType`: the inner backend to use for differentiation
16+
17+
# Constructor
18+
19+
DifferentiateWith(f, backend)
20+
21+
# Example
22+
23+
```@repl
24+
using DifferentiationInterface
25+
import ForwardDiff, Zygote
26+
27+
function f(x)
28+
a = Vector{eltype(x)}(undef, 1)
29+
a[1] = sum(x) # mutation that breaks Zygote
30+
return a[1]
31+
end
32+
33+
dw = DifferentiateWith(f, AutoForwardDiff());
34+
35+
gradient(dw, AutoZygote(), [1.0, 2.0]) # works because it calls ForwardDiff instead
36+
gradient(f, AutoZygote(), [1.0, 2.0]) # fails
37+
```
38+
"""
39+
struct DifferentiateWith{F,B<:AbstractADType}
40+
f::F
41+
backend::B
42+
end
43+
44+
"""
45+
(dw::DifferentiateWith)(x)
46+
47+
Call the underlying function `dw.f` of a [`DifferentiateWith`](@ref) wrapper.
48+
"""
49+
(dw::DifferentiateWith)(x) = dw.f(x)
50+
51+
function Base.show(io::IO, dw::DifferentiateWith)
52+
(; f, backend) = dw
53+
return print(io, "$f differentiated with $(backend_str(backend))")
54+
end

DifferentiationInterface/src/utils/exceptions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ struct MissingBackendError <: Exception
22
backend::AbstractADType
33
end
44
function Base.showerror(io::IO, e::MissingBackendError)
5-
println(io, "failed to use $(backend_string(e.backend)) backend.")
5+
println(io, "failed to use $(backend_str(e.backend)) backend.")
66
if !check_available(e.backend)
77
print(
88
io,

0 commit comments

Comments
 (0)