From 4420717206c87551ad680fe2ae7add692162a0ff Mon Sep 17 00:00:00 2001 From: Jeremy Date: Tue, 12 Jun 2018 16:26:47 -0400 Subject: [PATCH] added sparse matrix inner product (#27470) --- NEWS.md | 3 +++ stdlib/SparseArrays/src/SparseArrays.jl | 2 +- stdlib/SparseArrays/src/linalg.jl | 31 +++++++++++++++++++++++++ stdlib/SparseArrays/test/sparse.jl | 9 +++++++ 4 files changed, 44 insertions(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index fa216371d0176..c1229cb0eb51e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -694,6 +694,9 @@ Library improvements * `Sys.which()` provides a cross-platform method to find executable files, similar to the Unix `which` command. ([#26559]) + * Added an optimized method of `vecdot` for taking the Frobenius inner product + of sparse matrices. ([#27470]) + Compiler/Runtime improvements ----------------------------- diff --git a/stdlib/SparseArrays/src/SparseArrays.jl b/stdlib/SparseArrays/src/SparseArrays.jl index 7ec270532ea0d..33330420d385a 100644 --- a/stdlib/SparseArrays/src/SparseArrays.jl +++ b/stdlib/SparseArrays/src/SparseArrays.jl @@ -13,7 +13,7 @@ using LinearAlgebra import Base: +, -, *, \, /, &, |, xor, == import LinearAlgebra: mul!, ldiv!, rdiv!, chol, adjoint!, diag, dot, eigen, - issymmetric, istril, istriu, lu, tr, transpose!, tril!, triu!, + issymmetric, istril, istriu, lu, tr, transpose!, tril!, triu!, vecdot, vecnorm, cond, diagm, factorize, ishermitian, norm, lmul!, rmul!, tril, triu import Base: @get!, acos, acosd, acot, acotd, acsch, asech, asin, asind, asinh, diff --git a/stdlib/SparseArrays/src/linalg.jl b/stdlib/SparseArrays/src/linalg.jl index b0e4228bcd5d2..c308883147eda 100644 --- a/stdlib/SparseArrays/src/linalg.jl +++ b/stdlib/SparseArrays/src/linalg.jl @@ -203,6 +203,37 @@ function spmatmul(A::SparseMatrixCSC{Tv,Ti}, B::SparseMatrixCSC{Tv,Ti}; return C end +# Frobenius inner product: trace(A'B) +function vecdot(A::SparseMatrixCSC{T1,S1},B::SparseMatrixCSC{T2,S2}) where {T1,T2,S1,S2} + m, n = size(A) + size(B) == (m,n) || throw(DimensionMismatch("matrices must have the same dimensions")) + r = vecdot(zero(T1), zero(T2)) + @inbounds for j = 1:n + ia = A.colptr[j]; ia_nxt = A.colptr[j+1] + ib = B.colptr[j]; ib_nxt = B.colptr[j+1] + if ia < ia_nxt && ib < ib_nxt + ra = A.rowval[ia]; rb = B.rowval[ib] + while true + if ra < rb + ia += oneunit(S1) + ia < ia_nxt || break + ra = A.rowval[ia] + elseif ra > rb + ib += oneunit(S2) + ib < ib_nxt || break + rb = B.rowval[ib] + else # ra == rb + r += vecdot(A.nzval[ia], B.nzval[ib]) + ia += oneunit(S1); ib += oneunit(S2) + ia < ia_nxt && ib < ib_nxt || break + ra = A.rowval[ia]; rb = B.rowval[ib] + end + end + end + end + return r +end + ## solvers function fwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat) # forward substitution for CSC matrices diff --git a/stdlib/SparseArrays/test/sparse.jl b/stdlib/SparseArrays/test/sparse.jl index f98ff20776f3c..909ce848e7cc9 100644 --- a/stdlib/SparseArrays/test/sparse.jl +++ b/stdlib/SparseArrays/test/sparse.jl @@ -327,6 +327,15 @@ end end end +@testset "sparse Frobenius inner product" begin + for i = 1:5 + A = sprand(ComplexF64,10,15,0.4) + B = sprand(ComplexF64,10,15,0.5) + @test vecdot(A,B) ≈ vecdot(Matrix(A),Matrix(B)) + end + @test_throws DimensionMismatch vecdot(sprand(5,5,0.2),sprand(5,6,0.2)) +end + sA = sprandn(3, 7, 0.5) sC = similar(sA) dA = Array(sA)