Skip to content

Commit 3936bc3

Browse files
authored
Put testing into subpackage (#110)
* Put testing into subpackage * Fix deps * Fix more deps * More fixes * Less reimporting * Fix CI for docs * Fix docs CI * Rm JET
1 parent 99e67d9 commit 3936bc3

48 files changed

Lines changed: 302 additions & 564 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.github/workflows/CI.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ jobs:
6262
using Pkg
6363
Pkg.Registry.update() # TODO: remove? useful for fresh releases
6464
Pkg.develop(PackageSpec(path=pwd()))
65+
Pkg.develop(PackageSpec(path=joinpath(pwd(), "lib", "DifferentiationInterfaceTest")))
6566
Pkg.instantiate()
6667
- uses: julia-actions/julia-buildpkg@v1
6768
- uses: julia-actions/julia-docdeploy@v1

Project.toml

Lines changed: 31 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -7,58 +7,40 @@ version = "0.1.0"
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
99
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
10-
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1110
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
12-
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1311

1412
[weakdeps]
1513
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
1614
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
17-
Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de"
18-
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
19-
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
2015
Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c"
2116
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
2217
FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be"
2318
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
2419
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
2520
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
26-
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
27-
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
2821
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
2922
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
30-
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
31-
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
3223
Taped = "07d77754-e150-4737-8c94-cd238a1fb45b"
3324
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
3425
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3526

3627
[extensions]
3728
DifferentiationInterfaceChainRulesCoreExt = "ChainRulesCore"
38-
DifferentiationInterfaceChairmarksExt = ["Chairmarks"]
39-
DifferentiationInterfaceComponentArraysExt = ["ComponentArrays"]
4029
DifferentiationInterfaceDiffractorExt = ["AbstractDifferentiation", "Diffractor"]
4130
DifferentiationInterfaceEnzymeExt = "Enzyme"
42-
DifferentiationInterfaceFastDifferentiationExt = ["FastDifferentiation", "RuntimeGeneratedFunctions"]
31+
DifferentiationInterfaceFastDifferentiationExt = "FastDifferentiation"
4332
DifferentiationInterfaceFiniteDiffExt = "FiniteDiff"
4433
DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
45-
DifferentiationInterfaceForwardDiffExt = ["DiffResults", "ForwardDiff"]
46-
DifferentiationInterfaceJETExt = ["JET"]
47-
DifferentiationInterfaceJLArraysExt = ["JLArrays"]
48-
DifferentiationInterfacePolyesterForwardDiffExt = ["PolyesterForwardDiff", "ForwardDiff", "DiffResults"]
49-
DifferentiationInterfaceReverseDiffExt = ["DiffResults", "ReverseDiff"]
50-
DifferentiationInterfaceStaticArraysExt = ["StaticArrays"]
51-
DifferentiationInterfaceTapedExt = ["Taped"]
52-
DifferentiationInterfaceTrackerExt = ["Tracker"]
53-
DifferentiationInterfaceZygoteExt = ["Zygote"]
34+
DifferentiationInterfaceForwardDiffExt = "ForwardDiff"
35+
DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff"
36+
DifferentiationInterfaceReverseDiffExt = "ReverseDiff"
37+
DifferentiationInterfaceTrackerExt = "Tracker"
38+
DifferentiationInterfaceZygoteExt = "Zygote"
5439

5540
[compat]
5641
ADTypes = "0.2.7"
5742
AbstractDifferentiation = "0.6"
5843
ChainRulesCore = "1.19"
59-
Chairmarks = "1.2"
60-
ComponentArrays = "0.15"
61-
DiffResults = "1.1"
6244
Diffractor = "0.2"
6345
DocStringExtensions = "0.9"
6446
Enzyme = "0.11"
@@ -67,16 +49,35 @@ FillArrays = "1"
6749
FiniteDiff = "2.22"
6850
FiniteDifferences = "0.12"
6951
ForwardDiff = "0.10"
70-
Functors = "0.4"
71-
JET = "0.8"
72-
JLArrays = "0.1"
7352
LinearAlgebra = "1"
7453
PolyesterForwardDiff = "0.1"
7554
ReverseDiff = "1.15"
76-
RuntimeGeneratedFunctions = "0.5"
77-
StaticArrays = "1.9"
78-
Taped = "1"
55+
Taped = "0.1"
7956
Test = "1"
8057
Tracker = "0.2"
8158
Zygote = "0.6"
8259
julia = "1.10"
60+
61+
[extras]
62+
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
63+
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
64+
Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de"
65+
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
66+
Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c"
67+
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
68+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
69+
FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be"
70+
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
71+
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
72+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
73+
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
74+
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
75+
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
76+
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
77+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
78+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
79+
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
80+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
81+
82+
[targets]
83+
test = ["ADTypes", "Aqua", "Chairmarks", "DataFrames", "Diffractor", "Documenter", "Enzyme", "FastDifferentiation", "FiniteDiff", "FiniteDifferences", "ForwardDiff", "JET", "JuliaFormatter", "Pkg", "PolyesterForwardDiff", "ReverseDiff", "Test", "Tracker", "Zygote"]

docs/Project.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
44
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
55
Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de"
66
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
7-
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
87
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
98
Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c"
109
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
@@ -15,7 +14,6 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1514
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1615
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
1716
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
18-
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
1917
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
2018
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
2119
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"

docs/make.jl

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
using Base: get_extension
22
using DifferentiationInterface
3-
using DifferentiationInterface.DifferentiationTest
3+
using DifferentiationInterfaceTest
44
import DifferentiationInterface as DI
55
using Documenter
66
using DocumenterMermaid
7-
using JET
87
using Random
98
using Test
109

@@ -40,7 +39,7 @@ makedocs(;
4039
modules=[
4140
ADTypes,
4241
DifferentiationInterface,
43-
DifferentiationInterface.DifferentiationTest,
42+
DifferentiationInterfaceTest,
4443
get_extension(DI, :DifferentiationInterfaceChainRulesCoreExt),
4544
get_extension(DI, :DifferentiationInterfaceDiffractorExt),
4645
get_extension(DI, :DifferentiationInterfaceEnzymeExt),

docs/src/api.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
```@meta
2-
CurrentModule = DifferentiationInterface
2+
CurrentModule = Main
33
CollapsedDocStrings = true
44
```
55

@@ -61,7 +61,7 @@ Pages = ["prepare.jl"]
6161
## Testing & benchmarking
6262

6363
```@autodocs
64-
Modules = [DifferentiationTest]
64+
Modules = [DifferentiationInterfaceTest]
6565
Private = false
6666
```
6767

@@ -77,6 +77,6 @@ Filter = t -> !(t isa Type && t <: ADTypes.AbstractADType)
7777
```
7878

7979
```@autodocs
80-
Modules = [DifferentiationTest]
80+
Modules = [DifferentiationInterfaceTest]
8181
Public = false
8282
```

docs/src/backends.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
```@meta
2+
CurrentModule = Main
23
CollapsedDocStrings = true
34
```
45

56
```@setup backends
67
using ADTypes, DifferentiationInterface
7-
using DifferentiationInterface.DifferentiationTest: backend_string
8+
using DifferentiationInterfaceTest: backend_string
89
import Markdown
910
import Enzyme, FastDifferentiation, FiniteDiff, FiniteDifferences, ForwardDiff, PolyesterForwardDiff, ReverseDiff, Tracker, Zygote
1011

docs/src/tutorial.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
```@meta
2+
CurrentModule = Main
3+
```
4+
15
# Tutorial
26

37
We present a typical workflow with DifferentiationInterface.jl and showcase its potential performance benefits.
@@ -132,10 +136,10 @@ You didn't need to look at the docs of either ForwardDiff.jl or Enzyme.jl to ach
132136
## Testing and benchmarking
133137

134138
DifferentiationInterface.jl also provides some utilities for more involved comparison between backends.
135-
They are gathered in a submodule.
139+
They are gathered in a submodule called [`DifferentiationInterfaceTest`](https://github.com/gdalle/DifferentiationInterface.jl/tree/main/lib/DifferentiationInterfaceTest).
136140

137141
```@repl tuto
138-
using DifferentiationInterface.DifferentiationTest
142+
using DifferentiationInterfaceTest
139143
```
140144

141145
The main entry point is [`test_differentiation`](@ref), which is used as follows:

ext/DifferentiationInterfaceEnzymeExt/DifferentiationInterfaceEnzymeExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
module DifferentiationInterfaceEnzymeExt
22

33
using ADTypes: ADTypes, AutoEnzyme
4-
using DifferentiationInterface: mymul!!, myupdate!!, mysimilar, myzero, myzero!!
4+
using DifferentiationInterface: myupdate!!
55
import DifferentiationInterface as DI
66
using DocStringExtensions
77
using Enzyme:
@@ -44,7 +44,7 @@ DI.mode(::AutoReverseEnzyme) = ADTypes.AbstractReverseMode
4444

4545
# Enzyme's `Duplicated(x, dx)` expects both arguments to be of the same type
4646
function DI.basisarray(::AutoEnzyme, a::AbstractArray{T}, i::CartesianIndex) where {T}
47-
b = myzero(a)
47+
b = zero(a)
4848
b[i] = one(T)
4949
return b
5050
end
Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,15 @@
11
## Pullback
22

3-
function DI.value_and_pullback!!(
4-
f, _dx, ::AutoReverseEnzyme, x::Number, dy::Number, extras::Nothing
3+
function DI.value_and_pullback(
4+
f, ::AutoReverseEnzyme, x::Number, dy::Number, extras::Nothing
55
)
66
der, y = autodiff(ReverseWithPrimal, f, Active, Active(x))
77
new_dx = dy * only(der)
88
return y, new_dx
99
end
1010

11-
function DI.value_and_pullback!!(
12-
f, dx, ::AutoReverseEnzyme, x::AbstractArray, dy::Number, extras::Nothing
13-
)
14-
dx_sametype = convert(typeof(x), dx)
15-
dx_sametype = myzero!!(dx_sametype)
16-
_, y = autodiff(ReverseWithPrimal, f, Active, Duplicated(x, dx_sametype))
17-
dx_sametype = mymul!!(dx_sametype, dy)
18-
return y, myupdate!!(dx, dx_sametype)
19-
end
20-
21-
function DI.value_and_pullback!!(
22-
f, _dx, ::AutoReverseEnzyme, x::Number, dy::AbstractArray, extras::Nothing
11+
function DI.value_and_pullback(
12+
f, ::AutoReverseEnzyme, x::Number, dy::AbstractArray, extras::Nothing
2313
)
2414
forw, rev = autodiff_thunk(
2515
ReverseSplitWithPrimal, Const{typeof(f)}, Duplicated, Active{typeof(x)}
@@ -30,11 +20,21 @@ function DI.value_and_pullback!!(
3020
return y, new_dx
3121
end
3222

23+
function DI.value_and_pullback!!(
24+
f, dx, ::AutoReverseEnzyme, x::AbstractArray, dy::Number, extras::Nothing
25+
)
26+
dx_sametype = convert(typeof(x), dx)
27+
dx_sametype .= zero(eltype(dx_sametype))
28+
_, y = autodiff(ReverseWithPrimal, f, Active, Duplicated(x, dx_sametype))
29+
dx_sametype .*= dy
30+
return y, myupdate!!(dx, dx_sametype)
31+
end
32+
3333
function DI.value_and_pullback!!(
3434
f, dx, ::AutoReverseEnzyme, x::AbstractArray, dy::AbstractArray, extras::Nothing
3535
)
3636
dx_sametype = convert(typeof(x), dx)
37-
dx_sametype = myzero!!(dx_sametype)
37+
dx_sametype .= zero(eltype(dx_sametype))
3838
forw, rev = autodiff_thunk(
3939
ReverseSplitWithPrimal, Const{typeof(f)}, Duplicated, Duplicated{typeof(x)}
4040
)
@@ -44,25 +44,17 @@ function DI.value_and_pullback!!(
4444
return y, myupdate!!(dx, dx_sametype)
4545
end
4646

47-
function DI.value_and_pullback(f, backend::AutoReverseEnzyme, x, dy, extras)
48-
dx = mysimilar(x)
47+
function DI.value_and_pullback(f, backend::AutoReverseEnzyme, x::AbstractArray, dy, extras)
48+
dx = similar(x)
4949
return DI.value_and_pullback!!(f, dx, backend, x, dy, extras)
5050
end
5151

5252
## Gradient
5353

54-
function DI.gradient(f, backend::AutoReverseEnzyme, x, extras::Nothing)
54+
function DI.gradient(f, ::AutoReverseEnzyme, x::AbstractArray, extras::Nothing)
5555
return gradient(Reverse, f, x)
5656
end
5757

58-
function DI.gradient!!(f, grad, backend::AutoReverseEnzyme, x, extras::Nothing)
58+
function DI.gradient!!(f, grad, ::AutoReverseEnzyme, x::AbstractArray, extras::Nothing)
5959
return gradient!(Reverse, grad, f, x)
6060
end
61-
62-
function DI.gradient(f, backend::AutoReverseEnzyme, x::Number, extras::Nothing)
63-
return autodiff(Reverse, f, Active(x))
64-
end
65-
66-
function DI.gradient!!(f, grad, backend::AutoReverseEnzyme, x::Number, extras::Nothing)
67-
return autodiff(Reverse, f, Active(x))
68-
end

ext/DifferentiationInterfaceEnzymeExt/reverse_mutating.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@ function DI.value_and_pullback!!(
88
return y, new_dx
99
end
1010

11-
function DI.value_and_pullback!!(f!, y, dx, ::AutoReverseEnzyme, x, dy, extras::Nothing)
11+
function DI.value_and_pullback!!(
12+
f!, y, dx, ::AutoReverseEnzyme, x::AbstractArray, dy, extras::Nothing
13+
)
1214
dx_sametype = convert(typeof(x), dx)
13-
dx_sametype = myzero!!(dx_sametype)
14-
dy_sametype = convert(typeof(y), copy(dy))
15+
dx_sametype .= zero(eltype(dx_sametype))
16+
dy_sametype = convert(typeof(y), copy(dy)) # TODO: how to get rid of copy?
1517
autodiff(Reverse, f!, Const, Duplicated(y, dy_sametype), Duplicated(x, dx_sametype))
1618
return y, myupdate!!(dx, dx_sametype)
1719
end

0 commit comments

Comments
 (0)