Skip to content

Commit 8c4602c

Browse files
Add informative error when threading VoA{SArray} without Polyester
When Polyester is not loaded and a user requests threaded FastBroadcast on VectorOfArray{SArray}, throw an error explaining they need to load Polyester.jl. Also fix tests to use Vector-of-SVector construction (not Matrix-of-SVector) to properly exercise the SArray-specific path. Co-Authored-By: Chris Rackauckas <accounts@chrisrackauckas.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 0e5e41e commit 8c4602c

4 files changed

Lines changed: 22 additions & 8 deletions

File tree

ext/RecursiveArrayToolsFastBroadcastExt.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,20 @@ const AbstractVectorOfSArray = AbstractVectorOfArray{
2727
return dst
2828
end
2929

30-
# Fallback for VectorOfArray: the generic threaded path splits along the last
31-
# axis via views, which does not correctly partition work for VectorOfArray.
32-
# Fall back to serial broadcasting. The RecursiveArrayToolsFastBroadcastPolyesterExt
33-
# extension provides proper Polyester-based threading when Polyester is loaded.
30+
# Fallback for non-SArray VectorOfArray: the generic threaded path splits along
31+
# the last axis via views, which does not correctly partition work for
32+
# VectorOfArray. Fall back to serial broadcasting.
33+
# For SArray VectorOfArray, throw an informative error telling the user to
34+
# load Polyester.jl for threaded broadcasting.
3435
@inline function FastBroadcast.fast_materialize!(
3536
::Threaded, dst::AbstractVectorOfArray,
3637
bc::Broadcast.Broadcasted
3738
)
39+
if dst isa AbstractVectorOfSArray && !RecursiveArrayTools.POLYESTER_LOADED[]
40+
error("Threaded FastBroadcast on VectorOfArray{SArray} requires Polyester.jl. " *
41+
"Add `using Polyester` to enable threaded broadcasting, or use " *
42+
"`@.. thread=false` for serial broadcasting.")
43+
end
3844
return FastBroadcast.fast_materialize!(Serial(), dst, bc)
3945
end
4046

ext/RecursiveArrayToolsFastBroadcastPolyesterExt.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@ using FastBroadcast: Serial, Threaded
66
using Polyester
77
using StaticArraysCore
88

9+
# Signal to the base FastBroadcast extension that Polyester threading is available.
10+
RecursiveArrayTools.POLYESTER_LOADED[] = true
11+
912
const AbstractVectorOfSArray = AbstractVectorOfArray{
1013
T, N, <:AbstractVector{<:StaticArraysCore.SArray},
1114
} where {T, N}

src/RecursiveArrayTools.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,11 @@ module RecursiveArrayTools
142142

143143
export ArrayPartition, AP, NamedArrayPartition
144144

145+
# Flag set to `true` by RecursiveArrayToolsFastBroadcastPolyesterExt when
146+
# Polyester is loaded. Checked by the FastBroadcast ext to decide whether
147+
# to throw an informative error on threaded VoA{SArray} operations.
148+
const POLYESTER_LOADED = Ref(false)
149+
145150
include("precompilation.jl")
146151

147152
end # module

test/interface_tests.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -305,15 +305,15 @@ end
305305
:RecursiveArrayToolsFastBroadcastPolyesterExt
306306
) !== nothing
307307

308-
# Test basic threaded broadcast with Polyester
309-
u_p = VectorOfArray(fill(SVector(2.0, 3.0), 3, 3))
308+
# Test basic threaded broadcast with Polyester (Vector-of-SVector storage)
309+
u_p = VectorOfArray([SVector(2.0, 3.0) for _ in 1:9])
310310
v_p = copy(u_p)
311311
@.. thread = true v_p = v_p + u_p
312312
@test all(x -> x == SVector(4.0, 6.0), v_p.u)
313313

314314
# Test with larger array to exercise Polyester batching
315-
u_large = VectorOfArray(fill(SVector(1.0, 1.0, 1.0), 100))
316-
v_large = VectorOfArray(fill(SVector(0.0, 0.0, 0.0), 100))
315+
u_large = VectorOfArray([SVector(1.0, 1.0, 1.0) for _ in 1:100])
316+
v_large = VectorOfArray([SVector(0.0, 0.0, 0.0) for _ in 1:100])
317317
@.. thread = true v_large = u_large * 2.0
318318
@test all(x -> x == SVector(2.0, 2.0, 2.0), v_large.u)
319319
end

0 commit comments

Comments
 (0)