From 12547920154259bf07f8ced46fe75e02d3a8f7a9 Mon Sep 17 00:00:00 2001 From: tan Date: Sat, 7 Jun 2014 06:28:10 +0530 Subject: [PATCH] fix sparse getindex regression. also replaced binarysearch with methods from sort.jl ref #7131, #7047 --- base/sparse/sparsematrix.jl | 51 ++++++++++++++++--------------------- test/sparse.jl | 8 ++++++ 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/base/sparse/sparsematrix.jl b/base/sparse/sparsematrix.jl index 7478cc19aca30..ec9c5c85b3f99 100644 --- a/base/sparse/sparsematrix.jl +++ b/base/sparse/sparsematrix.jl @@ -712,24 +712,9 @@ prod{T}(A::SparseMatrixCSC{T}, region) = reducedim(*,A,region,one(T)) #sum(A::SparseMatrixCSC{Bool}) = countnz(A) ## getindex -function binarysearch(haystack::AbstractVector, needle, lo::Int, hi::Int) - # Finds the first occurrence of needle in haystack[lo:hi] - lo = lo-1 - hi2 = hi - hi = hi+1 - @inbounds while lo < hi-1 - m = (lo+hi)>>>1 - if haystack[m] < needle - lo = m - else - hi = m - end - end - (hi==hi2+1 || haystack[hi]!=needle) ? -1 : hi -end function rangesearch(haystack::Range, needle) (i,rem) = divrem(needle - first(haystack), step(haystack)) - (rem==0 && 1<=i+1<=length(haystack)) ? i+1 : -1 + (rem==0 && 1<=i+1<=length(haystack)) ? i+1 : 0 end getindex(A::SparseMatrixCSC, i::Integer) = getindex(A, ind2sub(size(A),i)) @@ -737,8 +722,11 @@ getindex(A::SparseMatrixCSC, I::(Integer,Integer)) = getindex(A, I[1], I[2]) function getindex{T}(A::SparseMatrixCSC{T}, i0::Integer, i1::Integer) if !(1 <= i0 <= A.m && 1 <= i1 <= A.n); throw(BoundsError()); end - ind = binarysearch(A.rowval, i0, A.colptr[i1], A.colptr[i1+1]-1) - ind > -1 ? A.nzval[ind] : zero(T) + r1 = int(A.colptr[i1]) + r2 = int(A.colptr[i1+1]-1) + (r1 > r2) && return zero(T) + r1 = searchsortedfirst(A.rowval, i0, r1, r2, Forward) + ((r1 > r2) || (A.rowval[r1] != i0)) ? zero(T) : A.nzval[r1] end getindex{T<:Integer}(A::SparseMatrixCSC, I::AbstractVector{T}, j::Integer) = getindex(A,I,[j]) @@ -810,7 +798,7 @@ function getindex{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv,Ti}, I::Range, J::Abstra for k = colptrA[col]:colptrA[col+1]-1 rowA = rowvalA[k] i = rangesearch(I, rowA) - if i > -1 + if i > 0 rowvalS[ptrS] = i nzvalS[ptrS] = nzvalA[k] ptrS += 1 @@ -1073,15 +1061,17 @@ function setindex!{T,Ti}(A::SparseMatrixCSC{T,Ti}, v, i0::Integer, i1::Integer) i1 = convert(Ti, i1) if !(1 <= i0 <= A.m && 1 <= i1 <= A.n); throw(BoundsError()); end v = convert(T, v) - r1 = A.colptr[i1] - r2 = A.colptr[i1+1]-1 + r1 = int(A.colptr[i1]) + r2 = int(A.colptr[i1+1]-1) if v == 0 #either do nothing or delete entry if it exists - loc = binarysearch(A.rowval, i0, r1, r2) - if loc != -1 - deleteat!(A.rowval, loc) - deleteat!(A.nzval, loc) - for j = (i1+1):(A.n+1) - A.colptr[j] -= 1 + if r1 <= r2 + r1 = searchsortedfirst(A.rowval, i0, r1, r2, Forward) + if (r1 <= r2) && (A.rowval[r1] == i0) + deleteat!(A.rowval, r1) + deleteat!(A.nzval, r1) + for j = (i1+1):(A.n+1) + A.colptr[j] -= 1 + end end end return A @@ -1851,8 +1841,11 @@ done(d::SpDiagIterator, j) = j > d.n function next{Tv}(d::SpDiagIterator{Tv}, j) A = d.A - idx = binarysearch(A.rowval, j, A.colptr[j], A.colptr[j+1]-1) - ((idx == -1) ? zero(Tv) : A.nzval[idx], j+1) + r1 = int(A.colptr[j]) + r2 = int(A.colptr[j+1]-1) + (r1 > r2) && (return (zero(Tv), j+1)) + r1 = searchsortedfirst(A.rowval, j, r1, r2, Forward) + (((r1 > r2) || (A.rowval[r1] != j)) ? zero(Tv) : A.nzval[r1], j+1) end function trace{Tv}(A::SparseMatrixCSC{Tv}) diff --git a/test/sparse.jl b/test/sparse.jl index 4e93fa42ca324..14f56476edb69 100644 --- a/test/sparse.jl +++ b/test/sparse.jl @@ -289,6 +289,14 @@ for (aa116, ss116) in [(a116, s116), (ad116, sd116)] @test full(ss116[li,lj]) == aa116[li,lj] end +let S = SparseMatrixCSC(3, 3, Uint8[1,1,1,1], Uint8[], Int64[]) + S[1,1] = 1 + S[5] = 2 + S[end] = 3 + @test S[end] == (S[1] + S[2,2]) + @test 6 == sum(diag(S)) +end + # setindex tests let a = spzeros(Int, 10, 10)