Skip to content

Commit b85ea79

Browse files
authored
Generate table of overloads in docs (#207)
* Generate table of overloads in docs * Fix table header * Remove question mark option * Add NA option * Manually specify function signatures * Reorganize tables * Manually set subsections * In case of multiple methods, return URL of first instead of none * Minor fixes * Filter backends that don't support `f!(y, x)` * Explain what "tables of overloads" are * Minor tweaks
1 parent e549302 commit b85ea79

4 files changed

Lines changed: 229 additions & 3 deletions

File tree

DifferentiationInterface/docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ makedocs(;
4545
"Home" => "index.md", #
4646
"Start here" => ["tutorial.md", "overview.md", "backends.md"],
4747
"API reference" => "api.md",
48-
"Advanced" => ["design.md", "extensions.md"],
48+
"Advanced" => ["design.md", "extensions.md", "overloads.md"],
4949
],
5050
)
5151

DifferentiationInterface/docs/src/design.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Package design
22

3-
## Backend requirements
3+
## [Backend requirements](@id ssec-requirements)
44

55
To be usable with DifferentiationInterface.jl, an AD backend needs an object subtyping `ADTypes.AbstractADType`.
66
In addition, some operators must be defined:
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
# Table of overloads
2+
3+
As described in the [overview](@ref sec-overview), DifferentiationInterface provides multiple high-level operators like [`jacobian`](@ref),
4+
each with several variants:
5+
* **out-of-place** or **in-place** return values
6+
* **with** or **without primal** output value
7+
* support for **one-argument functions** `y = f(x)` or **two-argument functions** `f!(y, x)`
8+
9+
To support a new backend, it is only required to [define either a pushforward or a pullback function](@ref ssec-requirements),
10+
since DifferentiationInterface provides default implementations of all operators using just these two primitives.
11+
However, backends sometimes provide their own implementations of operators, which can be more performant.
12+
When available, DifferentiationInterface **always** calls these backend-specific implementations, which we call *"overloads"*.
13+
14+
The following tables summarize all implemented overloads for each backend.
15+
Each cell can have three values:
16+
17+
- ❌: the operator is not overloaded because the backend does not support it
18+
- ✅: the operator is overloaded
19+
- NA: the operator does not exist
20+
21+
!!! tip
22+
Check marks (✅) are clickable and link to the source code.
23+
24+
```@setup overloads
25+
using ADTypes
26+
using DifferentiationInterface
27+
using DifferentiationInterface: backend_string, mutation_support, MutationSupported
28+
using Markdown: Markdown
29+
using Diffractor: Diffractor
30+
using Enzyme: Enzyme
31+
using FastDifferentiation: FastDifferentiation
32+
using FiniteDiff: FiniteDiff
33+
using FiniteDifferences: FiniteDifferences
34+
using ForwardDiff: ForwardDiff
35+
using PolyesterForwardDiff: PolyesterForwardDiff
36+
using ReverseDiff: ReverseDiff
37+
using Tapir: Tapir
38+
using Tracker: Tracker
39+
using Zygote: Zygote
40+
41+
function operators_and_types_f(backend::T) where {T<:AbstractADType}
42+
return (
43+
# (op, types_op),
44+
# (op!, types_op!),
45+
# (val_and_op, types_val_and_op),
46+
# (val_and_op!, types_val_and_op!),
47+
(
48+
(:derivative, (Any, T, Any, Any)),
49+
(:derivative!, (Any, Any, T, Any, Any)),
50+
(:value_and_derivative, (Any, T, Any, Any)),
51+
(:value_and_derivative!, (Any, Any, T, Any, Any)),
52+
),
53+
(
54+
(:gradient, (Any, T, Any, Any)),
55+
(:gradient!, (Any, Any, T, Any, Any)),
56+
(:value_and_gradient, (Any, T, Any, Any)),
57+
(:value_and_gradient!, (Any, Any, T, Any, Any)),
58+
),
59+
(
60+
(:jacobian, (Any, T, Any, Any)),
61+
(:jacobian!, (Any, Any, T, Any, Any)),
62+
(:value_and_jacobian, (Any, T, Any, Any)),
63+
(:value_and_jacobian!, (Any, Any, T, Any, Any)),
64+
),
65+
(
66+
(:hessian, (Any, T, Any, Any)),
67+
(:hessian!, (Any, Any, T, Any, Any)),
68+
(nothing, nothing),
69+
(nothing, nothing),
70+
),
71+
(
72+
(:hvp, (Any, T, Any, Any, Any)),
73+
(:hvp!, (Any, Any, T, Any, Any, Any)),
74+
(nothing, nothing),
75+
(nothing, nothing),
76+
),
77+
(
78+
(:pullback, (Any, T, Any, Any, Any)),
79+
(:pullback!, (Any, Any, T, Any, Any, Any)),
80+
(:value_and_pullback, (Any, T, Any, Any, Any)),
81+
(:value_and_pullback!, (Any, Any, T, Any, Any, Any)),
82+
),
83+
(
84+
(:pushforward, (Any, T, Any, Any, Any)),
85+
(:pushforward!, (Any, Any, T, Any, Any, Any)),
86+
(:value_and_pushforward, (Any, T, Any, Any, Any)),
87+
(:value_and_pushforward!, (Any, Any, T, Any, Any, Any)),
88+
),
89+
)
90+
end
91+
function operators_and_types_f!(backend::T) where {T<:AbstractADType}
92+
return (
93+
(
94+
(:derivative, (Any, Any, T, Any, Any)),
95+
(:derivative!, (Any, Any, Any, T, Any, Any)),
96+
(:value_and_derivative, (Any, Any, T, Any, Any)),
97+
(:value_and_derivative!, (Any, Any, Any, T, Any, Any)),
98+
),
99+
(
100+
(:jacobian, (Any, Any, T, Any, Any)),
101+
(:jacobian!, (Any, Any, Any, T, Any, Any)),
102+
(:value_and_jacobian, (Any, Any, T, Any, Any)),
103+
(:value_and_jacobian!, (Any, Any, Any, T, Any, Any)),
104+
),
105+
(
106+
(:pullback, (Any, Any, T, Any, Any, Any)),
107+
(:pullback!, (Any, Any, Any, T, Any, Any, Any)),
108+
(:value_and_pullback, (Any, Any, T, Any, Any, Any)),
109+
(:value_and_pullback!, (Any, Any, Any, T, Any, Any, Any)),
110+
),
111+
(
112+
(:pushforward, (Any, Any, T, Any, Any, Any)),
113+
(:pushforward!, (Any, Any, Any, T, Any, Any, Any)),
114+
(:value_and_pushforward, (Any, Any, T, Any, Any, Any)),
115+
(:value_and_pushforward!, (Any, Any, Any, T, Any, Any, Any)),
116+
),
117+
)
118+
end
119+
120+
function method_overloaded(operator::Symbol, argtypes, ext::Module)
121+
f = @eval DifferentiationInterface.$operator
122+
ms = methods(f, argtypes, ext)
123+
124+
n = length(ms)
125+
n == 0 && return "❌"
126+
n == 1 && return "[✅]($(Base.url(only(ms))))"
127+
return "[✅]($(Base.url(first(ms))))" # Optional TODO: return all URLs?
128+
end
129+
130+
function print_overload_table(io::IO, operators_and_types, ext::Module)
131+
println(io, "| Operator | `op` | `op!` | `value_and_op` | `value_and_op!` |")
132+
println(io, "|:---------|:----:|:-----:|:--------------:|:---------------:|")
133+
for operator_variants in operators_and_types
134+
opname = first(first(operator_variants))
135+
print(io, "| `$opname` |")
136+
for (op, type_signature) in operator_variants
137+
if isnothing(op)
138+
print(io, "NA")
139+
else
140+
print(io, method_overloaded(op, type_signature, ext))
141+
end
142+
print(io, '|')
143+
end
144+
println(io)
145+
end
146+
end
147+
148+
function print_overloads(backend, ext::Symbol)
149+
io = IOBuffer()
150+
ext = Base.get_extension(DifferentiationInterface, ext)
151+
152+
println(io, "#### One-argument functions `y = f(x)`")
153+
println(io)
154+
print_overload_table(io, operators_and_types_f(backend), ext)
155+
156+
println(io, "#### Two-argument functions `f!(y, x)`")
157+
println(io)
158+
if mutation_support(backend) == MutationSupported()
159+
print_overload_table(io, operators_and_types_f!(backend), ext)
160+
else
161+
println(io, "Backend doesn't support mutating functions.")
162+
end
163+
164+
return Markdown.parse(String(take!(io)))
165+
end
166+
```
167+
168+
## Diffractor (forward/reverse)
169+
```@example overloads
170+
print_overloads(AutoDiffractor(), :DifferentiationInterfaceDiffractorExt) # hide
171+
```
172+
173+
## Enzyme (forward)
174+
```@example overloads
175+
print_overloads(AutoEnzyme(; mode=Enzyme.Forward), :DifferentiationInterfaceEnzymeExt) # hide
176+
```
177+
178+
## Enzyme (reverse)
179+
```@example overloads
180+
print_overloads(AutoEnzyme(; mode=Enzyme.Reverse), :DifferentiationInterfaceEnzymeExt) # hide
181+
```
182+
183+
## FastDifferentiation (symbolic)
184+
```@example overloads
185+
print_overloads(AutoFastDifferentiation(), :DifferentiationInterfaceFastDifferentiationExt) # hide
186+
```
187+
188+
## FiniteDiff (forward)
189+
```@example overloads
190+
print_overloads(AutoFiniteDiff(), :DifferentiationInterfaceFiniteDiffExt) # hide
191+
```
192+
193+
## FiniteDifferences (forward)
194+
```@example overloads
195+
print_overloads(AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(3, 1)), :DifferentiationInterfaceFiniteDifferencesExt) # hide
196+
```
197+
198+
## ForwardDiff (forward)
199+
```@example overloads
200+
print_overloads(AutoForwardDiff(), :DifferentiationInterfaceForwardDiffExt) # hide
201+
```
202+
203+
## PolyesterForwardDiff (forward)
204+
```@example overloads
205+
print_overloads(AutoPolyesterForwardDiff(; chunksize=1), :DifferentiationInterfacePolyesterForwardDiffExt) # hide
206+
```
207+
208+
## ReverseDiff (reverse)
209+
```@example overloads
210+
print_overloads(AutoReverseDiff(), :DifferentiationInterfaceReverseDiffExt) # hide
211+
```
212+
213+
## Tapir (reverse)
214+
```@example overloads
215+
print_overloads(AutoTapir(), :DifferentiationInterfaceTapirExt) # hide
216+
```
217+
218+
## Tracker (reverse)
219+
```@example overloads
220+
print_overloads(AutoTracker(), :DifferentiationInterfaceTrackerExt) # hide
221+
```
222+
223+
## Zygote (reverse)
224+
```@example overloads
225+
print_overloads(AutoZygote(), :DifferentiationInterfaceZygoteExt) # hide
226+
```

DifferentiationInterface/docs/src/overview.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Overview
1+
# [Overview](@id sec-overview)
22

33
## Operators
44

0 commit comments

Comments
 (0)