Skip to content

Commit

Permalink
Merge pull request #16270 from martinholters/sparse_setindex_checkbounds
Browse files Browse the repository at this point in the history
Minor fixes to sparse setindex!
  • Loading branch information
tkelman committed May 17, 2016
2 parents a717fbe + de810b4 commit 807ec46
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 4 deletions.
2 changes: 1 addition & 1 deletion base/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

module SparseArrays

using Base: ReshapedArray
using Base: ReshapedArray, setindex_shape_check
using Base.Sort: Forward
using Base.LinAlg: AbstractTriangular, PosDefException

Expand Down
33 changes: 30 additions & 3 deletions base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2265,7 +2265,15 @@ function spset!{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv}, x::Tv, I::AbstractVector

m, n = size(A)
lenI = length(I)
((I[end] > m) || (J[end] > n)) && throw(DimensionMismatch(""))

if (!isempty(I) && (I[1] < 1 || I[end] > m)) || (!isempty(J) && (J[1] < 1 || J[end] > n))
throw(BoundsError(A, (I, J)))
end

if isempty(I) || isempty(J)
return A
end

nnzA = nnz(A) + lenI * length(J)

colptr = A.colptr
Expand Down Expand Up @@ -2371,7 +2379,13 @@ function spdelete!{Tv,Ti<:Integer}(A::SparseMatrixCSC{Tv}, I::AbstractVector{Ti}
!issorted(I) && (I = sort(I))
!issorted(J) && (J = sort(J))

((I[end] > m) || (J[end] > n)) && throw(DimensionMismatch(""))
if (!isempty(I) && (I[1] < 1 || I[end] > m)) || (!isempty(J) && (J[1] < 1 || J[end] > n))
throw(BoundsError(A, (I, J)))
end

if isempty(I) || isempty(J)
return A
end

colptr = A.colptr
rowval = rowvalA = A.rowval
Expand Down Expand Up @@ -2451,6 +2465,14 @@ function setindex!{Tv,Ti,T<:Integer}(A::SparseMatrixCSC{Tv,Ti}, B::SparseMatrixC
m, n = size(A)
mB, nB = size(B)

if (!isempty(I) && (I[1] < 1 || I[end] > m)) || (!isempty(J) && (J[1] < 1 || J[end] > n))
throw(BoundsError(A, (I, J)))
end

if isempty(I) || isempty(J)
return A
end

nI = length(I)
nJ = length(J)

Expand Down Expand Up @@ -2671,14 +2693,19 @@ function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVecto
S = issorted(I) ? (1:n) : sortperm(I)
sxidx = r1 = r2 = 0

if (!isempty(I) && (I[S[1]] < 1 || I[S[end]] > length(A)))
throw(BoundsError(A, I))
end

isa(x, AbstractArray) && setindex_shape_check(x, length(I))

lastcol = 0
(nrowA, ncolA) = szA
@inbounds for xidx in 1:n
sxidx = S[xidx]
(sxidx < n) && (I[sxidx] == I[sxidx+1]) && continue

row,col = ind2sub(szA, I[sxidx])
((row > nrowA) || (col > ncolA)) && throw(BoundsError())
v = isa(x, AbstractArray) ? x[sxidx] : x

if col > lastcol
Expand Down
34 changes: 34 additions & 0 deletions test/sparsedir/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,40 @@ let a = spzeros(Int, 10, 10)
@test a[1,:] == sparse([1:10;])
a[:,2] = 1:10
@test a[:,2] == sparse([1:10;])

a[1,1:0] = []
@test a[1,:] == sparse([1; 1; 3:10])
a[1:0,2] = []
@test a[:,2] == sparse([1:10;])
a[1,1:0] = 0
@test a[1,:] == sparse([1; 1; 3:10])
a[1:0,2] = 0
@test a[:,2] == sparse([1:10;])
a[1,1:0] = 1
@test a[1,:] == sparse([1; 1; 3:10])
a[1:0,2] = 1
@test a[:,2] == sparse([1:10;])

@test_throws BoundsError a[:,11] = spzeros(10,1)
@test_throws BoundsError a[11,:] = spzeros(1,10)
@test_throws BoundsError a[:,-1] = spzeros(10,1)
@test_throws BoundsError a[-1,:] = spzeros(1,10)
@test_throws BoundsError a[0:9] = spzeros(1,10)
@test_throws BoundsError a[:,11] = 0
@test_throws BoundsError a[11,:] = 0
@test_throws BoundsError a[:,-1] = 0
@test_throws BoundsError a[-1,:] = 0
@test_throws BoundsError a[0:9] = 0
@test_throws BoundsError a[:,11] = 1
@test_throws BoundsError a[11,:] = 1
@test_throws BoundsError a[:,-1] = 1
@test_throws BoundsError a[-1,:] = 1
@test_throws BoundsError a[0:9] = 1

@test_throws DimensionMismatch a[1:2,1:2] = 1:3
@test_throws DimensionMismatch a[1:2,1] = 1:3
@test_throws DimensionMismatch a[1,1:2] = 1:3
@test_throws DimensionMismatch a[1:2] = 1:3
end

let A = spzeros(Int, 10, 20)
Expand Down

0 comments on commit 807ec46

Please sign in to comment.