Skip to content

Commit

Permalink
introduce TileForIntrin
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 7, 2022
1 parent b87ef32 commit 4284a47
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 11 deletions.
8 changes: 8 additions & 0 deletions src/meta_schedule/schedule_rule/auto_tensorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,5 +87,13 @@ Optional<LoopRV> TilingwithTensorIntrin(const tir::Schedule& sch, const tir::Blo
return reorder_suffix[0];
}

tir::BlockRV TileForIntrin(tir::Schedule sch, tir::BlockRV block, const std::string& intrin_name) {
Optional<tir::LoopRV> tiled_loop_rv = TilingwithTensorIntrin(sch, block, intrin_name);
ICHECK(tiled_loop_rv.defined());
tir::BlockRV outer_block = sch->Blockize(tiled_loop_rv.value());
sch->Annotate(outer_block, tir::attr::meta_schedule_auto_tensorize, String(intrin_name));
return outer_block;
}

} // namespace meta_schedule
} // namespace tvm
3 changes: 3 additions & 0 deletions src/meta_schedule/schedule_rule/auto_tensorize.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ namespace meta_schedule {

Optional<tir::LoopRV> TilingwithTensorIntrin(const tir::Schedule& sch, const tir::BlockRV& block_rv,
const String& intrin_name);

tir::BlockRV TileForIntrin(tir::Schedule sch, tir::BlockRV block, const std::string& intrin_name);

} // namespace meta_schedule
} // namespace tvm

Expand Down
15 changes: 4 additions & 11 deletions src/meta_schedule/schedule_rule/multi_level_tiling_vnni.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,13 @@
namespace tvm {
namespace meta_schedule {

std::vector<State> TileForVNNI(State state) {
const std::string intrin_name = "dot_16x4_vnni";
Optional<tir::LoopRV> tiled_loop_rv =
TilingwithTensorIntrin(state.sch, state.block_rv, intrin_name);
ICHECK(tiled_loop_rv.defined());
state.block_rv = state.sch->Blockize(tiled_loop_rv.value());
state.sch->Annotate(state.block_rv, tir::attr::meta_schedule_auto_tensorize, String(intrin_name));
return {state};
}

class MultiLevelTilingVNNINode : public MultiLevelTilingNode {
protected:
virtual std::vector<State> ApplySubRules(std::vector<State> states) {
states = SubRule(std::move(states), [&](State state) { return TileForVNNI(state); });
states = SubRule(std::move(states), [&](State state) {
state.block_rv = TileForIntrin(state.sch, state.block_rv, "dot_16x4_vnni");
return std::vector<State>(1, state);
});
return MultiLevelTilingNode::ApplySubRules(states);
}

Expand Down

0 comments on commit 4284a47

Please sign in to comment.