Skip to content

Commit

Permalink
Generalize broadcast!(f, ::BitVector) optimization to BitArray. (#…
Browse files Browse the repository at this point in the history
…52736)

Follows #32048.
This PR fully avoids the allocation thus make nd logical broadcast
better scaled for small inputs.

---------

Co-authored-by: Matt Bauman <mbauman@gmail.com>
  • Loading branch information
N5N3 and mbauman authored Jan 5, 2024
1 parent 316cc4a commit 50788cd
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 40 deletions.
69 changes: 29 additions & 40 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -974,53 +974,42 @@ preprocess(dest, x) = extrude(broadcast_unalias(dest, x))
return dest
end

# Performance optimization: for BitVector outputs, we cache the result
# in a 64-bit register before writing into memory (to bypass LSQ)
@inline function copyto!(dest::BitVector, bc::Broadcasted{Nothing})
axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc))
ischunkedbroadcast(dest, bc) && return chunkedcopyto!(dest, bc)
destc = dest.chunks
bcp = preprocess(dest, bc)
length(bcp) <= 0 && return dest
len = Base.num_bit_chunks(Int(length(bcp)))
@inbounds for i = 0:(len - 2)
z = UInt64(0)
for j = 0:63
z |= UInt64(bcp[i*64 + j + 1]::Bool) << (j & 63)
end
destc[i + 1] = z
end
@inbounds let i = len - 1
z = UInt64(0)
for j = 0:((length(bcp) - 1) & 63)
z |= UInt64(bcp[i*64 + j + 1]::Bool) << (j & 63)
end
destc[i + 1] = z
end
return dest
end

# Performance optimization: for BitArray outputs, we cache the result
# in a "small" Vector{Bool}, and then copy in chunks into the output
# in a 64-bit register before writing into memory (to bypass LSQ)
@inline function copyto!(dest::BitArray, bc::Broadcasted{Nothing})
axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc))
ischunkedbroadcast(dest, bc) && return chunkedcopyto!(dest, bc)
length(dest) < 256 && return invoke(copyto!, Tuple{AbstractArray, Broadcasted{Nothing}}, dest, bc)
tmp = Vector{Bool}(undef, bitcache_size)
destc = dest.chunks
cind = 1
ndims(dest) == 0 && (dest[] = bc[]; return dest)
bc′ = preprocess(dest, bc)
@inbounds for P in Iterators.partition(eachindex(bc′), bitcache_size)
ind = 1
@simd for I in P
tmp[ind] = bc′[I]
ind += 1
ax = axes(bc′)
ax1, out = ax[1], CartesianIndices(tail(ax))
destc, indc = dest.chunks, 0
bitst, remain = 0, UInt64(0)
for I in out
i = first(ax1) - 1
if ndims(bc) == 1 || bitst >= 64 - length(ax1)
if ndims(bc) > 1 && bitst != 0
@inbounds @simd for j = bitst:63
remain |= UInt64(convert(Bool, bc′[i+=1, I])) << (j & 63)
end
@inbounds destc[indc+=1] = remain
bitst, remain = 0, UInt64(0)
end
while i <= last(ax1) - 64
z = UInt64(0)
@inbounds @simd for j = 0:63
z |= UInt64(convert(Bool, bc′[i+=1, I])) << (j & 63)
end
@inbounds destc[indc+=1] = z
end
end
@simd for i in ind:bitcache_size
tmp[i] = false
@inbounds @simd for j = i+1:last(ax1)
remain |= UInt64(convert(Bool, bc′[j, I])) << (bitst & 63)
bitst += 1
end
dumpbitcache(destc, cind, tmp)
cind += bitcache_chunks
end
@inbounds if bitst != 0
destc[indc+=1] = remain
end
return dest
end
Expand Down
10 changes: 10 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -592,6 +592,16 @@ end
end
end

@testset "convert behavior of logical broadcast" begin
a = mod.(1:4, 2)
@test !isa(a, BitArray)
for T in (Array{Bool}, BitArray)
la = T(a)
la .= mod.(0:3, 2)
@test la == [false; true; false; true]
end
end

# Test that broadcast treats type arguments as scalars, i.e. containertype yields Any,
# even for subtypes of abstract array. (https://github.com/JuliaStats/DataArrays.jl/issues/229)
@testset "treat type arguments as scalars, DataArrays issue 229" begin
Expand Down

0 comments on commit 50788cd

Please sign in to comment.