Skip to content

Commit 467c784

Browse files
committed
perf: optimize multibasis for sparse differentiation
1 parent f442ed4 commit 467c784

2 files changed

Lines changed: 5 additions & 2 deletions

File tree

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.51"
4+
version = "0.6.52"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/src/utils/basis.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,15 @@ end
6161
multibasis(a::AbstractArray, inds)
6262
6363
Construct the sum of the `i`-th standard basis arrays in the vector space of `a` for all `i ∈ inds`.
64+
65+
!!! warning
66+
Does not work on GPU, since this is intended for sparse autodiff and SparseMatrixColorings.jl doesn't work on GPUs either.
6467
"""
6568
function multibasis(a::AbstractArray{T}, inds) where {T}
6669
b = similar(a)
6770
fill!(b, zero(T))
6871
for i in inds
69-
b .+= OneElement(i, one(T), a)
72+
b[i] = one(T)
7073
end
7174
return ismutable_array(a) ? b : map(+, zero(a), b)
7275
end

0 commit comments

Comments
 (0)