diff --git a/base/abstractarray.jl b/base/abstractarray.jl index 30072363a34c3..1f1120740e99a 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -1580,6 +1580,7 @@ cat_indices(A::AbstractArray, d) = axes(A, d) cat_similar(A, ::Type{T}, shape) where T = Array{T}(undef, shape) cat_similar(A::AbstractArray, ::Type{T}, shape) where T = similar(A, T, shape) +# These are for backwards compatibility (even though internal) cat_shape(dims, shape::Tuple{Vararg{Int}}) = shape function cat_shape(dims, shapes::Tuple) out_shape = () @@ -1588,6 +1589,11 @@ function cat_shape(dims, shapes::Tuple) end return out_shape end +# The new way to compute the shape (more inferrable than combining cat_size & cat_shape, due to Varargs + issue#36454) +cat_size_shape(dims) = ntuple(zero, Val(length(dims))) +@inline cat_size_shape(dims, X, tail...) = _cat_size_shape(dims, _cshp(1, dims, (), cat_size(X)), tail...) +_cat_size_shape(dims, shape) = shape +@inline _cat_size_shape(dims, shape, X, tail...) = _cat_size_shape(dims, _cshp(1, dims, shape, cat_size(X)), tail...) _cshp(ndim::Int, ::Tuple{}, ::Tuple{}, ::Tuple{}) = () _cshp(ndim::Int, ::Tuple{}, ::Tuple{}, nshape) = nshape @@ -1631,7 +1637,7 @@ _cat(dims, X...) = cat_t(promote_eltypeof(X...), X...; dims=dims) @inline cat_t(::Type{T}, X...; dims) where {T} = _cat_t(dims, T, X...) @inline function _cat_t(dims, ::Type{T}, X...) where {T} catdims = dims2cat(dims) - shape = cat_shape(catdims, map(cat_size, X)) + shape = cat_size_shape(catdims, X...) A = cat_similar(X[1], T, shape) if count(!iszero, catdims)::Int > 1 fill!(A, zero(T)) diff --git a/test/abstractarray.jl b/test/abstractarray.jl index f00f1f80332bb..52af916acbdac 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -692,6 +692,12 @@ function test_cat(::Type{TestAbstractArray}) # 36041 @test_throws MethodError cat(["a"], ["b"], dims=[1, 2]) @test cat([1], [1], dims=[1, 2]) == I(2) + + # inferrability + As = [zeros(2, 2) for _ = 1:2] + @test @inferred(cat(As...; dims=Val(3))) == zeros(2, 2, 2) + cat3v(As) = cat(As...; dims=Val(3)) + @test @inferred(cat3v(As)) == zeros(2, 2, 2) end function test_ind2sub(::Type{TestAbstractArray})