Skip to content

Commit

Permalink
Re-implement sparse setindex! as fill!
Browse files Browse the repository at this point in the history
and simplify the call structure -- integers are just as capable as pseudo-vectors in these functions
  • Loading branch information
mbauman committed Mar 13, 2018
1 parent a5a8fd6 commit 7cfbfcd
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 27 deletions.
2 changes: 1 addition & 1 deletion base/sysimg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ if false
# simple print definitions for debugging. enable these if something
# goes wrong during bootstrap before printing code is available.
# otherwise, they just just eventually get (noisily) overwritten later
global show, print, println, string
global show, print, println
show(io::IO, x) = Core.show(io, x)
print(io::IO, a...) = Core.print(io, a...)
println(io::IO, x...) = Core.println(io, x...)
Expand Down
47 changes: 21 additions & 26 deletions stdlib/SparseArrays/src/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2350,37 +2350,34 @@ function setindex!(A::SparseMatrixCSC{Tv,Ti}, v::Tv, i::Ti, j::Ti) where Tv wher
return A
end

setindex!(A::SparseMatrixCSC, v::AbstractArray, i::Integer, J::AbstractVector{<:Integer}) = setindex!(A, v, [i], J)
setindex!(A::SparseMatrixCSC, v::AbstractArray, I::AbstractVector{<:Integer}, j::Integer) = setindex!(A, v, I, [j])

# Colon translation
setindex!(A::SparseMatrixCSC, x::AbstractArray, ::Colon) = setindex!(A, x, 1:length(A))
setindex!(A::SparseMatrixCSC, x::AbstractArray, ::Colon, ::Colon) = setindex!(A, x, 1:size(A, 1), 1:size(A,2))
setindex!(A::SparseMatrixCSC, x::AbstractArray, ::Colon, j::Union{Integer, AbstractVector}) = setindex!(A, x, 1:size(A, 1), j)
setindex!(A::SparseMatrixCSC, x::AbstractArray, i::Union{Integer, AbstractVector}, ::Colon) = setindex!(A, x, i, 1:size(A, 2))

# TODO: Revamp this guy to use broadcast
# function setindex!(A::SparseMatrixCSC{Tv}, x::Number,
# I::AbstractVector{<:Integer}, J::AbstractVector{<:Integer}) where Tv
# if isempty(I) || isempty(J); return A; end
# # lt=≤ to check for strict sorting
# if !issorted(I, lt=≤); I = sort!(unique(I)); end
# if !issorted(J, lt=≤); J = sort!(unique(J)); end
# if (I[1] < 1 || I[end] > A.m) || (J[1] < 1 || J[end] > A.n)
# throw(BoundsError(A, (I, J)))
# end
# if x == 0
# _spsetz_setindex!(A, I, J)
# else
# _spsetnz_setindex!(A, convert(Tv, x), I, J)
# end
# end
function Base.fill!(V::SubArray{Tv, <:Any, <:SparseMatrixCSC, Tuple{Vararg{Union{Integer, AbstractVector{<:Integer}},2}}}, x) where Tv
A = V.parent
I, J = V.indices
if isempty(I) || isempty(J); return A; end
# lt=≤ to check for strict sorting
if !issorted(I, lt=); I = sort!(unique(I)); end
if !issorted(J, lt=); J = sort!(unique(J)); end
if (I[1] < 1 || I[end] > A.m) || (J[1] < 1 || J[end] > A.n)
throw(BoundsError(A, (I, J)))
end
if x == 0
_spsetz_setindex!(A, I, J)
else
_spsetnz_setindex!(A, convert(Tv, x), I, J)
end
end
"""
Helper method for immediately preceding setindex! method. For all (i,j) such that i in I and
j in J, assigns zero to A[i,j] if A[i,j] is a presently-stored entry, and otherwise does nothing.
"""
function _spsetz_setindex!(A::SparseMatrixCSC,
I::AbstractVector{<:Integer}, J::AbstractVector{<:Integer})
I::Union{Integer, AbstractVector{<:Integer}}, J::Union{Integer, AbstractVector{<:Integer}})
lengthI = length(I)
for j in J
coljAfirstk = A.colptr[j]
Expand Down Expand Up @@ -2416,7 +2413,7 @@ and j in J, assigns x to A[i,j] if A[i,j] is a presently-stored entry, and alloc
assigns x to A[i,j] if A[i,j] is not presently stored.
"""
function _spsetnz_setindex!(A::SparseMatrixCSC{Tv}, x::Tv,
I::AbstractVector{<:Integer}, J::AbstractVector{<:Integer}) where Tv
I::Union{Integer, AbstractVector{<:Integer}}, J::Union{Integer, AbstractVector{<:Integer}}) where Tv
m, n = size(A)
lenI = length(I)

Expand Down Expand Up @@ -2521,16 +2518,14 @@ function _spsetnz_setindex!(A::SparseMatrixCSC{Tv}, x::Tv,
return A
end

setindex!(A::SparseMatrixCSC{Tv,Ti}, S::Matrix, I::AbstractVector{T}, J::AbstractVector{T}) where {Tv,Ti,T<:Integer} =
setindex!(A::SparseMatrixCSC{Tv,Ti}, S::Matrix, I::Union{Integer, AbstractVector{T}}, J::Union{Integer, AbstractVector{T}}) where {Tv,Ti,T<:Integer} =
setindex!(A, convert(SparseMatrixCSC{Tv,Ti}, S), I, J)

setindex!(A::SparseMatrixCSC, v::AbstractVector, I::AbstractVector{<:Integer}, j::Integer) = setindex!(A, v, I, [j])
setindex!(A::SparseMatrixCSC, v::AbstractVector, i::Integer, J::AbstractVector{<:Integer}) = setindex!(A, v, [i], J)
setindex!(A::SparseMatrixCSC, v::AbstractVector, I::AbstractVector{T}, J::AbstractVector{T}) where {T<:Integer} =
setindex!(A::SparseMatrixCSC, v::AbstractVector, I::Union{Integer, AbstractVector{T}}, J::Union{Integer, AbstractVector{T}}) where {T<:Integer} =
setindex!(A, reshape(v, length(I), length(J)), I, J)

# A[I,J] = B
function setindex!(A::SparseMatrixCSC{Tv,Ti}, B::SparseMatrixCSC{Tv,Ti}, I::AbstractVector{T}, J::AbstractVector{T}) where {Tv,Ti,T<:Integer}
function setindex!(A::SparseMatrixCSC{Tv,Ti}, B::SparseMatrixCSC{Tv,Ti}, I::Union{Integer, AbstractVector{T}}, J::Union{Integer, AbstractVector{T}}) where {Tv,Ti,T<:Integer}
if size(B,1) != length(I) || size(B,2) != length(J)
throw(DimensionMismatch(""))
end
Expand Down

0 comments on commit 7cfbfcd

Please sign in to comment.