Skip to content

Commit

Permalink
Revert "Implement NVVMReflect in Julia. (#280)"
Browse files Browse the repository at this point in the history
This reverts commit 64296ac.
  • Loading branch information
maleadt committed Feb 18, 2022
1 parent f423f1a commit 5223e58
Showing 1 changed file with 0 additions and 83 deletions.
83 changes: 0 additions & 83 deletions src/ptx.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,6 @@ function optimize_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
add_library_info!(pm, triple(mod))
add_transform_info!(pm, tm)

# TODO: need to run this earlier; optimize_module! is called after addOptimizationPasses!
add!(pm, FunctionPass("NVVMReflect", nvvm_reflect!))

# needed by GemmKernels.jl-like code
speculative_execution_if_has_branch_divergence!(pm)

Expand Down Expand Up @@ -395,83 +392,3 @@ function hide_trap!(mod::LLVM.Module)
end
return changed
end

# Replace occurrences of __nvvm_reflect("foo") and llvm.nvvm.reflect with an integer.
#
# NOTE: this is the same as LLVM's NVVMReflect pass, which we cannot use because it is
# not exported. It is meant to be added to a pass pipeline automatically, by
# calling adjustPassManager, but we don't use a PassManagerBuilder so cannot do so.
const NVVM_REFLECT_FUNCTION = "__nvvm_reflect"
function nvvm_reflect!(fun::LLVM.Function)
job = current_job::CompilerJob
mod = LLVM.parent(fun)
ctx = context(fun)
changed = false
@timeit_debug to "nvvmreflect" begin

# find and sanity check the nnvm-reflect function
# TODO: also handle the llvm.nvvm.reflect intrinsic
haskey(LLVM.functions(mod), NVVM_REFLECT_FUNCTION) || return false
reflect_function = LLVM.functions(mod)[NVVM_REFLECT_FUNCTION]
isdeclaration(reflect_function) || error("_reflect function should not have a body")
reflect_typ = return_type(eltype(llvmtype(reflect_function)))
isa(reflect_typ, LLVM.IntegerType) || error("_reflect's return type should be integer")

to_remove = []
for use in uses(reflect_function)
call = user(use)
isa(call, LLVM.CallInst) || continue
length(operands(call)) == 2 || error("Wrong number of operands to __nvvm_reflect function")

# decode the string argument
str = operands(call)[1]
isa(str, LLVM.ConstantExpr) || error("Format of __nvvm__reflect function not recognized")
sym = operands(str)[1]
isa(sym, LLVM.GlobalVariable) || error("Format of __nvvm__reflect function not recognized")
sym_op = operands(sym)[1]
isa(sym_op, LLVM.ConstantArray) || error("Format of __nvvm__reflect function not recognized")
chars = convert.(Ref(UInt8), collect(sym_op))
reflect_arg = String(chars[1:end-1])

# handle possible cases
# XXX: put some of these property in the compiler job?
# and/or first set the "nvvm-reflect-*" module flag like Clang does?
fast_math = Base.JLOptions().fast_math == 1
# NOTE: we follow nvcc's --use_fast_math
reflect_val = if reflect_arg == "__CUDA_FTZ"
# single-precision denormals support
ConstantInt(reflect_typ, fast_math ? 1 : 0)
elseif reflect_arg == "__CUDA_PREC_DIV"
# single-precision floating-point division and reciprocals.
ConstantInt(reflect_typ, fast_math ? 0 : 1)
elseif reflect_arg == "__CUDA_PREC_SQRT"
# single-precision denormals support
ConstantInt(reflect_typ, fast_math ? 0 : 1)
elseif reflect_arg == "__CUDA_FMAD"
# contraction of floating-point multiplies and adds/subtracts into
# floating-point multiply-add operations (FMAD, FFMA, or DFMA)
ConstantInt(reflect_typ, fast_math ? 1 : 0)
elseif reflect_arg == "__CUDA_ARCH"
ConstantInt(reflect_typ, job.target.cap.major*100 + job.target.cap.minor*10)
else
@warn "Unknown __nvvm_reflect argument: $reflect_arg. Please file an issue."
end

replace_uses!(call, reflect_val)
push!(to_remove, call)
end

# remove the calls to the function
for val in to_remove
@assert isempty(uses(val))
unsafe_delete!(LLVM.parent(val), val)
end

# maybe also delete the function
if isempty(uses(reflect_function))
unsafe_delete!(mod, reflect_function)
end

end
return changed
end

0 comments on commit 5223e58

Please sign in to comment.