Skip to content
This repository has been archived by the owner on Nov 1, 2024. It is now read-only.

Commit

Permalink
Merge pull request #11 from LuxDL/ap/operator
Browse files Browse the repository at this point in the history
Migrate to an operator based implementation
  • Loading branch information
avik-pal authored Apr 1, 2024
2 parents d0ce078 + 256f0cc commit 54b894e
Show file tree
Hide file tree
Showing 18 changed files with 863 additions and 550 deletions.
7 changes: 7 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,24 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
BatchedRoutinesCUDAExt = ["CUDA"]
BatchedRoutinesComponentArraysForwardDiffExt = ["ComponentArrays", "ForwardDiff"]
BatchedRoutinesCUDALinearSolveExt = ["CUDA", "LinearSolve"]
BatchedRoutinesFiniteDiffExt = ["FiniteDiff"]
BatchedRoutinesForwardDiffExt = ["ForwardDiff"]
BatchedRoutinesLinearSolveExt = ["LinearSolve"]
BatchedRoutinesReverseDiffExt = ["ReverseDiff"]
BatchedRoutinesZygoteExt = ["Zygote"]

Expand Down Expand Up @@ -53,6 +59,7 @@ PrecompileTools = "1.2.0"
Random = "<0.0.1, 1"
ReTestItems = "1.23.1"
ReverseDiff = "1.15"
SciMLOperators = "0.3.8"
StableRNGs = "1.0.1"
Statistics = "1.11.1"
Test = "<0.0.1, 1"
Expand Down
7 changes: 5 additions & 2 deletions ext/BatchedRoutinesCUDAExt/BatchedRoutinesCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
module BatchedRoutinesCUDAExt

using BatchedRoutines: AbstractBatchedMatrixFactorization, BatchedRoutines,
UniformBlockDiagonalMatrix, batchview, nbatches
UniformBlockDiagonalOperator, batchview, nbatches
using CUDA: CUBLAS, CUDA, CUSOLVER, CuArray, CuMatrix, CuPtr, CuVector, DenseCuArray,
DenseCuMatrix
using ConcreteStructs: @concrete
using LinearAlgebra: BLAS, ColumnNorm, LinearAlgebra, NoPivot, RowMaximum, RowNonZero, mul!

const CuBlasFloat = Union{Float16, Float32, Float64, ComplexF32, ComplexF64}

const CuUniformBlockDiagonalMatrix{T} = UniformBlockDiagonalMatrix{T, <:CuArray{T, 3}}
const CuUniformBlockDiagonalOperator{T} = UniformBlockDiagonalOperator{
T, <:CUDA.AnyCuArray{T, 3}}

include("low_level.jl")

include("batched_mul.jl")
include("factorization.jl")
Expand Down
65 changes: 32 additions & 33 deletions ext/BatchedRoutinesCUDAExt/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ end

for pT in (:RowMaximum, :RowNonZero, :NoPivot)
@eval begin
function LinearAlgebra.lu!(A::CuUniformBlockDiagonalMatrix, pivot::$pT; kwargs...)
function LinearAlgebra.lu!(A::CuUniformBlockDiagonalOperator, pivot::$pT; kwargs...)
return LinearAlgebra.lu!(A, !(pivot isa NoPivot); kwargs...)
end
end
end

function LinearAlgebra.lu!(
A::CuUniformBlockDiagonalMatrix, pivot::Bool=true; check::Bool=true, kwargs...)
A::CuUniformBlockDiagonalOperator, pivot::Bool=true; check::Bool=true, kwargs...)
pivot_array, info_, factors = CUBLAS.getrf_strided_batched!(A.data, pivot)
info = Array(info_)
check && LinearAlgebra.checknonsingular.(info)
Expand Down Expand Up @@ -82,11 +82,15 @@ function Base.show(io::IO, QR::CuBatchedQR)
return print(io, "CuBatchedQR() with Batch Count: $(nbatches(QR))")
end

function LinearAlgebra.qr!(::CuUniformBlockDiagonalMatrix, ::ColumnNorm; kwargs...)
function LinearAlgebra.qr!(A::CuUniformBlockDiagonalOperator; kwargs...)
return LinearAlgebra.qr!(A, NoPivot(); kwargs...)
end

function LinearAlgebra.qr!(::CuUniformBlockDiagonalOperator, ::ColumnNorm; kwargs...)
throw(ArgumentError("ColumnNorm is not supported for batched CUDA QR factorization!"))
end

function LinearAlgebra.qr!(A::CuUniformBlockDiagonalMatrix, ::NoPivot; kwargs...)
function LinearAlgebra.qr!(A::CuUniformBlockDiagonalOperator, ::NoPivot; kwargs...)
τ, factors = CUBLAS.geqrf_batched!(collect(batchview(A)))
return CuBatchedQR{eltype(A)}(factors, τ, size(A))
end
Expand Down Expand Up @@ -116,34 +120,29 @@ function LinearAlgebra.ldiv!(X::CuMatrix, A::CuBatchedQR, b::CuMatrix)
return X
end

# Low Level Wrappers
for (fname, elty) in ((:cublasDgetrsBatched, :Float64), (:cublasSgetrsBatched, :Float32),
(:cublasZgetrsBatched, :ComplexF64), (:cublasCgetrsBatched, :ComplexF32))
@eval begin
function getrs_batched!(trans::Char, n, nrhs, Aptrs::CuVector{CuPtr{$elty}},
lda, p, Bptrs::CuVector{CuPtr{$elty}}, ldb)
batchSize = length(Aptrs)
info = Array{Cint}(undef, batchSize)
CUBLAS.$fname(
CUBLAS.handle(), trans, n, nrhs, Aptrs, lda, p, Bptrs, ldb, info, batchSize)
CUDA.unsafe_free!(Aptrs)
CUDA.unsafe_free!(Bptrs)
return info
end
# Direct Ldiv
function BatchedRoutines.__internal_backslash(
op::CuUniformBlockDiagonalOperator{T1}, b::AbstractMatrix{T2}) where {T1, T2}
T = promote_type(T1, T2)
return __internal_backslash(T != T1 ? T.(op) : op, T != T2 ? T.(b) : b)
end

function BatchedRoutines.__internal_backslash(
op::CuUniformBlockDiagonalOperator{T}, b::AbstractMatrix{T}) where {T}
size(op, 1) != length(b) && throw(DimensionMismatch("size(op, 1) != length(b)"))
x = similar(b, T, size(BatchedRoutines.getdata(op), 2), nbatches(op))
m, n = size(op)
if n < m # Underdetermined: LQ or QR with ColumnNorm
error("Underdetermined systems are not supported yet! Please open an issue if you \
care about this feature.")
elseif n == m # Square: LU with Pivoting
p, _, F = CUBLAS.getrf_strided_batched!(copy(BatchedRoutines.getdata(op)), true)
copyto!(x, b)
getrs_strided_batched!('N', F, p, x)
else # Overdetermined: QR
CUBLAS.gels_batched!('N', batchview(copy(BatchedRoutines.getdata(op))),
[reshape(bᵢ, :, 1) for bᵢ in batchview(b)])
copyto!(x, view(b, 1:n, :))
end
end

function getrs_strided_batched!(trans::Char, F::DenseCuArray{<:Any, 3}, p::DenseCuMatrix,
B::Union{DenseCuArray{<:Any, 3}, DenseCuMatrix})
m, n = size(F, 1), size(F, 2)
m != n && throw(DimensionMismatch("All matrices must be square!"))
lda = max(1, stride(F, 2))
ldb = max(1, stride(B, 2))
nrhs = ifelse(ndims(B) == 2, 1, size(B, 2))

Fptrs = CUBLAS.unsafe_strided_batch(F)
Bptrs = CUBLAS.unsafe_strided_batch(B)
info = getrs_batched!(trans, n, nrhs, Fptrs, lda, p, Bptrs, ldb)

return B, info
return x
end
31 changes: 31 additions & 0 deletions ext/BatchedRoutinesCUDAExt/low_level.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Low Level Wrappers
for (fname, elty) in ((:cublasDgetrsBatched, :Float64), (:cublasSgetrsBatched, :Float32),
(:cublasZgetrsBatched, :ComplexF64), (:cublasCgetrsBatched, :ComplexF32))
@eval begin
function getrs_batched!(trans::Char, n, nrhs, Aptrs::CuVector{CuPtr{$elty}},
lda, p, Bptrs::CuVector{CuPtr{$elty}}, ldb)
batchSize = length(Aptrs)
info = Array{Cint}(undef, batchSize)
CUBLAS.$fname(
CUBLAS.handle(), trans, n, nrhs, Aptrs, lda, p, Bptrs, ldb, info, batchSize)
CUDA.unsafe_free!(Aptrs)
CUDA.unsafe_free!(Bptrs)
return info
end
end
end

function getrs_strided_batched!(trans::Char, F::DenseCuArray{<:Any, 3}, p::DenseCuMatrix,
B::Union{DenseCuArray{<:Any, 3}, DenseCuMatrix})
m, n = size(F, 1), size(F, 2)
m != n && throw(DimensionMismatch("All matrices must be square!"))
lda = max(1, stride(F, 2))
ldb = max(1, stride(B, 2))
nrhs = ifelse(ndims(B) == 2, 1, size(B, 2))

Fptrs = CUBLAS.unsafe_strided_batch(F)
Bptrs = CUBLAS.unsafe_strided_batch(B)
info = getrs_batched!(trans, n, nrhs, Fptrs, lda, p, Bptrs, ldb)

return B, info
end
24 changes: 24 additions & 0 deletions ext/BatchedRoutinesCUDALinearSolveExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
module BatchedRoutinesCUDALinearSolveExt

using BatchedRoutines: UniformBlockDiagonalOperator, getdata
using CUDA: CUDA
using LinearAlgebra: LinearAlgebra
using LinearSolve: LinearSolve

const CuUniformBlockDiagonalOperator{T} = UniformBlockDiagonalOperator{
T, <:CUDA.AnyCuArray{T, 3}}

function LinearSolve.init_cacheval(
alg::LinearSolve.SVDFactorization, A::CuUniformBlockDiagonalOperator, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool, ::LinearSolve.OperatorAssumptions)
return nothing
end

function LinearSolve.init_cacheval(
alg::LinearSolve.QRFactorization, A::CuUniformBlockDiagonalOperator, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool, ::LinearSolve.OperatorAssumptions)
A_ = UniformBlockDiagonalOperator(similar(getdata(A), 0, 0, 1))
return LinearAlgebra.qr!(A_) # ignore the pivot since CUDA doesn't support it
end

end
12 changes: 12 additions & 0 deletions ext/BatchedRoutinesComponentArraysForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module BatchedRoutinesComponentArraysForwardDiffExt

using BatchedRoutines: BatchedRoutines
using ComponentArrays: ComponentArrays, ComponentArray
using ForwardDiff: ForwardDiff

@inline function BatchedRoutines._restructure(y, x::ComponentArray)
x_data = ComponentArrays.getdata(x)
return ComponentArray(reshape(y, size(x_data)), ComponentArrays.getaxes(x))
end

end
8 changes: 4 additions & 4 deletions ext/BatchedRoutinesFiniteDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module BatchedRoutinesFiniteDiffExt

using ADTypes: AutoFiniteDiff
using ArrayInterface: matrix_colors, parameterless_type
using BatchedRoutines: BatchedRoutines, UniformBlockDiagonalMatrix, _assert_type
using BatchedRoutines: BatchedRoutines, UniformBlockDiagonalOperator, _assert_type
using FastClosures: @closure
using FiniteDiff: FiniteDiff

Expand All @@ -14,15 +14,15 @@ using FiniteDiff: FiniteDiff
ad::AutoFiniteDiff, f::F, x::AbstractVector{T}) where {F, T}
J = FiniteDiff.finite_difference_jacobian(f, x, ad.fdjtype)
(_assert_type(f) && _assert_type(x) && Base.issingletontype(F)) &&
(return UniformBlockDiagonalMatrix(J::parameterless_type(x){T, 2}))
return UniformBlockDiagonalMatrix(J)
(return UniformBlockDiagonalOperator(J::parameterless_type(x){T, 2}))
return UniformBlockDiagonalOperator(J)
end

@inline function BatchedRoutines._batched_jacobian(
ad::AutoFiniteDiff, f::F, x::AbstractMatrix) where {F}
f! = @closure (y, x_) -> copyto!(y, f(x_))
fx = f(x)
J = UniformBlockDiagonalMatrix(similar(
J = UniformBlockDiagonalOperator(similar(
x, promote_type(eltype(fx), eltype(x)), size(fx, 1), size(x, 1), size(x, 2)))
sparsecache = FiniteDiff.JacobianCache(
x, fx, ad.fdjtype; colorvec=matrix_colors(J), sparsity=J)
Expand Down
16 changes: 8 additions & 8 deletions ext/BatchedRoutinesForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module BatchedRoutinesForwardDiffExt

using ADTypes: AutoForwardDiff
using ArrayInterface: parameterless_type
using BatchedRoutines: BatchedRoutines, UniformBlockDiagonalMatrix, batched_jacobian,
using BatchedRoutines: BatchedRoutines, UniformBlockDiagonalOperator, batched_jacobian,
batched_mul, batched_pickchunksize, _assert_type
using ChainRulesCore: ChainRulesCore
using FastClosures: @closure
Expand Down Expand Up @@ -117,7 +117,7 @@ end
else
jac_call = :((y, J) = __batched_value_and_jacobian(ad, f, u, $(Val(CK))))
end
return Expr(:block, jac_call, :(return (y, UniformBlockDiagonalMatrix(J))))
return Expr(:block, jac_call, :(return (y, UniformBlockDiagonalOperator(J))))
end

## Exposed API
Expand All @@ -132,8 +132,8 @@ end
end
J = ForwardDiff.jacobian(f, u, cfg)
(_assert_type(f) && _assert_type(u) && Base.issingletontype(F)) &&
(return UniformBlockDiagonalMatrix(J::parameterless_type(u){T, 2}))
return UniformBlockDiagonalMatrix(J)
(return UniformBlockDiagonalOperator(J::parameterless_type(u){T, 2}))
return UniformBlockDiagonalOperator(J)
end

@inline function BatchedRoutines._batched_jacobian(
Expand Down Expand Up @@ -211,7 +211,7 @@ end
u_part_next = Dual.(u[idxs_next], dev(Partials.(map(nt, 1:length(idxs_next)))))
end

u_duals = reshape(vcat(u_part_prev, u_part_duals, u_part_next), size(u))
u_duals = BatchedRoutines._restructure(vcat(u_part_prev, u_part_duals, u_part_next), u)
y_duals = f(u_duals)

gs === nothing && return ForwardDiff.partials(y_duals)
Expand All @@ -224,20 +224,20 @@ Base.@assume_effects :total BatchedRoutines._assert_type(::Type{<:AbstractArray{

function BatchedRoutines._jacobian_vector_product(ad::AutoForwardDiff, f::F, x, u) where {F}
Tag = ad.tag === nothing ? typeof(ForwardDiff.Tag(f, eltype(x))) : typeof(ad.tag)
x_dual = _construct_jvp_duals(Tag, x, u)
x_dual = BatchedRoutines._construct_jvp_duals(Tag, x, u)
y_dual = f(x_dual)
return ForwardDiff.partials.(y_dual, 1)
end

function BatchedRoutines._jacobian_vector_product(
ad::AutoForwardDiff, f::F, x, u, p) where {F}
Tag = ad.tag === nothing ? typeof(ForwardDiff.Tag(f, eltype(x))) : typeof(ad.tag)
x_dual = _construct_jvp_duals(Tag, x, u)
x_dual = BatchedRoutines._construct_jvp_duals(Tag, x, u)
y_dual = f(x_dual, p)
return ForwardDiff.partials.(y_dual, 1)
end

@inline function _construct_jvp_duals(::Type{Tag}, x, u) where {Tag}
@inline function BatchedRoutines._construct_jvp_duals(::Type{Tag}, x, u) where {Tag}
T = promote_type(eltype(x), eltype(u))
partials = ForwardDiff.Partials{1, T}.(tuple.(u))
return ForwardDiff.Dual{Tag, T, 1}.(x, reshape(partials, size(x)))
Expand Down
Loading

0 comments on commit 54b894e

Please sign in to comment.