Skip to content

Commit

Permalink
tensorize worked
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 7, 2022
1 parent 2b53437 commit fcc31ee
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
9 changes: 7 additions & 2 deletions src/meta_schedule/postproc/rewrite_vnni.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,10 @@ void CollectTensorized(const tir::Schedule& sch, const String& func_name,
tir::StmtSRef block_sref = sch->GetSRef(block);
if (Optional<String> intrin_name =
tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize)) {
tasks.push_back(std::make_tuple(block_sref->StmtAs<tir::BlockNode>()->name_hint,
func_name, intrin_name.value()));
std::string block_name = block_sref->StmtAs<tir::BlockNode>()->name_hint;
if (block_name.find("init") == std::string::npos) {
tasks.push_back(std::make_tuple(block_name, func_name, intrin_name.value()));
}
}
}
return true;
Expand All @@ -59,6 +61,7 @@ void CollectTensorized(const tir::Schedule& sch, const String& func_name,
}

bool RewriteVNNINode::Apply(const tir::Schedule& sch) {
LOG(INFO) << "Apply RewriteVNNI " << sch->mod();
std::vector<BlockPosition> tasks;
for (const auto& kv : sch->mod()->functions) {
GlobalVar g_var = kv.first;
Expand All @@ -73,11 +76,13 @@ bool RewriteVNNINode::Apply(const tir::Schedule& sch) {
String intrin_name = std::get<2>(task);
sch->Unannotate(block_rv, tir::attr::meta_schedule_auto_tensorize);
sch->Tensorize(block_rv, intrin_name);
LOG(INFO) << "After tensorize: " << sch->mod();
}
return true;
}

Postproc RewriteVNNI() {
LOG(INFO) << "RewriteVNNI is called";
ObjectPtr<RewriteVNNINode> n = make_object<RewriteVNNINode>();
return Postproc(n);
}
Expand Down
3 changes: 1 addition & 2 deletions src/meta_schedule/schedule_rule/multi_level_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -453,9 +453,8 @@ inline std::vector<State> MultiLevelTilingNode::TileForVNNI(State state) const {
const std::string intrin_name = "dot_16x1x16_uint8_int8_int32_cascadelake";
Optional<LoopRV> tiled_loop_rv = TilingwithTensorIntrin(state.sch, block_rv, intrin_name);
ICHECK(tiled_loop_rv.defined());
LOG(INFO) << "After TilingwithTensorIntrin" << state.sch->mod();
state.block_rv = state.sch->Blockize(tiled_loop_rv.value());
state.sch->Annotate(block_rv, tir::attr::meta_schedule_auto_tensorize, String(intrin_name));
state.sch->Annotate(state.block_rv, tir::attr::meta_schedule_auto_tensorize, String(intrin_name));
result.push_back(state);
return result;
}
Expand Down

0 comments on commit fcc31ee

Please sign in to comment.