Skip to content

Commit 7e79b17

Browse files
committed
Add GPUArrays extension
1 parent 467c784 commit 7e79b17

9 files changed

Lines changed: 74 additions & 94 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be"
1717
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
1818
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
1919
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
20+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
2021
GTPSA = "b27dd330-f138-47c5-815b-40db9dd9b6e8"
2122
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2223
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
@@ -37,6 +38,7 @@ DifferentiationInterfaceFastDifferentiationExt = "FastDifferentiation"
3738
DifferentiationInterfaceFiniteDiffExt = "FiniteDiff"
3839
DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
3940
DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"]
41+
DifferentiationInterfaceGPUArraysCoreExt = "GPUArraysCore"
4042
DifferentiationInterfaceGTPSAExt = "GTPSA"
4143
DifferentiationInterfaceMooncakeExt = "Mooncake"
4244
DifferentiationInterfacePolyesterForwardDiffExt = ["PolyesterForwardDiff", "ForwardDiff", "DiffResults"]
@@ -109,21 +111,4 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
109111
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
110112

111113
[targets]
112-
test = [
113-
"ADTypes",
114-
"Aqua",
115-
"ComponentArrays",
116-
"DataFrames",
117-
"ExplicitImports",
118-
"JET",
119-
"JLArrays",
120-
"JuliaFormatter",
121-
"Pkg",
122-
"Random",
123-
"SparseArrays",
124-
"SparseConnectivityTracer",
125-
"SparseMatrixColorings",
126-
"StableRNGs",
127-
"StaticArrays",
128-
"Test",
129-
]
114+
test = ["ADTypes", "Aqua", "ComponentArrays", "DataFrames", "ExplicitImports", "JET", "JLArrays", "JuliaFormatter", "Pkg", "Random", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "StableRNGs", "StaticArrays", "Test"]

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@ to_val(::DI.BatchSizeSettings{B}) where {B} = Val(B)
99
## Annotations
1010

1111
@inline function get_f_and_df(
12-
f::F, backend::AutoEnzyme{M,Nothing}, mode::Mode, ::Val{B}=Val(1)
12+
f::F, backend::AutoEnzyme{M,Nothing}, mode::Mode, (::Val{B})=Val(1)
1313
) where {F,M,B}
1414
return f
1515
end
1616

1717
@inline function get_f_and_df(
18-
f::F, backend::AutoEnzyme{M,<:Const}, mode::Mode, ::Val{B}=Val(1)
18+
f::F, backend::AutoEnzyme{M,<:Const}, mode::Mode, (::Val{B})=Val(1)
1919
) where {F,M,B}
2020
return Const(f)
2121
end
@@ -34,7 +34,7 @@ end
3434
},
3535
},
3636
mode::Mode,
37-
::Val{B}=Val(1),
37+
(::Val{B})=Val(1),
3838
) where {F,M,B}
3939
# TODO: needs more sophistication for mixed activities
4040
if B == 1
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
module DifferentiationInterfaceGPUArraysCoreExt
2+
3+
import DifferentiationInterface as DI
4+
using GPUArraysCore: AbstractGPUArray
5+
6+
"""
7+
OneElement
8+
9+
Efficient storage for a one-hot array, aka an array in the standard Euclidean basis.
10+
"""
11+
struct OneElement{I,N,T,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
12+
ind::I
13+
val::T
14+
a::A
15+
16+
function OneElement(ind::Integer, val::T, a::A) where {N,T,A<:AbstractArray{T,N}}
17+
right_ind = eachindex(a)[ind]
18+
return new{typeof(right_ind),N,T,A}(right_ind, val, a)
19+
end
20+
21+
function OneElement(
22+
ind::CartesianIndex{N}, val::T, a::A
23+
) where {N,T,A<:AbstractArray{T,N}}
24+
linear_ind = LinearIndices(a)[ind]
25+
right_ind = eachindex(a)[linear_ind]
26+
return new{typeof(right_ind),N,T,A}(right_ind, val, a)
27+
end
28+
end
29+
30+
Base.size(oe::OneElement) = size(oe.a)
31+
Base.IndexStyle(oe::OneElement) = Base.IndexStyle(oe.a)
32+
33+
function Base.getindex(oe::OneElement{<:Integer}, ind::Integer)
34+
return ifelse(ind == oe.ind, oe.val, zero(eltype(oe.a)))
35+
end
36+
37+
function DI.basis(a::AbstractGPUArray{T}, i) where {T}
38+
b = zero(a)
39+
b .+= OneElement(i, one(T), a)
40+
return b
41+
end
42+
43+
end

DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/onearg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ function DI.pushforward(
4646
DI.check_prep(f, prep, backend, x, tx, contexts...)
4747
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
4848
ty = map(tx) do dx
49-
foreach((t, xi, dxi) -> (t[0] = xi; t[1] = dxi), prep.xt, x, dx)
49+
foreach((t, xi, dxi) -> (t[0]=xi; t[1]=dxi), prep.xt, x, dx)
5050
yt = fc(prep.xt)
5151
if yt isa Number
5252
return yt[1]
@@ -71,7 +71,7 @@ function DI.pushforward!(
7171
fc = DI.fix_tail(f, map(DI.unwrap, contexts)...)
7272
for b in eachindex(tx, ty)
7373
dx, dy = tx[b], ty[b]
74-
foreach((t, xi, dxi) -> (t[0] = xi; t[1] = dxi), prep.xt, x, dx)
74+
foreach((t, xi, dxi) -> (t[0]=xi; t[1]=dxi), prep.xt, x, dx)
7575
yt = fc(prep.xt)
7676
map!(t -> t[1], dy, yt)
7777
end

DifferentiationInterface/ext/DifferentiationInterfaceGTPSAExt/twoarg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ function DI.pushforward(
5656
DI.check_prep(f!, y, prep, backend, x, tx, contexts...)
5757
fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...)
5858
ty = map(tx) do dx
59-
foreach((t, xi, dxi) -> (t[0] = xi; t[1] = dxi), prep.xt, x, dx)
59+
foreach((t, xi, dxi) -> (t[0]=xi; t[1]=dxi), prep.xt, x, dx)
6060
fc!(prep.yt, prep.xt)
6161
dy = map(t -> t[1], prep.yt)
6262
return dy
@@ -79,7 +79,7 @@ function DI.pushforward!(
7979
fc! = DI.fix_tail(f!, map(DI.unwrap, contexts)...)
8080
for b in eachindex(tx, ty)
8181
dx, dy = tx[b], ty[b]
82-
foreach((t, xi, dxi) -> (t[0] = xi; t[1] = dxi), prep.xt, x, dx)
82+
foreach((t, xi, dxi) -> (t[0]=xi; t[1]=dxi), prep.xt, x, dx)
8383
fc!(prep.yt, prep.xt)
8484
map!(t -> t[1], dy, prep.yt)
8585
end

DifferentiationInterface/src/utils/basis.jl

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,3 @@
1-
"""
2-
OneElement
3-
4-
Efficient storage for a one-hot array, aka an array in the standard Euclidean basis.
5-
"""
6-
struct OneElement{I,N,T,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
7-
ind::I
8-
val::T
9-
a::A
10-
11-
function OneElement(ind::Integer, val::T, a::A) where {N,T,A<:AbstractArray{T,N}}
12-
right_ind = eachindex(a)[ind]
13-
return new{typeof(right_ind),N,T,A}(right_ind, val, a)
14-
end
15-
16-
function OneElement(
17-
ind::CartesianIndex{N}, val::T, a::A
18-
) where {N,T,A<:AbstractArray{T,N}}
19-
linear_ind = LinearIndices(a)[ind]
20-
right_ind = eachindex(a)[linear_ind]
21-
return new{typeof(right_ind),N,T,A}(right_ind, val, a)
22-
end
23-
end
24-
25-
Base.size(oe::OneElement) = size(oe.a)
26-
Base.IndexStyle(oe::OneElement) = Base.IndexStyle(oe.a)
27-
28-
function Base.getindex(oe::OneElement{<:Integer}, ind::Integer)
29-
if ind == oe.ind
30-
return oe.val
31-
else
32-
return zero(eltype(oe.a))
33-
end
34-
end
35-
36-
function Base.getindex(oe::OneElement{<:CartesianIndex{N}}, ind::Vararg{Int,N}) where {N}
37-
if ind == Tuple(oe.ind)
38-
return oe.val
39-
else
40-
return zero(eltype(oe.a))
41-
end
42-
end
43-
441
"""
452
basis(a::AbstractArray, i)
463
@@ -49,7 +6,7 @@ Construct the `i`-th standard basis array in the vector space of `a`.
496
function basis(a::AbstractArray{T}, i) where {T}
507
b = similar(a)
518
fill!(b, zero(T))
52-
b .+= OneElement(i, one(T), a)
9+
b[i] = one(T)
5310
if ismutable_array(a)
5411
return b
5512
else

DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using DifferentiationInterface:
55
AutoReverseFromPrimitive,
66
DenseSparsityDetector
77
using SparseMatrixColorings
8+
using JLArrays, StaticArrays
89
using Test
910

1011
LOGGING = get(ENV, "CI", "false") == "false"
@@ -137,3 +138,9 @@ end
137138
pushforward, copyto!, [1.0], AutoSimpleFiniteDiff(), [1.0], ([1.0], [1.0])
138139
)
139140
end
141+
142+
@testset "Weird arrays" begin
143+
test_differentiation(
144+
AutoSimpleFiniteDiff(), vcat(static_scenarios(), gpu_scenarios()); logging=LOGGING
145+
)
146+
end;

DifferentiationInterfaceTest/ext/DifferentiationInterfaceTestComponentArraysExt/DifferentiationInterfaceTestComponentArraysExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ function comp_to_num(x::ComponentVector)::Number
1010
return sum(sin.(x.a)) + sum(cos.(x.b))
1111
end
1212

13-
comp_to_num_gradient(x) = ComponentVector(; a=cos.(x.a), b=-sin.(x.b))
13+
comp_to_num_gradient(x) = ComponentVector(; a=cos.(x.a), b=(-sin.(x.b)))
1414

1515
function comp_to_num_pushforward(x, dx)
1616
g = comp_to_num_gradient(x)

DifferentiationInterfaceTest/src/tests/correctness_eval.jl

Lines changed: 12 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,7 @@ for op in ALL_OPS
5656
contextsrand = rewrap(map(myrandom unwrap, contexts)...)
5757
local prepstrict
5858
preptup_cands_val, preptup_cands_noval = map(1:2) do _
59-
new_smaller =
60-
if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba
59+
new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba
6160
deepcopy(scen)
6261
else
6362
deepcopy(smaller)
@@ -124,8 +123,7 @@ for op in ALL_OPS
124123
contextsrand = rewrap(map(myrandom unwrap, contexts)...)
125124
local prepstrict
126125
preptup_cands_val, preptup_cands_noval = map(1:2) do _
127-
new_smaller =
128-
if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba
126+
new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba
129127
deepcopy(scen)
130128
else
131129
deepcopy(smaller)
@@ -208,8 +206,7 @@ for op in ALL_OPS
208206
contextsrand = rewrap(map(myrandom unwrap, contexts)...)
209207
local prepstrict
210208
preptup_cands_val, preptup_cands_noval = map(1:2) do _
211-
new_smaller =
212-
if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba
209+
new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba
213210
deepcopy(scen)
214211
else
215212
deepcopy(smaller)
@@ -286,8 +283,7 @@ for op in ALL_OPS
286283
contextsrand = rewrap(map(myrandom unwrap, contexts)...)
287284
local prepstrict
288285
preptup_cands_val, preptup_cands_noval = map(1:2) do _
289-
new_smaller =
290-
if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba
286+
new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba
291287
deepcopy(scen)
292288
else
293289
deepcopy(smaller)
@@ -375,8 +371,7 @@ for op in ALL_OPS
375371
contextsrand = rewrap(map(myrandom unwrap, contexts)...)
376372
local prepstrict
377373
preptup_cands_val, preptup_cands_noval = map(1:2) do _
378-
new_smaller =
379-
if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba
374+
new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba
380375
deepcopy(scen)
381376
else
382377
deepcopy(smaller)
@@ -445,8 +440,7 @@ for op in ALL_OPS
445440
contextsrand = rewrap(map(myrandom unwrap, contexts)...)
446441
local prepstrict
447442
preptup_cands_val, preptup_cands_noval = map(1:2) do _
448-
new_smaller =
449-
if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba
443+
new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba
450444
deepcopy(scen)
451445
else
452446
deepcopy(smaller)
@@ -532,8 +526,7 @@ for op in ALL_OPS
532526
contextsrand = rewrap(map(myrandom unwrap, contexts)...)
533527
local prepstrict
534528
preptup_cands_val, preptup_cands_noval = map(1:2) do _
535-
new_smaller =
536-
if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba
529+
new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba
537530
deepcopy(scen)
538531
else
539532
deepcopy(smaller)
@@ -599,8 +592,7 @@ for op in ALL_OPS
599592
contextsrand = rewrap(map(myrandom unwrap, contexts)...)
600593
local prepstrict
601594
preptup_cands_val, preptup_cands_noval = map(1:2) do _
602-
new_smaller =
603-
if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba
595+
new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba
604596
deepcopy(scen)
605597
else
606598
deepcopy(smaller)
@@ -682,8 +674,7 @@ for op in ALL_OPS
682674
contextsrand = rewrap(map(myrandom unwrap, contexts)...)
683675
local prepstrict
684676
preptup_cands_val, preptup_cands_noval = map(1:2) do _
685-
new_smaller =
686-
if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba
677+
new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba
687678
deepcopy(scen)
688679
else
689680
deepcopy(smaller)
@@ -765,8 +756,7 @@ for op in ALL_OPS
765756
contextsrand = rewrap(map(myrandom unwrap, contexts)...)
766757
local prepstrict
767758
preptup_cands_val, preptup_cands_noval = map(1:2) do _
768-
new_smaller =
769-
if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba
759+
new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba
770760
deepcopy(scen)
771761
else
772762
deepcopy(smaller)
@@ -867,8 +857,7 @@ for op in ALL_OPS
867857
contextsrand = rewrap(map(myrandom unwrap, contexts)...)
868858
local prepstrict
869859
preptup_cands_val, preptup_cands_noval = map(1:2) do _
870-
new_smaller =
871-
if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba
860+
new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba
872861
deepcopy(scen)
873862
else
874863
deepcopy(smaller)
@@ -934,8 +923,7 @@ for op in ALL_OPS
934923
contextsrand = rewrap(map(myrandom unwrap, contexts)...)
935924
local prepstrict
936925
preptup_cands_val, preptup_cands_noval = map(1:2) do _
937-
new_smaller =
938-
if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba
926+
new_smaller = if isnothing(smaller) || adapt_batchsize(ba, smaller) != ba
939927
deepcopy(scen)
940928
else
941929
deepcopy(smaller)

0 commit comments

Comments
 (0)