Skip to content

Commit

Permalink
Adapt CUSPARSE to changes in triangular wrappers.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Nov 24, 2020
1 parent c2a98ed commit d0ab2ed
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 61 deletions.
143 changes: 107 additions & 36 deletions lib/cusparse/interfaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,24 @@ function mv_wrapper(transa::SparseChar, alpha::Number, A::CuSparseMatrix{T}, X::
mv!(transa, alpha, A, X, beta, Y, 'O')
end

LinearAlgebra.mul!(C::CuVector{T},A::CuSparseMatrix,B::DenseCuVector,alpha::Number,beta::Number) where {T} = mv_wrapper('N',alpha,A,B,beta,C)
LinearAlgebra.mul!(C::CuVector{T},transA::Transpose{<:Any,<:CuSparseMatrix},B::DenseCuVector,alpha::Number,beta::Number) where {T} = mv_wrapper('T',alpha,parent(transA),B,beta,C)
LinearAlgebra.mul!(C::CuVector{T},adjA::Adjoint{<:Any,<:CuSparseMatrix},B::DenseCuVector,alpha::Number,beta::Number) where {T} = mv_wrapper('C',alpha,parent(adjA),B,beta,C)
LinearAlgebra.mul!(C::CuVector{T},A::HermOrSym{T,<:CuSparseMatrix{T}},B::DenseCuVector{T},alpha::Number,beta::Number) where T = mv_wrapper('N',alpha,A,B,beta,C)
LinearAlgebra.mul!(C::CuVector{T},transA::Transpose{<:Any, <:HermOrSym{T,<:CuSparseMatrix{T}}},B::DenseCuVector{T},alpha::Number,beta::Number) where {T} = mv_wrapper('T',alpha,parent(transA),B,beta,C)
LinearAlgebra.mul!(C::CuVector{T},adjA::Adjoint{<:Any, <:HermOrSym{T,<:CuSparseMatrix{T}}},B::DenseCuVector{T},alpha::Number,beta::Number) where {T} = mv_wrapper('C',alpha,parent(adjA),B,beta,C)
LinearAlgebra.mul!(C::DenseCuVector{T}, A::CuSparseMatrix,
B::DenseCuVector, alpha::Number, beta::Number) where {T} =
mv_wrapper('N',alpha,A,B,beta,C)
LinearAlgebra.mul!(C::DenseCuVector{T}, A::Transpose{T,<:CuSparseMatrix},
B::DenseCuVector, alpha::Number, beta::Number) where {T} =
mv_wrapper('T',alpha,parent(A),B,beta,C)
LinearAlgebra.mul!(C::DenseCuVector{T}, A::Adjoint{T,<:CuSparseMatrix},
B::DenseCuVector, alpha::Number, beta::Number) where {T} =
mv_wrapper('C',alpha,parent(A),B,beta,C)
LinearAlgebra.mul!(C::DenseCuVector{T}, A::HermOrSym{T,<:CuSparseMatrix},
B::DenseCuVector{T}, alpha::Number, beta::Number) where T =
mv_wrapper('N',alpha,A,B,beta,C)
LinearAlgebra.mul!(C::DenseCuVector{T}, A::Transpose{<:Any, <:HermOrSym{T,<:CuSparseMatrix}},
B::DenseCuVector{T}, alpha::Number, beta::Number) where {T} =
mv_wrapper('T',alpha,parent(A),B,beta,C)
LinearAlgebra.mul!(C::DenseCuVector{T}, A::Adjoint{<:Any, <:HermOrSym{T,<:CuSparseMatrix}},
B::DenseCuVector{T}, alpha::Number, beta::Number) where {T} =
mv_wrapper('C',alpha,parent(A),B,beta,C)

function mm_wrapper(transa::SparseChar, transb::SparseChar, alpha::Number,
A::CuSparseMatrix{T}, B::CuMatrix{T}, beta::Number, C::CuMatrix{T}) where {T}
Expand All @@ -24,33 +36,92 @@ function mm_wrapper(transa::SparseChar, transb::SparseChar, alpha::Number,
mm!(transa, transb, alpha, A, B, beta, C, 'O')
end

LinearAlgebra.mul!(C::CuMatrix{T},A::CuSparseMatrix{T},B::DenseCuMatrix{T},alpha::Number,beta::Number) where {T} = mm_wrapper('N','N',alpha,A,B,beta,C)
LinearAlgebra.mul!(C::CuMatrix{T},A::CuSparseMatrix{T},transB::Transpose{<:Any, <:DenseCuMatrix{T}},alpha::Number,beta::Number) where {T} = mm_wrapper('N','T',alpha,A,parent(transB),beta,C)
LinearAlgebra.mul!(C::CuMatrix{T},transA::Transpose{<:Any, <:CuSparseMatrix{T}},B::DenseCuMatrix{T},alpha::Number,beta::Number) where {T} = mm_wrapper('T','N',alpha,parent(transA),B,beta,C)
LinearAlgebra.mul!(C::CuMatrix{T},transA::Transpose{<:Any, <:CuSparseMatrix{T}},transB::Transpose{<:Any, <:DenseCuMatrix{T}},alpha::Number,beta::Number) where {T} = mm_wrapper('T','T',alpha,parent(transA),parent(transB),beta,C)
LinearAlgebra.mul!(C::CuMatrix{T},adjA::Adjoint{<:Any, <:CuSparseMatrix{T}},B::DenseCuMatrix{T},alpha::Number,beta::Number) where {T} = mm_wrapper('C','N',alpha,parent(adjA),B,beta,C)

LinearAlgebra.mul!(C::CuMatrix{T},A::HermOrSym{<:Number, <:CuSparseMatrix},B::DenseCuMatrix,alpha::Number,beta::Number) where {T} = mm_wrapper('N',alpha,A,B,beta,C)
LinearAlgebra.mul!(C::CuMatrix{T},transA::Transpose{<:Any, <:HermOrSym{<:Number, <:CuSparseMatrix}},B::DenseCuMatrix,alpha::Number,beta::Number) where {T} = mm_wrapper('T',alpha,parent(transA),B,beta,C)
LinearAlgebra.mul!(C::CuMatrix{T},adjA::Adjoint{<:Any, <:HermOrSym{<:Number, <:CuSparseMatrix}},B::DenseCuMatrix,alpha::Number,beta::Number) where {T} = mm_wrapper('C',alpha,parent(adjA),B,beta,C)

Base.:(\)(A::Union{UpperTriangular{T, S},LowerTriangular{T, S}}, B::DenseCuMatrix{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sm('N',A,B,'O')
Base.:(\)(transA::Transpose{T, UpperTriangular{T, S}}, B::DenseCuMatrix{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sm('T',parent(transA),B,'O')
Base.:(\)(transA::Transpose{T, LowerTriangular{T, S}}, B::DenseCuMatrix{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sm('T',parent(transA),B,'O')
Base.:(\)(adjA::Adjoint{T, UpperTriangular{T, S}},B::DenseCuMatrix{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sm('C',parent(adjA),B,'O')
Base.:(\)(adjA::Adjoint{T, LowerTriangular{T, S}},B::DenseCuMatrix{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sm('C',parent(adjA),B,'O')

Base.:(\)(A::Union{UpperTriangular{T, S},LowerTriangular{T, S}}, B::DenseCuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('N',A,B,'O')
Base.:(\)(transA::Transpose{T, UpperTriangular{T, S}},B::DenseCuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('T',parent(transA),B,'O')
Base.:(\)(transA::Transpose{T, LowerTriangular{T, S}},B::DenseCuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('T',parent(transA),B,'O')
Base.:(\)(adjA::Adjoint{T, UpperTriangular{T, S}},B::DenseCuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('C',parent(adjA),B,'O')
Base.:(\)(adjA::Adjoint{T, LowerTriangular{T, S}},B::DenseCuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('C',parent(adjA),B,'O')

Base.:(\)(A::Union{UnitUpperTriangular{T, S},UnitLowerTriangular{T, S}}, B::DenseCuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('N',A,B,'O',unit_diag=true)
Base.:(\)(transA::Transpose{T, UnitUpperTriangular{T, S}},B::DenseCuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('T',parent(transA),B,'O',unit_diag=true)
Base.:(\)(transA::Transpose{T, UnitLowerTriangular{T, S}},B::DenseCuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('T',parent(transA),B,'O',unit_diag=true)
Base.:(\)(adjA::Adjoint{T, UnitUpperTriangular{T, S}},B::DenseCuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('C',parent(adjA),B,'O',unit_diag=true)
Base.:(\)(adjA::Adjoint{T, UnitLowerTriangular{T, S}},B::DenseCuVector{T}) where {T<:BlasFloat, S<:AbstractCuSparseMatrix{T}} = sv2('C',parent(adjA),B,'O',unit_diag=true)

Base.:(+)(A::Union{CuSparseMatrixCSR,CuSparseMatrixCSC},B::Union{CuSparseMatrixCSR,CuSparseMatrixCSC}) = geam(A,B,'O','O','O')
Base.:(-)(A::Union{CuSparseMatrixCSR,CuSparseMatrixCSC},B::Union{CuSparseMatrixCSR,CuSparseMatrixCSC}) = geam(A,-one(eltype(A)),B,'O','O','O')
LinearAlgebra.mul!(C::DenseCuMatrix{T}, A::CuSparseMatrix{T},
B::DenseCuMatrix{T}, alpha::Number, beta::Number) where {T} =
mm_wrapper('N','N',alpha,A,B,beta,C)
LinearAlgebra.mul!(C::DenseCuMatrix{T}, A::CuSparseMatrix{T},
B::Transpose{T, <:DenseCuMatrix}, alpha::Number, beta::Number) where {T} =
mm_wrapper('N','T',alpha,A,parent(B),beta,C)
LinearAlgebra.mul!(C::DenseCuMatrix{T}, A::Transpose{<:Any, <:CuSparseMatrix{T}},
B::DenseCuMatrix{T}, alpha::Number, beta::Number) where {T} =
mm_wrapper('T','N',alpha,parent(A),B,beta,C)
LinearAlgebra.mul!(C::DenseCuMatrix{T}, A::Transpose{<:Any, <:CuSparseMatrix{T}},
B::Transpose{<:Any, <:DenseCuMatrix{T}}, alpha::Number, beta::Number) where {T} =
mm_wrapper('T','T',alpha,parent(A),parent(B),beta,C)
LinearAlgebra.mul!(C::DenseCuMatrix{T}, A::Adjoint{<:Any, <:CuSparseMatrix{T}},
B::DenseCuMatrix{T}, alpha::Number, beta::Number) where {T} =
mm_wrapper('C','N',alpha,parent(A),B,beta,C)

LinearAlgebra.mul!(C::DenseCuMatrix{T}, A::HermOrSym{<:Number, <:CuSparseMatrix},
B::DenseCuMatrix, alpha::Number, beta::Number) where {T} =
mm_wrapper('N',alpha,A,B,beta,C)
LinearAlgebra.mul!(C::DenseCuMatrix{T}, A::Transpose{<:Any, <:HermOrSym{<:Number, <:CuSparseMatrix}},
B::DenseCuMatrix, alpha::Number, beta::Number) where {T} =
mm_wrapper('T',alpha,parent(A),B,beta,C)
LinearAlgebra.mul!(C::DenseCuMatrix{T}, A::Adjoint{<:Any, <:HermOrSym{<:Number, <:CuSparseMatrix}},
B::DenseCuMatrix, alpha::Number, beta::Number) where {T} =
mm_wrapper('C',alpha,parent(A),B,beta,C)

Base.:(\)(A::UpperTriangular{T, <:AbstractCuSparseMatrix{T}},
B::DenseCuMatrix{T}) where {T<:BlasFloat} =
sm('N',A,B,'O')
Base.:(\)(A::LowerTriangular{T, <:AbstractCuSparseMatrix{T}},
B::DenseCuMatrix{T}) where {T<:BlasFloat} =
sm('N',A,B,'O')
Base.:(\)(A::UpperTriangular{<:Any, <:Transpose{T, <:AbstractCuSparseMatrix{T}}},
B::DenseCuMatrix{T}) where {T<:BlasFloat} =
sm('T',parent(A),B,'O')
Base.:(\)(A::LowerTriangular{<:Any, <:Transpose{T, <:AbstractCuSparseMatrix{T}}},
B::DenseCuMatrix{T}) where {T<:BlasFloat} =
sm('T',parent(A),B,'O')
Base.:(\)(A::UpperTriangular{<:Any, <:Adjoint{T, <:AbstractCuSparseMatrix{T}}},
B::DenseCuMatrix{T}) where {T<:BlasFloat} =
sm('C',parent(A),B,'O')
Base.:(\)(A::LowerTriangular{<:Any, <:Adjoint{T, <:AbstractCuSparseMatrix{T}}},
B::DenseCuMatrix{T}) where {T<:BlasFloat} =
sm('C',parent(A),B,'O')

# TODO: some metaprogramming to reduce the amount of definitions here

Base.:(\)(A::UpperTriangular{T, <:AbstractCuSparseMatrix{T}},
B::DenseCuVector{T}) where {T<:BlasFloat} =
sv2('N', 'U', one(T), parent(A), B,'O')
Base.:(\)(A::LowerTriangular{T, <:AbstractCuSparseMatrix{T}},
B::DenseCuVector{T}) where {T<:BlasFloat} =
sv2('N', 'L', one(T), parent(A), B,'O')
Base.:(\)(A::UpperTriangular{<:Any, <:Transpose{T, <:AbstractCuSparseMatrix{T}}},
B::DenseCuVector{T}) where {T<:BlasFloat} =
sv2('T', 'L', one(T), parent(parent(A)), B, 'O')
Base.:(\)(A::LowerTriangular{<:Any, <:Transpose{T, <:AbstractCuSparseMatrix{T}}},
B::DenseCuVector{T}) where {T<:BlasFloat} =
sv2('T', 'U', one(T), parent(parent(A)), B, 'O')
Base.:(\)(A::UpperTriangular{<:Any, <:Adjoint{T, <:AbstractCuSparseMatrix{T}}},
B::DenseCuVector{T}) where {T<:BlasFloat} =
sv2('C', 'L', one(T), parent(parent(A)), B, 'O')
Base.:(\)(A::LowerTriangular{<:Any, <:Adjoint{T, <:AbstractCuSparseMatrix{T}}},
B::DenseCuVector{T}) where {T<:BlasFloat} =
sv2('C', 'U', one(T), parent(parent(A)), B, 'O')

Base.:(\)(A::UnitUpperTriangular{T, <:AbstractCuSparseMatrix{T}},
B::DenseCuVector{T}) where {T<:BlasFloat} =
sv2('N', 'U', one(T), parent(A), B, 'O', unit_diag=true)
Base.:(\)(A::UnitLowerTriangular{T, <:AbstractCuSparseMatrix{T}},
B::DenseCuVector{T}) where {T<:BlasFloat} =
sv2('N', 'L', one(T), parent(A), B, 'O', unit_diag=true)
Base.:(\)(A::UnitUpperTriangular{<:Any, <:Transpose{T, <:AbstractCuSparseMatrix{T}}},
B::DenseCuVector{T}) where {T<:BlasFloat} =
sv2('T', 'L', one(T), parent(parent(A)), B, 'O', unit_diag=true)
Base.:(\)(A::UnitLowerTriangular{<:Any, <:Transpose{T, <:AbstractCuSparseMatrix{T}}},
B::DenseCuVector{T}) where {T<:BlasFloat} =
sv2('T', 'U', one(T), parent(parent(A)), B, 'O', unit_diag=true)
Base.:(\)(A::UnitUpperTriangular{<:Any, <:Adjoint{T, <:AbstractCuSparseMatrix{T}}},
B::DenseCuVector{T}) where {T<:BlasFloat} =
sv2('C', 'L', one(T), parent(parent(A)), B, 'O', unit_diag=true)
Base.:(\)(A::UnitLowerTriangular{<:Any, <:Adjoint{T, <:AbstractCuSparseMatrix{T}}},
B::DenseCuVector{T}) where {T<:BlasFloat} =
sv2('C', 'U', one(T), parent(parent(A)), B, 'O', unit_diag=true)

Base.:(+)(A::Union{CuSparseMatrixCSR,CuSparseMatrixCSC},
B::Union{CuSparseMatrixCSR,CuSparseMatrixCSC}) = geam(A,B,'O','O','O')
Base.:(-)(A::Union{CuSparseMatrixCSR,CuSparseMatrixCSC},
B::Union{CuSparseMatrixCSR,CuSparseMatrixCSC}) = geam(A,-one(eltype(A)),B,'O','O','O')
25 changes: 0 additions & 25 deletions lib/cusparse/level2.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
# sparse linear algebra functions that perform operations between sparse matrices and dense
# vectors

using LinearAlgebra: AbstractTriangular

export sv2!, sv2, sv

"""
Expand Down Expand Up @@ -121,29 +119,6 @@ for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
unit_diag::Bool=false)
sv2!(transa,uplo,one($elty),A,copy(X),index,unit_diag=unit_diag)
end
function sv2(transa::SparseChar,
alpha::Number,
A::AbstractTriangular,
X::CuVector{$elty},
index::SparseChar;
unit_diag::Bool=false)
uplo = 'U'
if typeof(A) <: Union{LowerTriangular, UnitLowerTriangular}
uplo = 'L'
end
sv2!(transa,uplo,alpha,A.data,copy(X),index,unit_diag=unit_diag)
end
function sv2(transa::SparseChar,
A::AbstractTriangular,
X::CuVector{$elty},
index::SparseChar;
unit_diag::Bool=false)
uplo = 'U'
if typeof(A) <: Union{LowerTriangular, UnitLowerTriangular}
uplo = 'L'
end
sv2!(transa,uplo,one($elty),A.data,copy(X),index,unit_diag=unit_diag)
end
end
end

Expand Down

0 comments on commit d0ab2ed

Please sign in to comment.