Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
Add integration tests with LinearSolve.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 15, 2024
1 parent 16513d2 commit 56f3cc5
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 46 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ FillArrays = "1.9.3"
FiniteDiff = "2.22"
ForwardDiff = "0.10.36"
LinearAlgebra = "1.10"
LinearSolve = "2.27"
LuxCUDA = "0.3.2"
LuxDeviceUtils = "0.1.17"
LuxTestUtils = "0.1.15"
Expand Down
38 changes: 0 additions & 38 deletions examples/linear_solve.jl

This file was deleted.

1 change: 0 additions & 1 deletion examples/nonlinear_solve.jl

This file was deleted.

3 changes: 2 additions & 1 deletion ext/BatchedRoutinesCUDAExt/BatchedRoutinesCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ module BatchedRoutinesCUDAExt

using BatchedRoutines: AbstractBatchedMatrixFactorization, BatchedRoutines,
UniformBlockDiagonalMatrix, batchview, nbatches
using CUDA: CUBLAS, CUDA, CuArray, CuMatrix, CuPtr, CuVector, DenseCuArray, DenseCuMatrix
using CUDA: CUBLAS, CUDA, CUSOLVER, CuArray, CuMatrix, CuPtr, CuVector, DenseCuArray,
DenseCuMatrix
using ConcreteStructs: @concrete
using LinearAlgebra: BLAS, ColumnNorm, LinearAlgebra, NoPivot, RowMaximum, RowNonZero, mul!

Expand Down
52 changes: 52 additions & 0 deletions ext/BatchedRoutinesCUDAExt/factorization.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# LU Factorization
@concrete struct CuBatchedLU{T} <: AbstractBatchedMatrixFactorization
factors
pivot_array
Expand Down Expand Up @@ -47,6 +48,57 @@ function LinearAlgebra.ldiv!(X::CuMatrix, A::CuBatchedLU, b::CuMatrix)
return LinearAlgebra.ldiv!(A, X)
end

# QR Factorization
@concrete struct CuBatchedQR{T} <: AbstractBatchedMatrixFactorization
factors
τ
size
end

BatchedRoutines.nbatches(QR::CuBatchedQR) = length(QR.factors)
BatchedRoutines.batchview(QR::CuBatchedQR) = zip(QR.factors, QR.τ)
BatchedRoutines.batchview(QR::CuBatchedQR, idx::Int) = QR.factors[idx], QR.τ[idx]
Base.size(QR::CuBatchedQR) = QR.size
Base.size(QR::CuBatchedQR, i::Integer) = QR.size[i]
Base.eltype(::CuBatchedQR{T}) where {T} = T

function Base.show(io::IO, QR::CuBatchedQR)
return print(io, "CuBatchedQR() with Batch Count: $(nbatches(QR))")
end

function LinearAlgebra.qr!(::CuUniformBlockDiagonalMatrix, ::ColumnNorm; kwargs...)
throw(ArgumentError("ColumnNorm is not supported for batched CUDA QR factorization!"))
end

function LinearAlgebra.qr!(A::CuUniformBlockDiagonalMatrix, ::NoPivot; kwargs...)
τ, factors = CUBLAS.geqrf_batched!(collect(batchview(A)))
return CuBatchedQR{eltype(A)}(factors, τ, size(A))
end

function LinearAlgebra.ldiv!(A::CuBatchedQR, b::CuMatrix)
@assert nbatches(A) == nbatches(b)
(; τ, factors) = A
n, m = size(A) nbatches(A)
# TODO: Threading?
for i in 1:nbatches(A)
CUSOLVER.ormqr!('L', 'C', batchview(factors, i), batchview(τ, i),
batchview(b, i))
end
vecX = [reshape(view(bᵢ, 1:m), :, 1) for bᵢ in batchview(b)]
if n != m
sqF = [F_[1:m, 1:m] for F_ in batchview(factors)]
else
sqF = collect(batchview(factors))
end
CUBLAS.trsm_batched!('L', 'U', 'N', 'N', one(eltype(A)), sqF, vecX)
return b
end

function LinearAlgebra.ldiv!(X::CuMatrix, A::CuBatchedQR, b::CuMatrix)
copyto!(X, b)
return LinearAlgebra.ldiv!(A, X)
end

# Low Level Wrappers
for (fname, elty) in ((:cublasDgetrsBatched, :Float64), (:cublasSgetrsBatched, :Float32),
(:cublasZgetrsBatched, :ComplexF64), (:cublasCgetrsBatched, :ComplexF32))
Expand Down
10 changes: 7 additions & 3 deletions src/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,15 @@ Return a view of the `idx`-th batch of `A`. If `idx` is not supplied an iterator
batches is returned.
"""
batchview(A::AbstractArray, idx::Int) = selectdim(A, ndims(A), idx)
function batchview(A::AbstractVector, idx::Int)
return idx 2 && throw(BoundsError(batchview(A), idx))
function batchview(A::AbstractVector{T}, idx::Int) where {T}
if isbitstype(T)
idx 2 && throw(BoundsError(batchview(A), idx))
return A
end
return A[idx]
end
batchview(A::AbstractArray) = eachslice(A; dims=ndims(A))
batchview(A::AbstractVector) = (A,)
batchview(A::AbstractVector{T}) where {T} = isbitstype(T) ? (A,) : A

"""
batched_pinv(A::AbstractArray{T, 3}) where {T}
Expand Down
4 changes: 2 additions & 2 deletions test/autodiff_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

rng = get_stable_rng(1001)

@testset "$mode" for (mode, aType, device, ongpu) in MODES
@testset "$mode" for (mode, aType, dev, ongpu) in MODES
simple_batched_function = function (X, p)
X_ = reshape(X, :, nbatches(X))
return sum(abs2, X_ .* p; dims=1) .- sum(abs, X_ .* p; dims=1) .+ p .^ 2
Expand All @@ -30,7 +30,7 @@ end

rng = get_stable_rng(1001)

@testset "$mode" for (mode, aType, device, ongpu) in MODES
@testset "$mode" for (mode, aType, dev, ongpu) in MODES
simple_batched_function = function (X, p)
X_ = reshape(X, :, nbatches(X))
return sum(
Expand Down
18 changes: 17 additions & 1 deletion test/integration_tests.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,18 @@
@testitem "Linear Solve" setup=[SharedTestSetup] begin
@testitem "LinearSolve" setup=[SharedTestSetup] begin
using LinearSolve

rng = get_stable_rng(1001)

@testset "$mode" for (mode, aType, dev, ongpu) in MODES
A1 = UniformBlockDiagonalMatrix(rand(rng, 32, 32, 8)) |> dev
A2 = Matrix(A1) |> dev
b = rand(rng, size(A1, 2)) |> dev

prob1 = LinearProblem(A1, b)
prob2 = LinearProblem(A2, b)

@test solve(prob1, LUFactorization()).u solve(prob2, LUFactorization()).u
@test solve(prob1, QRFactorization()).u solve(prob2, QRFactorization()).u
@test solve(prob1, KrylovJL_GMRES()).u solve(prob2, KrylovJL_GMRES()).u
end
end

0 comments on commit 56f3cc5

Please sign in to comment.