1- """
2- OneElement
3-
4- Efficient storage for a one-hot array, aka an array in the standard Euclidean basis.
5- """
6- struct OneElement{I,N,T,A<: AbstractArray{T,N} } <: AbstractArray{T,N}
7- ind:: I
8- val:: T
9- a:: A
10-
11- function OneElement (ind:: Integer , val:: T , a:: A ) where {N,T,A<: AbstractArray{T,N} }
12- right_ind = eachindex (a)[ind]
13- return new {typeof(right_ind),N,T,A} (right_ind, val, a)
14- end
15-
16- function OneElement (
17- ind:: CartesianIndex{N} , val:: T , a:: A
18- ) where {N,T,A<: AbstractArray{T,N} }
19- linear_ind = LinearIndices (a)[ind]
20- right_ind = eachindex (a)[linear_ind]
21- return new {typeof(right_ind),N,T,A} (right_ind, val, a)
22- end
23- end
24-
25- Base. size (oe:: OneElement ) = size (oe. a)
26- Base. IndexStyle (oe:: OneElement ) = Base. IndexStyle (oe. a)
27-
28- function Base. getindex (oe:: OneElement{<:Integer} , ind:: Integer )
29- if ind == oe. ind
30- return oe. val
31- else
32- return zero (eltype (oe. a))
33- end
34- end
35-
36- function Base. getindex (oe:: OneElement{<:CartesianIndex{N}} , ind:: Vararg{Int,N} ) where {N}
37- if ind == Tuple (oe. ind)
38- return oe. val
39- else
40- return zero (eltype (oe. a))
41- end
42- end
43-
441"""
452 basis(a::AbstractArray, i)
463
@@ -49,7 +6,7 @@ Construct the `i`-th standard basis array in the vector space of `a`.
496function basis (a:: AbstractArray{T} , i) where {T}
507 b = similar (a)
518 fill! (b, zero (T))
52- b .+ = OneElement (i, one (T), a )
9+ b[i] = one (T)
5310 if ismutable_array (a)
5411 return b
5512 else
6118 multibasis(a::AbstractArray, inds)
6219
6320Construct the sum of the `i`-th standard basis arrays in the vector space of `a` for all `i ∈ inds`.
21+
22+ !!! warning
23+ Does not work on GPU, since this is intended for sparse autodiff and SparseMatrixColorings.jl doesn't work on GPUs either.
6424"""
6525function multibasis (a:: AbstractArray{T} , inds) where {T}
6626 b = similar (a)
6727 fill! (b, zero (T))
6828 for i in inds
69- b .+ = OneElement (i, one (T), a )
29+ b[i] = one (T)
7030 end
7131 return ismutable_array (a) ? b : map (+ , zero (a), b)
7232end
0 commit comments