Skip to content
This repository has been archived by the owner on Mar 12, 2021. It is now read-only.

More tests for BLAS and improve error msg #352

Merged
merged 1 commit into from
Jun 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/blas/highlevel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function LinearAlgebra.BLAS.dotc(DX::CuArray{T}, DY::CuArray{T}) where T<:Union{
end

function LinearAlgebra.BLAS.dot(DX::CuArray{T}, DY::CuArray{T}) where T<:Union{ComplexF32,ComplexF64}
dotc(DX, DY)
BLAS.dotc(DX, DY)
end

function LinearAlgebra.BLAS.dotu(DX::CuArray{T}, DY::CuArray{T}) where T<:Union{ComplexF32,ComplexF64}
Expand All @@ -43,7 +43,7 @@ LinearAlgebra.norm(x::CublasArray) = nrm2(x)
LinearAlgebra.BLAS.asum(x::CublasArray) = asum(length(x), x, 1)

function LinearAlgebra.axpy!(alpha::Number, x::CuArray{T}, y::CuArray{T}) where T<:CublasFloat
length(x)==length(y) || throw(DimensionMismatch(""))
length(x)==length(y) || throw(DimensionMismatch("axpy arguments have lengths $(length(x)) and $(length(y))"))
axpy!(length(x), convert(T,alpha), x, 1, y, 1)
end

Expand Down
21 changes: 21 additions & 0 deletions test/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ CUBLAS.cublasSetMathMode(CUBLAS.CUBLAS_DEFAULT_MATH)
if T <: Real
@test testf(argmin, rand(T, m))
@test testf(argmax, rand(T, m))
else
@test testf(BLAS.dotu, rand(T, m), rand(T, m))
x = rand(T, m)
y = rand(T, m)
dx = CuArray(x)
dy = CuArray(y)
dz = BLAS.dot(dx, dy)
z = BLAS.dotc(x, y)
@test dz ≈ z
end
end # level 1 testset

Expand All @@ -50,6 +59,16 @@ end # level 1 testset
@test testf(*, rand(elty, m, n), rand(elty, n))
@test testf(*, transpose(rand(elty, m, n)), rand(elty, m))
@test testf(*, rand(elty, m, n)', rand(elty, m))
x = rand(elty, m)
A = rand(elty, m, m + 1 )
y = rand(elty, m)
dx = CuArray(x)
dA = CuArray(A)
dy = CuArray(y)
@test_throws DimensionMismatch mul!(dy, dA, dx)
A = rand(elty, m + 1, m )
dA = CuArray(A)
@test_throws DimensionMismatch mul!(dy, dA, dx)
end
@testset "banded methods" begin
# bands
Expand Down Expand Up @@ -360,6 +379,8 @@ end # level 1 testset
# compare
@test C1 ≈ h_C1
@test C2 ≈ h_C2
@test_throws ArgumentError mul!(dhA, dhA, dsA)
@test_throws DimensionMismatch mul!(d_C1, d_A, dsA)
end

@testset "gemm" begin
Expand Down