Skip to content
This repository was archived by the owner on May 27, 2021. It is now read-only.

Commit

Permalink
Call cudaconvert in user code, allowing new definitions.
Browse files Browse the repository at this point in the history
  • Loading branch information
MikeInnes authored and maleadt committed Jul 31, 2017
1 parent 91f0e34 commit 443b22e
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 65 deletions.
2 changes: 1 addition & 1 deletion src/device/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ function Base.convert(::Type{CuDeviceArray{T,N,AS.Global}}, a::CuArray{T,N}) whe
ptr = Base.unsafe_convert(Ptr{T}, owned_ptr)
CuDeviceArray{T,N,AS.Global}(a.shape, DevicePtr{T,AS.Global}(ptr))
end
cudaconvert(::Type{CuArray{T,N}}) where {T,N} = CuDeviceArray{T,N,AS.Global}
cudaconvert(a::CuArray{T,N}) where {T,N} = convert(CuDeviceArray{T,N,AS.Global}, a)


## indexing
Expand Down
45 changes: 16 additions & 29 deletions src/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,36 +9,22 @@ using Base.Iterators: filter
# Auxiliary
#

# NOTE: this method cannot be extended, because it is used in a generated function
cudaconvert(::Type{T}) where {T} = T

# Convert the arguments to a kernel function to their CUDA representation, and figure out
# what types to specialize the kernel function for.
function convert_arguments(args, types)
argtypes = DataType[types...]
argexprs = Union{Expr,Symbol}[args...]

# convert types to their CUDA representation
for i in 1:length(argexprs)
t = argtypes[i]
ct = cudaconvert(t)
if ct != t
argtypes[i] = ct
if ct <: Ptr
argexprs[i] = :( Base.unsafe_convert($ct, $(argexprs[i])) )
else
argexprs[i] = :( convert($ct, $(argexprs[i])) )
end
end
end
"""
cudaconvert(x)
for argtype in argtypes
if argtype.layout == C_NULL || !Base.datatype_pointerfree(argtype)
error("don't know how to handle argument of type $argtype")
end
This function is called for every argument to be passed to a kernel, allowing it to be
converted to a GPU-friendly format. By default, the function does nothing and returns the
input object `x` as-is.
For `CuArray` objects, a corresponding `CuDeviceArray` object in global space is returned,
which implements GPU-compatible array functionality.
"""
function cudaconvert(x::T) where {T}
if T.layout == C_NULL || !Base.datatype_pointerfree(T)
error("don't know how to handle argument of type $T")
end

return argexprs, argtypes
return x
end

function emit_cudacall(func, dims, shmem, stream, types, args)
Expand Down Expand Up @@ -100,7 +86,8 @@ macro cuda(config::Expr, callexpr::Expr)
stream = length(config.args)==4 ? esc(pop!(config.args)) : :(CuDefaultStream())
shmem = length(config.args)==3 ? esc(pop!(config.args)) : :(0)
dims = esc(config)
return :(generated_cuda($dims, $shmem, $stream, $(map(esc, callexpr.args)...)))
args = :(cudaconvert.(($(map(esc, callexpr.args)...),)))
return :(generated_cuda($dims, $shmem, $stream, $args...))
end

# Compile and execute a CUDA kernel from a Julia function
Expand All @@ -109,7 +96,7 @@ const compilecache = Dict{UInt, CuFunction}()
@generated function generated_cuda{F<:Core.Function,N}(dims::Tuple{CuDim, CuDim}, shmem, stream,
func::F, args::Vararg{Any,N})
arg_exprs = [:( args[$i] ) for i in 1:N]
arg_exprs, arg_types = convert_arguments(arg_exprs, args)
arg_types = args

# compile the function, if necessary
@gensym cuda_fun
Expand Down
42 changes: 7 additions & 35 deletions src/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,43 +109,19 @@ code_sass(func::ANY, types::ANY=Tuple; kwargs...) = code_sass(STDOUT, func, type
# @code_* replacements
#

function gen_call_with_extracted_types(f, ex)
:($f($(esc(ex.args[1])), Base.typesof(cudaconvert.(($(esc.(ex.args[2:end])...),))...)))
end

for (fname,kernel_arg) in [(:code_lowered, false), (:code_typed, false), (:code_warntype, false),
(:code_llvm, true), (:code_ptx, true), (:code_sass, false)]
# types often need to be converted (eg. CuArray -> CuDeviceArray),
# so generate a type-converting wrapper, and a macro to call it
fname_wrapper = Symbol(fname, :_cputyped)
if kernel_arg
# some reflection functions take a `kernel` argument, indicating whether
# kernel function or device function conventions should be used
@eval begin
function $fname_wrapper(func, types, kernel::Bool)
_, arg_types =
convert_arguments(fill(Symbol(), length(types.parameters)),
types.parameters)
$fname(func, arg_types; kernel=kernel)
end
end
else
@eval begin
function $fname_wrapper(func, types)
_, arg_types =
convert_arguments(fill(Symbol(), length(types.parameters)),
types.parameters)
$fname(func, arg_types)
end
end
end

# TODO: test the kernel_arg-based behavior

@eval begin
@doc $"""
$fname
Extracts the relevant function call from any `@cuda` invocation, evaluates the
arguments to the function or macro call, determines their types (taking into account
GPU-specific type conversions), and calls $fname on the resulting expression.
Can be applied to a pure function call, or a call prefixed with the `@cuda` macro.
In that case, kernel code generation conventions are used (wrt. argument conversions,
return values, etc).
Expand All @@ -162,14 +138,10 @@ for (fname,kernel_arg) in [(:code_lowered, false), (:code_typed, false), (:code_
kernel = false
end

wrapper(func, types) = $kernel_arg ? $fname_wrapper(func, types, kernel) :
$fname_wrapper(func, types)
wrapper(func, types) = $kernel_arg ? $fname(func, types, kernel = kernel) :
$fname(func, types)

if Base.VERSION >= v"0.7.0-DEV.481"
Base.gen_call_with_extracted_types(__module__, wrapper, ex0)
else
Base.gen_call_with_extracted_types(wrapper, ex0)
end
gen_call_with_extracted_types(wrapper, ex0)
end
end
end

0 comments on commit 443b22e

Please sign in to comment.