Skip to content

Commit 1369010

Browse files
authored
Make DifferentiateWith compatible with ForwardDiff, clarify docs (#487)
* Make DIfferentiateWith compatible with ForwardDiff, clarify docs * More details
1 parent f262504 commit 1369010

4 files changed

Lines changed: 74 additions & 39 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import DifferentiationInterface as DI
77
using DifferentiationInterface:
88
Context,
99
DerivativeExtras,
10+
DifferentiateWith,
1011
GradientExtras,
1112
HessianExtras,
1213
HVPExtras,
@@ -59,5 +60,6 @@ include("utils.jl")
5960
include("onearg.jl")
6061
include("twoarg.jl")
6162
include("secondorder.jl")
63+
include("differentiate_with.jl")
6264

6365
end # module
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
function (dw::DifferentiateWith)(x::Dual{T,V,N}) where {T,V,N}
2+
@compat (; f, backend) = dw
3+
xval = myvalue(T, x)
4+
tx = mypartials(T, Val(N), x)
5+
y, ty = DI.value_and_pushforward(f, backend, xval, tx)
6+
return make_dual(T, y, ty)
7+
end
8+
9+
function (dw::DifferentiateWith)(x::AbstractArray{Dual{T,V,N}}) where {T,V,N}
10+
@compat (; f, backend) = dw
11+
xval = myvalue(T, x)
12+
tx = mypartials(T, Val(N), x)
13+
y, ty = DI.value_and_pushforward(f, backend, xval, tx)
14+
return make_dual(T, y, ty)
15+
end
Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,28 @@
11
"""
22
DifferentiateWith
33
4-
Callable function wrapper that enforces differentiation with a specified (inner) backend.
4+
Function wrapper that enforces differentiation with a "substitute" AD backend, possible different from the "true" AD backend that is called.
55
6-
This works by defining new rules overriding the behavior of the outer backend that would normally be used.
6+
For instance, suppose a function `f` is not differentiable with Zygote because it involves mutation, but you know that it is differentiable with Enzyme.
7+
Then `f2 = DifferentiateWith(f, AutoEnzyme())` is a new function that behaves like `f`, except that `f2` is differentiable with Zygote (thanks to a chain rule which calls Enzyme under the hood).
8+
Moreover, any larger algorithm `alg` that calls `f2` instead of `f` will also be differentiable with Zygote (as long as `f` was the only Zygote blocker).
9+
10+
!!! tip
11+
This is mainly relevant for package developers who want to produce differentiable code at low cost, without writing the differentiation rules themselves.
12+
If you sprinkle a few `DifferentiateWith` in places where some AD backends may struggle, end users can pick from a wider variety of packages to differentiate your algorithms.
713
814
!!! warning
9-
This is an experimental functionality, whose API cannot yet be considered stable.
10-
It only supports out-of-place functions, and rules are only defined for [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible outer backends.
15+
`DifferentiateWith` only supports out-of-place functions `y = f(x)` without additional context arguments.
16+
It only makes these functions differentiable if the true backend is either [ForwardDiff](https://github.com/JuliaDiff/ForwardDiff.jl) or compatible with [ChainRules](https://github.com/JuliaDiff/ChainRules.jl).
17+
For any other true backend, the differentiation behavior is not altered by `DifferentiateWith` (it becomes a transparent wrapper).
1118
1219
# Fields
1320
14-
- `f`: the function in question
15-
- `backend::AbstractADType`: the inner backend to use for differentiation
21+
- `f`: the function in question, with signature `f(x)`
22+
- `backend::AbstractADType`: the substitute backend to use for differentiation
23+
24+
!!! note
25+
For the substitute AD backend to be called under the hood, its package needs to be loaded in addition to the package of the true AD backend.
1626
1727
# Constructor
1828
@@ -21,40 +31,50 @@ This works by defining new rules overriding the behavior of the outer backend th
2131
# Example
2232
2333
```jldoctest
24-
using DifferentiationInterface
25-
import ForwardDiff, Zygote
34+
julia> using DifferentiationInterface
2635
27-
function f(x)
28-
a = Vector{eltype(x)}(undef, 1)
29-
a[1] = sum(x) # mutation that breaks Zygote
30-
return a[1]
31-
end
36+
julia> import Enzyme, ForwardDiff, Zygote
37+
38+
julia> function f(x::Vector{Float64})
39+
a = Vector{Float64}(undef, 1) # type constraint breaks ForwardDiff
40+
a[1] = sum(abs2, x) # mutation breaks Zygote
41+
return a[1]
42+
end;
3243
33-
dw = DifferentiateWith(f, AutoForwardDiff());
44+
julia> f2 = DifferentiateWith(f, AutoEnzyme());
3445
35-
gradient(dw, AutoZygote(), [2.0]) # calls ForwardDiff instead
46+
julia> f([3.0, 5.0]) == f2([3.0, 5.0])
47+
true
3648
37-
# output
49+
julia> alg(x) = 7 * f2(x);
3850
39-
1-element Vector{Float64}:
40-
1.0
51+
julia> ForwardDiff.gradient(alg, [3.0, 5.0])
52+
2-element Vector{Float64}:
53+
42.0
54+
70.0
55+
56+
julia> Zygote.gradient(alg, [3.0, 5.0])[1]
57+
2-element Vector{Float64}:
58+
42.0
59+
70.0
4160
```
4261
"""
4362
struct DifferentiateWith{F,B<:AbstractADType}
4463
f::F
4564
backend::B
4665
end
4766

48-
"""
49-
(dw::DifferentiateWith)(x)
50-
51-
Call the underlying function `dw.f` of a [`DifferentiateWith`](@ref) wrapper.
52-
"""
5367
(dw::DifferentiateWith)(x) = dw.f(x)
5468

5569
function Base.show(io::IO, dw::DifferentiateWith)
5670
@compat (; f, backend) = dw
5771
return print(
58-
io, DifferentiateWith, "(", repr(f; context=io), ",", repr(backend; context=io), ")"
72+
io,
73+
DifferentiateWith,
74+
"(",
75+
repr(f; context=io),
76+
", ",
77+
repr(backend; context=io),
78+
")",
5979
)
6080
end
Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,29 @@
11
using Pkg
2-
Pkg.add(["ForwardDiff", "Zygote"])
2+
Pkg.add(["FiniteDiff", "ForwardDiff", "Zygote"])
33

44
using DifferentiationInterface, DifferentiationInterfaceTest
55
import DifferentiationInterfaceTest as DIT
6+
using FiniteDiff: FiniteDiff
67
using ForwardDiff: ForwardDiff
78
using Zygote: Zygote
89
using Test
910

1011
LOGGING = get(ENV, "CI", "false") == "false"
1112

12-
function zygote_breaking_scenarios()
13-
outofplace_scens = filter(default_scenarios()) do scen
14-
DIT.operator_place(scen) == :out
15-
end
16-
bad_outofplace_scens = map(outofplace_scens) do scen
17-
function bad_f(x)
18-
a = Vector{eltype(x)}(undef, 1)
19-
a[1] = sum(x)
20-
return scen.f(x)
13+
function differentiatewith_scenarios()
14+
bad_scens = # these closurified scenarios have mutation and type constraints
15+
filter(default_scenarios(; include_normal=false, include_closurified=true)) do scen
16+
DIT.function_place(scen) == :out
2117
end
22-
wrapped_bad_f = DifferentiateWith(bad_f, AutoForwardDiff())
23-
bad_scen = DIT.change_function(scen, wrapped_bad_f)
24-
return bad_scen
18+
good_scens = map(bad_scens) do scen
19+
DIT.change_function(scen, DifferentiateWith(scen.f, AutoFiniteDiff()))
2520
end
26-
return bad_outofplace_scens
21+
return good_scens
2722
end
2823

2924
test_differentiation(
30-
AutoZygote(), zygote_breaking_scenarios(); second_order=false, logging=LOGGING
25+
[AutoForwardDiff(), AutoZygote()],
26+
differentiatewith_scenarios();
27+
second_order=false,
28+
logging=LOGGING,
3129
)

0 commit comments

Comments
 (0)