From fcc31ee20ddfafd47f566bf98ff40a9f684d12eb Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 7 Apr 2022 04:55:36 +0900 Subject: [PATCH] tensorize worked --- src/meta_schedule/postproc/rewrite_vnni.cc | 9 +++++++-- src/meta_schedule/schedule_rule/multi_level_tiling.cc | 3 +-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/meta_schedule/postproc/rewrite_vnni.cc b/src/meta_schedule/postproc/rewrite_vnni.cc index e462dde0949d..2d49a9c277ed 100644 --- a/src/meta_schedule/postproc/rewrite_vnni.cc +++ b/src/meta_schedule/postproc/rewrite_vnni.cc @@ -49,8 +49,10 @@ void CollectTensorized(const tir::Schedule& sch, const String& func_name, tir::StmtSRef block_sref = sch->GetSRef(block); if (Optional intrin_name = tir::GetAnn(block_sref, tir::attr::meta_schedule_auto_tensorize)) { - tasks.push_back(std::make_tuple(block_sref->StmtAs()->name_hint, - func_name, intrin_name.value())); + std::string block_name = block_sref->StmtAs()->name_hint; + if (block_name.find("init") == std::string::npos) { + tasks.push_back(std::make_tuple(block_name, func_name, intrin_name.value())); + } } } return true; @@ -59,6 +61,7 @@ void CollectTensorized(const tir::Schedule& sch, const String& func_name, } bool RewriteVNNINode::Apply(const tir::Schedule& sch) { + LOG(INFO) << "Apply RewriteVNNI " << sch->mod(); std::vector tasks; for (const auto& kv : sch->mod()->functions) { GlobalVar g_var = kv.first; @@ -73,11 +76,13 @@ bool RewriteVNNINode::Apply(const tir::Schedule& sch) { String intrin_name = std::get<2>(task); sch->Unannotate(block_rv, tir::attr::meta_schedule_auto_tensorize); sch->Tensorize(block_rv, intrin_name); + LOG(INFO) << "After tensorize: " << sch->mod(); } return true; } Postproc RewriteVNNI() { + LOG(INFO) << "RewriteVNNI is called"; ObjectPtr n = make_object(); return Postproc(n); } diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index 0732e097b62c..c3ef8523c46c 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -453,9 +453,8 @@ inline std::vector MultiLevelTilingNode::TileForVNNI(State state) const { const std::string intrin_name = "dot_16x1x16_uint8_int8_int32_cascadelake"; Optional tiled_loop_rv = TilingwithTensorIntrin(state.sch, block_rv, intrin_name); ICHECK(tiled_loop_rv.defined()); - LOG(INFO) << "After TilingwithTensorIntrin" << state.sch->mod(); state.block_rv = state.sch->Blockize(tiled_loop_rv.value()); - state.sch->Annotate(block_rv, tir::attr::meta_schedule_auto_tensorize, String(intrin_name)); + state.sch->Annotate(state.block_rv, tir::attr::meta_schedule_auto_tensorize, String(intrin_name)); result.push_back(state); return result; }