diff --git a/Project.toml b/Project.toml index b69ab6d..85ec644 100644 --- a/Project.toml +++ b/Project.toml @@ -20,6 +20,7 @@ SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -27,6 +28,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" BatchedRoutinesCUDAExt = ["CUDA"] BatchedRoutinesFiniteDiffExt = ["FiniteDiff"] BatchedRoutinesForwardDiffExt = ["ForwardDiff"] +BatchedRoutinesLinearSolveExt = ["LinearSolve"] BatchedRoutinesReverseDiffExt = ["ReverseDiff"] BatchedRoutinesZygoteExt = ["Zygote"] diff --git a/ext/BatchedRoutinesLinearSolveExt.jl b/ext/BatchedRoutinesLinearSolveExt.jl new file mode 100644 index 0000000..a808279 --- /dev/null +++ b/ext/BatchedRoutinesLinearSolveExt.jl @@ -0,0 +1,128 @@ +module BatchedRoutinesLinearSolveExt + +using ArrayInterface: ArrayInterface +using BatchedRoutines: BatchedRoutines, UniformBlockDiagonalOperator, getdata +using LinearAlgebra: LinearAlgebra +using LinearSolve: LinearSolve + +# Overload LinearProblem, else causing problems in the adjoint code +function LinearSolve.LinearProblem(op::UniformBlockDiagonalOperator, b, args...; kwargs...) + return LinearSolve.LinearProblem{true}(op, b, args...; kwargs...) +end + +# Default Algorithm +function LinearSolve.defaultalg( + op::UniformBlockDiagonalOperator, b, assump::LinearSolve.OperatorAssumptions{Bool}) + alg = if assump.issq + LinearSolve.DefaultAlgorithmChoice.LUFactorization + elseif assump.condition === LinearSolve.OperatorCondition.WellConditioned + LinearSolve.DefaultAlgorithmChoice.NormalCholeskyFactorization + elseif assump.condition === LinearSolve.OperatorCondition.IllConditioned + if LinearSolve.is_underdetermined(op) + LinearSolve.DefaultAlgorithmChoice.QRFactorizationPivoted + else + LinearSolve.DefaultAlgorithmChoice.QRFactorization + end + elseif assump.condition === LinearSolve.OperatorCondition.VeryIllConditioned + if LinearSolve.is_underdetermined(op) + LinearSolve.DefaultAlgorithmChoice.QRFactorizationPivoted + else + LinearSolve.DefaultAlgorithmChoice.QRFactorization + end + elseif assump.condition === LinearSolve.OperatorCondition.SuperIllConditioned + LinearSolve.DefaultAlgorithmChoice.SVDFactorization + else + error("Special factorization not handled in current default algorithm.") + end + return LinearSolve.DefaultLinearSolver(alg) +end + +# GenericLUFactorization +function LinearSolve.init_cacheval(alg::LinearSolve.GenericLUFactorization, + A::UniformBlockDiagonalOperator, b, u, Pl, Pr, maxiters::Int, + abstol, reltol, verbose::Bool, ::LinearSolve.OperatorAssumptions) + A_ = UniformBlockDiagonalOperator(similar(getdata(A), 0, 0, 1)) + return LinearAlgebra.generic_lufact!(A_, alg.pivot; check=false) +end + +function LinearSolve.do_factorization( + alg::LinearSolve.GenericLUFactorization, A::UniformBlockDiagonalOperator, b, u) + return LinearAlgebra.generic_lufact!(A, alg.pivot; check=false) +end + +# LUFactorization +function LinearSolve.init_cacheval( + alg::LinearSolve.LUFactorization, A::UniformBlockDiagonalOperator, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, ::LinearSolve.OperatorAssumptions) + A_ = UniformBlockDiagonalOperator(similar(getdata(A), 0, 0, 1)) + return LinearAlgebra.lu!(A_, alg.pivot; check=false) +end + +function LinearSolve.do_factorization( + alg::LinearSolve.LUFactorization, A::UniformBlockDiagonalOperator, b, u) + return LinearAlgebra.lu!(A, alg.pivot; check=false) +end + +# QRFactorization +function LinearSolve.init_cacheval( + alg::LinearSolve.QRFactorization, A::UniformBlockDiagonalOperator, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, ::LinearSolve.OperatorAssumptions) + A_ = UniformBlockDiagonalOperator(similar(getdata(A), 0, 0, 1)) + return LinearAlgebra.qr!(A_, alg.pivot) +end + +function LinearSolve.do_factorization( + alg::LinearSolve.QRFactorization, A::UniformBlockDiagonalOperator, b, u) + alg.inplace && return LinearAlgebra.qr!(A, alg.pivot) + return LinearAlgebra.qr(A, alg.pivot) +end + +# CholeskyFactorization +function LinearSolve.init_cacheval(alg::LinearSolve.CholeskyFactorization, + A::UniformBlockDiagonalOperator, b, u, Pl, Pr, maxiters::Int, + abstol, reltol, verbose::Bool, ::LinearSolve.OperatorAssumptions) + A_ = UniformBlockDiagonalOperator(similar(getdata(A), 0, 0, 1)) + return ArrayInterface.cholesky_instance(A_, alg.pivot) +end + +function LinearSolve.do_factorization( + alg::LinearSolve.CholeskyFactorization, A::UniformBlockDiagonalOperator, b, u) + return LinearAlgebra.cholesky!(A, alg.pivot; check=false) +end + +# NormalCholeskyFactorization +function LinearSolve.init_cacheval(alg::LinearSolve.NormalCholeskyFactorization, + A::UniformBlockDiagonalOperator, b, u, Pl, Pr, maxiters::Int, + abstol, reltol, verbose::Bool, ::LinearSolve.OperatorAssumptions) + A_ = UniformBlockDiagonalOperator(similar(getdata(A), 0, 0, 1)) + return ArrayInterface.cholesky_instance(A_, alg.pivot) +end + +function LinearSolve.solve!(cache::LinearSolve.LinearCache{<:UniformBlockDiagonalOperator}, + alg::LinearSolve.NormalCholeskyFactorization; kwargs...) + A = cache.A + if cache.isfresh + fact = LinearAlgebra.cholesky!(A' * A, alg.pivot; check=false) + cache.cacheval = fact + cache.isfresh = false + end + y = LinearAlgebra.ldiv!( + cache.u, LinearSolve.@get_cacheval(cache, :NormalCholeskyFactorization), + A' * cache.b) + return LinearSolve.SciMLBase.build_linear_solution(alg, y, nothing, cache) +end + +# SVDFactorization +function LinearSolve.init_cacheval( + alg::LinearSolve.SVDFactorization, A::UniformBlockDiagonalOperator, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, ::LinearSolve.OperatorAssumptions) + A_ = UniformBlockDiagonalOperator(similar(getdata(A), 0, 0, 1)) + return ArrayInterface.svd_instance(A_) +end + +function LinearSolve.do_factorization( + alg::LinearSolve.SVDFactorization, A::UniformBlockDiagonalOperator, b, u) + return LinearAlgebra.svd!(A; alg.full, alg.alg) +end + +end diff --git a/src/chainrules.jl b/src/chainrules.jl index 15ac2c6..44f5e56 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -86,16 +86,12 @@ end # batched_mul rrule function CRC.rrule(::typeof(_batched_mul), A::AbstractArray{T1, 3}, B::AbstractArray{T2, 3}) where {T1, T2} - function ∇batched_mul(_Δ) + ∇batched_mul = @closure _Δ -> begin Δ = CRC.unthunk(_Δ) - ∂A = CRC.@thunk begin - tmp = batched_mul(Δ, batched_adjoint(B)) - size(A, 3) == 1 ? sum(tmp; dims=3) : tmp - end - ∂B = CRC.@thunk begin - tmp = batched_mul(batched_adjoint(A), Δ) - size(B, 3) == 1 ? sum(tmp; dims=3) : tmp - end + tmpA = batched_mul(Δ, batched_adjoint(B)) + ∂A = size(A, 3) == 1 ? sum(tmpA; dims=3) : tmpA + tmpB = batched_mul(batched_adjoint(A), Δ) + ∂B = size(B, 3) == 1 ? sum(tmpB; dims=3) : tmpB return (NoTangent(), ∂A, ∂B) end return batched_mul(A, B), ∇batched_mul @@ -103,7 +99,7 @@ end # constructor function CRC.rrule(::Type{<:UniformBlockDiagonalOperator}, data) - function ∇UniformBlockDiagonalOperator(Δ) + ∇UniformBlockDiagonalOperator = @closure Δ -> begin ∂data = Δ isa UniformBlockDiagonalOperator ? getdata(Δ) : (Δ isa NoTangent ? NoTangent() : Δ) return (NoTangent(), ∂data) @@ -113,7 +109,7 @@ end function CRC.rrule(::typeof(getproperty), op::UniformBlockDiagonalOperator, x::Symbol) @assert x === :data - ∇getproperty(Δ) = (NoTangent(), UniformBlockDiagonalOperator(Δ)) + ∇getproperty = @closure Δ -> (NoTangent(), UniformBlockDiagonalOperator(Δ)) return op.data, ∇getproperty end diff --git a/src/factorization.jl b/src/factorization.jl index 8727a2a..bc0561c 100644 --- a/src/factorization.jl +++ b/src/factorization.jl @@ -82,18 +82,19 @@ function Base.show(io::IO, mime::MIME"text/plain", F::GenericBatchedFactorizatio show(io, mime, first(F.fact)) end -for fact in (:qr, :lu, :cholesky) +for fact in (:qr, :lu, :cholesky, :generic_lufact, :svd) fact! = Symbol(fact, :!) - @eval begin - function LinearAlgebra.$(fact)(op::UniformBlockDiagonalOperator, args...; kwargs...) + if isdefined(LinearAlgebra, fact) + @eval function LinearAlgebra.$(fact)( + op::UniformBlockDiagonalOperator, args...; kwargs...) return LinearAlgebra.$(fact!)(copy(op), args...; kwargs...) end + end - function LinearAlgebra.$(fact!)( - op::UniformBlockDiagonalOperator, args...; kwargs...) - fact = map(Aᵢ -> LinearAlgebra.$(fact!)(Aᵢ, args...; kwargs...), batchview(op)) - return GenericBatchedFactorization(LinearAlgebra.$(fact!), fact) - end + @eval function LinearAlgebra.$(fact!)( + op::UniformBlockDiagonalOperator, args...; kwargs...) + fact = map(Aᵢ -> LinearAlgebra.$(fact!)(Aᵢ, args...; kwargs...), batchview(op)) + return GenericBatchedFactorization(LinearAlgebra.$(fact!), fact) end end diff --git a/test/integration_tests.jl b/test/integration_tests.jl index c9a0ec0..b539ae0 100644 --- a/test/integration_tests.jl +++ b/test/integration_tests.jl @@ -1,5 +1,5 @@ @testitem "LinearSolve" setup=[SharedTestSetup] begin - using FiniteDiff, LinearSolve, Zygote + using FiniteDiff, LinearAlgebra, LinearSolve, Zygote rng = get_stable_rng(1001) @@ -13,12 +13,26 @@ prob2 = LinearProblem(A2, b) if dims[1] == dims[2] - solvers = [LUFactorization(), QRFactorization(), KrylovJL_GMRES()] + solvers = [LUFactorization(), QRFactorization(), + KrylovJL_GMRES(), svd_factorization(mode), nothing] else - solvers = [QRFactorization(), KrylovJL_LSMR()] + solvers = [ + QRFactorization(), KrylovJL_LSMR(), NormalCholeskyFactorization(), + QRFactorization(LinearAlgebra.ColumnNorm()), + svd_factorization(mode), nothing] end - @testset "solver: $(solver)" for solver in solvers + @testset "solver: $(nameof(typeof(solver)))" for solver in solvers + # FIXME: SVD doesn't define ldiv on CUDA side + if mode == "CUDA" + @show solver, solver isa SVDFactorization + if solver isa SVDFactorization || (solver isa QRFactorization && + solver.pivot isa LinearAlgebra.ColumnNorm) + # ColumnNorm is not implemented on CUDA + continue + end + end + x1 = solve(prob1, solver) x2 = solve(prob2, solver) @test x1.u ≈ x2.u diff --git a/test/shared_testsetup.jl b/test/shared_testsetup.jl index c3417fc..d69bfc7 100644 --- a/test/shared_testsetup.jl +++ b/test/shared_testsetup.jl @@ -32,7 +32,18 @@ end get_stable_rng(seed=12345) = StableRNG(seed) +# SVD Helper till https://github.com/SciML/LinearSolve.jl/issues/488 is resolved +using LinearSolve: LinearSolve + +function svd_factorization(mode) + mode == "CPU" && return LinearSolve.SVDFactorization() + mode == "CUDA" && + return LinearSolve.SVDFactorization(true, CUDA.CUSOLVER.JacobiAlgorithm()) + error("Unsupported mode: $mode") +end + export @jet, @test_gradients, check_approx export GROUP, MODES, cpu_testing, cuda_testing, get_default_rng, get_stable_rng +export svd_factorization end