Skip to content
This repository was archived by the owner on May 27, 2021. It is now read-only.

Commit

Permalink
Simplify execution.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Aug 1, 2017
1 parent d72fcf7 commit 8f80c05
Showing 1 changed file with 28 additions and 49 deletions.
77 changes: 28 additions & 49 deletions src/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@ export @cuda, nearest_warpsize, cudaconvert
using Base.Iterators: filter


#
# Auxiliary
#

"""
cudaconvert(x)
Expand All @@ -27,17 +23,6 @@ function cudaconvert(x::T) where {T}
return x
end

function emit_cudacall(func, dims, shmem, stream, types, args)
# TODO: can we handle non-isbits types?
all(t -> isbits(t) && sizeof(t) > 0, types) ||
error("can only pass bitstypes of size > 0 to CUDA kernels")

return quote
Profile.@launch begin
cudacall($func, $dims[1], $dims[2], $shmem, $stream, Tuple{$(types...)}, $(args...))
end
end
end

# fast lookup of global world age
world_age() = ccall(:jl_get_tls_world_age, UInt, ())
Expand All @@ -50,12 +35,9 @@ function method_age(f, tt)
return -1
end

isghosttype(dt) = !dt.mutable && sizeof(dt) == 0

isghosttype(dt) = !dt.mutable && sizeof(dt) == 0

#
# @cuda macro
#

"""
@cuda (gridDim::CuDim, blockDim::CuDim, [shmem::Int], [stream::CuStream]) func(args...)
Expand Down Expand Up @@ -87,56 +69,53 @@ macro cuda(config::Expr, callexpr::Expr)
shmem = length(config.args)==3 ? esc(pop!(config.args)) : :(0)
dims = esc(config)
args = :(cudaconvert.(($(map(esc, callexpr.args)...),)))
return :(generated_cuda($dims, $shmem, $stream, $args...))
return :(_cuda($dims, $shmem, $stream, $args...))
end

# Compile and execute a CUDA kernel from a Julia function
const agecache = Dict{UInt, UInt}()
const compilecache = Dict{UInt, CuFunction}()
@generated function generated_cuda{F<:Core.Function,N}(dims::Tuple{CuDim, CuDim}, shmem, stream,
func::F, args::Vararg{Any,N})
@generated function _cuda(dims::Tuple{CuDim, CuDim}, shmem, stream,
func::F, args::Vararg{Any,N}) where {F<:Core.Function,N}
arg_exprs = [:( args[$i] ) for i in 1:N]
arg_types = args

# compile the function, if necessary
@gensym cuda_fun
precomp_key = hash(tuple(func, arg_types...)) # precomputable part of the key
kernel_compilation = quote
# filter out ghost arguments
real_args = map(t->!isghosttype(t), arg_types)
real_arg_types = map(x->x[2], filter(x->x[1], zip(real_args, arg_types)))
real_arg_exprs = map(x->x[2], filter(x->x[1], zip(real_args, arg_exprs)))

precomp_key = hash(tuple(func, arg_types...)) # precomputable part of the keys
quote
Base.@_inline_meta

# look-up the method age
key = hash(($precomp_key, world_age()))
if haskey(agecache, key)
age = agecache[key]
key1 = hash(($precomp_key, world_age()))
if haskey(agecache, key1)
age = agecache[key1]
else
age = method_age(func, $arg_types)
agecache[key] = age
agecache[key1] = age
end

# compile the function
ctx = CuCurrentContext()
key = hash(($precomp_key, age, ctx))
if haskey(compilecache, key)
$cuda_fun = compilecache[key]
key2 = hash(($precomp_key, age, ctx))
if haskey(compilecache, key2)
cuda_fun = compilecache[key2]
else
$cuda_fun, _ = cufunction(device(ctx), func, $arg_types)
compilecache[key] = $cuda_fun
cuda_fun, _ = cufunction(device(ctx), func, $arg_types)
compilecache[key2] = cuda_fun
end
end

# filter out non-concrete args
concrete = map(t->!isghosttype(t), arg_types)
arg_types = map(x->x[2], filter(x->x[1], zip(concrete, arg_types)))
arg_exprs = map(x->x[2], filter(x->x[1], zip(concrete, arg_exprs)))

kernel_call = emit_cudacall(cuda_fun, :(dims), :(shmem), :(stream),
arg_types, arg_exprs)

quote
Base.@_inline_meta
$kernel_compilation
$kernel_call
# call the kernel
Profile.@launch begin
cudacall(cuda_fun, dims[1], dims[2], shmem, stream,
Tuple{$(real_arg_types...)}, $(real_arg_exprs...))
end
end
end


"""
Return the nearest number of threads that is a multiple of the warp size of a device:
Expand Down

0 comments on commit 8f80c05

Please sign in to comment.