Skip to content

Commit 6155b2a

Browse files
authored
Pushforwards and pullbacks for everyone (#113)
1 parent fb07d7f commit 6155b2a

22 files changed

Lines changed: 320 additions & 315 deletions

File tree

docs/src/tutorial.md

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ Note the double exclamation mark, which is a convention telling you that `grad`
7171
@btime gradient!!($f, _grad, $backend, $x) evals=1 setup=(_grad=similar($x));
7272
```
7373

74-
For some reason the in-place version is slower than our first attempt, but as you can see it has one less allocation, corresponding to the gradient vector.
74+
For some reason the in-place version is not much better than our first attempt.
75+
However, as you can see, it has one less allocation: it corresponds to the gradient vector we provided.
7576
Don't worry, we're not done yet.
7677

7778
## Preparing for multiple gradients
@@ -133,36 +134,48 @@ It's blazingly fast.
133134
And you know what's even better?
134135
You didn't need to look at the docs of either ForwardDiff.jl or Enzyme.jl to achieve top performance with both, or to compare them.
135136

136-
## Testing and benchmarking
137+
## Testing
137138

138139
DifferentiationInterface.jl also provides some utilities for more involved comparison between backends.
139-
They are gathered in a submodule called [`DifferentiationInterfaceTest`](https://github.com/gdalle/DifferentiationInterface.jl/tree/main/lib/DifferentiationInterfaceTest).
140+
They are gathered in a submodule called `DifferentiationInterfaceTest`, located [here](https://github.com/gdalle/DifferentiationInterface.jl/tree/main/lib/DifferentiationInterfaceTest) in the repo.
140141

141142
```@repl tuto
142143
using DifferentiationInterfaceTest
143144
```
144145

145-
The main entry point is [`test_differentiation`](@ref), which is used as follows:
146+
For testing, you can use [`test_differentiation`](@ref) as follows:
146147

147148
```@repl tuto
148-
data = test_differentiation(
149+
test_differentiation(
149150
[AutoForwardDiff(), AutoEnzyme(Enzyme.Reverse)], # backends to compare
150-
[gradient], # operators to try
151-
[Scenario(f; x=x)]; # test scenario
151+
[gradient, pullback], # operators to try
152+
[Scenario(f; x=rand(3)), Scenario(f; x=rand(3,3))]; # test scenarios
152153
correctness=AutoZygote(), # compare results to a "ground truth" from Zygote
153-
benchmark=true, # measure runtime and allocations too
154154
detailed=true, # print detailed test set
155155
);
156156
```
157157

158-
The output of `test_differentiation` when `benchmark=true` can be converted to a `DataFrame` from [DataFrames.jl](https://github.com/JuliaData/DataFrames.jl):
158+
## Benchmarking
159+
160+
Once you have ascertained correctness, performance will be your next concern.
161+
The interface of [`benchmark_differentiation`](@ref) is very similar to the one we've just seen, but this time it returns a data object.
162+
163+
```@repl tuto
164+
data = benchmark_differentiation(
165+
[AutoForwardDiff(), AutoEnzyme(Enzyme.Reverse)],
166+
[gradient, pullback],
167+
[Scenario(f; x=rand(3)), Scenario(f; x=rand(3,3))];
168+
);
169+
```
170+
171+
The `BenchmarkData` object is just a struct of vectors, and you can easily convert to a `DataFrame` from [DataFrames.jl](https://github.com/JuliaData/DataFrames.jl):
159172

160173
```@repl tuto
161174
df = DataFrames.DataFrame(pairs(data)...)
162175
```
163176

164177
Here's what the resulting `DataFrame` looks like with all its columns.
165-
Note that the results may be slightly different from the ones presented above (we use [Chairmarks.jl](https://github.com/LilithHafner/Chairmarks.jl) internally instead of BenchmarkTools.jl, and measure slightly different operators).
178+
Note that the results may vary from the ones presented above (we use [Chairmarks.jl](https://github.com/LilithHafner/Chairmarks.jl) internally instead of BenchmarkTools.jl, and measure slightly different operators).
166179

167180
```@example tuto
168181
import Markdown, PrettyTables # hide

ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ DI.mode(::AutoForwardEnzyme) = ADTypes.AbstractForwardMode
4343
DI.mode(::AutoReverseEnzyme) = ADTypes.AbstractReverseMode
4444

4545
# Enzyme's `Duplicated(x, dx)` expects both arguments to be of the same type
46-
function DI.basisarray(::AutoEnzyme, a::AbstractArray{T}, i::CartesianIndex) where {T}
46+
function DI.basis(::AutoEnzyme, a::AbstractArray{T}, i::CartesianIndex) where {T}
4747
b = zero(a)
4848
b[i] = one(T)
4949
return b

ext/DifferentiationInterfaceFastDifferentiationExt/DifferentiationInterfaceFastDifferentiationExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ using FastDifferentiation.RuntimeGeneratedFunctions: RuntimeGeneratedFunction
1515

1616
DI.mode(::AutoFastDifferentiation) = ADTypes.AbstractSymbolicDifferentiationMode
1717
DI.supports_mutation(::AutoFastDifferentiation) = DI.MutationNotSupported()
18-
DI.supports_pullback(::AutoFastDifferentiation) = DI.PullbackNotSupported()
18+
DI.pullback_performance(::AutoFastDifferentiation) = DI.PullbackSlow()
1919

2020
myvec(x::Number) = [x]
2121
myvec(x::AbstractArray) = vec(x)

ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module DifferentiationInterfaceFiniteDifferencesExt
33
using ADTypes: AutoFiniteDifferences
44
import DifferentiationInterface as DI
55
using FillArrays: OneElement
6-
using FiniteDifferences: FiniteDifferences, jvp
6+
using FiniteDifferences: FiniteDifferences, jvp, j′vp
77
using LinearAlgebra: dot
88

99
DI.supports_mutation(::AutoFiniteDifferences) = DI.MutationNotSupported()
@@ -19,4 +19,15 @@ function DI.value_and_pushforward(
1919
return y, jvp(backend.fdm, f, (x, dx))
2020
end
2121

22+
#=
23+
# TODO: why does this fail?
24+
25+
function DI.value_and_pullback(
26+
f, backend::AutoFiniteDifferences{fdm}, x, dy, extras::Nothing
27+
) where {fdm}
28+
y = f(x)
29+
return y, j′vp(backend.fdm, f, x, dy)[1]
30+
end
31+
=#
32+
2233
end

lib/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ using DifferentiationInterface:
2222
mode,
2323
outer,
2424
supports_mutation,
25-
supports_pushforward,
26-
supports_pullback
25+
pushforward_performance,
26+
pullback_performance
2727
using DocStringExtensions
2828
import DifferentiationInterface as DI
2929
using JET: @test_call, @test_opt
@@ -42,17 +42,19 @@ include("utils/zero.jl")
4242
include("utils/compatibility.jl")
4343
include("utils/printing.jl")
4444
include("utils/misc.jl")
45+
include("utils/filter.jl")
4546

4647
include("tests/correctness.jl")
4748
include("tests/type_stability.jl")
4849
include("tests/call_count.jl")
4950
include("tests/benchmark.jl")
5051
include("tests/test.jl")
5152

53+
export all_operators
5254
export Scenario
5355
export default_scenarios
5456
export static_scenarios, component_scenarios, gpu_scenarios
55-
export BenchmarkData, record!
56-
export all_operators, test_differentiation
57+
export BenchmarkData
58+
export test_differentiation, benchmark_differentiation
5759

5860
end

lib/DifferentiationInterfaceTest/src/tests/benchmark.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,49 @@ function Base.pairs(data::BenchmarkData)
3939
return ns .=> getfield.(Ref(data), ns)
4040
end
4141

42+
"""
43+
benchmark_differentiation(backends, [operators, scenarios]; [kwargs...])
44+
45+
Benchmark a list of `backends` for a list of `operators` on a list of `scenarios`.
46+
47+
# Keyword arguments
48+
49+
- filtering: same as [`test_differentiation`](@ref) for the filtering part.
50+
- `logging=true`: whether to log progress
51+
"""
52+
function benchmark_differentiation(
53+
backends::Vector{<:AbstractADType},
54+
operators::Vector{<:Function}=all_operators(),
55+
scenarios::Vector{<:Scenario}=default_scenarios();
56+
# filtering
57+
input_type::Type=Any,
58+
output_type::Type=Any,
59+
allocating=true,
60+
mutating=true,
61+
first_order=true,
62+
second_order=true,
63+
excluded::Vector{<:Function}=Function[],
64+
# options
65+
logging=false,
66+
)
67+
operators = filter_operators(operators; first_order, second_order, excluded)
68+
scenarios = filter_scenarios(scenarios; input_type, output_type, allocating, mutating)
69+
70+
benchmark_data = BenchmarkData()
71+
for backend in backends
72+
for op in operators
73+
for scen in filter(scenarios) do scen
74+
compatible(backend, op, scen)
75+
end
76+
logging &&
77+
@info "Benchmarking: $(backend_string(backend)) - $op - $(string(scen))"
78+
run_benchmark!(benchmark_data, backend, op, scen; allocations=false)
79+
end
80+
end
81+
end
82+
return benchmark_data
83+
end
84+
4285
function record!(data, tup::NamedTuple)
4386
for n in fieldnames(typeof(tup))
4487
push!(getfield(data, n), getfield(tup, n))

lib/DifferentiationInterfaceTest/src/tests/test.jl

Lines changed: 21 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,7 @@
1-
"""
2-
all_operators()
3-
4-
List all operators that can be tested with [`test_differentiation`](@ref).
5-
"""
6-
function all_operators()
7-
return [
8-
pushforward,
9-
pullback,
10-
derivative,
11-
gradient,
12-
jacobian,
13-
second_derivative,
14-
hvp,
15-
hessian,
16-
]
17-
end
18-
19-
function filter_operators(
20-
operators::Vector{<:Function};
21-
first_order::Bool,
22-
second_order::Bool,
23-
excluded::Vector{<:Function},
24-
)
25-
!first_order && (
26-
operators = filter(
27-
!in([pushforward, pullback, derivative, gradient, jacobian]), operators
28-
)
29-
)
30-
!second_order && (operators = filter(!in([second_derivative, hvp, hessian]), operators))
31-
operators = filter(!in(excluded), operators)
32-
return operators
33-
end
34-
35-
function filter_scenarios(
36-
scenarios::Vector{<:Scenario};
37-
input_type::Type,
38-
output_type::Type,
39-
allocating::Bool,
40-
mutating::Bool,
41-
)
42-
scenarios = filter(scenarios) do scen
43-
typeof(scen.x) <: input_type && typeof(scen.y) <: output_type
44-
end
45-
!allocating && (scenarios = filter(is_mutating, scenarios))
46-
!mutating && (scenarios = filter(!is_mutating, scenarios))
47-
return scenarios
48-
end
49-
501
"""
512
test_differentiation(backends, [operators, scenarios]; [kwargs...])
523
53-
Cross-test a list of `backends` for a list of `operators` on a list of `scenarios`, running a variety of different tests.
54-
55-
If `benchmark=true`, return a [`BenchmarkData`](@ref) object, otherwise return `nothing`.
4+
Test a list of `backends` for a list of `operators` on a list of `scenarios`.
565
576
# Default arguments
587
@@ -66,9 +15,7 @@ Testing:
6615
- `correctness=true`: whether to compare the differentiation results with the theoretical values specified in each scenario. If a backend object like `correctness=AutoForwardDiff()` is passed instead of a boolean, the results will be compared using that reference backend as the ground truth.
6716
- `call_count=false`: whether to check that the function is called the right number of times
6817
- `type_stability=false`: whether to check type stability with JET.jl (thanks to `@test_opt`)
69-
- `benchmark=false`: whether to run and return a benchmark suite with Chairmarks.jl
70-
- `allocations=false`: whether to check that the benchmarks are allocation-free
71-
- `detailed=false`: whether to print a detailed test set (by scenario) or condensed test set (by operator)
18+
- `detailed=false`: whether to print a detailed or condensed test log
7219
7320
Filtering:
7421
@@ -82,6 +29,7 @@ Filtering:
8229
8330
Options:
8431
32+
- `logging=true`: whether to log progress
8533
- `isapprox=isapprox`: function used to compare objects, only needs to be set for complicated cases beyond arrays / scalars
8634
- `rtol=1e-3`: precision for correctness testing (when comparing to the reference outputs)
8735
"""
@@ -93,8 +41,6 @@ function test_differentiation(
9341
correctness::Union{Bool,AbstractADType}=true,
9442
type_stability::Bool=false,
9543
call_count::Bool=false,
96-
benchmark::Bool=false,
97-
allocations::Bool=false,
9844
detailed=false,
9945
# filtering
10046
input_type::Type=Any,
@@ -105,64 +51,45 @@ function test_differentiation(
10551
second_order=true,
10652
excluded::Vector{<:Function}=Function[],
10753
# options
54+
logging=false,
10855
isapprox=isapprox,
10956
rtol=1e-3,
11057
)
11158
operators = filter_operators(operators; first_order, second_order, excluded)
11259
scenarios = filter_scenarios(scenarios; input_type, output_type, allocating, mutating)
11360

114-
benchmark_data = BenchmarkData()
61+
if correctness isa AbstractADType
62+
scenarios = change_ref.(scenarios, Ref(correctness))
63+
end
11564

11665
title =
11766
"Differentiation tests -" *
11867
(correctness != false ? " correctness" : "") *
11968
(call_count ? " calls" : "") *
120-
(type_stability ? " types" : "") *
121-
(benchmark ? " benchmark" : "") *
122-
(allocations ? " allocations" : "")
69+
(type_stability ? " types" : "")
12370

124-
@testset verbose = detailed "$(backend_string(backend))" for backend in backends
125-
@testset verbose = detailed "$op" for op in operators
126-
@testset "$scen" for scen in filter(scenarios) do scen
127-
compatible(backend, op, scen)
128-
end
129-
if correctness != false
130-
@testset "Correctness" begin
131-
if correctness isa AbstractADType
132-
test_correctness(
133-
backend, op, change_ref(scen, correctness); isapprox, rtol
134-
)
135-
else
136-
test_correctness(backend, op, scen; isapprox, rtol)
137-
end
138-
end
71+
@testset verbose = true "$title" begin
72+
@testset verbose = detailed "$(backend_string(backend))" for backend in backends
73+
@testset verbose = detailed "$op" for op in operators
74+
@testset "$scen" for scen in filter(scenarios) do scen
75+
compatible(backend, op, scen)
13976
end
140-
if call_count
141-
@testset "Call count" begin
77+
logging &&
78+
@info "Testing: $(backend_string(backend)) - $op - $(string(scen))"
79+
correctness != false && @testset "Correctness" begin
80+
test_correctness(backend, op, scen; isapprox, rtol)
81+
end
82+
call_count && @testset "Call count" begin
14283
test_call_count(backend, op, scen)
14384
end
144-
end
145-
if type_stability
146-
@testset "Type stability" begin
85+
type_stability && @testset "Type stability" begin
14786
test_jet(backend, op, scen)
14887
end
14988
end
150-
if benchmark || allocations
151-
@testset "Allocations" begin
152-
run_benchmark!(
153-
benchmark_data, backend, op, scen; allocations=allocations
154-
)
155-
end
156-
end
15789
end
15890
end
15991
end
160-
161-
if benchmark
162-
return benchmark_data
163-
else
164-
return nothing
165-
end
92+
return nothing
16693
end
16794

16895
"""

lib/DifferentiationInterfaceTest/src/utils/compatibility.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,6 @@ function compatible(::AbstractADType, ::Function)
44
return true
55
end
66

7-
function compatible(backend::AbstractADType, ::typeof(pushforward))
8-
return Bool(supports_pushforward(backend))
9-
end
10-
11-
function compatible(backend::AbstractADType, ::typeof(pullback))
12-
return Bool(supports_pullback(backend))
13-
end
14-
157
## Backend-scenario
168

179
function compatible(::AbstractADType, ::Scenario{false})

0 commit comments

Comments
 (0)