-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: Allows blocks to be <:AbstractVector or <:Tuple #56
Changes from all commits
ac3ddc2
32f9d07
8b3235c
c9bf915
d5d29da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,20 @@ | ||
using Documenter, BlockDiagonals | ||
using BlockDiagonals | ||
using Documenter | ||
|
||
makedocs(; | ||
modules=[BlockDiagonals], | ||
format=Documenter.HTML(prettyurls = get(ENV, "CI", nothing) == "true"), | ||
format=Documenter.HTML(prettyurls=false), | ||
pages=[ | ||
"Home" => "index.md", | ||
], | ||
repo="https://github.com/invenia/BlockDiagonals.jl/blob/{commit}{path}#L{line}", | ||
repo="https://github.com/invenia/BlockDiagonals.jl/blob/{commit}{path}#{line}", | ||
sitename="BlockDiagonals.jl", | ||
authors="Invenia Technical Computing", | ||
strict=true, | ||
checkdocs=:exports, | ||
) | ||
|
||
deploydocs(; | ||
repo="github.com/invenia/BlockDiagonals.jl", | ||
push_preview=true, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,19 +1,22 @@ | ||
# Core functionality for the `BlockDiagonal` type | ||
|
||
""" | ||
BlockDiagonal{T, V<:AbstractMatrix{T}} <: AbstractMatrix{T} | ||
BlockDiagonal{T, V} <: AbstractMatrix{T} | ||
BlockDiagonal(blocks::V) -> BlockDiagonal{T,V} | ||
|
||
A matrix with matrices on the diagonal, and zeros off the diagonal. | ||
""" | ||
struct BlockDiagonal{T, V<:AbstractMatrix{T}} <: AbstractMatrix{T} | ||
blocks::Vector{V} | ||
|
||
function BlockDiagonal{T, V}(blocks::Vector{V}) where {T, V<:AbstractMatrix{T}} | ||
return new{T, V}(blocks) | ||
end | ||
!!! info "`V` type" | ||
`blocks::V` should be a `Tuple` or `AbstractVector` where each component (each block) is | ||
`<:AbstractMatrix{T}` for some common element type `T`. | ||
""" | ||
struct BlockDiagonal{T, V} <: AbstractMatrix{T} | ||
blocks::V | ||
end | ||
|
||
function BlockDiagonal(blocks::Vector{V}) where {T, V<:AbstractMatrix{T}} | ||
function BlockDiagonal(blocks::V) where { | ||
T, V<:Union{Tuple{Vararg{<:AbstractMatrix{T}}}, AbstractVector{<:AbstractMatrix{T}}} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you really want all block matrices to have the same eltype? Or could you live with them having different eltypes, and There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
that's a great suggestion to be honest, i hadn't considered it. This currently just keeps There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. trying this locally, and running tests, it does cause some inference issues for
perhaps someone more familiar with ChainRules has some insight on why this might be (@willtebbutt , @mzgubic)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know why this is (I assume the answer contains the word compiler). But I'd rather live with the same eltype than with this I think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, this change would be non-breaking anyway, so let's leave it as a potentially follow-up #62 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that's perfectly fine. But then you may wish to consider to include a constructor that by itself does not yet require equal eltypes, but first determines |
||
} | ||
return BlockDiagonal{T, V}(blocks) | ||
end | ||
|
||
|
@@ -22,15 +25,15 @@ BlockDiagonal(B::BlockDiagonal) = B | |
is_square(A::AbstractMatrix) = size(A, 1) == size(A, 2) | ||
|
||
""" | ||
blocks(B::BlockDiagonal{T, V}) -> Vector{V} | ||
blocks(B::BlockDiagonal{T, V}) -> V | ||
|
||
Return the on-diagonal blocks of B. | ||
""" | ||
blocks(B::BlockDiagonal) = B.blocks | ||
|
||
# BlockArrays-like functions | ||
""" | ||
blocksizes(B::BlockDiagonal) -> Vector{Tuple} | ||
blocksizes(B::BlockDiagonal{T, V}) -> V | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we want this to return a Vector or a Tuple depending on the block type? Or should we force it to return a vector/tuple always? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. my thinking is that we'd use |
||
|
||
Return the size of each on-diagonal block in order. | ||
|
||
|
@@ -151,7 +154,9 @@ function _block_indices(B::BlockDiagonal, i::Integer, j::Integer) | |
p += 1 | ||
j -= ncols[p] | ||
end | ||
i -= sum(nrows[1:(p-1)]) | ||
if !isempty(nrows[1:(p-1)]) | ||
i -= sum(nrows[1:(p-1)]) | ||
end | ||
# if row `i` outside of block `p`, set `p` to place-holder value `-1` | ||
if i <= 0 || i > nrows[p] | ||
p = -1 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
# constructor | ||
function ChainRulesCore.rrule(::Type{<:BlockDiagonal}, blocks::Vector{V}) where {V} | ||
function ChainRulesCore.rrule(::Type{<:BlockDiagonal}, blocks::V) where {V<:AbstractVector} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we exclude the tuple we won't have an There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
correct, that'd be a nice feature to add as a follow-up, but out-of-scope for this PR But my question is a different one. Before this PR, (it's a question about how tightly to constrain |
||
BlockDiagonal_pullback(Δ::Composite) = (NO_FIELDS, Δ.blocks) | ||
return BlockDiagonal(blocks), BlockDiagonal_pullback | ||
end | ||
|
@@ -27,7 +27,7 @@ function ChainRulesCore.rrule( | |
::typeof(*), | ||
bm::BlockDiagonal{T, V}, | ||
v::StridedVector{T} | ||
) where {T<:Union{Real, Complex}, V<:Matrix{T}} | ||
) where {T<:Union{Real, Complex}, V<:Vector{Matrix{T}}} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this okay, @willtebbutt ? |
||
|
||
y = bm * v | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -76,7 +76,7 @@ svdvals_blockwise(B::BlockDiagonal) = mapreduce(svdvals, vcat, blocks(B)) | |
LinearAlgebra.svdvals(B::BlockDiagonal) = sort!(svdvals_blockwise(B); rev=true) | ||
|
||
# `B = U * Diagonal(S) * Vt` with `U` and `Vt` `BlockDiagonal` (`S` only sorted block-wise). | ||
function svd_blockwise(B::BlockDiagonal{T}; full::Bool=false) where T | ||
function svd_blockwise(B::BlockDiagonal{T, <:AbstractVector}; full::Bool=false) where T | ||
U = Matrix{float(T)}[] | ||
S = Vector{float(T)}() | ||
Vt = Matrix{float(T)}[] | ||
|
@@ -88,6 +88,17 @@ function svd_blockwise(B::BlockDiagonal{T}; full::Bool=false) where T | |
end | ||
return BlockDiagonal(U), S, BlockDiagonal(Vt) | ||
end | ||
function svd_blockwise(B::BlockDiagonal{T, <:Tuple}; full::Bool=false) where T | ||
S = Vector{float(T)}() | ||
U_Vt = ntuple(length(blocks(B))) do i | ||
F = svd(getblock(B, i), full=full) | ||
append!(S, F.S) | ||
(F.U, F.Vt) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How come we aren't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. because above we want rather than |
||
end | ||
U = first.(U_Vt) | ||
Vt = last.(U_Vt) | ||
return BlockDiagonal(U), S, BlockDiagonal(Vt) | ||
end | ||
|
||
function LinearAlgebra.svd(B::BlockDiagonal; full::Bool=false)::SVD | ||
U, S, Vt = svd_blockwise(B, full=full) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we want to constrain
T
to at least aNumber
, or are we happy with anything?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think if
T
isn't Number-like, some methods will fail. But 🤷 Would there be any gain from constraining things?