-
Notifications
You must be signed in to change notification settings - Fork 81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Transpose
and Adjoint
in broadcast better
#148
Conversation
I wonder if this approach can be adapted to allow arbitrarily nested wrapper types. Currently it doesn't: julia> using CuArrays, Adapt
julia> CuArrays.allowscalar(false);
julia> xs = Adapt.adapt(CuArray{Complex{Float32}},[1.0+0.0im]);
julia> ys = transpose(xs)
1×1 LinearAlgebra.Transpose{Complex{Float64},CuArray{Complex{Float64},1}}:
1.0 + 0.0im
julia> ys .= 1
[ Info: getfield(GPUArrays, Symbol("##19#20"))(), (LinearAlgebra.Transpose{Complex{Float64},CuArray{Complex{Float64},1}}, Base.Broadcast.Broadcasted{Nothing,Tuple{Base.OneTo{Int64},Base.OneTo{Int64}},typeof(identity),Tuple{Int64}})
1×1 LinearAlgebra.Transpose{Complex{Float64},CuArray{Complex{Float64},1}}:
1.0 + 0.0im
julia> zs = adjoint(ys)
1×1 LinearAlgebra.Adjoint{Complex{Float64},LinearAlgebra.Transpose{Complex{Float64},CuArray{Complex{Float64},1}}}:
Error showing value of type LinearAlgebra.Adjoint{Complex{Float64},LinearAlgebra.Transpose{Complex{Float64},CuArray{Complex{Float64},1}}}:
ERROR: scalar getindex is disabled
julia> zs .= 1
ERROR: scalar setindex! is disabled |
I am not sure we can handle the recursive case without injecting an abstraction into the type-hierarchy. But everything gets fuzzy once you start thinking about |
src/mapreduce.jl
Outdated
@@ -5,7 +5,10 @@ | |||
|
|||
Base.any(A::GPUArray{Bool}) = mapreduce(identity, |, false, A) | |||
Base.all(A::GPUArray{Bool}) = mapreduce(identity, &, true, A) | |||
Base.count(pred, A::GPUArray) = Int(mapreduce(pred, +, 0, A)) | |||
|
|||
Base.any(f::Function, A::GPUArray) = mapreduce(f, |, false, A) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just noticed that Base.any
and Base.all
have defined slightly different semantics from the mapreduce-based implementation here when I was trying to incorporate the changes into #145 :
- They are short-circuiting evaluations in
Base
. This changes semantics when the predicate has side-effects. Maybe we can just make it very clear in documentation that the GPU implementation is not short-circuiting? Base
implementations allowmissing
values, and deal with them by following three-valued logic. Should we follow theBase
behaviour here? It might end up with something like:
Base.any(pred, A::GPUArray) = let t =
begin
mapreduce(x -> let v = pred(x)
ismissing(v) ? (false, true) : (v && true, false) end, # enforce v is `Bool` or `missing`
((v1,m1),(v2,m2)) -> (v1 || v2, m1 || m2),
A; init = (false, false))
end
t[1] ? true : (t[2] ? missing : false)
end
I was thinking about marking wrapped arrays as GPU compatible by using another wrapper struct, so that we can get a dispatch target. Essentially this is using a concrete wrapper type (called It will be necessary to re-implement all Base. methods that directly access its content, like what is done for We will need to special case for all these wrappers we want to support vectorisation, but I think that is okay, because only very few wrappers, like I've got a minimal example to demonstrate the idea. I'll really appreciate your opinion as I think it would be really useful to get some simple wrappers like abstract type MyInterface{T} end # like AbstractArray
abstract type MyGPUStruct{T} <: MyInterface{T} end # like GPUArray
struct MyStruct{T} <: MyGPUStruct{T} # like CuArray
name::String
x::T
end
struct MyWrapper{T,AT <: MyInterface{T}} <: MyInterface{T} # like LinearAlgebra.Transpose
parent::AT
end
struct MyWrapper2{T,AT <: MyInterface{T}} <: MyInterface{T} # any Wrapper that we want it to fall back to base
parent2::AT
end
# methods in Base to create wrappers, like view, transpose and adjoint
wrap1(x::MyInterface) = MyWrapper(x)
wrap2(x::MyInterface) = MyWrapper2(x)
# all methods in `Base` to be overridden,
# like what is already done with GPUArray itself
function get_name(::MyInterface)
error("not implemented")
end
#think of get_name as broadcast, fill!, copyto! or show, etc
# we want to override it for OnGPU later
get_name(o::MyStruct) = begin println("MyStruct"); o.name end
get_name(o::MyWrapper) = begin println("Wrapper"); get_name(o.parent) end
get_name(o::MyWrapper2) = begin println("Wrapper2"); get_name(o.parent2) end
#methods that simply fall back to base
get_x(o::MyStruct) = o.x
get_x(o::MyWrapper) = get_x(o.parent)
get_x(o::MyWrapper) = get_x(o.parent2)
# our marker wrapper that marks an MyInterface like thing
# to be GPU compatible. It only propagates to the next layer of
# wrapper if only the next layer of wrapping
# is also GPU compatible
# type param BE remembers the backend type to be able to
# dispatch to specific GPU kernels
# type param AT reveals the type of underlying array
# so that you can also dispatch on that
struct OnGPU{T,BE<:MyGPUStruct{T},AT<:MyInterface{T}} <: MyInterface{T}
content::AT
end
# need to override methods that create wrappers we care about,
# e.g. Base.view, transpose, etc.
wrap1(x::BE) where {T,BE <: MyGPUStruct{T}} = begin
y = MyWrapper(x)
OnGPU{T,BE,typeof(y)}(y)
end
wrap1(x::OnGPU{T,BE,AT}) where {T,BE,AT} = begin
y = MyWrapper(x.content)
OnGPU{T,BE,typeof(y)}(y)
end
# capture vector operation methods like fill!
# for broadcast to work we need to fix
# BroadcastStyle(::Type{OnGPU{T,<:MyGPUStruct,<:MyInterface}})
get_name(x::OnGPU) = get_name_gpu(x.content)
# methods that fall back to base
get_x(x::OnGPU) = get_x(x.content)
# GPU methods which are generic enough to be handled in
# GPUArrays rather than in backend
# in this case, simply unwrapping MyWrapper and pass it onto
# get_name_gpu(::MyGPUStruct)
get_name_gpu(x::MyWrapper) = get_name_gpu(x.parent)
# default behaviour should be falling back to base
get_name_gpu(x::MyGPUStruct) = error("fall back to base")
# hopefully implemented in GPU backend, don't know how difficult
# it would be to handle recursively nested wrappers that wrap
# a GPUArray(CuArray).
get_name_gpu(x::MyStruct) = get_name(x)
x = MyStruct("x",1) # uses gpu method
y = wrap1(x) # dispatch to gpu method
z = wrap1(y) # dispatch to gpu method
w = wrap2(z) # fall back to base for non-compatible wrappers
v = wrap1(w) # fall back to base again
get_name(x)
get_name(y)
get_name(z)
get_name(w)
get_name(v) |
@ssz66666 Interesting design! But wouldn't it break use of Base/user code that expects wrappers? eg. For now, I've worked a little on this PR, at least improving the use of non-nested Transpose/Adjoint/SubArray. Let's merge this first and figure out a better wrapper design after that? |
Agreed with @maleadt , solving the whole wrapper support design problem seems out of scope of this PR. Also, my proposal does break functions expecting a wrapper directly. An obvious use case is |
When we tag we should bump the major version due to the change in API |
@maleadt as discussed earlier today, I didn't quite finish it so the CuArray changes are at https://github.com/JuliaGPU/CuArrays.jl/tree/vc/things
@SimonDanisch thoughts on the backend addition? The issue is that with wrapper types like
Adjoint
/Transpose
/SubArray
there is no easy dispatch target anymore so something trait-like might work quite well.@andreasnoack would it make sense to introduce
Adjoint{T, A} <: Wrapper{T, A<:AbstractArray{T}} <: AbstractArray{T}
into the base hierarchy? Right now we need to enumerate and special treat all the wrapper types.