diff --git a/src/blas/highlevel.jl b/src/blas/highlevel.jl index 14aae87d..33944599 100644 --- a/src/blas/highlevel.jl +++ b/src/blas/highlevel.jl @@ -81,6 +81,50 @@ LinearAlgebra.lmul!(Y::CuVector{T}, A::LinearAlgebra.Transpose{<:Any, CuMatrix{T LinearAlgebra.lmul!(Y::CuVector{T}, A::LinearAlgebra.Adjoint{<:Any, CuMatrix{T}}, B::CuVector{T}) where T<:CublasFloat = gemv_wrapper!(Y, 'T', A.parent, B) LinearAlgebra.lmul!(Y::CuVector{T}, A::LinearAlgebra.Adjoint{<:Any, CuMatrix{T}}, B::CuVector{T}) where T<:CublasComplex = gemv_wrapper!(Y, 'C', A.parent, B) +# TRSV + +function LinearAlgebra.ldiv!( + A::UpperTriangular{T, <:CuMatrix{T}}, + x::CuVector{T}, +) where T<:CublasFloat + return CUBLAS.trsv!('U', 'N', 'N', parent(A), x) +end + +function LinearAlgebra.ldiv!( + A::Adjoint{T, <:UpperTriangular{T, CuMatrix{T}}}, + x::CuVector{T}, +) where {T<:CUBLAS.CublasFloat} + return CUBLAS.trsv!('U', 'C', 'N', parent(parent(A)), x) +end + +function LinearAlgebra.ldiv!( + A::Transpose{T, <:UpperTriangular{T, CuMatrix{T}}}, + x::CuVector{T}, +) where {T<:CUBLAS.CublasFloat} + return CUBLAS.trsv!('U', 'T', 'N', parent(parent(A)), x) +end + +function LinearAlgebra.ldiv!( + A::LowerTriangular{T, <:CuMatrix{T}}, + x::CuVector{T}, +) where T<:CublasFloat + return CUBLAS.trsv!('L', 'N', 'N', parent(A), x) +end + +function LinearAlgebra.ldiv!( + A::Adjoint{T, <:LowerTriangular{T, CuMatrix{T}}}, + x::CuVector{T}, +) where {T<:CUBLAS.CublasFloat} + return CUBLAS.trsv!('L', 'C', 'N', parent(parent(A)), x) +end + +function LinearAlgebra.ldiv!( + A::Transpose{T, <:LowerTriangular{T, CuMatrix{T}}}, + x::CuVector{T}, +) where {T<:CUBLAS.CublasFloat} + return CUBLAS.trsv!('L', 'T', 'N', parent(parent(A)), x) +end + # diff --git a/test/blas.jl b/test/blas.jl index 9c20d7c0..d450e060 100644 --- a/test/blas.jl +++ b/test/blas.jl @@ -303,6 +303,55 @@ end # level 1 testset @test y ≈ h_y end + @testset "ldiv!(::UpperTriangular, ::CuVector)" begin + A = copy(sA) + dA = CuArray(A) + dy = copy(dx) + ldiv!(UpperTriangular(dA), dy) + y = UpperTriangular(A) \ x + @test y ≈ Array(dy) + end + @testset "ldiv!(::AdjointUpperTriangular, ::CuVector)" begin + A = copy(sA) + dA = CuArray(A) + dy = copy(dx) + ldiv!(adjoint(UpperTriangular(dA)), dy) + y = adjoint(UpperTriangular(A)) \ x + @test y ≈ Array(dy) + end + @testset "ldiv!(::TransposeUpperTriangular, ::CuVector)" begin + A = copy(sA) + dA = CuArray(A) + dy = copy(dx) + ldiv!(transpose(UpperTriangular(dA)), dy) + y = transpose(UpperTriangular(A)) \ x + @test y ≈ Array(dy) + end + @testset "ldiv!(::UpperTriangular, ::CuVector)" begin + A = copy(sA) + dA = CuArray(A) + dy = copy(dx) + ldiv!(LowerTriangular(dA), dy) + y = LowerTriangular(A) \ x + @test y ≈ Array(dy) + end + @testset "ldiv!(::AdjointUpperTriangular, ::CuVector)" begin + A = copy(sA) + dA = CuArray(A) + dy = copy(dx) + ldiv!(adjoint(LowerTriangular(dA)), dy) + y = adjoint(LowerTriangular(A)) \ x + @test y ≈ Array(dy) + end + @testset "ldiv!(::TransposeUpperTriangular, ::CuVector)" begin + A = copy(sA) + dA = CuArray(A) + dy = copy(dx) + ldiv!(transpose(LowerTriangular(dA)), dy) + y = transpose(LowerTriangular(A)) \ x + @test y ≈ Array(dy) + end + A = rand(elty,m,m) x = rand(elty,m) y = rand(elty,m)