From abeb33c4c2393ddac3f8f1a3b1b89fe8599ceb8b Mon Sep 17 00:00:00 2001 From: Dahua Lin Date: Thu, 12 Jun 2014 16:26:10 -0500 Subject: [PATCH 1/7] updated wsum for vectors --- src/weights.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/weights.jl b/src/weights.jl index 82eb1cebe98862..66c074f25c7a39 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -21,11 +21,15 @@ isempty(wv::WeightVec) = isempty(wv.values) ##### Weighted sum ##### -# 1D weighted sum/mean +## weighted sum over vectors + +wsum(v::AbstractVector, w::AbstractVector) = dot(v, w) wsum(v::AbstractArray, w::AbstractVector) = dot(vec(v), w) + +# Note: the methods for BitArray and SparseMatrixCSC are to avoid ambiguities Base.sum(v::BitArray, w::WeightVec) = wsum(v, values(w)) Base.sum(v::SparseMatrixCSC, w::WeightVec) = wsum(v, values(w)) -Base.sum(v::AbstractArray, w::WeightVec) = wsum(v, values(w)) +Base.sum(v::AbstractArray, w::WeightVec) = dot(v, values(w)) # General Cartesian-based weighted sum across dimensions import Base.Cartesian: @ngenerate, @nloops, @nref From b467a4921bc0bbbd65b44e8d09f75affdeb46022 Mon Sep 17 00:00:00 2001 From: Dahua Lin Date: Thu, 12 Jun 2014 16:36:02 -0500 Subject: [PATCH 2/7] move the import of Base.Cartesian upfront --- src/StatsBase.jl | 1 + src/weights.jl | 1 - 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/src/StatsBase.jl b/src/StatsBase.jl index 8f0636e30aff84..2a15a85641c985 100644 --- a/src/StatsBase.jl +++ b/src/StatsBase.jl @@ -2,6 +2,7 @@ module StatsBase import Base: length, isempty, eltype, values, sum, mean, mean!, show, quantile import Base: rand, rand! import Base.LinAlg: BlasReal + import Base.Cartesian: @ngenerate, @nloops, @nref export diff --git a/src/weights.jl b/src/weights.jl index 66c074f25c7a39..12596646060b41 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -32,7 +32,6 @@ Base.sum(v::SparseMatrixCSC, w::WeightVec) = wsum(v, values(w)) Base.sum(v::AbstractArray, w::WeightVec) = dot(v, values(w)) # General Cartesian-based weighted sum across dimensions -import Base.Cartesian: @ngenerate, @nloops, @nref @ngenerate N typeof(r) function wsum!{T,N,S,W<:Real}(r::AbstractArray{T,N}, v::AbstractArray{S,N}, w::AbstractVector{W}, dim::Int) 1 <= dim <= N || error("dim = $dim not in range [1,$N]") From fcd0e8222e645057912c6e9f7071f5f0835e994c Mon Sep 17 00:00:00 2001 From: Dahua Lin Date: Fri, 13 Jun 2014 15:58:49 -0500 Subject: [PATCH 3/7] new implementation of wsum --- REQUIRE | 3 +- src/StatsBase.jl | 6 +- src/common.jl | 11 +++ src/weights.jl | 187 ++++++++++++++++++++++++++++++++++++++--------- test/means.jl | 50 ------------- test/weights.jl | 151 ++++++++++++++++++++++++++++++++++++++ 6 files changed, 319 insertions(+), 89 deletions(-) create mode 100644 test/weights.jl diff --git a/REQUIRE b/REQUIRE index 2e43c1e196c86c..3acdc1d02db99d 100644 --- a/REQUIRE +++ b/REQUIRE @@ -1 +1,2 @@ -julia 0.3- \ No newline at end of file +julia 0.3- +ArrayViews 0.4.6- diff --git a/src/StatsBase.jl b/src/StatsBase.jl index 2a15a85641c985..a320edef563a84 100644 --- a/src/StatsBase.jl +++ b/src/StatsBase.jl @@ -1,8 +1,10 @@ module StatsBase + using ArrayViews + import Base: length, isempty, eltype, values, sum, mean, mean!, show, quantile import Base: rand, rand! - import Base.LinAlg: BlasReal - import Base.Cartesian: @ngenerate, @nloops, @nref + import Base.LinAlg: BlasReal, BlasFloat + import Base.Cartesian: @ngenerate, @nloops, @nref, @nextract export diff --git a/src/common.jl b/src/common.jl index d8c8dcfc6c6e15..d8e7e02b9f55f0 100644 --- a/src/common.jl +++ b/src/common.jl @@ -26,3 +26,14 @@ fptype{T<:Union(Float32,Bool,Int8,Uint8,Int16,Uint16)}(::Type{T}) = Float32 fptype{T<:Union(Float64,Int64,Uint64,Int128,Uint128)}(::Type{T}) = Float64 fptype(::Type{Complex64}) = Complex64 fptype(::Type{Complex128}) = Complex128 + +## auxiliary functions + +function numslices(A::AbstractArray, d::Int) # the number of d-dimensional slices + ns = 1 + for i=d+1:ndims(A) + ns *= size(A,i) + end + return ns::Int +end + diff --git a/src/weights.jl b/src/weights.jl index 12596646060b41..7365b347ece817 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -31,50 +31,165 @@ Base.sum(v::BitArray, w::WeightVec) = wsum(v, values(w)) Base.sum(v::SparseMatrixCSC, w::WeightVec) = wsum(v, values(w)) Base.sum(v::AbstractArray, w::WeightVec) = dot(v, values(w)) -# General Cartesian-based weighted sum across dimensions -@ngenerate N typeof(r) function wsum!{T,N,S,W<:Real}(r::AbstractArray{T,N}, v::AbstractArray{S,N}, - w::AbstractVector{W}, dim::Int) - 1 <= dim <= N || error("dim = $dim not in range [1,$N]") - for i = 1:N - (i == dim && size(r, i) == 1 && size(v, i) == length(w)) || size(r, i) == size(v, i) || error(DimensionMismatch("")) +## wsum along dimension +# +# Brief explanation of the algorithm: +# ------------------------------------ +# +# 1. _wsum! provides the core implementation, which assumes that +# the dimensions of all input arguments are consistent, and no +# dimension checking is performed therein. +# +# wsum and wsum! perform argument checking and call _wsum! +# internally. +# +# 2. _wsum! adopt a Cartesian based implementation for general +# sub types of AbstractArray. Particularly, a faster routine +# that keeps a local accumulator will be used when dim = 1. +# +# The internal function that implements this is _wsum_general! +# +# 3. _wsum! is specialized for following cases: +# (a) A is a vector: we invoke the vector version wsum above. +# The internal function that implements this is _wsum1! +# +# (b) A is a dense matrix with eltype <: BlasReal: we call gemv! +# The internal function that implements this is _wsum2_blas! +# +# (c) A is a contiguous array with eltype <: BlasReal: +# dim == 1: treat A like a matrix of size (d1, d2 x ... x dN) +# dim == N: treat A like a matrix of size (d1 x ... x d(N-1), dN) +# otherwise: decompose A into multiple pages, and apply _wsum2! +# for each +# +# (d) A is a general dense array with eltype <: BlasReal: +# dim <= 2: delegate to (a) and (b) +# otherwise, decompose A into multiple pages +# + +function _wsum1!(R::AbstractArray, A::AbstractVector, w::AbstractVector, init::Bool) + r = wsum(A, w) + if init + R[1] = r + else + R[1] += r + end + return R +end + +function _wsum2_blas!{T<:BlasReal}(R::StridedVector{T}, A::StridedMatrix{T}, w::StridedVector{T}, dim::Int, init::Bool) + beta = ifelse(init, zero(T), one(T)) + trans = dim == 1 ? 'T' : 'N' + BLAS.gemv!(trans, one(T), A, w, beta, R) + return R +end + +function _wsumN!{T<:BlasReal,N}(R::ContiguousArray{T}, A::ContiguousArray{T,N}, w::StridedVector{T}, dim::Int, init::Bool) + if dim == 1 + m = size(A, 1) + n = div(length(A), m) + _wsum2_blas!(view(R,:), reshape_view(A, (m, n)), w, 1, init) + elseif dim == N + n = size(A, N) + m = div(length(A), n) + _wsum2_blas!(view(R,:), reshape_view(A, (m, n)), w, 2, init) + else # 1 < dim < N + m = 1 + for i = 1:dim-1; m *= size(A, i); end + n = size(A, dim) + k = 1 + for i = dim+1:N; k *= size(A, i); end + Av = reshape_view(A, (m, n, k)) + Rv = reshape_view(R, (m, k)) + for i = 1:k + _wsum2_blas!(view(Rv,:,i), view(Av,:,:,i), w, 2, init) + end + end + return R +end + +function _wsumN!{T<:BlasReal,N}(R::ContiguousArray{T}, A::DenseArray{T,N}, w::StridedVector{T}, dim::Int, init::Bool) + @assert N >= 3 + if dim <= 2 + m = size(A, 1) + n = size(A, 2) + npages = 1 + for i = 3:N + npages *= size(A, i) + end + rlen = ifelse(dim == 1, n, m) + Rv = reshape_view(R, (rlen, npages)) + for i = 1:npages + _wsum2_blas!(view(Rv,:,i), view(A,:,:,i), w, dim, init) + end + else + _wsum_general!(R, A, w, dim, init) end - fill!(r, 0) - weight = zero(W) - @nloops N i v d->(if d == dim - weight = w[i_d] - j_d = 1 - else - j_d = i_d - end) @inbounds (@nref N r j) += (@nref N v i)*weight - r + return R end -# Weighted sum via `A_mul_B!`/`At_mul_B!` for first and last -# dimensions of compatible arrays. `vec` and `reshape` are only -# guaranteed not to make a copy for Arrays, so only supports Arrays if -# these calls may be necessary. -function wsum!{W<:Real}(r::Union(Array, AbstractVector), v::Union(Array, AbstractMatrix), w::AbstractVector{W}, dim::Int) +# General Cartesian-based weighted sum across dimensions +@ngenerate N typeof(R) function _wsum_general!{T,RT,WT,N}(R::AbstractArray{RT}, + A::AbstractArray{T,N}, w::AbstractVector{WT}, dim::Int, init::Bool) + init && fill!(R, zero(RT)) + wi = zero(WT) if dim == 1 - m = size(v, 1) - n = div(length(v), m) - (length(r) == n && length(w) == m) || throw(DimensionMismatch("")) - At_mul_B!(vec(r), isa(v, AbstractMatrix) ? v : reshape(v, m, n), w) - elseif dim == ndims(v) - n = size(v, ndims(v)) - m = div(length(v), n) - (length(r) == m && length(w) == n) || throw(DimensionMismatch("")) - A_mul_B!(vec(r), isa(v, AbstractMatrix) ? v : reshape(v, m, n), w) + @nextract N sizeR d->size(R,d) + sizA1 = size(A, 1) + @nloops N i d->(d>1? (1:size(A,d)) : (1:1)) d->(j_d = sizeR_d==1 ? 1 : i_d) begin + @inbounds r = (@nref N R j) + for i_1 = 1:sizA1 + @inbounds r += (@nref N A i) * w[i_1] + end + @inbounds (@nref N R j) = r + end else - invoke(wsum!, (AbstractArray, AbstractArray, typeof(w), Int), r, v, w, dim) + @nloops N i A d->(if d == dim + wi = w[i_d] + j_d = 1 + else + j_d = i_d + end) @inbounds (@nref N R j) += (@nref N A i) * wi end - r + return R +end + + +# N = 1 +_wsum!{T<:BlasReal}(R::ContiguousArray{T}, A::DenseArray{T,1}, w::StridedVector{T}, dim::Int, init::Bool) = + _wsum1!(R, A, w, init) + +# N = 2 +_wsum!{T<:BlasReal}(R::ContiguousArray{T}, A::DenseArray{T,2}, w::StridedVector{T}, dim::Int, init::Bool) = + (_wsum2_blas!(view(R,:), A, w, dim, init); R) + +# N >= 3 +_wsum!{T<:BlasReal,N}(R::ContiguousArray{T}, A::DenseArray{T,N}, w::StridedVector{T}, dim::Int, init::Bool) = + _wsumN!(R, A, w, dim, init) + +_wsum!(R::AbstractArray, A::AbstractArray, w::AbstractVector, dim::Int, init::Bool) = _wsum_general!(R, A, w, dim, init) + +## wsum! and wsum + +wsumtype{T,W}(::Type{T}, ::Type{W}) = typeof(zero(T) * zero(W) + zero(T) * zero(W)) + +function wsum!{T,N}(R::AbstractArray, A::AbstractArray{T,N}, w::AbstractVector, dim::Int; init::Bool=true) + 1 <= dim <= N || error("dim should be within [1, $N]") + ndims(R) <= N || error("ndims(R) should not exceed $N") + length(w) == size(A,dim) || throw(DimensionMismatch("Inconsistent array dimension.")) + # TODO: more careful examination of R's size + _wsum!(R, A, w, dim, init) +end + +function wsum{T<:Number,W<:Real}(A::AbstractArray{T}, w::AbstractVector{W}, dim::Int) + length(w) == size(A,dim) || throw(DimensionMismatch("Inconsistent array dimension.")) + _wsum!(Array(wsumtype(T,W), Base.reduced_dims(size(A), dim)), A, w, dim, true) end -Base.sum!{W<:Real}(r::AbstractArray, v::AbstractArray, w::WeightVec{W}, dim::Int) = - wsum!(r, v, values(w), dim) +# extended sum! and wsum -wsum{T<:Number,W<:Real}(v::AbstractArray{T}, w::AbstractVector{W}, dim::Int) = - wsum!(Array(typeof(zero(T)*zero(W) + zero(T)*zero(W)), Base.reduced_dims(size(v), dim)), v, w, dim) +Base.sum!{W<:Real}(R::AbstractArray, A::AbstractArray, w::WeightVec{W}, dim::Int; init::Bool=true) = + wsum!(R, A, values(w), dim; init=init) -Base.sum{T<:Number,W<:Real}(v::AbstractArray{T}, w::WeightVec{W}, dim::Int) = wsum(v, values(w), dim) +Base.sum{T<:Number,W<:Real}(A::AbstractArray{T}, w::WeightVec{W}, dim::Int) = wsum(A, values(w), dim) diff --git a/test/means.jl b/test/means.jl index 8d69b7c46bfc81..d3274178528865 100644 --- a/test/means.jl +++ b/test/means.jl @@ -18,53 +18,3 @@ using Base.Test @test_approx_eq trimmean([-100, 2, 3, 7, 200], 0.4) 4.0 @test_approx_eq trimmean([-100, 2, 3, 7, 200], 0.8) 3.0 -@test_approx_eq sum([1.0, 2.0, 3.0], weights([1/3, 1/3, 1/3])) 2.0 -@test_approx_eq sum([1.0, 2.0, 3.0], weights([1.0, 0.0, 0.0])) 1.0 -@test_approx_eq sum([1.0, 2.0, 3.0], weights([0.0, 1.0, 0.0])) 2.0 -@test_approx_eq sum([1.0, 2.0, 3.0], weights([0.0, 0.0, 1.0])) 3.0 -@test_approx_eq sum([1.0, 2.0, 3.0], weights([0.5, 0.0, 0.5])) 2.0 -@test_approx_eq sum([1.0, 2.0, 3.0], weights([0.5, 0.5, 0.0])) 1.5 -@test_approx_eq sum([1.0, 2.0, 3.0], weights([0.0, 0.5, 0.5])) 2.5 - -@test_approx_eq sum(1:3, weights([1/3, 1/3, 1/3])) 2.0 -@test_approx_eq sum(1:3, weights([1.0, 0.0, 0.0])) 1.0 -@test_approx_eq sum(1:3, weights([0.0, 1.0, 0.0])) 2.0 -@test_approx_eq sum(1:3, weights([0.0, 0.0, 1.0])) 3.0 -@test_approx_eq sum(1:3, weights([0.5, 0.0, 0.5])) 2.0 -@test_approx_eq sum(1:3, weights([0.5, 0.5, 0.0])) 1.5 -@test_approx_eq sum(1:3, weights([0.0, 0.5, 0.5])) 2.5 -@test_approx_eq sum(1:3, weights([1.0, 1.0, 0.5])) 4.5 -@test_approx_eq mean(1:3, weights([1.0, 1.0, 0.5])) 1.8 - -a = [1. 2. 3.; 4. 5. 6.] - -@test size(mean(a, weights(ones(2)), 1)) == (1, 3) -@test_approx_eq sum(a, weights([1.0, 1.0]), 1) [5.0, 7.0, 9.0] -@test_approx_eq mean(a, weights([1.0, 1.0]), 1) [2.5, 3.5, 4.5] -@test_approx_eq sum(a, weights([1.0, 0.0]), 1) [1.0, 2.0, 3.0] -@test_approx_eq sum(a, weights([0.0, 1.0]), 1) [4.0, 5.0, 6.0] - -@test size(mean(a, weights(ones(3)), 2)) == (2, 1) -@test_approx_eq wsum!(zeros(1, 2), a, [1.0, 1.0, 1.0], 2) [6.0 15.0] -@test_approx_eq wsum(a, [1.0, 1.0, 1.0], 2) [6.0 15.0] -@test_approx_eq sum!(zeros(1, 2), a, weights([1.0, 1.0, 1.0]), 2) [6.0 15.0] -@test_approx_eq sum(a, weights([1.0, 1.0, 1.0]), 2) [6.0 15.0] -@test_approx_eq mean(a, weights([1.0, 1.0, 1.0]), 2) [2.0 5.0] -@test_approx_eq sum(a, weights([1.0, 0.0, 0.0]), 2) [1.0 4.0] -@test_approx_eq sum(a, weights([0.0, 0.0, 1.0]), 2) [3.0 6.0] - -@test_throws ErrorException mean(a, weights(ones(3)), 3) -@test_throws DimensionMismatch mean(a, weights(ones(2)), 2) -@test_throws DimensionMismatch mean!(ones(1, 1), a, weights(ones(3)), 2) - -a = reshape(1.0:27.0, 3, 3, 3) - -for wt in ([1.0, 1.0, 1.0], [1.0, 0.2, 0.0], [0.2, 0.0, 1.0]) - @test_approx_eq sum(a, weights(wt), 1) sum(a.*reshape(wt, length(wt), 1, 1), 1) - @test_approx_eq sum(a, weights(wt), 2) sum(a.*reshape(wt, 1, length(wt), 1), 2) - @test_approx_eq sum(a, weights(wt), 3) sum(a.*reshape(wt, 1, 1, length(wt)), 3) - @test_approx_eq mean(a, weights(wt), 1) sum(a.*reshape(wt, length(wt), 1, 1), 1)/sum(wt) - @test_approx_eq mean(a, weights(wt), 2) sum(a.*reshape(wt, 1, length(wt), 1), 2)/sum(wt) - @test_approx_eq mean(a, weights(wt), 3) sum(a.*reshape(wt, 1, 1, length(wt)), 3)/sum(wt) - @test_throws ErrorException mean(a, weights(wt), 4) -end diff --git a/test/weights.jl b/test/weights.jl new file mode 100644 index 00000000000000..453e89fe61d69e --- /dev/null +++ b/test/weights.jl @@ -0,0 +1,151 @@ +## WeightVec + +using ArrayViews +using StatsBase +using Base.Test + +@test isa(weights([1, 2, 3]), WeightVec{Int}) +@test isa(weights([1., 2., 3.]), WeightVec{Float64}) + +@test isempty(weights(Float64[])) + +w = [1., 2., 3.] +wv = weights(w) +@test eltype(wv) == Float64 +@test length(wv) == 3 +@test values(wv) === w +@test sum(wv) == 6.0 +@test !isempty(wv) + +## wsum + +x = [6., 8., 9.] +w = [2., 3., 4.] + +@test wsum(Float64[], Float64[]) === 0.0 +@test wsum(x, w) == 72.0 + +## wsum along dimensions + +@test wsum(x, w, 1) == [72.0] + +x = rand(6, 8) +w1 = rand(6) +w2 = rand(8) + +@test size(wsum(x, w1, 1)) == (1, 8) +@test size(wsum(x, w2, 2)) == (6, 1) + +@test_approx_eq wsum(x, w1, 1) sum(x .* w1, 1) +@test_approx_eq wsum(x, w2, 2) sum(x .* w2', 2) + +x = rand(6, 5, 4) +w1 = rand(6) +w2 = rand(5) +w3 = rand(4) + +@test size(wsum(x, w1, 1)) == (1, 5, 4) +@test size(wsum(x, w2, 2)) == (6, 1, 4) +@test size(wsum(x, w3, 3)) == (6, 5, 1) + +@test_approx_eq wsum(x, w1, 1) sum(x .* w1, 1) +@test_approx_eq wsum(x, w2, 2) sum(x .* w2', 2) +@test_approx_eq wsum(x, w3, 3) sum(x .* reshape(w3, 1, 1, 4), 3) + +v = view(x, 2:4, :, :) + +@test_approx_eq wsum(v, w1[1:3], 1) sum(v .* w1[1:3], 1) +@test_approx_eq wsum(v, w2, 2) sum(v .* w2', 2) +@test_approx_eq wsum(v, w3, 3) sum(v .* reshape(w3, 1, 1, 4), 3) + +## wsum for Arrays with non-BlasReal elements + +x = rand(1:100, 6, 8) +w1 = rand(6) +w2 = rand(8) + +@test_approx_eq wsum(x, w1, 1) sum(x .* w1, 1) +@test_approx_eq wsum(x, w2, 2) sum(x .* w2', 2) + +## wsum! + +x = rand(6) +w = rand(6) + +r = ones(1) +@test wsum!(r, x, w, 1; init=true) === r +@test_approx_eq r [dot(x, w)] + +r = ones(1) +@test wsum!(r, x, w, 1; init=false) === r +@test_approx_eq r [dot(x, w) + 1.0] + +x = rand(6, 8) +w1 = rand(6) +w2 = rand(8) + +r = ones(1, 8) +@test wsum!(r, x, w1, 1; init=true) === r +@test_approx_eq r sum(x .* w1, 1) + +r = ones(1, 8) +@test wsum!(r, x, w1, 1; init=false) === r +@test_approx_eq r sum(x .* w1, 1) .+ 1.0 + +r = ones(6) +@test wsum!(r, x, w2, 2; init=true) === r +@test_approx_eq r sum(x .* w2', 2) + +r = ones(6) +@test wsum!(r, x, w2, 2; init=false) === r +@test_approx_eq r sum(x .* w2', 2) .+ 1.0 + +x = rand(8, 6, 5) +w1 = rand(8) +w2 = rand(6) +w3 = rand(5) + +r = ones(1, 6, 5) +@test wsum!(r, x, w1, 1; init=true) === r +@test_approx_eq r sum(x .* w1, 1) + +r = ones(1, 6, 5) +@test wsum!(r, x, w1, 1; init=false) === r +@test_approx_eq r sum(x .* w1, 1) .+ 1.0 + +r = ones(8, 1, 5) +@test wsum!(r, x, w2, 2; init=true) === r +@test_approx_eq r sum(x .* w2', 2) + +r = ones(8, 1, 5) +@test wsum!(r, x, w2, 2; init=false) === r +@test_approx_eq r sum(x .* w2', 2) .+ 1.0 + +r = ones(8, 6) +@test wsum!(r, x, w3, 3; init=true) === r +@test_approx_eq r sum(x .* reshape(w3, (1, 1, 5)), 3) + +r = ones(8, 6) +@test wsum!(r, x, w3, 3; init=false) === r +@test_approx_eq r sum(x .* reshape(w3, (1, 1, 5)), 3) .+ 1.0 + + +## the sum and mean syntax + +@test_approx_eq sum([1.0, 2.0, 3.0], weights([1.0, 0.5, 0.5])) 3.5 +@test_approx_eq sum(1:3, weights([1.0, 1.0, 0.5])) 4.5 + +@test_approx_eq mean([1:3], weights([1.0, 1.0, 0.5])) 1.8 +@test_approx_eq mean(1:3, weights([1.0, 1.0, 0.5])) 1.8 + +a = reshape(1.0:27.0, 3, 3, 3) +for wt in ([1.0, 1.0, 1.0], [1.0, 0.2, 0.0], [0.2, 0.0, 1.0]) + @test_approx_eq sum(a, weights(wt), 1) sum(a.*reshape(wt, length(wt), 1, 1), 1) + @test_approx_eq sum(a, weights(wt), 2) sum(a.*reshape(wt, 1, length(wt), 1), 2) + @test_approx_eq sum(a, weights(wt), 3) sum(a.*reshape(wt, 1, 1, length(wt)), 3) + @test_approx_eq mean(a, weights(wt), 1) sum(a.*reshape(wt, length(wt), 1, 1), 1)/sum(wt) + @test_approx_eq mean(a, weights(wt), 2) sum(a.*reshape(wt, 1, length(wt), 1), 2)/sum(wt) + @test_approx_eq mean(a, weights(wt), 3) sum(a.*reshape(wt, 1, 1, length(wt)), 3)/sum(wt) + @test_throws ErrorException mean(a, weights(wt), 4) +end + From 0ac825a414a1883fa6442d6aa300c6f7033fa6dd Mon Sep 17 00:00:00 2001 From: Dahua Lin Date: Fri, 13 Jun 2014 16:51:40 -0500 Subject: [PATCH 4/7] move weighted mean to src/weights.jl (with minor modification) --- src/means.jl | 15 --------------- src/weights.jl | 21 +++++++++++++++++++++ 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/src/means.jl b/src/means.jl index 7c1d23aad0d580..2ad474232fe23e 100644 --- a/src/means.jl +++ b/src/means.jl @@ -42,18 +42,3 @@ function trimmean(x::RealArray, p::FloatingPoint) end end -# Weighted means - -function wmean{T<:Number}(v::AbstractArray{T}, w::AbstractArray) - Base.depwarn("wmean is deprecated, use mean(v, weights(w)) instead.", :wmean) - mean(v, weights(w)) -end - -Base.mean(v::AbstractArray, w::WeightVec) = sum(v, w) / sum(w) - -Base.mean!(r::AbstractArray, v::AbstractArray, w::WeightVec, dim::Int) = - scale!(Base.sum!(r, v, w, dim), inv(sum(w))) - -Base.mean{T<:Number,W<:Real}(v::AbstractArray{T}, w::WeightVec{W}, dim::Int) = - mean!(Array(typeof((zero(T)*zero(W) + zero(T)*zero(W)) / one(W)), Base.reduced_dims(size(v), dim)), v, w, dim) - diff --git a/src/weights.jl b/src/weights.jl index 7365b347ece817..9ad4369edeb3df 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -172,6 +172,7 @@ _wsum!(R::AbstractArray, A::AbstractArray, w::AbstractVector, dim::Int, init::Bo ## wsum! and wsum wsumtype{T,W}(::Type{T}, ::Type{W}) = typeof(zero(T) * zero(W) + zero(T) * zero(W)) +wsumtype{T<:BlasReal}(::Type{T}, ::Type{T}) = T function wsum!{T,N}(R::AbstractArray, A::AbstractArray{T,N}, w::AbstractVector, dim::Int; init::Bool=true) 1 <= dim <= N || error("dim should be within [1, $N]") @@ -193,3 +194,23 @@ Base.sum!{W<:Real}(R::AbstractArray, A::AbstractArray, w::WeightVec{W}, dim::Int Base.sum{T<:Number,W<:Real}(A::AbstractArray{T}, w::WeightVec{W}, dim::Int) = wsum(A, values(w), dim) + +###### Weighted means ##### + +function wmean{T<:Number}(v::AbstractArray{T}, w::AbstractVector) + Base.depwarn("wmean is deprecated, use mean(v, weights(w)) instead.", :wmean) + mean(v, weights(w)) +end + +Base.mean(v::AbstractArray, w::WeightVec) = sum(v, w) / sum(w) + +Base.mean!(R::AbstractArray, A::AbstractArray, w::WeightVec, dim::Int) = + scale!(Base.sum!(R, A, w, dim), inv(sum(w))) + +wmeantype{T,W}(::Type{T}, ::Type{W}) = typeof((zero(T)*zero(W) + zero(T)*zero(W)) / one(W)) +wmeantype{T<:BlasReal}(::Type{T}, ::Type{T}) = T + +Base.mean{T<:Number,W<:Real}(A::AbstractArray{T}, w::WeightVec{W}, dim::Int) = + mean!(Array(wmeantype(T, W), Base.reduced_dims(size(A), dim)), A, w, dim) + + From 2edb2ca02461b06342fe000877c4933a8639e5a0 Mon Sep 17 00:00:00 2001 From: Dahua Lin Date: Fri, 13 Jun 2014 16:52:52 -0500 Subject: [PATCH 5/7] add weights.jl to runtests.jl --- runtests.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/runtests.jl b/runtests.jl index 128275226915ca..e16d2d1ef0288a 100644 --- a/runtests.jl +++ b/runtests.jl @@ -1,6 +1,7 @@ using StatsBase tests = ["mathfuns", + "weights", "means", "scalarstats", "counts", From 338b939f64b9754c0e973f1be7086949f8a445af Mon Sep 17 00:00:00 2001 From: Dahua Lin Date: Fri, 13 Jun 2014 16:59:27 -0500 Subject: [PATCH 6/7] getindex for weight vector --- src/weights.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/weights.jl b/src/weights.jl index 9ad4369edeb3df..ee77b377d12b89 100644 --- a/src/weights.jl +++ b/src/weights.jl @@ -18,6 +18,8 @@ values(wv::WeightVec) = wv.values sum(wv::WeightVec) = wv.sum isempty(wv::WeightVec) = isempty(wv.values) +Base.getindex(wv::WeightVec, i) = getindex(wv.values, i) + ##### Weighted sum ##### From 87b8722ea034136c8c5daf2db216d1b4b8c3e731 Mon Sep 17 00:00:00 2001 From: Dahua Lin Date: Fri, 13 Jun 2014 17:04:14 -0500 Subject: [PATCH 7/7] no need for numslices --- src/common.jl | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/src/common.jl b/src/common.jl index d8e7e02b9f55f0..d8c8dcfc6c6e15 100644 --- a/src/common.jl +++ b/src/common.jl @@ -26,14 +26,3 @@ fptype{T<:Union(Float32,Bool,Int8,Uint8,Int16,Uint16)}(::Type{T}) = Float32 fptype{T<:Union(Float64,Int64,Uint64,Int128,Uint128)}(::Type{T}) = Float64 fptype(::Type{Complex64}) = Complex64 fptype(::Type{Complex128}) = Complex128 - -## auxiliary functions - -function numslices(A::AbstractArray, d::Int) # the number of d-dimensional slices - ns = 1 - for i=d+1:ndims(A) - ns *= size(A,i) - end - return ns::Int -end -