Skip to content

Commit

Permalink
[WIP] Speed up dense-sparse matmul (#38876)
Browse files Browse the repository at this point in the history
* Speed up dense-sparse matmul

* add one at-simd, minor edits

* improve A_mul_Bq for dense-sparse

* revert ineffective changes

* shift at-inbounds annotation
  • Loading branch information
dkarrasch committed Jan 12, 2021
1 parent 3d1598e commit a3369df
Showing 1 changed file with 65 additions and 93 deletions.
158 changes: 65 additions & 93 deletions stdlib/SparseArrays/src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ function mul!(C::StridedVecOrMat, A::AbstractSparseMatrixCSC, B::Union{StridedVe
if β != 1
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
end
for k = 1:size(C, 2)
@inbounds for col = 1:size(A, 2)
for k in 1:size(C, 2)
@inbounds for col in 1:size(A, 2)
αxj = B[col,k] * α
for j = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1)
for j in nzrange(A, col)
C[rv[j], k] += nzv[j]*αxj
end
end
Expand All @@ -49,67 +49,38 @@ end
*(A::SparseMatrixCSCUnion{TA}, B::AdjOrTransStridedOrTriangularMatrix{Tx}) where {TA,Tx} =
(T = promote_op(matprod, TA, Tx); mul!(similar(B, T, (size(A, 1), size(B, 2))), A, B, true, false))

function mul!(C::StridedVecOrMat, adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}, α::Number, β::Number)
A = adjA.parent
size(A, 2) == size(C, 1) || throw(DimensionMismatch())
size(A, 1) == size(B, 1) || throw(DimensionMismatch())
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
nzv = nonzeros(A)
rv = rowvals(A)
if β != 1
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
end
for k = 1:size(C, 2)
@inbounds for col = 1:size(A, 2)
tmp = zero(eltype(C))
for j = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1)
tmp += adjoint(nzv[j])*B[rv[j],k]
for (T, t) in ((Adjoint, adjoint), (Transpose, transpose))
@eval function mul!(C::StridedVecOrMat, xA::$T{<:Any,<:AbstractSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}, α::Number, β::Number)
A = xA.parent
size(A, 2) == size(C, 1) || throw(DimensionMismatch())
size(A, 1) == size(B, 1) || throw(DimensionMismatch())
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
nzv = nonzeros(A)
rv = rowvals(A)
if β != 1
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
end
for k in 1:size(C, 2)
@inbounds for col in 1:size(A, 2)
tmp = zero(eltype(C))
for j in nzrange(A, col)
tmp += $t(nzv[j])*B[rv[j],k]
end
C[col,k] += tmp * α
end
C[col,k] += tmp * α
end
C
end
C
end
*(adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, x::StridedVector{Tx}) where {Tx} =
(T = promote_op(matprod, eltype(adjA), Tx); mul!(similar(x, T, size(adjA, 1)), adjA, x, true, false))
*(adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, B::AdjOrTransStridedOrTriangularMatrix) =
(T = promote_op(matprod, eltype(adjA), eltype(B)); mul!(similar(B, T, (size(adjA, 1), size(B, 2))), adjA, B, true, false))

function mul!(C::StridedVecOrMat, transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}, α::Number, β::Number)
A = transA.parent
size(A, 2) == size(C, 1) || throw(DimensionMismatch())
size(A, 1) == size(B, 1) || throw(DimensionMismatch())
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
nzv = nonzeros(A)
rv = rowvals(A)
if β != 1
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
end
for k = 1:size(C, 2)
@inbounds for col = 1:size(A, 2)
tmp = zero(eltype(C))
for j = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1)
tmp += transpose(nzv[j])*B[rv[j],k]
end
C[col,k] += tmp * α
end
end
C
end
*(transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, x::StridedVector{Tx}) where {Tx} =
(T = promote_op(matprod, eltype(transA), Tx); mul!(similar(x, T, size(transA, 1)), transA, x, true, false))
*(transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, B::AdjOrTransStridedOrTriangularMatrix) =
(T = promote_op(matprod, eltype(transA), eltype(B)); mul!(similar(B, T, (size(transA, 1), size(B, 2))), transA, B, true, false))

# For compatibility with dense multiplication API. Should be deleted when dense multiplication
# API is updated to follow BLAS API.
mul!(C::StridedVecOrMat, A::AbstractSparseMatrixCSC, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}) =
mul!(C, A, B, true, false)
mul!(C::StridedVecOrMat, adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}) =
mul!(C, adjA, B, true, false)
mul!(C::StridedVecOrMat, transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, B::Union{StridedVector,AdjOrTransStridedOrTriangularMatrix}) =
mul!(C, transA, B, true, false)

function mul!(C::StridedVecOrMat, X::AdjOrTransStridedOrTriangularMatrix, A::AbstractSparseMatrixCSC, α::Number, β::Number)
mX, nX = size(X)
nX == size(A, 1) || throw(DimensionMismatch())
Expand All @@ -120,49 +91,50 @@ function mul!(C::StridedVecOrMat, X::AdjOrTransStridedOrTriangularMatrix, A::Abs
if β != 1
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
end
@inbounds for multivec_row=1:mX, col = 1:size(A, 2), k=getcolptr(A)[col]:(getcolptr(A)[col+1]-1)
C[multivec_row, col] += α * X[multivec_row, rv[k]] * nzv[k] # perhaps suboptimal position of α?
if X isa StridedOrTriangularMatrix
@inbounds for col in 1:size(A, 2), k in nzrange(A, col)
Aiα = nzv[k] * α
rvk = rv[k]
@simd for multivec_row in 1:mX
C[multivec_row, col] += X[multivec_row, rvk] * Aiα
end
end
else # X isa Adjoint or Transpose
for multivec_row in 1:mX, col in 1:size(A, 2)
@inbounds for k in nzrange(A, col)
C[multivec_row, col] += X[multivec_row, rv[k]] * nzv[k] * α
end
end
end
C
end
*(X::AdjOrTransStridedOrTriangularMatrix, A::SparseMatrixCSCUnion{TvA}) where {TvA} =
(T = promote_op(matprod, eltype(X), TvA); mul!(similar(X, T, (size(X, 1), size(A, 2))), X, A, true, false))

function mul!(C::StridedVecOrMat, X::AdjOrTransStridedOrTriangularMatrix, adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}, α::Number, β::Number)
A = adjA.parent
mX, nX = size(X)
nX == size(A, 2) || throw(DimensionMismatch())
mX == size(C, 1) || throw(DimensionMismatch())
size(A, 1) == size(C, 2) || throw(DimensionMismatch())
rv = rowvals(A)
nzv = nonzeros(A)
if β != 1
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
end
@inbounds for col = 1:size(A, 2), k=getcolptr(A)[col]:(getcolptr(A)[col+1]-1), multivec_col=1:mX
C[multivec_col, rv[k]] += α * X[multivec_col, col] * adjoint(nzv[k]) # perhaps suboptimal position of α?
for (T, t) in ((Adjoint, adjoint), (Transpose, transpose))
@eval function mul!(C::StridedVecOrMat, X::AdjOrTransStridedOrTriangularMatrix, xA::$T{<:Any,<:AbstractSparseMatrixCSC}, α::Number, β::Number)
A = xA.parent
mX, nX = size(X)
nX == size(A, 2) || throw(DimensionMismatch())
mX == size(C, 1) || throw(DimensionMismatch())
size(A, 1) == size(C, 2) || throw(DimensionMismatch())
rv = rowvals(A)
nzv = nonzeros(A)
if β != 1
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
end
@inbounds for col in 1:size(A, 2), k in nzrange(A, col)
Aiα = $t(nzv[k]) * α
rvk = rv[k]
@simd for multivec_col in 1:mX
C[multivec_col, rvk] += X[multivec_col, col] * Aiα
end
end
C
end
C
end
*(X::AdjOrTransStridedOrTriangularMatrix, adjA::Adjoint{<:Any,<:AbstractSparseMatrixCSC}) =
(T = promote_op(matprod, eltype(X), eltype(adjA)); mul!(similar(X, T, (size(X, 1), size(adjA, 2))), X, adjA, true, false))

function mul!(C::StridedVecOrMat, X::AdjOrTransStridedOrTriangularMatrix, transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}, α::Number, β::Number)
A = transA.parent
mX, nX = size(X)
nX == size(A, 2) || throw(DimensionMismatch())
mX == size(C, 1) || throw(DimensionMismatch())
size(A, 1) == size(C, 2) || throw(DimensionMismatch())
rv = rowvals(A)
nzv = nonzeros(A)
if β != 1
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
end
@inbounds for col = 1:size(A, 2), k=getcolptr(A)[col]:(getcolptr(A)[col+1]-1), multivec_col=1:mX
C[multivec_col, rv[k]] += α * X[multivec_col, col] * transpose(nzv[k]) # perhaps suboptimal position of α?
end
C
end
*(X::AdjOrTransStridedOrTriangularMatrix, transA::Transpose{<:Any,<:AbstractSparseMatrixCSC}) =
(T = promote_op(matprod, eltype(X), eltype(transA)); mul!(similar(X, T, (size(X, 1), size(transA, 2))), X, transA, true, false))

Expand Down Expand Up @@ -896,7 +868,7 @@ function ldiv!(D::Diagonal{T}, A::AbstractSparseMatrixCSC{T}) where {T}
for i=1:length(b)
iszero(b[i]) && throw(SingularException(i))
end
@inbounds for col = 1:size(A, 2), p = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1)
@inbounds for col in 1:size(A, 2), p in nzrange(A, col)
nonz[p] = b[Arowval[p]] \ nonz[p]
end
A
Expand All @@ -916,7 +888,7 @@ function triu(S::AbstractSparseMatrixCSC{Tv,Ti}, k::Integer=0) where {Tv,Ti}
colptr[col] = 1
end
for col = max(k+1,1) : n
for c1 = getcolptr(S)[col] : getcolptr(S)[col+1]-1
for c1 in nzrange(S, col)
rowvals(S)[c1] > col - k && break
nnz += 1
end
Expand All @@ -927,7 +899,7 @@ function triu(S::AbstractSparseMatrixCSC{Tv,Ti}, k::Integer=0) where {Tv,Ti}
A = SparseMatrixCSC(m, n, colptr, rowval, nzval)
for col = max(k+1,1) : n
c1 = getcolptr(S)[col]
for c2 = getcolptr(A)[col] : getcolptr(A)[col+1]-1
for c2 in nzrange(A, col)
rowvals(A)[c2] = rowvals(S)[c1]
nonzeros(A)[c2] = nonzeros(S)[c1]
c1 += 1
Expand Down Expand Up @@ -981,7 +953,7 @@ function sparse_diff1(S::AbstractSparseMatrixCSC{Tv,Ti}) where {Tv,Ti}
for col = 1 : n
last_row = 0
last_val = 0
for k = getcolptr(S)[col] : getcolptr(S)[col+1]-1
for k in nzrange(S, col)
row = rowvals(S)[k]
val = nonzeros(S)[k]
if row > 1
Expand Down Expand Up @@ -1124,7 +1096,7 @@ function opnorm(A::AbstractSparseMatrixCSC, p::Real=2)
nA::Tsum = 0
for j=1:n
colSum::Tsum = 0
for i = getcolptr(A)[j]:getcolptr(A)[j+1]-1
for i in nzrange(A, j)
colSum += abs(nonzeros(A)[i])
end
nA = max(nA, colSum)
Expand Down Expand Up @@ -1469,7 +1441,7 @@ function mul!(C::AbstractSparseMatrixCSC, A::AbstractSparseMatrixCSC, D::Diagona
Cnzval = nonzeros(C)
Anzval = nonzeros(A)
resize!(Cnzval, length(Anzval))
for col = 1:n, p = getcolptr(A)[col]:(getcolptr(A)[col+1]-1)
for col in 1:n, p in nzrange(A, col)
@inbounds Cnzval[p] = Anzval[p] * b[col]
end
C
Expand All @@ -1484,7 +1456,7 @@ function mul!(C::AbstractSparseMatrixCSC, D::Diagonal, A::AbstractSparseMatrixCS
Anzval = nonzeros(A)
Arowval = rowvals(A)
resize!(Cnzval, length(Anzval))
for col = 1:n, p = getcolptr(A)[col]:(getcolptr(A)[col+1]-1)
for col in 1:n, p in nzrange(A, col)
@inbounds Cnzval[p] = b[Arowval[p]] * Anzval[p]
end
C
Expand Down Expand Up @@ -1520,7 +1492,7 @@ function rmul!(A::AbstractSparseMatrixCSC, D::Diagonal)
m, n = size(A)
(n == size(D, 1)) || throw(DimensionMismatch())
Anzval = nonzeros(A)
@inbounds for col = 1:n, p = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1)
@inbounds for col in 1:n, p in nzrange(A, col)
Anzval[p] = Anzval[p] * D.diag[col]
end
return A
Expand All @@ -1531,7 +1503,7 @@ function lmul!(D::Diagonal, A::AbstractSparseMatrixCSC)
(m == size(D, 2)) || throw(DimensionMismatch())
Anzval = nonzeros(A)
Arowval = rowvals(A)
@inbounds for col = 1:n, p = getcolptr(A)[col]:(getcolptr(A)[col + 1] - 1)
@inbounds for col in 1:n, p in nzrange(A, col)
Anzval[p] = D.diag[Arowval[p]] * Anzval[p]
end
return A
Expand Down

0 comments on commit a3369df

Please sign in to comment.