|
| 1 | +# Differentiability |
| 2 | + |
| 3 | +DifferentiationInterface.jl and its sibling package [DifferentiationInterfaceTest.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl/tree/main/DifferentiationInterfaceTest) allow you to try out differentiation of existing code with a variety of AD backends. |
| 4 | +However, they will not help you _write_ differentiable code in the first place. |
| 5 | +To make your functions compatible with several backends, you need to mind the restrictions imposed by each one. |
| 6 | + |
| 7 | +The list of backends available at [juliadiff.org](https://juliadiff.org/) is split into 2 main families: operator overloading and source transformation. |
| 8 | +Writing differentiable code requires a specific approach in each paradigm: |
| 9 | + |
| 10 | +- For operator overloading, ensure type-genericity. |
| 11 | +- For source transformation, rely on existing rules or write your own. |
| 12 | + |
| 13 | +!!! tip |
| 14 | + Depending on your intended use case, you may not need to ensure compatibility with every single backend. |
| 15 | + In particular, some applications strongly suggest a specific "mode" of AD (forward or reverse), in which case backends limited to the other mode are mostly irrelevant. |
| 16 | + |
| 17 | +In what follows, we do not discuss AD with finite differences ([FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl) and [FiniteDifferences.jl](https://github.com/JuliaDiff/FiniteDifferences.jl)) because those packages will work as long as your function itself can run, which is obviously a prerequisite. |
| 18 | + |
| 19 | +## Operator overloading |
| 20 | + |
| 21 | +### ForwardDiff |
| 22 | + |
| 23 | +One of the most common backends in the ecosystem is [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl). |
| 24 | +It performs AD at a scalar level by replacing plain numbers with [`Dual` numbers](https://juliadiff.org/ForwardDiff.jl/stable/dev/how_it_works/), which carry derivative information. |
| 25 | +As explained in the [limitations of ForwardDiff](https://juliadiff.org/ForwardDiff.jl/stable/user/limitations/), this will only work if the differentiated code does not restrict number types too much. |
| 26 | +Otherwise, you may encounter errors like this one: |
| 27 | + |
| 28 | +```julia |
| 29 | +MethodError: no method matching Float64(::ForwardDiff.Dual{...}) |
| 30 | +``` |
| 31 | + |
| 32 | +To prevent them, here are a few things to look out for: |
| 33 | + |
| 34 | +- Avoid functions with overly specific type annotations. |
| 35 | + |
| 36 | +```julia |
| 37 | +f(x::Vector{Float64}) = ... # bad |
| 38 | +f(x::AbstractVector{<:Real}) = ... # good |
| 39 | +``` |
| 40 | + |
| 41 | +- When creating new containers or buffers, adapt to the input number type if necessary. |
| 42 | + |
| 43 | +```julia |
| 44 | +tmp = zeros(length(x)) # bad |
| 45 | +tmp = zeros(eltype(x), length(x)) # good |
| 46 | +tmp = similar(x) # best when possible |
| 47 | +``` |
| 48 | + |
| 49 | +In some situations, manually writing overloads for `x::Dual` or `x::AbstractArray{<:Dual}` can be necessary. |
| 50 | + |
| 51 | +### ReverseDiff |
| 52 | + |
| 53 | +[ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl) relies on operator overloading for scalars, but also for arrays. |
| 54 | +The relevant types are called `TrackedReal` and `TrackedArray`, they have a set of [limitations](https://juliadiff.org/ReverseDiff.jl/stable/limits/) very similar to that of ForwardDiff.jl's `Dual` and will cause similar errors. |
| 55 | + |
| 56 | +### Symbolic backends |
| 57 | + |
| 58 | +[Symbolics.jl](https://github.com/JuliaSymbolics/Symbolics.jl) and [FastDifferentiation.jl](https://github.com/brianguenter/FastDifferentiation.jl) are also based on operator overloading. |
| 59 | +However, their respective number types are a bit different because they represent symbolic variables instead of numerical values. |
| 60 | +The operator overloading aims at reconstructing a symbolic representation of the function (an equation, more or less), which means certain language constructs will not be tolerated even though ForwardDiff.jl or ReverseDiff.jl could handle them. |
| 61 | + |
| 62 | +## Source transformation |
| 63 | + |
| 64 | +### Zygote |
| 65 | + |
| 66 | +[Zygote.jl](https://github.com/FluxML/Zygote.jl) can differentiate a lot of Julia code, but it does have some major [limitations](https://fluxml.ai/Zygote.jl/stable/limitations/). |
| 67 | +The most frequently encountered is the lack of support for mutation: if you try to modify the contents of an array during differentiation, you will get an error like |
| 68 | + |
| 69 | +```julia |
| 70 | +ERROR: Mutating arrays is not supported |
| 71 | +``` |
| 72 | + |
| 73 | +Mutations and some other language constructs (exceptions, foreign calls) will make a function incompatible with Zygote. |
| 74 | +In such cases, the proper workaround is to define a reverse rule (`rrule`) for that function using [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl). |
| 75 | +You can find a [pedagogical example](https://juliadiff.org/ChainRulesCore.jl/stable/rule_author/example.html) for rule-writing in the documentation of ChainRulesCore.jl. |
| 76 | + |
| 77 | +### Enzyme |
| 78 | + |
| 79 | +By targeting a lower-level code representation than Zygote.jl, [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) is able to differentiate a much wider set of functions. |
| 80 | +The [FAQ](https://enzymead.github.io/Enzyme.jl/stable/faq/) gives some details on the breadth of coverage, but it should be enough for a lot of use cases. |
| 81 | + |
| 82 | +Enzyme.jl also has an extensible [rule system](https://enzymead.github.io/Enzyme.jl/stable/generated/custom_rule/) which you can use to circumvent differentiation errors. |
| 83 | +Note that its rule writing is very different from ChainRulesCore.jl due to the presence of input activity [annotations](https://enzymead.github.io/Enzyme.jl/stable/api/#EnzymeCore.Annotation). |
| 84 | + |
| 85 | +### Mooncake |
| 86 | + |
| 87 | +[Mooncake.jl](https://github.com/compintell/Mooncake.jl) is a recent package which also handles a large subset of all Julia programs out-of-the-box. |
| 88 | + |
| 89 | +Its [rule system](https://compintell.github.io/Mooncake.jl/dev/understanding_mooncake/rule_system/) is less expressive than that of Enzyme.jl, which might make it easier to start with. |
| 90 | + |
| 91 | +## A rule mayhem? |
| 92 | + |
| 93 | +To summarize, here are the main rule systems which coexist at the moment: |
| 94 | + |
| 95 | +- `Dual` numbers in ForwardDiff.jl |
| 96 | +- ChainRulesCore.jl |
| 97 | +- Enzyme.jl |
| 98 | +- Mooncake.jl |
| 99 | + |
| 100 | +### Rule translation |
| 101 | + |
| 102 | +This split situation is unfortunate, but AD packages are so complex that making a cross-backend rule system is a very ambitious endeavor. |
| 103 | +ChainRulesCore.jl is the closest thing we have to a standard, but it does not handle mutation. |
| 104 | +As a result, Enzyme.jl and Mooncake.jl both rolled out their own designs, which are not mutually compatible. |
| 105 | +There are, however, translation utilities: |
| 106 | + |
| 107 | +- from ChainRulesCore.jl to ForwardDiff.jl with [ForwardDiffChainRules.jl](https://github.com/ThummeTo/ForwardDiffChainRules.jl) |
| 108 | +- from ChainRulesCore.jl to Enzyme.jl with [`Enzyme.@import_rrule`](https://enzymead.github.io/Enzyme.jl/stable/api/#Enzyme.@import_rrule-Tuple) |
| 109 | +- from ChainRulesCore.jl to Mooncake.jl with [`Mooncake.@from_rrule`](https://compintell.github.io/Mooncake.jl/dev/utilities/tools_for_rules/#Using-ChainRules.jl) |
| 110 | + |
| 111 | +### Backend switch |
| 112 | + |
| 113 | +Also note the existence of [`DifferentiationInterface.DifferentiateWith`](@ref), which allows the user to wrap a function that should be differentiated with a specific backend. |
| 114 | +Right now it only targets ForwardDiff.jl and ChainRulesCore.jl, but PRs are welcome to define Enzyme.jl and Mooncake.jl rules for this object. |
0 commit comments