Skip to content
This repository has been archived by the owner on May 27, 2021. It is now read-only.

Commit

Permalink
Merge pull request #86 from JuliaGPU/tb/llvm_cuprintf
Browse files Browse the repository at this point in the history
Implement @cuprint using LLVM
  • Loading branch information
maleadt authored Aug 1, 2017
2 parents fffbba8 + a0e2a59 commit 1d371d8
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 60 deletions.
23 changes: 4 additions & 19 deletions src/cgutils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,23 +116,6 @@ macro wrap(call, attrs="")
end
end

# Make a string literal safe to embed in LLVM IR.
#
# This is a custom, simplified version of Base.escape_string, replacing non-printable
# characters with their two-digit hex code.
function escape_llvm_string(io, s::AbstractString, esc::AbstractString)
i = start(s)
while !done(s,i)
c, j = next(s,i)
c == '\\' ? print(io, "\\\\") :
c in esc ? print(io, '\\', c) :
isprint(c) ? print(io, c) :
print(io, "\\", hex(c, 2))
i = j
end
end
escape_llvm_string(s::AbstractString) = sprint(endof(s), escape_llvm_string, s, "\"")


# julia.h: jl_datatype_align
Base.@pure function datatype_align(::Type{T}) where {T}
Expand All @@ -149,7 +132,8 @@ end


# create an LLVM function, given its return (LLVM) type and a vector of argument types
function create_llvmf(ret::LLVMType, params::Vector{LLVMType}, name::String="")::LLVM.Function
function create_llvmf(ret::LLVMType=LLVM.VoidType(jlctx[]), params::Vector{LLVMType}=LLVMType[],
name::String="")
mod = LLVM.Module("llvmcall", jlctx[])

llvmf_typ = LLVM.FunctionType(ret, params)
Expand All @@ -161,7 +145,8 @@ end

# call an LLVM function, given its return (Julia) type, a tuple-type for the arguments,
# and an expression yielding a tuple of the actual argument values.
function call_llvmf(llvmf::LLVM.Function, ret::Type, params::Type, args::Expr)
function call_llvmf(llvmf::LLVM.Function, ret::Type=Void, params::Type=Tuple{},
args::Expr=:())
quote
Base.@_inline_meta
Base.llvmcall(LLVM.ref($llvmf), $ret, $params, $args...)
Expand Down
86 changes: 48 additions & 38 deletions src/device/intrinsics/output.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,48 +17,58 @@ to match, eg. printing a 64-bit Julia integer requires the `%ld` formatting stri
"""
macro cuprintf(fmt::String, args...)
# NOTE: we can't pass fmt by Val{}, so save it in a global buffer
push!(cuprintf_fmts, "$fmt\0")
push!(cuprintf_fmts, "$fmt")
id = length(cuprintf_fmts)

return :(generated_cuprintf(Val{$id}, $(map(esc, args)...)))
return :(_cuprintf(Val{$id}, $(map(esc, args)...)))
end

function emit_vprintf(id::Integer, argtypes, args...)
fmt = cuprintf_fmts[id]
fmtlen = length(fmt)

llvm_argtypes = [llvmtypes[jltype] for jltype in argtypes]

decls = Vector{String}()
push!(decls, """declare i32 @vprintf(i8*, i8*)""")
push!(decls, """%print$(id)_argtyp = type { $(join(llvm_argtypes, ", ")) }""")
push!(decls, """@print$(id)_fmt = private unnamed_addr constant [$fmtlen x i8] c"$(escape_llvm_string(fmt))", align 1""")

ir = Vector{String}()
push!(ir, """%args = alloca %print$(id)_argtyp""")
arg = 0
tmp = length(args)+1
for jltype in argtypes
llvmtype = llvmtypes[jltype]
push!(ir, """%$tmp = getelementptr inbounds %print$(id)_argtyp, %print$(id)_argtyp* %args, i32 0, i32 $arg""")
push!(ir, """store $llvmtype %$arg, $llvmtype* %$tmp, align 4""")
arg+=1
tmp+=1
end
push!(ir, """%argptr = bitcast %print$(id)_argtyp* %args to i8*""")
push!(ir, """%$tmp = call i32 @vprintf(i8* getelementptr inbounds ([$fmtlen x i8], [$fmtlen x i8]* @print$(id)_fmt, i32 0, i32 0), i8* %argptr)""")
push!(ir, """ret void""")

return quote
Base.@_inline_meta
Base.llvmcall(($(join(decls, "\n")),
$(join(ir, "\n"))),
Void, Tuple{$argtypes...}, $(args...)
)
@generated function _cuprintf(::Type{Val{id}}, argspec...) where {id}
arg_exprs = [:( argspec[$i] ) for i in 1:length(argspec)]
arg_types = [argspec...]

# TODO: needs to adhere to C vararg promotion (short -> int, float -> double, etc.)

T_void = LLVM.VoidType(jlctx[])
T_int32 = LLVM.Int32Type(jlctx[])
T_pint8 = LLVM.PointerType(LLVM.Int8Type(jlctx[]))

# create functions
param_types = LLVMType[convert.(LLVMType, arg_types)...]
llvmf = create_llvmf(T_int32, param_types)
mod = LLVM.parent(llvmf)

# generate IR
Builder(jlctx[]) do builder
entry = BasicBlock(llvmf, "entry", jlctx[])
position!(builder, entry)

fmt = globalstring_ptr!(builder, cuprintf_fmts[id])

# construct and fill args buffer
if isempty(argspec)
buffer = LLVM.PointerNull(T_pint8)
else
argtypes = LLVM.StructType("vprintf_args", jlctx[])
elements!(argtypes, param_types)

args = alloca!(builder, argtypes, "args")
for (i, param) in enumerate(parameters(llvmf))
p = struct_gep!(builder, args, i-1)
store!(builder, param, p)
end

buffer = bitcast!(builder, args, T_pint8)
end

# invoke vprintf and return
vprintf_typ = LLVM.FunctionType(T_int32, [T_pint8, T_pint8])
vprintf = LLVM.Function(mod, "vprintf", vprintf_typ)
chars = call!(builder, vprintf, [fmt, buffer])

ret!(builder, chars)
end
end

@generated function generated_cuprintf{ID}(::Type{Val{ID}}, argspec...)
args = [:( argspec[$i] ) for i in 1:length(argspec)]
return emit_vprintf(ID, argspec, args...)
arg_tuple = Expr(:tuple, arg_exprs...)
call_llvmf(llvmf, Int32, Tuple{arg_types...}, arg_tuple)
end
6 changes: 3 additions & 3 deletions src/execution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ end
const agecache = Dict{UInt, UInt}()
const compilecache = Dict{UInt, CuFunction}()
@generated function _cuda(dims::Tuple{CuDim, CuDim}, shmem, stream,
func::F, args::Vararg{Any,N}) where {F<:Core.Function,N}
arg_exprs = [:( args[$i] ) for i in 1:N]
arg_types = args
func::Core.Function, argspec...)
arg_exprs = [:( argspec[$i] ) for i in 1:length(argspec)]
arg_types = argspec

# filter out ghost arguments
real_args = map(t->!isghosttype(t), arg_types)
Expand Down

0 comments on commit 1d371d8

Please sign in to comment.