We present a typical workflow with DifferentiationInterfaceTest.jl, building on the tutorial of the DifferentiationInterface.jl documentation (which we encourage you to read first).
import Chairmarks
using DataFrames
using DifferentiationInterface, DifferentiationInterfaceTest
using ForwardDiff: ForwardDiff
using Zygote: Zygote
The AD backends we want to compare are ForwardDiff.jl and Zygote.jl.
backends = [AutoForwardDiff(), AutoZygote()]
To do that, we are going to take gradients of a simple function:
f(x::AbstractArray) = sum(sin, x)
Of course we know the true gradient mapping:
∇f(x::AbstractArray) = cos.(x)
DifferentiationInterfaceTest.jl relies with so-called Scenarios, in which you encapsulate the information needed for your test:
- the operator category (here
:gradient) - the behavior of the operator (either
:inor:outof place) - the function
f - the input
xof the functionf(and possible tangents or contexts) - the reference first-order result
res1(and possible second-order resultres2) of the operator - the arguments
prep_argspassed during preparation
xv = rand(Float32, 3)
xm = rand(Float64, 3, 2)
scenarios = [
Scenario{:gradient,:out}(f, xv; res1=∇f(xv)),
Scenario{:gradient,:out}(f, xm; res1=∇f(xm)),
];
nothing # hide
The main entry point for testing is the function test_differentiation.
It has many options, but the main ingredients are the following:
test_differentiation(
backends, # the backends you want to compare
scenarios; # the scenarios you defined,
correctness=true, # compares values against the reference
type_stability=:none, # checks type stability with JET.jl
detailed=true, # prints a detailed test set
)
Once you are confident that your backends give the correct answers, you probably want to compare their performance.
This is made easy by the benchmark_differentiation function, whose syntax should feel familiar:
table = benchmark_differentiation(backends, scenarios);
The resulting object is a table, which can easily be converted into a DataFrame from DataFrames.jl.
Its columns correspond to the fields of DifferentiationBenchmarkDataRow.
df = DataFrame(table)