|
| 1 | +# Table of overloads |
| 2 | + |
| 3 | +As described in the [overview](@ref sec-overview), DifferentiationInterface provides multiple high-level operators like [`jacobian`](@ref), |
| 4 | +each with several variants: |
| 5 | +* **out-of-place** or **in-place** return values |
| 6 | +* **with** or **without primal** output value |
| 7 | +* support for **one-argument functions** `y = f(x)` or **two-argument functions** `f!(y, x)` |
| 8 | + |
| 9 | +To support a new backend, it is only required to [define either a pushforward or a pullback function](@ref ssec-requirements), |
| 10 | +since DifferentiationInterface provides default implementations of all operators using just these two primitives. |
| 11 | +However, backends sometimes provide their own implementations of operators, which can be more performant. |
| 12 | +When available, DifferentiationInterface **always** calls these backend-specific implementations, which we call *"overloads"*. |
| 13 | + |
| 14 | +The following tables summarize all implemented overloads for each backend. |
| 15 | +Each cell can have three values: |
| 16 | + |
| 17 | +- ❌: the operator is not overloaded because the backend does not support it |
| 18 | +- ✅: the operator is overloaded |
| 19 | +- NA: the operator does not exist |
| 20 | + |
| 21 | +!!! tip |
| 22 | + Check marks (✅) are clickable and link to the source code. |
| 23 | + |
| 24 | +```@setup overloads |
| 25 | +using ADTypes |
| 26 | +using DifferentiationInterface |
| 27 | +using DifferentiationInterface: backend_string, mutation_support, MutationSupported |
| 28 | +using Markdown: Markdown |
| 29 | +using Diffractor: Diffractor |
| 30 | +using Enzyme: Enzyme |
| 31 | +using FastDifferentiation: FastDifferentiation |
| 32 | +using FiniteDiff: FiniteDiff |
| 33 | +using FiniteDifferences: FiniteDifferences |
| 34 | +using ForwardDiff: ForwardDiff |
| 35 | +using PolyesterForwardDiff: PolyesterForwardDiff |
| 36 | +using ReverseDiff: ReverseDiff |
| 37 | +using Tapir: Tapir |
| 38 | +using Tracker: Tracker |
| 39 | +using Zygote: Zygote |
| 40 | +
|
| 41 | +function operators_and_types_f(backend::T) where {T<:AbstractADType} |
| 42 | + return ( |
| 43 | + # (op, types_op), |
| 44 | + # (op!, types_op!), |
| 45 | + # (val_and_op, types_val_and_op), |
| 46 | + # (val_and_op!, types_val_and_op!), |
| 47 | + ( |
| 48 | + (:derivative, (Any, T, Any, Any)), |
| 49 | + (:derivative!, (Any, Any, T, Any, Any)), |
| 50 | + (:value_and_derivative, (Any, T, Any, Any)), |
| 51 | + (:value_and_derivative!, (Any, Any, T, Any, Any)), |
| 52 | + ), |
| 53 | + ( |
| 54 | + (:gradient, (Any, T, Any, Any)), |
| 55 | + (:gradient!, (Any, Any, T, Any, Any)), |
| 56 | + (:value_and_gradient, (Any, T, Any, Any)), |
| 57 | + (:value_and_gradient!, (Any, Any, T, Any, Any)), |
| 58 | + ), |
| 59 | + ( |
| 60 | + (:jacobian, (Any, T, Any, Any)), |
| 61 | + (:jacobian!, (Any, Any, T, Any, Any)), |
| 62 | + (:value_and_jacobian, (Any, T, Any, Any)), |
| 63 | + (:value_and_jacobian!, (Any, Any, T, Any, Any)), |
| 64 | + ), |
| 65 | + ( |
| 66 | + (:hessian, (Any, T, Any, Any)), |
| 67 | + (:hessian!, (Any, Any, T, Any, Any)), |
| 68 | + (nothing, nothing), |
| 69 | + (nothing, nothing), |
| 70 | + ), |
| 71 | + ( |
| 72 | + (:hvp, (Any, T, Any, Any, Any)), |
| 73 | + (:hvp!, (Any, Any, T, Any, Any, Any)), |
| 74 | + (nothing, nothing), |
| 75 | + (nothing, nothing), |
| 76 | + ), |
| 77 | + ( |
| 78 | + (:pullback, (Any, T, Any, Any, Any)), |
| 79 | + (:pullback!, (Any, Any, T, Any, Any, Any)), |
| 80 | + (:value_and_pullback, (Any, T, Any, Any, Any)), |
| 81 | + (:value_and_pullback!, (Any, Any, T, Any, Any, Any)), |
| 82 | + ), |
| 83 | + ( |
| 84 | + (:pushforward, (Any, T, Any, Any, Any)), |
| 85 | + (:pushforward!, (Any, Any, T, Any, Any, Any)), |
| 86 | + (:value_and_pushforward, (Any, T, Any, Any, Any)), |
| 87 | + (:value_and_pushforward!, (Any, Any, T, Any, Any, Any)), |
| 88 | + ), |
| 89 | + ) |
| 90 | +end |
| 91 | +function operators_and_types_f!(backend::T) where {T<:AbstractADType} |
| 92 | + return ( |
| 93 | + ( |
| 94 | + (:derivative, (Any, Any, T, Any, Any)), |
| 95 | + (:derivative!, (Any, Any, Any, T, Any, Any)), |
| 96 | + (:value_and_derivative, (Any, Any, T, Any, Any)), |
| 97 | + (:value_and_derivative!, (Any, Any, Any, T, Any, Any)), |
| 98 | + ), |
| 99 | + ( |
| 100 | + (:jacobian, (Any, Any, T, Any, Any)), |
| 101 | + (:jacobian!, (Any, Any, Any, T, Any, Any)), |
| 102 | + (:value_and_jacobian, (Any, Any, T, Any, Any)), |
| 103 | + (:value_and_jacobian!, (Any, Any, Any, T, Any, Any)), |
| 104 | + ), |
| 105 | + ( |
| 106 | + (:pullback, (Any, Any, T, Any, Any, Any)), |
| 107 | + (:pullback!, (Any, Any, Any, T, Any, Any, Any)), |
| 108 | + (:value_and_pullback, (Any, Any, T, Any, Any, Any)), |
| 109 | + (:value_and_pullback!, (Any, Any, Any, T, Any, Any, Any)), |
| 110 | + ), |
| 111 | + ( |
| 112 | + (:pushforward, (Any, Any, T, Any, Any, Any)), |
| 113 | + (:pushforward!, (Any, Any, Any, T, Any, Any, Any)), |
| 114 | + (:value_and_pushforward, (Any, Any, T, Any, Any, Any)), |
| 115 | + (:value_and_pushforward!, (Any, Any, Any, T, Any, Any, Any)), |
| 116 | + ), |
| 117 | + ) |
| 118 | +end |
| 119 | +
|
| 120 | +function method_overloaded(operator::Symbol, argtypes, ext::Module) |
| 121 | + f = @eval DifferentiationInterface.$operator |
| 122 | + ms = methods(f, argtypes, ext) |
| 123 | +
|
| 124 | + n = length(ms) |
| 125 | + n == 0 && return "❌" |
| 126 | + n == 1 && return "[✅]($(Base.url(only(ms))))" |
| 127 | + return "[✅]($(Base.url(first(ms))))" # Optional TODO: return all URLs? |
| 128 | +end |
| 129 | +
|
| 130 | +function print_overload_table(io::IO, operators_and_types, ext::Module) |
| 131 | + println(io, "| Operator | `op` | `op!` | `value_and_op` | `value_and_op!` |") |
| 132 | + println(io, "|:---------|:----:|:-----:|:--------------:|:---------------:|") |
| 133 | + for operator_variants in operators_and_types |
| 134 | + opname = first(first(operator_variants)) |
| 135 | + print(io, "| `$opname` |") |
| 136 | + for (op, type_signature) in operator_variants |
| 137 | + if isnothing(op) |
| 138 | + print(io, "NA") |
| 139 | + else |
| 140 | + print(io, method_overloaded(op, type_signature, ext)) |
| 141 | + end |
| 142 | + print(io, '|') |
| 143 | + end |
| 144 | + println(io) |
| 145 | + end |
| 146 | +end |
| 147 | +
|
| 148 | +function print_overloads(backend, ext::Symbol) |
| 149 | + io = IOBuffer() |
| 150 | + ext = Base.get_extension(DifferentiationInterface, ext) |
| 151 | +
|
| 152 | + println(io, "#### One-argument functions `y = f(x)`") |
| 153 | + println(io) |
| 154 | + print_overload_table(io, operators_and_types_f(backend), ext) |
| 155 | +
|
| 156 | + println(io, "#### Two-argument functions `f!(y, x)`") |
| 157 | + println(io) |
| 158 | + if mutation_support(backend) == MutationSupported() |
| 159 | + print_overload_table(io, operators_and_types_f!(backend), ext) |
| 160 | + else |
| 161 | + println(io, "Backend doesn't support mutating functions.") |
| 162 | + end |
| 163 | +
|
| 164 | + return Markdown.parse(String(take!(io))) |
| 165 | +end |
| 166 | +``` |
| 167 | + |
| 168 | +## Diffractor (forward/reverse) |
| 169 | +```@example overloads |
| 170 | +print_overloads(AutoDiffractor(), :DifferentiationInterfaceDiffractorExt) # hide |
| 171 | +``` |
| 172 | + |
| 173 | +## Enzyme (forward) |
| 174 | +```@example overloads |
| 175 | +print_overloads(AutoEnzyme(; mode=Enzyme.Forward), :DifferentiationInterfaceEnzymeExt) # hide |
| 176 | +``` |
| 177 | + |
| 178 | +## Enzyme (reverse) |
| 179 | +```@example overloads |
| 180 | +print_overloads(AutoEnzyme(; mode=Enzyme.Reverse), :DifferentiationInterfaceEnzymeExt) # hide |
| 181 | +``` |
| 182 | + |
| 183 | +## FastDifferentiation (symbolic) |
| 184 | +```@example overloads |
| 185 | +print_overloads(AutoFastDifferentiation(), :DifferentiationInterfaceFastDifferentiationExt) # hide |
| 186 | +``` |
| 187 | + |
| 188 | +## FiniteDiff (forward) |
| 189 | +```@example overloads |
| 190 | +print_overloads(AutoFiniteDiff(), :DifferentiationInterfaceFiniteDiffExt) # hide |
| 191 | +``` |
| 192 | + |
| 193 | +## FiniteDifferences (forward) |
| 194 | +```@example overloads |
| 195 | +print_overloads(AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(3, 1)), :DifferentiationInterfaceFiniteDifferencesExt) # hide |
| 196 | +``` |
| 197 | + |
| 198 | +## ForwardDiff (forward) |
| 199 | +```@example overloads |
| 200 | +print_overloads(AutoForwardDiff(), :DifferentiationInterfaceForwardDiffExt) # hide |
| 201 | +``` |
| 202 | + |
| 203 | +## PolyesterForwardDiff (forward) |
| 204 | +```@example overloads |
| 205 | +print_overloads(AutoPolyesterForwardDiff(; chunksize=1), :DifferentiationInterfacePolyesterForwardDiffExt) # hide |
| 206 | +``` |
| 207 | + |
| 208 | +## ReverseDiff (reverse) |
| 209 | +```@example overloads |
| 210 | +print_overloads(AutoReverseDiff(), :DifferentiationInterfaceReverseDiffExt) # hide |
| 211 | +``` |
| 212 | + |
| 213 | +## Tapir (reverse) |
| 214 | +```@example overloads |
| 215 | +print_overloads(AutoTapir(), :DifferentiationInterfaceTapirExt) # hide |
| 216 | +``` |
| 217 | + |
| 218 | +## Tracker (reverse) |
| 219 | +```@example overloads |
| 220 | +print_overloads(AutoTracker(), :DifferentiationInterfaceTrackerExt) # hide |
| 221 | +``` |
| 222 | + |
| 223 | +## Zygote (reverse) |
| 224 | +```@example overloads |
| 225 | +print_overloads(AutoZygote(), :DifferentiationInterfaceZygoteExt) # hide |
| 226 | +``` |
0 commit comments