From 5477835782155ae5d74d711d8ccc6da9fd396c0e Mon Sep 17 00:00:00 2001 From: Nick Robinson Date: Mon, 15 Feb 2021 19:51:44 +0000 Subject: [PATCH] Improve tests in blockdiagonal.jl --- test/blockdiagonal.jl | 39 ++++++++++++++++++--------------------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/test/blockdiagonal.jl b/test/blockdiagonal.jl index 79f2cc9..9d3dc9f 100644 --- a/test/blockdiagonal.jl +++ b/test/blockdiagonal.jl @@ -17,18 +17,14 @@ end blocks2 = [rand(rng, N1, N1), rand(rng, N3, N3), rand(rng, N2, N2)] blocks3 = [rand(rng, N1, N1), rand(rng, N2, N2), rand(rng, N2, N2)] - @testset "$T" for (T, (b1, b2, b3)) in ( - Tuple => (BlockDiagonal(Tuple(blocks1)), BlockDiagonal(Tuple(blocks2)), BlockDiagonal(Tuple(blocks3))), - Vector => (BlockDiagonal(blocks1), BlockDiagonal(blocks2), BlockDiagonal(blocks3)), - ) - A = rand(rng, N, N + N1) - B = rand(rng, N + N1, N + N2) - A′, B′ = A', B' - a = rand(rng, N) - b = rand(rng, N + N1) + @testset for V in (Tuple, Vector) + b1 = BlockDiagonal(V(blocks1)) + b2 = BlockDiagonal(V(blocks2)) + N = size(b1, 1) @testset "AbstractArray" begin - X = rand(2, 2); Y = rand(3, 3) + X = rand(2, 2) + Y = rand(3, 3) @test size(b1) == (N, N) @test size(b1, 1) == N && size(b1, 2) == N @@ -53,7 +49,7 @@ end end @testset "parent" begin - @test parent(b1) isa Union{Tuple,AbstractVector} + @test parent(b1) isa V @test eltype(parent(b1)) <: AbstractMatrix @test parent(BlockDiagonal([X, Y])) == [X, Y] @test parent(BlockDiagonal((X, Y))) == (X, Y) @@ -66,7 +62,7 @@ end end @testset "setindex!" begin - X = BlockDiagonal([rand(Float32, 5, 5), rand(Float32, 3, 3)]) + X = BlockDiagonal(V([rand(Float32, 5, 5), rand(Float32, 3, 3)])) X[10] = Int(10) @test X[10] === Float32(10.0) X[3, 3] = Int(9) @@ -78,14 +74,15 @@ end @testset "ChainRules" begin @testset "BlockDiagonal" begin - x = [randn(1, 2), randn(2, 2)] - x̄ = [randn(1, 2), randn(2, 2)] - ȳ = Composite{typeof(BlockDiagonal(x))}(blocks=[randn(1, 2), randn(2, 2)]) + x = V([randn(1, 2), randn(2, 2)]) + x̄ = V([randn(1, 2), randn(2, 2)]) + + ȳ = Composite{typeof(BlockDiagonal(x))}(blocks=V([randn(1, 2), randn(2, 2)])) rrule_test(BlockDiagonal, ȳ, (x, x̄)) end @testset "Matrix" begin - D = BlockDiagonal([randn(1, 2), randn(2, 2)]) - D̄ = Composite{typeof(D)}((blocks=[randn(1, 2), randn(2, 2)]), ) + D = BlockDiagonal(V([randn(1, 2), randn(2, 2)])) + D̄ = Composite{typeof(D)}((blocks=V([randn(1, 2), randn(2, 2)])),) Ȳ = randn(size(D)) rrule_test(Matrix, Ȳ, (D, D̄)) end @@ -98,9 +95,9 @@ end end @testset "blocks size" begin - B = BlockDiagonal([rand(3, 3), rand(4, 4)]) + B = BlockDiagonal(V([rand(3, 3), rand(4, 4)])) @test nblocks(B) == 2 - @test blocksizes(B) == [(3, 3), (4, 4)] + @test blocksizes(B) == V([(3, 3), (4, 4)]) @test blocksize(B, 2) == blocksizes(B)[2] == blocksize(B, 2, 2) end @@ -124,8 +121,8 @@ end @testset "Non-Square Matrix" begin A1 = ones(2, 4) A2 = 2 * ones(3, 2) - B1 = BlockDiagonal([A1, A2]) - B2 = [A1 zeros(2, 2); zeros(3, 4) A2] + B1 = BlockDiagonal(V([A1, A2])) + B2 = [A1 zeros(2, 2); zeros(3, 4) A2] @test B1 == B2 # Dimension check