Skip to content

Commit

Permalink
Merge pull request #18 from JuliaGPU/tb/gpuarrays
Browse files Browse the repository at this point in the history
Update to latest AbstractGPU stack
  • Loading branch information
maleadt authored Oct 12, 2020
2 parents 703d018 + 0828692 commit 193b0f7
Show file tree
Hide file tree
Showing 7 changed files with 266 additions and 152 deletions.
16 changes: 12 additions & 4 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@ uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b"

[[GPUArrays]]
deps = ["AbstractFFTs", "Adapt", "LinearAlgebra", "Printf", "Random", "Serialization"]
git-tree-sha1 = "da6398282abd2a8c0dc3e55b49d984fcc2c582e5"
git-tree-sha1 = "e39817aafb64a0794817a1e5126d042d0b26f700"
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
version = "5.2.1"
version = "6.0.1"

[[GPUCompiler]]
deps = ["DataStructures", "InteractiveUtils", "LLVM", "Libdl", "TimerOutputs", "UUIDs"]
git-tree-sha1 = "1b19d415fc3581ff0ed2f57875fca16b5190060a"
deps = ["DataStructures", "InteractiveUtils", "LLVM", "Libdl", "Scratch", "Serialization", "TimerOutputs", "UUIDs"]
git-tree-sha1 = "8b4f39f74d89d3d9de68b5fc5b090858838caf6d"
repo-rev = "bc13183"
repo-url = "https://github.com/JuliaGPU/GPUCompiler.jl.git"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
version = "0.7.3"

Expand Down Expand Up @@ -140,6 +142,12 @@ git-tree-sha1 = "f459073ffe7b1247ea784df16b7059ed54bf0734"
uuid = "6ac6d60f-d740-5983-97d7-a4482c0689f4"
version = "2020.2.0+1"

[[Scratch]]
deps = ["Dates"]
git-tree-sha1 = "ad4b278adb62d185bbcb6864dc24959ab0627bf6"
uuid = "6c6a2e73-6563-6170-7368-637461726353"
version = "1.0.3"

[[Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"

Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ oneAPI_Level_Zero_Loader_jll = "13eca655-d68d-5b81-8367-6d99d727ab01"
[compat]
Adapt = "2.0"
CEnum = "0.4"
GPUArrays = "5.2"
GPUArrays = "6.0"
GPUCompiler = "0.7"
LLVM = "3"
julia = "1.5"
Expand Down
108 changes: 98 additions & 10 deletions lib/level-zero/pointer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export ZePtr, ZE_NULL, PtrOrZePtr


#
# device pointer
# Device pointer
#

"""
Expand Down Expand Up @@ -56,6 +56,11 @@ Base.cconvert(::Type{<:ZePtr}, x) = x
# fallback for unsafe_convert
Base.unsafe_convert(::Type{P}, x::ZePtr) where {P<:ZePtr} = convert(P, x)

# from arrays
Base.unsafe_convert(::Type{ZePtr{S}}, a::AbstractArray{T}) where {S,T} =
convert(CuPtr{S}, Base.unsafe_convert(CuPtr{T}, a))
Base.unsafe_convert(::Type{ZePtr{T}}, a::AbstractArray{T}) where {T} =
error("conversion to pointer not defined for $(typeof(a))")

## limited pointer arithmetic & comparison

Expand All @@ -73,7 +78,7 @@ Base.:(+)(x::Integer, y::ZePtr) = y + x


#
# device or host pointer
# Host or device pointer
#

"""
Expand Down Expand Up @@ -110,15 +115,98 @@ function Base.cconvert(::Type{PtrOrZePtr{T}}, val) where {T}
end

function Base.unsafe_convert(::Type{PtrOrZePtr{T}}, val) where {T}
# FIXME: this is expensive; optimize using isapplicable?
ptr = try
ptr = if Core.Compiler.return_type(Base.unsafe_convert,
Tuple{Type{Ptr{T}}, typeof(val)}) !== Union{}
Base.unsafe_convert(Ptr{T}, val)
catch
try
Base.unsafe_convert(ZePtr{T}, val)
catch
throw(ArgumentError("cannot convert to either a host or device pointer"))
end
elseif Core.Compiler.return_type(Base.unsafe_convert,
Tuple{Type{ZePtr{T}}, typeof(val)}) !== Union{}
Base.unsafe_convert(ZePtr{T}, val)
else
throw(ArgumentError("cannot convert to either a host or device pointer"))
end

return Base.bitcast(PtrOrZePtr{T}, ptr)
end


#
# Device reference objects
#

if sizeof(Ptr{Cvoid}) == 8
primitive type ZeRef{T} 64 end
else
primitive type ZeRef{T} 32 end
end

# general methods for ZeRef{T} type
Base.eltype(x::Type{<:ZeRef{T}}) where {T} = @isdefined(T) ? T : Any

Base.convert(::Type{ZeRef{T}}, x::ZeRef{T}) where {T} = x

# conversion or the actual ccall
Base.unsafe_convert(::Type{ZeRef{T}}, x::ZeRef{T}) where {T} = Base.bitcast(ZeRef{T}, Base.unsafe_convert(ZePtr{T}, x))
Base.unsafe_convert(::Type{ZeRef{T}}, x) where {T} = Base.bitcast(ZeRef{T}, Base.unsafe_convert(ZePtr{T}, x))

# ZeRef from literal pointer
Base.convert(::Type{ZeRef{T}}, x::ZePtr{T}) where {T} = x

# indirect constructors using ZeRef
ZeRef(x::Any) = ZeRefArray(ZeArray([x]))
ZeRef{T}(x) where {T} = ZeRefArray{T}(ZeArray(T[x]))
ZeRef{T}() where {T} = ZeRefArray(ZeArray{T}(undef, 1))
Base.convert(::Type{ZeRef{T}}, x) where {T} = ZeRef{T}(x)


## ZeRef object backed by a CUDA array at index i

struct ZeRefArray{T,A<:AbstractArray{T}} <: Ref{T}
x::A
i::Int
ZeRefArray{T,A}(x,i) where {T,A<:AbstractArray{T}} = new(x,i)
end
ZeRefArray{T}(x::AbstractArray{T}, i::Int=1) where {T} = ZeRefArray{T,typeof(x)}(x, i)
ZeRefArray(x::AbstractArray{T}, i::Int=1) where {T} = ZeRefArray{T}(x, i)
Base.convert(::Type{ZeRef{T}}, x::AbstractArray{T}) where {T} = ZeRefArray(x, 1)

function Base.unsafe_convert(P::Type{ZePtr{T}}, b::ZeRefArray{T}) where T
return pointer(b.x, b.i)
end
function Base.unsafe_convert(P::Type{ZePtr{Any}}, b::ZeRefArray{Any})
return convert(P, pointer(b.x, b.i))
end
Base.unsafe_convert(::Type{ZePtr{Cvoid}}, b::ZeRefArray{T}) where {T} =
convert(ZePtr{Cvoid}, Base.unsafe_convert(ZePtr{T}, b))


## Union with all ZeRef 'subtypes'

const ZeRefs{T} = Union{ZePtr{T}, ZeRefArray{T}}


## RefOrZeRef

if sizeof(Ptr{Cvoid}) == 8
primitive type RefOrZeRef{T} 64 end
else
primitive type RefOrZeRef{T} 32 end
end

Base.convert(::Type{RefOrZeRef{T}}, x::Union{RefOrZeRef{T}, Ref{T}, ZeRef{T}, ZeRefs{T}}) where {T} = x

# prefer conversion to CPU ref: this is generally cheaper
Base.convert(::Type{RefOrZeRef{T}}, x) where {T} = Ref{T}(x)
Base.unsafe_convert(::Type{RefOrZeRef{T}}, x::Ref{T}) where {T} =
Base.bitcast(RefOrZeRef{T}, Base.unsafe_convert(Ptr{T}, x))
Base.unsafe_convert(::Type{RefOrZeRef{T}}, x) where {T} =
Base.bitcast(RefOrZeRef{T}, Base.unsafe_convert(Ptr{T}, x))

# support conversion from GPU ref
Base.unsafe_convert(::Type{RefOrZeRef{T}}, x::ZeRefs{T}) where {T} =
Base.bitcast(RefOrZeRef{T}, Base.unsafe_convert(ZePtr{T}, x))

# support conversion from arrays
Base.convert(::Type{RefOrZeRef{T}}, x::Array{T}) where {T} = convert(Ref{T}, x)
Base.convert(::Type{RefOrZeRef{T}}, x::AbstractArray{T}) where {T} = convert(ZeRef{T}, x)
Base.unsafe_convert(P::Type{RefOrZeRef{T}}, b::ZeRefArray{T}) where T =
Base.bitcast(RefOrZeRef{T}, Base.unsafe_convert(ZeRef{T}, b))
Loading

0 comments on commit 193b0f7

Please sign in to comment.