Skip to content

Commit

Permalink
Vector-Codegen support for llvm-pure-intrin
Browse files Browse the repository at this point in the history
  • Loading branch information
rutkoor committed May 10, 2024
1 parent fffd168 commit deefe06
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/tir/op/builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ TIR_DEFINE_BUILTIN_FUNC(call_llvm_intrin)
TIR_DEFINE_BUILTIN_FUNC(call_llvm_pure_intrin)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
.set_attr<TScriptDtypePrintLocation>("TScriptDtypePrintLocation",
Integer(ScriptDtypePrintLocation::kFirst));
Integer(ScriptDtypePrintLocation::kFirst))
.set_attr<TVectorizable>("TVectorizable", true);

TIR_DEFINE_BUILTIN_FUNC(call_spirv_pure_glsl450)
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure));
Expand Down
22 changes: 21 additions & 1 deletion src/tir/transforms/vectorize_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,27 @@ class Vectorizer : public StmtMutator, public ExprFunctor<PrimExpr(const PrimExp
}
} else {
int lane = 0;
Array<PrimExpr> new_args = MutateArray(op->args, &lane);
Array<PrimExpr> new_args;
if (op->op.same_as(builtin::call_llvm_pure_intrin())) {
// op->args[1], will give us total number of arguments to intrinsic
int num_signature = Downcast<IntImm>(op->args[1])->value;
Array<PrimExpr> op_expr_args;
for (int i = 0; i < num_signature; i++) {
// Collect all intrinsic arguments
op_expr_args.push_back(op->args[i + 2]);
}
// Generate RAMP nodes for intrinsic arguments
Array<PrimExpr> updated_args = MutateArray(op_expr_args, &lane);
// Collect Intrinsic ID and no. of argument
for (int i = 0; i < 2; i++) {
new_args.push_back(op->args[i]);
}
// Collect updated intrinsic arguments
for (int i = 0; i < num_signature; i++) {
new_args.push_back(updated_args[i]);
}
} else
new_args = MutateArray(op->args, &lane);
// normal code path.
if (op->args.same_as(new_args)) {
return GetRef<PrimExpr>(op);
Expand Down
58 changes: 58 additions & 0 deletions tests/python/tir-transform/test_tir_transform_vectorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,5 +488,63 @@ def main(A: T.Buffer((16,), "float32")):
tvm.tir.transform.VectorizeLoop()(Mod)


@pytest.mark.parametrize(
"extent, vec_str, target",
[(4, "float32x4", simple_target)],
)
def test_vectorize_llvm_pure_intrin(extent, vec_str, target):
@I.ir_module
class Before:
@T.prim_func
def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
for j in T.vectorized(extent):
A[j] = T.call_llvm_pure_intrin(
"float32", "llvm.sqrt", tvm.tir.const(1, "uint"), B[j]
)

@I.ir_module
class After:
@T.prim_func
def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
A[T.Ramp(0, 1, extent)] = T.call_llvm_pure_intrin(
vec_str, "llvm.sqrt", tvm.tir.const(1, "uint"), B[T.Ramp(0, 1, extent)]
)

with tvm.target.Target(target):
mod = tvm.tir.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)
mod = tvm.build(mod, target)


@pytest.mark.parametrize(
"extent, vec_str, target",
[(4, "int32x4", simple_target)],
)
def test_vectorize_llvm_pure_intrin_fail(extent, vec_str, target):
@I.ir_module
class Before:
@T.prim_func
def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")):
for j in T.vectorized(extent):
A[j] = T.call_llvm_pure_intrin(
"int32", "llvm.lround", tvm.tir.const(1, "uint"), B[j]
)

@I.ir_module
class After:
@T.prim_func
def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")):
A[T.Ramp(0, 1, extent)] = T.call_llvm_pure_intrin(
vec_str, "llvm.lround", tvm.tir.const(1, "uint"), B[T.Ramp(0, 1, extent)]
)

with pytest.raises(Exception) as e_info:
with tvm.target.Target(target):
mod = tvm.tir.transform.VectorizeLoop()(Before)
ex = tvm.build(mod, target)
tvm.ir.assert_structural_equal(mod, After)
assert "Intrinsic does not support vectors" in e_info.value.args[0]


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit deefe06

Please sign in to comment.