Skip to content
This repository has been archived by the owner on Mar 12, 2021. It is now read-only.

Commit

Permalink
Merge pull request #383 from willtebbutt/wct/ldiv-vector
Browse files Browse the repository at this point in the history
Add ldiv! and tests
  • Loading branch information
vchuravy authored Aug 6, 2019
2 parents d729323 + fe8fe6b commit 3ab04e2
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 0 deletions.
44 changes: 44 additions & 0 deletions src/blas/highlevel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,50 @@ LinearAlgebra.lmul!(Y::CuVector{T}, A::LinearAlgebra.Transpose{<:Any, CuMatrix{T
LinearAlgebra.lmul!(Y::CuVector{T}, A::LinearAlgebra.Adjoint{<:Any, CuMatrix{T}}, B::CuVector{T}) where T<:CublasFloat = gemv_wrapper!(Y, 'T', A.parent, B)
LinearAlgebra.lmul!(Y::CuVector{T}, A::LinearAlgebra.Adjoint{<:Any, CuMatrix{T}}, B::CuVector{T}) where T<:CublasComplex = gemv_wrapper!(Y, 'C', A.parent, B)

# TRSV

function LinearAlgebra.ldiv!(
A::UpperTriangular{T, <:CuMatrix{T}},
x::CuVector{T},
) where T<:CublasFloat
return CUBLAS.trsv!('U', 'N', 'N', parent(A), x)
end

function LinearAlgebra.ldiv!(
A::Adjoint{T, <:UpperTriangular{T, CuMatrix{T}}},
x::CuVector{T},
) where {T<:CUBLAS.CublasFloat}
return CUBLAS.trsv!('U', 'C', 'N', parent(parent(A)), x)
end

function LinearAlgebra.ldiv!(
A::Transpose{T, <:UpperTriangular{T, CuMatrix{T}}},
x::CuVector{T},
) where {T<:CUBLAS.CublasFloat}
return CUBLAS.trsv!('U', 'T', 'N', parent(parent(A)), x)
end

function LinearAlgebra.ldiv!(
A::LowerTriangular{T, <:CuMatrix{T}},
x::CuVector{T},
) where T<:CublasFloat
return CUBLAS.trsv!('L', 'N', 'N', parent(A), x)
end

function LinearAlgebra.ldiv!(
A::Adjoint{T, <:LowerTriangular{T, CuMatrix{T}}},
x::CuVector{T},
) where {T<:CUBLAS.CublasFloat}
return CUBLAS.trsv!('L', 'C', 'N', parent(parent(A)), x)
end

function LinearAlgebra.ldiv!(
A::Transpose{T, <:LowerTriangular{T, CuMatrix{T}}},
x::CuVector{T},
) where {T<:CUBLAS.CublasFloat}
return CUBLAS.trsv!('L', 'T', 'N', parent(parent(A)), x)
end



#
Expand Down
49 changes: 49 additions & 0 deletions test/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,55 @@ end # level 1 testset
@test y h_y
end

@testset "ldiv!(::UpperTriangular, ::CuVector)" begin
A = copy(sA)
dA = CuArray(A)
dy = copy(dx)
ldiv!(UpperTriangular(dA), dy)
y = UpperTriangular(A) \ x
@test y Array(dy)
end
@testset "ldiv!(::AdjointUpperTriangular, ::CuVector)" begin
A = copy(sA)
dA = CuArray(A)
dy = copy(dx)
ldiv!(adjoint(UpperTriangular(dA)), dy)
y = adjoint(UpperTriangular(A)) \ x
@test y Array(dy)
end
@testset "ldiv!(::TransposeUpperTriangular, ::CuVector)" begin
A = copy(sA)
dA = CuArray(A)
dy = copy(dx)
ldiv!(transpose(UpperTriangular(dA)), dy)
y = transpose(UpperTriangular(A)) \ x
@test y Array(dy)
end
@testset "ldiv!(::UpperTriangular, ::CuVector)" begin
A = copy(sA)
dA = CuArray(A)
dy = copy(dx)
ldiv!(LowerTriangular(dA), dy)
y = LowerTriangular(A) \ x
@test y Array(dy)
end
@testset "ldiv!(::AdjointUpperTriangular, ::CuVector)" begin
A = copy(sA)
dA = CuArray(A)
dy = copy(dx)
ldiv!(adjoint(LowerTriangular(dA)), dy)
y = adjoint(LowerTriangular(A)) \ x
@test y Array(dy)
end
@testset "ldiv!(::TransposeUpperTriangular, ::CuVector)" begin
A = copy(sA)
dA = CuArray(A)
dy = copy(dx)
ldiv!(transpose(LowerTriangular(dA)), dy)
y = transpose(LowerTriangular(A)) \ x
@test y Array(dy)
end

A = rand(elty,m,m)
x = rand(elty,m)
y = rand(elty,m)
Expand Down

0 comments on commit 3ab04e2

Please sign in to comment.