diff --git a/Project.toml b/Project.toml index 185a324..e2d43b8 100644 --- a/Project.toml +++ b/Project.toml @@ -14,18 +14,24 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" [weakdeps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] BatchedRoutinesCUDAExt = ["CUDA"] +BatchedRoutinesComponentArraysForwardDiffExt = ["ComponentArrays", "ForwardDiff"] +BatchedRoutinesCUDALinearSolveExt = ["CUDA", "LinearSolve"] BatchedRoutinesFiniteDiffExt = ["FiniteDiff"] BatchedRoutinesForwardDiffExt = ["ForwardDiff"] +BatchedRoutinesLinearSolveExt = ["LinearSolve"] BatchedRoutinesReverseDiffExt = ["ReverseDiff"] BatchedRoutinesZygoteExt = ["Zygote"] @@ -53,6 +59,7 @@ PrecompileTools = "1.2.0" Random = "<0.0.1, 1" ReTestItems = "1.23.1" ReverseDiff = "1.15" +SciMLOperators = "0.3.8" StableRNGs = "1.0.1" Statistics = "1.11.1" Test = "<0.0.1, 1" diff --git a/ext/BatchedRoutinesCUDAExt/BatchedRoutinesCUDAExt.jl b/ext/BatchedRoutinesCUDAExt/BatchedRoutinesCUDAExt.jl index 9791553..8f088a3 100644 --- a/ext/BatchedRoutinesCUDAExt/BatchedRoutinesCUDAExt.jl +++ b/ext/BatchedRoutinesCUDAExt/BatchedRoutinesCUDAExt.jl @@ -1,7 +1,7 @@ module BatchedRoutinesCUDAExt using BatchedRoutines: AbstractBatchedMatrixFactorization, BatchedRoutines, - UniformBlockDiagonalMatrix, batchview, nbatches + UniformBlockDiagonalOperator, batchview, nbatches using CUDA: CUBLAS, CUDA, CUSOLVER, CuArray, CuMatrix, CuPtr, CuVector, DenseCuArray, DenseCuMatrix using ConcreteStructs: @concrete @@ -9,7 +9,10 @@ using LinearAlgebra: BLAS, ColumnNorm, LinearAlgebra, NoPivot, RowMaximum, RowNo const CuBlasFloat = Union{Float16, Float32, Float64, ComplexF32, ComplexF64} -const CuUniformBlockDiagonalMatrix{T} = UniformBlockDiagonalMatrix{T, <:CuArray{T, 3}} +const CuUniformBlockDiagonalOperator{T} = UniformBlockDiagonalOperator{ + T, <:CUDA.AnyCuArray{T, 3}} + +include("low_level.jl") include("batched_mul.jl") include("factorization.jl") diff --git a/ext/BatchedRoutinesCUDAExt/factorization.jl b/ext/BatchedRoutinesCUDAExt/factorization.jl index c54f940..85db0ce 100644 --- a/ext/BatchedRoutinesCUDAExt/factorization.jl +++ b/ext/BatchedRoutinesCUDAExt/factorization.jl @@ -28,14 +28,14 @@ end for pT in (:RowMaximum, :RowNonZero, :NoPivot) @eval begin - function LinearAlgebra.lu!(A::CuUniformBlockDiagonalMatrix, pivot::$pT; kwargs...) + function LinearAlgebra.lu!(A::CuUniformBlockDiagonalOperator, pivot::$pT; kwargs...) return LinearAlgebra.lu!(A, !(pivot isa NoPivot); kwargs...) end end end function LinearAlgebra.lu!( - A::CuUniformBlockDiagonalMatrix, pivot::Bool=true; check::Bool=true, kwargs...) + A::CuUniformBlockDiagonalOperator, pivot::Bool=true; check::Bool=true, kwargs...) pivot_array, info_, factors = CUBLAS.getrf_strided_batched!(A.data, pivot) info = Array(info_) check && LinearAlgebra.checknonsingular.(info) @@ -82,11 +82,15 @@ function Base.show(io::IO, QR::CuBatchedQR) return print(io, "CuBatchedQR() with Batch Count: $(nbatches(QR))") end -function LinearAlgebra.qr!(::CuUniformBlockDiagonalMatrix, ::ColumnNorm; kwargs...) +function LinearAlgebra.qr!(A::CuUniformBlockDiagonalOperator; kwargs...) + return LinearAlgebra.qr!(A, NoPivot(); kwargs...) +end + +function LinearAlgebra.qr!(::CuUniformBlockDiagonalOperator, ::ColumnNorm; kwargs...) throw(ArgumentError("ColumnNorm is not supported for batched CUDA QR factorization!")) end -function LinearAlgebra.qr!(A::CuUniformBlockDiagonalMatrix, ::NoPivot; kwargs...) +function LinearAlgebra.qr!(A::CuUniformBlockDiagonalOperator, ::NoPivot; kwargs...) τ, factors = CUBLAS.geqrf_batched!(collect(batchview(A))) return CuBatchedQR{eltype(A)}(factors, τ, size(A)) end @@ -116,34 +120,29 @@ function LinearAlgebra.ldiv!(X::CuMatrix, A::CuBatchedQR, b::CuMatrix) return X end -# Low Level Wrappers -for (fname, elty) in ((:cublasDgetrsBatched, :Float64), (:cublasSgetrsBatched, :Float32), - (:cublasZgetrsBatched, :ComplexF64), (:cublasCgetrsBatched, :ComplexF32)) - @eval begin - function getrs_batched!(trans::Char, n, nrhs, Aptrs::CuVector{CuPtr{$elty}}, - lda, p, Bptrs::CuVector{CuPtr{$elty}}, ldb) - batchSize = length(Aptrs) - info = Array{Cint}(undef, batchSize) - CUBLAS.$fname( - CUBLAS.handle(), trans, n, nrhs, Aptrs, lda, p, Bptrs, ldb, info, batchSize) - CUDA.unsafe_free!(Aptrs) - CUDA.unsafe_free!(Bptrs) - return info - end +# Direct Ldiv +function BatchedRoutines.__internal_backslash( + op::CuUniformBlockDiagonalOperator{T1}, b::AbstractMatrix{T2}) where {T1, T2} + T = promote_type(T1, T2) + return __internal_backslash(T != T1 ? T.(op) : op, T != T2 ? T.(b) : b) +end + +function BatchedRoutines.__internal_backslash( + op::CuUniformBlockDiagonalOperator{T}, b::AbstractMatrix{T}) where {T} + size(op, 1) != length(b) && throw(DimensionMismatch("size(op, 1) != length(b)")) + x = similar(b, T, size(BatchedRoutines.getdata(op), 2), nbatches(op)) + m, n = size(op) + if n < m # Underdetermined: LQ or QR with ColumnNorm + error("Underdetermined systems are not supported yet! Please open an issue if you \ + care about this feature.") + elseif n == m # Square: LU with Pivoting + p, _, F = CUBLAS.getrf_strided_batched!(copy(BatchedRoutines.getdata(op)), true) + copyto!(x, b) + getrs_strided_batched!('N', F, p, x) + else # Overdetermined: QR + CUBLAS.gels_batched!('N', batchview(copy(BatchedRoutines.getdata(op))), + [reshape(bᵢ, :, 1) for bᵢ in batchview(b)]) + copyto!(x, view(b, 1:n, :)) end -end - -function getrs_strided_batched!(trans::Char, F::DenseCuArray{<:Any, 3}, p::DenseCuMatrix, - B::Union{DenseCuArray{<:Any, 3}, DenseCuMatrix}) - m, n = size(F, 1), size(F, 2) - m != n && throw(DimensionMismatch("All matrices must be square!")) - lda = max(1, stride(F, 2)) - ldb = max(1, stride(B, 2)) - nrhs = ifelse(ndims(B) == 2, 1, size(B, 2)) - - Fptrs = CUBLAS.unsafe_strided_batch(F) - Bptrs = CUBLAS.unsafe_strided_batch(B) - info = getrs_batched!(trans, n, nrhs, Fptrs, lda, p, Bptrs, ldb) - - return B, info + return x end diff --git a/ext/BatchedRoutinesCUDAExt/low_level.jl b/ext/BatchedRoutinesCUDAExt/low_level.jl new file mode 100644 index 0000000..a858e52 --- /dev/null +++ b/ext/BatchedRoutinesCUDAExt/low_level.jl @@ -0,0 +1,31 @@ +# Low Level Wrappers +for (fname, elty) in ((:cublasDgetrsBatched, :Float64), (:cublasSgetrsBatched, :Float32), + (:cublasZgetrsBatched, :ComplexF64), (:cublasCgetrsBatched, :ComplexF32)) + @eval begin + function getrs_batched!(trans::Char, n, nrhs, Aptrs::CuVector{CuPtr{$elty}}, + lda, p, Bptrs::CuVector{CuPtr{$elty}}, ldb) + batchSize = length(Aptrs) + info = Array{Cint}(undef, batchSize) + CUBLAS.$fname( + CUBLAS.handle(), trans, n, nrhs, Aptrs, lda, p, Bptrs, ldb, info, batchSize) + CUDA.unsafe_free!(Aptrs) + CUDA.unsafe_free!(Bptrs) + return info + end + end +end + +function getrs_strided_batched!(trans::Char, F::DenseCuArray{<:Any, 3}, p::DenseCuMatrix, + B::Union{DenseCuArray{<:Any, 3}, DenseCuMatrix}) + m, n = size(F, 1), size(F, 2) + m != n && throw(DimensionMismatch("All matrices must be square!")) + lda = max(1, stride(F, 2)) + ldb = max(1, stride(B, 2)) + nrhs = ifelse(ndims(B) == 2, 1, size(B, 2)) + + Fptrs = CUBLAS.unsafe_strided_batch(F) + Bptrs = CUBLAS.unsafe_strided_batch(B) + info = getrs_batched!(trans, n, nrhs, Fptrs, lda, p, Bptrs, ldb) + + return B, info +end diff --git a/ext/BatchedRoutinesCUDALinearSolveExt.jl b/ext/BatchedRoutinesCUDALinearSolveExt.jl new file mode 100644 index 0000000..d9da190 --- /dev/null +++ b/ext/BatchedRoutinesCUDALinearSolveExt.jl @@ -0,0 +1,24 @@ +module BatchedRoutinesCUDALinearSolveExt + +using BatchedRoutines: UniformBlockDiagonalOperator, getdata +using CUDA: CUDA +using LinearAlgebra: LinearAlgebra +using LinearSolve: LinearSolve + +const CuUniformBlockDiagonalOperator{T} = UniformBlockDiagonalOperator{ + T, <:CUDA.AnyCuArray{T, 3}} + +function LinearSolve.init_cacheval( + alg::LinearSolve.SVDFactorization, A::CuUniformBlockDiagonalOperator, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, ::LinearSolve.OperatorAssumptions) + return nothing +end + +function LinearSolve.init_cacheval( + alg::LinearSolve.QRFactorization, A::CuUniformBlockDiagonalOperator, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, ::LinearSolve.OperatorAssumptions) + A_ = UniformBlockDiagonalOperator(similar(getdata(A), 0, 0, 1)) + return LinearAlgebra.qr!(A_) # ignore the pivot since CUDA doesn't support it +end + +end diff --git a/ext/BatchedRoutinesComponentArraysForwardDiffExt.jl b/ext/BatchedRoutinesComponentArraysForwardDiffExt.jl new file mode 100644 index 0000000..e4c239b --- /dev/null +++ b/ext/BatchedRoutinesComponentArraysForwardDiffExt.jl @@ -0,0 +1,12 @@ +module BatchedRoutinesComponentArraysForwardDiffExt + +using BatchedRoutines: BatchedRoutines +using ComponentArrays: ComponentArrays, ComponentArray +using ForwardDiff: ForwardDiff + +@inline function BatchedRoutines._restructure(y, x::ComponentArray) + x_data = ComponentArrays.getdata(x) + return ComponentArray(reshape(y, size(x_data)), ComponentArrays.getaxes(x)) +end + +end diff --git a/ext/BatchedRoutinesFiniteDiffExt.jl b/ext/BatchedRoutinesFiniteDiffExt.jl index 6b0fe49..c95a0ac 100644 --- a/ext/BatchedRoutinesFiniteDiffExt.jl +++ b/ext/BatchedRoutinesFiniteDiffExt.jl @@ -2,7 +2,7 @@ module BatchedRoutinesFiniteDiffExt using ADTypes: AutoFiniteDiff using ArrayInterface: matrix_colors, parameterless_type -using BatchedRoutines: BatchedRoutines, UniformBlockDiagonalMatrix, _assert_type +using BatchedRoutines: BatchedRoutines, UniformBlockDiagonalOperator, _assert_type using FastClosures: @closure using FiniteDiff: FiniteDiff @@ -14,15 +14,15 @@ using FiniteDiff: FiniteDiff ad::AutoFiniteDiff, f::F, x::AbstractVector{T}) where {F, T} J = FiniteDiff.finite_difference_jacobian(f, x, ad.fdjtype) (_assert_type(f) && _assert_type(x) && Base.issingletontype(F)) && - (return UniformBlockDiagonalMatrix(J::parameterless_type(x){T, 2})) - return UniformBlockDiagonalMatrix(J) + (return UniformBlockDiagonalOperator(J::parameterless_type(x){T, 2})) + return UniformBlockDiagonalOperator(J) end @inline function BatchedRoutines._batched_jacobian( ad::AutoFiniteDiff, f::F, x::AbstractMatrix) where {F} f! = @closure (y, x_) -> copyto!(y, f(x_)) fx = f(x) - J = UniformBlockDiagonalMatrix(similar( + J = UniformBlockDiagonalOperator(similar( x, promote_type(eltype(fx), eltype(x)), size(fx, 1), size(x, 1), size(x, 2))) sparsecache = FiniteDiff.JacobianCache( x, fx, ad.fdjtype; colorvec=matrix_colors(J), sparsity=J) diff --git a/ext/BatchedRoutinesForwardDiffExt.jl b/ext/BatchedRoutinesForwardDiffExt.jl index b37405c..b5b3d3c 100644 --- a/ext/BatchedRoutinesForwardDiffExt.jl +++ b/ext/BatchedRoutinesForwardDiffExt.jl @@ -2,7 +2,7 @@ module BatchedRoutinesForwardDiffExt using ADTypes: AutoForwardDiff using ArrayInterface: parameterless_type -using BatchedRoutines: BatchedRoutines, UniformBlockDiagonalMatrix, batched_jacobian, +using BatchedRoutines: BatchedRoutines, UniformBlockDiagonalOperator, batched_jacobian, batched_mul, batched_pickchunksize, _assert_type using ChainRulesCore: ChainRulesCore using FastClosures: @closure @@ -117,7 +117,7 @@ end else jac_call = :((y, J) = __batched_value_and_jacobian(ad, f, u, $(Val(CK)))) end - return Expr(:block, jac_call, :(return (y, UniformBlockDiagonalMatrix(J)))) + return Expr(:block, jac_call, :(return (y, UniformBlockDiagonalOperator(J)))) end ## Exposed API @@ -132,8 +132,8 @@ end end J = ForwardDiff.jacobian(f, u, cfg) (_assert_type(f) && _assert_type(u) && Base.issingletontype(F)) && - (return UniformBlockDiagonalMatrix(J::parameterless_type(u){T, 2})) - return UniformBlockDiagonalMatrix(J) + (return UniformBlockDiagonalOperator(J::parameterless_type(u){T, 2})) + return UniformBlockDiagonalOperator(J) end @inline function BatchedRoutines._batched_jacobian( @@ -211,7 +211,7 @@ end u_part_next = Dual.(u[idxs_next], dev(Partials.(map(nt, 1:length(idxs_next))))) end - u_duals = reshape(vcat(u_part_prev, u_part_duals, u_part_next), size(u)) + u_duals = BatchedRoutines._restructure(vcat(u_part_prev, u_part_duals, u_part_next), u) y_duals = f(u_duals) gs === nothing && return ForwardDiff.partials(y_duals) @@ -224,7 +224,7 @@ Base.@assume_effects :total BatchedRoutines._assert_type(::Type{<:AbstractArray{ function BatchedRoutines._jacobian_vector_product(ad::AutoForwardDiff, f::F, x, u) where {F} Tag = ad.tag === nothing ? typeof(ForwardDiff.Tag(f, eltype(x))) : typeof(ad.tag) - x_dual = _construct_jvp_duals(Tag, x, u) + x_dual = BatchedRoutines._construct_jvp_duals(Tag, x, u) y_dual = f(x_dual) return ForwardDiff.partials.(y_dual, 1) end @@ -232,12 +232,12 @@ end function BatchedRoutines._jacobian_vector_product( ad::AutoForwardDiff, f::F, x, u, p) where {F} Tag = ad.tag === nothing ? typeof(ForwardDiff.Tag(f, eltype(x))) : typeof(ad.tag) - x_dual = _construct_jvp_duals(Tag, x, u) + x_dual = BatchedRoutines._construct_jvp_duals(Tag, x, u) y_dual = f(x_dual, p) return ForwardDiff.partials.(y_dual, 1) end -@inline function _construct_jvp_duals(::Type{Tag}, x, u) where {Tag} +@inline function BatchedRoutines._construct_jvp_duals(::Type{Tag}, x, u) where {Tag} T = promote_type(eltype(x), eltype(u)) partials = ForwardDiff.Partials{1, T}.(tuple.(u)) return ForwardDiff.Dual{Tag, T, 1}.(x, reshape(partials, size(x))) diff --git a/ext/BatchedRoutinesLinearSolveExt.jl b/ext/BatchedRoutinesLinearSolveExt.jl new file mode 100644 index 0000000..42b2808 --- /dev/null +++ b/ext/BatchedRoutinesLinearSolveExt.jl @@ -0,0 +1,199 @@ +module BatchedRoutinesLinearSolveExt + +using ArrayInterface: ArrayInterface +using BatchedRoutines: BatchedRoutines, UniformBlockDiagonalOperator, getdata +using ChainRulesCore: ChainRulesCore, NoTangent +using FastClosures: @closure +using LinearAlgebra: LinearAlgebra +using LinearSolve: LinearSolve, SciMLBase + +const CRC = ChainRulesCore + +# Overload LinearProblem, else causing problems in the adjoint code +function LinearSolve.LinearProblem(op::UniformBlockDiagonalOperator, b, args...; kwargs...) + return LinearSolve.LinearProblem{true}(op, b, args...; kwargs...) +end + +# Default Algorithm +function LinearSolve.defaultalg( + op::UniformBlockDiagonalOperator, b, assump::LinearSolve.OperatorAssumptions{Bool}) + alg = if assump.issq + LinearSolve.DefaultAlgorithmChoice.LUFactorization + elseif assump.condition === LinearSolve.OperatorCondition.WellConditioned + LinearSolve.DefaultAlgorithmChoice.NormalCholeskyFactorization + elseif assump.condition === LinearSolve.OperatorCondition.IllConditioned + if LinearSolve.is_underdetermined(op) + LinearSolve.DefaultAlgorithmChoice.QRFactorizationPivoted + else + LinearSolve.DefaultAlgorithmChoice.QRFactorization + end + elseif assump.condition === LinearSolve.OperatorCondition.VeryIllConditioned + if LinearSolve.is_underdetermined(op) + LinearSolve.DefaultAlgorithmChoice.QRFactorizationPivoted + else + LinearSolve.DefaultAlgorithmChoice.QRFactorization + end + elseif assump.condition === LinearSolve.OperatorCondition.SuperIllConditioned + LinearSolve.DefaultAlgorithmChoice.SVDFactorization + else + error("Special factorization not handled in current default algorithm.") + end + return LinearSolve.DefaultLinearSolver(alg) +end + +# GenericLUFactorization +function LinearSolve.init_cacheval(alg::LinearSolve.GenericLUFactorization, + A::UniformBlockDiagonalOperator, b, u, Pl, Pr, maxiters::Int, + abstol, reltol, verbose::Bool, ::LinearSolve.OperatorAssumptions) + A_ = UniformBlockDiagonalOperator(similar(getdata(A), 0, 0, 1)) + return LinearAlgebra.generic_lufact!(A_, alg.pivot; check=false) +end + +function LinearSolve.do_factorization( + alg::LinearSolve.GenericLUFactorization, A::UniformBlockDiagonalOperator, b, u) + return LinearAlgebra.generic_lufact!(A, alg.pivot; check=false) +end + +# LUFactorization +function LinearSolve.init_cacheval( + alg::LinearSolve.LUFactorization, A::UniformBlockDiagonalOperator, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, ::LinearSolve.OperatorAssumptions) + A_ = UniformBlockDiagonalOperator(similar(getdata(A), 0, 0, 1)) + return LinearAlgebra.lu!(A_, alg.pivot; check=false) +end + +function LinearSolve.do_factorization( + alg::LinearSolve.LUFactorization, A::UniformBlockDiagonalOperator, b, u) + return LinearAlgebra.lu!(A, alg.pivot; check=false) +end + +# QRFactorization +function LinearSolve.init_cacheval( + alg::LinearSolve.QRFactorization, A::UniformBlockDiagonalOperator, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, ::LinearSolve.OperatorAssumptions) + A_ = UniformBlockDiagonalOperator(similar(getdata(A), 0, 0, 1)) + return LinearAlgebra.qr!(A_, alg.pivot) +end + +function LinearSolve.do_factorization( + alg::LinearSolve.QRFactorization, A::UniformBlockDiagonalOperator, b, u) + alg.inplace && return LinearAlgebra.qr!(A, alg.pivot) + return LinearAlgebra.qr(A, alg.pivot) +end + +# CholeskyFactorization +function LinearSolve.init_cacheval(alg::LinearSolve.CholeskyFactorization, + A::UniformBlockDiagonalOperator, b, u, Pl, Pr, maxiters::Int, + abstol, reltol, verbose::Bool, ::LinearSolve.OperatorAssumptions) + A_ = UniformBlockDiagonalOperator(similar(getdata(A), 0, 0, 1)) + return ArrayInterface.cholesky_instance(A_, alg.pivot) +end + +function LinearSolve.do_factorization( + alg::LinearSolve.CholeskyFactorization, A::UniformBlockDiagonalOperator, b, u) + return LinearAlgebra.cholesky!(A, alg.pivot; check=false) +end + +# NormalCholeskyFactorization +function LinearSolve.init_cacheval(alg::LinearSolve.NormalCholeskyFactorization, + A::UniformBlockDiagonalOperator, b, u, Pl, Pr, maxiters::Int, + abstol, reltol, verbose::Bool, ::LinearSolve.OperatorAssumptions) + A_ = UniformBlockDiagonalOperator(similar(getdata(A), 0, 0, 1)) + return ArrayInterface.cholesky_instance(A_, alg.pivot) +end + +function LinearSolve.solve!(cache::LinearSolve.LinearCache{<:UniformBlockDiagonalOperator}, + alg::LinearSolve.NormalCholeskyFactorization; kwargs...) + A = cache.A + if cache.isfresh + fact = LinearAlgebra.cholesky!(A' * A, alg.pivot; check=false) + cache.cacheval = fact + cache.isfresh = false + end + y = LinearAlgebra.ldiv!( + cache.u, LinearSolve.@get_cacheval(cache, :NormalCholeskyFactorization), + A' * cache.b) + return LinearSolve.SciMLBase.build_linear_solution(alg, y, nothing, cache) +end + +# SVDFactorization +function LinearSolve.init_cacheval( + alg::LinearSolve.SVDFactorization, A::UniformBlockDiagonalOperator, b, u, Pl, Pr, + maxiters::Int, abstol, reltol, verbose::Bool, ::LinearSolve.OperatorAssumptions) + A_ = UniformBlockDiagonalOperator(similar(getdata(A), 0, 0, 1)) + return ArrayInterface.svd_instance(A_) +end + +function LinearSolve.do_factorization( + alg::LinearSolve.SVDFactorization, A::UniformBlockDiagonalOperator, b, u) + return LinearAlgebra.svd!(A; alg.full, alg.alg) +end + +# We need a custom rrule here to prevent spurios gradients for zero blocks +# Copied from https://github.com/SciML/LinearSolve.jl/blob/7911113c6b14b6897cc356e277ccd5a98faa7dd7/src/adjoint.jl#L31 except the Lazy Arrays part +function CRC.rrule(::typeof(SciMLBase.solve), + prob::SciMLBase.LinearProblem{T1, T2, <:UniformBlockDiagonalOperator}, + alg::LinearSolve.SciMLLinearSolveAlgorithm, args...; + alias_A=LinearSolve.default_alias_A(alg, prob.A, prob.b), kwargs...) where {T1, T2} + cache = SciMLBase.init(prob, alg, args...; kwargs...) + (; A, sensealg) = cache + + @assert sensealg isa LinearSolve.LinearSolveAdjoint "Currently only `LinearSolveAdjoint` is supported for adjoint sensitivity analysis." + + # Decide if we need to cache `A` and `b` for the reverse pass + A_ = A + if sensealg.linsolve === missing + # We can reuse the factorization so no copy is needed + # Krylov Methods don't modify `A`, so it's safe to just reuse it + # No Copy is needed even for the default case + if !(alg isa LinearSolve.AbstractFactorization || + alg isa LinearSolve.AbstractKrylovSubspaceMethod || + alg isa LinearSolve.DefaultLinearSolver) + A_ = alias_A ? deepcopy(A) : A + end + else + A_ = deepcopy(A) + end + + sol = SciMLBase.solve!(cache) + + proj_A = CRC.ProjectTo(getdata(A)) + proj_b = CRC.ProjectTo(prob.b) + + ∇linear_solve = @closure ∂sol -> begin + ∂u = ∂sol.u + if sensealg.linsolve === missing + λ = if cache.cacheval isa LinearAlgebra.Factorization + cache.cacheval' \ ∂u + elseif cache.cacheval isa Tuple && + cache.cacheval[1] isa LinearAlgebra.Factorization + first(cache.cacheval)' \ ∂u + elseif alg isa LinearSolve.AbstractKrylovSubspaceMethod + invprob = SciMLBase.LinearProblem(transpose(cache.A), ∂u) + SciMLBase.solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u + elseif alg isa LinearSolve.DefaultLinearSolver + LinearSolve.defaultalg_adjoint_eval(cache, ∂u) + else + invprob = SciMLBase.LinearProblem(transpose(A_), ∂u) # We cached `A` + SciMLBase.solve(invprob, alg; cache.abstol, cache.reltol, cache.verbose).u + end + else + invprob = SciMLBase.LinearProblem(transpose(A_), ∂u) # We cached `A` + λ = SciMLBase.solve( + invprob, sensealg.linsolve; cache.abstol, cache.reltol, cache.verbose).u + end + + uᵀ = reshape(sol.u, 1, :, BatchedRoutines.nbatches(A)) + ∂A = UniformBlockDiagonalOperator(proj_A(BatchedRoutines.batched_mul( + reshape(λ, :, 1, BatchedRoutines.nbatches(A)), -uᵀ))) + ∂b = proj_b(λ) + ∂prob = SciMLBase.LinearProblem(∂A, ∂b, NoTangent()) + + return ( + NoTangent(), ∂prob, NoTangent(), ntuple(Returns(NoTangent()), length(args))...) + end + + return sol, ∇linear_solve +end + +end diff --git a/src/BatchedRoutines.jl b/src/BatchedRoutines.jl index 3e10f60..692b71c 100644 --- a/src/BatchedRoutines.jl +++ b/src/BatchedRoutines.jl @@ -14,6 +14,22 @@ import PrecompileTools: @recompile_invalidations using LinearAlgebra: BLAS, ColumnNorm, LinearAlgebra, NoPivot, RowMaximum, RowNonZero, mul!, pinv using LuxDeviceUtils: LuxDeviceUtils, get_device + using SciMLOperators: SciMLOperators, AbstractSciMLOperator +end + +function __init__() + @static if isdefined(Base.Experimental, :register_error_hint) + Base.Experimental.register_error_hint(MethodError) do io, exc, argtypes, kwargs + if any(Base.Fix2(isa, UniformBlockDiagonalOperator), exc.args) + printstyled(io, "\nHINT: "; bold=true) + printstyled( + io, "`UniformBlockDiagonalOperator` doesn't support AbstractArray \ + operations. If you want this supported, open an issue at \ + https://github.com/LuxDL/BatchedRoutines.jl to discuss it."; + color=:cyan) + end + end + end end const CRC = ChainRulesCore @@ -28,7 +44,9 @@ const BatchedMatrix{T} = AbstractArray{T, 3} include("api.jl") include("helpers.jl") -include("matrix.jl") + +include("operator.jl") +include("factorization.jl") include("impl/batched_mul.jl") include("impl/batched_gmres.jl") @@ -39,9 +57,6 @@ export AutoFiniteDiff, AutoForwardDiff, AutoReverseDiff, AutoZygote export batched_adjoint, batched_gradient, batched_jacobian, batched_pickchunksize, batched_mul, batched_pinv, batched_transpose export batchview, nbatches -export UniformBlockDiagonalMatrix - -# TODO: Ship a custom GMRES routine & if needed some of the other complex nonlinear solve -# routines +export UniformBlockDiagonalOperator end diff --git a/src/api.jl b/src/api.jl index b3fc107..c68c3dd 100644 --- a/src/api.jl +++ b/src/api.jl @@ -3,7 +3,7 @@ batched_jacobian(ad, f::F, x, p) where {F} Use the backend `ad` to compute the Jacobian of `f` at `x` in batched mode. Returns a -[`UniformBlockDiagonalMatrix`](@ref) as the Jacobian. +[`UniformBlockDiagonalOperator`](@ref) as the Jacobian. !!! warning @@ -63,6 +63,7 @@ batched_mul!(C, A, B, α=true, β=false) = _batched_mul!(C, A, B, α, β) Transpose the first two dimensions of `X`. """ batched_transpose(X::BatchedMatrix) = PermutedDimsArray(X, (2, 1, 3)) +batched_transpose(X::AbstractMatrix) = reshape(X, 1, size(X, 1), size(X, 2)) """ batched_adjoint(X::AbstractArray{T, 3}) where {T} @@ -101,7 +102,7 @@ batchview(A::AbstractVector{T}) where {T} = isbitstype(T) ? (A,) : A """ batched_pinv(A::AbstractArray{T, 3}) where {T} - batched_pinv(A::UniformBlockDiagonalMatrix) + batched_pinv(A::UniformBlockDiagonalOperator) Compute the pseudo-inverse of `A` in batched mode. """ @@ -109,7 +110,7 @@ Compute the pseudo-inverse of `A` in batched mode. """ batched_inv(A::AbstractArray{T, 3}) where {T} - batched_inv(A::UniformBlockDiagonalMatrix) + batched_inv(A::UniformBlockDiagonalOperator) Compute the inverse of `A` in batched mode. """ diff --git a/src/chainrules.jl b/src/chainrules.jl index 4bd7d2c..7b7655c 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -10,7 +10,7 @@ function CRC.rrule(::typeof(batched_jacobian), ad, f::F, x::AbstractMatrix) wher gradient_ad = AutoZygote() _map_fnₓ = ((i, Δᵢ),) -> _jacobian_vector_product(AutoForwardDiff(), x -> batched_gradient(gradient_ad, x_ -> sum(vec(f(x_))[i:i]), x), x, Δᵢ) - ∂x = reshape(mapreduce(_map_fnₓ, +, enumerate(_eachrow(Δ))), size(x)) + ∂x = reshape(mapreduce(_map_fnₓ, +, enumerate(eachrow(Δ))), size(x)) return NoTangent(), NoTangent(), NoTangent(), ∂x end return J, ∇batched_jacobian @@ -28,13 +28,13 @@ function CRC.rrule(::typeof(batched_jacobian), ad, f::F, x, p) where {F} _map_fnₓ = ((i, Δᵢ),) -> _jacobian_vector_product(AutoForwardDiff(), x -> batched_gradient(AutoZygote(), x_ -> sum(vec(f(x_, p))[i:i]), x), x, Δᵢ) - ∂x = reshape(mapreduce(_map_fnₓ, +, enumerate(_eachrow(Δ))), size(x)) + ∂x = reshape(mapreduce(_map_fnₓ, +, enumerate(eachrow(Δ))), size(x)) _map_fnₚ = ((i, Δᵢ),) -> _jacobian_vector_product(AutoForwardDiff(), (x, p_) -> batched_gradient(AutoZygote(), p__ -> sum(vec(f(x, p__))[i:i]), p_), x, Δᵢ, p) - ∂p = reshape(mapreduce(_map_fnₚ, +, enumerate(_eachrow(Δ))), size(p)) + ∂p = reshape(mapreduce(_map_fnₚ, +, enumerate(eachrow(Δ))), size(p)) return NoTangent(), NoTangent(), NoTangent(), ∂x, ∂p end @@ -86,55 +86,75 @@ end # batched_mul rrule function CRC.rrule(::typeof(_batched_mul), A::AbstractArray{T1, 3}, B::AbstractArray{T2, 3}) where {T1, T2} - function ∇batched_mul(_Δ) + ∇batched_mul = @closure _Δ -> begin Δ = CRC.unthunk(_Δ) - ∂A = CRC.@thunk begin - tmp = batched_mul(Δ, batched_adjoint(B)) - size(A, 3) == 1 ? sum(tmp; dims=3) : tmp - end - ∂B = CRC.@thunk begin - tmp = batched_mul(batched_adjoint(A), Δ) - size(B, 3) == 1 ? sum(tmp; dims=3) : tmp - end + tmpA = batched_mul(Δ, batched_adjoint(B)) + ∂A = size(A, 3) == 1 ? sum(tmpA; dims=3) : tmpA + tmpB = batched_mul(batched_adjoint(A), Δ) + ∂B = size(B, 3) == 1 ? sum(tmpB; dims=3) : tmpB return (NoTangent(), ∂A, ∂B) end return batched_mul(A, B), ∇batched_mul end -function CRC.rrule(::typeof(*), X::UniformBlockDiagonalMatrix{<:Union{Real, Complex}}, - Y::AbstractMatrix{<:Union{Real, Complex}}) - function ∇times(_Δ) - Δ = CRC.unthunk(_Δ) - ∂X = CRC.@thunk(Δ*batched_adjoint(batched_reshape(Y, :, 1))) - ∂Y = CRC.@thunk begin - res = (X' * Δ) - Y isa UniformBlockDiagonalMatrix ? res : dropdims(res.data; dims=2) - end - return (NoTangent(), ∂X, ∂Y) +# constructor +function CRC.rrule(::Type{<:UniformBlockDiagonalOperator}, data) + ∇UniformBlockDiagonalOperator = @closure Δ -> begin + ∂data = Δ isa UniformBlockDiagonalOperator ? getdata(Δ) : + (Δ isa NoTangent ? NoTangent() : Δ) + return (NoTangent(), ∂data) end - return X * Y, ∇times + return UniformBlockDiagonalOperator(data), ∇UniformBlockDiagonalOperator end -function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(*), - X::AbstractMatrix{<:Union{Real, Complex}}, - Y::UniformBlockDiagonalMatrix{<:Union{Real, Complex}}) - _f = @closure (x, y) -> dropdims( - batched_mul(reshape(x, :, 1, nbatches(x)), y.data); dims=1) - return CRC.rrule_via_ad(cfg, _f, X, Y) +function CRC.rrule(::typeof(getproperty), op::UniformBlockDiagonalOperator, x::Symbol) + @assert x === :data + ∇getproperty = @closure Δ -> (NoTangent(), UniformBlockDiagonalOperator(Δ)) + return op.data, ∇getproperty end -# constructor -function CRC.rrule(::Type{<:UniformBlockDiagonalMatrix}, data) - function ∇UniformBlockDiagonalMatrix(Δ) - ∂data = Δ isa UniformBlockDiagonalMatrix ? Δ.data : - (Δ isa NoTangent ? NoTangent() : Δ) - return (NoTangent(), ∂data) +# mapreduce fallback rules for UniformBlockDiagonalOperator +@inline _unsum(x, dy, dims) = broadcast(last ∘ tuple, x, dy) +@inline _unsum(x, dy, ::Colon) = broadcast(last ∘ tuple, x, Ref(dy)) + +function CRC.rrule(::typeof(sum), ::typeof(abs2), op::UniformBlockDiagonalOperator{T}; + dims=:) where {T <: Union{Real, Complex}} + y = sum(abs2, op; dims) + ∇sum_abs2 = @closure Δ -> begin + ∂op = if dims isa Colon + UniformBlockDiagonalOperator(2 .* real.(Δ) .* getdata(op)) + else + UniformBlockDiagonalOperator(2 .* real.(getdata(Δ)) .* getdata(op)) + end + return NoTangent(), NoTangent(), ∂op end - return UniformBlockDiagonalMatrix(data), ∇UniformBlockDiagonalMatrix + return y, ∇sum_abs2 end -function CRC.rrule(::typeof(getproperty), A::UniformBlockDiagonalMatrix, x::Symbol) - @assert x === :data - ∇getproperty(Δ) = (NoTangent(), UniformBlockDiagonalMatrix(Δ)) - return A.data, ∇getproperty +function CRC.rrule(::typeof(sum), ::typeof(identity), op::UniformBlockDiagonalOperator{T}; + dims=:) where {T <: Union{Real, Complex}} + y = sum(abs2, op; dims) + project = CRC.ProjectTo(getdata(op)) + ∇sum_abs2 = @closure Δ -> begin + ∂op = project(_unsum(getdata(op), getdata(Δ), dims)) + return NoTangent(), NoTangent(), UniformBlockDiagonalOperator(∂op) + end + return y, ∇sum_abs2 +end + +# Direct Ldiv +function CRC.rrule(cfg::CRC.RuleConfig{>:CRC.HasReverseMode}, ::typeof(\), + op::UniformBlockDiagonalOperator, b::AbstractMatrix) + # We haven't implemented the rrule for least squares yet, direct AD through the code + size(op, 1) != size(op, 2) && return CRC.rrule_via_ad(cfg, __internal_backslash, op, b) + # TODO: reuse the factorization once, `factorize(op)` has been implemented + u = op \ b + proj_A = CRC.ProjectTo(getdata(op)) + proj_b = CRC.ProjectTo(b) + ∇backslash = @closure ∂u -> begin + λ = op' \ ∂u + ∂A = -batched_mul(λ, batched_adjoint(reshape(u, :, 1, nbatches(u)))) + return NoTangent(), UniformBlockDiagonalOperator(proj_A(∂A)), proj_b(λ) + end + return u, ∇backslash end diff --git a/src/factorization.jl b/src/factorization.jl new file mode 100644 index 0000000..bc0561c --- /dev/null +++ b/src/factorization.jl @@ -0,0 +1,116 @@ +abstract type AbstractBatchedMatrixFactorization{T} <: LinearAlgebra.Factorization{T} end + +const AdjointAbstractBatchedMatrixFactorization{T} = LinearAlgebra.AdjointFactorization{ + T, <:AbstractBatchedMatrixFactorization{T}} +const TransposeAbstractBatchedMatrixFactorization{T} = LinearAlgebra.TransposeFactorization{ + T, <:AbstractBatchedMatrixFactorization{T}} +const AdjointOrTransposeAbstractBatchedMatrixFactorization{T} = Union{ + AdjointAbstractBatchedMatrixFactorization{T}, + TransposeAbstractBatchedMatrixFactorization{T}} + +const AllAbstractBatchedMatrixFactorization{T} = Union{ + AbstractBatchedMatrixFactorization{T}, + AdjointOrTransposeAbstractBatchedMatrixFactorization{T}} + +nbatches(f::AdjointOrTransposeAbstractBatchedMatrixFactorization) = nbatches(parent(f)) +batchview(f::AdjointOrTransposeAbstractBatchedMatrixFactorization) = batchview(parent(f)) +function batchview(f::AdjointOrTransposeAbstractBatchedMatrixFactorization, idx::Int) + return batchview(parent(f), idx) +end + +# First we take inputs and standardize them +function LinearAlgebra.ldiv!(A::AllAbstractBatchedMatrixFactorization, b::AbstractVector) + LinearAlgebra.ldiv!(A, reshape(b, :, nbatches(A))) + return b +end + +function LinearAlgebra.ldiv!( + X::AbstractVector, A::AllAbstractBatchedMatrixFactorization, b::AbstractVector) + LinearAlgebra.ldiv!(reshape(X, :, nbatches(A)), A, reshape(b, :, nbatches(A))) + return X +end + +function Base.:\(A::AllAbstractBatchedMatrixFactorization, b::AbstractVector) + X = similar(b, promote_type(eltype(A), eltype(b)), size(A, 1)) + LinearAlgebra.ldiv!(X, A, b) + return X +end + +function Base.:\(A::AllAbstractBatchedMatrixFactorization, b::AbstractMatrix) + X = similar(b, promote_type(eltype(A), eltype(b)), size(A, 1)) + LinearAlgebra.ldiv!(X, A, vec(b)) + return reshape(X, :, nbatches(b)) +end + +# Now we implement the actual factorizations +## This just loops over the batches and calls the factorization on each, mostly used where +## we don't have native batched factorizations +struct GenericBatchedFactorization{T, A, F} <: AbstractBatchedMatrixFactorization{T} + alg::A + fact::Vector{F} + + function GenericBatchedFactorization(alg, fact) + return GenericBatchedFactorization{eltype(first(fact))}(alg, fact) + end + + function GenericBatchedFactorization{T}(alg::A, fact::Vector{F}) where {T, A, F} + return new{T, A, F}(alg, fact) + end +end + +nbatches(F::GenericBatchedFactorization) = length(F.fact) +batchview(F::GenericBatchedFactorization) = F.fact +batchview(F::GenericBatchedFactorization, idx::Int) = F.fact[idx] +Base.size(F::GenericBatchedFactorization) = size(first(F.fact)) .* length(F.fact) +function Base.size(F::GenericBatchedFactorization, i::Integer) + return size(first(F.fact), i) * length(F.fact) +end + +function LinearAlgebra.issuccess(fact::GenericBatchedFactorization) + return all(LinearAlgebra.issuccess, fact.fact) +end + +function Base.adjoint(fact::GenericBatchedFactorization{T}) where {T} + return GenericBatchedFactorization{T}(fact.alg, adjoint.(fact.fact)) +end + +function Base.show(io::IO, mime::MIME"text/plain", F::GenericBatchedFactorization) + println(io, "GenericBatchedFactorization() with Batch Count: $(nbatches(F))") + Base.printstyled(io, "Factorization Function: "; color=:green) + show(io, mime, F.alg) + Base.printstyled(io, "\nPrototype Factorization: "; color=:green) + show(io, mime, first(F.fact)) +end + +for fact in (:qr, :lu, :cholesky, :generic_lufact, :svd) + fact! = Symbol(fact, :!) + if isdefined(LinearAlgebra, fact) + @eval function LinearAlgebra.$(fact)( + op::UniformBlockDiagonalOperator, args...; kwargs...) + return LinearAlgebra.$(fact!)(copy(op), args...; kwargs...) + end + end + + @eval function LinearAlgebra.$(fact!)( + op::UniformBlockDiagonalOperator, args...; kwargs...) + fact = map(Aᵢ -> LinearAlgebra.$(fact!)(Aᵢ, args...; kwargs...), batchview(op)) + return GenericBatchedFactorization(LinearAlgebra.$(fact!), fact) + end +end + +function LinearAlgebra.ldiv!(A::GenericBatchedFactorization, b::AbstractMatrix) + @assert nbatches(A) == nbatches(b) + for i in 1:nbatches(A) + LinearAlgebra.ldiv!(batchview(A, i), batchview(b, i)) + end + return b +end + +function LinearAlgebra.ldiv!( + X::AbstractMatrix, A::GenericBatchedFactorization, b::AbstractMatrix) + @assert nbatches(A) == nbatches(b) == nbatches(X) + for i in 1:nbatches(A) + LinearAlgebra.ldiv!(batchview(X, i), batchview(A, i), batchview(b, i)) + end + return X +end diff --git a/src/helpers.jl b/src/helpers.jl index 7cc8c89..fd071af 100644 --- a/src/helpers.jl +++ b/src/helpers.jl @@ -119,6 +119,9 @@ function _jacobian_vector_product end function _vector_jacobian_product end function _batched_jacobian end function _batched_gradient end +function _construct_jvp_duals end + +@inline _restructure(y, x) = reshape(y, size(x)) # Test Loaded AD Backend _assert_loaded_backend(::AutoForwardDiff) = @assert _is_extension_loaded(Val(:ForwardDiff)) @@ -148,9 +151,6 @@ end return promote_type(T, eltype(f.x)), false end -# eachrow override -@inline _eachrow(X) = eachrow(X) - # MLUtils.jl has too many unwanted dependencies @inline fill_like(x::AbstractArray, v, ::Type{T}, dims...) where {T} = fill!( similar(x, T, dims...), v) diff --git a/src/matrix.jl b/src/matrix.jl deleted file mode 100644 index c7e4ee2..0000000 --- a/src/matrix.jl +++ /dev/null @@ -1,431 +0,0 @@ -struct UniformBlockDiagonalMatrix{T, D <: AbstractArray{T, 3}} <: AbstractMatrix{T} - data::D -end - -function UniformBlockDiagonalMatrix(X::AbstractMatrix) - return UniformBlockDiagonalMatrix(reshape(X, size(X, 1), size(X, 2), 1)) -end - -nbatches(A::UniformBlockDiagonalMatrix) = size(A.data, 3) -batchview(A::UniformBlockDiagonalMatrix) = batchview(A.data) -batchview(A::UniformBlockDiagonalMatrix, i::Int) = batchview(A.data, i) - -function batched_mul(A::UniformBlockDiagonalMatrix, B::UniformBlockDiagonalMatrix) - return UniformBlockDiagonalMatrix(batched_mul(A.data, B.data)) -end - -function batched_transpose(X::UniformBlockDiagonalMatrix) - return UniformBlockDiagonalMatrix(batched_transpose(X.data)) -end - -# To support ReverseDiff -Base.IndexStyle(::Type{<:UniformBlockDiagonalMatrix}) = IndexLinear() - -Base.transpose(A::UniformBlockDiagonalMatrix) = batched_transpose(A) - -function batched_adjoint(X::UniformBlockDiagonalMatrix) - return UniformBlockDiagonalMatrix(batched_adjoint(X.data)) -end - -Base.adjoint(A::UniformBlockDiagonalMatrix) = batched_adjoint(A) - -function batched_pinv(A::UniformBlockDiagonalMatrix) - return UniformBlockDiagonalMatrix(batched_pinv(A.data)) -end -function batched_inv(A::UniformBlockDiagonalMatrix) - return UniformBlockDiagonalMatrix(batched_inv(A.data)) -end - -function batched_reshape(A::UniformBlockDiagonalMatrix, dims...) - return UniformBlockDiagonalMatrix(batched_reshape(A.data, dims...)) -end - -# Adapt -function Adapt.adapt_structure(to, x::UniformBlockDiagonalMatrix) - return UniformBlockDiagonalMatrix(Adapt.adapt(to, parent(x))) -end - -# ArrayInterface -ArrayInterface.fast_matrix_colors(::Type{<:UniformBlockDiagonalMatrix}) = true -function ArrayInterface.fast_scalar_indexing(::Type{<:UniformBlockDiagonalMatrix{ - T, D}}) where {T, D} - return ArrayInterface.fast_scalar_indexing(D) -end -function ArrayInterface.can_setindex(::Type{<:UniformBlockDiagonalMatrix{ - T, D}}) where {T, D} - return ArrayInterface.can_setindex(D) -end - -function ArrayInterface.matrix_colors(A::UniformBlockDiagonalMatrix) - return repeat(1:size(A.data, 2), size(A.data, 3)) -end - -function ArrayInterface.findstructralnz(A::UniformBlockDiagonalMatrix) - I, J, K = size(A.data) - L = I * J * K - i_idxs, j_idxs = Vector{Int}(undef, L), Vector{Int}(undef, L) - - @inbounds for (idx, (i, j, k)) in enumerate(Iterators.product(1:I, 1:J, 1:K)) - i_idxs[idx] = i + (k - 1) * I - j_idxs[idx] = j + (k - 1) * J - end - - return i_idxs, j_idxs -end - -ArrayInterface.has_sparsestruct(::Type{<:UniformBlockDiagonalMatrix}) = true - -# Βase -function Base.size(A::UniformBlockDiagonalMatrix) - return (size(A.data, 1) * size(A.data, 3), size(A.data, 2) * size(A.data, 3)) -end -Base.size(A::UniformBlockDiagonalMatrix, i::Int) = (size(A.data, i) * size(A.data, 3)) - -Base.parent(A::UniformBlockDiagonalMatrix) = A.data - -Base.@propagate_inbounds function Base.getindex( - A::UniformBlockDiagonalMatrix, i::Int, j::Int) - i_, j_, k = _block_indices(A, i, j) - k == -1 && return zero(eltype(A)) - return A.data[i_, j_, k] -end - -Base.@propagate_inbounds function Base.getindex(A::UniformBlockDiagonalMatrix, idx::Int) - return getindex(A, mod1(idx, size(A, 1)), (idx - 1) ÷ size(A, 1) + 1) -end - -Base.@propagate_inbounds function Base.setindex!( - A::UniformBlockDiagonalMatrix, v, i::Int, j::Int) - i_, j_, k = _block_indices(A, i, j) - k == -1 && - !iszero(v) && - throw(ArgumentError("cannot set non-zero value outside of block.")) - A.data[i_, j_, k] = v - return v -end - -Base.@propagate_inbounds function Base.setindex!(A::UniformBlockDiagonalMatrix, v, idx::Int) - return setindex!(A, v, mod1(idx, size(A, 1)), (idx - 1) ÷ size(A, 1) + 1) -end - -function _block_indices(A::UniformBlockDiagonalMatrix, i::Int, j::Int) - all((0, 0) .< (i, j) .<= size(A)) || throw(BoundsError(A, (i, j))) - - M, N, _ = size(A.data) - - i_div = div(i - 1, M) + 1 - !((i_div - 1) * N + 1 ≤ j ≤ i_div * N) && return -1, -1, -1 - - return mod1(i, M), mod1(j, N), i_div -end - -function Base.Matrix(A::UniformBlockDiagonalMatrix) - M = Matrix{eltype(A)}(undef, size(A, 1), size(A, 2)) - L1, L2, _ = size(A.data) - fill!(M, false) - for (i, Aᵢ) in enumerate(batchview(A)) - M[((i - 1) * L1 + 1):(i * L1), ((i - 1) * L2 + 1):(i * L2)] .= Matrix(Aᵢ) - end - return M -end - -Base.Array(A::UniformBlockDiagonalMatrix) = Matrix(A) - -Base.collect(A::UniformBlockDiagonalMatrix) = Matrix(A) - -function Base.similar(A::UniformBlockDiagonalMatrix, ::Type{T}) where {T} - return UniformBlockDiagonalMatrix(similar(A.data, T)) -end - -Base.copy(A::UniformBlockDiagonalMatrix) = UniformBlockDiagonalMatrix(copy(A.data)) - -function Base.copyto!(dest::UniformBlockDiagonalMatrix, src::UniformBlockDiagonalMatrix) - copyto!(dest.data, src.data) - return dest -end - -function Base.fill!(A::UniformBlockDiagonalMatrix, v) - fill!(A.data, v) - return A -end - -@inline function _eachrow(X::UniformBlockDiagonalMatrix) - row_fn = @closure i -> begin - M, N, K = size(X.data) - k = (i - 1) ÷ M + 1 - i_ = mod1(i, M) - data = view(X.data, i_, :, k) - if k == 1 - return vcat(data, zeros_like(data, N * (K - 1))) - elseif k == K - return vcat(zeros_like(data, N * (K - 1)), data) - else - return vcat(zeros_like(data, N * (k - 1)), data, zeros_like(data, N * (K - k))) - end - end - return map(row_fn, 1:size(X, 1)) -end - -# Broadcasting -struct UniformBlockDiagonalMatrixStyle{N} <: Broadcast.AbstractArrayStyle{2} end - -function Broadcast.BroadcastStyle( - ::UniformBlockDiagonalMatrixStyle{N}, ::Broadcast.DefaultArrayStyle{M}) where {N, M} - return UniformBlockDiagonalMatrixStyle{max(N, M)}() -end -function Broadcast.BroadcastStyle(::Broadcast.AbstractArrayStyle{M}, - ::UniformBlockDiagonalMatrixStyle{N}) where {M, N} - return UniformBlockDiagonalMatrixStyle{max(M, N)}() -end -function Broadcast.BroadcastStyle(::UniformBlockDiagonalMatrixStyle{M}, - ::UniformBlockDiagonalMatrixStyle{N}) where {M, N} - return UniformBlockDiagonalMatrixStyle{max(M, N)}() -end -function Base.BroadcastStyle(::Type{<:UniformBlockDiagonalMatrix{T}}) where {T} - return UniformBlockDiagonalMatrixStyle{-1}() -end - -@inline function Base.copy(bc::Broadcast.Broadcasted{<:UniformBlockDiagonalMatrixStyle{N}}) where {N} - bc = Broadcast.flatten(bc) - return UniformBlockDiagonalMatrix(bc.f.(__standardize_broadcast_args( - Val(N), bc.axes, bc.args)...)) -end - -@inline function Base.copyto!(dest::UniformBlockDiagonalMatrix, - bc::Broadcast.Broadcasted{<:UniformBlockDiagonalMatrixStyle{N}}) where {N} - bc = Broadcast.flatten(bc) - dest.data .= bc.f.(__standardize_broadcast_args(Val(N), bc.axes, bc.args)...) - return dest -end - -@inline function Broadcast.instantiate(bc::Broadcast.Broadcasted{<:UniformBlockDiagonalMatrixStyle{N}}) where {N} - bc = Broadcast.flatten(bc) - axes_bc = Broadcast.combine_axes(getfield.( - filter(Base.Fix2(isa, UniformBlockDiagonalMatrix), bc.args), :data)...) - args = __standardize_broadcast_args(Val(N), axes_bc, bc.args) - if bc.axes isa Nothing # Not done via dispatch to make it easier to extend instantiate(::Broadcasted{Style}) - axes = Broadcast.combine_axes(args...) - else - axes = bc.axes - Broadcast.check_broadcast_axes(axes, args...) - end - return Broadcast.Broadcasted(bc.style, bc.f, args, axes) -end - -@inline __standardize_broadcast_args(N::Val, new_axes, args::Tuple) = __standardize_broadcast_args.( - (N,), (new_axes,), args) -for N in -1:1:3 - @eval @inline __standardize_broadcast_args(::Val{$(N)}, _, x::UniformBlockDiagonalMatrix) = x.data -end -@inline __standardize_broadcast_args(::Val{3}, _, x::AbstractArray{T, 3}) where {T} = x -@inline function __standardize_broadcast_args( - ::Val{3}, new_axes, x::AbstractArray{T, 2}) where {T} - I, J, K = __standardize_axes(new_axes) - ((I == size(x, 1) || I == 1) && (J == size(x, 2) || J == 1)) && - return reshape(x, size(x, 1), size(x, 2), 1) - return __standardize_broadcast_args(Val(2), new_axes, x) -end -@inline __standardize_broadcast_args(::Val, _, x::AbstractArray{T, 1}) where {T} = reshape( - x, 1, 1, length(x)) -@inline function __standardize_broadcast_args( - ::Val{2}, new_axes, x::AbstractArray{T, 2}) where {T} - I, J, K = __standardize_axes(new_axes) - @assert I * K == size(x, 1) && J * K == size(x, 2) - return mapfoldl(_cat3, 1:K; init=_init_array_prototype(x, I, J, 0)) do k - return view(x, ((k - 1) * I + 1):(k * I), ((k - 1) * J + 1):(k * J)) - end -end -@inline function __standardize_broadcast_args(::Val{2}, new_axes, x::Fill{T, 2}) where {T} - I, J, K = __standardize_axes(new_axes) - @assert I * K == size(x, 1) && J * K == size(x, 2) - return Fill(x.value, I, J, K) -end -@inline __standardize_broadcast_args(::Val, _, x) = x - -@inline __standardize_axes(axes::Tuple) = __standardize_axes.(axes) -@inline __standardize_axes(axes::Base.OneTo) = axes.stop -@inline function __standardize_axes(axes::StepRange) - @assert axes.start == 1 && axes.step == 1 - return axes.stop -end - -# Common Math Operations -function Base.mapreduce( - f::F, op::OP, A::UniformBlockDiagonalMatrix; dims=Colon(), kwargs...) where {F, OP} - res = mapreduce(f, op, A.data; dims, kwargs...) - dims isa Colon && return res - return UniformBlockDiagonalMatrix(res) -end - -function Base.:*(X::UniformBlockDiagonalMatrix, Y::UniformBlockDiagonalMatrix) - return UniformBlockDiagonalMatrix(batched_mul(X.data, Y.data)) -end - -function Base.:*(X::UniformBlockDiagonalMatrix, Y::AbstractVector) - return (X * reshape(Y, :, 1, nbatches(X))).data |> vec -end - -function Base.:*(X::UniformBlockDiagonalMatrix, Y::AbstractMatrix) - return X * reshape(Y, :, 1, nbatches(Y)) -end - -function Base.:*(X::UniformBlockDiagonalMatrix, Y::AbstractArray{T, 3}) where {T} - return UniformBlockDiagonalMatrix(batched_mul(X.data, Y)) -end - -function Base.:*(X::AbstractArray{T, 3}, Y::UniformBlockDiagonalMatrix) where {T} - return UniformBlockDiagonalMatrix(batched_mul(X, Y.data)) -end - -function Base.:*(X::AbstractArray{T, 2}, Y::UniformBlockDiagonalMatrix) where {T} - C = reshape(X, 1, :, nbatches(X)) * Y - return dropdims(C.data; dims=1) -end - -function LinearAlgebra.mul!(A::AbstractMatrix, B::AbstractMatrix, - C::UniformBlockDiagonalMatrix, α::Number=true, β::Number=false) - A_ = reshape(A, 1, :, nbatches(A)) - B_ = reshape(B, 1, :, nbatches(B)) - batched_mul!(A_, B_, C.data, α, β) - return A -end - -# LinearAlgebra -abstract type AbstractBatchedMatrixFactorization{T} <: LinearAlgebra.Factorization{T} end - -const AdjointAbstractBatchedMatrixFactorization{T} = LinearAlgebra.AdjointFactorization{ - T, <:AbstractBatchedMatrixFactorization{T}} -const TransposeAbstractBatchedMatrixFactorization{T} = LinearAlgebra.TransposeFactorization{ - T, <:AbstractBatchedMatrixFactorization{T}} -const AdjointOrTransposeAbstractBatchedMatrixFactorization{T} = Union{ - AdjointAbstractBatchedMatrixFactorization{T}, - TransposeAbstractBatchedMatrixFactorization{T}} - -const AllAbstractBatchedMatrixFactorization{T} = Union{ - AbstractBatchedMatrixFactorization{T}, - AdjointOrTransposeAbstractBatchedMatrixFactorization{T}} - -nbatches(f::AdjointOrTransposeAbstractBatchedMatrixFactorization) = nbatches(parent(f)) -batchview(f::AdjointOrTransposeAbstractBatchedMatrixFactorization) = batchview(parent(f)) -function batchview(f::AdjointOrTransposeAbstractBatchedMatrixFactorization, idx::Int) - return batchview(parent(f), idx) -end - -function LinearAlgebra.ldiv!(A::AllAbstractBatchedMatrixFactorization, b::AbstractVector) - LinearAlgebra.ldiv!(A, reshape(b, :, nbatches(A))) - return b -end - -function LinearAlgebra.ldiv!( - X::AbstractVector, A::AllAbstractBatchedMatrixFactorization, b::AbstractVector) - LinearAlgebra.ldiv!(reshape(X, :, nbatches(A)), A, reshape(b, :, nbatches(A))) - return X -end - -function LinearAlgebra.:\(A::AllAbstractBatchedMatrixFactorization, b::AbstractVector) - X = similar(b, promote_type(eltype(A), eltype(b)), size(A, 1)) - LinearAlgebra.ldiv!(X, A, b) - return X -end - -function LinearAlgebra.:\(A::AllAbstractBatchedMatrixFactorization, b::AbstractMatrix) - X = similar(b, promote_type(eltype(A), eltype(b)), size(A, 1)) - LinearAlgebra.ldiv!(X, A, vec(b)) - return reshape(X, :, nbatches(b)) -end - -struct GenericBatchedFactorization{T, A, F} <: AbstractBatchedMatrixFactorization{T} - alg::A - fact::Vector{F} - - function GenericBatchedFactorization(alg, fact) - return GenericBatchedFactorization{eltype(first(fact))}(alg, fact) - end - - function GenericBatchedFactorization{T}(alg::A, fact::Vector{F}) where {T, A, F} - return new{T, A, F}(alg, fact) - end -end - -nbatches(F::GenericBatchedFactorization) = length(F.fact) -batchview(F::GenericBatchedFactorization) = F.fact -batchview(F::GenericBatchedFactorization, idx::Int) = F.fact[idx] -Base.size(F::GenericBatchedFactorization) = size(first(F.fact)) .* length(F.fact) -function Base.size(F::GenericBatchedFactorization, i::Integer) - return size(first(F.fact), i) * length(F.fact) -end - -function LinearAlgebra.issuccess(fact::GenericBatchedFactorization) - return all(LinearAlgebra.issuccess, fact.fact) -end - -function Base.adjoint(fact::GenericBatchedFactorization{T}) where {T} - return GenericBatchedFactorization{T}(fact.alg, adjoint.(fact.fact)) -end - -function Base.show(io::IO, mime::MIME{Symbol("text/plain")}, F::GenericBatchedFactorization) - println(io, "GenericBatchedFactorization() with Batch Count: $(nbatches(F))") - Base.printstyled(io, "Factorization Function: "; color=:green) - show(io, mime, F.alg) - Base.printstyled(io, "\nPrototype Factorization: "; color=:green) - show(io, mime, first(F.fact)) - return nothing -end - -const PIVOT_TYPES = Dict( - :qr => (:NoPivot, :ColumnNorm), :lu => (:NoPivot, :RowMaximum, :RowNonZero), - :cholesky => (:NoPivot, :RowMaximum)) - -for fact in (:qr, :lu, :cholesky) - fact! = Symbol("$(fact)!") - @eval begin - function LinearAlgebra.$(fact)(A::UniformBlockDiagonalMatrix, args...; kwargs...) - return LinearAlgebra.$(fact!)(copy(A), args...; kwargs...) - end - end - - for pType in PIVOT_TYPES[fact] - @eval begin - function LinearAlgebra.$(fact!)( - A::UniformBlockDiagonalMatrix, pivot::$pType; kwargs...) - fact = map(Aᵢ -> LinearAlgebra.$(fact!)(Aᵢ, pivot; kwargs...), batchview(A)) - return GenericBatchedFactorization(LinearAlgebra.$(fact!), fact) - end - - # Needed to prevent method ambiguities - function LinearAlgebra.$(fact)( - A::UniformBlockDiagonalMatrix, pivot::$pType; kwargs...) - return LinearAlgebra.$(fact!)(copy(A), pivot; kwargs...) - end - end - end -end - -function LinearAlgebra.ldiv!(A::GenericBatchedFactorization, b::AbstractMatrix) - @assert nbatches(A) == nbatches(b) - for i in 1:nbatches(A) - LinearAlgebra.ldiv!(batchview(A, i), batchview(b, i)) - end - return b -end - -function LinearAlgebra.ldiv!( - X::AbstractMatrix, A::GenericBatchedFactorization, b::AbstractMatrix) - @assert nbatches(A) == nbatches(b) == nbatches(X) - for i in 1:nbatches(A) - LinearAlgebra.ldiv!(batchview(X, i), batchview(A, i), batchview(b, i)) - end - return X -end - -function LinearAlgebra.mul!( - C::AbstractVector, A::UniformBlockDiagonalMatrix, B::AbstractVector) - LinearAlgebra.mul!(reshape(C, :, 1, nbatches(A)), A, reshape(B, :, 1, nbatches(A))) - return C -end - -function LinearAlgebra.mul!(C::AbstractArray{T1, 3}, A::UniformBlockDiagonalMatrix, - B::AbstractArray{T2, 3}) where {T1, T2} - batched_mul!(C, A.data, B) - return C -end diff --git a/src/operator.jl b/src/operator.jl new file mode 100644 index 0000000..80e6469 --- /dev/null +++ b/src/operator.jl @@ -0,0 +1,288 @@ +struct UniformBlockDiagonalOperator{T, D <: AbstractArray{T, 3}} <: AbstractSciMLOperator{T} + data::D +end + +function UniformBlockDiagonalOperator(X::AbstractMatrix) + return UniformBlockDiagonalOperator(reshape(X, size(X, 1), size(X, 2), 1)) +end + +# SciMLOperators Interface +## Even though it is easily convertible, it is helpful to get warnings +SciMLOperators.isconvertible(::UniformBlockDiagonalOperator) = false + +# BatchedRoutines API +getdata(op::UniformBlockDiagonalOperator) = op.data +getdata(x) = x +nbatches(op::UniformBlockDiagonalOperator) = size(op.data, 3) +batchview(op::UniformBlockDiagonalOperator) = batchview(op.data) +batchview(op::UniformBlockDiagonalOperator, i::Int) = batchview(op.data, i) + +function batched_mul(op1::UniformBlockDiagonalOperator, op2::UniformBlockDiagonalOperator) + return UniformBlockDiagonalOperator(batched_mul(op1.data, op2.data)) +end + +for f in ( + :batched_transpose, :batched_adjoint, :batched_inv, :batched_pinv, :batched_reshape) + @eval function $(f)(op::UniformBlockDiagonalOperator, args...) + return UniformBlockDiagonalOperator($(f)(op.data, args...)) + end +end + +## Matrix Multiplies +@inline function Base.:*( + op1::UniformBlockDiagonalOperator, op2::UniformBlockDiagonalOperator) + return batched_mul(op1, op2) +end + +@inline function Base.:*(op::UniformBlockDiagonalOperator, x::AbstractVector) + return (op * reshape(x, :, 1, nbatches(op))) |> vec +end + +@inline function Base.:*(op::UniformBlockDiagonalOperator, x::AbstractMatrix) + return dropdims(op * reshape(x, :, 1, nbatches(x)); dims=2) +end + +@inline function Base.:*(op::UniformBlockDiagonalOperator, x::AbstractArray{T, 3}) where {T} + return (op * UniformBlockDiagonalOperator(x)) |> getdata +end + +@inline function Base.:*(x::AbstractVector, op::UniformBlockDiagonalOperator) + return (reshape(x, :, 1, nbatches(op)) * op) |> vec +end + +@inline function Base.:*(x::AbstractMatrix, op::UniformBlockDiagonalOperator) + return dropdims(reshape(x, :, 1, nbatches(x)) * op; dims=1) +end + +@inline function Base.:*(x::AbstractArray{T, 3}, op::UniformBlockDiagonalOperator) where {T} + return (UniformBlockDiagonalOperator(x) * op) |> getdata +end + +for f in (:transpose, :adjoint) + batched_f = Symbol("batched_", f) + @eval (Base.$(f))(op::UniformBlockDiagonalOperator) = $(batched_f)(op) +end + +@inline function Base.size(op::UniformBlockDiagonalOperator) + N, M, B = size(op.data) + return N * B, M * B +end +@inline Base.size(op::UniformBlockDiagonalOperator, i::Int) = size(op.data, i) * + size(op.data, 3) + +@inline Base.length(op::UniformBlockDiagonalOperator) = prod(size(op)) + +function Base.show(io::IO, mime::MIME"text/plain", op::UniformBlockDiagonalOperator) + print(io, "UniformBlockDiagonalOperator{$(eltype(op.data))} storing ") + show(io, mime, op.data) +end + +function Base.mapreduce(f::F, op::OP, A::UniformBlockDiagonalOperator; + dims=Colon(), kwargs...) where {F, OP} + res = mapreduce(f, op, getdata(A); dims, kwargs...) + dims isa Colon && return res + return UniformBlockDiagonalOperator(res) +end + +function Base.fill!(op::UniformBlockDiagonalOperator, v) + fill!(getdata(op), v) + return op +end + +## getindex and setindex! are supported mostly to allow finitediff to compute the jacobian +Base.@propagate_inbounds function Base.getindex( + A::UniformBlockDiagonalOperator, i::Int, j::Int) + i_, j_, k = _block_indices(A, i, j) + k == -1 && return zero(eltype(A)) + return A.data[i_, j_, k] +end + +Base.@propagate_inbounds function Base.getindex(A::UniformBlockDiagonalOperator, idx::Int) + return getindex(A, mod1(idx, size(A, 1)), (idx - 1) ÷ size(A, 1) + 1) +end + +Base.@propagate_inbounds function Base.setindex!( + A::UniformBlockDiagonalOperator, v, i::Int, j::Int) + i_, j_, k = _block_indices(A, i, j) + k == -1 && + !iszero(v) && + throw(ArgumentError("cannot set non-zero value outside of block.")) + A.data[i_, j_, k] = v + return v +end + +Base.@propagate_inbounds function Base.setindex!( + A::UniformBlockDiagonalOperator, v, idx::Int) + return setindex!(A, v, mod1(idx, size(A, 1)), (idx - 1) ÷ size(A, 1) + 1) +end + +function _block_indices(A::UniformBlockDiagonalOperator, i::Int, j::Int) + all((0, 0) .< (i, j) .<= size(A)) || throw(BoundsError(A, (i, j))) + + M, N, _ = size(A.data) + + i_div = div(i - 1, M) + 1 + !((i_div - 1) * N + 1 ≤ j ≤ i_div * N) && return -1, -1, -1 + + return mod1(i, M), mod1(j, N), i_div +end + +@inline function Base.eachrow(X::UniformBlockDiagonalOperator) + row_fn = @closure i -> begin + M, N, K = size(X.data) + k = (i - 1) ÷ M + 1 + i_ = mod1(i, M) + data = view(X.data, i_, :, k) + if k == 1 + return vcat(data, zeros_like(data, N * (K - 1))) + elseif k == K + return vcat(zeros_like(data, N * (K - 1)), data) + else + return vcat(zeros_like(data, N * (k - 1)), data, zeros_like(data, N * (K - k))) + end + end + return map(row_fn, 1:size(X, 1)) +end + +## Operator --> AbstractArray +function __copyto!(A::AbstractMatrix, op::UniformBlockDiagonalOperator) + N, M, B = size(getdata(op)) + @assert size(A) == (N * B, M * B) + fill!(A, zero(eltype(op))) + for (i, Aᵢ) in enumerate(batchview(op)) + A[((i - 1) * N + 1):(i * N), ((i - 1) * M + 1):(i * M)] .= convert( + AbstractMatrix, Aᵢ) + end +end + +function Base.convert( + ::Type{C}, op::UniformBlockDiagonalOperator) where {C <: AbstractArray} + A = similar(op.data, size(op)) + __copyto!(A, op) + return convert(C, A) +end + +Base.Matrix(op::UniformBlockDiagonalOperator) = convert(Matrix, op) +Base.Array(op::UniformBlockDiagonalOperator) = Matrix(op) +Base.collect(op::UniformBlockDiagonalOperator) = convert(AbstractMatrix, op) + +function Base.copyto!(A::AbstractArray, op::UniformBlockDiagonalOperator) + @assert length(A) ≥ length(op) + A_ = reshape(view(vec(A), 1:length(op)), size(op)) + __copyto!(A_, op) + return A +end + +@inline function Base.copy(op::UniformBlockDiagonalOperator) + return UniformBlockDiagonalOperator(copy(getdata(op))) +end + +## Define some of the common operations like `sum` directly since SciMLOperators doesn't +## use a very nice implemented +@inline function Base.sum(op::UniformBlockDiagonalOperator; kwargs...) + return sum(identity, op; kwargs...) +end + +@inline function Base.sum(f::F, op::UniformBlockDiagonalOperator; dims=Colon()) where {F} + return mapreduce(f, +, op; dims) +end + +## Common Operations +function Base.:+(op1::UniformBlockDiagonalOperator, op2::UniformBlockDiagonalOperator) + return UniformBlockDiagonalOperator(getdata(op1) + getdata(op2)) +end + +function Base.:-(op1::UniformBlockDiagonalOperator, op2::UniformBlockDiagonalOperator) + return UniformBlockDiagonalOperator(getdata(op1) - getdata(op2)) +end + +function Base.isapprox( + op1::UniformBlockDiagonalOperator, op2::UniformBlockDiagonalOperator; kwargs...) + return isapprox(getdata(op1), getdata(op2); kwargs...) +end + +# Adapt +@inline function Adapt.adapt_structure(to, op::UniformBlockDiagonalOperator) + return UniformBlockDiagonalOperator(Adapt.adapt(to, getdata(op))) +end + +# ArrayInterface +ArrayInterface.fast_matrix_colors(::Type{<:UniformBlockDiagonalOperator}) = true +function ArrayInterface.fast_scalar_indexing(::Type{<:UniformBlockDiagonalOperator{ + T, D}}) where {T, D} + return ArrayInterface.fast_scalar_indexing(D) +end +function ArrayInterface.can_setindex(::Type{<:UniformBlockDiagonalOperator{ + T, D}}) where {T, D} + return ArrayInterface.can_setindex(D) +end + +function ArrayInterface.matrix_colors(A::UniformBlockDiagonalOperator) + return repeat(1:size(A.data, 2), size(A.data, 3)) +end + +function ArrayInterface.findstructralnz(A::UniformBlockDiagonalOperator) + I, J, K = size(A.data) + L = I * J * K + i_idxs, j_idxs = Vector{Int}(undef, L), Vector{Int}(undef, L) + + @inbounds for (idx, (i, j, k)) in enumerate(Iterators.product(1:I, 1:J, 1:K)) + i_idxs[idx] = i + (k - 1) * I + j_idxs[idx] = j + (k - 1) * J + end + + return i_idxs, j_idxs +end + +ArrayInterface.has_sparsestruct(::Type{<:UniformBlockDiagonalOperator}) = true + +# Linear Algebra Routines +function LinearAlgebra.mul!( + A::Union{AbstractMatrix, UniformBlockDiagonalOperator}, B::AbstractMatrix, + C::UniformBlockDiagonalOperator, α::Number=true, β::Number=false) + A_ = A isa AbstractArray ? reshape(A, :, 1, nbatches(A)) : getdata(A) + B_ = reshape(B, :, 1, nbatches(B)) + batched_mul!(A_, B_, getdata(C), α, β) + return A +end + +function LinearAlgebra.mul!(A::Union{AbstractMatrix, UniformBlockDiagonalOperator}, + B::UniformBlockDiagonalOperator, + C::AbstractMatrix, α::Number=true, β::Number=false) + A_ = A isa AbstractArray ? reshape(A, :, 1, nbatches(A)) : getdata(A) + C_ = reshape(C, :, 1, nbatches(C)) + batched_mul!(A_, getdata(B), C_, α, β) + return A +end + +function LinearAlgebra.mul!(A::Union{AbstractMatrix, UniformBlockDiagonalOperator}, + B::UniformBlockDiagonalOperator, + C::UniformBlockDiagonalOperator, α::Number=true, β::Number=false) + A_ = A isa AbstractArray ? reshape(A, :, 1, nbatches(A)) : getdata(A) + batched_mul!(A_, getdata(B), getdata(C), α, β) + return A +end + +function LinearAlgebra.mul!( + C::AbstractVector, A::UniformBlockDiagonalOperator, B::AbstractVector) + LinearAlgebra.mul!(reshape(C, :, 1, nbatches(A)), A, reshape(B, :, 1, nbatches(A))) + return C +end + +function LinearAlgebra.mul!(C::AbstractArray{T1, 3}, A::UniformBlockDiagonalOperator, + B::AbstractArray{T2, 3}) where {T1, T2} + batched_mul!(C, getdata(A), B) + return C +end + +# Direct \ operator +function Base.:\(op::UniformBlockDiagonalOperator, b::AbstractVector) + return vec(op \ reshape(b, :, nbatches(op))) +end +Base.:\(op::UniformBlockDiagonalOperator, b::AbstractMatrix) = __internal_backslash(op, b) + +## This exists to allow a direct autodiff through the code. eg, for non-square systems +@inline function __internal_backslash(op::UniformBlockDiagonalOperator, b::AbstractMatrix) + size(op, 1) != length(b) && throw(DimensionMismatch("size(op, 1) != length(b)")) + return mapfoldl(((Aᵢ, bᵢ),) -> Aᵢ \ bᵢ, hcat, zip(batchview(op), batchview(b))) +end diff --git a/test/integration_tests.jl b/test/integration_tests.jl index a03603a..8f6433f 100644 --- a/test/integration_tests.jl +++ b/test/integration_tests.jl @@ -1,52 +1,70 @@ @testitem "LinearSolve" setup=[SharedTestSetup] begin - using FiniteDiff, LinearSolve, Zygote + using LinearAlgebra, LinearSolve, Zygote rng = get_stable_rng(1001) @testset "$mode" for (mode, aType, dev, ongpu) in MODES for dims in ((8, 8, 2), (5, 3, 2)) - A1 = UniformBlockDiagonalMatrix(rand(rng, dims...)) |> dev - A2 = Matrix(A1) |> dev + A1 = UniformBlockDiagonalOperator(rand(rng, dims...)) |> dev + A2 = collect(A1) b = rand(rng, size(A1, 1)) |> dev prob1 = LinearProblem(A1, b) prob2 = LinearProblem(A2, b) if dims[1] == dims[2] - solvers = [LUFactorization(), QRFactorization(), KrylovJL_GMRES()] + solvers = [LUFactorization(), QRFactorization(), + KrylovJL_GMRES(), svd_factorization(mode), nothing] else - solvers = [QRFactorization(), KrylovJL_LSMR()] + solvers = [ + QRFactorization(), KrylovJL_LSMR(), NormalCholeskyFactorization(), + QRFactorization(LinearAlgebra.ColumnNorm()), + svd_factorization(mode), nothing] end - @testset "solver: $(solver)" for solver in solvers + if dims[1] == dims[2] + test_chainrules_adjoint = (A, b) -> sum(abs2, A \ b) + + ∂A_cr, ∂b_cr = Zygote.gradient(test_chainrules_adjoint, A1, b) + else + ∂A_cr, ∂b_cr = nothing, nothing + end + + @testset "solver: $(nameof(typeof(solver)))" for solver in solvers + # FIXME: SVD doesn't define ldiv on CUDA side + if mode == "CUDA" + if solver isa SVDFactorization || (solver isa QRFactorization && + solver.pivot isa LinearAlgebra.ColumnNorm) + # ColumnNorm is not implemented on CUDA + continue + end + end + x1 = solve(prob1, solver) - x2 = solve(prob2, solver) - @test x1.u ≈ x2.u + if !ongpu && !(solver isa NormalCholeskyFactorization) + x2 = solve(prob2, solver) + @test x1.u ≈ x2.u + end + + dims[1] != dims[2] && continue test_adjoint = function (A, b) sol = solve(LinearProblem(A, b), solver) return sum(abs2, sol.u) end - dims[1] != dims[2] && continue - - ∂A_fd = FiniteDiff.finite_difference_gradient( - x -> test_adjoint(x, Array(b)), Array(A1)) - ∂b_fd = FiniteDiff.finite_difference_gradient( - x -> test_adjoint(Array(A1), x), Array(b)) - if solver isa QRFactorization && ongpu @test_broken begin ∂A, ∂b = Zygote.gradient(test_adjoint, A1, b) - @test Array(∂A)≈∂A_fd atol=1e-1 rtol=1e-1 - @test Array(∂b)≈∂b_fd atol=1e-1 rtol=1e-1 + @test ∂A ≈ ∂A_cr + @test ∂b ≈ ∂b_cr end else ∂A, ∂b = Zygote.gradient(test_adjoint, A1, b) - @test Array(∂A)≈∂A_fd atol=1e-1 rtol=1e-1 - @test Array(∂b)≈∂b_fd atol=1e-1 rtol=1e-1 + @test ∂A ≈ ∂A_cr + @test ∂b ≈ ∂b_cr end end end @@ -72,7 +90,7 @@ end loss_function = (model, x, target_jac, ps, st) -> begin m = StatefulLuxLayer(model, nothing, st) jac_full = batched_jacobian(AutoForwardDiff(; chunksize=4), m, x, ps) - return sum(abs2, jac_full .- target_jac) + return sum(abs2, jac_full - target_jac) end @test loss_function(model, x, target_jac, ps, st) isa Number @@ -94,7 +112,7 @@ end loss_function2 = (model, x, target_jac, ps, st) -> begin m = StatefulLuxLayer(model, ps, st) jac_full = batched_jacobian(AutoForwardDiff(; chunksize=4), m, x) - return sum(abs2, jac_full .- target_jac) + return sum(abs2, jac_full - target_jac) end @test loss_function2(model, x, target_jac, ps, st) isa Number diff --git a/test/shared_testsetup.jl b/test/shared_testsetup.jl index c3417fc..d69bfc7 100644 --- a/test/shared_testsetup.jl +++ b/test/shared_testsetup.jl @@ -32,7 +32,18 @@ end get_stable_rng(seed=12345) = StableRNG(seed) +# SVD Helper till https://github.com/SciML/LinearSolve.jl/issues/488 is resolved +using LinearSolve: LinearSolve + +function svd_factorization(mode) + mode == "CPU" && return LinearSolve.SVDFactorization() + mode == "CUDA" && + return LinearSolve.SVDFactorization(true, CUDA.CUSOLVER.JacobiAlgorithm()) + error("Unsupported mode: $mode") +end + export @jet, @test_gradients, check_approx export GROUP, MODES, cpu_testing, cuda_testing, get_default_rng, get_stable_rng +export svd_factorization end