Skip to content

Commit c7e7598

Browse files
committed
Include contexts
1 parent 2f13597 commit c7e7598

8 files changed

Lines changed: 63 additions & 44 deletions

File tree

DifferentiationInterface/docs/src/explanation/backends.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ We support the following dense backend choices from [ADTypes.jl](https://github.
1919
- [`AutoTracker`](@extref ADTypes.AutoTracker)
2020
- [`AutoZygote`](@extref ADTypes.AutoZygote)
2121

22+
In addition, we provide experimental support for [`AutoReactant`](@extref ADTypes.AutoReactant), sofar only for [`gradient`](@ref) and its variants.
23+
2224
## Features
2325

2426
Given a backend object, you can use:

DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/DifferentiationInterfaceReactantExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@ module DifferentiationInterfaceReactantExt
22

33
using ADTypes: ADTypes, AutoReactant
44
import DifferentiationInterface as DI
5-
using Reactant: @compile, to_rarray
5+
using Reactant: @compile, ConcreteRArray, ConcreteRNumber, to_rarray
66

77
DI.check_available(backend::AutoReactant) = DI.check_available(backend.mode)
88
DI.inplace_support(backend::AutoReactant) = DI.inplace_support(backend.mode)
99

10+
include("utils.jl")
1011
include("onearg.jl")
1112

1213
end # module

DifferentiationInterface/ext/DifferentiationInterfaceReactantExt/onearg.jl

Lines changed: 38 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,18 @@ struct ReactantGradientPrep{SIG, XR, GR, CG, CG!, CVG, CVG!} <: DI.GradientPrep{
88
compiled_value_and_gradient!::CVG!
99
end
1010

11-
function DI.prepare_gradient_nokwarg(strict::Val, f::F, rebackend::AutoReactant, x) where {F}
11+
function DI.prepare_gradient_nokwarg(
12+
strict::Val, f::F, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C}
13+
) where {F, C}
1214
_sig = DI.signature(f, rebackend, x; strict)
1315
backend = rebackend.mode
14-
xr = to_rarray(x)
15-
gr = to_rarray(similar(x))
16-
_gradient(_xr) = DI.gradient(f, backend, _xr)
17-
_gradient!(_gr, _xr) = copy!(_gr, DI.gradient(f, backend, _xr))
18-
_value_and_gradient(_xr) = DI.value_and_gradient(f, backend, _xr)
19-
function _value_and_gradient!(_gr, _xr)
20-
y, __gr = DI.value_and_gradient(f, backend, _xr)
21-
copy!(_gr, __gr)
22-
return y, _gr
23-
end
24-
compiled_gradient = @compile _gradient(xr)
25-
compiled_gradient! = @compile _gradient!(gr, xr)
26-
compiled_value_and_gradient = @compile _value_and_gradient(xr)
27-
compiled_value_and_gradient! = @compile _value_and_gradient!(gr, xr)
16+
xr = to_reac(x)
17+
gr = to_reac(similar(x))
18+
contextsr = map(to_reac, contexts)
19+
compiled_gradient = @compile DI.gradient(f, backend, xr, contextsr...)
20+
compiled_gradient! = @compile DI.gradient!(f, gr, backend, xr, contextsr...)
21+
compiled_value_and_gradient = @compile DI.value_and_gradient(f, backend, xr, contextsr...)
22+
compiled_value_and_gradient! = @compile DI.value_and_gradient!(f, gr, backend, xr, contextsr...)
2823
return ReactantGradientPrep(
2924
_sig,
3025
xr,
@@ -37,45 +32,49 @@ function DI.prepare_gradient_nokwarg(strict::Val, f::F, rebackend::AutoReactant,
3732
end
3833

3934
function DI.gradient(
40-
f::F, prep::ReactantGradientPrep, rebackend::AutoReactant, x
41-
) where {F}
35+
f::F, prep::ReactantGradientPrep, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C}
36+
) where {F, C}
4237
DI.check_prep(f, prep, rebackend, x)
38+
backend = rebackend.mode
4339
(; xr, compiled_gradient) = prep
44-
copy!(xr, x)
45-
gr = compiled_gradient(xr)
46-
g = convert(typeof(x), gr)
47-
return g
40+
copyto!(xr, x)
41+
contextsr = map(to_reac, contexts)
42+
gr = compiled_gradient(f, backend, xr, contextsr...)
43+
return gr
4844
end
4945

5046
function DI.value_and_gradient(
51-
f::F, prep::ReactantGradientPrep, rebackend::AutoReactant, x
52-
) where {F}
47+
f::F, prep::ReactantGradientPrep, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C}
48+
) where {F, C}
5349
DI.check_prep(f, prep, rebackend, x)
50+
backend = rebackend.mode
5451
(; xr, compiled_value_and_gradient) = prep
55-
copy!(xr, x)
56-
yr, gr = compiled_value_and_gradient(xr)
57-
y = convert(eltype(x), yr)
58-
g = convert(typeof(x), gr)
59-
return y, g
52+
copyto!(xr, x)
53+
contextsr = map(to_reac, contexts)
54+
yr, gr = compiled_value_and_gradient(f, backend, xr, contextsr...)
55+
return yr, gr
6056
end
6157

6258
function DI.gradient!(
63-
f::F, grad, prep::ReactantGradientPrep, rebackend::AutoReactant, x
64-
) where {F}
59+
f::F, grad, prep::ReactantGradientPrep, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C}
60+
) where {F, C}
6561
DI.check_prep(f, prep, rebackend, x)
62+
backend = rebackend.mode
6663
(; xr, gr, compiled_gradient!) = prep
67-
copy!(xr, x)
68-
compiled_gradient!(gr, xr)
69-
return copy!(grad, gr)
64+
copyto!(xr, x)
65+
contextsr = map(to_reac, contexts)
66+
compiled_gradient!(f, gr, backend, xr, contextsr...)
67+
return copyto!(grad, gr)
7068
end
7169

7270
function DI.value_and_gradient!(
73-
f::F, grad, prep::ReactantGradientPrep, rebackend::AutoReactant, x
74-
) where {F}
71+
f::F, grad, prep::ReactantGradientPrep, rebackend::AutoReactant, x, contexts::Vararg{DI.Context, C}
72+
) where {F, C}
7573
DI.check_prep(f, prep, rebackend, x)
74+
backend = rebackend.mode
7675
(; xr, gr, compiled_value_and_gradient!) = prep
77-
copy!(xr, x)
78-
yr, gr = compiled_value_and_gradient!(gr, xr)
79-
y = convert(eltype(x), yr)
80-
return y, copy!(grad, gr)
76+
copyto!(xr, x)
77+
contextsr = map(to_reac, contexts)
78+
yr, gr = compiled_value_and_gradient!(f, gr, backend, xr, contextsr...)
79+
return yr, copyto!(grad, gr)
8180
end
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
to_reac(x::AbstractArray) = to_rarray(x)
2+
to_reac(x::ConcreteRArray) = x
3+
to_reac(x::Number) = ConcreteRNumber(x)
4+
to_reac(x::ConcreteRNumber) = x
5+
6+
to_reac(c::DI.Constant) = DI.Constant(to_reac(DI.unwrap(c)))
7+
to_reac(c::DI.Cache) = DI.Cache(to_reac(DI.unwrap(c)))
Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
using Pkg
2+
Pkg.add(url = "https://github.com/EnzymeAD/Enzyme.jl")
23
Pkg.add("Reactant")
34

45
using DifferentiationInterface
56
using DifferentiationInterfaceTest
67
using Reactant
8+
using Test
79

810
backend = AutoReactant()
911

12+
@test check_available(backend)
13+
@test check_inplace(backend)
14+
1015
test_differentiation(
11-
backend, DifferentiationInterfaceTest.default_scenarios();
16+
backend, DifferentiationInterfaceTest.default_scenarios(;
17+
include_constantified = true, include_cachified = false
18+
);
1219
excluded = vcat(SECOND_ORDER, :jacobian, :derivative, :pushforward, :pullback),
13-
logging = true
20+
logging = false
1421
)

DifferentiationInterfaceTest/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de"
1010
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
1111
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1212
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
13+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1314
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
1415
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1516
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
@@ -40,6 +41,7 @@ DataFrames = "1.6.1"
4041
DifferentiationInterface = "0.7.7"
4142
DocStringExtensions = "0.8,0.9"
4243
ForwardDiff = "0.10.36,1"
44+
GPUArraysCore = "0.2.0"
4345
JET = "0.9,0.10,0.11"
4446
JLArrays = "0.1,0.2,0.3"
4547
LinearAlgebra = "1"

DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ using DifferentiationInterface:
9292
using DifferentiationInterface: Rewrap, Context, Constant, Cache, ConstantOrCache, unwrap
9393
using DifferentiationInterface: PreparationMismatchError
9494
using DocStringExtensions: TYPEDFIELDS, TYPEDSIGNATURES
95+
using GPUArraysCore: @allowscalar
9596
using JET: @test_opt
9697
using LinearAlgebra: Adjoint, Diagonal, Transpose, I, dot, parent
9798
using PrecompileTools: @compile_workload

DifferentiationInterfaceTest/src/scenarios/modify.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,8 +224,8 @@ Base.show(io::IO, f::StoreInCache) = print(io, "StoreInCache($(f.f))")
224224
function (sc::StoreInCache{:out})(x, y_cache) # no annotation otherwise Zygote.Buffer cries
225225
y = sc.f(x)
226226
if y isa Number
227-
y_cache[1] = y
228-
return y_cache[1]
227+
@allowscalar y_cache[1] = y
228+
return @allowscalar y_cache[1]
229229
else
230230
copyto!(y_cache, y)
231231
return copy(y_cache)

0 commit comments

Comments
 (0)