From ae35f19997b0e92bfd706cf577a30822ecdb823d Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 3 Apr 2024 12:31:22 +0530 Subject: [PATCH 1/3] zeros/ones/fill/trues/falses may accept AbstractUnitRange dims --- base/array.jl | 6 ++++++ base/bitarray.jl | 2 ++ 2 files changed, 8 insertions(+) diff --git a/base/array.jl b/base/array.jl index 7676b380923ee..fb702725b389a 100644 --- a/base/array.jl +++ b/base/array.jl @@ -529,6 +529,7 @@ function fill end fill(v, dims::DimOrInd...) = fill(v, dims) fill(v, dims::NTuple{N, Union{Integer, OneTo}}) where {N} = fill(v, map(to_dim, dims)) fill(v, dims::NTuple{N, Integer}) where {N} = (a=Array{typeof(v),N}(undef, dims); fill!(a, v); a) +fill(v, dims::NTuple{N, DimOrInd}) where {N} = (a=similar(Array{typeof(v),N}, dims); fill!(a, v); a) fill(v, dims::Tuple{}) = (a=Array{typeof(v),0}(undef, dims); fill!(a, v); a) """ @@ -589,6 +590,11 @@ for (fname, felt) in ((:zeros, :zero), (:ones, :one)) fill!(a, $felt(T)) return a end + function $fname(::Type{T}, dims::NTuple{N, DimOrInd}) where {T,N} + a = similar(Array{T,N}, dims) + fill!(a, $felt(T)) + return a + end end end diff --git a/base/bitarray.jl b/base/bitarray.jl index 079dbefe03a94..dc3fa141fa359 100644 --- a/base/bitarray.jl +++ b/base/bitarray.jl @@ -403,6 +403,7 @@ julia> falses(2,3) falses(dims::DimOrInd...) = falses(dims) falses(dims::NTuple{N, Union{Integer, OneTo}}) where {N} = falses(map(to_dim, dims)) falses(dims::NTuple{N, Integer}) where {N} = fill!(BitArray(undef, dims), false) +falses(dims::NTuple{N, DimOrInd}) where {N} = fill!(BitArray(undef, dims), false) falses(dims::Tuple{}) = fill!(BitArray(undef, dims), false) """ @@ -421,6 +422,7 @@ julia> trues(2,3) trues(dims::DimOrInd...) = trues(dims) trues(dims::NTuple{N, Union{Integer, OneTo}}) where {N} = trues(map(to_dim, dims)) trues(dims::NTuple{N, Integer}) where {N} = fill!(BitArray(undef, dims), true) +trues(dims::NTuple{N, DimOrInd}) where {N} = fill!(BitArray(undef, dims), true) trues(dims::Tuple{}) = fill!(BitArray(undef, dims), true) function one(x::BitMatrix) From cd3a5cdfacc8b06924606f6f08ffe7721d1b31d4 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Fri, 5 Apr 2024 20:05:39 +0530 Subject: [PATCH 2/3] Add tests --- base/bitarray.jl | 2 -- test/abstractarray.jl | 17 +++++++++++++++++ test/testhelpers/SizedArrays.jl | 14 ++++++++++++++ 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/base/bitarray.jl b/base/bitarray.jl index dc3fa141fa359..079dbefe03a94 100644 --- a/base/bitarray.jl +++ b/base/bitarray.jl @@ -403,7 +403,6 @@ julia> falses(2,3) falses(dims::DimOrInd...) = falses(dims) falses(dims::NTuple{N, Union{Integer, OneTo}}) where {N} = falses(map(to_dim, dims)) falses(dims::NTuple{N, Integer}) where {N} = fill!(BitArray(undef, dims), false) -falses(dims::NTuple{N, DimOrInd}) where {N} = fill!(BitArray(undef, dims), false) falses(dims::Tuple{}) = fill!(BitArray(undef, dims), false) """ @@ -422,7 +421,6 @@ julia> trues(2,3) trues(dims::DimOrInd...) = trues(dims) trues(dims::NTuple{N, Union{Integer, OneTo}}) where {N} = trues(map(to_dim, dims)) trues(dims::NTuple{N, Integer}) where {N} = fill!(BitArray(undef, dims), true) -trues(dims::NTuple{N, DimOrInd}) where {N} = fill!(BitArray(undef, dims), true) trues(dims::Tuple{}) = fill!(BitArray(undef, dims), true) function one(x::BitMatrix) diff --git a/test/abstractarray.jl b/test/abstractarray.jl index a0a6ba6b2229a..162f1781af2e3 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -11,6 +11,9 @@ using .Main.StructArrays isdefined(Main, :FillArrays) || @eval Main include("testhelpers/FillArrays.jl") using .Main.FillArrays +isdefined(Main, :SizedArrays) || @eval Main include("testhelpers/SizedArrays.jl") +using .Main.SizedArrays + A = rand(5,4,3) @testset "Bounds checking" begin @test checkbounds(Bool, A, 1, 1, 1) == true @@ -2097,3 +2100,17 @@ end @test r2[i] == z[j] end end + +@testset "zero for arbitrary axes" begin + r = SizedArrays.SOneTo(2) + s = Base.OneTo(2) + _to_oneto(x::Integer) = Base.OneTo(2) + _to_oneto(x::Union{Base.OneTo, SizedArrays.SOneTo}) = x + for (f, v) in ((zeros, 0), (ones, 1), ((x...)->fill(3,x...),3)) + for ax in ((r,r), (s, r), (2, r)) + A = f(ax...) + @test axes(A) == map(_to_oneto, ax) + @test all(==(v), A) + end + end +end diff --git a/test/testhelpers/SizedArrays.jl b/test/testhelpers/SizedArrays.jl index 43bc27e630479..5177de97f4c67 100644 --- a/test/testhelpers/SizedArrays.jl +++ b/test/testhelpers/SizedArrays.jl @@ -43,10 +43,24 @@ Base.size(a::SizedArray) = size(typeof(a)) Base.size(::Type{<:SizedArray{SZ}}) where {SZ} = SZ Base.axes(a::SizedArray) = map(SOneTo, size(a)) Base.getindex(A::SizedArray, i...) = getindex(A.data, i...) +Base.setindex!(A::SizedArray, v, i...) = setindex!(A.data, v, i...) Base.zero(::Type{T}) where T <: SizedArray = SizedArray{size(T)}(zeros(eltype(T), size(T))) +(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = SizedArray{SZ}(S1.data + S2.data) ==(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = S1.data == S2.data +homogenize_shape(t::Tuple) = (_homogenize_shape(first(t)), homogenize_shape(Base.tail(t))...) +homogenize_shape(::Tuple{}) = () +_homogenize_shape(x::Integer) = x +_homogenize_shape(x::AbstractUnitRange) = length(x) +const Dims = Union{Integer, Base.OneTo, SOneTo} +function Base.similar(::Type{A}, shape::Tuple{Dims, Vararg{Dims}}) where {A<:AbstractArray} + similar(A, homogenize_shape(shape)) +end +function Base.similar(::Type{A}, shape::Tuple{SOneTo, Vararg{SOneTo}}) where {A<:AbstractArray} + R = similar(A, length.(shape)) + SizedArray{length.(shape)}(R) +end + const SizedMatrixLike = Union{SizedMatrix, Transpose{<:Any, <:SizedMatrix}, Adjoint{<:Any, <:SizedMatrix}} _data(S::SizedArray) = S.data From 420e782900e2241f14d5a7d1acd218e418734585 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Sun, 14 Apr 2024 16:30:38 +0530 Subject: [PATCH 3/3] Specialize trues and falses --- base/bitarray.jl | 2 ++ test/abstractarray.jl | 5 +++++ test/bitarray.jl | 22 ++++++++++++++++++++++ test/testhelpers/SizedArrays.jl | 1 + 4 files changed, 30 insertions(+) diff --git a/base/bitarray.jl b/base/bitarray.jl index 079dbefe03a94..f7eeafbb62231 100644 --- a/base/bitarray.jl +++ b/base/bitarray.jl @@ -404,6 +404,7 @@ falses(dims::DimOrInd...) = falses(dims) falses(dims::NTuple{N, Union{Integer, OneTo}}) where {N} = falses(map(to_dim, dims)) falses(dims::NTuple{N, Integer}) where {N} = fill!(BitArray(undef, dims), false) falses(dims::Tuple{}) = fill!(BitArray(undef, dims), false) +falses(dims::NTuple{N, DimOrInd}) where {N} = fill!(similar(BitArray, dims), false) """ trues(dims) @@ -422,6 +423,7 @@ trues(dims::DimOrInd...) = trues(dims) trues(dims::NTuple{N, Union{Integer, OneTo}}) where {N} = trues(map(to_dim, dims)) trues(dims::NTuple{N, Integer}) where {N} = fill!(BitArray(undef, dims), true) trues(dims::Tuple{}) = fill!(BitArray(undef, dims), true) +trues(dims::NTuple{N, DimOrInd}) where {N} = fill!(similar(BitArray, dims), true) function one(x::BitMatrix) m, n = size(x) diff --git a/test/abstractarray.jl b/test/abstractarray.jl index 162f1781af2e3..feb6adaf39fdd 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -2110,6 +2110,11 @@ end for ax in ((r,r), (s, r), (2, r)) A = f(ax...) @test axes(A) == map(_to_oneto, ax) + if all(x -> x isa SizedArrays.SOneTo, ax) + @test A isa SizedArrays.SizedArray && parent(A) isa Array + else + @test A isa Array + end @test all(==(v), A) end end diff --git a/test/bitarray.jl b/test/bitarray.jl index 056a201bd4f6f..2cf285370441e 100644 --- a/test/bitarray.jl +++ b/test/bitarray.jl @@ -3,6 +3,9 @@ using Base: findprevnot, findnextnot using Random, LinearAlgebra, Test +isdefined(Main, :SizedArrays) || @eval Main include("testhelpers/SizedArrays.jl") +using .Main.SizedArrays + tc(r1::NTuple{N,Any}, r2::NTuple{N,Any}) where {N} = all(x->tc(x...), [zip(r1,r2)...]) tc(r1::BitArray{N}, r2::Union{BitArray{N},Array{Bool,N}}) where {N} = true tc(r1::SubArray{Bool,N1,BitArray{N2}}, r2::SubArray{Bool,N1,<:Union{BitArray{N2},Array{Bool,N2}}}) where {N1,N2} = true @@ -82,6 +85,25 @@ allsizes = [((), BitArray{0}), ((v1,), BitVector), @test !isassigned(b, length(b) + 1) end +@testset "trues and falses with custom axes" begin + for ax in ((SizedArrays.SOneTo(2),), (SizedArrays.SOneTo(2), Base.OneTo(2))) + t = trues(ax) + if all(x -> x isa SizedArrays.SOneTo, ax) + @test t isa SizedArrays.SizedArray && parent(t) isa BitArray + else + @test t isa BitArray + end + @test all(t) + + f = falses(ax) + if all(x -> x isa SizedArrays.SOneTo, ax) + @test t isa SizedArrays.SizedArray && parent(t) isa BitArray + else + @test t isa BitArray + end + @test !any(f) + end +end @testset "Conversions for size $sz" for (sz, T) in allsizes b1 = rand!(falses(sz...)) diff --git a/test/testhelpers/SizedArrays.jl b/test/testhelpers/SizedArrays.jl index 5177de97f4c67..2d37cead61a08 100644 --- a/test/testhelpers/SizedArrays.jl +++ b/test/testhelpers/SizedArrays.jl @@ -45,6 +45,7 @@ Base.axes(a::SizedArray) = map(SOneTo, size(a)) Base.getindex(A::SizedArray, i...) = getindex(A.data, i...) Base.setindex!(A::SizedArray, v, i...) = setindex!(A.data, v, i...) Base.zero(::Type{T}) where T <: SizedArray = SizedArray{size(T)}(zeros(eltype(T), size(T))) +Base.parent(S::SizedArray) = S.data +(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = SizedArray{SZ}(S1.data + S2.data) ==(S1::SizedArray{SZ}, S2::SizedArray{SZ}) where {SZ} = S1.data == S2.data