From cd3984af45864e19528ca04535d2ace8bc00d4cc Mon Sep 17 00:00:00 2001 From: Xiyou Zhou Date: Sat, 1 May 2021 04:27:06 -0700 Subject: [PATCH] [Target][Legalization]Add Tir Level Legalization Function Registration And Update Intrinsic Lowering Pass (#7936) --- src/target/intrin_rule.cc | 24 +++-- src/target/llvm/intrin_rule_llvm.cc | 109 ++++++++++++----------- src/target/spirv/intrin_rule_spirv.cc | 53 +++++------ src/tir/transforms/lower_intrin.cc | 43 ++++----- tests/python/unittest/test_tir_intrin.py | 69 +++++++++++++- 5 files changed, 188 insertions(+), 110 deletions(-) diff --git a/src/target/intrin_rule.cc b/src/target/intrin_rule.cc index bfc3fe6fcc8c8..e697d9b602730 100644 --- a/src/target/intrin_rule.cc +++ b/src/target/intrin_rule.cc @@ -112,19 +112,25 @@ TVM_REGISTER_OP("tir.ceil") TVM_REGISTER_OP("tir.round") .set_attr("default.FLowerIntrinsic", DispatchPureExtern); +TVM_REGISTER_OP("tir.pow").set_attr("default.FLowerIntrinsic", + DispatchPureExtern); + +} // namespace intrin + +namespace legalize { + +using namespace tir; + TVM_REGISTER_OP("tir.rsqrt") - .set_attr("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { + .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { const CallNode* call = e.as(); ICHECK(call != nullptr); auto one = make_const(call->args[0].dtype(), 1); return one / sqrt(call->args[0]); }); -TVM_REGISTER_OP("tir.pow").set_attr("default.FLowerIntrinsic", - DispatchPureExtern); - TVM_REGISTER_OP("tir.sigmoid") - .set_attr("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { + .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { const CallNode* call = e.as(); ICHECK(call != nullptr); auto one = make_const(call->args[0].dtype(), 1); @@ -132,21 +138,21 @@ TVM_REGISTER_OP("tir.sigmoid") }); TVM_REGISTER_OP("tir.isfinite") - .set_attr("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { + .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { const CallNode* call = e.as(); ICHECK(call != nullptr); return isfinite(call->args[0]); }); TVM_REGISTER_OP("tir.isinf") - .set_attr("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { + .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { const CallNode* call = e.as(); ICHECK(call != nullptr); return isinf(call->args[0]); }); TVM_REGISTER_OP("tir.q_multiply_shift") - .set_attr("default.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { + .set_attr("default.FLegalize", [](const PrimExpr& e) -> PrimExpr { using tir::make_const; const tir::CallNode* call = e.as(); @@ -222,6 +228,6 @@ TVM_REGISTER_OP("tir.q_multiply_shift") } }); -} // namespace intrin +} // namespace legalize } // namespace codegen } // namespace tvm diff --git a/src/target/llvm/intrin_rule_llvm.cc b/src/target/llvm/intrin_rule_llvm.cc index 2d30c20306856..adbd1056d962b 100644 --- a/src/target/llvm/intrin_rule_llvm.cc +++ b/src/target/llvm/intrin_rule_llvm.cc @@ -30,6 +30,7 @@ namespace tvm { namespace codegen { namespace llvm { +namespace intrin { using tir::FLowerIntrinsic; TVM_REGISTER_OP("tir.prefetch") @@ -43,20 +44,6 @@ TVM_REGISTER_OP("tir.exp2") .set_attr("llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::exp2, 1>); -// TODO(tvm-team): migrate the legalization transformations as a separate -// set of rules in TIR that can be shared across backends. -TVM_REGISTER_OP("tir.exp10") - .set_attr("llvm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { - using tir::make_const; - using tir::make_zero; - const tir::CallNode* call = e.as(); - ICHECK(call != nullptr); - const PrimExpr& x = call->args[0]; - PrimExpr ln10 = make_const(x.dtype(), 2.302585093); - PrimExpr ret = exp(x * ln10); - return ret; - }); - TVM_REGISTER_OP("tir.fma").set_attr( "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::fmuladd, 3>); @@ -99,8 +86,37 @@ TVM_REGISTER_OP("tir.nearbyint") .set_attr("llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>); +TVM_REGISTER_OP("tir.pow").set_attr( + "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>); + +TVM_REGISTER_OP("tir.popcount") + .set_attr("llvm.FLowerIntrinsic", + DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>); + +TVM_REGISTER_OP("tir.cos").set_attr( + "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>); + +TVM_REGISTER_OP("tir.sin").set_attr( + "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>); +} // namespace intrin + +namespace legalize { +using tir::FLegalize; + +TVM_REGISTER_OP("tir.exp10") + .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { + using tir::make_const; + using tir::make_zero; + const tir::CallNode* call = e.as(); + ICHECK(call != nullptr); + const PrimExpr& x = call->args[0]; + PrimExpr ln10 = make_const(x.dtype(), 2.302585093); + PrimExpr ret = exp(x * ln10); + return ret; + }); + TVM_REGISTER_OP("tir.tanh") - .set_attr("llvm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { + .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { using tir::make_const; using tir::make_zero; const tir::CallNode* call = e.as(); @@ -118,28 +134,16 @@ TVM_REGISTER_OP("tir.tanh") return tir::Select(x >= make_zero(x.dtype()), tanh_pos, tanh_neg); }); -TVM_REGISTER_OP("tir.pow").set_attr( - "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::pow, 2>); - -TVM_REGISTER_OP("tir.popcount") - .set_attr("llvm.FLowerIntrinsic", - DispatchLLVMPureIntrin<::llvm::Intrinsic::ctpop, 1>); - -TVM_REGISTER_OP("tir.tan").set_attr("llvm.FLowerIntrinsic", - [](const PrimExpr& e) -> PrimExpr { - const tir::CallNode* call = - e.as(); - ICHECK(call != nullptr); - const PrimExpr& x = call->args[0]; - PrimExpr tan_x = sin(x) / cos(x); - return tan_x; - }); - -TVM_REGISTER_OP("tir.cos").set_attr( - "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::cos, 1>); +TVM_REGISTER_OP("tir.tan").set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { + const tir::CallNode* call = e.as(); + ICHECK(call != nullptr); + const PrimExpr& x = call->args[0]; + PrimExpr tan_x = sin(x) / cos(x); + return tan_x; +}); TVM_REGISTER_OP("tir.cosh") - .set_attr("llvm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { + .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { using tir::make_const; using tir::make_zero; const tir::CallNode* call = e.as(); @@ -153,11 +157,8 @@ TVM_REGISTER_OP("tir.cosh") return ret; }); -TVM_REGISTER_OP("tir.sin").set_attr( - "llvm.FLowerIntrinsic", DispatchLLVMPureIntrin<::llvm::Intrinsic::sin, 1>); - TVM_REGISTER_OP("tir.sinh") - .set_attr("llvm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { + .set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { using tir::make_const; using tir::make_zero; const tir::CallNode* call = e.as(); @@ -171,21 +172,21 @@ TVM_REGISTER_OP("tir.sinh") return ret; }); -TVM_REGISTER_OP("tir.clz").set_attr( - "llvm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { - const tir::CallNode* call = e.as(); - ICHECK(call != nullptr); - ICHECK_EQ(call->args.size(), 1); - Array cargs; - cargs.push_back(IntImm(DataType::UInt(32), ::llvm::Intrinsic::ctlz)); - cargs.push_back(IntImm(DataType::UInt(32), 2)); - cargs.push_back(call->args[0]); - cargs.push_back(IntImm(DataType::Int(1), 1)); // is_zero_undef - // LLVM requires that the return type must match the first argument type - auto clz = tir::Call(call->args[0]->dtype, tir::builtin::call_llvm_intrin(), cargs); - return cast(call->dtype, clz); - }); - +TVM_REGISTER_OP("tir.clz").set_attr("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr { + const tir::CallNode* call = e.as(); + ICHECK(call != nullptr); + ICHECK_EQ(call->args.size(), 1); + Array cargs; + cargs.push_back(IntImm(DataType::UInt(32), ::llvm::Intrinsic::ctlz)); + cargs.push_back(IntImm(DataType::UInt(32), 2)); + cargs.push_back(call->args[0]); + cargs.push_back(IntImm(DataType::Int(1), 1)); // is_zero_undef + // LLVM requires that the return type must match the first argument type + auto clz = tir::Call(call->args[0]->dtype, tir::builtin::call_llvm_intrin(), cargs); + return cast(call->dtype, clz); +}); + +} // namespace legalize } // namespace llvm } // namespace codegen } // namespace tvm diff --git a/src/target/spirv/intrin_rule_spirv.cc b/src/target/spirv/intrin_rule_spirv.cc index fa38f8fb01078..eca7c4ce17006 100644 --- a/src/target/spirv/intrin_rule_spirv.cc +++ b/src/target/spirv/intrin_rule_spirv.cc @@ -30,8 +30,6 @@ namespace tvm { namespace codegen { namespace spirv { -using tir::FLowerIntrinsic; - // num_signature means number of arguments used to query signature template PrimExpr CallGLSLIntrin(PrimExpr e, const Array& args) { @@ -59,6 +57,8 @@ inline PrimExpr DispatchGLSLPureIntrin(const PrimExpr& e) { return CallGLSLIntrin(e); } +namespace intrin { +using tir::FLowerIntrinsic; TVM_REGISTER_OP("tir.floor") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); @@ -98,29 +98,6 @@ TVM_REGISTER_OP("tir.pow").set_attr("vulkan.FLowerIntrinsic", TVM_REGISTER_OP("tir.tanh") .set_attr("vulkan.FLowerIntrinsic", DispatchGLSLPureIntrin); -TVM_REGISTER_OP("tir.clz").set_attr( - "vulkan.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr { - const tir::CallNode* call = e.as(); - ICHECK(call != nullptr); - ICHECK_EQ(call->args.size(), 1); - PrimExpr arg = call->args[0]; - PrimExpr msb; - if (arg.dtype().bits() == 64) { - // SPIR-V FindUMsb intrinsic only supports 32 bit input - auto int32 = DataType::Int(32); - PrimExpr arg_hi32 = tvm::tir::Cast(int32, arg >> 32); - PrimExpr arg_lo32 = tvm::tir::Cast(int32, arg); - PrimExpr msb_hi = CallGLSLIntrin(e, {arg_hi32}); - PrimExpr msb_lo = CallGLSLIntrin(e, {arg_lo32}); - msb = tvm::if_then_else(arg_hi32 == 0, msb_lo, msb_hi + 32); - } else if (arg.dtype().bits() == 32) { - msb = CallGLSLIntrin(e); - } else { - LOG(FATAL) << "SPIR-V clz only supports a 32 bit or 64 bit integer."; - } - return PrimExpr(arg.dtype().bits() - 1) - msb; - }); - // WebGPU rules. TVM_REGISTER_OP("tir.floor") .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); @@ -151,7 +128,33 @@ TVM_REGISTER_OP("tir.pow").set_attr("webgpu.FLowerIntrinsic", TVM_REGISTER_OP("tir.tanh") .set_attr("webgpu.FLowerIntrinsic", DispatchGLSLPureIntrin); +} // namespace intrin +namespace legalize { +using tir::FLegalize; +TVM_REGISTER_OP("tir.clz").set_attr( + "vulkan.FLegalize", [](const PrimExpr& e) -> PrimExpr { + const tir::CallNode* call = e.as(); + ICHECK(call != nullptr); + ICHECK_EQ(call->args.size(), 1); + PrimExpr arg = call->args[0]; + PrimExpr msb; + if (arg.dtype().bits() == 64) { + // SPIR-V FindUMsb intrinsic only supports 32 bit input + auto int32 = DataType::Int(32); + PrimExpr arg_hi32 = tvm::tir::Cast(int32, arg >> 32); + PrimExpr arg_lo32 = tvm::tir::Cast(int32, arg); + PrimExpr msb_hi = CallGLSLIntrin(e, {arg_hi32}); + PrimExpr msb_lo = CallGLSLIntrin(e, {arg_lo32}); + msb = tvm::if_then_else(arg_hi32 == 0, msb_lo, msb_hi + 32); + } else if (arg.dtype().bits() == 32) { + msb = CallGLSLIntrin(e); + } else { + LOG(FATAL) << "SPIR-V clz only supports a 32 bit or 64 bit integer."; + } + return PrimExpr(arg.dtype().bits() - 1) - msb; + }); +} // namespace legalize } // namespace spirv } // namespace codegen } // namespace tvm diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 4101891db6998..2555002d29b05 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -39,33 +39,34 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { public: using IRMutatorWithAnalyzer::VisitExpr_; using IRMutatorWithAnalyzer::VisitStmt_; + using FLowerGeneral = runtime::TypedPackedFunc; IntrinInjecter(arith::Analyzer* analyzer, std::string target, std::string mtriple = "") : IRMutatorWithAnalyzer(analyzer) { - std::vector patterns_; - patterns_.push_back(target + ".FLowerIntrinsic"); - + std::vector patterns; + patterns.push_back(target + ".FLowerIntrinsic"); + patterns.push_back(target + ".FLegalize"); bool is_llvm_aarch64 = (mtriple.find("aarch64") != std::string::npos); if (is_llvm_aarch64) { - patterns_.push_back(target + ".aarch64.FLowerIntrinsic"); - } - - patterns_.push_back("default.FLowerIntrinsic"); - - fma_ = runtime::Registry::Get("tvm.intrin.rule." + target + ".fma"); - if (target == "stackvm") { - support_bitwise_op_ = false; + patterns.push_back(target + ".aarch64.FLowerIntrinsic"); + patterns.push_back(target + ".aarch64.FLegalize"); } - - for (const std::string& pattern : patterns_) - if (Op::HasAttrMap(pattern)) - lower_intrin_maps_.push_back(Op::GetAttrMap(pattern)); + patterns.push_back("default.FLowerIntrinsic"); + patterns.push_back("default.FLegalize"); + + for (const std::string& pattern : patterns) + if (Op::HasAttrMap(pattern)) { + attr_maps_.push_back(Op::GetAttrMap(pattern)); + if (fma_ == nullptr) { + fma_ = (*attr_maps_.rbegin()).get(Op::Get("tir.fma"), nullptr); + } + } } PrimExpr VisitExpr_(const CallNode* op) final { if (auto* ptr_op = op->op.as()) { - for (const auto& f_lower_intrin_map : lower_intrin_maps_) { - FLowerIntrinsic f = f_lower_intrin_map.get(GetRef(ptr_op), nullptr); + for (const auto& f_attr_map : attr_maps_) { + FLowerGeneral f = f_attr_map.get(GetRef(ptr_op), nullptr); if (f != nullptr) { PrimExpr e = GetRef(op); PrimExpr r = f(e); @@ -269,7 +270,7 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { PrimExpr rhs = SwapBroadcastCast(b); if (fma_ != nullptr && op->dtype.is_float()) { - PrimExpr r = (*fma_)(Call(op->dtype, builtin::fma(), {lhs, rhs, c})); + PrimExpr r = fma_(Call(op->dtype, builtin::fma(), {lhs, rhs, c})); if (r.defined()) return this->VisitExpr(r); } else { if (!lhs.same_as(a) || !rhs.same_as(b)) { @@ -280,9 +281,9 @@ class IntrinInjecter : public tvm::arith::IRMutatorWithAnalyzer { return IRMutatorWithAnalyzer::VisitExpr_(op); } - // patterns - std::vector> lower_intrin_maps_; - const PackedFunc* fma_{nullptr}; + // attribute maps, shared only when FLegalize == FLowerIntrinsic + std::vector> attr_maps_; + FLowerGeneral fma_{nullptr}; bool support_bitwise_op_{true}; }; diff --git a/tests/python/unittest/test_tir_intrin.py b/tests/python/unittest/test_tir_intrin.py index 8512d1c311eb7..79b2819212b78 100644 --- a/tests/python/unittest/test_tir_intrin.py +++ b/tests/python/unittest/test_tir_intrin.py @@ -16,9 +16,10 @@ # under the License. import tvm import tvm.testing -from tvm import te +from tvm import te, tir from tvm import topi from tvm.contrib import utils, clang +from tvm.script import ty import numpy as np import ctypes import math @@ -184,6 +185,71 @@ def clz_np(x, dtype): np.testing.assert_equal(b.asnumpy(), ref) +@tvm.script.tir +class Module: + def test_tir_fma(A: ty.handle, B: ty.handle, C: ty.handle, d: ty.handle) -> None: + # function attr dict + tir.func_attr({"global_symbol": "test_fma", "tir.noalias": True}) + n = tir.var("int32") + stride = tir.var("int32") + stride_1 = tir.var("int32") + stride_2 = tir.var("int32") + stride_3 = tir.var("int32") + A_1 = tir.match_buffer( + A, + [n], + strides=[stride], + elem_offset=0, + align=128, + offset_factor=1, + type="auto", + ) + B_1 = tir.match_buffer( + B, + [n], + strides=[stride_1], + elem_offset=0, + align=128, + offset_factor=1, + type="auto", + ) + C_1 = tir.match_buffer( + C, + [n], + strides=[stride_2], + elem_offset=0, + align=128, + offset_factor=1, + type="auto", + ) + d_1 = tir.match_buffer( + d, + [n], + strides=[stride_3], + elem_offset=0, + align=128, + offset_factor=1, + type="auto", + ) + # body + for i in tir.serial(0, n): + d_1.data[(i * stride_3)] = ( + tir.load("float32", A_1.data, (i * stride)) + * tir.load("float32", B_1.data, (i * stride_1)) + ) + tir.load("float32", C_1.data, (i * stride_2)) + + +def test_fma(): + opt = tvm.transform.Sequential( + [ + tvm.tir.transform.Apply(lambda f: f.with_attr("target", tvm.target.Target("llvm"))), + tvm.tir.transform.LowerIntrin(), + ] + ) + mod = opt(Module()) + assert mod["test_tir_fma"].body.body.value.op.name == "tir.call_llvm_pure_intrin" + + if __name__ == "__main__": test_nearbyint() test_unary_intrin() @@ -191,3 +257,4 @@ def clz_np(x, dtype): test_binary_intrin() test_ldexp() test_clz() + test_fma()