Skip to content

Commit de77c97

Browse files
committed
small changes which make it ~1.5x faster
1 parent 3707ad3 commit de77c97

2 files changed

Lines changed: 33 additions & 8 deletions

File tree

src/array_partition.jl

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -364,15 +364,37 @@ end
364364
ArrayPartition(f, N)
365365
end
366366

367+
# old version
368+
# @inline function Base.copyto!(dest::ArrayPartition,
369+
# bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}) where {
370+
# Style,
371+
# }
372+
# N = npartitions(dest, bc)
373+
# @inline function f(i)
374+
# copyto!(dest.x[i], unpack(bc, i))
375+
# end
376+
# ntuple(f, Val(N))
377+
# dest
378+
# end
379+
380+
# new version
367381
@inline function Base.copyto!(dest::ArrayPartition,
368-
bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}) where {
369-
Style,
370-
}
382+
bc::Broadcast.Broadcasted{ArrayPartitionStyle{Style}}) where {Style}
371383
N = npartitions(dest, bc)
372-
@inline function f(i)
373-
copyto!(dest.x[i], unpack(bc, i))
384+
# Check if this is a simple enough broadcast that we can optimize
385+
if bc.f isa Union{typeof(+), typeof(*), typeof(muladd)}
386+
# @show "hey", bc, N
387+
@inbounds for i in 1:N
388+
# Use materialize! which is more efficient than copyto! for simple broadcasts
389+
Base.Broadcast.materialize!(dest.x[i], unpack(bc, i))
390+
end
391+
else
392+
# Fall back to original implementation for complex broadcasts
393+
@inline function f(i)
394+
copyto!(dest.x[i], unpack(bc, i))
395+
end
396+
ntuple(f, Val(N))
374397
end
375-
ntuple(f, Val(N))
376398
dest
377399
end
378400

@@ -411,8 +433,10 @@ end
411433
i) where {Style <: Broadcast.DefaultArrayStyle}
412434
Broadcast.Broadcasted{Style}(bc.f, unpack_args(i, bc.args))
413435
end
414-
unpack(x, ::Any) = x
415-
unpack(x::ArrayPartition, i) = x.x[i]
436+
437+
@inline unpack(x, ::Any) = x
438+
@inline unpack(x::ArrayPartition, i) = x.x[i]
439+
416440

417441
@inline function unpack_args(i, args::Tuple)
418442
(unpack(args[1], i), unpack_args(i, Base.tail(args))...)

src/named_array_partition.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ end
135135
NamedArrayPartition(f, N, getfield(x, :names_to_indices))
136136
end
137137

138+
# TODO: has this also performance problems and can be improved?
138139
@inline function Base.copyto!(dest::NamedArrayPartition,
139140
bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArrayPartition}})
140141
N = npartitions(dest, bc)

0 commit comments

Comments
 (0)