Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve cat design / performance #49322

Merged
merged 2 commits into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 24 additions & 26 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1651,7 +1651,7 @@ function _typed_hcat(::Type{T}, A::AbstractVecOrTuple{AbstractVecOrMat}) where T
for j = 1:nargs
Aj = A[j]
if size(Aj, 1) != nrows
throw(ArgumentError("number of rows of each array must match (got $(map(x->size(x,1), A)))"))
throw(DimensionMismatch("number of rows of each array must match (got $(map(x->size(x,1), A)))"))
end
dense &= isa(Aj,Array)
nd = ndims(Aj)
Expand Down Expand Up @@ -1686,7 +1686,7 @@ function _typed_vcat(::Type{T}, A::AbstractVecOrTuple{AbstractVecOrMat}) where T
ncols = size(A[1], 2)
for j = 2:nargs
if size(A[j], 2) != ncols
throw(ArgumentError("number of columns of each array must match (got $(map(x->size(x,2), A)))"))
throw(DimensionMismatch("number of columns of each array must match (got $(map(x->size(x,2), A)))"))
end
end
B = similar(A[1], T, nrows, ncols)
Expand Down Expand Up @@ -1984,16 +1984,14 @@ julia> cat(1, [2], [3;;]; dims=Val(2))

# The specializations for 1 and 2 inputs are important
# especially when running with --inline=no, see #11158
# The specializations for Union{AbstractVecOrMat,Number} are necessary
# to have more specialized methods here than in LinearAlgebra/uniformscaling.jl
vcat(A::AbstractArray) = cat(A; dims=Val(1))
vcat(A::AbstractArray, B::AbstractArray) = cat(A, B; dims=Val(1))
vcat(A::AbstractArray...) = cat(A...; dims=Val(1))
vcat(A::Union{AbstractVecOrMat,Number}...) = cat(A...; dims=Val(1))
vcat(A::Union{AbstractArray,Number}...) = cat(A...; dims=Val(1))
hcat(A::AbstractArray) = cat(A; dims=Val(2))
hcat(A::AbstractArray, B::AbstractArray) = cat(A, B; dims=Val(2))
hcat(A::AbstractArray...) = cat(A...; dims=Val(2))
hcat(A::Union{AbstractVecOrMat,Number}...) = cat(A...; dims=Val(2))
hcat(A::Union{AbstractArray,Number}...) = cat(A...; dims=Val(2))

typed_vcat(T::Type, A::AbstractArray) = _cat_t(Val(1), T, A)
typed_vcat(T::Type, A::AbstractArray, B::AbstractArray) = _cat_t(Val(1), T, A, B)
Expand Down Expand Up @@ -2055,8 +2053,8 @@ julia> hvcat((2,2,2), a,b,c,d,e,f) == hvcat(2, a,b,c,d,e,f)
true
```
"""
hvcat(rows::Tuple{Vararg{Int}}, xs::AbstractVecOrMat...) = typed_hvcat(promote_eltype(xs...), rows, xs...)
hvcat(rows::Tuple{Vararg{Int}}, xs::AbstractVecOrMat{T}...) where {T} = typed_hvcat(T, rows, xs...)
hvcat(rows::Tuple{Vararg{Int}}, xs::AbstractArray...) = typed_hvcat(promote_eltype(xs...), rows, xs...)
hvcat(rows::Tuple{Vararg{Int}}, xs::AbstractArray{T}...) where {T} = typed_hvcat(T, rows, xs...)

function typed_hvcat(::Type{T}, rows::Tuple{Vararg{Int}}, as::AbstractVecOrMat...) where T
nbr = length(rows) # number of block rows
Expand Down Expand Up @@ -2084,16 +2082,16 @@ function typed_hvcat(::Type{T}, rows::Tuple{Vararg{Int}}, as::AbstractVecOrMat..
Aj = as[a+j-1]
szj = size(Aj,2)
if size(Aj,1) != szi
throw(ArgumentError("mismatched height in block row $(i) (expected $szi, got $(size(Aj,1)))"))
throw(DimensionMismatch("mismatched height in block row $(i) (expected $szi, got $(size(Aj,1)))"))
end
if c-1+szj > nc
throw(ArgumentError("block row $(i) has mismatched number of columns (expected $nc, got $(c-1+szj))"))
throw(DimensionMismatch("block row $(i) has mismatched number of columns (expected $nc, got $(c-1+szj))"))
end
out[r:r-1+szi, c:c-1+szj] = Aj
c += szj
end
if c != nc+1
throw(ArgumentError("block row $(i) has mismatched number of columns (expected $nc, got $(c-1))"))
throw(DimensionMismatch("block row $(i) has mismatched number of columns (expected $nc, got $(c-1))"))
end
r += szi
a += rows[i]
Expand All @@ -2115,7 +2113,7 @@ function hvcat(rows::Tuple{Vararg{Int}}, xs::T...) where T<:Number
k = 1
@inbounds for i=1:nr
if nc != rows[i]
throw(ArgumentError("row $(i) has mismatched number of columns (expected $nc, got $(rows[i]))"))
throw(DimensionMismatch("row $(i) has mismatched number of columns (expected $nc, got $(rows[i]))"))
end
for j=1:nc
a[i,j] = xs[k]
Expand Down Expand Up @@ -2144,14 +2142,14 @@ end
hvcat(rows::Tuple{Vararg{Int}}, xs::Number...) = typed_hvcat(promote_typeof(xs...), rows, xs...)
hvcat(rows::Tuple{Vararg{Int}}, xs...) = typed_hvcat(promote_eltypeof(xs...), rows, xs...)
# the following method is needed to provide a more specific one compared to LinearAlgebra/uniformscaling.jl
hvcat(rows::Tuple{Vararg{Int}}, xs::Union{AbstractVecOrMat,Number}...) = typed_hvcat(promote_eltypeof(xs...), rows, xs...)
hvcat(rows::Tuple{Vararg{Int}}, xs::Union{AbstractArray,Number}...) = typed_hvcat(promote_eltypeof(xs...), rows, xs...)

function typed_hvcat(::Type{T}, rows::Tuple{Vararg{Int}}, xs::Number...) where T
nr = length(rows)
nc = rows[1]
for i = 2:nr
if nc != rows[i]
throw(ArgumentError("row $(i) has mismatched number of columns (expected $nc, got $(rows[i]))"))
throw(DimensionMismatch("row $(i) has mismatched number of columns (expected $nc, got $(rows[i]))"))
end
end
hvcat_fill!(Matrix{T}(undef, nr, nc), xs)
Expand Down Expand Up @@ -2319,7 +2317,7 @@ function _typed_hvncat(::Type{T}, ::Val{N}, as::AbstractArray...) where {T, N}
Ndim += cat_size(as[i], N)
nd = max(nd, cat_ndims(as[i]))
for d ∈ 1:N - 1
cat_size(as[1], d) == cat_size(as[i], d) || throw(ArgumentError("mismatched size along axis $d in element $i"))
cat_size(as[1], d) == cat_size(as[i], d) || throw(DimensionMismatch("mismatched size along axis $d in element $i"))
end
end

Expand All @@ -2346,7 +2344,7 @@ function _typed_hvncat(::Type{T}, ::Val{N}, as...) where {T, N}
nd = max(nd, cat_ndims(as[i]))
for d ∈ 1:N-1
cat_size(as[i], d) == 1 ||
throw(ArgumentError("all dimensions of element $i other than $N must be of length 1"))
throw(DimensionMismatch("all dimensions of element $i other than $N must be of length 1"))
end
end

Expand Down Expand Up @@ -2463,7 +2461,7 @@ function _typed_hvncat_dims(::Type{T}, dims::NTuple{N, Int}, row_first::Bool, as
for dd ∈ 1:N
dd == d && continue
if cat_size(as[startelementi], dd) != cat_size(as[i], dd)
throw(ArgumentError("incompatible shape in element $i"))
throw(DimensionMismatch("incompatible shape in element $i"))
end
end
end
Expand Down Expand Up @@ -2500,18 +2498,18 @@ function _typed_hvncat_dims(::Type{T}, dims::NTuple{N, Int}, row_first::Bool, as
elseif currentdims[d] < outdims[d] # dimension in progress
break
else # exceeded dimension
throw(ArgumentError("argument $i has too many elements along axis $d"))
throw(DimensionMismatch("argument $i has too many elements along axis $d"))
end
end
end
elseif currentdims[d1] > outdims[d1] # exceeded dimension
throw(ArgumentError("argument $i has too many elements along axis $d1"))
throw(DimensionMismatch("argument $i has too many elements along axis $d1"))
end
end

outlen = prod(outdims)
elementcount == outlen ||
throw(ArgumentError("mismatched number of elements; expected $(outlen), got $(elementcount)"))
throw(DimensionMismatch("mismatched number of elements; expected $(outlen), got $(elementcount)"))

# copy into final array
A = cat_similar(as[1], T, outdims)
Expand Down Expand Up @@ -2572,8 +2570,8 @@ function _typed_hvncat_shape(::Type{T}, shape::NTuple{N, Tuple}, row_first, as::
if d == 1 || i == 1 || wasstartblock
currentdims[d] += dsize
elseif dsize != cat_size(as[i - 1], ad)
throw(ArgumentError("argument $i has a mismatched number of elements along axis $ad; \
expected $(cat_size(as[i - 1], ad)), got $dsize"))
throw(DimensionMismatch("argument $i has a mismatched number of elements along axis $ad; \
expected $(cat_size(as[i - 1], ad)), got $dsize"))
end

wasstartblock = blockcounts[d] == 1 # remember for next dimension
Expand All @@ -2583,15 +2581,15 @@ function _typed_hvncat_shape(::Type{T}, shape::NTuple{N, Tuple}, row_first, as::
if outdims[d] == -1
outdims[d] = currentdims[d]
elseif outdims[d] != currentdims[d]
throw(ArgumentError("argument $i has a mismatched number of elements along axis $ad; \
expected $(abs(outdims[d] - (currentdims[d] - dsize))), got $dsize"))
throw(DimensionMismatch("argument $i has a mismatched number of elements along axis $ad; \
expected $(abs(outdims[d] - (currentdims[d] - dsize))), got $dsize"))
end
currentdims[d] = 0
blockcounts[d] = 0
shapepos[d] += 1
d > 1 && (blockcounts[d - 1] == 0 ||
throw(ArgumentError("shape in level $d is inconsistent; level counts must nest \
evenly into each other")))
throw(DimensionMismatch("shape in level $d is inconsistent; level counts must nest \
evenly into each other")))
end
end
end
Expand Down
12 changes: 0 additions & 12 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2041,18 +2041,6 @@ function vcat(arrays::Vector{T}...) where T
end
vcat(A::Vector...) = cat(A...; dims=Val(1)) # more special than SparseArrays's vcat

# disambiguation with LinAlg/special.jl
# Union{Number,Vector,Matrix} is for LinearAlgebra._DenseConcatGroup
# VecOrMat{T} is for LinearAlgebra._TypedDenseConcatGroup
hcat(A::Union{Number,Vector,Matrix}...) = cat(A...; dims=Val(2))
hcat(A::VecOrMat{T}...) where {T} = typed_hcat(T, A...)
vcat(A::Union{Number,Vector,Matrix}...) = cat(A...; dims=Val(1))
vcat(A::VecOrMat{T}...) where {T} = typed_vcat(T, A...)
hvcat(rows::Tuple{Vararg{Int}}, xs::Union{Number,Vector,Matrix}...) =
typed_hvcat(promote_eltypeof(xs...), rows, xs...)
hvcat(rows::Tuple{Vararg{Int}}, xs::VecOrMat{T}...) where {T} =
typed_hvcat(T, rows, xs...)

_cat(n::Integer, x::Integer...) = reshape([x...], (ntuple(Returns(1), n-1)..., length(x)))

## find ##
Expand Down

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
e59c1c57b97e17a73eba758d65022bd7
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
ad88ebe77aaf1580e6d7ee7649ac5b812a23b9d9bf947f26babe9dd79902f6da11aa69bf63f22f67f6eae92a4c6e665cc3b950bb7c648c623e9cb4b9cb4daac4
26 changes: 5 additions & 21 deletions stdlib/LinearAlgebra/src/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -330,27 +330,11 @@ end
==(A::Bidiagonal, B::SymTridiagonal) = iszero(_evview(B)) && iszero(A.ev) && A.dv == B.dv
==(B::SymTridiagonal, A::Bidiagonal) = A == B

# concatenation
const _SpecialArrays = Union{Diagonal, Bidiagonal, Tridiagonal, SymTridiagonal}
const _Symmetric_DenseArrays{T,A<:Matrix} = Symmetric{T,A}
const _Hermitian_DenseArrays{T,A<:Matrix} = Hermitian{T,A}
const _Triangular_DenseArrays{T,A<:Matrix} = UpperOrLowerTriangular{T,A}
const _Annotated_DenseArrays = Union{_SpecialArrays, _Triangular_DenseArrays, _Symmetric_DenseArrays, _Hermitian_DenseArrays}
const _Annotated_Typed_DenseArrays{T} = Union{_Triangular_DenseArrays{T}, _Symmetric_DenseArrays{T}, _Hermitian_DenseArrays{T}}
const _DenseConcatGroup = Union{Number, Vector, Adjoint{<:Any,<:Vector}, Transpose{<:Any,<:Vector}, Matrix, _Annotated_DenseArrays}
const _TypedDenseConcatGroup{T} = Union{Vector{T}, Adjoint{T,Vector{T}}, Transpose{T,Vector{T}}, Matrix{T}, _Annotated_Typed_DenseArrays{T}}

promote_to_array_type(::Tuple{Vararg{Union{_DenseConcatGroup,UniformScaling}}}) = Matrix

Base._cat(dims, xs::_DenseConcatGroup...) = Base._cat_t(dims, promote_eltype(xs...), xs...)
vcat(A::_DenseConcatGroup...) = Base.typed_vcat(promote_eltype(A...), A...)
hcat(A::_DenseConcatGroup...) = Base.typed_hcat(promote_eltype(A...), A...)
hvcat(rows::Tuple{Vararg{Int}}, xs::_DenseConcatGroup...) = Base.typed_hvcat(promote_eltype(xs...), rows, xs...)
# For performance, specially handle the case where the matrices/vectors have homogeneous eltype
Base._cat(dims, xs::_TypedDenseConcatGroup{T}...) where {T} = Base._cat_t(dims, T, xs...)
vcat(A::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_vcat(T, A...)
hcat(A::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_hcat(T, A...)
hvcat(rows::Tuple{Vararg{Int}}, xs::_TypedDenseConcatGroup{T}...) where {T} = Base.typed_hvcat(T, rows, xs...)
# TODO: remove these deprecations (used by SparseArrays in the past)
const _DenseConcatGroup = Union{}
const _SpecialArrays = Union{}

promote_to_array_type(::Tuple) = Matrix

# factorizations
function cholesky(S::RealHermSymComplexHerm{<:Real,<:SymTridiagonal}, ::NoPivot = NoPivot(); check::Bool = true)
Expand Down
14 changes: 6 additions & 8 deletions stdlib/LinearAlgebra/src/uniformscaling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ end
# so that we can re-use this code for sparse-matrix hcat etcetera.
promote_to_arrays_(n::Int, ::Type, a::Number) = a
promote_to_arrays_(n::Int, ::Type{Matrix}, J::UniformScaling{T}) where {T} = Matrix(J, n, n)
promote_to_arrays_(n::Int, ::Type, A::AbstractVecOrMat) = A
promote_to_arrays_(n::Int, ::Type, A::AbstractArray) = A
promote_to_arrays(n,k, ::Type) = ()
promote_to_arrays(n,k, ::Type{T}, A) where {T} = (promote_to_arrays_(n[k], T, A),)
promote_to_arrays(n,k, ::Type{T}, A, B) where {T} =
Expand All @@ -417,17 +417,16 @@ promote_to_arrays(n,k, ::Type{T}, A, B, C) where {T} =
(promote_to_arrays_(n[k], T, A), promote_to_arrays_(n[k+1], T, B), promote_to_arrays_(n[k+2], T, C))
promote_to_arrays(n,k, ::Type{T}, A, B, Cs...) where {T} =
(promote_to_arrays_(n[k], T, A), promote_to_arrays_(n[k+1], T, B), promote_to_arrays(n,k+2, T, Cs...)...)
promote_to_array_type(A::Tuple{Vararg{Union{AbstractVecOrMat,UniformScaling,Number}}}) = Matrix

_us2number(A) = A
_us2number(J::UniformScaling) = J.λ

for (f, _f, dim, name) in ((:hcat, :_hcat, 1, "rows"), (:vcat, :_vcat, 2, "cols"))
@eval begin
@inline $f(A::Union{AbstractVecOrMat,UniformScaling}...) = $_f(A...)
@inline $f(A::Union{AbstractArray,UniformScaling}...) = $_f(A...)
# if there's a Number present, J::UniformScaling must be 1x1-dimensional
@inline $f(A::Union{AbstractVecOrMat,UniformScaling,Number}...) = $f(map(_us2number, A)...)
function $_f(A::Union{AbstractVecOrMat,UniformScaling,Number}...; array_type = promote_to_array_type(A))
@inline $f(A::Union{AbstractArray,UniformScaling,Number}...) = $f(map(_us2number, A)...)
function $_f(A::Union{AbstractArray,UniformScaling,Number}...; array_type = promote_to_array_type(A))
n = -1
for a in A
if !isa(a, UniformScaling)
Expand All @@ -445,9 +444,8 @@ for (f, _f, dim, name) in ((:hcat, :_hcat, 1, "rows"), (:vcat, :_vcat, 2, "cols"
end
end

hvcat(rows::Tuple{Vararg{Int}}, A::Union{AbstractVecOrMat,UniformScaling}...) = _hvcat(rows, A...)
hvcat(rows::Tuple{Vararg{Int}}, A::Union{AbstractVecOrMat,UniformScaling,Number}...) = _hvcat(rows, A...)
function _hvcat(rows::Tuple{Vararg{Int}}, A::Union{AbstractVecOrMat,UniformScaling,Number}...; array_type = promote_to_array_type(A))
hvcat(rows::Tuple{Vararg{Int}}, A::Union{AbstractArray,UniformScaling,Number}...) = _hvcat(rows, A...)
function _hvcat(rows::Tuple{Vararg{Int}}, A::Union{AbstractArray,UniformScaling,Number}...; array_type = promote_to_array_type(A))
require_one_based_indexing(A...)
nr = length(rows)
sum(rows) == length(A) || throw(ArgumentError("mismatch between row sizes and number of arguments"))
Expand Down
2 changes: 1 addition & 1 deletion stdlib/SparseArrays.version
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
SPARSEARRAYS_BRANCH = main
SPARSEARRAYS_SHA1 = 2c7f4d6d839e9a97027454a037bfa004c1eb34b0
SPARSEARRAYS_SHA1 = b4b0e721ada6e8cf5f6391aff4db307be69b0401
SPARSEARRAYS_GIT_URL := https://github.com/JuliaSparse/SparseArrays.jl.git
SPARSEARRAYS_TAR_URL = https://api.github.com/repos/JuliaSparse/SparseArrays.jl/tarball/$1
Loading