Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ Apart from the conditions above, this repository follows the [ColPrac](https://g
Its code is formatted using [Runic.jl](https://github.com/fredrikekre/Runic.jl).
As part of continuous integration, a set of formal tests is run using [pre-commit](https://pre-commit.com/).
We invite you to install pre-commit so that these checks are performed locally before you open or update a pull request.
You can refer to the [dev guide](https://juliadiff.org/DifferentiationInterface.jl/DifferentiationInterface/dev/dev_guide/) for details on the package structure and the testing pipeline.
You can refer to the relevant page of the development documentation for details on the package structure and the testing pipeline.
5 changes: 4 additions & 1 deletion DifferentiationInterface/docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ makedocs(;
],
"FAQ" => ["faq/limitations.md", "faq/differentiability.md"],
"api.md",
"dev_guide.md",
"Development" => [
"dev/internals.md",
"dev/contributing.md",
],
],
plugins = [links],
)
Expand Down
10 changes: 3 additions & 7 deletions DifferentiationInterface/docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,12 +139,8 @@ DifferentiationInterface.AutoForwardFromPrimitive
DifferentiationInterface.AutoReverseFromPrimitive
```

## Internals
### Preparation type

The following is not part of the public API.

```@autodocs
Modules = [DifferentiationInterface]
Public = false
Filter = t -> !(Symbol(t) in [:outer, :inner])
```@docs
DifferentiationInterface.Prep
```
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Dev guide
# Contributing

This page is important reading if you want to contribute to DifferentiationInterface.jl.
It is not part of the public API and the content below may become outdated, in which case you should refer to the source code as the ground truth.
Expand All @@ -7,26 +7,27 @@ It is not part of the public API and the content below may become outdated, in w

The package is structured around 8 [operators](@ref Operators):

- [`derivative`](@ref)
- [`second_derivative`](@ref)
- [`gradient`](@ref)
- [`jacobian`](@ref)
- [`hessian`](@ref)
- [`pushforward`](@ref)
- [`pullback`](@ref)
- [`hvp`](@ref)
- [`derivative`](@ref)
- [`second_derivative`](@ref)
- [`gradient`](@ref)
- [`jacobian`](@ref)
- [`hessian`](@ref)
- [`pushforward`](@ref)
- [`pullback`](@ref)
- [`hvp`](@ref)

Most operators have 4 variants, which look like this in the first order: `operator`, `operator!`, `value_and_operator`, `value_and_operator!`.

## New operator

To implement a new operator for an existing backend, you need to write 5 methods: 1 for [preparation](@ref Preparation) and 4 corresponding to the variants of the operator (see above).
For first-order operators, you may also want to support [in-place functions](@ref "Mutation and signatures"), which requires another 5 methods (defined on `f!` instead of `f`).
For some operators, you will also need to support [in-place functions](@ref "Mutation and signatures"), which requires another 5 methods (defined on `f!` instead of `f`).

The method `prepare_operator_nokwarg` must output a `prep` object of the correct type.
For instance, `prepare_gradient(strict, f, backend, x)` must return a [`DifferentiationInterface.GradientPrep`](@ref).
Assuming you don't need any preparation for said operator, you can use the trivial prep that are already defined, like `DifferentiationInterface.NoGradientPrep{SIG}`.
For instance, `prepare_gradient_nokwarg(strict, f, backend, x)` must return a [`DifferentiationInterface.GradientPrep`](@ref).
Assuming you don't need any preparation for said operator, you can use the trivial preparation types that are already defined, like `DifferentiationInterface.NoGradientPrep{SIG}`.
Otherwise, define a custom struct like `MyGradientPrep{SIG} <: DifferentiationInterface.GradientPrep{SIG}` and put the necessary storage in there.
Take inspiration from existing operators on how to enforce the signature `SIG`.

## New backend

Expand All @@ -36,18 +37,18 @@ Your AD package needs to be registered first.
### Core code

In the main package, you should define a new struct `SuperDiffBackend` which subtypes [`ADTypes.AbstractADType`](@extref ADTypes), and endow it with the fields you need to parametrize your differentiation routines.
You also have to define [`ADTypes.mode`](@extref) and [`DifferentiationInterface.inplace_support`](@ref) on `SuperDiffBackend`.
You also have to define [`ADTypes.mode`](@extref), [`DifferentiationInterface.check_available`](@ref) and [`DifferentiationInterface.inplace_support`](@ref) on `SuperDiffBackend`.

!!! info

In the end, this backend struct will need to be contributed to [ADTypes.jl](https://github.com/SciML/ADTypes.jl).
However, putting it in the DifferentiationInterface.jl PR is a good first step for debugging.

In a [package extension](https://pkgdocs.julialang.org/v1/creating-packages/#Conditional-loading-of-code-in-packages-(Extensions)) named `DifferentiationInterfaceSuperDiffExt`, you need to implement at least [`pushforward`](@ref) or [`pullback`](@ref) (and their variants).
The exact requirements depend on the differentiation mode you chose:

| backend mode | pushforward necessary | pullback necessary |
|:------------------------------------------------- |:--------------------- |:------------------ |
| :------------------------------------------------ | :-------------------- | :----------------- |
| [`ADTypes.ForwardMode`](@extref ADTypes) | yes | no |
| [`ADTypes.ReverseMode`](@extref ADTypes) | no | yes |
| [`ADTypes.ForwardOrReverseMode`](@extref ADTypes) | yes | yes |
Expand Down
9 changes: 9 additions & 0 deletions DifferentiationInterface/docs/src/dev/internals.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Internals

The following names are not part of the public API.

```@autodocs
Modules = [DifferentiationInterface]
Public = false
Filter = t -> !(Symbol(t) in [:outer, :inner, :Prep, :AutoForwardFromPrimitive, :AutoReverseFromPrimitive])
```
1 change: 1 addition & 0 deletions DifferentiationInterface/src/DifferentiationInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ export AutoSparse

@public inner, outer
@public AutoForwardFromPrimitive, AutoReverseFromPrimitive
@public Prep

include("init.jl")

Expand Down
8 changes: 8 additions & 0 deletions DifferentiationInterface/src/utils/prep.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
"""
Prep

Abstract supertype for all preparation results (outputs of `prepare_operator` functions).

!!! warning
The public API does not make any guarantees about the type parameters or field layout of `Prep`, the only guarantee is that this type exists.
"""
abstract type Prep{SIG} end

"""
Expand Down
14 changes: 14 additions & 0 deletions DifferentiationInterface/test/Core/Internals/prep.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using DifferentiationInterface: Prep
using InteractiveUtils: subtypes
using Test

@test subtypes(Prep) == [
DifferentiationInterface.DerivativePrep,
DifferentiationInterface.GradientPrep,
DifferentiationInterface.HVPPrep,
DifferentiationInterface.HessianPrep,
DifferentiationInterface.JacobianPrep,
DifferentiationInterface.PullbackPrep,
DifferentiationInterface.PushforwardPrep,
DifferentiationInterface.SecondDerivativePrep,
]
2 changes: 2 additions & 0 deletions DifferentiationInterface/test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -23,6 +24,7 @@ ComponentArrays = "0.15.27"
DataFrames = "1.7.0"
Dates = "1"
ExplicitImports = "1.10.1"
InteractiveUtils = "1"
JET = "0.9,0.10"
JLArrays = "0.2.0"
Pkg = "1"
Expand Down
Loading