Skip to content

Commit

Permalink
JuliaGPU#362: More tests and documentation; simplify the dims
Browse files Browse the repository at this point in the history
  • Loading branch information
Evgeny Tankhilevich committed Aug 28, 2019
1 parent 70714d4 commit b0662eb
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
13 changes: 10 additions & 3 deletions src/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,11 @@ function Base._mapreducedim!(f, op, R::CuArray{T}, A::CuArray{T}) where {T}
end

import Base.minimum, Base.maximum, Base.reduce
_initarray(x::CuArray{T}, ::Colon, init) where {T} = fill!(similar(x, T, Base.reduced_indices(x, 1:ndims(x))), init)
_initarray(x::CuArray{T}, dims, init) where {T} = fill!(similar(x, T, Base.reduced_indices(x, dims)), init)

_reduced_dims(x::CuArray, ::Colon) = Tuple(ones(Int, ndims(x)))
_reduced_dims(x::CuArray, dims) = Base.reduced_indices(x, dims)

reduce(op, x::CuArray; dims=:, init) where {T} = _reduce(op, x, init, dims)
_initarray(x::CuArray{T}, dims, init) where {T} = fill!(similar(x, T, _reduced_dims(x, dims)), init)

function _reduce(op, x::CuArray, init, ::Colon)
mx = _initarray(x, :, init)
Expand All @@ -114,5 +114,12 @@ function _reduce(op, x::CuArray, init, dims)
Base._mapreducedim!(identity, op, mx, x)
end

"""
reduce(op, x::CuArray; dims=:, init)
The initial value `init` is mandatory for `reduce` on `CuArray`'s. It must be a neutral element for `op`.
"""
reduce(op, x::CuArray; dims=:, init) = _reduce(op, x, init, dims)

minimum(x::CuArray{T}; dims=:) where {T} = _reduce(min, x, typemax(T), dims)
maximum(x::CuArray{T}; dims=:) where {T} = _reduce(max, x, typemin(T), dims)
17 changes: 13 additions & 4 deletions test/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,25 @@ end

@test testf(x -> minimum(x), rand(2, 3))
@test testf(x -> minimum(x, dims=2), rand(2, 3))
@test testf(x -> minimum(x, dims=(2,3)), rand(2, 3, 4))
@test testf(x -> minimum(x, dims=(2, 3)), rand(2, 3, 4))

@test testf(x -> maximum(x), rand(2, 3))
@test testf(x -> maximum(x, dims=2), rand(2, 3))
@test testf(x -> maximum(x, dims=(2,3)), rand(2, 3, 4))
@test testf(x -> maximum(x, dims=(2, 3)), rand(2, 3, 4))

myreducer(x1, x2) = x1+x2 # bypass optimisations for sum()
@test testf(x -> reduce(myreducer, x, dims=(2,3), init=0.0), rand(2, 3, 4))
myreducer(x1, x2) = x1 + x2 # bypass optimisations for sum()
@test testf(x -> reduce(myreducer, x, dims=(2, 3), init=0.0), rand(2, 3, 4))
@test testf(x -> reduce(myreducer, x, init=0.0), rand(2, 3))
@test testf(x -> reduce(myreducer, x, dims=2, init=42.0), rand(2, 3))

ex = ErrorException("Please supply a neutral element for &. E.g: mapreduce(f, &, A; init = 1)")
@test_throws ex mapreduce(t -> t > 0.5, &, cu(rand(2, 3)))
@test testf(x -> mapreduce(t -> t > 0.5, &, x, init=true), rand(2, 3))

ex = UndefKeywordError(:init)
cub = map(t -> t > 0.5, cu(rand(2, 3)))
@test_throws ex reduce(|, cub)
@test testf(x -> reduce(|, x, init=false), map(t -> t > 0.5, cu(rand(2, 3))))
end

@testset "0D" begin
Expand Down

0 comments on commit b0662eb

Please sign in to comment.