Skip to content

Commit 2084b3d

Browse files
committed
fix: make (multi)basis work on CuArray
1 parent d8905f5 commit 2084b3d

2 files changed

Lines changed: 12 additions & 36 deletions

File tree

DifferentiationInterface/ext/DifferentiationInterfaceGPUArraysCoreExt/DifferentiationInterfaceGPUArraysCoreExt.jl

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,21 @@
11
module DifferentiationInterfaceGPUArraysCoreExt
22

33
import DifferentiationInterface as DI
4-
using GPUArraysCore: AbstractGPUArray
4+
using GPUArraysCore: @allowscalar, AbstractGPUArray
55

6-
"""
7-
OneElement
8-
9-
Efficient storage for a one-hot array, aka an array in the standard Euclidean basis.
10-
"""
11-
struct OneElement{I,N,T,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
12-
ind::I
13-
val::T
14-
a::A
15-
16-
function OneElement(ind::Integer, val::T, a::A) where {N,T,A<:AbstractArray{T,N}}
17-
right_ind = eachindex(a)[ind]
18-
return new{typeof(right_ind),N,T,A}(right_ind, val, a)
19-
end
20-
21-
function OneElement(
22-
ind::CartesianIndex{N}, val::T, a::A
23-
) where {N,T,A<:AbstractArray{T,N}}
24-
linear_ind = LinearIndices(a)[ind]
25-
right_ind = eachindex(a)[linear_ind]
26-
return new{typeof(right_ind),N,T,A}(right_ind, val, a)
27-
end
28-
end
29-
30-
Base.size(oe::OneElement) = size(oe.a)
31-
Base.IndexStyle(oe::OneElement) = Base.IndexStyle(oe.a)
32-
33-
function Base.getindex(oe::OneElement{<:Integer}, ind::Integer)
34-
return ifelse(ind == oe.ind, oe.val, zero(eltype(oe.a)))
6+
function DI.basis(a::AbstractGPUArray{T}, i) where {T}
7+
b = similar(a)
8+
fill!(b, zero(T))
9+
@allowscalar b[i] = one(T)
10+
return b
3511
end
3612

37-
function DI.basis(a::AbstractGPUArray{T}, i) where {T}
38-
b = zero(a)
39-
b .+= OneElement(i, one(T), a)
13+
function DI.multibasis(a::AbstractGPUArray{T}, inds) where {T}
14+
b = similar(a)
15+
fill!(b, zero(T))
16+
for i in inds
17+
@allowscalar b[i] = one(T)
18+
end
4019
return b
4120
end
4221

DifferentiationInterface/src/utils/basis.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@ end
1818
multibasis(a::AbstractArray, inds)
1919
2020
Construct the sum of the `i`-th standard basis arrays in the vector space of `a` for all `i ∈ inds`.
21-
22-
!!! warning
23-
Does not work on GPU, since this is intended for sparse autodiff and SparseMatrixColorings.jl doesn't work on GPUs either.
2421
"""
2522
function multibasis(a::AbstractArray{T}, inds) where {T}
2623
b = similar(a)

0 commit comments

Comments
 (0)