Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize broadcast!(f, ::BitVector) optimization to BitArray. #52736

Merged
merged 4 commits into from
Jan 5, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 30 additions & 40 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -974,53 +974,43 @@ preprocess(dest, x) = extrude(broadcast_unalias(dest, x))
return dest
end

# Performance optimization: for BitVector outputs, we cache the result
# in a 64-bit register before writing into memory (to bypass LSQ)
@inline function copyto!(dest::BitVector, bc::Broadcasted{Nothing})
axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc))
ischunkedbroadcast(dest, bc) && return chunkedcopyto!(dest, bc)
destc = dest.chunks
bcp = preprocess(dest, bc)
length(bcp) <= 0 && return dest
len = Base.num_bit_chunks(Int(length(bcp)))
@inbounds for i = 0:(len - 2)
z = UInt64(0)
for j = 0:63
z |= UInt64(bcp[i*64 + j + 1]::Bool) << (j & 63)
end
destc[i + 1] = z
end
@inbounds let i = len - 1
z = UInt64(0)
for j = 0:((length(bcp) - 1) & 63)
z |= UInt64(bcp[i*64 + j + 1]::Bool) << (j & 63)
end
destc[i + 1] = z
end
return dest
end

# Performance optimization: for BitArray outputs, we cache the result
# in a "small" Vector{Bool}, and then copy in chunks into the output
# in a 64-bit register before writing into memory (to bypass LSQ)
@inline function copyto!(dest::BitArray, bc::Broadcasted{Nothing})
axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc))
ischunkedbroadcast(dest, bc) && return chunkedcopyto!(dest, bc)
length(dest) < 256 && return invoke(copyto!, Tuple{AbstractArray, Broadcasted{Nothing}}, dest, bc)
tmp = Vector{Bool}(undef, bitcache_size)
destc = dest.chunks
cind = 1
ndims(dest) == 0 && (dest[] = bc[]; return dest)
bc′ = preprocess(dest, bc)
@inbounds for P in Iterators.partition(eachindex(bc′), bitcache_size)
ind = 1
@simd for I in P
tmp[ind] = bc′[I]
ind += 1
ax = axes(bc′)
ax1, out = ax[1], CartesianIndices(tail(ax))
destc, indc = dest.chunks, 0
bitst, remain = 0, UInt64(0)
@inbounds for I in out
N5N3 marked this conversation as resolved.
Show resolved Hide resolved
i = first(ax1) - 1
if ndims(bc) == 1 || bitst >= 64 - length(ax1)
if ndims(bc) > 1 && bitst != 0
z = remain
@simd for j = bitst:63
z |= UInt64(convert(Bool, bc′[i+=1, I])) << (j & 63)
N5N3 marked this conversation as resolved.
Show resolved Hide resolved
end
destc[indc+=1] = z
end
while i <= last(ax1) - 64
z = UInt64(0)
@simd for j = 0:63
z |= UInt64(convert(Bool, bc′[i+=1, I])) << (j & 63)
end
destc[indc+=1] = z
end
bitst, remain = 0, UInt64(0)
end
@simd for i in ind:bitcache_size
tmp[i] = false
@simd for j = i+1:last(ax1)
remain |= UInt64(convert(Bool, bc′[j, I])) << (bitst & 63)
bitst += 1
end
dumpbitcache(destc, cind, tmp)
cind += bitcache_chunks
end
@inbounds if bitst != 0
destc[indc+=1] = remain
end
return dest
end
Expand Down