Skip to content

Commit

Permalink
Colon-integer indexing bug (#452)
Browse files Browse the repository at this point in the history
* fix bug

* fix unit tests
  • Loading branch information
max-vassili3v authored Aug 21, 2024
1 parent d8e10c7 commit c6c3912
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 16 deletions.
4 changes: 2 additions & 2 deletions src/banded/BandedMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ end
@propagate_inbounds function getindex(A::BandedMatrix, ::Colon, j::Int)
@boundscheck checkbounds(A, axes(A,1), j)
r = similar(A, axes(A,1))
r[firstindex(r):colstart(A,j)-1] .= zero(eltype(r))
r[firstindex(r):min(size(A, 1), colstart(A,j)-1)] .= zero(eltype(r))
# broadcasted assignment is currently faster than setindex
# see https://github.com/JuliaLang/julia/issues/40962#issuecomment-1921340377
# may need revisiting in the future
Expand All @@ -439,7 +439,7 @@ end
@propagate_inbounds function getindex(A::BandedMatrix, k::Int, ::Colon)
@boundscheck checkbounds(A, k, axes(A,2))
r = similar(A, axes(A,2))
r[firstindex(r):rowstart(A,k)-1] .= zero(eltype(r))
r[firstindex(r):min(size(A, 2), rowstart(A,k)-1)] .= zero(eltype(r))
r[rowrange(A,k)] = @view A.data[data_rowrange(A,k)]
r[rowstop(A,k)+1:end] .= zero(eltype(r))
return r
Expand Down
32 changes: 18 additions & 14 deletions test/test_indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ import BandedMatrices: rowstart, rowstop, colstart, colstop,
end

@testset "vector - BandRange/Colon - integer" begin
a = BandedMatrix(Ones{Int}(5, 7), (2, 1))
a = BandedMatrix(Ones{Int}(5, 8), (2, 1))
# 5x7 BandedMatrices.BandedMatrix{Float64}:
# 1.0 1.0 0 0 0 0 0 0
# 1.0 1.0 1.0 0 0 0 0 0
Expand All @@ -224,11 +224,11 @@ import BandedMatrices: rowstart, rowstop, colstart, colstop,
a[BandRange, 5] = [15, 16]
a[BandRange, 6] = [17]

@test a == [ 1 4 0 0 0 0 0;
2 5 8 0 0 0 0;
3 6 9 12 0 0 0;
0 7 10 13 15 0 0;
0 0 11 14 16 17 0]
@test a == [ 1 4 0 0 0 0 0 0;
2 5 8 0 0 0 0 0;
3 6 9 12 0 0 0 0;
0 7 10 13 15 0 0 0;
0 0 11 14 16 17 0 0]

@test a[BandRange, 1] == @view(a[BandRange, 1]) == [1, 2, 3]
@test a[BandRange, 2] == @view(a[BandRange, 2]) == [4, 5, 6, 7]
Expand All @@ -237,6 +237,7 @@ import BandedMatrices: rowstart, rowstop, colstart, colstop,
@test a[BandRange, 5] == @view(a[BandRange, 5]) == [15, 16]
@test a[BandRange, 6] == @view(a[BandRange, 6]) == [17]
@test a[BandRange, 7] == @view(a[BandRange, 7]) == Int[]
@test a[BandRange, 8] == @view(a[BandRange, 8]) == Int[]

@test a[:, 1] == view(a, :, 1) == [1,2,3,0,0]
@test a[:, 2] == view(a, :, 2) == [4,5,6,7,0]
Expand All @@ -245,11 +246,12 @@ import BandedMatrices: rowstart, rowstop, colstart, colstop,
@test a[:, 5] == view(a, :, 5) == [0,0,0,15,16]
@test a[:, 6] == view(a, :, 6) == [0,0,0,0,17]
@test a[:, 7] == view(a, :, 7) == [0,0,0,0,0]
@test a[:, 8] == view(a, :, 8) == [0,0,0,0,0]

@test_throws BoundsError a[:, 0] = [1, 2, 3]
@test_throws DimensionMismatch a[:, 1] = [1, 2, 3]
@test_throws BoundsError a[BandRange, 0] = [1, 2, 3]
@test_throws BoundsError a[BandRange, 8] = [1, 2, 3]
@test_throws BoundsError a[BandRange, 9] = [1, 2, 3]
@test_throws DimensionMismatch a[BandRange, 1] = [1, 2]
end

Expand Down Expand Up @@ -299,7 +301,7 @@ import BandedMatrices: rowstart, rowstop, colstart, colstop,
end

@testset "vector - integer - BandRange/Colon" begin
a = BandedMatrix(Ones{Int}(7, 5), (1, 2))
a = BandedMatrix(Ones{Int}(8, 5), (1, 2))
# 5x7 BandedMatrices.BandedMatrix{Float64}:
# 1.0 1.0 0 0 0 0 0 0
# 1.0 1.0 1.0 0 0 0 0 0
Expand All @@ -315,11 +317,11 @@ import BandedMatrices: rowstart, rowstop, colstart, colstop,
a[5, BandRange] = [15, 16]
a[6, BandRange] = [17]

@test a == [ 1 4 0 0 0 0 0;
2 5 8 0 0 0 0;
3 6 9 12 0 0 0;
0 7 10 13 15 0 0;
0 0 11 14 16 17 0]'
@test a == [ 1 4 0 0 0 0 0 0;
2 5 8 0 0 0 0 0;
3 6 9 12 0 0 0 0;
0 7 10 13 15 0 0 0;
0 0 11 14 16 17 0 0]'

@test a[1, BandRange] == @view(a[1, BandRange]) == [1, 2, 3]
@test a[2, BandRange] == @view(a[2, BandRange]) == [4, 5, 6, 7]
Expand All @@ -328,6 +330,7 @@ import BandedMatrices: rowstart, rowstop, colstart, colstop,
@test a[5, BandRange] == @view(a[5, BandRange]) == [15, 16]
@test a[6, BandRange] == @view(a[6, BandRange]) == [17]
@test a[7, BandRange] == @view(a[7, BandRange]) == Int[]
@test a[8, BandRange] == @view(a[7, BandRange]) == Int[]

@test a[1, :] == @view(a[1, :]) == [1,2,3,0,0]
@test a[2, :] == @view(a[2, :]) == [4,5,6,7,0]
Expand All @@ -336,11 +339,12 @@ import BandedMatrices: rowstart, rowstop, colstart, colstop,
@test a[5, :] == @view(a[5, :]) == [0,0,0,15,16]
@test a[6, :] == @view(a[6, :]) == [0,0,0,0,17]
@test a[7, :] == @view(a[7, :]) == [0,0,0,0,0]
@test a[8, :] == @view(a[7, :]) == [0,0,0,0,0]

@test_throws BoundsError a[0, :] = [1, 2, 3]
@test_throws DimensionMismatch a[1, :] = [1, 2, 3]
@test_throws BoundsError a[0, BandRange] = [1, 2, 3]
@test_throws BoundsError a[8, BandRange] = [1, 2, 3]
@test_throws BoundsError a[9, BandRange] = [1, 2, 3]
@test_throws DimensionMismatch a[1, BandRange] = [1, 2]
end

Expand Down

0 comments on commit c6c3912

Please sign in to comment.