diff --git a/ext/BatchedRoutinesCUDAExt/factorization.jl b/ext/BatchedRoutinesCUDAExt/factorization.jl index 9a5710d..2400792 100644 --- a/ext/BatchedRoutinesCUDAExt/factorization.jl +++ b/ext/BatchedRoutinesCUDAExt/factorization.jl @@ -79,10 +79,8 @@ function LinearAlgebra.ldiv!(A::CuBatchedQR, b::CuMatrix) @assert nbatches(A) == nbatches(b) (; τ, factors) = A n, m = size(A) .÷ nbatches(A) - # TODO: Threading? for i in 1:nbatches(A) - CUSOLVER.ormqr!('L', 'C', batchview(factors, i), batchview(τ, i), - batchview(b, i)) + CUSOLVER.ormqr!('L', 'C', batchview(factors, i), batchview(τ, i), batchview(b, i)) end vecX = [reshape(view(bᵢ, 1:m), :, 1) for bᵢ in batchview(b)] if n != m @@ -95,8 +93,10 @@ function LinearAlgebra.ldiv!(A::CuBatchedQR, b::CuMatrix) end function LinearAlgebra.ldiv!(X::CuMatrix, A::CuBatchedQR, b::CuMatrix) - copyto!(X, b) - return LinearAlgebra.ldiv!(A, X) + @assert size(X, 1) ≤ size(b, 1) + b_ = LinearAlgebra.ldiv!(A, copy(b)) + copyto!(X, view(b_, 1:size(X, 1), :)) + return X end # Low Level Wrappers diff --git a/ext/BatchedRoutinesForwardDiffExt.jl b/ext/BatchedRoutinesForwardDiffExt.jl index ba85495..0793a10 100644 --- a/ext/BatchedRoutinesForwardDiffExt.jl +++ b/ext/BatchedRoutinesForwardDiffExt.jl @@ -152,7 +152,7 @@ end push!(calls, :(ck = ForwardDiff.Chunk{ForwardDiff.pickchunksize(length(u))}())) else push!(calls, quote - @assert CK ≤ length(u) "Chunk size must be ≤ the length of u" + @assert CK≤length(u) "Chunk size must be ≤ the length of u" ck = ForwardDiff.Chunk{CK}() end) end diff --git a/test/integration_tests.jl b/test/integration_tests.jl index 58a91f9..60c4d60 100644 --- a/test/integration_tests.jl +++ b/test/integration_tests.jl @@ -4,15 +4,25 @@ rng = get_stable_rng(1001) @testset "$mode" for (mode, aType, dev, ongpu) in MODES - A1 = UniformBlockDiagonalMatrix(rand(rng, 32, 32, 8)) |> dev - A2 = Matrix(A1) |> dev - b = rand(rng, size(A1, 2)) |> dev + for dims in ((8, 8, 2), (5, 3, 2)) + A1 = UniformBlockDiagonalMatrix(rand(rng, dims...)) |> dev + A2 = Matrix(A1) |> dev + b = rand(rng, size(A1, 1)) |> dev + + prob1 = LinearProblem(A1, b) + prob2 = LinearProblem(A2, b) - prob1 = LinearProblem(A1, b) - prob2 = LinearProblem(A2, b) + if dims[1] == dims[2] + solvers = [LUFactorization(), QRFactorization(), KrylovJL_GMRES()] + else + solvers = [QRFactorization(), KrylovJL_LSMR()] + end - @test solve(prob1, LUFactorization()).u ≈ solve(prob2, LUFactorization()).u - @test solve(prob1, QRFactorization()).u ≈ solve(prob2, QRFactorization()).u - @test solve(prob1, KrylovJL_GMRES()).u ≈ solve(prob2, KrylovJL_GMRES()).u + @testset "solver: $(solver)" for solver in solvers + x1 = solve(prob1, solver) + x2 = solve(prob2, solver) + @test x1.u ≈ x2.u + end + end end end