-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement S
(quare) type parameter
#114
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
Base.@deprecate( | ||
BlockDiagonal{T, V}(blocks) where {T, V}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why deprecate this rather than just computing S in this case? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was on the fence, but in the end it seemed simpler to have fewer constructors. If we decide to go with the suggestion on your other comment we should maybe bring this back? |
||
BlockDiagonal{T, V, all(is_square.(blocks))}(blocks) | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ for f in (:adjoint, :eigvecs, :inv, :pinv, :transpose) | |
end | ||
|
||
LinearAlgebra.diag(B::BlockDiagonal) = map(i -> getindex(B, i, i), 1:minimum(size(B))) | ||
LinearAlgebra.diag(B::BlockDiagonal{T, V, true}) where {T, V} = mapreduce(diag, vcat, B.blocks) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is the first use case julia> b = BlockDiagonal([rand(3, 3) for _ in 1:100]);
julia> @btime diag(b); # this PR
20.859 μs (199 allocations: 137.12 KiB)
julia> @btime diag(b); # master
67.951 μs (298 allocations: 139.84 KiB) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what are these timings showing us? is the second timing using the old version of the package? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep, added the comment now. |
||
LinearAlgebra.det(B::BlockDiagonal) = prod(det, blocks(B)) | ||
LinearAlgebra.logdet(B::BlockDiagonal) = sum(logdet, blocks(B)) | ||
LinearAlgebra.tr(B::BlockDiagonal) = sum(tr, blocks(B)) | ||
|
@@ -157,12 +158,11 @@ function _mul!(C::BlockDiagonal, A::BlockDiagonal, B::BlockDiagonal, α::Number, | |
return C | ||
end | ||
|
||
function LinearAlgebra.:\(B::BlockDiagonal{T, V, false}, vm::AbstractVecOrMat) where {T, V} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is the second use case, going from |
||
return Matrix(B) \ vm # Fallback on the generic LinearAlgebra method | ||
end | ||
function LinearAlgebra.:\(B::BlockDiagonal, vm::AbstractVecOrMat) | ||
row_i = 1 | ||
# BlockDiagonals with non-square blocks | ||
if !all(is_square, blocks(B)) | ||
return Matrix(B) \ vm # Fallback on the generic LinearAlgebra method | ||
end | ||
result = similar(vm) | ||
for block in blocks(B) | ||
nrow = size(block, 1) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
@testset "deprecate.jl" begin | ||
blocks = [rand(3, 3), rand(3, 3)] | ||
@test_deprecated BlockDiagonal{Float64, Matrix{Float64}}(blocks) | ||
@test BlockDiagonal(blocks) == BlockDiagonal{Float64, Matrix{Float64}}(blocks) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this defines the default constructor meaning we cannot constrcut a
BlockDiagonal
without having to computeall(is_square.(blocks))
every time... which seems like quite an overheadwe're going to be constructing a new
BlockDiagonal
quite a lot (as the output most mathematical operations)can we instead have the default constructor expect the
S
arg and have a constructor that allows us passingS
when we already know it (and similarly have outer-constructors that compute theS
only when not provided)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that's a good point. I approached this defensively to prevent bugs by preserving the invariant and didn't give much consideration to efficiency.
Doing some minimal benchmarking it looks like the overhead is small, but not negligible. E.g. consider a simple operation (
-
) where presumably the impact of checking is the relatively the largest*.so the effect is about 5% slowdown. I am somewhat nervous of allowing the users to make
S
inconsistent with the actual state of the blocks, but since most users will just callBlockDiagonal(blocks)
maybe that's ok, and allows us to speed up (keep at the same speed really) common math operations by calling the inner constructor directly. What do you think?*The constructor alone is orders of magnitude slower of course, but that's probably not a relevant comparison since it will always be done in addition to some other operation on blocks.