Skip to content

Commit

Permalink
NFC clean-up.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Sep 21, 2024
1 parent 8c6e38c commit b77553c
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 29 deletions.
22 changes: 9 additions & 13 deletions src/metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,11 @@ end
function validate_ir(job::CompilerJob{MetalCompilerTarget}, mod::LLVM.Module)
errors = IRError[]

function is_valid_double_use(inst::LLVM.Instruction, errors)
# Metal does not support double precision, except for logging
function is_illegal_double(val)
T_bad = LLVM.DoubleType()

if value_type(inst) != T_bad || all(param->value_type(param) != T_bad, operands(inst))
return
if value_type(val) != T_bad
return false
end

function used_for_logging(use::LLVM.Use)
Expand All @@ -104,21 +104,17 @@ function validate_ir(job::CompilerJob{MetalCompilerTarget}, mod::LLVM.Module)
return true
end
end

return false
end

if all(used_for_logging, uses(inst))
return
if all(used_for_logging, uses(val))
return false
end

bt = backtrace(inst)
err = ("use of double value", bt, inst)
push!(errors, err)
return true
end
append!(errors, check_ir_values(mod, is_illegal_double, "use of double value"))

# Metal never supports double precision
append!(errors, check_ir_values(mod, is_valid_double_use))
# Metal never supports 128-bit integers
append!(errors, check_ir_values(mod, LLVM.IntType(128)))

errors
Expand Down
26 changes: 10 additions & 16 deletions src/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -317,24 +317,18 @@ function check_ir!(job, errors::Vector{IRError}, inst::LLVM.CallInst)
return errors
end

# helper function to check if a LLVM module uses values of a certain type

function check_illegal_value_type(inst::LLVM.Instruction, errors, T_bad::LLVMType)
if value_type(inst) == T_bad || any(param->value_type(param) == T_bad, operands(inst))
bt = backtrace(inst)
err = ("use of $(string(T_bad)) value", bt, inst)
push!(errors, err)
end
end

check_ir_values(mod::LLVM.Module, T_bad::LLVMType) = check_ir_values(mod, (x,errs)->check_illegal_value_type(x, errs, T_bad))

function check_ir_values(mod::LLVM.Module, T_bad)
# helper function to check for illegal values in an LLVM module
function check_ir_values(mod::LLVM.Module, predicate, msg="value")
errors = IRError[]

for fun in functions(mod), bb in blocks(fun), inst in instructions(bb)
T_bad(inst, errors)
if predicate(inst) || any(predicate, operands(inst))
bt = backtrace(inst)
push!(errors, (msg, bt, inst))
end
end

return errors
end
## shorthand to check for illegal value types
function check_ir_values(mod::LLVM.Module, T_bad::LLVMType)
check_ir_values(mod, val -> value_type(val) == T_bad, "use of $(string(T_bad)) value")
end

0 comments on commit b77553c

Please sign in to comment.