diff --git a/src/abstractarray.jl b/src/abstractarray.jl index b5560960..7148924a 100644 --- a/src/abstractarray.jl +++ b/src/abstractarray.jl @@ -56,14 +56,24 @@ end ## showing -Base.show(io::IO, x::GPUArray) = Base.show(io, Array(x)) -Base.show(io::IO, x::LinearAlgebra.Adjoint{<:Any,<:GPUArray}) = - Base.show(io, LinearAlgebra.adjoint(Array(x.parent))) -Base.show(io::IO, x::LinearAlgebra.Transpose{<:Any,<:GPUArray}) = - Base.show(io, LinearAlgebra.transpose(Array(x.parent))) - -Base.show_vector(io::IO, x::GPUArray) = Base.show_vector(io, Array(x)) - +for (atype, op) in + [(:(GPUArray), :(Array)), + (:(LinearAlgebra.Adjoint{<:Any,<:GPUArray}), :(x->LinearAlgebra.adjoint(Array(parent(x))))), + (:(LinearAlgebra.Transpose{<:Any,<:GPUArray}), :(x->LinearAlgebra.transpose(Array(parent(x)))))] + @eval begin + # for display + Base.print_array(io::IO, X::($atype)) = + Base.print_array(io,($op)(X)) + + # for show + Base._show_nonempty(io::IO, X::($atype), prefix::String) = + Base._show_nonempty(io,($op)(X),prefix) + Base._show_empty(io::IO, X::($atype)) = + Base._show_empty(io,($op)(X)) + Base.show_vector(io::IO, v::($atype), args...) = + Base.show_vector(io,($op)(v),args...) + end +end # memory operations diff --git a/src/mapreduce.jl b/src/mapreduce.jl index 87672532..2f1c7ac5 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -3,16 +3,16 @@ # functions in base implemented with a direct loop need to be overloaded to use mapreduce -Base.any(A::GPUArray{Bool}) = mapreduce(identity, |, false, A) -Base.all(A::GPUArray{Bool}) = mapreduce(identity, &, true, A) -Base.count(pred, A::GPUArray) = Int(mapreduce(pred, +, 0, A)) +Base.any(A::GPUArray{Bool}) = mapreduce(identity, |, A; init = false) +Base.all(A::GPUArray{Bool}) = mapreduce(identity, &, A; init = true) +Base.count(pred, A::GPUArray) = Int(mapreduce(pred, +, A; init = 0)) -Base.:(==)(A::GPUArray, B::GPUArray) = Bool(mapreduce(==, &, true, A, B)) +Base.:(==)(A::GPUArray, B::GPUArray) = Bool(mapreduce(==, &, A, B; init = true)) # hack to get around of fetching the first element of the GPUArray # as a startvalue, which is a bit complicated with the current reduce implementation function startvalue(f, T) - error("Please supply a starting value for mapreduce. E.g: mapreduce(func, $f, 1, A)") + error("Please supply a starting value for mapreduce. E.g: mapreduce(func, $f, A; init = 1)") end startvalue(::typeof(+), T) = zero(T) startvalue(::typeof(Base.add_sum), T) = zero(T) @@ -50,20 +50,30 @@ gpu_promote_type(::typeof(Base.mul_prod), ::Type{T}) where {T<:Number} = typeof( gpu_promote_type(::typeof(max), ::Type{T}) where {T<: WidenReduceResult} = T gpu_promote_type(::typeof(min), ::Type{T}) where {T<: WidenReduceResult} = T -function Base.mapreduce(f::Function, op::Function, A::GPUArray{T, N}) where {T, N} +function Base.mapreduce(f::Function, op::Function, A::GPUArray{T, N}; dims = :, init...) where {T, N} + mapreduce_impl(f, op, init.data, A, dims) +end + +function mapreduce_impl(f, op, ::NamedTuple{()}, A::GPUArray{T, N}, ::Colon) where {T, N} OT = gpu_promote_type(op, T) v0 = startvalue(op, OT) # TODO do this better - mapreduce(f, op, v0, A) + acc_mapreduce(f, op, v0, A, ()) end -function acc_mapreduce end -function Base.mapreduce(f, op, v0, A::GPUArray, B::GPUArray, C::Number) - acc_mapreduce(f, op, v0, A, (B, C)) + +function mapreduce_impl(f, op, nt::NamedTuple{(:init,)}, A::GPUArray{T, N}, ::Colon) where {T, N} + acc_mapreduce(f, op, nt.init, A, ()) end -function Base.mapreduce(f, op, v0, A::GPUArray, B::GPUArray) - acc_mapreduce(f, op, v0, A, (B,)) + +function mapreduce_impl(f, op, nt, A::GPUArray{T, N}, dims) where {T, N} + Base._mapreduce_dim(f, op, nt, A, dims) end -function Base.mapreduce(f, op, v0, A::GPUArray) - acc_mapreduce(f, op, v0, A, ()) + +function acc_mapreduce end +function Base.mapreduce(f, op, A::GPUArray, B::GPUArray, C::Number; init) + acc_mapreduce(f, op, init, A, (B, C)) +end +function Base.mapreduce(f, op, A::GPUArray, B::GPUArray; init) + acc_mapreduce(f, op, init, A, (B,)) end @generated function mapreducedim_kernel(state, f, op, R, A, range::NTuple{N, Any}) where N diff --git a/src/testsuite/io.jl b/src/testsuite/io.jl index 0e75a2b1..619afe46 100644 --- a/src/testsuite/io.jl +++ b/src/testsuite/io.jl @@ -3,11 +3,29 @@ function test_io(AT) @testset "showing" begin io = IOBuffer() A = AT(Int64[1]) + B = AT(Int64[1 2;3 4]) # vectors and non-vector arrays showing + # are handled differently in base/arrayshow.jl show(io, MIME("text/plain"), A) seekstart(io) @test String(take!(io)) == "1-element $AT{Int64,1}:\n 1" + show(io, A) + seekstart(io) + msg = String(take!(io)) # result of e.g. `print` differs on 32bit and 64bit machines + # due to different definition of `Int` type + # print([1]) shows as [1] on 64bit but Int64[1] on 32bit + @test msg == "[1]" || msg == "Int64[1]" + + show(io, MIME("text/plain"), B) + seekstart(io) + @test String(take!(io)) == "2×2 $AT{Int64,2}:\n 1 2\n 3 4" + + show(io, B) + seekstart(io) + msg = String(take!(io)) + @test msg == "[1 2; 3 4]" || msg == "Int64[1 2; 3 4]" + show(io, MIME("text/plain"), A') seekstart(io) msg = String(take!(io)) # the printing of Adjoint depends on global state diff --git a/src/testsuite/mapreduce.jl b/src/testsuite/mapreduce.jl index c0a6473c..010110cb 100644 --- a/src/testsuite/mapreduce.jl +++ b/src/testsuite/mapreduce.jl @@ -21,6 +21,15 @@ function test_mapreduce(AT) x = T(y) @test sum(y, dims = 2) ≈ Array(sum(x, dims = 2)) @test sum(y, dims = 1) ≈ Array(sum(x, dims = 1)) + + y = rand(range, N, N) + x = T(y) + _zero = zero(ET) + _addone(z) = z + one(ET) + @test mapreduce(_addone, +, y; dims = 2, init = _zero) ≈ + Array(mapreduce(_addone, +, x; dims = 2, init = _zero)) + @test mapreduce(_addone, +, y; init = _zero) ≈ + mapreduce(_addone, +, x; init = _zero) end end @testset "sum maximum minimum prod" begin