Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement S(quare) type parameter #114

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BlockDiagonals"
uuid = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
authors = ["Invenia Technical Computing Corporation"]
version = "0.1.36"
version = "0.1.37"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
1 change: 1 addition & 0 deletions src/BlockDiagonals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ include("blockdiagonal.jl")
include("base_maths.jl")
include("chainrules.jl")
include("linalg.jl")
include("deprecate.jl")

end # end module
11 changes: 7 additions & 4 deletions src/blockdiagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,19 @@

A matrix with matrices on the diagonal, and zeros off the diagonal.
"""
struct BlockDiagonal{T, V<:AbstractMatrix{T}} <: AbstractMatrix{T}
struct BlockDiagonal{T, V<:AbstractMatrix{T}, S} <: AbstractMatrix{T}
blocks::Vector{V}

function BlockDiagonal{T, V}(blocks::Vector{V}) where {T, V<:AbstractMatrix{T}}
return new{T, V}(blocks)
function BlockDiagonal{T, V, S}(blocks::Vector{V}) where {T, V<:AbstractMatrix{T}, S}
infer_S = all(is_square.(blocks))
S == infer_S || throw(ArgumentError("inferred S $infer_S must be equal to S $S"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this defines the default constructor meaning we cannot constrcut a BlockDiagonal without having to compute all(is_square.(blocks)) every time... which seems like quite an overhead

we're going to be constructing a new BlockDiagonal quite a lot (as the output most mathematical operations)

can we instead have the default constructor expect the S arg and have a constructor that allows us passing S when we already know it (and similarly have outer-constructors that compute the S only when not provided)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that's a good point. I approached this defensively to prevent bugs by preserving the invariant and didn't give much consideration to efficiency.

Doing some minimal benchmarking it looks like the overhead is small, but not negligible. E.g. consider a simple operation (-) where presumably the impact of checking is the relatively the largest*.

julia> const blocks = [rand(3, 3) for _ in 1:100];

julia> const bd = BlockDiagonal(blocks);

julia> @btime -bd; # master
  5.222 μs (104 allocations: 13.50 KiB)

julia> @btime -bd; # this PR
  5.416 μs (106 allocations: 13.61 KiB)

julia> @btime all(BlockDiagonals.is_square.(blocks));
  185.885 ns (2 allocations: 112 bytes)

so the effect is about 5% slowdown. I am somewhat nervous of allowing the users to make S inconsistent with the actual state of the blocks, but since most users will just call BlockDiagonal(blocks) maybe that's ok, and allows us to speed up (keep at the same speed really) common math operations by calling the inner constructor directly. What do you think?

*The constructor alone is orders of magnitude slower of course, but that's probably not a relevant comparison since it will always be done in addition to some other operation on blocks.

return new{T, V, S}(blocks)
end
end

function BlockDiagonal(blocks::Vector{V}) where {T, V<:AbstractMatrix{T}}
return BlockDiagonal{T, V}(blocks)
S = all(is_square.(blocks))
return BlockDiagonal{T, V, S}(blocks)
end

BlockDiagonal(B::BlockDiagonal) = B
Expand Down
6 changes: 3 additions & 3 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ end
# multiplication
function ChainRulesCore.rrule(
::typeof(*),
bm::BlockDiagonal{T, V},
bm::BlockDiagonal{T, V, S},
v::StridedVector{T}
) where {T<:Union{Real, Complex}, V<:Matrix{T}}
) where {T<:Union{Real, Complex}, V<:Matrix{T}, S}

y = bm * v

Expand All @@ -72,7 +72,7 @@ function ChainRulesCore.rrule(
)
end

b̄m = Tangent{BlockDiagonal{T, V}}(;blocks=Δblocks)
b̄m = Tangent{BlockDiagonal{T, V, S}}(;blocks=Δblocks)
v̄ = InplaceableThunk(X̄ -> mul!(X̄, bm', ȳ, true, true), @thunk(bm' * ȳ))
return NoTangent(), b̄m, v̄
end
Expand Down
4 changes: 4 additions & 0 deletions src/deprecate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Base.@deprecate(
BlockDiagonal{T, V}(blocks) where {T, V},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why deprecate this rather than just computing S in this case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was on the fence, but in the end it seemed simpler to have fewer constructors. If we decide to go with the suggestion on your other comment we should maybe bring this back?

BlockDiagonal{T, V, all(is_square.(blocks))}(blocks)
)
8 changes: 4 additions & 4 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ for f in (:adjoint, :eigvecs, :inv, :pinv, :transpose)
end

LinearAlgebra.diag(B::BlockDiagonal) = map(i -> getindex(B, i, i), 1:minimum(size(B)))
LinearAlgebra.diag(B::BlockDiagonal{T, V, true}) where {T, V} = mapreduce(diag, vcat, B.blocks)
Copy link
Collaborator Author

@mzgubic mzgubic Jul 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the first use case

julia> b = BlockDiagonal([rand(3, 3) for _ in 1:100]);

julia> @btime diag(b); # this PR
  20.859 μs (199 allocations: 137.12 KiB)

julia> @btime diag(b); # master
  67.951 μs (298 allocations: 139.84 KiB)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are these timings showing us? is the second timing using the old version of the package?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, added the comment now.

LinearAlgebra.det(B::BlockDiagonal) = prod(det, blocks(B))
LinearAlgebra.logdet(B::BlockDiagonal) = sum(logdet, blocks(B))
LinearAlgebra.tr(B::BlockDiagonal) = sum(tr, blocks(B))
Expand Down Expand Up @@ -157,12 +158,11 @@ function _mul!(C::BlockDiagonal, A::BlockDiagonal, B::BlockDiagonal, α::Number,
return C
end

function LinearAlgebra.:\(B::BlockDiagonal{T, V, false}, vm::AbstractVecOrMat) where {T, V}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is the second use case, going from if to dispatch

return Matrix(B) \ vm # Fallback on the generic LinearAlgebra method
end
function LinearAlgebra.:\(B::BlockDiagonal, vm::AbstractVecOrMat)
row_i = 1
# BlockDiagonals with non-square blocks
if !all(is_square, blocks(B))
return Matrix(B) \ vm # Fallback on the generic LinearAlgebra method
end
result = similar(vm)
for block in blocks(B)
nrow = size(block, 1)
Expand Down
5 changes: 5 additions & 0 deletions test/deprecate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
@testset "deprecate.jl" begin
blocks = [rand(3, 3), rand(3, 3)]
@test_deprecated BlockDiagonal{Float64, Matrix{Float64}}(blocks)
@test BlockDiagonal(blocks) == BlockDiagonal{Float64, Matrix{Float64}}(blocks)
end
4 changes: 2 additions & 2 deletions test/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ end
@test C.UL ≈ C.U
@test C.uplo === 'U'
@test C.info == 0
@test typeof(C) == Cholesky{Float64, BlockDiagonal{Float64, Matrix{Float64}}}
@test typeof(C) == Cholesky{Float64, BlockDiagonal{Float64, Matrix{Float64}, true}}
@test PDMat(cholesky(BD)) == PDMat(cholesky(Matrix(BD)))

M = BlockDiagonal(map(Matrix, blocks(C.L)))
Expand All @@ -192,7 +192,7 @@ end
@test C.UL ≈ C.L
@test C.uplo === 'L'
@test C.info == 0
@test typeof(C) == Cholesky{Float64, BlockDiagonal{Float64, Matrix{Float64}}}
@test typeof(C) == Cholesky{Float64, BlockDiagonal{Float64, Matrix{Float64}, true}}

# we didn't think we needed to support this, but #109
d = Diagonal(rand(5))
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ push!(ChainRulesTestUtils.TRANSFORMS_TO_ALT_TANGENTS, x -> @thunk(x))
include("base_maths.jl")
include("chainrules.jl")
include("linalg.jl")
include("deprecate.jl")
end # tests