Skip to content

Commit 2f13597

Browse files
committed
feat: add gradient with AutoReactant
1 parent bbc39fd commit 2f13597

6 files changed

Lines changed: 194 additions & 79 deletions

File tree

.github/workflows/Test.yml

Lines changed: 81 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -28,29 +28,30 @@ jobs:
2828
fail-fast: true # TODO: toggle
2929
matrix:
3030
version:
31-
- '1.10'
31+
# - '1.10'
3232
- '1.11'
3333
- '1.12'
3434
group:
35-
- Core/Internals
36-
- Back/DifferentiateWith
37-
- Core/SimpleFiniteDiff
38-
- Back/SparsityDetector
39-
- Core/ZeroBackends
40-
- Back/ChainRules
41-
# - Back/Diffractor
42-
- Back/Enzyme
43-
- Back/FastDifferentiation
44-
- Back/FiniteDiff
45-
- Back/FiniteDifferences
46-
- Back/ForwardDiff
47-
- Back/GTPSA
48-
- Back/Mooncake
49-
- Back/PolyesterForwardDiff
50-
- Back/ReverseDiff
51-
- Back/Symbolics
52-
- Back/Tracker
53-
- Back/Zygote
35+
# - Core/Internals
36+
# - Back/DifferentiateWith
37+
# - Core/SimpleFiniteDiff
38+
# - Back/SparsityDetector
39+
# - Core/ZeroBackends
40+
# - Back/ChainRules
41+
# # - Back/Diffractor
42+
# - Back/Enzyme
43+
# - Back/FastDifferentiation
44+
# - Back/FiniteDiff
45+
# - Back/FiniteDifferences
46+
# - Back/ForwardDiff
47+
# - Back/GTPSA
48+
# - Back/Mooncake
49+
# - Back/PolyesterForwardDiff
50+
- Back/Reactant
51+
# - Back/ReverseDiff
52+
# - Back/Symbolics
53+
# - Back/Tracker
54+
# - Back/Zygote
5455
skip_lts:
5556
- ${{ github.event.pull_request.draft }}
5657
skip_pre:
@@ -64,6 +65,8 @@ jobs:
6465
group: Back/ChainRules
6566
- version: '1.12'
6667
group: Back/Enzyme
68+
- version: '1.12'
69+
group: Back/Reactant
6770
- version: '1.12'
6871
group: Back/DifferentiateWith
6972
env:
@@ -104,61 +107,61 @@ jobs:
104107
token: ${{ secrets.CODECOV_TOKEN }}
105108
fail_ci_if_error: false
106109

107-
test-DIT:
108-
name: ${{ matrix.version }} - DIT (${{ matrix.group }})
109-
runs-on: ubuntu-latest
110-
if: ${{ !contains(github.event.pull_request.labels.*.name, 'skipci') }}
111-
timeout-minutes: 60
112-
permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created
113-
actions: write
114-
contents: read
115-
strategy:
116-
fail-fast: true
117-
matrix:
118-
version:
119-
- '1.10'
120-
- '1.11'
121-
- '1.12'
122-
group:
123-
- Formalities
124-
- Zero
125-
- Standard
126-
- Weird
127-
skip_lts:
128-
- ${{ github.event.pull_request.draft }}
129-
skip_pre:
130-
- ${{ github.event.pull_request.draft }}
131-
exclude:
132-
- skip_lts: true
133-
version: '1.10'
134-
- skip_pre: true
135-
version: '1.12'
136-
env:
137-
JULIA_DIT_TEST_GROUP: ${{ matrix.group }}
138-
JULIA_DI_PR_DRAFT: ${{ github.event.pull_request.draft }}
139-
steps:
140-
- uses: actions/checkout@v5
141-
- uses: julia-actions/setup-julia@v2
142-
with:
143-
version: ${{ matrix.version }}
144-
arch: x64
145-
- uses: julia-actions/cache@v2
146-
- name: Install dependencies & run tests
147-
run: julia --project=./DifferentiationInterfaceTest --color=yes -e '
148-
using Pkg;
149-
Pkg.Registry.update();
150-
Pkg.develop(path="./DifferentiationInterface");
151-
if ENV["JULIA_DI_PR_DRAFT"] == "true";
152-
Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true, julia_args=["-O1"]);
153-
else;
154-
Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true);
155-
end;'
156-
- uses: julia-actions/julia-processcoverage@v1
157-
with:
158-
directories: ./DifferentiationInterfaceTest/src,./DifferentiationInterfaceTest/ext,./DifferentiationInterfaceTest/test
159-
- uses: codecov/codecov-action@v5
160-
with:
161-
files: lcov.info
162-
flags: DIT
163-
token: ${{ secrets.CODECOV_TOKEN }}
164-
fail_ci_if_error: false
110+
# test-DIT:
111+
# name: ${{ matrix.version }} - DIT (${{ matrix.group }})
112+
# runs-on: ubuntu-latest
113+
# if: ${{ !contains(github.event.pull_request.labels.*.name, 'skipci') }}
114+
# timeout-minutes: 60
115+
# permissions: # needed to allow julia-actions/cache to proactively delete old caches that it has created
116+
# actions: write
117+
# contents: read
118+
# strategy:
119+
# fail-fast: true
120+
# matrix:
121+
# version:
122+
# - '1.10'
123+
# - '1.11'
124+
# - '1.12'
125+
# group:
126+
# - Formalities
127+
# - Zero
128+
# - Standard
129+
# - Weird
130+
# skip_lts:
131+
# - ${{ github.event.pull_request.draft }}
132+
# skip_pre:
133+
# - ${{ github.event.pull_request.draft }}
134+
# exclude:
135+
# - skip_lts: true
136+
# version: '1.10'
137+
# - skip_pre: true
138+
# version: '1.12'
139+
# env:
140+
# JULIA_DIT_TEST_GROUP: ${{ matrix.group }}
141+
# JULIA_DI_PR_DRAFT: ${{ github.event.pull_request.draft }}
142+
# steps:
143+
# - uses: actions/checkout@v5
144+
# - uses: julia-actions/setup-julia@v2
145+
# with:
146+
# version: ${{ matrix.version }}
147+
# arch: x64
148+
# - uses: julia-actions/cache@v2
149+
# - name: Install dependencies & run tests
150+
# run: julia --project=./DifferentiationInterfaceTest --color=yes -e '
151+
# using Pkg;
152+
# Pkg.Registry.update();
153+
# Pkg.develop(path="./DifferentiationInterface");
154+
# if ENV["JULIA_DI_PR_DRAFT"] == "true";
155+
# Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true, julia_args=["-O1"]);
156+
# else;
157+
# Pkg.test("DifferentiationInterfaceTest"; allow_reresolve=false, coverage=true);
158+
# end;'
159+
# - uses: julia-actions/julia-processcoverage@v1
160+
# with:
161+
# directories: ./DifferentiationInterfaceTest/src,./DifferentiationInterfaceTest/ext,./DifferentiationInterfaceTest/test
162+
# - uses: codecov/codecov-action@v5
163+
# with:
164+
# files: lcov.info
165+
# flags: DIT
166+
# token: ${{ secrets.CODECOV_TOKEN }}
167+
# fail_ci_if_error: false

DifferentiationInterface/Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
2121
GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8"
2222
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2323
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
24+
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
2425
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
2526
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2627
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
@@ -46,6 +47,7 @@ DifferentiationInterfacePolyesterForwardDiffExt = [
4647
"ForwardDiff",
4748
"DiffResults",
4849
]
50+
DifferentiationInterfaceReactantExt = ["Reactant", "Enzyme"]
4951
DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"]
5052
DifferentiationInterfaceSparseArraysExt = "SparseArrays"
5153
DifferentiationInterfaceSparseConnectivityTracerExt = "SparseConnectivityTracer"
@@ -56,7 +58,7 @@ DifferentiationInterfaceTrackerExt = "Tracker"
5658
DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"]
5759

5860
[compat]
59-
ADTypes = "1.18.0"
61+
ADTypes = "1.19.0"
6062
ChainRulesCore = "1.23.0"
6163
DiffResults = "1.1.0"
6264
Diffractor = "=0.2.6"
@@ -71,6 +73,7 @@ GTPSA = "1.4.0"
7173
LinearAlgebra = "1"
7274
Mooncake = "0.4.175"
7375
PolyesterForwardDiff = "0.1.2"
76+
Reactant = "0.2.178"
7477
ReverseDiff = "1.15.1"
7578
SparseArrays = "1"
7679
SparseConnectivityTracer = "0.6.14, 1"
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
module DifferentiationInterfaceReactantExt
2+
3+
using ADTypes: ADTypes, AutoReactant
4+
import DifferentiationInterface as DI
5+
using Reactant: @compile, to_rarray
6+
7+
DI.check_available(backend::AutoReactant) = DI.check_available(backend.mode)
8+
DI.inplace_support(backend::AutoReactant) = DI.inplace_support(backend.mode)
9+
10+
include("onearg.jl")
11+
12+
end # module
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
struct ReactantGradientPrep{SIG, XR, GR, CG, CG!, CVG, CVG!} <: DI.GradientPrep{SIG}
2+
_sig::Val{SIG}
3+
xr::XR
4+
gr::GR
5+
compiled_gradient::CG
6+
compiled_gradient!::CG!
7+
compiled_value_and_gradient::CVG
8+
compiled_value_and_gradient!::CVG!
9+
end
10+
11+
function DI.prepare_gradient_nokwarg(strict::Val, f::F, rebackend::AutoReactant, x) where {F}
12+
_sig = DI.signature(f, rebackend, x; strict)
13+
backend = rebackend.mode
14+
xr = to_rarray(x)
15+
gr = to_rarray(similar(x))
16+
_gradient(_xr) = DI.gradient(f, backend, _xr)
17+
_gradient!(_gr, _xr) = copy!(_gr, DI.gradient(f, backend, _xr))
18+
_value_and_gradient(_xr) = DI.value_and_gradient(f, backend, _xr)
19+
function _value_and_gradient!(_gr, _xr)
20+
y, __gr = DI.value_and_gradient(f, backend, _xr)
21+
copy!(_gr, __gr)
22+
return y, _gr
23+
end
24+
compiled_gradient = @compile _gradient(xr)
25+
compiled_gradient! = @compile _gradient!(gr, xr)
26+
compiled_value_and_gradient = @compile _value_and_gradient(xr)
27+
compiled_value_and_gradient! = @compile _value_and_gradient!(gr, xr)
28+
return ReactantGradientPrep(
29+
_sig,
30+
xr,
31+
gr,
32+
compiled_gradient,
33+
compiled_gradient!,
34+
compiled_value_and_gradient,
35+
compiled_value_and_gradient!,
36+
)
37+
end
38+
39+
function DI.gradient(
40+
f::F, prep::ReactantGradientPrep, rebackend::AutoReactant, x
41+
) where {F}
42+
DI.check_prep(f, prep, rebackend, x)
43+
(; xr, compiled_gradient) = prep
44+
copy!(xr, x)
45+
gr = compiled_gradient(xr)
46+
g = convert(typeof(x), gr)
47+
return g
48+
end
49+
50+
function DI.value_and_gradient(
51+
f::F, prep::ReactantGradientPrep, rebackend::AutoReactant, x
52+
) where {F}
53+
DI.check_prep(f, prep, rebackend, x)
54+
(; xr, compiled_value_and_gradient) = prep
55+
copy!(xr, x)
56+
yr, gr = compiled_value_and_gradient(xr)
57+
y = convert(eltype(x), yr)
58+
g = convert(typeof(x), gr)
59+
return y, g
60+
end
61+
62+
function DI.gradient!(
63+
f::F, grad, prep::ReactantGradientPrep, rebackend::AutoReactant, x
64+
) where {F}
65+
DI.check_prep(f, prep, rebackend, x)
66+
(; xr, gr, compiled_gradient!) = prep
67+
copy!(xr, x)
68+
compiled_gradient!(gr, xr)
69+
return copy!(grad, gr)
70+
end
71+
72+
function DI.value_and_gradient!(
73+
f::F, grad, prep::ReactantGradientPrep, rebackend::AutoReactant, x
74+
) where {F}
75+
DI.check_prep(f, prep, rebackend, x)
76+
(; xr, gr, compiled_value_and_gradient!) = prep
77+
copy!(xr, x)
78+
yr, gr = compiled_value_and_gradient!(gr, xr)
79+
y = convert(eltype(x), yr)
80+
return y, copy!(grad, gr)
81+
end

DifferentiationInterface/src/DifferentiationInterface.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ using ADTypes:
3030
AutoMooncake,
3131
AutoMooncakeForward,
3232
AutoPolyesterForwardDiff,
33+
AutoReactant,
3334
AutoReverseDiff,
3435
AutoSymbolics,
3536
AutoTracker,
@@ -118,6 +119,7 @@ export AutoGTPSA
118119
export AutoMooncake
119120
export AutoMooncakeForward
120121
export AutoPolyesterForwardDiff
122+
export AutoReactant
121123
export AutoReverseDiff
122124
export AutoSymbolics
123125
export AutoTracker
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using Pkg
2+
Pkg.add("Reactant")
3+
4+
using DifferentiationInterface
5+
using DifferentiationInterfaceTest
6+
using Reactant
7+
8+
backend = AutoReactant()
9+
10+
test_differentiation(
11+
backend, DifferentiationInterfaceTest.default_scenarios();
12+
excluded = vcat(SECOND_ORDER, :jacobian, :derivative, :pushforward, :pullback),
13+
logging = true
14+
)

0 commit comments

Comments
 (0)