From 7dce60dfb095cc3044ce107505ca54950e53a76f Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 29 Jul 2022 15:56:42 +0200 Subject: [PATCH 1/6] implement square type parameter --- src/BlockDiagonals.jl | 1 + src/blockdiagonal.jl | 11 +++++++---- src/chainrules.jl | 6 +++--- src/deprecate.jl | 4 ++++ test/linalg.jl | 4 ++-- 5 files changed, 17 insertions(+), 9 deletions(-) create mode 100644 src/deprecate.jl diff --git a/src/BlockDiagonals.jl b/src/BlockDiagonals.jl index 1d69dfe..377045c 100644 --- a/src/BlockDiagonals.jl +++ b/src/BlockDiagonals.jl @@ -11,6 +11,7 @@ import ChainRulesCore.ProjectTo export BlockDiagonal, blocks export blocksize, blocksizes, nblocks +include("deprecate.jl") include("blockdiagonal.jl") include("base_maths.jl") include("chainrules.jl") diff --git a/src/blockdiagonal.jl b/src/blockdiagonal.jl index 96e58fe..5fa5e18 100644 --- a/src/blockdiagonal.jl +++ b/src/blockdiagonal.jl @@ -5,16 +5,19 @@ A matrix with matrices on the diagonal, and zeros off the diagonal. """ -struct BlockDiagonal{T, V<:AbstractMatrix{T}} <: AbstractMatrix{T} +struct BlockDiagonal{T, V<:AbstractMatrix{T}, S} <: AbstractMatrix{T} blocks::Vector{V} - function BlockDiagonal{T, V}(blocks::Vector{V}) where {T, V<:AbstractMatrix{T}} - return new{T, V}(blocks) + function BlockDiagonal{T, V, S}(blocks::Vector{V}) where {T, V<:AbstractMatrix{T}, S} + infer_S = all(is_square.(blocks)) + S == infer_S || throw(ArgumentError("inferred S $infer_S must be equal to S $S")) + return new{T, V, S}(blocks) end end function BlockDiagonal(blocks::Vector{V}) where {T, V<:AbstractMatrix{T}} - return BlockDiagonal{T, V}(blocks) + S = all(is_square.(blocks)) + return BlockDiagonal{T, V, S}(blocks) end BlockDiagonal(B::BlockDiagonal) = B diff --git a/src/chainrules.jl b/src/chainrules.jl index 7b383e8..3f1c70a 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -49,9 +49,9 @@ end # multiplication function ChainRulesCore.rrule( ::typeof(*), - bm::BlockDiagonal{T, V}, + bm::BlockDiagonal{T, V, S}, v::StridedVector{T} - ) where {T<:Union{Real, Complex}, V<:Matrix{T}} + ) where {T<:Union{Real, Complex}, V<:Matrix{T}, S} y = bm * v @@ -72,7 +72,7 @@ function ChainRulesCore.rrule( ) end - b̄m = Tangent{BlockDiagonal{T, V}}(;blocks=Δblocks) + b̄m = Tangent{BlockDiagonal{T, V, S}}(;blocks=Δblocks) v̄ = InplaceableThunk(X̄ -> mul!(X̄, bm', ȳ, true, true), @thunk(bm' * ȳ)) return NoTangent(), b̄m, v̄ end diff --git a/src/deprecate.jl b/src/deprecate.jl new file mode 100644 index 0000000..66c5505 --- /dev/null +++ b/src/deprecate.jl @@ -0,0 +1,4 @@ +Base.@deprecate( + BlockDiagonal{T, V}(blocks) where {T, V}, + BlockDiagonal{T, V, all(is_square.(blocks))}(blocks) +) diff --git a/test/linalg.jl b/test/linalg.jl index d9f8977..0112eaf 100644 --- a/test/linalg.jl +++ b/test/linalg.jl @@ -181,7 +181,7 @@ end @test C.UL ≈ C.U @test C.uplo === 'U' @test C.info == 0 - @test typeof(C) == Cholesky{Float64, BlockDiagonal{Float64, Matrix{Float64}}} + @test typeof(C) == Cholesky{Float64, BlockDiagonal{Float64, Matrix{Float64}, true}} @test PDMat(cholesky(BD)) == PDMat(cholesky(Matrix(BD))) M = BlockDiagonal(map(Matrix, blocks(C.L))) @@ -192,7 +192,7 @@ end @test C.UL ≈ C.L @test C.uplo === 'L' @test C.info == 0 - @test typeof(C) == Cholesky{Float64, BlockDiagonal{Float64, Matrix{Float64}}} + @test typeof(C) == Cholesky{Float64, BlockDiagonal{Float64, Matrix{Float64}, true}} # we didn't think we needed to support this, but #109 d = Diagonal(rand(5)) From 83838aa2bec08424c8bed6d2705e93c2981679b8 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 29 Jul 2022 16:09:09 +0200 Subject: [PATCH 2/6] test deprecate.jl --- test/deprecate.jl | 5 +++++ test/runtests.jl | 1 + 2 files changed, 6 insertions(+) create mode 100644 test/deprecate.jl diff --git a/test/deprecate.jl b/test/deprecate.jl new file mode 100644 index 0000000..8a762c2 --- /dev/null +++ b/test/deprecate.jl @@ -0,0 +1,5 @@ +@testset "deprecate.jl" begin + blocks = [rand(3, 3), rand(3, 3)] + @test_deprecated BlockDiagonal{Float64, Matrix{Float64}}(blocks) + @test BlockDiagonal(blocks) == BlockDiagonal{Float64, Matrix{Float64}}(blocks) +end diff --git a/test/runtests.jl b/test/runtests.jl index 79e1ebe..2603a21 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -16,4 +16,5 @@ push!(ChainRulesTestUtils.TRANSFORMS_TO_ALT_TANGENTS, x -> @thunk(x)) include("base_maths.jl") include("chainrules.jl") include("linalg.jl") + include("deprecate.jl") end # tests From 62e27451b525db6f3a0a8afaaf489434cc2508d1 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 29 Jul 2022 16:09:20 +0200 Subject: [PATCH 3/6] v0.1.37 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 4dd1bc4..f4b67ad 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "BlockDiagonals" uuid = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0" authors = ["Invenia Technical Computing Corporation"] -version = "0.1.36" +version = "0.1.37" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From af32d8ae1c10e4aad49b42b2cf64bf3a765c525f Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 29 Jul 2022 16:14:01 +0200 Subject: [PATCH 4/6] dispatch on S for ' --- src/linalg.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/linalg.jl b/src/linalg.jl index 1d2b0d5..f7b641b 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -157,12 +157,11 @@ function _mul!(C::BlockDiagonal, A::BlockDiagonal, B::BlockDiagonal, α::Number, return C end +function LinearAlgebra.:\(B::BlockDiagonal{T, V, false}, vm::AbstractVecOrMat) where {T, V} + return Matrix(B) \ vm # Fallback on the generic LinearAlgebra method +end function LinearAlgebra.:\(B::BlockDiagonal, vm::AbstractVecOrMat) row_i = 1 - # BlockDiagonals with non-square blocks - if !all(is_square, blocks(B)) - return Matrix(B) \ vm # Fallback on the generic LinearAlgebra method - end result = similar(vm) for block in blocks(B) nrow = size(block, 1) From 3de63b21d408cabf8120afcdb83a5613ffe1c294 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 29 Jul 2022 16:15:12 +0200 Subject: [PATCH 5/6] fix include order --- src/BlockDiagonals.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/BlockDiagonals.jl b/src/BlockDiagonals.jl index 377045c..899e2dc 100644 --- a/src/BlockDiagonals.jl +++ b/src/BlockDiagonals.jl @@ -11,10 +11,10 @@ import ChainRulesCore.ProjectTo export BlockDiagonal, blocks export blocksize, blocksizes, nblocks -include("deprecate.jl") include("blockdiagonal.jl") include("base_maths.jl") include("chainrules.jl") include("linalg.jl") +include("deprecate.jl") end # end module From 56906cce915cc59ce30c6644d68aec705f3103a4 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Fri, 29 Jul 2022 16:22:36 +0200 Subject: [PATCH 6/6] improve diag --- src/linalg.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/linalg.jl b/src/linalg.jl index f7b641b..dfdb4ed 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -5,6 +5,7 @@ for f in (:adjoint, :eigvecs, :inv, :pinv, :transpose) end LinearAlgebra.diag(B::BlockDiagonal) = map(i -> getindex(B, i, i), 1:minimum(size(B))) +LinearAlgebra.diag(B::BlockDiagonal{T, V, true}) where {T, V} = mapreduce(diag, vcat, B.blocks) LinearAlgebra.det(B::BlockDiagonal) = prod(det, blocks(B)) LinearAlgebra.logdet(B::BlockDiagonal) = sum(logdet, blocks(B)) LinearAlgebra.tr(B::BlockDiagonal) = sum(tr, blocks(B))