Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
Custom rrules and special batched backslash implementation for CUDA
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Apr 1, 2024
1 parent b55fc0e commit 92e957b
Show file tree
Hide file tree
Showing 10 changed files with 209 additions and 47 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
BatchedRoutinesCUDAExt = ["CUDA"]
BatchedRoutinesCUDALinearSolveExt = ["CUDA", "LinearSolve"]
BatchedRoutinesFiniteDiffExt = ["FiniteDiff"]
BatchedRoutinesForwardDiffExt = ["ForwardDiff"]
BatchedRoutinesLinearSolveExt = ["LinearSolve"]
Expand Down
2 changes: 2 additions & 0 deletions ext/BatchedRoutinesCUDAExt/BatchedRoutinesCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ const CuBlasFloat = Union{Float16, Float32, Float64, ComplexF32, ComplexF64}
const CuUniformBlockDiagonalOperator{T} = UniformBlockDiagonalOperator{
T, <:CUDA.AnyCuArray{T, 3}}

include("low_level.jl")

include("batched_mul.jl")
include("factorization.jl")

Expand Down
53 changes: 24 additions & 29 deletions ext/BatchedRoutinesCUDAExt/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,34 +120,29 @@ function LinearAlgebra.ldiv!(X::CuMatrix, A::CuBatchedQR, b::CuMatrix)
return X
end

# Low Level Wrappers
for (fname, elty) in ((:cublasDgetrsBatched, :Float64), (:cublasSgetrsBatched, :Float32),
(:cublasZgetrsBatched, :ComplexF64), (:cublasCgetrsBatched, :ComplexF32))
@eval begin
function getrs_batched!(trans::Char, n, nrhs, Aptrs::CuVector{CuPtr{$elty}},
lda, p, Bptrs::CuVector{CuPtr{$elty}}, ldb)
batchSize = length(Aptrs)
info = Array{Cint}(undef, batchSize)
CUBLAS.$fname(
CUBLAS.handle(), trans, n, nrhs, Aptrs, lda, p, Bptrs, ldb, info, batchSize)
CUDA.unsafe_free!(Aptrs)
CUDA.unsafe_free!(Bptrs)
return info
end
# Direct Ldiv
function BatchedRoutines.__internal_backslash(

Check warning on line 124 in ext/BatchedRoutinesCUDAExt/factorization.jl

View check run for this annotation

Codecov / codecov/patch

ext/BatchedRoutinesCUDAExt/factorization.jl#L124

Added line #L124 was not covered by tests
op::CuUniformBlockDiagonalOperator{T1}, b::AbstractMatrix{T2}) where {T1, T2}
T = promote_type(T1, T2)
return __internal_backslash(T != T1 ? T.(op) : op, T != T2 ? T.(b) : b)

Check warning on line 127 in ext/BatchedRoutinesCUDAExt/factorization.jl

View check run for this annotation

Codecov / codecov/patch

ext/BatchedRoutinesCUDAExt/factorization.jl#L126-L127

Added lines #L126 - L127 were not covered by tests
end

function BatchedRoutines.__internal_backslash(
op::CuUniformBlockDiagonalOperator{T}, b::AbstractMatrix{T}) where {T}
size(op, 1) != length(b) && throw(DimensionMismatch("size(op, 1) != length(b)"))
x = similar(b, T, size(BatchedRoutines.getdata(op), 2), nbatches(op))
m, n = size(op)
if n < m # Underdetermined: LQ or QR with ColumnNorm
error("Underdetermined systems are not supported yet! Please open an issue if you \

Check warning on line 136 in ext/BatchedRoutinesCUDAExt/factorization.jl

View check run for this annotation

Codecov / codecov/patch

ext/BatchedRoutinesCUDAExt/factorization.jl#L136

Added line #L136 was not covered by tests
care about this feature.")
elseif n == m # Square: LU with Pivoting
p, _, F = CUBLAS.getrf_strided_batched!(copy(BatchedRoutines.getdata(op)), true)
copyto!(x, b)
getrs_strided_batched!('N', F, p, x)
else # Overdetermined: QR
CUBLAS.gels_batched!('N', batchview(copy(BatchedRoutines.getdata(op))),

Check warning on line 143 in ext/BatchedRoutinesCUDAExt/factorization.jl

View check run for this annotation

Codecov / codecov/patch

ext/BatchedRoutinesCUDAExt/factorization.jl#L143

Added line #L143 was not covered by tests
[reshape(bᵢ, :, 1) for bᵢ in batchview(b)])
copyto!(x, view(b, 1:n, :))

Check warning on line 145 in ext/BatchedRoutinesCUDAExt/factorization.jl

View check run for this annotation

Codecov / codecov/patch

ext/BatchedRoutinesCUDAExt/factorization.jl#L145

Added line #L145 was not covered by tests
end
end

function getrs_strided_batched!(trans::Char, F::DenseCuArray{<:Any, 3}, p::DenseCuMatrix,
B::Union{DenseCuArray{<:Any, 3}, DenseCuMatrix})
m, n = size(F, 1), size(F, 2)
m != n && throw(DimensionMismatch("All matrices must be square!"))
lda = max(1, stride(F, 2))
ldb = max(1, stride(B, 2))
nrhs = ifelse(ndims(B) == 2, 1, size(B, 2))

Fptrs = CUBLAS.unsafe_strided_batch(F)
Bptrs = CUBLAS.unsafe_strided_batch(B)
info = getrs_batched!(trans, n, nrhs, Fptrs, lda, p, Bptrs, ldb)

return B, info
return x
end
31 changes: 31 additions & 0 deletions ext/BatchedRoutinesCUDAExt/low_level.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Low Level Wrappers
for (fname, elty) in ((:cublasDgetrsBatched, :Float64), (:cublasSgetrsBatched, :Float32),
(:cublasZgetrsBatched, :ComplexF64), (:cublasCgetrsBatched, :ComplexF32))
@eval begin
function getrs_batched!(trans::Char, n, nrhs, Aptrs::CuVector{CuPtr{$elty}},
lda, p, Bptrs::CuVector{CuPtr{$elty}}, ldb)
batchSize = length(Aptrs)
info = Array{Cint}(undef, batchSize)
CUBLAS.$fname(
CUBLAS.handle(), trans, n, nrhs, Aptrs, lda, p, Bptrs, ldb, info, batchSize)
CUDA.unsafe_free!(Aptrs)
CUDA.unsafe_free!(Bptrs)
return info
end
end
end

function getrs_strided_batched!(trans::Char, F::DenseCuArray{<:Any, 3}, p::DenseCuMatrix,
B::Union{DenseCuArray{<:Any, 3}, DenseCuMatrix})
m, n = size(F, 1), size(F, 2)
m != n && throw(DimensionMismatch("All matrices must be square!"))
lda = max(1, stride(F, 2))
ldb = max(1, stride(B, 2))
nrhs = ifelse(ndims(B) == 2, 1, size(B, 2))

Fptrs = CUBLAS.unsafe_strided_batch(F)
Bptrs = CUBLAS.unsafe_strided_batch(B)
info = getrs_batched!(trans, n, nrhs, Fptrs, lda, p, Bptrs, ldb)

return B, info
end
24 changes: 24 additions & 0 deletions ext/BatchedRoutinesCUDALinearSolveExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
module BatchedRoutinesCUDALinearSolveExt

using BatchedRoutines: UniformBlockDiagonalOperator, getdata
using CUDA: CUDA
using LinearAlgebra: LinearAlgebra
using LinearSolve: LinearSolve

const CuUniformBlockDiagonalOperator{T} = UniformBlockDiagonalOperator{
T, <:CUDA.AnyCuArray{T, 3}}

function LinearSolve.init_cacheval(
alg::LinearSolve.SVDFactorization, A::CuUniformBlockDiagonalOperator, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool, ::LinearSolve.OperatorAssumptions)
return nothing
end

function LinearSolve.init_cacheval(
alg::LinearSolve.QRFactorization, A::CuUniformBlockDiagonalOperator, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool, ::LinearSolve.OperatorAssumptions)
A_ = UniformBlockDiagonalOperator(similar(getdata(A), 0, 0, 1))
return LinearAlgebra.qr!(A_) # ignore the pivot since CUDA doesn't support it
end

end
73 changes: 72 additions & 1 deletion ext/BatchedRoutinesLinearSolveExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@ module BatchedRoutinesLinearSolveExt

using ArrayInterface: ArrayInterface
using BatchedRoutines: BatchedRoutines, UniformBlockDiagonalOperator, getdata
using ChainRulesCore: ChainRulesCore, NoTangent
using FastClosures: @closure
using LinearAlgebra: LinearAlgebra
using LinearSolve: LinearSolve
using LinearSolve: LinearSolve, SciMLBase

const CRC = ChainRulesCore

# Overload LinearProblem, else causing problems in the adjoint code
function LinearSolve.LinearProblem(op::UniformBlockDiagonalOperator, b, args...; kwargs...)
Expand Down Expand Up @@ -125,4 +129,71 @@ function LinearSolve.do_factorization(
return LinearAlgebra.svd!(A; alg.full, alg.alg)
end

# We need a custom rrule here to prevent spurios gradients for zero blocks
# Copied from https://github.com/SciML/LinearSolve.jl/blob/7911113c6b14b6897cc356e277ccd5a98faa7dd7/src/adjoint.jl#L31 except the Lazy Arrays part
function CRC.rrule(::typeof(SciMLBase.solve),
prob::SciMLBase.LinearProblem{T1, T2, <:UniformBlockDiagonalOperator},
alg::LinearSolve.SciMLLinearSolveAlgorithm, args...;
alias_A=LinearSolve.default_alias_A(alg, prob.A, prob.b), kwargs...) where {T1, T2}
cache = SciMLBase.init(prob, alg, args...; kwargs...)
(; A, sensealg) = cache

@assert sensealg isa LinearSolve.LinearSolveAdjoint "Currently only `LinearSolveAdjoint` is supported for adjoint sensitivity analysis."

# Decide if we need to cache `A` and `b` for the reverse pass
A_ = A
if sensealg.linsolve === missing
# We can reuse the factorization so no copy is needed
# Krylov Methods don't modify `A`, so it's safe to just reuse it
# No Copy is needed even for the default case
if !(alg isa LinearSolve.AbstractFactorization ||
alg isa LinearSolve.AbstractKrylovSubspaceMethod ||
alg isa LinearSolve.DefaultLinearSolver)
A_ = alias_A ? deepcopy(A) : A

Check warning on line 152 in ext/BatchedRoutinesLinearSolveExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BatchedRoutinesLinearSolveExt.jl#L152

Added line #L152 was not covered by tests
end
else
A_ = deepcopy(A)

Check warning on line 155 in ext/BatchedRoutinesLinearSolveExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BatchedRoutinesLinearSolveExt.jl#L155

Added line #L155 was not covered by tests
end

sol = SciMLBase.solve!(cache)

proj_A = CRC.ProjectTo(getdata(A))
proj_b = CRC.ProjectTo(prob.b)

∇linear_solve = @closure ∂sol -> begin
∂u = ∂sol.u
if sensealg.linsolve === missing
λ = if cache.cacheval isa LinearAlgebra.Factorization
cache.cacheval' \ ∂u
elseif cache.cacheval isa Tuple &&
cache.cacheval[1] isa LinearAlgebra.Factorization
first(cache.cacheval)' \ ∂u

Check warning on line 170 in ext/BatchedRoutinesLinearSolveExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BatchedRoutinesLinearSolveExt.jl#L170

Added line #L170 was not covered by tests
elseif alg isa LinearSolve.AbstractKrylovSubspaceMethod
invprob = SciMLBase.LinearProblem(transpose(cache.A), ∂u)
SciMLBase.solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
elseif alg isa LinearSolve.DefaultLinearSolver
LinearSolve.defaultalg_adjoint_eval(cache, ∂u)
else
invprob = SciMLBase.LinearProblem(transpose(A_), ∂u) # We cached `A`

Check warning on line 177 in ext/BatchedRoutinesLinearSolveExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BatchedRoutinesLinearSolveExt.jl#L177

Added line #L177 was not covered by tests
SciMLBase.solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u
end
else
invprob = SciMLBase.LinearProblem(transpose(A_), ∂u) # We cached `A`
λ = SciMLBase.solve(

Check warning on line 182 in ext/BatchedRoutinesLinearSolveExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BatchedRoutinesLinearSolveExt.jl#L181-L182

Added lines #L181 - L182 were not covered by tests
invprob, sensealg.linsolve; cache.abstol, cache.reltol, cache.verbose).u
end

uᵀ = reshape(sol.u, 1, :, BatchedRoutines.nbatches(A))
∂A = UniformBlockDiagonalOperator(proj_A(BatchedRoutines.batched_mul(
reshape(λ, :, 1, BatchedRoutines.nbatches(A)), -uᵀ)))
∂b = proj_b(λ)
∂prob = SciMLBase.LinearProblem(∂A, ∂b, NoTangent())

return (
NoTangent(), ∂prob, NoTangent(), ntuple(Returns(NoTangent()), length(args))...)
end

return sol, ∇linear_solve
end

end
2 changes: 1 addition & 1 deletion src/BatchedRoutines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ function __init__()
printstyled(io, "\nHINT: "; bold=true)
printstyled(
io, "`UniformBlockDiagonalOperator` doesn't support AbstractArray \
operations. If you want this supported open an issue at \
operations. If you want this supported, open an issue at \
https://github.com/LuxDL/BatchedRoutines.jl to discuss it.";
color=:cyan)
end
Expand Down
17 changes: 17 additions & 0 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,20 @@ function CRC.rrule(::typeof(sum), ::typeof(identity), op::UniformBlockDiagonalOp
end
return y, ∇sum_abs2

Check warning on line 142 in src/chainrules.jl

View check run for this annotation

Codecov / codecov/patch

src/chainrules.jl#L142

Added line #L142 was not covered by tests
end

# Direct Ldiv
function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(\),
op::UniformBlockDiagonalOperator, b::AbstractMatrix)
# We haven't implemented the rrule for least squares yet, direct AD through the code
size(op, 1) != size(op, 2) && return CRC.rrule_via_ad(cfg, __internal_backslash, op, b)
# TODO: reuse the factorization once, `factorize(op)` has been implemented
u = op \ b
proj_A = CRC.ProjectTo(getdata(op))
proj_b = CRC.ProjectTo(b)
∇backslash = @closure ∂u -> begin
λ = op' \ ∂u
∂A = -batched_mul(λ, batched_adjoint(reshape(u, :, 1, nbatches(u))))
return NoTangent(), UniformBlockDiagonalOperator(proj_A(∂A)), proj_b(λ)
end
return u, ∇backslash
end
17 changes: 17 additions & 0 deletions src/operator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,11 @@ function Base.:-(op1::UniformBlockDiagonalOperator, op2::UniformBlockDiagonalOpe
return UniformBlockDiagonalOperator(getdata(op1) - getdata(op2))
end

function Base.isapprox(
op1::UniformBlockDiagonalOperator, op2::UniformBlockDiagonalOperator; kwargs...)
return isapprox(getdata(op1), getdata(op2); kwargs...)
end

# Adapt
@inline function Adapt.adapt_structure(to, op::UniformBlockDiagonalOperator)
return UniformBlockDiagonalOperator(Adapt.adapt(to, getdata(op)))
Expand Down Expand Up @@ -269,3 +274,15 @@ function LinearAlgebra.mul!(C::AbstractArray{T1, 3}, A::UniformBlockDiagonalOper
batched_mul!(C, getdata(A), B)
return C
end

# Direct \ operator
function Base.:\(op::UniformBlockDiagonalOperator, b::AbstractVector)
return vec(op \ reshape(b, :, nbatches(op)))
end
Base.:\(op::UniformBlockDiagonalOperator, b::AbstractMatrix) = __internal_backslash(op, b)

## This exists to allow a direct autodiff through the code. eg, for non-square systems
@inline function __internal_backslash(op::UniformBlockDiagonalOperator, b::AbstractMatrix)
size(op, 1) != length(b) && throw(DimensionMismatch("size(op, 1) != length(b)"))
return mapfoldl(((Aᵢ, bᵢ),) -> Aᵢ \ bᵢ, hcat, zip(batchview(op), batchview(b)))
end
36 changes: 20 additions & 16 deletions test/integration_tests.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
@testitem "LinearSolve" setup=[SharedTestSetup] begin
using FiniteDiff, LinearAlgebra, LinearSolve, Zygote
using LinearAlgebra, LinearSolve, Zygote

rng = get_stable_rng(1001)

@testset "$mode" for (mode, aType, dev, ongpu) in MODES
for dims in ((8, 8, 2), (5, 3, 2))
A1 = UniformBlockDiagonalOperator(rand(rng, dims...)) |> dev
A2 = Matrix(A1) |> dev
A2 = collect(A1)
b = rand(rng, size(A1, 1)) |> dev

prob1 = LinearProblem(A1, b)
Expand All @@ -22,10 +22,17 @@
svd_factorization(mode), nothing]
end

if dims[1] == dims[2]
test_chainrules_adjoint = (A, b) -> sum(abs2, A \ b)

∂A_cr, ∂b_cr = Zygote.gradient(test_chainrules_adjoint, A1, b)
else
∂A_cr, ∂b_cr = nothing, nothing
end

@testset "solver: $(nameof(typeof(solver)))" for solver in solvers
# FIXME: SVD doesn't define ldiv on CUDA side
if mode == "CUDA"
@show solver, solver isa SVDFactorization
if solver isa SVDFactorization || (solver isa QRFactorization &&
solver.pivot isa LinearAlgebra.ColumnNorm)
# ColumnNorm is not implemented on CUDA
Expand All @@ -34,33 +41,30 @@
end

x1 = solve(prob1, solver)
x2 = solve(prob2, solver)
@test x1.u x2.u
if !ongpu && !(solver isa NormalCholeskyFactorization)
x2 = solve(prob2, solver)
@test x1.u x2.u
end

dims[1] != dims[2] && continue

test_adjoint = function (A, b)
sol = solve(LinearProblem(A, b), solver)
return sum(abs2, sol.u)
end

dims[1] != dims[2] && continue

∂A_fd = FiniteDiff.finite_difference_gradient(
x -> test_adjoint(x, Array(b)), Array(A1))
∂b_fd = FiniteDiff.finite_difference_gradient(
x -> test_adjoint(Array(A1), x), Array(b))

if solver isa QRFactorization && ongpu
@test_broken begin
∂A, ∂b = Zygote.gradient(test_adjoint, A1, b)

@test Array(∂A)∂A_fd atol=1e-1 rtol=1e-1
@test Array(∂b)∂b_fd atol=1e-1 rtol=1e-1
@test ∂A ∂A_cr
@test ∂b ∂b_cr
end
else
∂A, ∂b = Zygote.gradient(test_adjoint, A1, b)

@test Array(∂A)∂A_fd atol=1e-1 rtol=1e-1
@test Array(∂b)∂b_fd atol=1e-1 rtol=1e-1
@test ∂A ∂A_cr
@test ∂b ∂b_cr
end
end
end
Expand Down

0 comments on commit 92e957b

Please sign in to comment.