Skip to content

Commit

Permalink
Improve inferability of shape::Dims for cat (#39294)
Browse files Browse the repository at this point in the history
`cat` is often called with Varargs or heterogenous inputs,
and inference almost always fails. Even when all the arrays
are of the same type, if the number of varargs isn't known
inference typically fails. The culprit is probably #36454.

This reduces the number of failures considerably, by avoiding
creation of vararg length tuples in the shape-inference pipeline.
  • Loading branch information
timholy committed Jan 19, 2021
1 parent fb39bdb commit 815076b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
8 changes: 7 additions & 1 deletion base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1580,6 +1580,7 @@ cat_indices(A::AbstractArray, d) = axes(A, d)
cat_similar(A, ::Type{T}, shape) where T = Array{T}(undef, shape)
cat_similar(A::AbstractArray, ::Type{T}, shape) where T = similar(A, T, shape)

# These are for backwards compatibility (even though internal)
cat_shape(dims, shape::Tuple{Vararg{Int}}) = shape
function cat_shape(dims, shapes::Tuple)
out_shape = ()
Expand All @@ -1588,6 +1589,11 @@ function cat_shape(dims, shapes::Tuple)
end
return out_shape
end
# The new way to compute the shape (more inferrable than combining cat_size & cat_shape, due to Varargs + issue#36454)
cat_size_shape(dims) = ntuple(zero, Val(length(dims)))
@inline cat_size_shape(dims, X, tail...) = _cat_size_shape(dims, _cshp(1, dims, (), cat_size(X)), tail...)
_cat_size_shape(dims, shape) = shape
@inline _cat_size_shape(dims, shape, X, tail...) = _cat_size_shape(dims, _cshp(1, dims, shape, cat_size(X)), tail...)

_cshp(ndim::Int, ::Tuple{}, ::Tuple{}, ::Tuple{}) = ()
_cshp(ndim::Int, ::Tuple{}, ::Tuple{}, nshape) = nshape
Expand Down Expand Up @@ -1631,7 +1637,7 @@ _cat(dims, X...) = cat_t(promote_eltypeof(X...), X...; dims=dims)
@inline cat_t(::Type{T}, X...; dims) where {T} = _cat_t(dims, T, X...)
@inline function _cat_t(dims, ::Type{T}, X...) where {T}
catdims = dims2cat(dims)
shape = cat_shape(catdims, map(cat_size, X))
shape = cat_size_shape(catdims, X...)
A = cat_similar(X[1], T, shape)
if count(!iszero, catdims)::Int > 1
fill!(A, zero(T))
Expand Down
6 changes: 6 additions & 0 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,12 @@ function test_cat(::Type{TestAbstractArray})
# 36041
@test_throws MethodError cat(["a"], ["b"], dims=[1, 2])
@test cat([1], [1], dims=[1, 2]) == I(2)

# inferrability
As = [zeros(2, 2) for _ = 1:2]
@test @inferred(cat(As...; dims=Val(3))) == zeros(2, 2, 2)
cat3v(As) = cat(As...; dims=Val(3))
@test @inferred(cat3v(As)) == zeros(2, 2, 2)
end

function test_ind2sub(::Type{TestAbstractArray})
Expand Down

0 comments on commit 815076b

Please sign in to comment.