From f7c505221aecfce85a79140e5b1cf226c232ab02 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Thu, 7 Dec 2023 20:53:19 +0800 Subject: [PATCH] Generalize `broadcast!(f, ::BitVector)` optimization to `BitArray`. Follows #32048. --- base/broadcast.jl | 70 ++++++++++++++++++++--------------------------- 1 file changed, 30 insertions(+), 40 deletions(-) diff --git a/base/broadcast.jl b/base/broadcast.jl index d6e5513889cee3..2f31260fefe020 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -974,53 +974,43 @@ preprocess(dest, x) = extrude(broadcast_unalias(dest, x)) return dest end -# Performance optimization: for BitVector outputs, we cache the result -# in a 64-bit register before writing into memory (to bypass LSQ) -@inline function copyto!(dest::BitVector, bc::Broadcasted{Nothing}) - axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc)) - ischunkedbroadcast(dest, bc) && return chunkedcopyto!(dest, bc) - destc = dest.chunks - bcp = preprocess(dest, bc) - length(bcp) <= 0 && return dest - len = Base.num_bit_chunks(Int(length(bcp))) - @inbounds for i = 0:(len - 2) - z = UInt64(0) - for j = 0:63 - z |= UInt64(bcp[i*64 + j + 1]::Bool) << (j & 63) - end - destc[i + 1] = z - end - @inbounds let i = len - 1 - z = UInt64(0) - for j = 0:((length(bcp) - 1) & 63) - z |= UInt64(bcp[i*64 + j + 1]::Bool) << (j & 63) - end - destc[i + 1] = z - end - return dest -end - # Performance optimization: for BitArray outputs, we cache the result -# in a "small" Vector{Bool}, and then copy in chunks into the output +# in a 64-bit register before writing into memory (to bypass LSQ) @inline function copyto!(dest::BitArray, bc::Broadcasted{Nothing}) axes(dest) == axes(bc) || throwdm(axes(dest), axes(bc)) ischunkedbroadcast(dest, bc) && return chunkedcopyto!(dest, bc) - length(dest) < 256 && return invoke(copyto!, Tuple{AbstractArray, Broadcasted{Nothing}}, dest, bc) - tmp = Vector{Bool}(undef, bitcache_size) - destc = dest.chunks - cind = 1 + ndims(dest) == 0 && return Base.@invoke copyto!(dest::AbstractArray, bc) bc′ = preprocess(dest, bc) - @inbounds for P in Iterators.partition(eachindex(bc′), bitcache_size) - ind = 1 - @simd for I in P - tmp[ind] = bc′[I] - ind += 1 + ax = axes(bc′) + ax1, out = ax[1], CartesianIndices(tail(ax)) + destc, indc = dest.chunks, 0 + bitst, remain = 0, UInt64(0) + @inbounds for I in out + i = first(ax1) - 1 + if ndims(bc) == 1 || bitst >= 64 - length(ax1) + if ndims(bc) > 1 && bitst != 0 + z = remain + @simd for j = bitst:63 + z |= UInt64(convert(Bool, bc′[i+=1, I])) << (j & 63) + end + destc[indc+=1] = z + end + while i <= last(ax1) - 64 + z = UInt64(0) + @simd for j = 0:63 + z |= UInt64(convert(Bool, bc′[i+=1, I])) << (j & 63) + end + destc[indc+=1] = z + end + bitst, remain = 0, UInt64(0) end - @simd for i in ind:bitcache_size - tmp[i] = false + @simd for j = i+1:last(ax1) + remain |= UInt64(convert(Bool, bc′[j, I])) << (bitst & 63) + bitst += 1 end - dumpbitcache(destc, cind, tmp) - cind += bitcache_chunks + end + @inbounds if bitst != 0 + destc[indc+=1] = remain end return dest end