Skip to content

Commit 36bfdb6

Browse files
Fix GPU CuArray ambiguity with typed AbstractGPUArray constructor
Disambiguate CuArray(::AbstractArray{T,N}) from CUDA.jl by defining (::Type{GA})(::AbstractVectorOfArray{T,N}) where {T,N,GA<:AbstractGPUArray} AbstractVectorOfArray{T,N} is strictly more specific than AbstractArray{T,N} on arg2, so this method wins dispatch for VectorOfArray arguments. Uses stack(VA.u) to stay on GPU (avoids GPU→CPU→GPU round-trip). Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 626da00 commit 36bfdb6

1 file changed

Lines changed: 7 additions & 3 deletions

File tree

src/RecursiveArrayTools.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,13 @@ module RecursiveArrayTools
176176
# AbstractVectorOfArray uses AbstractArray's show
177177

178178
import GPUArraysCore
179-
Base.convert(T::Type{<:GPUArraysCore.AnyGPUArray}, VA::AbstractVectorOfArray) = stack(VA.u)
180-
# Disambiguate with CuArray(::AbstractArray{T,N}) by providing the typed method
181-
(T::Type{<:GPUArraysCore.AnyGPUArray})(VA::AbstractVectorOfArray{<:Any, N}) where {N} = T(stack(VA.u))
179+
Base.convert(::Type{T}, VA::AbstractVectorOfArray) where {T <: GPUArraysCore.AnyGPUArray} = T(stack(VA.u))
180+
# Constructor: CuArray(va) etc. Must disambiguate with CuArray(::AbstractArray{T,N})
181+
# from CUDA.jl. AbstractVectorOfArray{T,N} is more specific than AbstractArray{T,N}
182+
# on arg2, matching {T,N} ensures equal specificity to CUDA's method.
183+
function (::Type{GA})(VA::AbstractVectorOfArray{T, N}) where {T, N, GA <: GPUArraysCore.AbstractGPUArray}
184+
return GA(stack(VA.u))
185+
end
182186

183187
export VectorOfArray, VA, DiffEqArray, AbstractVectorOfArray, AbstractDiffEqArray,
184188
AllObserved, vecarr_to_vectors, tuples

0 commit comments

Comments
 (0)