Skip to content

Commit 8ca970e

Browse files
committed
Better docs
1 parent d3af887 commit 8ca970e

9 files changed

Lines changed: 138 additions & 33 deletions

File tree

DifferentiationInterface/docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using Zygote: Zygote
1010

1111
links = InterLinks(
1212
"ADTypes" => "https://sciml.github.io/ADTypes.jl/stable/",
13+
"Reactant" => "https://enzymead.github.io/Reactant.jl/stable/",
1314
"SparseConnectivityTracer" => "https://adrianhill.de/SparseConnectivityTracer.jl/stable/",
1415
"SparseMatrixColorings" => "https://gdalle.github.io/SparseMatrixColorings.jl/stable/",
1516
"Symbolics" => "https://symbolics.juliasymbolics.org/stable/",

DifferentiationInterface/docs/src/api.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,3 +144,9 @@ DifferentiationInterface.AutoReverseFromPrimitive
144144
```@docs
145145
DifferentiationInterface.Prep
146146
```
147+
148+
### Reactant
149+
150+
```@docs
151+
DifferentiationInterface.to_reactant
152+
```

DifferentiationInterface/docs/src/explanation/backends.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,10 @@ If a GTPSA [`Descriptor`](https://bmad-sim.github.io/GTPSA.jl/stable/man/b_descr
169169

170170
Most operators fall back on `AutoForwardDiff`.
171171

172+
### Reactant
173+
174+
See the docstring for [`AutoReactant`](@ref).
175+
172176
### ReverseDiff
173177

174178
With `AutoReverseDiff(compile=false)`, preparation preallocates a [config](https://juliadiff.org/ReverseDiff.jl/dev/api/#The-AbstractConfig-API).

DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/onearg.jl

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ function DI.prepare_gradient_nokwarg(
1313
) where {F, C}
1414
_sig = DI.signature(f, rebackend, x; strict)
1515
backend = rebackend.mode
16-
xr = to_reac(x)
17-
gr = to_reac(similar(x))
18-
contextsr = map(to_reac, contexts)
16+
xr = x isa ConcreteRArray ? nothing : ConcreteRArray(x)
17+
gr = x isa ConcreteRArray ? nothing : ConcreteRArray(similar(x))
18+
contextsr = map(_to_reactant, contexts)
1919
compiled_gradient = @compile DI.gradient(f, backend, xr, contextsr...)
2020
compiled_gradient! = @compile DI.gradient!(f, gr, backend, xr, contextsr...)
2121
compiled_value_and_gradient = @compile DI.value_and_gradient(f, backend, xr, contextsr...)
@@ -36,10 +36,9 @@ function DI.gradient(
3636
) where {F, C}
3737
DI.check_prep(f, prep, rebackend, x)
3838
backend = rebackend.mode
39-
(; xr, compiled_gradient) = prep
40-
copyto!(xr, x)
41-
contextsr = map(to_reac, contexts)
42-
gr = compiled_gradient(f, backend, xr, contextsr...)
39+
xr = isnothing(prep.xr) ? x : copyto!(prep.xr, x)
40+
contextsr = map(_to_reactant, contexts)
41+
gr = prep.compiled_gradient(f, backend, xr, contextsr...)
4342
return gr
4443
end
4544

@@ -48,10 +47,9 @@ function DI.value_and_gradient(
4847
) where {F, C}
4948
DI.check_prep(f, prep, rebackend, x)
5049
backend = rebackend.mode
51-
(; xr, compiled_value_and_gradient) = prep
52-
copyto!(xr, x)
53-
contextsr = map(to_reac, contexts)
54-
yr, gr = compiled_value_and_gradient(f, backend, xr, contextsr...)
50+
xr = isnothing(prep.xr) ? x : copyto!(prep.xr, x)
51+
contextsr = map(_to_reactant, contexts)
52+
yr, gr = prep.compiled_value_and_gradient(f, backend, xr, contextsr...)
5553
return yr, gr
5654
end
5755

@@ -60,21 +58,21 @@ function DI.gradient!(
6058
) where {F, C}
6159
DI.check_prep(f, prep, rebackend, x)
6260
backend = rebackend.mode
63-
(; xr, gr, compiled_gradient!) = prep
64-
copyto!(xr, x)
65-
contextsr = map(to_reac, contexts)
66-
compiled_gradient!(f, gr, backend, xr, contextsr...)
67-
return copyto!(grad, gr)
61+
xr = isnothing(prep.xr) ? x : copyto!(prep.xr, x)
62+
gr = isnothing(prep.gr) ? grad : prep.gr
63+
contextsr = map(_to_reactant, contexts)
64+
prep.compiled_gradient!(f, gr, backend, xr, contextsr...)
65+
return isnothing(prep.gr) ? grad : copyto!(grad, gr)
6866
end
6967

7068
function DI.value_and_gradient!(
7169
f::F, grad, prep::ReactantGradientPrep, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C}
7270
) where {F, C}
7371
DI.check_prep(f, prep, rebackend, x)
7472
backend = rebackend.mode
75-
(; xr, gr, compiled_value_and_gradient!) = prep
76-
copyto!(xr, x)
77-
contextsr = map(to_reac, contexts)
78-
yr, gr = compiled_value_and_gradient!(f, gr, backend, xr, contextsr...)
79-
return yr, copyto!(grad, gr)
73+
xr = isnothing(prep.xr) ? x : copyto!(prep.xr, x)
74+
gr = isnothing(prep.gr) ? grad : prep.gr
75+
contextsr = map(_to_reactant, contexts)
76+
yr, gr = prep.compiled_value_and_gradient!(f, gr, backend, xr, contextsr...)
77+
return yr, isnothing(prep.gr) ? grad : copyto!(grad, gr)
8078
end
Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
to_reac(x::AbstractArray) = to_rarray(x)
2-
to_reac(x::ConcreteRArray) = x
3-
to_reac(x::Number) = ConcreteRNumber(x)
4-
to_reac(x::ConcreteRNumber) = x
5-
6-
to_reac(c::DI.Constant) = DI.Constant(to_reac(DI.unwrap(c)))
7-
to_reac(c::DI.Cache) = DI.Cache(to_reac(DI.unwrap(c)))
1+
_to_reactant(x) = DI.to_reactant(x)
2+
_to_reactant(c::DI.Constant) = DI.Constant(_to_reactant(DI.unwrap(c)))
3+
_to_reactant(c::DI.Cache) = DI.Cache(_to_reactant(DI.unwrap(c)))

DifferentiationInterface/src/DifferentiationInterface.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ include("misc/sparsity_detector.jl")
6969
include("misc/simple_finite_diff.jl")
7070
include("misc/zero_backends.jl")
7171
include("misc/overloading.jl")
72+
include("misc/reactant.jl")
7273

7374
## Exported
7475

@@ -132,6 +133,7 @@ export AutoSparse
132133
@public inner, outer
133134
@public AutoForwardFromPrimitive, AutoReverseFromPrimitive
134135
@public Prep
136+
@public to_reactant
135137

136138
include("init.jl")
137139

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
"""
2+
!!! tip "DI-specific information"
3+
This part of the docstring is related to the use of `AutoReactant` inside [DifferentiationInterface.jl](https://github.com/JuliaDiff/DifferentiationInterface.jl), or DI for short.
4+
Reactant's tutorial on [partial evaluation](https://enzymead.github.io/Reactant.jl/stable/tutorials/partial-evaluation) is useful reading to understand what follows.
5+
6+
The `AutoReactant` backend inside DI imposes the following restrictions / assumptions:
7+
8+
- The only supported operator (at the moment) is `DI.gradient` (along with its variants).
9+
- The input `x` must be an `AbstractArray` such that `Reactant.ConcreteRArray(x)` is well-defined.
10+
- By default, contexts such as `DI.Constant` and `DI.Cache` will be partially evaluated inside the compiled differentiation operator at preparation time. This means that the context value provided at preparation will be reused at every subsequent execution, while the context value provided at execution will be ignored. In particular, `DI.Cache` contexts will usually error and `DI.Constant` contexts will be frozen to one value.
11+
12+
To disable partial evaluation and enforce tracing of contexts instead, first wrap them into types that _you own_.
13+
Then, overload [`DifferentiationInterface.to_reactant`](@ref) on these types to perform tracing in the way you see fit, for instance with `Reactant.to_rarray`.
14+
Every value you choose not to trace will still be partially evaluated at preparation time.
15+
16+
# Example
17+
18+
```jldoctest
19+
using DifferentiationInterface
20+
import DifferentiationInterface as DI
21+
import Reactant
22+
23+
struct MyArgument{T1 <: Number, T2 <: AbstractArray}
24+
u::T1
25+
v::T2
26+
end
27+
28+
f(x, a::MyArgument) = a.u * sum(a.v .* x .^ 2)
29+
30+
DI.to_reactant(a::MyArgument) = Reactant.to_rarray(a; track_numbers = false)
31+
32+
# preparation time
33+
x0 = zeros(2)
34+
a0 = MyArgument(1.0, [2.0, 3.0])
35+
36+
# execution time
37+
x = [4.0, 5.0]
38+
a = MyArgument(6.0, [7.0, 8.0])
39+
40+
backend = AutoReactant()
41+
prep = prepare_gradient(f, backend, x0, Constant(a0));
42+
43+
g = gradient(f, prep, backend, x, Constant(a))
44+
g ≈ a0.u * 2 * (a.v .* x) # a0.u is partially evaluated, a0.v is traced
45+
46+
# output
47+
48+
true
49+
```
50+
"""
51+
AutoReactant
52+
53+
"""
54+
to_reactant(a)
55+
56+
Convert an argument `a` to an object `ar` containing the same values, where all the fields and subfields that can contain active (differentiated) data have been translated to [Reactant.jl](https://github.com/EnzymeAD/Reactant.jl) types such as [`ConcreteRArray`](@extref Reactant.ConcreteRArray) or [`ConcreteRNumber`](@extref Reactant.ConcreteRNumber).
57+
58+
!!! danger
59+
DifferentiationInterface.jl implements this function as the identity, on purpose.
60+
It should not be overloaded on base types, but only on types that you own, to modify the default behavior of `AutoReactant`.
61+
62+
# Example
63+
64+
```jldoctest
65+
import DifferentiationInterface as DI
66+
import Reactant
67+
68+
struct MyArgument{T1 <: Number, T2 <: AbstractArray}
69+
u::T1
70+
v::T2
71+
end
72+
73+
DI.to_reactant(a::MyArgument) = Reactant.to_rarray(a; track_numbers = false)
74+
75+
a = MyArgument(1.0, [2.0, 3.0])
76+
ar = DI.to_reactant(a)
77+
ar isa MyArgument{Float64, <:Reactant.ConcreteRArray}
78+
79+
# output
80+
81+
true
82+
```
83+
"""
84+
to_reactant(x) = x
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
[deps]
2+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
3+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
4+
DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3"
5+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
6+
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
7+
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
8+
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
9+
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
10+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

DifferentiationInterface/test/Back/Reactant/test.jl

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,24 @@
1-
using Pkg
2-
Pkg.add(url = "https://github.com/EnzymeAD/Enzyme.jl")
3-
Pkg.add("Reactant")
1+
include("../../testutils.jl")
42

53
using DifferentiationInterface
4+
import DifferentiationInterface as DI
65
using DifferentiationInterfaceTest
7-
using Reactant
6+
import DifferentiationInterfaceTest as DIT
7+
import Enzyme, Reactant
88
using Test
99

1010
backend = AutoReactant()
1111

1212
@test check_available(backend)
1313
@test check_inplace(backend)
1414

15+
scen1 = DIT.Scenario(
16+
17+
)
18+
1519
test_differentiation(
1620
backend, DifferentiationInterfaceTest.default_scenarios(;
17-
include_constantified = true, include_cachified = false
21+
include_constantified = false, include_cachified = false
1822
);
1923
excluded = vcat(SECOND_ORDER, :jacobian, :derivative, :pushforward, :pullback),
2024
logging = false

0 commit comments

Comments
 (0)