Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Jul 12, 2022
1 parent 8dd2de5 commit 8bd7a05
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 51 deletions.
2 changes: 0 additions & 2 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,8 @@ class ScheduleRule : public runtime::ObjectRef {
* the map should be names of tensor intrinsics, must be registerd via TensorIntrin.register(...)
* beforehand
* \param structure The tiling structure. Recommended:
* - 'SSRSRS' on CPU
* - 'SSSRRSRS' on GPU
* \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended:
* - NullOpt on CPU
* - [blockIdx.y, blockIdx.x, threadIdx.y] on GPU
* \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit
* \param vector_load_lens The length of vector lane in vectorized cooperative fetching.
Expand Down
2 changes: 0 additions & 2 deletions python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,9 @@ class MultiLevelTilingTensorCore(ScheduleRule):
TensorIntrin.register(...) beforehand
structure : str
The tiling structure. Recommended:
- 'SSRSRS' on CPU
- 'SSSRRSRS' on GPU
tile_bind : Optional[List[str]]
For each level of tiles, which thread axis it is bound to. Recommended:
- None on CPU
- [blockIdx.y, vthread.x, threadIdx.y] on GPU
max_innermost_factor : Optional[int]
The maximum size of the innermost factor. None means no limit
Expand Down
14 changes: 8 additions & 6 deletions src/meta_schedule/postproc/rewrite_reduction_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,17 @@ bool RewriteReductionBlockNode::Apply(const tir::Schedule& sch) {
if (tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize).defined()) {
// Remove tensorization annotation as it shouldn't be propagated to the init block.
sch->Unannotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize);
}
if (Optional<String> tensorize_init =
tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize_init)) {
Optional<String> tensorize_init =
tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize_init);
// The annotation of tensorization of the init statement should be moved to the init block
// after 'DecomposeReduction'.
// Annotate to hint `RewriteTensorize` postprocessor even if tensorize_init is NullOpt.
sch->Annotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize,
tensorize_init.value());
sch->Unannotate(block_rv, tir::attr::meta_schedule_auto_tensorize_init);
sch->Unannotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize_init);
tensorize_init.value_or(""));
if (tensorize_init.defined()) {
sch->Unannotate(block_rv, tir::attr::meta_schedule_auto_tensorize_init);
sch->Unannotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize_init);
}
}
++rewritten;
}
Expand Down
32 changes: 17 additions & 15 deletions src/meta_schedule/postproc/rewrite_tensorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,23 @@ void CollectTensorizationJobs(
std::string block_name = block_sref->StmtAs<tir::BlockNode>()->name_hint;
if (Optional<String> intrin_name =
tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize)) {
jobs->emplace_back(block_name, func_name, [sch, intrin_name](tir::BlockRV block) {
try {
sch->Tensorize(block, intrin_name.value());
} catch (const std::exception& e) {
LOG(WARNING) << "Tensorize failed with error " << e.what();
}
});
} else if (block_name.find("init") && vectorize_init_loop) {
jobs->emplace_back(block_name, func_name, [sch](tir::BlockRV block) {
Array<BlockRV> child_blocks = sch->GetChildBlocks(block);
ICHECK(child_blocks.size() == 1);
Array<LoopRV> init_loops = sch->GetLoops(child_blocks[0]);
ICHECK(init_loops.size() == 1);
sch->Vectorize(init_loops[0]);
});
if (intrin_name.value() != "") {
jobs->emplace_back(block_name, func_name, [sch, intrin_name](tir::BlockRV block) {
try {
sch->Tensorize(block, intrin_name.value());
} catch (const std::exception& e) {
LOG(WARNING) << "Tensorize failed with error " << e.what();
}
});
} else if (block_name.find("init") && vectorize_init_loop) {
jobs->emplace_back(block_name, func_name, [sch,block_name](tir::BlockRV block) {
Array<BlockRV> child_blocks = sch->GetChildBlocks(block);
ICHECK(child_blocks.size() == 1) << block_name << child_blocks;
Array<LoopRV> init_loops = sch->GetLoops(child_blocks[0]);
ICHECK(init_loops.size() == 1);
sch->Vectorize(init_loops[0]);
});
}
}
}
});
Expand Down
25 changes: 0 additions & 25 deletions src/meta_schedule/schedule_rule/multi_level_tiling.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,31 +117,6 @@ class State : public ObjectRef {
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(State, ObjectRef, StateNode);
};

class TensorCoreStateNode : public StateNode {
public:
/*! \brief The Tensor Core reindex block A for Tensor Core computation */
tir::BlockRV tensor_core_reindex_A;
/*! \brief The Tensor Core reindex block B for Tensor Core computation */
tir::BlockRV tensor_core_reindex_B;
/*! \brief The Tensor Core reindex store block for Tensor Core computation */
tir::BlockRV tensor_core_reindex_store;

State Copy() const final;

static constexpr const char* _type_key = "meta_schedule.TensorCoreState";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorCoreStateNode, StateNode);
};

class TensorCoreState : public State {
public:
explicit TensorCoreState(tir::Schedule sch, tir::BlockRV block_rv,
Array<Array<tir::LoopRV>> tiles = {});

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorCoreState, State, TensorCoreStateNode);
};

struct AutoTensorizationState : public State {};

/*!
* \brief Helper to apply a sub-rule to a list of auto scheduling states
* \tparam FLambda The type of the sub-rule functor
Expand Down
24 changes: 24 additions & 0 deletions src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,30 @@ struct TensorCoreIntrinGroup {
String store_intrin;
};

class TensorCoreStateNode : public StateNode {
public:
/*! \brief The Tensor Core reindex block A for Tensor Core computation */
tir::BlockRV tensor_core_reindex_A;
/*! \brief The Tensor Core reindex block B for Tensor Core computation */
tir::BlockRV tensor_core_reindex_B;
/*! \brief The Tensor Core reindex store block for Tensor Core computation */
tir::BlockRV tensor_core_reindex_store;

State Copy() const final;

static constexpr const char* _type_key = "meta_schedule.TensorCoreState";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorCoreStateNode, StateNode);
};

class TensorCoreState : public State {
public:
explicit TensorCoreState(tir::Schedule sch, tir::BlockRV block_rv,
Array<Array<tir::LoopRV>> tiles = {});

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorCoreState, State, TensorCoreStateNode);
};


TVM_REGISTER_OBJECT_TYPE(TensorCoreStateNode);

TensorCoreState::TensorCoreState(Schedule sch, BlockRV block_rv, Array<Array<LoopRV>> tiles) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def main(
)
T.reads()
T.writes(compute_local[i, j])
T.block_attr({"meta_schedule.auto_tensorize": "dp4a"})
T.block_attr({"meta_schedule.auto_tensorize": ""})
with T.block("compute_init"):
T.reads()
T.writes(compute_local[i, j])
Expand Down

0 comments on commit 8bd7a05

Please sign in to comment.