Skip to content
This repository has been archived by the owner on Mar 12, 2021. It is now read-only.

Commit

Permalink
Try #612:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] authored Mar 13, 2020
2 parents ae239d7 + 1214a37 commit 10001d7
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 0 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
AbstractFFTs = "0.4, 0.5"
Expand Down
1 change: 1 addition & 0 deletions src/CuArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ include("accumulate.jl")
include("linalg.jl")
include("nnlib.jl")
include("iterator.jl")
include("statistics.jl")

include("deprecated.jl")

Expand Down
13 changes: 13 additions & 0 deletions src/statistics.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import Statistics

Statistics._var(A::CuArray, corrected::Bool, mean, dims) =
sum((A .- something(mean, Statistics.mean(A, dims=dims))).^2, dims=dims)/(prod(size(A)[[dims...]])-corrected)

Statistics._var(A::CuArray, corrected::Bool, mean, ::Colon) =
sum((A .- something(mean, Statistics.mean(A))).^2)/(length(A)-corrected)

Statistics._std(A::CuArray, corrected::Bool, mean, dims) =
sqrt.(Statistics.var(A; corrected=corrected, mean=mean, dims=dims))

Statistics._std(A::CuArray, corrected::Bool, mean, ::Colon) =
sqrt.(Statistics.var(A; corrected=corrected, mean=mean, dims=:))
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ include("iterator.jl")

include("forwarddiff.jl")
include("nnlib.jl")
include("statistics.jl")

if haskey(ENV, "CI")
CuArrays.memory_status()
Expand Down
18 changes: 18 additions & 0 deletions test/statistics.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
@testset "Statistics" begin

using Statistics

@testset "std" begin
@test testf(std, rand(10))
@test testf(std, rand(10,1,2))
@test testf(std, rand(10,1,2), corrected=true)
@test testf(std, rand(10,1,2), dims=1)
end
@testset "var" begin
@test testf(var, rand(10))
@test testf(var, rand(10,1,2))
@test testf(var, rand(10,1,2), corrected=true)
@test testf(var, rand(10,1,2), dims=1)
end

end

0 comments on commit 10001d7

Please sign in to comment.