Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize broadcast!(f, ::BitVector) optimization to BitArray. #52736

Merged
merged 4 commits into from
Jan 5, 2024

Conversation

N5N3
Copy link
Member

@N5N3 N5N3 commented Jan 4, 2024

Follows #32048.
This PR fully avoids the allocation thus make nd logical broadcast better scaled for small inputs.

Some Benchmark
import Random, Statistics
using BenchmarkTools
using Base:tail, bitcache_size, bitcache_chunks, dumpbitcache
Random.seed!(0)

@inline function master_copyto!(dest::BitVector, bc)
    axes(dest) == axes(bc) || Base.Broadcast.throwdm(axes(dest), axes(bc))
    Base.Broadcast.ischunkedbroadcast(dest, bc) && return Base.Broadcast.chunkedcopyto!(dest, bc)
    destc = dest.chunks
    bcp = Base.Broadcast.preprocess(dest, bc)
    length(bcp) <= 0 && return dest
    len = Base.num_bit_chunks(Int(length(bcp)))
    @inbounds for i = 0:(len - 2)
        z = UInt64(0)
        for j = 0:63
           z |= UInt64(bcp[i*64 + j + 1]::Bool) << (j & 63)
        end
        destc[i + 1] = z
    end
    @inbounds let i = len - 1
        z = UInt64(0)
        for j = 0:((length(bcp) - 1) & 63)
             z |= UInt64(bcp[i*64 + j + 1]::Bool) << (j & 63)
        end
        destc[i + 1] = z
    end
    return dest
end

# Performance optimization: for BitArray outputs, we cache the result
# in a "small" Vector{Bool}, and then copy in chunks into the output
@inline function master_copyto!(dest::BitArray, bc)
    axes(dest) == axes(bc) || Base.Broadcast.throwdm(axes(dest), axes(bc))
    Base.Broadcast.ischunkedbroadcast(dest, bc) && return Base.Broadcast.chunkedcopyto!(dest, bc)
    length(dest) < 256 && return Base.@invoke copyto!(dest::AbstractArray, bc)
    tmp = Vector{Bool}(undef, bitcache_size)
    destc = dest.chunks
    cind = 1
    bc′ = Base.Broadcast.preprocess(dest, bc)
    @inbounds for P in Iterators.partition(eachindex(bc′), bitcache_size)
        ind = 1
        @simd for I in P
            tmp[ind] = bc′[I]
            ind += 1
        end
        @simd for i in ind:bitcache_size
            tmp[i] = false
        end
        dumpbitcache(destc, cind, tmp)
        cind += bitcache_chunks
    end
    return dest
end

function pr_copyto!(dest::BitArray, bc)
    axes(dest) == axes(bc) || Base.Broadcast.throwdm(axes(dest), axes(bc))
    Base.Broadcast.ischunkedbroadcast(dest, bc) && return Base.Broadcast.chunkedcopyto!(dest, bc)
    ndims(dest) == 0 && return Base.@invoke copyto!(dest::AbstractArray, bc)
    ax = axes(bc)
    ax1, out = ax[1], CartesianIndices(tail(ax))
    destc, indc = dest.chunks, 0
    bitst, remain = 0, UInt64(0)
    bc′ = Base.Broadcast.preprocess(dest, bc)
     @inbounds for I in out
        i = first(ax1) - 1
        if ndims(bc) == 1 || bitst >= 64 - length(ax1)
            if ndims(bc) > 1 && bitst != 0
                z = remain
                @simd for j = bitst:63
                    z |= UInt64(convert(Bool, bc[i+=1, I])) << (j & 63)
                end
                destc[indc+=1] = z
            end
            while i <= last(ax1) - 64
                z = UInt64(0)
                @simd for j = 0:63
                    z |= UInt64(convert(Bool, bc[i+=1, I])) << (j & 63)
                end
                destc[indc+=1] = z
            end
            bitst, remain = 0, UInt64(0)
        end
        @simd for j = i+1:last(ax1)
            remain |= UInt64(convert(Bool, bc[j, I])) << (bitst & 63)
            bitst += 1
        end
    end
    @inbounds if bitst != 0
        destc[indc+=1] = remain
    end
    return dest
end

for n in round.(Int, [3 5 10 63 64 65 100])
    for T in (Int8, Int32, Int64)
        for (desc, dims) in (("1d", (100n,)),
                             ("2d (wide)", (10, 10n)),
                             ("2d (tall)", (10n, 10)),
                             (n > 10 ? (("3d (wide)", (10, 10, n)),
                                        ("3d (tall)", (n, 10, 10)),) : ())...)
            a = rand(T(0):T(4), dims)
            a = view(a, axes(a)...)
            bc = Base.broadcasted(==, a, 0)
            bc = Base.Broadcast.instantiate(bc)
            b = BitArray(undef, size(a))
            bc = convert(Base.Broadcast.Broadcasted{Nothing}, bc)
            master = @benchmark master_copyto!($b, $bc)
            pr = @benchmark pr_copyto!($b, $bc)
            println("$T, $(100n), $(Base.dims2string(dims)), $desc, $(median(pr).time/median(master).time)")
        end
    end
end
type len size dimension ratio
Int8 300 300 1d 0.8049701082178423
Int8 300 10×30 2d (wide) 0.14661725552845356
Int8 300 30×10 2d (tall) 0.21419304402787712
Int32 300 300 1d 0.8562138296047829
Int32 300 10×30 2d (wide) 0.36087548666933134
Int32 300 30×10 2d (tall) 0.23401590810405398
Int64 300 300 1d 0.7572145199966925
Int64 300 10×30 2d (wide) 0.3141521846653253
Int64 300 30×10 2d (tall) 0.19794190115328072
Int8 500 500 1d 0.772538940009253
Int8 500 10×50 2d (wide) 0.47195179556897127
Int8 500 50×10 2d (tall) 0.27832403048421167
Int32 500 500 1d 0.8102381389012601
Int32 500 10×50 2d (wide) 0.49712456275567674
Int32 500 50×10 2d (tall) 0.29639818364298487
Int64 500 500 1d 0.680161943319838
Int64 500 10×50 2d (wide) 0.47691933916423707
Int64 500 50×10 2d (tall) 0.25733235297889534
Int8 1000 1000 1d 0.7740494843876912
Int8 1000 10×100 2d (wide) 0.7955315763989763
Int8 1000 100×10 2d (tall) 0.39482795983723
Int32 1000 1000 1d 0.7664919113876304
Int32 1000 10×100 2d (wide) 0.7918898474427308
Int32 1000 100×10 2d (tall) 0.4206732580037665
Int64 1000 1000 1d 0.6162508215892311
Int64 1000 10×100 2d (wide) 0.7419391308280197
Int64 1000 100×10 2d (tall) 0.28843806921675774
Int8 6300 6300 1d 0.7508786158421195
Int8 6300 10×630 2d (wide) 1.1015625
Int8 6300 630×10 2d (tall) 0.5841503267973855
Int8 6300 10×10×63 3d (wide) 1.0446096654275092
Int8 6300 63×10×10 3d (tall) 0.9523809523809523
Int32 6300 6300 1d 0.7714673386315177
Int32 6300 10×630 2d (wide) 1.060377358490566
Int32 6300 630×10 2d (tall) 0.5455993628036638
Int32 6300 10×10×63 3d (wide) 1.0881226053639848
Int32 6300 63×10×10 3d (tall) 1.03125
Int64 6300 6300 1d 0.5650955545919919
Int64 6300 10×630 2d (wide) 1.1023622047244095
Int64 6300 630×10 2d (tall) 0.3621870397643593
Int64 6300 10×10×63 3d (wide) 1.0795454545454546
Int64 6300 63×10×10 3d (tall) 0.7745098039215687
Int8 6400 6400 1d 0.7665206279664111
Int8 6400 10×640 2d (wide) 1.14453125
Int8 6400 640×10 2d (tall) 0.5405405405405406
Int8 6400 10×10×64 3d (wide) 1.032490974729242
Int8 6400 64×10×10 3d (tall) 0.6187290969899666
Int32 6400 6400 1d 0.763771712158809
Int32 6400 10×640 2d (wide) 1.055350553505535
Int32 6400 640×10 2d (tall) 0.5044358507734303
Int32 6400 10×10×64 3d (wide) 1.1162790697674418
Int32 6400 64×10×10 3d (tall) 0.5995934959349594
Int64 6400 6400 1d 0.551921470342523
Int64 6400 10×640 2d (wide) 1.111969111969112
Int64 6400 640×10 2d (tall) 0.32497607655502386
Int64 6400 10×10×64 3d (wide) 1.1254752851711027
Int64 6400 64×10×10 3d (tall) 0.39924961377179435
Int8 6500 6500 1d 0.760268456375839
Int8 6500 10×650 2d (wide) 1.1221374045801527
Int8 6500 650×10 2d (tall) 0.5633802816901409
Int8 6500 10×10×65 3d (wide) 1.0469314079422383
Int8 6500 65×10×10 3d (tall) 1.064102564102564
Int32 6500 6500 1d 0.7692039916453933
Int32 6500 10×650 2d (wide) 1.1153846153846154
Int32 6500 650×10 2d (tall) 0.5496323529411764
Int32 6500 10×10×65 3d (wide) 1.0541516245487366
Int32 6500 65×10×10 3d (tall) 1.0121212121212122
Int64 6500 6500 1d 0.5520719460113399
Int64 6500 10×650 2d (wide) 1.118320610687023
Int64 6500 650×10 2d (tall) 0.361455716080402
Int64 6500 10×10×65 3d (wide) 1.0640569395017794
Int64 6500 65×10×10 3d (tall) 0.7994923857868022
Int8 10000 10000 1d 0.7973856209150327
Int8 10000 10×1000 2d (wide) 1.1263208453410183
Int8 10000 1000×10 2d (tall) 0.5541871921182265
Int8 10000 10×10×100 3d (wide) 1.0932475884244375
Int8 10000 100×10×10 3d (tall) 0.8782608695652174
Int32 10000 10000 1d 0.7612903225806451
Int32 10000 10×1000 2d (wide) 1.1429263565891472
Int32 10000 1000×10 2d (tall) 0.5195454545454545
Int32 10000 10×10×100 3d (wide) 1.1467236467236466
Int32 10000 100×10×10 3d (tall) 0.8681818181818182
Int64 10000 10000 1d 0.5546050755631594
Int64 10000 10×1000 2d (wide) 1.1632850241545893
Int64 10000 1000×10 2d (tall) 0.349034749034749
Int64 10000 10×10×100 3d (wide) 1.1310185185185184
Int64 10000 100×10×10 3d (tall) 0.6534351145038167

So looks like this change is a general win as the iteration in the 1st dimension is better vectorized.

@N5N3 N5N3 added performance Must go faster broadcast Applying a function over a collection labels Jan 4, 2024
base/broadcast.jl Outdated Show resolved Hide resolved
@mbauman
Copy link
Member

mbauman commented Jan 4, 2024

This is great!

Co-authored-by: Matt Bauman <mbauman@gmail.com>
@oscardssmith
Copy link
Member

I really wish the compiler was smart enough to do this itself. It is somewhat annoying that there isn't a good way to efficiently use BitVectors other than working with the chunks directly, but while that is the case, I approve of this.

base/broadcast.jl Outdated Show resolved Hide resolved
base/broadcast.jl Outdated Show resolved Hide resolved
Copy link
Member

@vtjnash vtjnash left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this SGTM, but may need added tests?

@vtjnash
Copy link
Member

vtjnash commented Jan 4, 2024

@nanosoldier runbenchmarks("broadcast" || "BitArray", vs=":master")

@nanosoldier
Copy link
Collaborator

Your benchmark job has completed - possible performance regressions were detected. A full report can be found here.

N5N3 added 2 commits January 5, 2024 09:00
And limit the usage of `@inbounds`
@N5N3 N5N3 merged commit 50788cd into JuliaLang:master Jan 5, 2024
4 of 7 checks passed
@N5N3 N5N3 deleted the bitarray_bc branch January 5, 2024 08:11
maleadt added a commit that referenced this pull request Jan 6, 2024
N5N3 added a commit that referenced this pull request Jan 7, 2024
…ay`." (#52736) (#52776)

Reland "Generalize broadcast!(f, ::BitVector) optimization to `BitArray`." (#52736)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
broadcast Applying a function over a collection performance Must go faster
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants