Skip to content

Commit

Permalink
add non-checking and checking constructor - improve check performance
Browse files Browse the repository at this point in the history
  • Loading branch information
KlausC committed Jun 18, 2019
1 parent c5fc16e commit d545e34
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 26 deletions.
44 changes: 20 additions & 24 deletions stdlib/SparseArrays/src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,52 +23,48 @@ struct SparseMatrixCSC{Tv,Ti<:Integer} <: AbstractSparseMatrix{Tv,Ti}
rowval::Vector{Ti} # Row indices of stored values
nzval::Vector{Tv} # Stored values, typically nonzeros

function SparseMatrixCSC{Tv,Ti}(m::Integer, n::Integer, colptr::Vector{Ti}, rowval::Vector{Ti},
nzval::Vector{Tv}) where {Tv,Ti<:Integer}

sparse_check_Ti(m, n, Ti)
sparse_check(n, colptr, rowval, nzval)
# silently shorten rowval and nzval to usable index positions.
maxlen = widemul(m, n)
isbitstype(Ti) && (maxlen = min(maxlen, typemax(Ti) - 1))
length(rowval) > maxlen && resize!(rowval, maxlen)
length(nzval) > maxlen && resize!(nzval, maxlen)
function SparseMatrixCSC{Tv,Ti}(m::Integer, n::Integer, colptr::Vector{Ti},
rowval::Vector{Ti}, nzval::Vector{Tv}) where {Tv,Ti<:Integer}
@noinline throwsz(str, lbl, k) =
throw(ArgumentError("number of $str ($lbl) must be ≥ 0, got $k"))
m < 0 && throwsz("rows", 'm', m)
n < 0 && throwsz("columns", 'n', n)
new(Int(m), Int(n), colptr, rowval, nzval)
end
end
function SparseMatrixCSC(m::Integer, n::Integer, colptr::Vector, rowval::Vector, nzval::Vector)
Tv = eltype(nzval)
Ti = promote_type(eltype(colptr), eltype(rowval))
sparse_check_Ti(m, n, Ti)
sparse_check(n, colptr, rowval, nzval)
# silently shorten rowval and nzval to usable index positions.
maxlen = abs(widemul(m, n))
isbitstype(Ti) && (maxlen = min(maxlen, typemax(Ti) - 1))
length(rowval) > maxlen && resize!(rowval, maxlen)
length(nzval) > maxlen && resize!(nzval, maxlen)
SparseMatrixCSC{Tv,Ti}(m, n, colptr, rowval, nzval)
end

@noinline function sparse_check_Ti(m::Integer, n::Integer, Ti::Type)
@noinline throwsz(str, lbl, k) =
throw(ArgumentError("number of $str ($lbl) must be ≥ 0, got $k"))
function sparse_check_Ti(m::Integer, n::Integer, Ti::Type)
@noinline throwTi(str, lbl, k) =
throw(ArgumentError("$str ($lbl = $k) does not fit in Ti = $(Ti)"))
m < 0 && throwsz("rows", 'm', m)
n < 0 && throwsz("columns", 'n', n)
!isbitstype(Ti) || m typemax(Ti) || throwTi("number of rows", "m", m)
!isbitstype(Ti) || n typemax(Ti) || throwTi("number of columns", "n", n)
0 m && (!isbitstype(Ti) || m typemax(Ti)) || throwTi("number of rows", "m", m)
0 n && (!isbitstype(Ti) || n typemax(Ti)) || throwTi("number of columns", "n", n)
end

@noinline function sparse_check(n::Integer, colptr::Vector{Ti}, rowval, nzval) where Ti
nc = length(colptr)
function sparse_check(n::Integer, colptr::Vector{Ti}, rowval, nzval) where Ti
sparse_check_length("colptr", colptr, n+1, String) # don't check upper bound
ckp = Ti(1)
ckp == colptr[1] || throw(ArgumentError("$ckp == colptr[1] != 1"))
k = 1
while k <= n + 1
@inbounds for k = 2:n+1
ck = colptr[k]
ckp <= ck || throw(ArgumentError("$ckp == colptr[$(k-1)] > colptr[$k] == $ck"))
ckp = ck
k += 1
end
sparse_check_length("rowval", rowval, ckp-1, Ti)
sparse_check_length("nzval", nzval, 0, Ti) # we allow empty nzval !!!
end
@noinline function sparse_check_length(rowstr, rowval, minlen, Ti)
function sparse_check_length(rowstr, rowval, minlen, Ti)
len = length(rowval)
len >= minlen || throw(ArgumentError("$len == length($rowstr) < $minlen"))
!isbitstype(Ti) || len < typemax(Ti) ||
Expand Down Expand Up @@ -633,7 +629,7 @@ function sparse!(I::AbstractVector{Ti}, J::AbstractVector{Ti},
# Compute the CSR form's row counts and store them shifted forward by one in csrrowptr
fill!(csrrowptr, Tj(0))
coolen = length(I)
min(length(J), length(V)) >= coolen || throw(ArgumentError("I and V need length >= length(I) = $coolen"))
min(length(J), length(V)) >= coolen || throw(ArgumentError("J and V need length >= length(I) = $coolen"))
@inbounds for k in 1:coolen
Ik = I[k]
if 1 > Ik || m < Ik
Expand Down
6 changes: 4 additions & 2 deletions stdlib/SparseArrays/test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2626,9 +2626,11 @@ end
# nzval short
@test SparseMatrixCSC(10, 3, [1,2,2,4], [1,2,3], Float64[]) !== nothing
# length(rowval) >= typemax
@test_throws ArgumentError SparseMatrixCSC{Int,Int8}(5, 1, Int8[1,2], fill(Int8(1),127), Int[1,2,3])
@test_throws ArgumentError SparseMatrixCSC(5, 1, Int8[1,2], fill(Int8(1),127), Int[1,2,3])
@test SparseMatrixCSC{Int,Int8}(5, 1, Int8[1,2], fill(Int8(1),127), Int[1,2,3]) != 0
# length(nzval) >= typemax
@test_throws ArgumentError SparseMatrixCSC{Int,Int8}(5, 1, Int8[1,2], Int8[1], fill(7, 127))
@test_throws ArgumentError SparseMatrixCSC(5, 1, Int8[1,2], Int8[1], fill(7, 127))
@test SparseMatrixCSC{Int,Int8}(5, 1, Int8[1,2], Int8[1], fill(7, 127)) != 0

# length(I) >= typemax
@test_throws ArgumentError sparse(UInt8.(1:255), fill(UInt8(1), 255), fill(1, 255))
Expand Down

0 comments on commit d545e34

Please sign in to comment.