Skip to content

Commit

Permalink
Specialize trues and falses
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed Apr 14, 2024
1 parent cd3a5cd commit 420e782
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 0 deletions.
2 changes: 2 additions & 0 deletions base/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ falses(dims::DimOrInd...) = falses(dims)
falses(dims::NTuple{N, Union{Integer, OneTo}}) where {N} = falses(map(to_dim, dims))
falses(dims::NTuple{N, Integer}) where {N} = fill!(BitArray(undef, dims), false)
falses(dims::Tuple{}) = fill!(BitArray(undef, dims), false)
falses(dims::NTuple{N, DimOrInd}) where {N} = fill!(similar(BitArray, dims), false)

"""
trues(dims)
Expand All @@ -422,6 +423,7 @@ trues(dims::DimOrInd...) = trues(dims)
trues(dims::NTuple{N, Union{Integer, OneTo}}) where {N} = trues(map(to_dim, dims))
trues(dims::NTuple{N, Integer}) where {N} = fill!(BitArray(undef, dims), true)
trues(dims::Tuple{}) = fill!(BitArray(undef, dims), true)
trues(dims::NTuple{N, DimOrInd}) where {N} = fill!(similar(BitArray, dims), true)

function one(x::BitMatrix)
m, n = size(x)
Expand Down
5 changes: 5 additions & 0 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2110,6 +2110,11 @@ end
for ax in ((r,r), (s, r), (2, r))
A = f(ax...)
@test axes(A) == map(_to_oneto, ax)
if all(x -> x isa SizedArrays.SOneTo, ax)
@test A isa SizedArrays.SizedArray && parent(A) isa Array
else
@test A isa Array
end
@test all(==(v), A)
end
end
Expand Down
22 changes: 22 additions & 0 deletions test/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
using Base: findprevnot, findnextnot
using Random, LinearAlgebra, Test

isdefined(Main, :SizedArrays) || @eval Main include("testhelpers/SizedArrays.jl")
using .Main.SizedArrays

tc(r1::NTuple{N,Any}, r2::NTuple{N,Any}) where {N} = all(x->tc(x...), [zip(r1,r2)...])
tc(r1::BitArray{N}, r2::Union{BitArray{N},Array{Bool,N}}) where {N} = true
tc(r1::SubArray{Bool,N1,BitArray{N2}}, r2::SubArray{Bool,N1,<:Union{BitArray{N2},Array{Bool,N2}}}) where {N1,N2} = true
Expand Down Expand Up @@ -82,6 +85,25 @@ allsizes = [((), BitArray{0}), ((v1,), BitVector),
@test !isassigned(b, length(b) + 1)
end

@testset "trues and falses with custom axes" begin
for ax in ((SizedArrays.SOneTo(2),), (SizedArrays.SOneTo(2), Base.OneTo(2)))
t = trues(ax)
if all(x -> x isa SizedArrays.SOneTo, ax)
@test t isa SizedArrays.SizedArray && parent(t) isa BitArray
else
@test t isa BitArray
end
@test all(t)

f = falses(ax)
if all(x -> x isa SizedArrays.SOneTo, ax)
@test t isa SizedArrays.SizedArray && parent(t) isa BitArray
else
@test t isa BitArray
end
@test !any(f)
end
end

@testset "Conversions for size $sz" for (sz, T) in allsizes
b1 = rand!(falses(sz...))
Expand Down
1 change: 1 addition & 0 deletions test/testhelpers/SizedArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ Base.axes(a::SizedArray) = map(SOneTo, size(a))
Base.getindex(A::SizedArray, i...) = getindex(A.data, i...)
Base.setindex!(A::SizedArray, v, i...) = setindex!(A.data, v, i...)
Base.zero(::Type{T}) where T <: SizedArray = SizedArray{size(T)}(zeros(eltype(T), size(T)))
Base.parent(S::SizedArray) = S.data
+(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = SizedArray{SZ}(S1.data + S2.data)
==(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = S1.data == S2.data

Expand Down

0 comments on commit 420e782

Please sign in to comment.