diff --git a/src/execution.jl b/src/execution.jl index acd5524e..f107c901 100644 --- a/src/execution.jl +++ b/src/execution.jl @@ -5,10 +5,6 @@ export @cuda, nearest_warpsize, cudaconvert using Base.Iterators: filter -# -# Auxiliary -# - """ cudaconvert(x) @@ -27,17 +23,6 @@ function cudaconvert(x::T) where {T} return x end -function emit_cudacall(func, dims, shmem, stream, types, args) - # TODO: can we handle non-isbits types? - all(t -> isbits(t) && sizeof(t) > 0, types) || - error("can only pass bitstypes of size > 0 to CUDA kernels") - - return quote - Profile.@launch begin - cudacall($func, $dims[1], $dims[2], $shmem, $stream, Tuple{$(types...)}, $(args...)) - end - end -end # fast lookup of global world age world_age() = ccall(:jl_get_tls_world_age, UInt, ()) @@ -50,12 +35,9 @@ function method_age(f, tt) return -1 end -isghosttype(dt) = !dt.mutable && sizeof(dt) == 0 +isghosttype(dt) = !dt.mutable && sizeof(dt) == 0 -# -# @cuda macro -# """ @cuda (gridDim::CuDim, blockDim::CuDim, [shmem::Int], [stream::CuStream]) func(args...) @@ -87,56 +69,53 @@ macro cuda(config::Expr, callexpr::Expr) shmem = length(config.args)==3 ? esc(pop!(config.args)) : :(0) dims = esc(config) args = :(cudaconvert.(($(map(esc, callexpr.args)...),))) - return :(generated_cuda($dims, $shmem, $stream, $args...)) + return :(_cuda($dims, $shmem, $stream, $args...)) end -# Compile and execute a CUDA kernel from a Julia function const agecache = Dict{UInt, UInt}() const compilecache = Dict{UInt, CuFunction}() -@generated function generated_cuda{F<:Core.Function,N}(dims::Tuple{CuDim, CuDim}, shmem, stream, - func::F, args::Vararg{Any,N}) +@generated function _cuda(dims::Tuple{CuDim, CuDim}, shmem, stream, + func::F, args::Vararg{Any,N}) where {F<:Core.Function,N} arg_exprs = [:( args[$i] ) for i in 1:N] arg_types = args - # compile the function, if necessary - @gensym cuda_fun - precomp_key = hash(tuple(func, arg_types...)) # precomputable part of the key - kernel_compilation = quote + # filter out ghost arguments + real_args = map(t->!isghosttype(t), arg_types) + real_arg_types = map(x->x[2], filter(x->x[1], zip(real_args, arg_types))) + real_arg_exprs = map(x->x[2], filter(x->x[1], zip(real_args, arg_exprs))) + + precomp_key = hash(tuple(func, arg_types...)) # precomputable part of the keys + quote + Base.@_inline_meta + # look-up the method age - key = hash(($precomp_key, world_age())) - if haskey(agecache, key) - age = agecache[key] + key1 = hash(($precomp_key, world_age())) + if haskey(agecache, key1) + age = agecache[key1] else age = method_age(func, $arg_types) - agecache[key] = age + agecache[key1] = age end # compile the function ctx = CuCurrentContext() - key = hash(($precomp_key, age, ctx)) - if haskey(compilecache, key) - $cuda_fun = compilecache[key] + key2 = hash(($precomp_key, age, ctx)) + if haskey(compilecache, key2) + cuda_fun = compilecache[key2] else - $cuda_fun, _ = cufunction(device(ctx), func, $arg_types) - compilecache[key] = $cuda_fun + cuda_fun, _ = cufunction(device(ctx), func, $arg_types) + compilecache[key2] = cuda_fun end - end - - # filter out non-concrete args - concrete = map(t->!isghosttype(t), arg_types) - arg_types = map(x->x[2], filter(x->x[1], zip(concrete, arg_types))) - arg_exprs = map(x->x[2], filter(x->x[1], zip(concrete, arg_exprs))) - - kernel_call = emit_cudacall(cuda_fun, :(dims), :(shmem), :(stream), - arg_types, arg_exprs) - quote - Base.@_inline_meta - $kernel_compilation - $kernel_call + # call the kernel + Profile.@launch begin + cudacall(cuda_fun, dims[1], dims[2], shmem, stream, + Tuple{$(real_arg_types...)}, $(real_arg_exprs...)) + end end end + """ Return the nearest number of threads that is a multiple of the warp size of a device: