Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement sum #446

Merged
merged 14 commits into from
Jul 16, 2024
3 changes: 2 additions & 1 deletion src/BandedMatrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import Base: axes, axes1, getproperty, getindex, setindex!, *, +, -, ==, <, <=,
>=, /, \, adjoint, transpose, showerror, convert, size, view,
unsafe_indices, first, last, size, length, unsafe_length, step, to_indices,
to_index, show, fill!, similar, copy, promote_rule, real, imag,
copyto!, Array
copyto!, Array, sum, sum!

using Base.Broadcast: AbstractArrayStyle, DefaultArrayStyle, Broadcasted
import Base.Broadcast: BroadcastStyle, broadcasted
Expand Down Expand Up @@ -99,4 +99,5 @@ end
include("precompile.jl")



end #module
1 change: 1 addition & 0 deletions src/banded/BandedMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1007,3 +1007,4 @@ function resize(A::BandedSubBandedMatrix, n::Integer, m::Integer)
l,u = bandwidths(A)
_BandedMatrix(reshape(resize!(vec(copy(bandeddata(A))), (l+u+1)*m), l+u+1, m), n, l,u)
end

65 changes: 65 additions & 0 deletions src/generic/AbstractBandedMatrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -336,3 +336,68 @@
copy(A::Adjoint{T,<:AbstractBandedMatrix}) where T = copy(parent(A))'
copy(A::Transpose{T,<:AbstractBandedMatrix}) where T = transpose(copy(parent(A)))
end

function sum!(ret::AbstractArray, A::AbstractBandedMatrix)

Check warning on line 340 in src/generic/AbstractBandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/generic/AbstractBandedMatrix.jl#L340

Added line #L340 was not covered by tests
#Behaves similarly to Base.sum!
fill!(ret, zero(eltype(ret)))
n,m = size(A)
s = size(ret)
l = length(s)

Check warning on line 345 in src/generic/AbstractBandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/generic/AbstractBandedMatrix.jl#L342-L345

Added lines #L342 - L345 were not covered by tests
#Check for singleton dimension and perform respective sum
if s[1] == 1 && (l == 1 || s[2]==1)
for j = 1:m, i = colrange(A, j)
ret .+= A[i, j]
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add tests fro this special case

elseif s[1] == n && (l == 1 || s[2]==1)
for i = 1:n, j = rowrange(A, i)
ret[i, 1] += A[i, j]
end
elseif s[1] == 1 && s[2] == m
for j = 1:m, i = colrange(A, j)
ret[1, j] += A[i, j]
end
elseif s[1] == n && s[2] == m
copyto!(ret,A)

Check warning on line 360 in src/generic/AbstractBandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/generic/AbstractBandedMatrix.jl#L347-L360

Added lines #L347 - L360 were not covered by tests
else
throw(DimensionMismatch("reduction on matrix of size ($n, $m) with output size $s"))

Check warning on line 362 in src/generic/AbstractBandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/generic/AbstractBandedMatrix.jl#L362

Added line #L362 was not covered by tests
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add test using @test_throws

end
#return the value to mimic Base.sum!
ret

Check warning on line 365 in src/generic/AbstractBandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/generic/AbstractBandedMatrix.jl#L365

Added line #L365 was not covered by tests
end

function sum(A::AbstractBandedMatrix; dims=:)
if dims isa Colon
l, u = bandwidths(A)
ret = zero(eltype(A))
if l + u < 0
return ret

Check warning on line 373 in src/generic/AbstractBandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/generic/AbstractBandedMatrix.jl#L368-L373

Added lines #L368 - L373 were not covered by tests
end
n, m = size(A)
for j = 1:m, i = colrange(A, j)
ret += A[i, j]
end
ret
elseif dims > 2
A
elseif dims == 2
l, u = bandwidths(A)
n, m = size(A)
ret = zeros(eltype(A), n, 1)
if l + u < 0
return ret

Check warning on line 387 in src/generic/AbstractBandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/generic/AbstractBandedMatrix.jl#L375-L387

Added lines #L375 - L387 were not covered by tests
end
sum!(ret, A)
ret
elseif dims == 1
l, u = bandwidths(A)
n, m = size(A)
ret = zeros(eltype(A), 1, m)
if l + u < 0
return ret

Check warning on line 396 in src/generic/AbstractBandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/generic/AbstractBandedMatrix.jl#L389-L396

Added lines #L389 - L396 were not covered by tests
end
sum!(ret, A)
ret

Check warning on line 399 in src/generic/AbstractBandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/generic/AbstractBandedMatrix.jl#L398-L399

Added lines #L398 - L399 were not covered by tests
else
throw(ArgumentError("dimension must be ≥ 1, got $dims"))

Check warning on line 401 in src/generic/AbstractBandedMatrix.jl

View check run for this annotation

Codecov / codecov/patch

src/generic/AbstractBandedMatrix.jl#L401

Added line #L401 was not covered by tests
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ include("test_symbanded.jl")
include("test_tribanded.jl")
include("test_interface.jl")
include("test_miscs.jl")
include("test_sum.jl")
2 changes: 1 addition & 1 deletion test/test_broadcasting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Random.seed!(0)
@test identity.(A) isa BandedMatrix
@test bandwidths(identity.(A)) == bandwidths(A)

@test (z -> exp(z)-1).(A) == (z -> exp(z)-1).(Matrix(A))
@test (z -> exp(z)-1).(A) (z -> exp(z)-1).(Matrix(A)) # for some reason == is breaking on Mac CI
@test (z -> exp(z)-1).(A) isa BandedMatrix
@test bandwidths((z -> exp(z)-1).(A)) == bandwidths(A)

Expand Down
34 changes: 34 additions & 0 deletions test/test_sum.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
module TestSum

using Test, BandedMatrices, Random

Random.seed!(0)
r = brand(rand(1:10_000),rand(1:10_000),rand(-20:100),rand(-20:100))
empty_r = brand(rand(1:1_000),rand(1:1_000),rand(1:100),rand(-200:-101))
n,m = size(empty_r)
matr = Matrix(r)
@testset "sum" begin
@test sum(empty_r) == 0
@test sum(empty_r; dims = 2) == zeros(n,1)
@test sum(empty_r; dims = 1) == zeros(1,m)

@test sum(r) ≈ sum(matr) rtol = 1e-10
@test sum(r; dims=2) ≈ sum(matr; dims=2) rtol = 1e-10
@test sum(r; dims=1) ≈ sum(matr; dims=1) rtol = 1e-10
@test sum(r; dims=3) == r
@test_throws ArgumentError sum(r; dims=0)

v = [1.0]
sum!(v, r)
@test v == sum!(v, Matrix(r))
n2, m2 = size(r)
v = ones(n2)
@test sum!(v, r) == sum!(v, Matrix(r))
V = zeros(1,m2)
@test sum!(V, r) === V ≈ sum!(zeros(1,m2), Matrix(r))
V = zeros(n2,m2)
@test sum!(V, r) === V == r
@test_throws DimensionMismatch sum!(zeros(Float64, n2 + 1, m2 + 1), r)
end

end
Loading