Skip to content

Commit

Permalink
fix issue involving bandwidths larger than dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
max-vassili3v committed Jul 20, 2024
1 parent 0282c46 commit 174b393
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 10 deletions.
16 changes: 13 additions & 3 deletions src/generic/AbstractBandedMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -414,21 +414,31 @@ function LinearAlgebra.vcat(x::AbstractBandedMatrix...)

#instantiate the returned banded matrix with zeros and required bandwidths/dimensions
m = size(x[1], 2)
l, u = bandwidths(x[1])
l,u = -m, typemin(Int64)
n = 0
isempty = true

Check warning on line 419 in src/generic/AbstractBandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/generic/AbstractBandedMatrix.jl#L416-L419

Added lines #L416 - L419 were not covered by tests

#Check for dimension error and calculate bandwidths
for A in x
if size(A, 2) != m
sizes = Tuple(size(b, 2) for b in x)
throw(DimensionMismatch("number of columns of each matrix must match (got $sizes)"))

Check warning on line 425 in src/generic/AbstractBandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/generic/AbstractBandedMatrix.jl#L422-L425

Added lines #L422 - L425 were not covered by tests
end

u = max(u, bandwidth(A, 2) - n)
l = max(l, n + bandwidth(A, 1))
l_A, u_A = bandwidths(A)
if l_A + u_A >= 0
isempty = false
u = max(u, min(m - 1, u_A) - n)
l = max(l, min(size(A, 1) - 1, l_A) + n)

Check warning on line 432 in src/generic/AbstractBandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/generic/AbstractBandedMatrix.jl#L428-L432

Added lines #L428 - L432 were not covered by tests
end

n += size(A, 1)
end

Check warning on line 436 in src/generic/AbstractBandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/generic/AbstractBandedMatrix.jl#L435-L436

Added lines #L435 - L436 were not covered by tests

type = promote_type(eltype.(x)...)
if isempty
return BandedMatrix{type}(undef, (n, m), bandwidths(Zeros(1)))

Check warning on line 440 in src/generic/AbstractBandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/generic/AbstractBandedMatrix.jl#L438-L440

Added lines #L438 - L440 were not covered by tests
end
ret = BandedMatrix(Zeros{type}(n, m), (l, u))

Check warning on line 442 in src/generic/AbstractBandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/generic/AbstractBandedMatrix.jl#L442

Added line #L442 was not covered by tests

#Populate the banded matrix
Expand Down
16 changes: 9 additions & 7 deletions test/test_cat.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module TestCat

using BandedMatrices, LinearAlgebra, Test, Random, FillArrays
using BandedMatrices, LinearAlgebra, Test, Random, FillArrays, SparseArrays

@testset "vcat" begin
@testset "banded matrices" begin
Expand All @@ -14,13 +14,15 @@ using BandedMatrices, LinearAlgebra, Test, Random, FillArrays
@test eltype(vcat(b, c)) == Float64
@test vcat(b, c) == vcat(Matrix(b), Matrix(c))

for i = 1:3
a = brand(Float64, rand(1:10), 5, rand(1:10),rand(-4:4))
b = brand(Float64, rand(1:10), 5, rand(1:10),rand(-4:4))
c = brand(Float64, rand(1:10), 5, rand(1:10),rand(-4:4))
for i in ((1,2), (-3,4), (0,-1))
a = BandedMatrix(ones(Float64, rand(1:10), 5), i)
b = BandedMatrix(ones(Int64, rand(1:10), 5), i)
c = BandedMatrix(ones(Int32, rand(1:10), 5), i)
d = vcat(a, b, c)
@test d == vcat(Matrix(a), Matrix(b), Matrix(c))
@test bandwidths(d) == (bandwidth(c, 1) + size(a, 1) + size(b, 1), bandwidth(a, 2))
sd = vcat(sparse(a), sparse(b), sparse(c))
@test eltype(d) == Float64
@test d == sd
@test bandwidths(d) == bandwidths(sd)
end
end

Expand Down

0 comments on commit 174b393

Please sign in to comment.