Skip to content

Commit

Permalink
Merge pull request #145 from ssz66666/master
Browse files Browse the repository at this point in the history
Fix `Base.mapreduce` to match the new signature in 0.7/1.0
  • Loading branch information
maleadt authored Aug 27, 2018
2 parents d98ca89 + 9664691 commit e10a9b9
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 22 deletions.
26 changes: 18 additions & 8 deletions src/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,24 @@ end

## showing

Base.show(io::IO, x::GPUArray) = Base.show(io, Array(x))
Base.show(io::IO, x::LinearAlgebra.Adjoint{<:Any,<:GPUArray}) =
Base.show(io, LinearAlgebra.adjoint(Array(x.parent)))
Base.show(io::IO, x::LinearAlgebra.Transpose{<:Any,<:GPUArray}) =
Base.show(io, LinearAlgebra.transpose(Array(x.parent)))

Base.show_vector(io::IO, x::GPUArray) = Base.show_vector(io, Array(x))

for (atype, op) in
[(:(GPUArray), :(Array)),
(:(LinearAlgebra.Adjoint{<:Any,<:GPUArray}), :(x->LinearAlgebra.adjoint(Array(parent(x))))),
(:(LinearAlgebra.Transpose{<:Any,<:GPUArray}), :(x->LinearAlgebra.transpose(Array(parent(x)))))]
@eval begin
# for display
Base.print_array(io::IO, X::($atype)) =
Base.print_array(io,($op)(X))

# for show
Base._show_nonempty(io::IO, X::($atype), prefix::String) =
Base._show_nonempty(io,($op)(X),prefix)
Base._show_empty(io::IO, X::($atype)) =
Base._show_empty(io,($op)(X))
Base.show_vector(io::IO, v::($atype), args...) =
Base.show_vector(io,($op)(v),args...)
end
end

# memory operations

Expand Down
38 changes: 24 additions & 14 deletions src/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
# functions in base implemented with a direct loop need to be overloaded to use mapreduce


Base.any(A::GPUArray{Bool}) = mapreduce(identity, |, false, A)
Base.all(A::GPUArray{Bool}) = mapreduce(identity, &, true, A)
Base.count(pred, A::GPUArray) = Int(mapreduce(pred, +, 0, A))
Base.any(A::GPUArray{Bool}) = mapreduce(identity, |, A; init = false)
Base.all(A::GPUArray{Bool}) = mapreduce(identity, &, A; init = true)
Base.count(pred, A::GPUArray) = Int(mapreduce(pred, +, A; init = 0))

Base.:(==)(A::GPUArray, B::GPUArray) = Bool(mapreduce(==, &, true, A, B))
Base.:(==)(A::GPUArray, B::GPUArray) = Bool(mapreduce(==, &, A, B; init = true))

# hack to get around of fetching the first element of the GPUArray
# as a startvalue, which is a bit complicated with the current reduce implementation
function startvalue(f, T)
error("Please supply a starting value for mapreduce. E.g: mapreduce(func, $f, 1, A)")
error("Please supply a starting value for mapreduce. E.g: mapreduce(func, $f, A; init = 1)")
end
startvalue(::typeof(+), T) = zero(T)
startvalue(::typeof(Base.add_sum), T) = zero(T)
Expand Down Expand Up @@ -50,20 +50,30 @@ gpu_promote_type(::typeof(Base.mul_prod), ::Type{T}) where {T<:Number} = typeof(
gpu_promote_type(::typeof(max), ::Type{T}) where {T<: WidenReduceResult} = T
gpu_promote_type(::typeof(min), ::Type{T}) where {T<: WidenReduceResult} = T

function Base.mapreduce(f::Function, op::Function, A::GPUArray{T, N}) where {T, N}
function Base.mapreduce(f::Function, op::Function, A::GPUArray{T, N}; dims = :, init...) where {T, N}
mapreduce_impl(f, op, init.data, A, dims)
end

function mapreduce_impl(f, op, ::NamedTuple{()}, A::GPUArray{T, N}, ::Colon) where {T, N}
OT = gpu_promote_type(op, T)
v0 = startvalue(op, OT) # TODO do this better
mapreduce(f, op, v0, A)
acc_mapreduce(f, op, v0, A, ())
end
function acc_mapreduce end
function Base.mapreduce(f, op, v0, A::GPUArray, B::GPUArray, C::Number)
acc_mapreduce(f, op, v0, A, (B, C))

function mapreduce_impl(f, op, nt::NamedTuple{(:init,)}, A::GPUArray{T, N}, ::Colon) where {T, N}
acc_mapreduce(f, op, nt.init, A, ())
end
function Base.mapreduce(f, op, v0, A::GPUArray, B::GPUArray)
acc_mapreduce(f, op, v0, A, (B,))

function mapreduce_impl(f, op, nt, A::GPUArray{T, N}, dims) where {T, N}
Base._mapreduce_dim(f, op, nt, A, dims)
end
function Base.mapreduce(f, op, v0, A::GPUArray)
acc_mapreduce(f, op, v0, A, ())

function acc_mapreduce end
function Base.mapreduce(f, op, A::GPUArray, B::GPUArray, C::Number; init)
acc_mapreduce(f, op, init, A, (B, C))
end
function Base.mapreduce(f, op, A::GPUArray, B::GPUArray; init)
acc_mapreduce(f, op, init, A, (B,))
end

@generated function mapreducedim_kernel(state, f, op, R, A, range::NTuple{N, Any}) where N
Expand Down
18 changes: 18 additions & 0 deletions src/testsuite/io.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,29 @@ function test_io(AT)
@testset "showing" begin
io = IOBuffer()
A = AT(Int64[1])
B = AT(Int64[1 2;3 4]) # vectors and non-vector arrays showing
# are handled differently in base/arrayshow.jl

show(io, MIME("text/plain"), A)
seekstart(io)
@test String(take!(io)) == "1-element $AT{Int64,1}:\n 1"

show(io, A)
seekstart(io)
msg = String(take!(io)) # result of e.g. `print` differs on 32bit and 64bit machines
# due to different definition of `Int` type
# print([1]) shows as [1] on 64bit but Int64[1] on 32bit
@test msg == "[1]" || msg == "Int64[1]"

show(io, MIME("text/plain"), B)
seekstart(io)
@test String(take!(io)) == "2×2 $AT{Int64,2}:\n 1 2\n 3 4"

show(io, B)
seekstart(io)
msg = String(take!(io))
@test msg == "[1 2; 3 4]" || msg == "Int64[1 2; 3 4]"

show(io, MIME("text/plain"), A')
seekstart(io)
msg = String(take!(io)) # the printing of Adjoint depends on global state
Expand Down
9 changes: 9 additions & 0 deletions src/testsuite/mapreduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ function test_mapreduce(AT)
x = T(y)
@test sum(y, dims = 2) Array(sum(x, dims = 2))
@test sum(y, dims = 1) Array(sum(x, dims = 1))

y = rand(range, N, N)
x = T(y)
_zero = zero(ET)
_addone(z) = z + one(ET)
@test mapreduce(_addone, +, y; dims = 2, init = _zero)
Array(mapreduce(_addone, +, x; dims = 2, init = _zero))
@test mapreduce(_addone, +, y; init = _zero)
mapreduce(_addone, +, x; init = _zero)
end
end
@testset "sum maximum minimum prod" begin
Expand Down

0 comments on commit e10a9b9

Please sign in to comment.