From 5970046488266a36d1da96e4b0a553c163ca3f7f Mon Sep 17 00:00:00 2001 From: "Christopher M. Pierce" Date: Mon, 27 Apr 2020 19:37:36 -0400 Subject: [PATCH] adds wrappers for syevjBatched/heevjBatched family of CUSOLVER functions --- src/solver/wrappers.jl | 58 ++++++++++++++++++++++++++++++++++++++++++ test/solver.jl | 54 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+) diff --git a/src/solver/wrappers.jl b/src/solver/wrappers.jl index d482f309..84f4ad9c 100644 --- a/src/solver/wrappers.jl +++ b/src/solver/wrappers.jl @@ -823,3 +823,61 @@ for (jname, bname, fname, elty, relty) in ((:sygvj!, :cusolverDnSsygvj_bufferSiz end end end + +for (jname, bname, fname, elty, relty) in ((:syevjBatched!, :cusolverDnSsyevjBatched_bufferSize, :cusolverDnSsyevjBatched, :Float32, :Float32), + (:syevjBatched!, :cusolverDnDsyevjBatched_bufferSize, :cusolverDnDsyevjBatched, :Float64, :Float64), + (:heevjBatched!, :cusolverDnCheevjBatched_bufferSize, :cusolverDnCheevjBatched, :ComplexF32, :Float32), + (:heevjBatched!, :cusolverDnZheevjBatched_bufferSize, :cusolverDnZheevjBatched, :ComplexF64, :Float64) + ) + @eval begin + function $jname(jobz::Char, + uplo::Char, + A::CuArray{$elty}; + tol::$relty=eps($relty), + max_sweeps::Int=100) + + # Set up information for the solver arguments + cuuplo = cublasfill(uplo) + cujobz = cusolverjob(jobz) + n = checksquare(A) + lda = max(1, stride(A, 2)) + batchSize = size(A,3) + W = CuArray{$relty}(undef, n,batchSize) + params = Ref{syevjInfo_t}(C_NULL) + devinfo = CuArray{Cint}(undef, batchSize) + + # Initialize the solver parameters + cusolverDnCreateSyevjInfo(params) + cusolverDnXsyevjSetTolerance(params[], tol) + cusolverDnXsyevjSetMaxSweeps(params[], max_sweeps) + + # Calculate the workspace size + lwork = @argout(CUSOLVER.$bname(dense_handle(), cujobz, cuuplo, n, + A, lda, W, out(Ref{Cint}(0)), params, batchSize))[] + + # Run the solver + @workspace eltyp=$elty size=lwork work->begin + $fname(dense_handle(), cujobz, cuuplo, n, A, lda, W, work, + lwork, devinfo, params[], batchSize) + end + + # Copy the solver info and delete the device memory + info = @allowscalar collect(devinfo) + unsafe_free!(devinfo) + + # Double check the solver's exit status + for i = 1:batchSize + if info[i] < 0 + throw(ArgumentError("The $(info)th parameter of the $(i)th solver is wrong")) + end + end + + # Return eigenvalues (in W) and possibly eigenvectors (in A) + if jobz == 'N' + return W + elseif jobz == 'V' + return W, A + end + end + end +end diff --git a/test/solver.jl b/test/solver.jl index 698ec07d..96b8c056 100644 --- a/test/solver.jl +++ b/test/solver.jl @@ -279,6 +279,60 @@ k = 1 @test Eig.values ≈ h_W end + @testset "elty = $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64] + @testset "syevjBatched!" begin + # Generate a random symmetric/hermitian matrix + A = rand(elty, m,m,n) + A += permutedims(A, (2,1,3)) + d_A = CuArray(A) + + # Run the solver + local d_W, d_V + if( elty <: Complex ) + d_W, d_V = CUSOLVER.heevjBatched!('V','U', d_A) + else + d_W, d_V = CUSOLVER.syevjBatched!('V','U', d_A) + end + + # Pull it back to hardware + h_W = collect(d_W) + h_V = collect(d_V) + + # Use non-GPU blas to estimate the eigenvalues as well + for i = 1:n + # Get our eigenvalues + Eig = eigen(LinearAlgebra.Hermitian(A[:,:,i])) + + # Compare to the actual ones + @test Eig.values ≈ h_W[:,i] + @test abs.(Eig.vectors'*h_V[:,:,i]) ≈ I + end + + # Do it all again, but with the option to not compute eigenvectors + d_A = CuArray(A) + + # Run the solver + local d_W + if( elty <: Complex ) + d_W = CUSOLVER.heevjBatched!('N','U', d_A) + else + d_W = CUSOLVER.syevjBatched!('N','U', d_A) + end + + # Pull it back to hardware + h_W = collect(d_W) + + # Use non-GPU blas to estimate the eigenvalues as well + for i = 1:n + # Get the reference results + Eig = eigen(LinearAlgebra.Hermitian(A[:,:,i])) + + # Compare to the actual ones + @test Eig.values ≈ h_W[:,i] + end + end + end + @testset "svd with $method method" for method in (CUSOLVER.QRAlgorithm, CUSOLVER.JacobiAlgorithm), (_m, _n) in ((m, n), (n, m))