Skip to content
This repository has been archived by the owner on Mar 12, 2021. It is now read-only.

Commit

Permalink
Try #695:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] authored Apr 29, 2020
2 parents 2258a24 + 5970046 commit cae2a33
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 0 deletions.
58 changes: 58 additions & 0 deletions src/solver/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -823,3 +823,61 @@ for (jname, bname, fname, elty, relty) in ((:sygvj!, :cusolverDnSsygvj_bufferSiz
end
end
end

for (jname, bname, fname, elty, relty) in ((:syevjBatched!, :cusolverDnSsyevjBatched_bufferSize, :cusolverDnSsyevjBatched, :Float32, :Float32),
(:syevjBatched!, :cusolverDnDsyevjBatched_bufferSize, :cusolverDnDsyevjBatched, :Float64, :Float64),
(:heevjBatched!, :cusolverDnCheevjBatched_bufferSize, :cusolverDnCheevjBatched, :ComplexF32, :Float32),
(:heevjBatched!, :cusolverDnZheevjBatched_bufferSize, :cusolverDnZheevjBatched, :ComplexF64, :Float64)
)
@eval begin
function $jname(jobz::Char,
uplo::Char,
A::CuArray{$elty};
tol::$relty=eps($relty),
max_sweeps::Int=100)

# Set up information for the solver arguments
cuuplo = cublasfill(uplo)
cujobz = cusolverjob(jobz)
n = checksquare(A)
lda = max(1, stride(A, 2))
batchSize = size(A,3)
W = CuArray{$relty}(undef, n,batchSize)
params = Ref{syevjInfo_t}(C_NULL)
devinfo = CuArray{Cint}(undef, batchSize)

# Initialize the solver parameters
cusolverDnCreateSyevjInfo(params)
cusolverDnXsyevjSetTolerance(params[], tol)
cusolverDnXsyevjSetMaxSweeps(params[], max_sweeps)

# Calculate the workspace size
lwork = @argout(CUSOLVER.$bname(dense_handle(), cujobz, cuuplo, n,
A, lda, W, out(Ref{Cint}(0)), params, batchSize))[]

# Run the solver
@workspace eltyp=$elty size=lwork work->begin
$fname(dense_handle(), cujobz, cuuplo, n, A, lda, W, work,
lwork, devinfo, params[], batchSize)
end

# Copy the solver info and delete the device memory
info = @allowscalar collect(devinfo)
unsafe_free!(devinfo)

# Double check the solver's exit status
for i = 1:batchSize
if info[i] < 0
throw(ArgumentError("The $(info)th parameter of the $(i)th solver is wrong"))
end
end

# Return eigenvalues (in W) and possibly eigenvectors (in A)
if jobz == 'N'
return W
elseif jobz == 'V'
return W, A
end
end
end
end
54 changes: 54 additions & 0 deletions test/solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,60 @@ k = 1
@test Eig.values h_W
end

@testset "elty = $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64]
@testset "syevjBatched!" begin
# Generate a random symmetric/hermitian matrix
A = rand(elty, m,m,n)
A += permutedims(A, (2,1,3))
d_A = CuArray(A)

# Run the solver
local d_W, d_V
if( elty <: Complex )
d_W, d_V = CUSOLVER.heevjBatched!('V','U', d_A)
else
d_W, d_V = CUSOLVER.syevjBatched!('V','U', d_A)
end

# Pull it back to hardware
h_W = collect(d_W)
h_V = collect(d_V)

# Use non-GPU blas to estimate the eigenvalues as well
for i = 1:n
# Get our eigenvalues
Eig = eigen(LinearAlgebra.Hermitian(A[:,:,i]))

# Compare to the actual ones
@test Eig.values h_W[:,i]
@test abs.(Eig.vectors'*h_V[:,:,i]) I
end

# Do it all again, but with the option to not compute eigenvectors
d_A = CuArray(A)

# Run the solver
local d_W
if( elty <: Complex )
d_W = CUSOLVER.heevjBatched!('N','U', d_A)
else
d_W = CUSOLVER.syevjBatched!('N','U', d_A)
end

# Pull it back to hardware
h_W = collect(d_W)

# Use non-GPU blas to estimate the eigenvalues as well
for i = 1:n
# Get the reference results
Eig = eigen(LinearAlgebra.Hermitian(A[:,:,i]))

# Compare to the actual ones
@test Eig.values h_W[:,i]
end
end
end

@testset "svd with $method method" for
method in (CUSOLVER.QRAlgorithm, CUSOLVER.JacobiAlgorithm),
(_m, _n) in ((m, n), (n, m))
Expand Down

0 comments on commit cae2a33

Please sign in to comment.