Skip to content

Commit

Permalink
make cumsum result type consistent with sum, make cumsum! and cumsum …
Browse files Browse the repository at this point in the history
…use same codepath for 1d arrays (see discussion in JuliaLang#9650)
  • Loading branch information
stevengj committed Jan 7, 2015
1 parent 3de4d53 commit b19d8f8
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
22 changes: 12 additions & 10 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1452,11 +1452,11 @@ symdiff(a) = a
symdiff(a, b) = union(setdiff(a,b), setdiff(b,a))
symdiff(a, b, rest...) = symdiff(a, symdiff(b, rest...))

_cumsum_type{T<:Number}(v::AbstractArray{T}) = typeof(+zero(T))
_cumsum_type(v) = typeof(v[1]+v[1])
_cumsum_type{T<:Number}(v::AbstractArray{T}) = typeof(r_promote(AddFun, zero(T)::T))
_cumsum_type(v) = typeof(r_promote(AddFun, v[1]))

for (f, fp, op) = ((:cumsum, :cumsum_pairwise!, :+),
(:cumprod, :cumprod_pairwise!, :*) )
for (f, f!, fp, op) = ((:cumsum, :cumsum!, :cumsum_pairwise!, :+),
(:cumprod, :cumprod!, :cumprod_pairwise!, :*) )
# in-place cumsum of c = s+v[range(i1,n)], using pairwise summation
@eval function ($fp){T}(v::AbstractVector, c::AbstractVector{T}, s, i1, n)
local s_::T # for sum(v[range(i1,n)]), i.e. sum without s
Expand All @@ -1475,16 +1475,18 @@ for (f, fp, op) = ((:cumsum, :cumsum_pairwise!, :+),
return s_
end

@eval function ($f)(v::AbstractVector)
@eval function ($f!)(result::AbstractVector, v::AbstractVector)
n = length(v)
if n == 0; return result; end
($fp)(v, result, $(op==:+ ? :(zero(v[1])) : :(one(v[1]))), 1, n)
return result
end

@eval function ($f)(v::AbstractVector)
c = $(op===:+ ? (:(similar(v,_cumsum_type(v)))) :
(:(similar(v))))
if n == 0; return c; end
($fp)(v, c, $(op==:+ ? :(zero(v[1])) : :(one(v[1]))), 1, n)
return c
return ($f!)(c, v)
end

@eval ($f)(A::AbstractArray) = ($f)(A, 1)
end

for (f, op) = ((:cummin, :min), (:cummax, :max))
Expand Down
7 changes: 7 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -978,4 +978,11 @@ a = [ [ 1 0 0 ], [ 0 0 0 ] ]
# issue #9648
let x = fill(1.5f0, 10^7)
@test abs(1.5f7 - cumsum(x)[end]) < 3*eps(1.5f7)
@test cumsum(x) == cumsum!(similar(x), x)
end

# cumsum type consistency (discussed in #9650)
let x = Uint8[1,2,3,4,6,7]
@test eltype(cumsum(x)) == typeof(sum(x)) == eltype(cumsum(reshape(x,3,2)))
@test cumsum(x) == cumsum!(similar(x),x) == cumsum!(similar(x,Int), x)
end

0 comments on commit b19d8f8

Please sign in to comment.