Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Argument checks for SparseMatrixCSC constructors #31724

Merged
merged 11 commits into from
Jun 26, 2019
4 changes: 2 additions & 2 deletions stdlib/SparseArrays/src/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -224,12 +224,12 @@ _maxnnzfrom(shape::NTuple{2}, A::SparseMatrixCSC) = nnz(A) * div(shape[1], A.m)
@inline _checked_maxnnzbcres(shape::NTuple{1}, As...) = shape[1] != 0 ? _unchecked_maxnnzbcres(shape, As) : 0
@inline _checked_maxnnzbcres(shape::NTuple{2}, As...) = shape[1] != 0 && shape[2] != 0 ? _unchecked_maxnnzbcres(shape, As) : 0
@inline function _allocres(shape::NTuple{1}, indextype, entrytype, maxnnz)
storedinds = Vector{indextype}(undef, maxnnz)
storedinds = ones(indextype, maxnnz)
storedvals = Vector{entrytype}(undef, maxnnz)
return SparseVector(shape..., storedinds, storedvals)
end
@inline function _allocres(shape::NTuple{2}, indextype, entrytype, maxnnz)
pointers = Vector{indextype}(undef, shape[2] + 1)
pointers = ones(indextype, shape[2] + 1)
storedinds = Vector{indextype}(undef, maxnnz)
storedvals = Vector{entrytype}(undef, maxnnz)
return SparseMatrixCSC(shape..., pointers, storedinds, storedvals)
Expand Down
67 changes: 57 additions & 10 deletions stdlib/SparseArrays/src/sparsematrix.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

# Compressed sparse columns data structure
# Assumes that no zeros are stored in the data structure
# No assumptions about stored zeros in the data structure
# Assumes that row values in rowval for each column are sorted
# issorted(rowval[colptr[i]:(colptr[i+1]-1)]) == true
# Assumes that 1 <= colptr[i] <= colptr[i+1] for i in 1..n
# Assumes that nnz <= length(rowval) < typemax(Ti)
# Assumes that nnz <= length(nzval) < typemax(Ti)

"""
SparseMatrixCSC{Tv,Ti<:Integer} <: AbstractSparseMatrix{Tv,Ti}
Expand All @@ -22,10 +25,9 @@ struct SparseMatrixCSC{Tv,Ti<:Integer} <: AbstractSparseMatrix{Tv,Ti}

function SparseMatrixCSC{Tv,Ti}(m::Integer, n::Integer, colptr::Vector{Ti}, rowval::Vector{Ti},
nzval::Vector{Tv}) where {Tv,Ti<:Integer}
@noinline throwsz(str, lbl, k) =
throw(ArgumentError("number of $str ($lbl) must be ≥ 0, got $k"))
m < 0 && throwsz("rows", 'm', m)
n < 0 && throwsz("columns", 'n', n)

sparse_check_Ti(m, n, Ti)
sparse_check(n, colptr, rowval, nzval)
new(Int(m), Int(n), colptr, rowval, nzval)
end
end
Expand All @@ -35,6 +37,39 @@ function SparseMatrixCSC(m::Integer, n::Integer, colptr::Vector, rowval::Vector,
SparseMatrixCSC{Tv,Ti}(m, n, colptr, rowval, nzval)
end

@noinline function sparse_check_Ti(m::Integer, n::Integer, Ti::Type)
@noinline throwsz(str, lbl, k) =
throw(ArgumentError("number of $str ($lbl) must be ≥ 0, got $k"))
@noinline throwTi(str, lbl, k) =
throw(ArgumentError("$str ($lbl = $k) does not fit in Ti = $(Ti)"))
m < 0 && throwsz("rows", 'm', m)
n < 0 && throwsz("columns", 'n', n)
!isbitstype(Ti) || m ≤ typemax(Ti) || throwTi("number of rows", "m", m)
KlausC marked this conversation as resolved.
Show resolved Hide resolved
!isbitstype(Ti) || n ≤ typemax(Ti) || throwTi("number of columns", "n", n)
end

@noinline function sparse_check(n::Integer, colptr::Vector{Ti}, rowval, nzval) where Ti
nc = length(colptr)
sparse_check_length("colptr", colptr, n+1, String) # don't check upper bound
ckp = Ti(1)
ckp == colptr[1] || throw(ArgumentError("$ckp == colptr[1] != 1"))
k = 1
while k <= n + 1
ck = colptr[k]
ckp <= ck || throw(ArgumentError("$ckp == colptr[$(k-1)] > colptr[$k] == $ck"))
ckp = ck
k += 1
end
sparse_check_length("rowval", rowval, ckp-1, Ti)
sparse_check_length("nzval", nzval, 0, Ti) # we allow empty nzval !!!
end
@noinline function sparse_check_length(rowstr, rowval, minlen, Ti)
len = length(rowval)
len >= minlen || throw(ArgumentError("$len == length($rowstr) < $minlen"))
!isbitstype(Ti) || len < typemax(Ti) ||
throw(ArgumentError("$len == length($rowstr) >= $(typemax(Ti))"))
end

size(S::SparseMatrixCSC) = (S.m, S.n)

# Define an alias for views of a SparseMatrixCSC which include all rows and a unit range of the columns.
Expand Down Expand Up @@ -585,9 +620,13 @@ function sparse!(I::AbstractVector{Ti}, J::AbstractVector{Ti},
csccolptr::Vector{Ti}, cscrowval::Vector{Ti}, cscnzval::Vector{Tv}) where {Tv,Ti<:Integer}

require_one_based_indexing(I, J, V)
sparse_check_Ti(m, n, Ti)
sparse_check_length("I", I, 0, Ti)
# Compute the CSR form's row counts and store them shifted forward by one in csrrowptr
fill!(csrrowptr, Ti(0))
coolen = length(I)
min(length(J), length(V)) >= coolen || throw(ArgumentError("I and V need length >= length(I) = $coolen"))
coolen < typemax(Ti) || throw(ArgumentError("length(I) exceeds typemax($Ti)"))
@inbounds for k in 1:coolen
Ik = I[k]
if 1 > Ik || m < Ik
Expand All @@ -613,6 +652,9 @@ function sparse!(I::AbstractVector{Ti}, J::AbstractVector{Ti},
throw(ArgumentError("column indices J[k] must satisfy 1 <= J[k] <= n"))
end
csrk = csrrowptr[Ik+1]
if csrk < Ti(1)
throw(ArgumentError("count of nonzeros in row $Ik exceeds $(typemax(Ti))"))
end
csrrowptr[Ik+1] = csrk + Ti(1)
csrcolval[csrk] = Jk
csrnzval[csrk] = V[k]
Expand Down Expand Up @@ -826,7 +868,7 @@ adjoint!(X::SparseMatrixCSC{Tv,Ti}, A::SparseMatrixCSC{Tv,Ti}) where {Tv,Ti} = f

function ftranspose(A::SparseMatrixCSC{Tv,Ti}, f::Function) where {Tv,Ti}
X = SparseMatrixCSC(A.n, A.m,
Vector{Ti}(undef, A.m+1),
ones(Ti, A.m+1),
Vector{Ti}(undef, nnz(A)),
Vector{Tv}(undef, nnz(A)))
halfperm!(X, A, 1:A.n, f)
Expand Down Expand Up @@ -1045,7 +1087,7 @@ function permute!(X::SparseMatrixCSC{Tv,Ti}, A::SparseMatrixCSC{Tv,Ti},
_checkargs_sourcecompatdest_permute!(A, X)
_checkargs_sourcecompatperms_permute!(A, p, q)
C = SparseMatrixCSC(A.n, A.m,
Vector{Ti}(undef, A.m + 1),
ones(Ti, A.m + 1),
Vector{Ti}(undef, nnz(A)),
Vector{Tv}(undef, nnz(A)))
_checkargs_permutationsvalid_permute!(p, C.colptr, q, X.colptr)
Expand All @@ -1064,7 +1106,7 @@ function permute!(A::SparseMatrixCSC{Tv,Ti}, p::AbstractVector{<:Integer},
q::AbstractVector{<:Integer}) where {Tv,Ti}
_checkargs_sourcecompatperms_permute!(A, p, q)
C = SparseMatrixCSC(A.n, A.m,
Vector{Ti}(undef, A.m + 1),
ones(Ti, A.m + 1),
Vector{Ti}(undef, nnz(A)),
Vector{Tv}(undef, nnz(A)))
workcolptr = Vector{Ti}(undef, A.n + 1)
Expand Down Expand Up @@ -1135,11 +1177,11 @@ function permute(A::SparseMatrixCSC{Tv,Ti}, p::AbstractVector{<:Integer},
q::AbstractVector{<:Integer}) where {Tv,Ti}
_checkargs_sourcecompatperms_permute!(A, p, q)
X = SparseMatrixCSC(A.m, A.n,
Vector{Ti}(undef, A.n + 1),
ones(Ti, A.n + 1),
Vector{Ti}(undef, nnz(A)),
Vector{Tv}(undef, nnz(A)))
C = SparseMatrixCSC(A.n, A.m,
Vector{Ti}(undef, A.m + 1),
ones(Ti, A.m + 1),
Vector{Ti}(undef, nnz(A)),
Vector{Tv}(undef, nnz(A)))
_checkargs_permutationsvalid_permute!(p, C.colptr, q, X.colptr)
Expand Down Expand Up @@ -2404,6 +2446,11 @@ function _setindex_scalar!(A::SparseMatrixCSC{Tv,Ti}, _v, _i::Integer, _j::Integ
# Column j does not contain entry A[i,j]. If v is nonzero, insert entry A[i,j] = v
# and return. If to the contrary v is zero, then simply return.
if v != 0
# throw exception before state is partially modified
!isbitstype(Ti) || A.colptr[A.n+1] < typemax(Ti) ||
throw(ArgumentError("nnz(A) going to exceed typemax(Ti) = $(typemax(Ti))"))

# TODO if nnz(A) < length(rowval/nzval): no need to grow rowval and preserve values
insert!(A.rowval, searchk, i)
insert!(A.nzval, searchk, v)
@simd for m in (j + 1):(A.n + 1)
Expand Down
41 changes: 41 additions & 0 deletions stdlib/SparseArrays/test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2586,4 +2586,45 @@ end
@test sum(x1, dims=2) == sum(x2, dims=2)
end

@testset "Ti cannot store all potential values #31024" begin
A = SparseMatrixCSC(12, 12, fill(Int8(1),13), Int8[], Int[])
@test size(A) == (12,12)
@test nnz(A) == 0
I1 = [Int8(i) for i in 1:20 for _ in 1:20]
J1 = [Int8(i) for _ in 1:20 for i in 1:20]
@test_throws ArgumentError sparse(I1, J1, zero(length(I1)zero(length(I1))))
end

@testset "Typecheck too strict #31435" begin
A = SparseMatrixCSC{Int,Int8}(70, 2, fill(Int8(1), 3), Int8[], Int[])
A[5:67,1:2] .= ones(Int, 63, 2)
@test nnz(A) == 126
# nnz >= typemax
@test_throws ArgumentError A[2,1] = 42

# colptr short
@test_throws ArgumentError SparseMatrixCSC(1, 1, Int[], Int[], Float64[])
# colptr[1] must be 1
@test_throws ArgumentError SparseMatrixCSC(10, 3, [0,1,1,1], Int[], Float64[])
# colptr not ascending
@test_throws ArgumentError SparseMatrixCSC(10, 3, [1,2,1,2], Int[], Float64[])
# rowwal (and nzval) short
@test_throws ArgumentError SparseMatrixCSC(10, 3, [1,2,2,4], [1,2], Float64[])
# nzval short
@test SparseMatrixCSC(10, 3, [1,2,2,4], [1,2,3], Float64[]) !== nothing
# length(rowval) >= typemax
@test_throws ArgumentError SparseMatrixCSC{Int,Int8}(5, 1, Int8[1,2], fill(Int8(1),127), Int[1,2,3])
# length(nzval) >= typemax
@test_throws ArgumentError SparseMatrixCSC{Int,Int8}(5, 1, Int8[1,2], Int8[1], fill(7, 127))

# length(I) >= typemax
@test_throws ArgumentError sparse(UInt8.(1:255), fill(UInt8(1), 255), fill(1, 255))
# m > typemax
@test_throws ArgumentError sparse(UInt8.(1:254), fill(UInt8(1), 254), fill(1, 254), 256, 1)
# n > typemax
@test_throws ArgumentError sparse(UInt8.(1:254), fill(UInt8(1), 254), fill(1, 254), 255, 256)
# n, m maximal
@test sparse(UInt8.(1:254), fill(UInt8(1), 254), fill(1, 254), 255, 255) !== nothing
end

end # module