Skip to content

Latest commit

 

History

History
81 lines (61 loc) · 2.6 KB

File metadata and controls

81 lines (61 loc) · 2.6 KB

Tutorial

We present a typical workflow with DifferentiationInterfaceTest.jl, building on the tutorial of the DifferentiationInterface.jl documentation (which we encourage you to read first).

import Chairmarks
using DataFrames
using DifferentiationInterface, DifferentiationInterfaceTest
using ForwardDiff: ForwardDiff
using Zygote: Zygote

Introduction

The AD backends we want to compare are ForwardDiff.jl and Zygote.jl.

backends = [AutoForwardDiff(), AutoZygote()]

To do that, we are going to take gradients of a simple function:

f(x::AbstractArray) = sum(sin, x)

Of course we know the true gradient mapping:

∇f(x::AbstractArray) = cos.(x)

DifferentiationInterfaceTest.jl relies with so-called Scenarios, in which you encapsulate the information needed for your test:

  • the operator category (here :gradient)
  • the behavior of the operator (either :in or :out of place)
  • the function f
  • the input x of the function f (and possible tangents or contexts)
  • the reference first-order result res1 (and possible second-order result res2) of the operator
  • the arguments prep_args passed during preparation
xv = rand(Float32, 3)
xm = rand(Float64, 3, 2)
scenarios = [
    Scenario{:gradient,:out}(f, xv; res1=∇f(xv)),
    Scenario{:gradient,:out}(f, xm; res1=∇f(xm)),
];
nothing  # hide

Testing

The main entry point for testing is the function test_differentiation. It has many options, but the main ingredients are the following:

test_differentiation(
    backends,  # the backends you want to compare
    scenarios;  # the scenarios you defined,
    correctness=true,  # compares values against the reference
    type_stability=:none,  # checks type stability with JET.jl
    detailed=true,  # prints a detailed test set
)

Benchmarking

Once you are confident that your backends give the correct answers, you probably want to compare their performance. This is made easy by the benchmark_differentiation function, whose syntax should feel familiar:

table = benchmark_differentiation(backends, scenarios);

The resulting object is a table, which can easily be converted into a DataFrame from DataFrames.jl. Its columns correspond to the fields of DifferentiationBenchmarkDataRow.

df = DataFrame(table)