Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
Add rectangular matrix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Mar 15, 2024
1 parent 56f3cc5 commit 5305eed
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 14 deletions.
10 changes: 5 additions & 5 deletions ext/BatchedRoutinesCUDAExt/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,8 @@ function LinearAlgebra.ldiv!(A::CuBatchedQR, b::CuMatrix)
@assert nbatches(A) == nbatches(b)
(; τ, factors) = A
n, m = size(A) nbatches(A)
# TODO: Threading?
for i in 1:nbatches(A)
CUSOLVER.ormqr!('L', 'C', batchview(factors, i), batchview(τ, i),
batchview(b, i))
CUSOLVER.ormqr!('L', 'C', batchview(factors, i), batchview(τ, i), batchview(b, i))
end
vecX = [reshape(view(bᵢ, 1:m), :, 1) for bᵢ in batchview(b)]
if n != m
Expand All @@ -95,8 +93,10 @@ function LinearAlgebra.ldiv!(A::CuBatchedQR, b::CuMatrix)
end

function LinearAlgebra.ldiv!(X::CuMatrix, A::CuBatchedQR, b::CuMatrix)
copyto!(X, b)
return LinearAlgebra.ldiv!(A, X)
@assert size(X, 1) size(b, 1)
b_ = LinearAlgebra.ldiv!(A, copy(b))
copyto!(X, view(b_, 1:size(X, 1), :))
return X
end

# Low Level Wrappers
Expand Down
2 changes: 1 addition & 1 deletion ext/BatchedRoutinesForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ end
push!(calls, :(ck = ForwardDiff.Chunk{ForwardDiff.pickchunksize(length(u))}()))
else
push!(calls, quote
@assert CK length(u) "Chunk size must be ≤ the length of u"
@assert CKlength(u) "Chunk size must be ≤ the length of u"
ck = ForwardDiff.Chunk{CK}()
end)
end
Expand Down
26 changes: 18 additions & 8 deletions test/integration_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,25 @@
rng = get_stable_rng(1001)

@testset "$mode" for (mode, aType, dev, ongpu) in MODES
A1 = UniformBlockDiagonalMatrix(rand(rng, 32, 32, 8)) |> dev
A2 = Matrix(A1) |> dev
b = rand(rng, size(A1, 2)) |> dev
for dims in ((8, 8, 2), (5, 3, 2))
A1 = UniformBlockDiagonalMatrix(rand(rng, dims...)) |> dev
A2 = Matrix(A1) |> dev
b = rand(rng, size(A1, 1)) |> dev

prob1 = LinearProblem(A1, b)
prob2 = LinearProblem(A2, b)

prob1 = LinearProblem(A1, b)
prob2 = LinearProblem(A2, b)
if dims[1] == dims[2]
solvers = [LUFactorization(), QRFactorization(), KrylovJL_GMRES()]
else
solvers = [QRFactorization(), KrylovJL_LSMR()]
end

@test solve(prob1, LUFactorization()).u solve(prob2, LUFactorization()).u
@test solve(prob1, QRFactorization()).u solve(prob2, QRFactorization()).u
@test solve(prob1, KrylovJL_GMRES()).u solve(prob2, KrylovJL_GMRES()).u
@testset "solver: $(solver)" for solver in solvers
x1 = solve(prob1, solver)
x2 = solve(prob2, solver)
@test x1.u x2.u
end
end
end
end

0 comments on commit 5305eed

Please sign in to comment.