Skip to content

Commit

Permalink
Merge pull request #7162 from tanmaykm/tanmaykm
Browse files Browse the repository at this point in the history
fix sparse getindex regression
  • Loading branch information
tanmaykm committed Jun 10, 2014
2 parents 9ffbaa8 + 1254792 commit 664cab5
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 29 deletions.
51 changes: 22 additions & 29 deletions base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -712,33 +712,21 @@ prod{T}(A::SparseMatrixCSC{T}, region) = reducedim(*,A,region,one(T))
#sum(A::SparseMatrixCSC{Bool}) = countnz(A)

## getindex
function binarysearch(haystack::AbstractVector, needle, lo::Int, hi::Int)
# Finds the first occurrence of needle in haystack[lo:hi]
lo = lo-1
hi2 = hi
hi = hi+1
@inbounds while lo < hi-1
m = (lo+hi)>>>1
if haystack[m] < needle
lo = m
else
hi = m
end
end
(hi==hi2+1 || haystack[hi]!=needle) ? -1 : hi
end
function rangesearch(haystack::Range, needle)
(i,rem) = divrem(needle - first(haystack), step(haystack))
(rem==0 && 1<=i+1<=length(haystack)) ? i+1 : -1
(rem==0 && 1<=i+1<=length(haystack)) ? i+1 : 0
end

getindex(A::SparseMatrixCSC, i::Integer) = getindex(A, ind2sub(size(A),i))
getindex(A::SparseMatrixCSC, I::(Integer,Integer)) = getindex(A, I[1], I[2])

function getindex{T}(A::SparseMatrixCSC{T}, i0::Integer, i1::Integer)
if !(1 <= i0 <= A.m && 1 <= i1 <= A.n); throw(BoundsError()); end
ind = binarysearch(A.rowval, i0, A.colptr[i1], A.colptr[i1+1]-1)
ind > -1 ? A.nzval[ind] : zero(T)
r1 = int(A.colptr[i1])
r2 = int(A.colptr[i1+1]-1)
(r1 > r2) && return zero(T)
r1 = searchsortedfirst(A.rowval, i0, r1, r2, Forward)
((r1 > r2) || (A.rowval[r1] != i0)) ? zero(T) : A.nzval[r1]
end

getindex{T<:Integer}(A::SparseMatrixCSC, I::AbstractVector{T}, j::Integer) = getindex(A,I,[j])
Expand Down Expand Up @@ -810,7 +798,7 @@ function getindex{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv,Ti}, I::Range, J::Abstra
for k = colptrA[col]:colptrA[col+1]-1
rowA = rowvalA[k]
i = rangesearch(I, rowA)
if i > -1
if i > 0
rowvalS[ptrS] = i
nzvalS[ptrS] = nzvalA[k]
ptrS += 1
Expand Down Expand Up @@ -1073,15 +1061,17 @@ function setindex!{T,Ti}(A::SparseMatrixCSC{T,Ti}, v, i0::Integer, i1::Integer)
i1 = convert(Ti, i1)
if !(1 <= i0 <= A.m && 1 <= i1 <= A.n); throw(BoundsError()); end
v = convert(T, v)
r1 = A.colptr[i1]
r2 = A.colptr[i1+1]-1
r1 = int(A.colptr[i1])
r2 = int(A.colptr[i1+1]-1)
if v == 0 #either do nothing or delete entry if it exists
loc = binarysearch(A.rowval, i0, r1, r2)
if loc != -1
deleteat!(A.rowval, loc)
deleteat!(A.nzval, loc)
for j = (i1+1):(A.n+1)
A.colptr[j] -= 1
if r1 <= r2
r1 = searchsortedfirst(A.rowval, i0, r1, r2, Forward)
if (r1 <= r2) && (A.rowval[r1] == i0)
deleteat!(A.rowval, r1)
deleteat!(A.nzval, r1)
for j = (i1+1):(A.n+1)
A.colptr[j] -= 1
end
end
end
return A
Expand Down Expand Up @@ -1851,8 +1841,11 @@ done(d::SpDiagIterator, j) = j > d.n

function next{Tv}(d::SpDiagIterator{Tv}, j)
A = d.A
idx = binarysearch(A.rowval, j, A.colptr[j], A.colptr[j+1]-1)
((idx == -1) ? zero(Tv) : A.nzval[idx], j+1)
r1 = int(A.colptr[j])
r2 = int(A.colptr[j+1]-1)
(r1 > r2) && (return (zero(Tv), j+1))
r1 = searchsortedfirst(A.rowval, j, r1, r2, Forward)
(((r1 > r2) || (A.rowval[r1] != j)) ? zero(Tv) : A.nzval[r1], j+1)
end

function trace{Tv}(A::SparseMatrixCSC{Tv})
Expand Down
8 changes: 8 additions & 0 deletions test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,14 @@ for (aa116, ss116) in [(a116, s116), (ad116, sd116)]
@test full(ss116[li,lj]) == aa116[li,lj]
end

let S = SparseMatrixCSC(3, 3, Uint8[1,1,1,1], Uint8[], Int64[])
S[1,1] = 1
S[5] = 2
S[end] = 3
@test S[end] == (S[1] + S[2,2])
@test 6 == sum(diag(S))
end


# setindex tests
let a = spzeros(Int, 10, 10)
Expand Down

0 comments on commit 664cab5

Please sign in to comment.