Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement sum #446

Merged
merged 14 commits into from
Jul 16, 2024
3 changes: 2 additions & 1 deletion src/BandedMatrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import Base: axes, axes1, getproperty, getindex, setindex!, *, +, -, ==, <, <=,
>=, /, \, adjoint, transpose, showerror, convert, size, view,
unsafe_indices, first, last, size, length, unsafe_length, step, to_indices,
to_index, show, fill!, similar, copy, promote_rule, real, imag,
copyto!, Array
copyto!, Array, sum

using Base.Broadcast: AbstractArrayStyle, DefaultArrayStyle, Broadcasted
import Base.Broadcast: BroadcastStyle, broadcasted
Expand Down Expand Up @@ -99,4 +99,5 @@ end
include("precompile.jl")



end #module
42 changes: 42 additions & 0 deletions src/banded/BandedMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1007,3 +1007,45 @@
l,u = bandwidths(A)
_BandedMatrix(reshape(resize!(vec(copy(bandeddata(A))), (l+u+1)*m), l+u+1, m), n, l,u)
end

function sum(A::BandedMatrix; dims=:)
dlfivefifty marked this conversation as resolved.
Show resolved Hide resolved
if dims isa Colon
l, u = bandwidths(A)
ret = zero(eltype(A))
if l + u < 0
return ret
dlfivefifty marked this conversation as resolved.
Show resolved Hide resolved
end
n, m = size(A)
for i = 1:n, j = rowrange(A, i)
dlfivefifty marked this conversation as resolved.
Show resolved Hide resolved
ret += A[i, j]
end
ret

Check warning on line 1022 in src/banded/BandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/banded/BandedMatrix.jl#L1022

Added line #L1022 was not covered by tests
elseif dims > 2
A

Check warning on line 1024 in src/banded/BandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/banded/BandedMatrix.jl#L1024

Added line #L1024 was not covered by tests
dlfivefifty marked this conversation as resolved.
Show resolved Hide resolved
elseif dims == 2
l, u = bandwidths(A)
n, m = size(A)
ret = zeros(eltype(A), n, 1)
if l + u < 0
return ret
end
for i = 1:n, j = rowrange(A, i)
ret[i, 1] += A[i, j]
end
ret

Check warning on line 1035 in src/banded/BandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/banded/BandedMatrix.jl#L1035

Added line #L1035 was not covered by tests
elseif dims == 1
l, u = bandwidths(A)
n, m = size(A)
ret = zeros(eltype(A), 1, m)
dlfivefifty marked this conversation as resolved.
Show resolved Hide resolved
if l + u < 0
return ret
end
for i = 1:m, j = colrange(A, i)
ret[1, i] += A[j, i]
end
ret

Check warning on line 1046 in src/banded/BandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/banded/BandedMatrix.jl#L1046

Added line #L1046 was not covered by tests
else
throw(ArgumentError("dimension must be ≥ 1, got $dims"))
end
end

1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ include("test_symbanded.jl")
include("test_tribanded.jl")
include("test_interface.jl")
include("test_miscs.jl")
include("test_sum.jl")
20 changes: 20 additions & 0 deletions test/test_sum.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
module TestSum

using Test, BandedMatrices, Random

r = brand(Float64,rand(1:10_000),rand(1:10_000),rand(-20:100),rand(-20:100))
dlfivefifty marked this conversation as resolved.
Show resolved Hide resolved
empty_r = brand(Float64,rand(1:1_000),rand(1:1_000),rand(1:100),rand(-200:-101))
n,m = size(empty_r)
matr = Matrix(r)
@testset "sum" begin
@test sum(empty_r) ≈ 0
@test sum(empty_r; dims = 2) ≈ zeros(n,1)
@test sum(empty_r; dims = 1) ≈ zeros(1,m)
@test sum(r) ≈ sum(matr) rtol = 1e-10
@test sum(r; dims=2) ≈ sum(matr; dims=2) rtol = 1e-10
@test sum(r; dims=1) ≈ sum(matr; dims=1) rtol = 1e-10
@test sum(r; dims=3) == r
@test_throws ArgumentError sum(r; dims=0)
end

end
Loading