Skip to content

Commit

Permalink
Use sparse triangular solvers for sparse triangular solves. Fixes #13…
Browse files Browse the repository at this point in the history
…792.

Make fwd/bwdTriSolve! work for triagular views

Add check for triangular matrices in sparse factorize
  • Loading branch information
andreasnoack committed Nov 1, 2015
1 parent ee584e8 commit 81753cc
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 36 deletions.
2 changes: 1 addition & 1 deletion base/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using Base: Func, AddFun, OrFun, ConjFun, IdFun
using Base.Sort: Forward
using Base.LinAlg: AbstractTriangular, PosDefException

import Base: +, -, *, &, |, $, .+, .-, .*, ./, .\, .^, .<, .!=, ==
import Base: +, -, *, \, &, |, $, .+, .-, .*, ./, .\, .^, .<, .!=, ==
import Base: A_mul_B!, Ac_mul_B, Ac_mul_B!, At_mul_B, At_mul_B!, A_ldiv_B!

import Base: @get!, acos, acosd, acot, acotd, acsch, asech, asin, asind, asinh,
Expand Down
96 changes: 61 additions & 35 deletions base/sparse/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,16 +167,10 @@ end
## solvers
function fwdTriSolve!(A::SparseMatrixCSC, B::AbstractVecOrMat)
# forward substitution for CSC matrices
n = length(B)
if isa(B, Vector)
nrowB = n
ncolB = 1
else
nrowB, ncolB = size(B)
end
ncol = chksquare(A)
nrowB, ncolB = size(B, 1), size(B, 2)
ncol = LinAlg.chksquare(A)
if nrowB != ncol
throw(DimensionMismatch("A is $(ncol)X$(ncol) and B has length $(n)"))
throw(DimensionMismatch("A is $(ncol) columns and B has $(nrowB) rows"))
end

aa = A.nzval
Expand All @@ -185,56 +179,88 @@ function fwdTriSolve!(A::SparseMatrixCSC, B::AbstractVecOrMat)

joff = 0
for k = 1:ncolB
for j = 1:(nrowB-1)
jb = joff + j
for j = 1:nrowB
i1 = ia[j]
i2 = ia[j+1]-1
B[jb] /= aa[i1]
bj = B[jb]
for i = i1+1:i2
B[joff+ja[i]] -= bj*aa[i]
i2 = ia[j + 1] - 1

# loop through the structural zeros
ii = i1
jai = ja[ii]
while ii <= i2 && jai < j
ii += 1
jai = ja[ii]
end

# check for zero pivot and divide with pivot
if jai == j
bj = B[joff + jai]/aa[ii]
B[joff + jai] = bj
ii += 1
else
throw(LinAlg.SingularException(j))
end

# update remaining part
for i = ii:i2
B[joff + ja[i]] -= bj*aa[i]
end
end
joff += nrowB
B[joff] /= aa[end]
end
return B
B
end

function bwdTriSolve!(A::SparseMatrixCSC, B::AbstractVecOrMat)
# backward substitution for CSC matrices
n = length(B)
if isa(B, Vector)
nrowB = n
ncolB = 1
else
nrowB, ncolB = size(B)
nrowB, ncolB = size(B, 1), size(B, 2)
ncol = LinAlg.chksquare(A)
if nrowB != ncol
throw(DimensionMismatch("A is $(ncol) columns and B has $(nrowB) rows"))
end
ncol = chksquare(A)
if nrowB != ncol throw(DimensionMismatch("A is $(ncol)X$(ncol) and B has length $(n)")) end

aa = A.nzval
ja = A.rowval
ia = A.colptr

joff = 0
for k = 1:ncolB
for j = nrowB:-1:2
jb = joff + j
for j = nrowB:-1:1
i1 = ia[j]
i2 = ia[j+1]-1
B[jb] /= aa[i2]
bj = B[jb]
for i = i2-1:-1:i1
B[joff+ja[i]] -= bj*aa[i]
i2 = ia[j + 1] - 1

# loop through the structural zeros
ii = i2
jai = ja[ii]
while ii >= i1 && jai > j
ii -= 1
jai = ja[ii]
end

# check for zero pivot and divide with pivot
if jai == j
bj = B[joff + jai]/aa[ii]
B[joff + jai] = bj
ii -= 1
else
throw(LinAlg.SingularException(j))
end

# update remaining part
for i = ii:-1:i1
B[joff + ja[i]] -= bj*aa[i]
end
end
B[joff+1] /= aa[1]
joff += nrowB
end
return B
B
end

A_ldiv_B!{T,Ti}(L::LowerTriangular{T,SparseMatrixCSC{T,Ti}}, B::StridedVecOrMat) = fwdTriSolve!(L.data, B)
A_ldiv_B!{T,Ti}(U::UpperTriangular{T,SparseMatrixCSC{T,Ti}}, B::StridedVecOrMat) = bwdTriSolve!(U.data, B)

(\){T,Ti}(L::LowerTriangular{T,SparseMatrixCSC{T,Ti}}, B::SparseMatrixCSC) = A_ldiv_B!(L, full(B))
(\){T,Ti}(U::UpperTriangular{T,SparseMatrixCSC{T,Ti}}, B::SparseMatrixCSC) = A_ldiv_B!(U, full(B))

## triu, tril

function triu{Tv,Ti}(S::SparseMatrixCSC{Tv,Ti}, k::Integer=0)
Expand Down
12 changes: 12 additions & 0 deletions test/sparsedir/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1167,3 +1167,15 @@ let
@test_throws ErrorException eig(A)
@test_throws ErrorException inv(A)
end

let
n = 100
A = sprandn(n, n, 0.5) + sqrt(n)*I
x = LowerTriangular(A)*ones(n)
@test LowerTriangular(A)\x ones(n)
x = UpperTriangular(A)*ones(n)
@test UpperTriangular(A)\x ones(n)
A[2,2] = 0
@test_throws LinAlg.SingularException LowerTriangular(A)\ones(n)
@test_throws LinAlg.SingularException UpperTriangular(A)\ones(n)
end

0 comments on commit 81753cc

Please sign in to comment.