@@ -364,15 +364,37 @@ end
364364 ArrayPartition (f, N)
365365end
366366
367+ # old version
368+ # @inline function Base.copyto!(dest::ArrayPartition,
369+ # bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}) where {
370+ # Style,
371+ # }
372+ # N = npartitions(dest, bc)
373+ # @inline function f(i)
374+ # copyto!(dest.x[i], unpack(bc, i))
375+ # end
376+ # ntuple(f, Val(N))
377+ # dest
378+ # end
379+
380+ # new version
367381@inline function Base. copyto! (dest:: ArrayPartition ,
368- bc:: Broadcast.Broadcasted{ArrayPartitionStyle{Style}} ) where {
369- Style,
370- }
382+ bc:: Broadcast.Broadcasted{ArrayPartitionStyle{Style}} ) where {Style}
371383 N = npartitions (dest, bc)
372- @inline function f (i)
373- copyto! (dest. x[i], unpack (bc, i))
384+ # Check if this is a simple enough broadcast that we can optimize
385+ if bc. f isa Union{typeof (+ ), typeof (* ), typeof (muladd)}
386+ # @show "hey", bc, N
387+ @inbounds for i in 1 : N
388+ # Use materialize! which is more efficient than copyto! for simple broadcasts
389+ Base. Broadcast. materialize! (dest. x[i], unpack (bc, i))
390+ end
391+ else
392+ # Fall back to original implementation for complex broadcasts
393+ @inline function f (i)
394+ copyto! (dest. x[i], unpack (bc, i))
395+ end
396+ ntuple (f, Val (N))
374397 end
375- ntuple (f, Val (N))
376398 dest
377399end
378400
411433 i) where {Style <: Broadcast.DefaultArrayStyle }
412434 Broadcast. Broadcasted {Style} (bc. f, unpack_args (i, bc. args))
413435end
414- unpack (x, :: Any ) = x
415- unpack (x:: ArrayPartition , i) = x. x[i]
436+
437+ @inline unpack (x, :: Any ) = x
438+ @inline unpack (x:: ArrayPartition , i) = x. x[i]
439+
416440
417441@inline function unpack_args (i, args:: Tuple )
418442 (unpack (args[1 ], i), unpack_args (i, Base. tail (args))... )
0 commit comments