Skip to content

Commit 8ca6f2e

Browse files
committed
Fix and test Tables API
1 parent 80ec41d commit 8ca6f2e

4 files changed

Lines changed: 49 additions & 10 deletions

File tree

DifferentiationInterfaceTest/src/tests/benchmark.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,18 +82,18 @@ function DifferentiationBenchmark()
8282
return DifferentiationBenchmark(DifferentiationBenchmarkDataRow{Float64}[])
8383
end
8484

85-
Tables.istable(::Type{DifferentiationBenchmark}) = true
85+
Tables.istable(::Type{<:DifferentiationBenchmark}) = true
8686
DataAPI.nrow(data::DifferentiationBenchmark) = length(data.rows)
8787
DataAPI.ncol(data::DifferentiationBenchmark) = 12
8888

89-
Tables.rowaccess(::Type{DifferentiationBenchmark}) = true
89+
Tables.rowaccess(::Type{<:DifferentiationBenchmark}) = true
9090
Tables.rows(data::DifferentiationBenchmark) = data.rows
9191

9292
Tables.getcolumn(row::DifferentiationBenchmarkDataRow, i::Int) = getfield(row, i)
9393
Tables.getcolumn(row::DifferentiationBenchmarkDataRow, nm::Symbol) = getproperty(row, nm)
9494
Tables.columnnames(row::DifferentiationBenchmarkDataRow) = fieldnames(typeof(row))
9595

96-
Tables.columnaccess(::Type{DifferentiationBenchmark}) = true
96+
Tables.columnaccess(::Type{<:DifferentiationBenchmark}) = true
9797
Tables.columns(data::DifferentiationBenchmark) = data
9898

9999
Tables.getcolumn(cols::DifferentiationBenchmark, i::Int) = getfield.(cols.rows, i)

DifferentiationInterfaceTest/test/Project.toml

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,12 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
33
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
44
Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de"
55
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
6+
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
67
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
78
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
89
DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3"
910
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
1011
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
11-
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
12-
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1312
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1413
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
1514
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
@@ -19,9 +18,13 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1918
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
2019
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
2120
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
21+
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
2222
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2323
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2424

25+
[sources]
26+
DifferentiationInterfaceTest = {path = ".."}
27+
2528
[compat]
2629
ADTypes = "1.18"
2730
Aqua = "0.8.12"
@@ -31,8 +34,6 @@ DataFrames = "1.8.1"
3134
DifferentiationInterface = "0.7.10"
3235
ExplicitImports = "1.10.1"
3336
FiniteDiff = "2.27.0"
34-
FiniteDifferences = "0.12.33"
35-
Flux = "0.16.5"
3637
ForwardDiff = "1.2.2"
3738
JET = "0.9, 0.10, 0.11"
3839
JLArrays = "0.3"
@@ -41,6 +42,3 @@ SparseMatrixColorings = "0.4.9"
4142
StaticArrays = "1.9.15"
4243
Zygote = "0.7.10"
4344
julia = "1.10.10"
44-
45-
[sources]
46-
DifferentiationInterfaceTest = { path = ".." }
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
using Pkg; Pkg.activate(@__DIR__)
2+
3+
using ADTypes
4+
using DifferentiationInterfaceTest
5+
import DifferentiationInterfaceTest as DIT
6+
using Tables, DataAPI
7+
using Test
8+
9+
row1 = DIT.DifferentiationBenchmarkDataRow(;
10+
backend = AutoForwardDiff(),
11+
scenario = Scenario{:gradient, :out}(sum, ones(2)),
12+
operator = :gradient,
13+
prepared = true,
14+
calls = 2,
15+
samples = 1,
16+
evals = 1,
17+
time = 1.0,
18+
allocs = 10.0,
19+
bytes = 100.0,
20+
gc_fraction = 0.5,
21+
compile_fraction = 0.1
22+
)
23+
24+
tab = DIT.DifferentiationBenchmark([row1, row1])
25+
26+
@testset "Tables API" begin
27+
@test Tables.istable(typeof(tab))
28+
@test Tables.rowaccess(typeof(tab))
29+
@test Tables.columnaccess(typeof(tab))
30+
@test DataAPI.nrow(tab) == 2
31+
@test DataAPI.ncol(tab) == 12
32+
@test Tables.rows(tab) == tab.rows
33+
@test Tables.columns(tab) == tab
34+
@test Tables.getcolumn(tab, :samples) == [1, 1]
35+
@test Tables.getcolumn(row1, :samples) == 1
36+
@test Tables.getcolumn(tab, 5) == [2, 2]
37+
@test Tables.getcolumn(row1, 5) == 2
38+
@test Tables.columnnames(tab) |> length == 12
39+
@test Tables.columnnames(row1) |> length == 12
40+
end

DifferentiationInterfaceTest/test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ LOGGING = get(ENV, "CI", "false") == "false"
2929
if GROUP == "Formalities" || GROUP == "All"
3030
@testset verbose = true "Formalities" begin
3131
include("formalities.jl")
32+
include("benchmark.jl")
3233
end
3334
@testset verbose = true "Scenarios" begin
3435
include("scenario.jl")

0 commit comments

Comments
 (0)