Skip to content

Commit

Permalink
Generalize broadcast!(f, ::BitVector) optimization to BitArray.
Browse files Browse the repository at this point in the history
  • Loading branch information
N5N3 committed Dec 21, 2023
1 parent 34d1b71 commit f7c5052
Showing 1 changed file with 30 additions and 40 deletions.
70 changes: 30 additions & 40 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -974,53 +974,43 @@ preprocess(dest, x) = extrude(broadcast_unalias(dest, x))
return dest
end

# Performance optimization: for BitVector outputs, we cache the result
# in a 64-bit register before writing into memory (to bypass LSQ)
@inline function copyto!(dest::BitVector, bc::Broadcasted{Nothing})
axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc))
ischunkedbroadcast(dest, bc) && return chunkedcopyto!(dest, bc)
destc = dest.chunks
bcp = preprocess(dest, bc)
length(bcp) <= 0 && return dest
len = Base.num_bit_chunks(Int(length(bcp)))
@inbounds for i = 0:(len - 2)
z = UInt64(0)
for j = 0:63
z |= UInt64(bcp[i*64 + j + 1]::Bool) << (j & 63)
end
destc[i + 1] = z
end
@inbounds let i = len - 1
z = UInt64(0)
for j = 0:((length(bcp) - 1) & 63)
z |= UInt64(bcp[i*64 + j + 1]::Bool) << (j & 63)
end
destc[i + 1] = z
end
return dest
end

# Performance optimization: for BitArray outputs, we cache the result
# in a "small" Vector{Bool}, and then copy in chunks into the output
# in a 64-bit register before writing into memory (to bypass LSQ)
@inline function copyto!(dest::BitArray, bc::Broadcasted{Nothing})
axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc))
ischunkedbroadcast(dest, bc) && return chunkedcopyto!(dest, bc)
length(dest) < 256 && return invoke(copyto!, Tuple{AbstractArray, Broadcasted{Nothing}}, dest, bc)
tmp = Vector{Bool}(undef, bitcache_size)
destc = dest.chunks
cind = 1
ndims(dest) == 0 && return Base.@invoke copyto!(dest::AbstractArray, bc)
bc′ = preprocess(dest, bc)
@inbounds for P in Iterators.partition(eachindex(bc′), bitcache_size)
ind = 1
@simd for I in P
tmp[ind] = bc′[I]
ind += 1
ax = axes(bc′)
ax1, out = ax[1], CartesianIndices(tail(ax))
destc, indc = dest.chunks, 0
bitst, remain = 0, UInt64(0)
@inbounds for I in out
i = first(ax1) - 1
if ndims(bc) == 1 || bitst >= 64 - length(ax1)
if ndims(bc) > 1 && bitst != 0
z = remain
@simd for j = bitst:63
z |= UInt64(convert(Bool, bc′[i+=1, I])) << (j & 63)
end
destc[indc+=1] = z
end
while i <= last(ax1) - 64
z = UInt64(0)
@simd for j = 0:63
z |= UInt64(convert(Bool, bc′[i+=1, I])) << (j & 63)
end
destc[indc+=1] = z
end
bitst, remain = 0, UInt64(0)
end
@simd for i in ind:bitcache_size
tmp[i] = false
@simd for j = i+1:last(ax1)
remain |= UInt64(convert(Bool, bc′[j, I])) << (bitst & 63)
bitst += 1
end
dumpbitcache(destc, cind, tmp)
cind += bitcache_chunks
end
@inbounds if bitst != 0
destc[indc+=1] = remain
end
return dest
end
Expand Down

0 comments on commit f7c5052

Please sign in to comment.