Skip to content

Commit

Permalink
1d set/getindex for abstract sparse
Browse files Browse the repository at this point in the history
  • Loading branch information
tanmaykm committed Jun 4, 2014
1 parent 182f791 commit 2d30e1e
Show file tree
Hide file tree
Showing 3 changed files with 408 additions and 1 deletion.
53 changes: 53 additions & 0 deletions base/sparse/abstractsparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,56 @@ issparse(A::AbstractArray) = false
issparse(S::AbstractSparseArray) = true

indtype{Tv,Ti}(S::AbstractSparseArray{Tv,Ti}) = Ti

function getindex{T<:AbstractSparseMatrix}(A::T, I::AbstractArray{Bool})
checkbounds(A, I)
n = sum(I)
out = similar(A, n, 1)
c = 1
for i = 1:length(I)
if I[i]
out[c] = A[i]
c += 1
end
end
out
end

function getindex{T<:AbstractSparseMatrix}(A::T, I::AbstractArray)
x = similar(A, size(I,1), size(I,2))
for i=1:length(I)
x[i] = A[I[i]]
end
return x
end


function setindex!{S<:AbstractSparseMatrix}(A::S, x, I::AbstractVector)
if isa(x, AbstractArray)
for i in I
A[i] = x[i]
end
else
for i in I
A[i] = x
end
end
return A
end

function setindex!{S<:AbstractSparseMatrix}(A::S, x, I::AbstractArray{Bool,2})
checkbounds(A, I)
if isa(x, AbstractArray)
c = 1
for i = 1:length(I)
I[i] && (A[i] = X[c]; c += 1)
end
(length(X) == c-1) || throw(DimensionMismatch("assigned $(length(X)) elements to length $(c-1) destination"))
else
for i = 1:length(I)
I[i] && (A[i] = x)
end
end
A
end

316 changes: 315 additions & 1 deletion base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ copy(S::SparseMatrixCSC) =
similar(S::SparseMatrixCSC, Tv::NonTupleType=eltype(S)) = SparseMatrixCSC(S.m, S.n, copy(S.colptr), copy(S.rowval), Array(Tv, length(S.nzval)))
similar{Tv,Ti,TvNew}(S::SparseMatrixCSC{Tv,Ti}, ::Type{TvNew}, ::Type{Ti}) = similar(S, TvNew)
similar{Tv,Ti,TvNew,TiNew}(S::SparseMatrixCSC{Tv,Ti}, ::Type{TvNew}, ::Type{TiNew}) = SparseMatrixCSC(S.m, S.n, convert(Array{TiNew},S.colptr), convert(Array{TiNew}, S.rowval), Array(TvNew, length(S.nzval)))
similar(S::SparseMatrixCSC, Tv::Type, d::(Integer,Integer)) = spzeros(Tv, d[1], d[2])
similar{Tv}(S::SparseMatrixCSC, ::Type{Tv}, d::NTuple{Integer}) = spzeros(Tv, d...)

function convert{Tv,Ti,TvS,TiS}(::Type{SparseMatrixCSC{Tv,Ti}}, S::SparseMatrixCSC{TvS,TiS})
if Tv == TvS && Ti == TiS
Expand Down Expand Up @@ -1029,6 +1029,84 @@ getindex(A::SparseMatrixCSC, I::AbstractVector{Bool}, J::AbstractVector{Bool}) =
getindex{T<:Integer}(A::SparseMatrixCSC, I::AbstractVector{T}, J::AbstractVector{Bool}) = A[I,find(J)]
getindex{T<:Integer}(A::SparseMatrixCSC, I::AbstractVector{Bool}, J::AbstractVector{T}) = A[find(I),J]

function getindex{Tv}(A::SparseMatrixCSC{Tv}, I::AbstractArray{Bool})
checkbounds(A, I)
n = sum(I)

colptrA = A.colptr; rowvalA = A.rowval; nzvalA = A.nzval
colptrB = Int[1,n+1]
rowvalB = Array(Int, n)
nzvalB = Array(Tv, n)
c = 1
rowB = 1

for col in 1:A.n
r1 = colptrA[col]
r2 = colptrA[col+1]-1

for row in 1:A.m
if I[row, col]
while (r1 <= r2) && (rowvalA[r1] < row)
r1 += 1
end
if (r1 <= r2) && (rowvalA[r1] == row)
nzvalB[c] = nzvalA[r1]
rowvalB[c] = rowB
c += 1
end
rowB += 1
(rowB > n) && break
end
end
(rowB > n) && break
end
colptrB[end] = c
n = length(nzvalB)
if n > (c-1)
deleteat!(nzvalB, c:n)
deleteat!(rowvalB, c:n)
end
SparseMatrixCSC(n, 1, colptrB, rowvalB, nzvalB)
end

function getindex{Tv}(A::SparseMatrixCSC{Tv}, I::AbstractArray)
szA = size(A); colptrA = A.colptr; rowvalA = A.rowval; nzvalA = A.nzval

n = length(I)
outm = size(I,1)
outn = size(I,2)
szB = (outm, outn)
colptrB = zeros(Int, outn+1)
rowvalB = Array(Int, n)
nzvalB = Array(Tv, n)

colB = 1
rowB = 1
colptrB[colB] = 1
idxB = 1

for i in 1:n
row,col = ind2sub(szA, I[i])
for r in colptrA[col]:(colptrA[col+1]-1)
if rowvalA[r] == row
rowB,colB = ind2sub(szB, i)
colptrB[colB+1] += 1
rowvalB[idxB] = rowB
nzvalB[idxB] = nzvalA[r]
idxB += 1
break
end
end
end
colptrB = cumsum(colptrB)
if n > (idxB-1)
deleteat!(nzvalB, idxB:n)
deleteat!(rowvalB, idxB:n)
end
SparseMatrixCSC(outm, outn, colptrB, rowvalB, nzvalB)
end


## setindex!
setindex!(A::SparseMatrixCSC, v, i::Integer) = setindex!(A, v, ind2sub(size(A),i)...)

Expand Down Expand Up @@ -1451,6 +1529,242 @@ setindex!(A::Matrix, x::SparseMatrixCSC, I::AbstractVector{Bool}, J::AbstractVec
setindex!{T<:Integer}(A::Matrix, x::SparseMatrixCSC, I::AbstractVector{T}, J::AbstractVector{Bool}) = setindex!(A, full(x), I, find(J))
setindex!{T<:Integer}(A::Matrix, x::SparseMatrixCSC, I::AbstractVector{Bool}, J::AbstractVector{T}) = setindex!(A, full(x), find(I), J)

function setindex!{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractArray{Bool,2})
checkbounds(A, I)
n = sum(I)
(n == 0) && (return A)

colptrA = A.colptr; rowvalA = A.rowval; nzvalA = A.nzval
colptrB = colptrA; rowvalB = rowvalA; nzvalB = nzvalA
nadd = ndel = 0
bidx = xidx = 1
ridx = r1 = r2 = 0
last = mid = midval = 0

for col in 1:A.n
r1 = int(colptrA[col])
r2 = int(colptrA[col+1]-1)

for row in 1:A.m
if I[row, col]
v = isa(x, AbstractArray) ? x[xidx] : x
xidx += 1

if r1 <= r2
ridx = r1
last = r2
@inbounds while ridx <= last
mid = (ridx + last) >> 1
midval = int(rowvalA[mid])
if midval > row
last = mid - 1
elseif midval == row
ridx = mid
break
else
ridx = mid + 1
end
end

copylen = ridx - r1
if (copylen > 0)
if (nadd > 0) || (ndel > 0)
copy!(rowvalB, bidx, rowvalA, r1, copylen)
copy!(nzvalB, bidx, nzvalA, r1, copylen)
end
bidx += copylen
r1 += copylen
end
end

# 0: no change, 1: update, 2: delete, 3: add new
mode = ((r1 <= r2) && (rowvalA[r1] == row)) ? ((v == 0) ? 2 : 1) : ((v == 0) ? 0 : 3)

if (mode > 1) && (nadd == 0) && (ndel == 0)
# copy storage to take changes
colptrB = copy(colptrA)
memreq = (x == 0) ? 0 : n
rowvalB = Array(Ti, length(rowvalA)+memreq); copy!(rowvalB, 1, rowvalA, 1, r1-1)
nzvalB = Array(Tv, length(nzvalA)+memreq); copy!(nzvalB, 1, nzvalA, 1, r1-1)
end
if mode == 1
rowvalB[bidx] = row
nzvalB[bidx] = v
bidx += 1
r1 += 1
elseif mode == 2
r1 += 1
ndel += 1
elseif mode == 3
rowvalB[bidx] = row
nzvalB[bidx] = v
bidx += 1
nadd += 1
end
(xidx > n) && break
end # if I[row, col]
end # for row in 1:A.m

if ((nadd != 0) || (ndel != 0))
l = r2-r1+1
if l > 0
copy!(rowvalB, bidx, rowvalA, r1, l)
copy!(nzvalB, bidx, nzvalA, r1, l)
bidx += l
end
colptrB[col+1] = bidx

if (xidx > n) && (length(colptrB) > (col+1))
diff = nadd - ndel
colptrB[(col+2):end] = colptrA[(col+2):end] .+ diff
r1 = colptrA[col+1]
r2 = colptrA[end]-1
l = r2-r1+1
if l > 0
copy!(rowvalB, bidx, rowvalA, r1, l)
copy!(nzvalB, bidx, nzvalA, r1, l)
bidx += l
end
end
else
bidx = colptrA[col+1]
end
(xidx > n) && break
end # for col in 1:A.n

if (nadd != 0) || (ndel != 0)
n = length(nzvalB)
if n > (bidx-1)
deleteat!(nzvalB, bidx:n)
deleteat!(rowvalB, bidx:n)
end
A.nzval = nzvalB; A.rowval = rowvalB; A.colptr = colptrB
end
A
end


function setindex!{Tv,Ti,T<:Real}(A::SparseMatrixCSC{Tv,Ti}, x, I::AbstractVector{T})
n = length(I)
(n == 0) && (return A)

colptrA = A.colptr; rowvalA = A.rowval; nzvalA = A.nzval; szA = size(A)
colptrB = colptrA; rowvalB = rowvalA; nzvalB = nzvalA
nadd = ndel = 0
bidx = aidx = 1

S = issorted(I) ? (1:n) : sortperm(I)
sxidx = ridx = r1 = r2 = 0
last = mid = midval = 0

lastcol = 0
for xidx in 1:n
sxidx = S[xidx]
(sxidx < n) && (I[sxidx] == I[sxidx+1]) && continue

row,col = ind2sub(szA, I[sxidx])
v = isa(x, AbstractArray) ? x[sxidx] : x

if col > lastcol
r1 = int(colptrA[col])
r2 = int(colptrA[col+1] - 1)

# copy from last position till current column
if (nadd > 0) || (ndel > 0)
colptrB[(lastcol+1):col] = colptrA[(lastcol+1):col] .+ (nadd - ndel)
copylen = r1 - aidx
if copylen > 0
copy!(rowvalB, bidx, rowvalA, aidx, copylen)
copy!(nzvalB, bidx, nzvalA, aidx, copylen)
aidx += copylen
bidx += copylen
end
else
aidx = bidx = r1
end
lastcol = col
end

if r1 <= r2
ridx = r1
last = r2
@inbounds while ridx <= last
mid = (ridx + last) >> 1
midval = int(rowvalA[mid])
if midval > row
last = mid - 1
elseif midval == row
ridx = mid
break
else
ridx = mid + 1
end
end

copylen = ridx - r1
if (copylen > 0)
if (nadd > 0) || (ndel > 0)
copy!(rowvalB, bidx, rowvalA, r1, copylen)
copy!(nzvalB, bidx, nzvalA, r1, copylen)
end
bidx += copylen
r1 += copylen
aidx += copylen
end
end

# 0: no change, 1: update, 2: delete, 3: add new
mode = ((r1 <= r2) && (rowvalA[r1] == row)) ? ((v == 0) ? 2 : 1) : ((v == 0) ? 0 : 3)

if (mode > 1) && (nadd == 0) && (ndel == 0)
# copy storage to take changes
colptrB = copy(colptrA)
memreq = (x == 0) ? 0 : n
rowvalB = Array(Ti, length(rowvalA)+memreq); copy!(rowvalB, 1, rowvalA, 1, r1-1)
nzvalB = Array(Tv, length(nzvalA)+memreq); copy!(nzvalB, 1, nzvalA, 1, r1-1)
end
if mode == 1
rowvalB[bidx] = row
nzvalB[bidx] = v
bidx += 1
aidx += 1
r1 += 1
elseif mode == 2
r1 += 1
aidx += 1
ndel += 1
elseif mode == 3
rowvalB[bidx] = row
nzvalB[bidx] = v
bidx += 1
nadd += 1
end
end

# copy the rest
if (nadd > 0) || (ndel > 0)
colptrB[(lastcol+1):end] = colptrA[(lastcol+1):end] .+ (nadd - ndel)
r1 = colptrA[end]-1
copylen = r1 - aidx + 1
if copylen > 0
copy!(rowvalB, bidx, rowvalA, aidx, copylen)
copy!(nzvalB, bidx, nzvalA, aidx, copylen)
aidx += copylen
bidx += copylen
end

n = length(nzvalB)
if n > (bidx-1)
deleteat!(nzvalB, bidx:n)
deleteat!(rowvalB, bidx:n)
end
A.nzval = nzvalB; A.rowval = rowvalB; A.colptr = colptrB
end
A
end



# Sparse concatenation

function vcat(X::SparseMatrixCSC...)
Expand Down
Loading

0 comments on commit 2d30e1e

Please sign in to comment.