diff --git a/Project.toml b/Project.toml index 85ec644..4a79155 100644 --- a/Project.toml +++ b/Project.toml @@ -26,6 +26,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] BatchedRoutinesCUDAExt = ["CUDA"] +BatchedRoutinesCUDALinearSolveExt = ["CUDA", "LinearSolve"] BatchedRoutinesFiniteDiffExt = ["FiniteDiff"] BatchedRoutinesForwardDiffExt = ["ForwardDiff"] BatchedRoutinesLinearSolveExt = ["LinearSolve"] diff --git a/ext/BatchedRoutinesCUDAExt/BatchedRoutinesCUDAExt.jl b/ext/BatchedRoutinesCUDAExt/BatchedRoutinesCUDAExt.jl index 1616870..8f088a3 100644 --- a/ext/BatchedRoutinesCUDAExt/BatchedRoutinesCUDAExt.jl +++ b/ext/BatchedRoutinesCUDAExt/BatchedRoutinesCUDAExt.jl @@ -12,6 +12,8 @@ const CuBlasFloat = Union{Float16, Float32, Float64, ComplexF32, ComplexF64} const CuUniformBlockDiagonalOperator{T} = UniformBlockDiagonalOperator{ T, <:CUDA.AnyCuArray{T, 3}} +include("low_level.jl") + include("batched_mul.jl") include("factorization.jl") diff --git a/ext/BatchedRoutinesCUDAExt/factorization.jl b/ext/BatchedRoutinesCUDAExt/factorization.jl index 13d7936..85db0ce 100644 --- a/ext/BatchedRoutinesCUDAExt/factorization.jl +++ b/ext/BatchedRoutinesCUDAExt/factorization.jl @@ -120,34 +120,29 @@ function LinearAlgebra.ldiv!(X::CuMatrix, A::CuBatchedQR, b::CuMatrix) return X end -# Low Level Wrappers -for (fname, elty) in ((:cublasDgetrsBatched, :Float64), (:cublasSgetrsBatched, :Float32), - (:cublasZgetrsBatched, :ComplexF64), (:cublasCgetrsBatched, :ComplexF32)) - @eval begin - function getrs_batched!(trans::Char, n, nrhs, Aptrs::CuVector{CuPtr{$elty}}, - lda, p, Bptrs::CuVector{CuPtr{$elty}}, ldb) - batchSize = length(Aptrs) - info = Array{Cint}(undef, batchSize) - CUBLAS.$fname( - CUBLAS.handle(), trans, n, nrhs, Aptrs, lda, p, Bptrs, ldb, info, batchSize) - CUDA.unsafe_free!(Aptrs) - CUDA.unsafe_free!(Bptrs) - return info - end +# Direct Ldiv +function BatchedRoutines.__internal_backslash( + op::CuUniformBlockDiagonalOperator{T1}, b::AbstractMatrix{T2}) where {T1, T2} + T = promote_type(T1, T2) + return __internal_backslash(T != T1 ? T.(op) : op, T != T2 ? T.(b) : b) +end + +function BatchedRoutines.__internal_backslash( + op::CuUniformBlockDiagonalOperator{T}, b::AbstractMatrix{T}) where {T} + size(op, 1) != length(b) && throw(DimensionMismatch("size(op, 1) != length(b)")) + x = similar(b, T, size(BatchedRoutines.getdata(op), 2), nbatches(op)) + m, n = size(op) + if n < m # Underdetermined: LQ or QR with ColumnNorm + error("Underdetermined systems are not supported yet! Please open an issue if you \ + care about this feature.") + elseif n == m # Square: LU with Pivoting + p, _, F = CUBLAS.getrf_strided_batched!(copy(BatchedRoutines.getdata(op)), true) + copyto!(x, b) + getrs_strided_batched!('N', F, p, x) + else # Overdetermined: QR + CUBLAS.gels_batched!('N', batchview(copy(BatchedRoutines.getdata(op))), + [reshape(bᵢ, :, 1) for bᵢ in batchview(b)]) + copyto!(x, view(b, 1:n, :)) end -end - -function getrs_strided_batched!(trans::Char, F::DenseCuArray{<:Any, 3}, p::DenseCuMatrix, - B::Union{DenseCuArray{<:Any, 3}, DenseCuMatrix}) - m, n = size(F, 1), size(F, 2) - m != n && throw(DimensionMismatch("All matrices must be square!")) - lda = max(1, stride(F, 2)) - ldb = max(1, stride(B, 2)) - nrhs = ifelse(ndims(B) == 2, 1, size(B, 2)) - - Fptrs = CUBLAS.unsafe_strided_batch(F) - Bptrs = CUBLAS.unsafe_strided_batch(B) - info = getrs_batched!(trans, n, nrhs, Fptrs, lda, p, Bptrs, ldb) - - return B, info + return x end diff --git a/ext/BatchedRoutinesCUDAExt/low_level.jl b/ext/BatchedRoutinesCUDAExt/low_level.jl new file mode 100644 index 0000000..a858e52 --- /dev/null +++ b/ext/BatchedRoutinesCUDAExt/low_level.jl @@ -0,0 +1,31 @@ +# Low Level Wrappers +for (fname, elty) in ((:cublasDgetrsBatched, :Float64), (:cublasSgetrsBatched, :Float32), + (:cublasZgetrsBatched, :ComplexF64), (:cublasCgetrsBatched, :ComplexF32)) + @eval begin + function getrs_batched!(trans::Char, n, nrhs, Aptrs::CuVector{CuPtr{$elty}}, + lda, p, Bptrs::CuVector{CuPtr{$elty}}, ldb) + batchSize = length(Aptrs) + info = Array{Cint}(undef, batchSize) + CUBLAS.$fname( + CUBLAS.handle(), trans, n, nrhs, Aptrs, lda, p, Bptrs, ldb, info, batchSize) + CUDA.unsafe_free!(Aptrs) + CUDA.unsafe_free!(Bptrs) + return info + end + end +end + +function getrs_strided_batched!(trans::Char, F::DenseCuArray{<:Any, 3}, p::DenseCuMatrix, + B::Union{DenseCuArray{<:Any, 3}, DenseCuMatrix}) + m, n = size(F, 1), size(F, 2) + m != n && throw(DimensionMismatch("All matrices must be square!")) + lda = max(1, stride(F, 2)) + ldb = max(1, stride(B, 2)) + nrhs = ifelse(ndims(B) == 2, 1, size(B, 2)) + + Fptrs = CUBLAS.unsafe_strided_batch(F) + Bptrs = CUBLAS.unsafe_strided_batch(B) + info = getrs_batched!(trans, n, nrhs, Fptrs, lda, p, Bptrs, ldb) + + return B, info +end diff --git a/ext/BatchedRoutinesCUDALinearSolveExt.jl b/ext/BatchedRoutinesCUDALinearSolveExt.jl new file mode 100644 index 0000000..d9da190 --- /dev/null +++ b/ext/BatchedRoutinesCUDALinearSolveExt.jl @@ -0,0 +1,24 @@ +module BatchedRoutinesCUDALinearSolveExt + +using BatchedRoutines: UniformBlockDiagonalOperator, getdata +using CUDA: CUDA +using LinearAlgebra: LinearAlgebra +using LinearSolve: LinearSolve + +const CuUniformBlockDiagonalOperator{T} = UniformBlockDiagonalOperator{ + T, <:CUDA.AnyCuArray{T, 3}} + +function LinearSolve.init_cacheval( + alg::LinearSolve.SVDFactorization, A::CuUniformBlockDiagonalOperator, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, ::LinearSolve.OperatorAssumptions) + return nothing +end + +function LinearSolve.init_cacheval( + alg::LinearSolve.QRFactorization, A::CuUniformBlockDiagonalOperator, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, ::LinearSolve.OperatorAssumptions) + A_ = UniformBlockDiagonalOperator(similar(getdata(A), 0, 0, 1)) + return LinearAlgebra.qr!(A_) # ignore the pivot since CUDA doesn't support it +end + +end diff --git a/ext/BatchedRoutinesLinearSolveExt.jl b/ext/BatchedRoutinesLinearSolveExt.jl index a808279..42b2808 100644 --- a/ext/BatchedRoutinesLinearSolveExt.jl +++ b/ext/BatchedRoutinesLinearSolveExt.jl @@ -2,8 +2,12 @@ module BatchedRoutinesLinearSolveExt using ArrayInterface: ArrayInterface using BatchedRoutines: BatchedRoutines, UniformBlockDiagonalOperator, getdata +using ChainRulesCore: ChainRulesCore, NoTangent +using FastClosures: @closure using LinearAlgebra: LinearAlgebra -using LinearSolve: LinearSolve +using LinearSolve: LinearSolve, SciMLBase + +const CRC = ChainRulesCore # Overload LinearProblem, else causing problems in the adjoint code function LinearSolve.LinearProblem(op::UniformBlockDiagonalOperator, b, args...; kwargs...) @@ -125,4 +129,71 @@ function LinearSolve.do_factorization( return LinearAlgebra.svd!(A; alg.full, alg.alg) end +# We need a custom rrule here to prevent spurios gradients for zero blocks +# Copied from https://github.com/SciML/LinearSolve.jl/blob/7911113c6b14b6897cc356e277ccd5a98faa7dd7/src/adjoint.jl#L31 except the Lazy Arrays part +function CRC.rrule(::typeof(SciMLBase.solve), + prob::SciMLBase.LinearProblem{T1, T2, <:UniformBlockDiagonalOperator}, + alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; + alias_A=LinearSolve.default_alias_A(alg, prob.A, prob.b), kwargs...) where {T1, T2} + cache = SciMLBase.init(prob, alg, args...; kwargs...) + (; A, sensealg) = cache + + @assert sensealg isa LinearSolve.LinearSolveAdjoint "Currently only `LinearSolveAdjoint` is supported for adjoint sensitivity analysis." + + # Decide if we need to cache `A` and `b` for the reverse pass + A_ = A + if sensealg.linsolve === missing + # We can reuse the factorization so no copy is needed + # Krylov Methods don't modify `A`, so it's safe to just reuse it + # No Copy is needed even for the default case + if !(alg isa LinearSolve.AbstractFactorization || + alg isa LinearSolve.AbstractKrylovSubspaceMethod || + alg isa LinearSolve.DefaultLinearSolver) + A_ = alias_A ? deepcopy(A) : A + end + else + A_ = deepcopy(A) + end + + sol = SciMLBase.solve!(cache) + + proj_A = CRC.ProjectTo(getdata(A)) + proj_b = CRC.ProjectTo(prob.b) + + ∇linear_solve = @closure ∂sol -> begin + ∂u = ∂sol.u + if sensealg.linsolve === missing + λ = if cache.cacheval isa LinearAlgebra.Factorization + cache.cacheval' \ ∂u + elseif cache.cacheval isa Tuple && + cache.cacheval[1] isa LinearAlgebra.Factorization + first(cache.cacheval)' \ ∂u + elseif alg isa LinearSolve.AbstractKrylovSubspaceMethod + invprob = SciMLBase.LinearProblem(transpose(cache.A), ∂u) + SciMLBase.solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u + elseif alg isa LinearSolve.DefaultLinearSolver + LinearSolve.defaultalg_adjoint_eval(cache, ∂u) + else + invprob = SciMLBase.LinearProblem(transpose(A_), ∂u) # We cached `A` + SciMLBase.solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u + end + else + invprob = SciMLBase.LinearProblem(transpose(A_), ∂u) # We cached `A` + λ = SciMLBase.solve( + invprob, sensealg.linsolve; cache.abstol, cache.reltol, cache.verbose).u + end + + uᵀ = reshape(sol.u, 1, :, BatchedRoutines.nbatches(A)) + ∂A = UniformBlockDiagonalOperator(proj_A(BatchedRoutines.batched_mul( + reshape(λ, :, 1, BatchedRoutines.nbatches(A)), -uᵀ))) + ∂b = proj_b(λ) + ∂prob = SciMLBase.LinearProblem(∂A, ∂b, NoTangent()) + + return ( + NoTangent(), ∂prob, NoTangent(), ntuple(Returns(NoTangent()), length(args))...) + end + + return sol, ∇linear_solve +end + end diff --git a/src/BatchedRoutines.jl b/src/BatchedRoutines.jl index efb9cec..692b71c 100644 --- a/src/BatchedRoutines.jl +++ b/src/BatchedRoutines.jl @@ -24,7 +24,7 @@ function __init__() printstyled(io, "\nHINT: "; bold=true) printstyled( io, "`UniformBlockDiagonalOperator` doesn't support AbstractArray \ - operations. If you want this supported open an issue at \ + operations. If you want this supported, open an issue at \ https://github.com/LuxDL/BatchedRoutines.jl to discuss it."; color=:cyan) end diff --git a/src/chainrules.jl b/src/chainrules.jl index 44f5e56..7b7655c 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -141,3 +141,20 @@ function CRC.rrule(::typeof(sum), ::typeof(identity), op::UniformBlockDiagonalOp end return y, ∇sum_abs2 end + +# Direct Ldiv +function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(\), + op::UniformBlockDiagonalOperator, b::AbstractMatrix) + # We haven't implemented the rrule for least squares yet, direct AD through the code + size(op, 1) != size(op, 2) && return CRC.rrule_via_ad(cfg, __internal_backslash, op, b) + # TODO: reuse the factorization once, `factorize(op)` has been implemented + u = op \ b + proj_A = CRC.ProjectTo(getdata(op)) + proj_b = CRC.ProjectTo(b) + ∇backslash = @closure ∂u -> begin + λ = op' \ ∂u + ∂A = -batched_mul(λ, batched_adjoint(reshape(u, :, 1, nbatches(u)))) + return NoTangent(), UniformBlockDiagonalOperator(proj_A(∂A)), proj_b(λ) + end + return u, ∇backslash +end diff --git a/src/operator.jl b/src/operator.jl index b632045..034b583 100644 --- a/src/operator.jl +++ b/src/operator.jl @@ -196,6 +196,11 @@ function Base.:-(op1::UniformBlockDiagonalOperator, op2::UniformBlockDiagonalOpe return UniformBlockDiagonalOperator(getdata(op1) - getdata(op2)) end +function Base.isapprox( + op1::UniformBlockDiagonalOperator, op2::UniformBlockDiagonalOperator; kwargs...) + return isapprox(getdata(op1), getdata(op2); kwargs...) +end + # Adapt @inline function Adapt.adapt_structure(to, op::UniformBlockDiagonalOperator) return UniformBlockDiagonalOperator(Adapt.adapt(to, getdata(op))) @@ -269,3 +274,15 @@ function LinearAlgebra.mul!(C::AbstractArray{T1, 3}, A::UniformBlockDiagonalOper batched_mul!(C, getdata(A), B) return C end + +# Direct \ operator +function Base.:\(op::UniformBlockDiagonalOperator, b::AbstractVector) + return vec(op \ reshape(b, :, nbatches(op))) +end +Base.:\(op::UniformBlockDiagonalOperator, b::AbstractMatrix) = __internal_backslash(op, b) + +## This exists to allow a direct autodiff through the code. eg, for non-square systems +@inline function __internal_backslash(op::UniformBlockDiagonalOperator, b::AbstractMatrix) + size(op, 1) != length(b) && throw(DimensionMismatch("size(op, 1) != length(b)")) + return mapfoldl(((Aᵢ, bᵢ),) -> Aᵢ \ bᵢ, hcat, zip(batchview(op), batchview(b))) +end diff --git a/test/integration_tests.jl b/test/integration_tests.jl index b539ae0..8f6433f 100644 --- a/test/integration_tests.jl +++ b/test/integration_tests.jl @@ -1,12 +1,12 @@ @testitem "LinearSolve" setup=[SharedTestSetup] begin - using FiniteDiff, LinearAlgebra, LinearSolve, Zygote + using LinearAlgebra, LinearSolve, Zygote rng = get_stable_rng(1001) @testset "$mode" for (mode, aType, dev, ongpu) in MODES for dims in ((8, 8, 2), (5, 3, 2)) A1 = UniformBlockDiagonalOperator(rand(rng, dims...)) |> dev - A2 = Matrix(A1) |> dev + A2 = collect(A1) b = rand(rng, size(A1, 1)) |> dev prob1 = LinearProblem(A1, b) @@ -22,10 +22,17 @@ svd_factorization(mode), nothing] end + if dims[1] == dims[2] + test_chainrules_adjoint = (A, b) -> sum(abs2, A \ b) + + ∂A_cr, ∂b_cr = Zygote.gradient(test_chainrules_adjoint, A1, b) + else + ∂A_cr, ∂b_cr = nothing, nothing + end + @testset "solver: $(nameof(typeof(solver)))" for solver in solvers # FIXME: SVD doesn't define ldiv on CUDA side if mode == "CUDA" - @show solver, solver isa SVDFactorization if solver isa SVDFactorization || (solver isa QRFactorization && solver.pivot isa LinearAlgebra.ColumnNorm) # ColumnNorm is not implemented on CUDA @@ -34,33 +41,30 @@ end x1 = solve(prob1, solver) - x2 = solve(prob2, solver) - @test x1.u ≈ x2.u + if !ongpu && !(solver isa NormalCholeskyFactorization) + x2 = solve(prob2, solver) + @test x1.u ≈ x2.u + end + + dims[1] != dims[2] && continue test_adjoint = function (A, b) sol = solve(LinearProblem(A, b), solver) return sum(abs2, sol.u) end - dims[1] != dims[2] && continue - - ∂A_fd = FiniteDiff.finite_difference_gradient( - x -> test_adjoint(x, Array(b)), Array(A1)) - ∂b_fd = FiniteDiff.finite_difference_gradient( - x -> test_adjoint(Array(A1), x), Array(b)) - if solver isa QRFactorization && ongpu @test_broken begin ∂A, ∂b = Zygote.gradient(test_adjoint, A1, b) - @test Array(∂A)≈∂A_fd atol=1e-1 rtol=1e-1 - @test Array(∂b)≈∂b_fd atol=1e-1 rtol=1e-1 + @test ∂A ≈ ∂A_cr + @test ∂b ≈ ∂b_cr end else ∂A, ∂b = Zygote.gradient(test_adjoint, A1, b) - @test Array(∂A)≈∂A_fd atol=1e-1 rtol=1e-1 - @test Array(∂b)≈∂b_fd atol=1e-1 rtol=1e-1 + @test ∂A ≈ ∂A_cr + @test ∂b ≈ ∂b_cr end end end