diff --git a/lib/cusparse/interfaces.jl b/lib/cusparse/interfaces.jl index f5727bbef5..68cecd155a 100644 --- a/lib/cusparse/interfaces.jl +++ b/lib/cusparse/interfaces.jl @@ -8,12 +8,24 @@ function mv_wrapper(transa::SparseChar, alpha::Number, A::CuSparseMatrix{T}, X:: mv!(transa, alpha, A, X, beta, Y, 'O') end -LinearAlgebra.mul!(C::CuVector{T},A::CuSparseMatrix,B::DenseCuVector,alpha::Number,beta::Number) where {T} = mv_wrapper('N',alpha,A,B,beta,C) -LinearAlgebra.mul!(C::CuVector{T},transA::Transpose{<:Any,<:CuSparseMatrix},B::DenseCuVector,alpha::Number,beta::Number) where {T} = mv_wrapper('T',alpha,parent(transA),B,beta,C) -LinearAlgebra.mul!(C::CuVector{T},adjA::Adjoint{<:Any,<:CuSparseMatrix},B::DenseCuVector,alpha::Number,beta::Number) where {T} = mv_wrapper('C',alpha,parent(adjA),B,beta,C) -LinearAlgebra.mul!(C::CuVector{T},A::HermOrSym{T,<:CuSparseMatrix{T}},B::DenseCuVector{T},alpha::Number,beta::Number) where T = mv_wrapper('N',alpha,A,B,beta,C) -LinearAlgebra.mul!(C::CuVector{T},transA::Transpose{<:Any, <:HermOrSym{T,<:CuSparseMatrix{T}}},B::DenseCuVector{T},alpha::Number,beta::Number) where {T} = mv_wrapper('T',alpha,parent(transA),B,beta,C) -LinearAlgebra.mul!(C::CuVector{T},adjA::Adjoint{<:Any, <:HermOrSym{T,<:CuSparseMatrix{T}}},B::DenseCuVector{T},alpha::Number,beta::Number) where {T} = mv_wrapper('C',alpha,parent(adjA),B,beta,C) +LinearAlgebra.mul!(C::DenseCuVector{T}, A::CuSparseMatrix, + B::DenseCuVector, alpha::Number, beta::Number) where {T} = + mv_wrapper('N',alpha,A,B,beta,C) +LinearAlgebra.mul!(C::DenseCuVector{T}, A::Transpose{T,<:CuSparseMatrix}, + B::DenseCuVector, alpha::Number, beta::Number) where {T} = + mv_wrapper('T',alpha,parent(A),B,beta,C) +LinearAlgebra.mul!(C::DenseCuVector{T}, A::Adjoint{T,<:CuSparseMatrix}, + B::DenseCuVector, alpha::Number, beta::Number) where {T} = + mv_wrapper('C',alpha,parent(A),B,beta,C) +LinearAlgebra.mul!(C::DenseCuVector{T}, A::HermOrSym{T,<:CuSparseMatrix}, + B::DenseCuVector{T}, alpha::Number, beta::Number) where T = + mv_wrapper('N',alpha,A,B,beta,C) +LinearAlgebra.mul!(C::DenseCuVector{T}, A::Transpose{<:Any, <:HermOrSym{T,<:CuSparseMatrix}}, + B::DenseCuVector{T}, alpha::Number, beta::Number) where {T} = + mv_wrapper('T',alpha,parent(A),B,beta,C) +LinearAlgebra.mul!(C::DenseCuVector{T}, A::Adjoint{<:Any, <:HermOrSym{T,<:CuSparseMatrix}}, + B::DenseCuVector{T}, alpha::Number, beta::Number) where {T} = + mv_wrapper('C',alpha,parent(A),B,beta,C) function mm_wrapper(transa::SparseChar, transb::SparseChar, alpha::Number, A::CuSparseMatrix{T}, B::CuMatrix{T}, beta::Number, C::CuMatrix{T}) where {T} @@ -24,33 +36,92 @@ function mm_wrapper(transa::SparseChar, transb::SparseChar, alpha::Number, mm!(transa, transb, alpha, A, B, beta, C, 'O') end -LinearAlgebra.mul!(C::CuMatrix{T},A::CuSparseMatrix{T},B::DenseCuMatrix{T},alpha::Number,beta::Number) where {T} = mm_wrapper('N','N',alpha,A,B,beta,C) -LinearAlgebra.mul!(C::CuMatrix{T},A::CuSparseMatrix{T},transB::Transpose{<:Any, <:DenseCuMatrix{T}},alpha::Number,beta::Number) where {T} = mm_wrapper('N','T',alpha,A,parent(transB),beta,C) -LinearAlgebra.mul!(C::CuMatrix{T},transA::Transpose{<:Any, <:CuSparseMatrix{T}},B::DenseCuMatrix{T},alpha::Number,beta::Number) where {T} = mm_wrapper('T','N',alpha,parent(transA),B,beta,C) -LinearAlgebra.mul!(C::CuMatrix{T},transA::Transpose{<:Any, <:CuSparseMatrix{T}},transB::Transpose{<:Any, <:DenseCuMatrix{T}},alpha::Number,beta::Number) where {T} = mm_wrapper('T','T',alpha,parent(transA),parent(transB),beta,C) -LinearAlgebra.mul!(C::CuMatrix{T},adjA::Adjoint{<:Any, <:CuSparseMatrix{T}},B::DenseCuMatrix{T},alpha::Number,beta::Number) where {T} = mm_wrapper('C','N',alpha,parent(adjA),B,beta,C) - -LinearAlgebra.mul!(C::CuMatrix{T},A::HermOrSym{<:Number, <:CuSparseMatrix},B::DenseCuMatrix,alpha::Number,beta::Number) where {T} = mm_wrapper('N',alpha,A,B,beta,C) -LinearAlgebra.mul!(C::CuMatrix{T},transA::Transpose{<:Any, <:HermOrSym{<:Number, <:CuSparseMatrix}},B::DenseCuMatrix,alpha::Number,beta::Number) where {T} = mm_wrapper('T',alpha,parent(transA),B,beta,C) -LinearAlgebra.mul!(C::CuMatrix{T},adjA::Adjoint{<:Any, <:HermOrSym{<:Number, <:CuSparseMatrix}},B::DenseCuMatrix,alpha::Number,beta::Number) where {T} = mm_wrapper('C',alpha,parent(adjA),B,beta,C) - -Base.:(\)(A::Union{UpperTriangular{T, S},LowerTriangular{T, S}}, B::DenseCuMatrix{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sm('N',A,B,'O') -Base.:(\)(transA::Transpose{T, UpperTriangular{T, S}}, B::DenseCuMatrix{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sm('T',parent(transA),B,'O') -Base.:(\)(transA::Transpose{T, LowerTriangular{T, S}}, B::DenseCuMatrix{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sm('T',parent(transA),B,'O') -Base.:(\)(adjA::Adjoint{T, UpperTriangular{T, S}},B::DenseCuMatrix{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sm('C',parent(adjA),B,'O') -Base.:(\)(adjA::Adjoint{T, LowerTriangular{T, S}},B::DenseCuMatrix{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sm('C',parent(adjA),B,'O') - -Base.:(\)(A::Union{UpperTriangular{T, S},LowerTriangular{T, S}}, B::DenseCuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('N',A,B,'O') -Base.:(\)(transA::Transpose{T, UpperTriangular{T, S}},B::DenseCuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('T',parent(transA),B,'O') -Base.:(\)(transA::Transpose{T, LowerTriangular{T, S}},B::DenseCuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('T',parent(transA),B,'O') -Base.:(\)(adjA::Adjoint{T, UpperTriangular{T, S}},B::DenseCuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('C',parent(adjA),B,'O') -Base.:(\)(adjA::Adjoint{T, LowerTriangular{T, S}},B::DenseCuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('C',parent(adjA),B,'O') - -Base.:(\)(A::Union{UnitUpperTriangular{T, S},UnitLowerTriangular{T, S}}, B::DenseCuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('N',A,B,'O',unit_diag=true) -Base.:(\)(transA::Transpose{T, UnitUpperTriangular{T, S}},B::DenseCuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('T',parent(transA),B,'O',unit_diag=true) -Base.:(\)(transA::Transpose{T, UnitLowerTriangular{T, S}},B::DenseCuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('T',parent(transA),B,'O',unit_diag=true) -Base.:(\)(adjA::Adjoint{T, UnitUpperTriangular{T, S}},B::DenseCuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('C',parent(adjA),B,'O',unit_diag=true) -Base.:(\)(adjA::Adjoint{T, UnitLowerTriangular{T, S}},B::DenseCuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('C',parent(adjA),B,'O',unit_diag=true) - -Base.:(+)(A::Union{CuSparseMatrixCSR,CuSparseMatrixCSC},B::Union{CuSparseMatrixCSR,CuSparseMatrixCSC}) = geam(A,B,'O','O','O') -Base.:(-)(A::Union{CuSparseMatrixCSR,CuSparseMatrixCSC},B::Union{CuSparseMatrixCSR,CuSparseMatrixCSC}) = geam(A,-one(eltype(A)),B,'O','O','O') +LinearAlgebra.mul!(C::DenseCuMatrix{T}, A::CuSparseMatrix{T}, + B::DenseCuMatrix{T}, alpha::Number, beta::Number) where {T} = + mm_wrapper('N','N',alpha,A,B,beta,C) +LinearAlgebra.mul!(C::DenseCuMatrix{T}, A::CuSparseMatrix{T}, + B::Transpose{T, <:DenseCuMatrix}, alpha::Number, beta::Number) where {T} = + mm_wrapper('N','T',alpha,A,parent(B),beta,C) +LinearAlgebra.mul!(C::DenseCuMatrix{T}, A::Transpose{<:Any, <:CuSparseMatrix{T}}, + B::DenseCuMatrix{T}, alpha::Number, beta::Number) where {T} = + mm_wrapper('T','N',alpha,parent(A),B,beta,C) +LinearAlgebra.mul!(C::DenseCuMatrix{T}, A::Transpose{<:Any, <:CuSparseMatrix{T}}, + B::Transpose{<:Any, <:DenseCuMatrix{T}}, alpha::Number, beta::Number) where {T} = + mm_wrapper('T','T',alpha,parent(A),parent(B),beta,C) +LinearAlgebra.mul!(C::DenseCuMatrix{T}, A::Adjoint{<:Any, <:CuSparseMatrix{T}}, + B::DenseCuMatrix{T}, alpha::Number, beta::Number) where {T} = + mm_wrapper('C','N',alpha,parent(A),B,beta,C) + +LinearAlgebra.mul!(C::DenseCuMatrix{T}, A::HermOrSym{<:Number, <:CuSparseMatrix}, + B::DenseCuMatrix, alpha::Number, beta::Number) where {T} = + mm_wrapper('N',alpha,A,B,beta,C) +LinearAlgebra.mul!(C::DenseCuMatrix{T}, A::Transpose{<:Any, <:HermOrSym{<:Number, <:CuSparseMatrix}}, + B::DenseCuMatrix, alpha::Number, beta::Number) where {T} = + mm_wrapper('T',alpha,parent(A),B,beta,C) +LinearAlgebra.mul!(C::DenseCuMatrix{T}, A::Adjoint{<:Any, <:HermOrSym{<:Number, <:CuSparseMatrix}}, + B::DenseCuMatrix, alpha::Number, beta::Number) where {T} = + mm_wrapper('C',alpha,parent(A),B,beta,C) + +Base.:(\)(A::UpperTriangular{T, <:AbstractCuSparseMatrix{T}}, + B::DenseCuMatrix{T}) where {T<:BlasFloat} = + sm('N',A,B,'O') +Base.:(\)(A::LowerTriangular{T, <:AbstractCuSparseMatrix{T}}, + B::DenseCuMatrix{T}) where {T<:BlasFloat} = + sm('N',A,B,'O') +Base.:(\)(A::UpperTriangular{<:Any, <:Transpose{T, <:AbstractCuSparseMatrix{T}}}, + B::DenseCuMatrix{T}) where {T<:BlasFloat} = + sm('T',parent(A),B,'O') +Base.:(\)(A::LowerTriangular{<:Any, <:Transpose{T, <:AbstractCuSparseMatrix{T}}}, + B::DenseCuMatrix{T}) where {T<:BlasFloat} = + sm('T',parent(A),B,'O') +Base.:(\)(A::UpperTriangular{<:Any, <:Adjoint{T, <:AbstractCuSparseMatrix{T}}}, + B::DenseCuMatrix{T}) where {T<:BlasFloat} = + sm('C',parent(A),B,'O') +Base.:(\)(A::LowerTriangular{<:Any, <:Adjoint{T, <:AbstractCuSparseMatrix{T}}}, + B::DenseCuMatrix{T}) where {T<:BlasFloat} = + sm('C',parent(A),B,'O') + +# TODO: some metaprogramming to reduce the amount of definitions here + +Base.:(\)(A::UpperTriangular{T, <:AbstractCuSparseMatrix{T}}, + B::DenseCuVector{T}) where {T<:BlasFloat} = + sv2('N', 'U', one(T), parent(A), B,'O') +Base.:(\)(A::LowerTriangular{T, <:AbstractCuSparseMatrix{T}}, + B::DenseCuVector{T}) where {T<:BlasFloat} = + sv2('N', 'L', one(T), parent(A), B,'O') +Base.:(\)(A::UpperTriangular{<:Any, <:Transpose{T, <:AbstractCuSparseMatrix{T}}}, + B::DenseCuVector{T}) where {T<:BlasFloat} = + sv2('T', 'L', one(T), parent(parent(A)), B, 'O') +Base.:(\)(A::LowerTriangular{<:Any, <:Transpose{T, <:AbstractCuSparseMatrix{T}}}, + B::DenseCuVector{T}) where {T<:BlasFloat} = + sv2('T', 'U', one(T), parent(parent(A)), B, 'O') +Base.:(\)(A::UpperTriangular{<:Any, <:Adjoint{T, <:AbstractCuSparseMatrix{T}}}, + B::DenseCuVector{T}) where {T<:BlasFloat} = + sv2('C', 'L', one(T), parent(parent(A)), B, 'O') +Base.:(\)(A::LowerTriangular{<:Any, <:Adjoint{T, <:AbstractCuSparseMatrix{T}}}, + B::DenseCuVector{T}) where {T<:BlasFloat} = + sv2('C', 'U', one(T), parent(parent(A)), B, 'O') + +Base.:(\)(A::UnitUpperTriangular{T, <:AbstractCuSparseMatrix{T}}, + B::DenseCuVector{T}) where {T<:BlasFloat} = + sv2('N', 'U', one(T), parent(A), B, 'O', unit_diag=true) +Base.:(\)(A::UnitLowerTriangular{T, <:AbstractCuSparseMatrix{T}}, + B::DenseCuVector{T}) where {T<:BlasFloat} = + sv2('N', 'L', one(T), parent(A), B, 'O', unit_diag=true) +Base.:(\)(A::UnitUpperTriangular{<:Any, <:Transpose{T, <:AbstractCuSparseMatrix{T}}}, + B::DenseCuVector{T}) where {T<:BlasFloat} = + sv2('T', 'L', one(T), parent(parent(A)), B, 'O', unit_diag=true) +Base.:(\)(A::UnitLowerTriangular{<:Any, <:Transpose{T, <:AbstractCuSparseMatrix{T}}}, + B::DenseCuVector{T}) where {T<:BlasFloat} = + sv2('T', 'U', one(T), parent(parent(A)), B, 'O', unit_diag=true) +Base.:(\)(A::UnitUpperTriangular{<:Any, <:Adjoint{T, <:AbstractCuSparseMatrix{T}}}, + B::DenseCuVector{T}) where {T<:BlasFloat} = + sv2('C', 'L', one(T), parent(parent(A)), B, 'O', unit_diag=true) +Base.:(\)(A::UnitLowerTriangular{<:Any, <:Adjoint{T, <:AbstractCuSparseMatrix{T}}}, + B::DenseCuVector{T}) where {T<:BlasFloat} = + sv2('C', 'U', one(T), parent(parent(A)), B, 'O', unit_diag=true) + +Base.:(+)(A::Union{CuSparseMatrixCSR,CuSparseMatrixCSC}, + B::Union{CuSparseMatrixCSR,CuSparseMatrixCSC}) = geam(A,B,'O','O','O') +Base.:(-)(A::Union{CuSparseMatrixCSR,CuSparseMatrixCSC}, + B::Union{CuSparseMatrixCSR,CuSparseMatrixCSC}) = geam(A,-one(eltype(A)),B,'O','O','O') diff --git a/lib/cusparse/level2.jl b/lib/cusparse/level2.jl index 85b1eeb2ca..8fd8d8233e 100644 --- a/lib/cusparse/level2.jl +++ b/lib/cusparse/level2.jl @@ -1,8 +1,6 @@ # sparse linear algebra functions that perform operations between sparse matrices and dense # vectors -using LinearAlgebra: AbstractTriangular - export sv2!, sv2, sv """ @@ -121,29 +119,6 @@ for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64) unit_diag::Bool=false) sv2!(transa,uplo,one($elty),A,copy(X),index,unit_diag=unit_diag) end - function sv2(transa::SparseChar, - alpha::Number, - A::AbstractTriangular, - X::CuVector{$elty}, - index::SparseChar; - unit_diag::Bool=false) - uplo = 'U' - if typeof(A) <: Union{LowerTriangular, UnitLowerTriangular} - uplo = 'L' - end - sv2!(transa,uplo,alpha,A.data,copy(X),index,unit_diag=unit_diag) - end - function sv2(transa::SparseChar, - A::AbstractTriangular, - X::CuVector{$elty}, - index::SparseChar; - unit_diag::Bool=false) - uplo = 'U' - if typeof(A) <: Union{LowerTriangular, UnitLowerTriangular} - uplo = 'L' - end - sv2!(transa,uplo,one($elty),A.data,copy(X),index,unit_diag=unit_diag) - end end end