Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.

Migrate to an operator based implementation #11

Merged
merged 6 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,24 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
BatchedRoutinesCUDAExt = ["CUDA"]
BatchedRoutinesComponentArraysForwardDiffExt = ["ComponentArrays", "ForwardDiff"]
BatchedRoutinesCUDALinearSolveExt = ["CUDA", "LinearSolve"]
BatchedRoutinesFiniteDiffExt = ["FiniteDiff"]
BatchedRoutinesForwardDiffExt = ["ForwardDiff"]
BatchedRoutinesLinearSolveExt = ["LinearSolve"]
BatchedRoutinesReverseDiffExt = ["ReverseDiff"]
BatchedRoutinesZygoteExt = ["Zygote"]

Expand Down Expand Up @@ -53,6 +59,7 @@ PrecompileTools = "1.2.0"
Random = "<0.0.1, 1"
ReTestItems = "1.23.1"
ReverseDiff = "1.15"
SciMLOperators = "0.3.8"
StableRNGs = "1.0.1"
Statistics = "1.11.1"
Test = "<0.0.1, 1"
Expand Down
7 changes: 5 additions & 2 deletions ext/BatchedRoutinesCUDAExt/BatchedRoutinesCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
module BatchedRoutinesCUDAExt

using BatchedRoutines: AbstractBatchedMatrixFactorization, BatchedRoutines,
UniformBlockDiagonalMatrix, batchview, nbatches
UniformBlockDiagonalOperator, batchview, nbatches
using CUDA: CUBLAS, CUDA, CUSOLVER, CuArray, CuMatrix, CuPtr, CuVector, DenseCuArray,
DenseCuMatrix
using ConcreteStructs: @concrete
using LinearAlgebra: BLAS, ColumnNorm, LinearAlgebra, NoPivot, RowMaximum, RowNonZero, mul!

const CuBlasFloat = Union{Float16, Float32, Float64, ComplexF32, ComplexF64}

const CuUniformBlockDiagonalMatrix{T} = UniformBlockDiagonalMatrix{T, <:CuArray{T, 3}}
const CuUniformBlockDiagonalOperator{T} = UniformBlockDiagonalOperator{
T, <:CUDA.AnyCuArray{T, 3}}

include("low_level.jl")

include("batched_mul.jl")
include("factorization.jl")
Expand Down
65 changes: 32 additions & 33 deletions ext/BatchedRoutinesCUDAExt/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@

for pT in (:RowMaximum, :RowNonZero, :NoPivot)
@eval begin
function LinearAlgebra.lu!(A::CuUniformBlockDiagonalMatrix, pivot::$pT; kwargs...)
function LinearAlgebra.lu!(A::CuUniformBlockDiagonalOperator, pivot::$pT; kwargs...)
return LinearAlgebra.lu!(A, !(pivot isa NoPivot); kwargs...)
end
end
end

function LinearAlgebra.lu!(
A::CuUniformBlockDiagonalMatrix, pivot::Bool=true; check::Bool=true, kwargs...)
A::CuUniformBlockDiagonalOperator, pivot::Bool=true; check::Bool=true, kwargs...)
pivot_array, info_, factors = CUBLAS.getrf_strided_batched!(A.data, pivot)
info = Array(info_)
check && LinearAlgebra.checknonsingular.(info)
Expand Down Expand Up @@ -82,11 +82,15 @@
return print(io, "CuBatchedQR() with Batch Count: $(nbatches(QR))")
end

function LinearAlgebra.qr!(::CuUniformBlockDiagonalMatrix, ::ColumnNorm; kwargs...)
function LinearAlgebra.qr!(A::CuUniformBlockDiagonalOperator; kwargs...)
return LinearAlgebra.qr!(A, NoPivot(); kwargs...)
end

function LinearAlgebra.qr!(::CuUniformBlockDiagonalOperator, ::ColumnNorm; kwargs...)

Check warning on line 89 in ext/BatchedRoutinesCUDAExt/factorization.jl

View check run for this annotation

Codecov / codecov/patch

ext/BatchedRoutinesCUDAExt/factorization.jl#L89

Added line #L89 was not covered by tests
throw(ArgumentError("ColumnNorm is not supported for batched CUDA QR factorization!"))
end

function LinearAlgebra.qr!(A::CuUniformBlockDiagonalMatrix, ::NoPivot; kwargs...)
function LinearAlgebra.qr!(A::CuUniformBlockDiagonalOperator, ::NoPivot; kwargs...)
τ, factors = CUBLAS.geqrf_batched!(collect(batchview(A)))
return CuBatchedQR{eltype(A)}(factors, τ, size(A))
end
Expand Down Expand Up @@ -116,34 +120,29 @@
return X
end

# Low Level Wrappers
for (fname, elty) in ((:cublasDgetrsBatched, :Float64), (:cublasSgetrsBatched, :Float32),
(:cublasZgetrsBatched, :ComplexF64), (:cublasCgetrsBatched, :ComplexF32))
@eval begin
function getrs_batched!(trans::Char, n, nrhs, Aptrs::CuVector{CuPtr{$elty}},
lda, p, Bptrs::CuVector{CuPtr{$elty}}, ldb)
batchSize = length(Aptrs)
info = Array{Cint}(undef, batchSize)
CUBLAS.$fname(
CUBLAS.handle(), trans, n, nrhs, Aptrs, lda, p, Bptrs, ldb, info, batchSize)
CUDA.unsafe_free!(Aptrs)
CUDA.unsafe_free!(Bptrs)
return info
end
# Direct Ldiv
function BatchedRoutines.__internal_backslash(

Check warning on line 124 in ext/BatchedRoutinesCUDAExt/factorization.jl

View check run for this annotation

Codecov / codecov/patch

ext/BatchedRoutinesCUDAExt/factorization.jl#L124

Added line #L124 was not covered by tests
op::CuUniformBlockDiagonalOperator{T1}, b::AbstractMatrix{T2}) where {T1, T2}
T = promote_type(T1, T2)
return __internal_backslash(T != T1 ? T.(op) : op, T != T2 ? T.(b) : b)

Check warning on line 127 in ext/BatchedRoutinesCUDAExt/factorization.jl

View check run for this annotation

Codecov / codecov/patch

ext/BatchedRoutinesCUDAExt/factorization.jl#L126-L127

Added lines #L126 - L127 were not covered by tests
end

function BatchedRoutines.__internal_backslash(
op::CuUniformBlockDiagonalOperator{T}, b::AbstractMatrix{T}) where {T}
size(op, 1) != length(b) && throw(DimensionMismatch("size(op, 1) != length(b)"))
x = similar(b, T, size(BatchedRoutines.getdata(op), 2), nbatches(op))
m, n = size(op)
if n < m # Underdetermined: LQ or QR with ColumnNorm
error("Underdetermined systems are not supported yet! Please open an issue if you \

Check warning on line 136 in ext/BatchedRoutinesCUDAExt/factorization.jl

View check run for this annotation

Codecov / codecov/patch

ext/BatchedRoutinesCUDAExt/factorization.jl#L136

Added line #L136 was not covered by tests
care about this feature.")
elseif n == m # Square: LU with Pivoting
p, _, F = CUBLAS.getrf_strided_batched!(copy(BatchedRoutines.getdata(op)), true)
copyto!(x, b)
getrs_strided_batched!('N', F, p, x)
else # Overdetermined: QR
CUBLAS.gels_batched!('N', batchview(copy(BatchedRoutines.getdata(op))),

Check warning on line 143 in ext/BatchedRoutinesCUDAExt/factorization.jl

View check run for this annotation

Codecov / codecov/patch

ext/BatchedRoutinesCUDAExt/factorization.jl#L143

Added line #L143 was not covered by tests
[reshape(bᵢ, :, 1) for bᵢ in batchview(b)])
copyto!(x, view(b, 1:n, :))

Check warning on line 145 in ext/BatchedRoutinesCUDAExt/factorization.jl

View check run for this annotation

Codecov / codecov/patch

ext/BatchedRoutinesCUDAExt/factorization.jl#L145

Added line #L145 was not covered by tests
end
end

function getrs_strided_batched!(trans::Char, F::DenseCuArray{<:Any, 3}, p::DenseCuMatrix,
B::Union{DenseCuArray{<:Any, 3}, DenseCuMatrix})
m, n = size(F, 1), size(F, 2)
m != n && throw(DimensionMismatch("All matrices must be square!"))
lda = max(1, stride(F, 2))
ldb = max(1, stride(B, 2))
nrhs = ifelse(ndims(B) == 2, 1, size(B, 2))

Fptrs = CUBLAS.unsafe_strided_batch(F)
Bptrs = CUBLAS.unsafe_strided_batch(B)
info = getrs_batched!(trans, n, nrhs, Fptrs, lda, p, Bptrs, ldb)

return B, info
return x
end
31 changes: 31 additions & 0 deletions ext/BatchedRoutinesCUDAExt/low_level.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Low Level Wrappers
for (fname, elty) in ((:cublasDgetrsBatched, :Float64), (:cublasSgetrsBatched, :Float32),
(:cublasZgetrsBatched, :ComplexF64), (:cublasCgetrsBatched, :ComplexF32))
@eval begin
function getrs_batched!(trans::Char, n, nrhs, Aptrs::CuVector{CuPtr{$elty}},
lda, p, Bptrs::CuVector{CuPtr{$elty}}, ldb)
batchSize = length(Aptrs)
info = Array{Cint}(undef, batchSize)
CUBLAS.$fname(
CUBLAS.handle(), trans, n, nrhs, Aptrs, lda, p, Bptrs, ldb, info, batchSize)
CUDA.unsafe_free!(Aptrs)
CUDA.unsafe_free!(Bptrs)
return info
end
end
end

function getrs_strided_batched!(trans::Char, F::DenseCuArray{<:Any, 3}, p::DenseCuMatrix,
B::Union{DenseCuArray{<:Any, 3}, DenseCuMatrix})
m, n = size(F, 1), size(F, 2)
m != n && throw(DimensionMismatch("All matrices must be square!"))
lda = max(1, stride(F, 2))
ldb = max(1, stride(B, 2))
nrhs = ifelse(ndims(B) == 2, 1, size(B, 2))

Fptrs = CUBLAS.unsafe_strided_batch(F)
Bptrs = CUBLAS.unsafe_strided_batch(B)
info = getrs_batched!(trans, n, nrhs, Fptrs, lda, p, Bptrs, ldb)

return B, info
end
24 changes: 24 additions & 0 deletions ext/BatchedRoutinesCUDALinearSolveExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
module BatchedRoutinesCUDALinearSolveExt

using BatchedRoutines: UniformBlockDiagonalOperator, getdata
using CUDA: CUDA
using LinearAlgebra: LinearAlgebra
using LinearSolve: LinearSolve

const CuUniformBlockDiagonalOperator{T} = UniformBlockDiagonalOperator{
T, <:CUDA.AnyCuArray{T, 3}}

function LinearSolve.init_cacheval(
alg::LinearSolve.SVDFactorization, A::CuUniformBlockDiagonalOperator, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool, ::LinearSolve.OperatorAssumptions)
return nothing
end

function LinearSolve.init_cacheval(
alg::LinearSolve.QRFactorization, A::CuUniformBlockDiagonalOperator, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Bool, ::LinearSolve.OperatorAssumptions)
A_ = UniformBlockDiagonalOperator(similar(getdata(A), 0, 0, 1))
return LinearAlgebra.qr!(A_) # ignore the pivot since CUDA doesn't support it
end

end
12 changes: 12 additions & 0 deletions ext/BatchedRoutinesComponentArraysForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
module BatchedRoutinesComponentArraysForwardDiffExt

using BatchedRoutines: BatchedRoutines
using ComponentArrays: ComponentArrays, ComponentArray
using ForwardDiff: ForwardDiff

@inline function BatchedRoutines._restructure(y, x::ComponentArray)
x_data = ComponentArrays.getdata(x)
return ComponentArray(reshape(y, size(x_data)), ComponentArrays.getaxes(x))

Check warning on line 9 in ext/BatchedRoutinesComponentArraysForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/BatchedRoutinesComponentArraysForwardDiffExt.jl#L7-L9

Added lines #L7 - L9 were not covered by tests
end

end
8 changes: 4 additions & 4 deletions ext/BatchedRoutinesFiniteDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module BatchedRoutinesFiniteDiffExt

using ADTypes: AutoFiniteDiff
using ArrayInterface: matrix_colors, parameterless_type
using BatchedRoutines: BatchedRoutines, UniformBlockDiagonalMatrix, _assert_type
using BatchedRoutines: BatchedRoutines, UniformBlockDiagonalOperator, _assert_type
using FastClosures: @closure
using FiniteDiff: FiniteDiff

Expand All @@ -14,15 +14,15 @@ using FiniteDiff: FiniteDiff
ad::AutoFiniteDiff, f::F, x::AbstractVector{T}) where {F, T}
J = FiniteDiff.finite_difference_jacobian(f, x, ad.fdjtype)
(_assert_type(f) && _assert_type(x) && Base.issingletontype(F)) &&
(return UniformBlockDiagonalMatrix(J::parameterless_type(x){T, 2}))
return UniformBlockDiagonalMatrix(J)
(return UniformBlockDiagonalOperator(J::parameterless_type(x){T, 2}))
return UniformBlockDiagonalOperator(J)
end

@inline function BatchedRoutines._batched_jacobian(
ad::AutoFiniteDiff, f::F, x::AbstractMatrix) where {F}
f! = @closure (y, x_) -> copyto!(y, f(x_))
fx = f(x)
J = UniformBlockDiagonalMatrix(similar(
J = UniformBlockDiagonalOperator(similar(
x, promote_type(eltype(fx), eltype(x)), size(fx, 1), size(x, 1), size(x, 2)))
sparsecache = FiniteDiff.JacobianCache(
x, fx, ad.fdjtype; colorvec=matrix_colors(J), sparsity=J)
Expand Down
16 changes: 8 additions & 8 deletions ext/BatchedRoutinesForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module BatchedRoutinesForwardDiffExt

using ADTypes: AutoForwardDiff
using ArrayInterface: parameterless_type
using BatchedRoutines: BatchedRoutines, UniformBlockDiagonalMatrix, batched_jacobian,
using BatchedRoutines: BatchedRoutines, UniformBlockDiagonalOperator, batched_jacobian,
batched_mul, batched_pickchunksize, _assert_type
using ChainRulesCore: ChainRulesCore
using FastClosures: @closure
Expand Down Expand Up @@ -117,7 +117,7 @@ end
else
jac_call = :((y, J) = __batched_value_and_jacobian(ad, f, u, $(Val(CK))))
end
return Expr(:block, jac_call, :(return (y, UniformBlockDiagonalMatrix(J))))
return Expr(:block, jac_call, :(return (y, UniformBlockDiagonalOperator(J))))
end

## Exposed API
Expand All @@ -132,8 +132,8 @@ end
end
J = ForwardDiff.jacobian(f, u, cfg)
(_assert_type(f) && _assert_type(u) && Base.issingletontype(F)) &&
(return UniformBlockDiagonalMatrix(J::parameterless_type(u){T, 2}))
return UniformBlockDiagonalMatrix(J)
(return UniformBlockDiagonalOperator(J::parameterless_type(u){T, 2}))
return UniformBlockDiagonalOperator(J)
end

@inline function BatchedRoutines._batched_jacobian(
Expand Down Expand Up @@ -211,7 +211,7 @@ end
u_part_next = Dual.(u[idxs_next], dev(Partials.(map(nt, 1:length(idxs_next)))))
end

u_duals = reshape(vcat(u_part_prev, u_part_duals, u_part_next), size(u))
u_duals = BatchedRoutines._restructure(vcat(u_part_prev, u_part_duals, u_part_next), u)
y_duals = f(u_duals)

gs === nothing && return ForwardDiff.partials(y_duals)
Expand All @@ -224,20 +224,20 @@ Base.@assume_effects :total BatchedRoutines._assert_type(::Type{<:AbstractArray{

function BatchedRoutines._jacobian_vector_product(ad::AutoForwardDiff, f::F, x, u) where {F}
Tag = ad.tag === nothing ? typeof(ForwardDiff.Tag(f, eltype(x))) : typeof(ad.tag)
x_dual = _construct_jvp_duals(Tag, x, u)
x_dual = BatchedRoutines._construct_jvp_duals(Tag, x, u)
y_dual = f(x_dual)
return ForwardDiff.partials.(y_dual, 1)
end

function BatchedRoutines._jacobian_vector_product(
ad::AutoForwardDiff, f::F, x, u, p) where {F}
Tag = ad.tag === nothing ? typeof(ForwardDiff.Tag(f, eltype(x))) : typeof(ad.tag)
x_dual = _construct_jvp_duals(Tag, x, u)
x_dual = BatchedRoutines._construct_jvp_duals(Tag, x, u)
y_dual = f(x_dual, p)
return ForwardDiff.partials.(y_dual, 1)
end

@inline function _construct_jvp_duals(::Type{Tag}, x, u) where {Tag}
@inline function BatchedRoutines._construct_jvp_duals(::Type{Tag}, x, u) where {Tag}
T = promote_type(eltype(x), eltype(u))
partials = ForwardDiff.Partials{1, T}.(tuple.(u))
return ForwardDiff.Dual{Tag, T, 1}.(x, reshape(partials, size(x)))
Expand Down
Loading
Loading