Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] Allow MultiLevelTilingTensorCore rule to specify multiple tensor intrin groups #12113

Merged
merged 1 commit into from
Jul 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,13 +174,13 @@ class ScheduleRule : public runtime::ObjectRef {
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write);

/*!
* \brief Extension of MultiLevelTiling for auto-tensorizing with a single group of tensor core
* intrinsics
* \param intrin_group A group of tensor core intrinsics. The map should contains key "init",
* "load_a", "load_b", "compute", "store", which represent the tensor intrin for initialization,
* loading operand A, loading operand B, tensor core computation, storing the result. The value of
* the map should be names of tensor intrinsics, must be registerd via TensorIntrin.register(...)
* beforehand
* \brief Extension of MultiLevelTiling for auto-tensorizing with multiple groups of candidate
* tensor core intrinsics
* \param intrin_groups A list of groups of tensor core intrinsics. The map should contains key
* "init", "load_a", "load_b", "compute", "store", which represent the tensor intrin for
* initialization, loading operand A, loading operand B, tensor core computation, storing the
* result. The value of the map should be names of tensor intrinsics, must be registerd via
* TensorIntrin.register(...) beforehand
* \param structure The tiling structure. Recommended:
* - 'SSSRRSRS' on GPU
* \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended:
Expand All @@ -193,9 +193,10 @@ class ScheduleRule : public runtime::ObjectRef {
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule MultiLevelTilingTensorCore(
Map<String, String> intrin_group, String structure, Optional<Array<String>> tile_binds,
Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write);
Array<Map<String, String>> intrin_groups, String structure,
Optional<Array<String>> tile_binds, Optional<Integer> max_innermost_factor,
Optional<Array<Integer>> vector_load_lens, Optional<Map<String, ObjectRef>> reuse_read,
Optional<Map<String, ObjectRef>> reuse_write);

/*!
* \brief Create a rule: add-rfactor to some blocks if needed
Expand Down
16 changes: 8 additions & 8 deletions python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,15 @@ def __init__(

@register_object("meta_schedule.MultiLevelTilingTensorCore")
class MultiLevelTilingTensorCore(ScheduleRule):
"""Extension of MultiLevelTiling for auto-tensorizing with a single group of tensor core
intrinsics.
"""Extension of MultiLevelTiling for auto-tensorizing with multiple groups of candidate tensor
core intrinsics.
Parameters
----------
intrin_group : Mapping[str, str]
A group of tensor core intrinsics. The map should contains key "init", "load_a", "load_b",
"compute", "store", which represent the tensor intrin for initialization, loading operand A,
loading operand B, tensor core computation, storing the result.
intrin_groups : List[Mapping[str, str]]
A list of groups of tensor core intrinsics. The map should contains key "init", "load_a",
"load_b", "compute", "store", which represent the tensor intrin for initialization,
loading operand A, loading operand B, tensor core computation, storing the result.
The value of the map should be names of tensor intrinsics, must be registerd via
TensorIntrin.register(...) beforehand
structure : str
Expand All @@ -165,7 +165,7 @@ class MultiLevelTilingTensorCore(ScheduleRule):

def __init__(
self,
intrin_group: Mapping[str, str],
intrin_groups: List[Mapping[str, str]],
structure: str,
tile_binds: Optional[List[str]] = None,
max_innermost_factor: Optional[int] = None,
Expand All @@ -175,7 +175,7 @@ def __init__(
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ScheduleRuleMultiLevelTilingTensorCore, # type: ignore # pylint: disable=no-member
intrin_group,
intrin_groups,
structure,
tile_binds,
max_innermost_factor,
Expand Down
26 changes: 19 additions & 7 deletions python/tvm/meta_schedule/testing/schedule_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Default schedule rules"""
from typing import List, Union
from tvm.meta_schedule.schedule_rule import (
AddRFactor,
AutoBind,
Expand Down Expand Up @@ -114,18 +115,29 @@ def multi_level_tiling(target: Target) -> ScheduleRule:

def multi_level_tiling_tensor_core(
target: Target,
write_reuse_scope="shared",
in_dtype="float16",
out_dtype="float32",
trans_b=False,
write_reuse_scope: str = "shared",
in_dtype: Union[str, List[str]] = "float16",
out_dtype: Union[str, List[str]] = "float32",
trans_b: Union[bool, List[bool]] = False,
) -> ScheduleRule:
"""Default schedule rules for with multi-level tiling reuse for tensor core"""
assert write_reuse_scope in ["shared", "global"]
if not isinstance(in_dtype, list):
in_dtype = [in_dtype]
if not isinstance(out_dtype, list):
out_dtype = [out_dtype]
if not isinstance(trans_b, list):
trans_b = [trans_b]

if target.kind.name == "cuda":
intrin_groups = [
tensor_intrin.get_wmma_intrin_group(write_reuse_scope, _in_dtype, _out_dtype, _trans_b)
for _in_dtype in in_dtype
for _out_dtype in out_dtype
for _trans_b in trans_b
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given a compute, isn't there only one group that's valid?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, this is to allow one single rule to apply to different workloads (different dtypes).

]
return MultiLevelTilingTensorCore(
intrin_group=tensor_intrin.get_wmma_intrin_group(
write_reuse_scope, in_dtype, out_dtype, trans_b
),
intrin_groups=intrin_groups,
structure="SSSRRSRS",
tile_binds=["blockIdx.y", "blockIdx.x", "threadIdx.y"],
max_innermost_factor=4, # 64 // tensor intrin size
Expand Down
145 changes: 93 additions & 52 deletions src/meta_schedule/schedule_rule/multi_level_tiling_tensor_core.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,42 @@ struct TensorCoreIntrinGroup {
String load_b_intrin;
String compute_intrin;
String store_intrin;

/*! \brief Create TensorCoreIntrinGroup from config in a map. The map should contains the
* following keys:
* - init
* - load_a
* - load_b
* - compute
* - store
* The values of the keys should be the names of the corresponding intrinsics and should be
* registered via TensorIntrin.Register beforehand.
*/
static TensorCoreIntrinGroup FromConfig(const Map<String, String>& config);
};

TensorCoreIntrinGroup TensorCoreIntrinGroup::FromConfig(const Map<String, String>& config) {
auto f_initialize_intrin = [&config](String key_name, String* intrin_name) {
CHECK(config.count(key_name)) << "ValueError: " << key_name << " is not set.";
*intrin_name = config.at(key_name);
// Check the existence of the intrin
tir::TensorIntrin::Get(*intrin_name);
};
TensorCoreIntrinGroup intrin_group;
f_initialize_intrin("init", &intrin_group.init_intrin);
f_initialize_intrin("load_a", &intrin_group.load_a_intrin);
f_initialize_intrin("load_b", &intrin_group.load_b_intrin);
f_initialize_intrin("compute", &intrin_group.compute_intrin);
f_initialize_intrin("store", &intrin_group.store_intrin);
return intrin_group;
}

class TensorCoreStateNode : public StateNode {
public:
/*! \brief The tensor core intrinsic group. */
TensorCoreIntrinGroup intrin_group;
/*! \brief The auto tensorization maping info. */
tir::AutoTensorizeMappingInfo mapping_info{nullptr};
/*! \brief The Tensor Core reindex block A for Tensor Core computation */
tir::BlockRV tensor_core_reindex_A;
/*! \brief The Tensor Core reindex block B for Tensor Core computation */
Expand All @@ -57,16 +89,21 @@ class TensorCoreStateNode : public StateNode {

class TensorCoreState : public State {
public:
explicit TensorCoreState(tir::Schedule sch, tir::BlockRV block_rv,
Array<Array<tir::LoopRV>> tiles = {});
explicit TensorCoreState(TensorCoreIntrinGroup intrin_group,
tir::AutoTensorizeMappingInfo mapping_info, Schedule sch,
BlockRV block_rv, Array<Array<tir::LoopRV>> tiles = {});

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorCoreState, State, TensorCoreStateNode);
};

TVM_REGISTER_OBJECT_TYPE(TensorCoreStateNode);

TensorCoreState::TensorCoreState(Schedule sch, BlockRV block_rv, Array<Array<LoopRV>> tiles) {
TensorCoreState::TensorCoreState(TensorCoreIntrinGroup intrin_group,
tir::AutoTensorizeMappingInfo mapping_info, Schedule sch,
BlockRV block_rv, Array<Array<LoopRV>> tiles) {
ObjectPtr<TensorCoreStateNode> node = make_object<TensorCoreStateNode>();
node->intrin_group = intrin_group;
node->mapping_info = mapping_info;
node->sch = std::move(sch);
node->block_rv = std::move(block_rv);
node->tiles = std::move(tiles);
Expand Down Expand Up @@ -116,16 +153,12 @@ class MultiLevelTilingTensorCoreNode : public MultiLevelTilingNode {
const String& intrin_name) const;

public:
/*! \brief The tensor core intrin group to apply */
TensorCoreIntrinGroup intrin_group;
/*! \brief The candidate tensor core intrin groups to apply */
std::vector<TensorCoreIntrinGroup> intrin_groups;
static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingTensorCore";
TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingTensorCoreNode, MultiLevelTilingNode);

private:
/*!
* \brief The mapping info for auto tensorization
*/
tir::AutoTensorizeMappingInfo mapping_info_{nullptr};
};

// Entry of the mega rule; Inherited from ScheduleRuleNode
Expand All @@ -135,21 +168,36 @@ Array<Schedule> MultiLevelTilingTensorCoreNode::Apply(const Schedule& sch,
return {sch};
}

Optional<tir::AutoTensorizeMappingInfo> mapping_info =
tir::GetAutoTensorizeMappingInfo(sch->state(), sch->GetSRef(block_rv),
tir::TensorIntrin::Get(intrin_group.compute_intrin)->desc);
if (!mapping_info.defined()) {
std::unordered_map<int, tir::AutoTensorizeMappingInfo> intrin_group_to_mapping_info;
for (int i = 0, n = intrin_groups.size(); i < n; ++i) {
TensorCoreIntrinGroup intrin_group = intrin_groups[i];
Optional<tir::AutoTensorizeMappingInfo> mapping_info = tir::GetAutoTensorizeMappingInfo(
sch->state(), sch->GetSRef(block_rv),
tir::TensorIntrin::Get(intrin_groups[i].compute_intrin)->desc);
if (mapping_info.defined()) {
intrin_group_to_mapping_info.emplace(i, mapping_info.value());
}
}

if (intrin_group_to_mapping_info.empty()) {
// No tensor intrinsics can be applied.
return {sch};
}
mapping_info_ = mapping_info.value();

// Create a copy of the schedule so that we can roll back transformations if tensorization
// Save the original schedule so that we can roll back transformations if tensorization
// fail.
Schedule original_sch = sch->Copy();
sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure);

Schedule original_sch = sch;

std::vector<State> initial_states;
for (const auto& kv : intrin_group_to_mapping_info) {
const TensorCoreIntrinGroup& intrin_group = intrin_groups[kv.first];
const tir::AutoTensorizeMappingInfo& mapping_info = kv.second;
Schedule new_sch = sch->Copy();
new_sch->Annotate(block_rv, tir::attr::meta_schedule_tiling_structure, structure);
initial_states.push_back(TensorCoreState(intrin_group, mapping_info, new_sch, block_rv));
}
Array<Schedule> results;
for (auto&& state : ApplySubRules({TensorCoreState(sch, block_rv)})) {
for (auto&& state : ApplySubRules(initial_states)) {
results.push_back(std::move(state->sch));
}
if (results.empty()) {
Expand Down Expand Up @@ -196,7 +244,7 @@ std::vector<State> MultiLevelTilingTensorCoreNode::AddWriteReuseTensorCore(
AnnotateCooperativeFetching(&sch, state->write_reuse[0]);
}
sch->ReverseComputeInline(state->tensor_core_reindex_store);
TileAndAnnotateTensorize(&sch, cache_write, intrin_group.store_intrin);
TileAndAnnotateTensorize(&sch, cache_write, state->intrin_group.store_intrin);
return {state};
}

Expand All @@ -212,8 +260,8 @@ std::vector<State> MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore(
TileAndAnnotateTensorize(&sch, cache_read, intrin_name);
};

f_tensorize_load(0, "wmma.matrix_a", intrin_group.load_a_intrin);
f_tensorize_load(1, "wmma.matrix_b", intrin_group.load_b_intrin);
f_tensorize_load(0, "wmma.matrix_a", state->intrin_group.load_a_intrin);
f_tensorize_load(1, "wmma.matrix_b", state->intrin_group.load_b_intrin);
sch->ComputeInline(state->tensor_core_reindex_A);
sch->ComputeInline(state->tensor_core_reindex_B);

Expand All @@ -238,6 +286,7 @@ std::vector<State> MultiLevelTilingTensorCoreNode::AddReadReuseTensorCore(
Optional<LoopRV> MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(
TensorCoreStateNode* state, const String& intrin_name) const {
BlockRV block_rv = state->block_rv;
const tir::AutoTensorizeMappingInfo& mapping_info = state->mapping_info;
tir::StmtSRef block_sref = state->sch->GetSRef(state->block_rv);

// Add reindex stages
Expand All @@ -258,38 +307,38 @@ Optional<LoopRV> MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(
// Transform the layout of reindex buffers accordingly.
// The index map defines the mapping for the computation block. We need to extract the sub index
// map to transform the load and store block.
ICHECK_EQ(mapping_info_->mappings.size(), 1U); // assume only one mapping is present
const tir::IndexMap& index_map = mapping_info_->mappings[0];
ICHECK_EQ(mapping_info->mappings.size(), 1U); // assume only one mapping is present
const tir::IndexMap& index_map = mapping_info->mappings[0];

// Find the correspondence between block iters and the iters in the index map.
std::unordered_map<tir::Var, tir::Var, ObjectPtrHash, ObjectPtrEqual> lhs_to_index_map_src;
std::unordered_map<tir::Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> rhs_to_index_map_tgt;
std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> unmapped_index_map_src;
ICHECK_EQ(mapping_info_->lhs_iters.size(), index_map->initial_indices.size());
for (int i = 0; i < static_cast<int>(mapping_info_->lhs_iters.size()); ++i) {
lhs_to_index_map_src[mapping_info_->lhs_iters[i]->var] = index_map->initial_indices[i];
ICHECK_EQ(mapping_info->lhs_iters.size(), index_map->initial_indices.size());
for (int i = 0; i < static_cast<int>(mapping_info->lhs_iters.size()); ++i) {
lhs_to_index_map_src[mapping_info->lhs_iters[i]->var] = index_map->initial_indices[i];
}
// The number of result iters in the index map is equal or more than the number of rhs (the
// tensor intrin) iters. When there are extra iters, these iters represent unmapped iters from the
// lhs. They will be skipped during pattern matching for tensorization.
// An example of such case is batch matmul, the batch dimension is kept after layout
// transformations and it will be kept as a outer loop after tensorization.
// tensor intrin) iters. When there are extra iters, these iters represent unmapped iters from
// the lhs. They will be skipped during pattern matching for tensorization. An example of such
// case is batch matmul, the batch dimension is kept after layout transformations and it will be
// kept as a outer loop after tensorization.
int offset = static_cast<int>(index_map->final_indices.size()) -
static_cast<int>(mapping_info_->rhs_iters.size());
static_cast<int>(mapping_info->rhs_iters.size());
ICHECK_GE(offset, 0);
for (int i = 0; i < offset; ++i) {
const tir::VarNode* var_ptr = index_map->final_indices[i].as<tir::VarNode>();
ICHECK(var_ptr != nullptr);
unmapped_index_map_src.insert(GetRef<tir::Var>(var_ptr));
}
for (int i = offset; i < static_cast<int>(index_map->final_indices.size()); ++i) {
rhs_to_index_map_tgt[mapping_info_->rhs_iters[i - offset]->var] = index_map->final_indices[i];
rhs_to_index_map_tgt[mapping_info->rhs_iters[i - offset]->var] = index_map->final_indices[i];
}

auto f_get_sub_index_map = [&](const tir::Buffer& lhs_buffer, const tir::Region& lhs_region) {
std::vector<tir::Var> sub_index_map_src;
std::vector<PrimExpr> sub_index_map_tgt;
const tir::Buffer& rhs_buffer = mapping_info_->lhs_buffer_map[lhs_buffer];
const tir::Buffer& rhs_buffer = mapping_info->lhs_buffer_map[lhs_buffer];
for (const Range& range : lhs_region) {
ICHECK(tir::is_one(range->extent));
const tir::VarNode* var_ptr = range->min.as<tir::VarNode>();
Expand All @@ -300,8 +349,8 @@ Optional<LoopRV> MultiLevelTilingTensorCoreNode::TransformWithTensorIntrin(
sub_index_map_tgt.push_back(lhs_representer);
}
}
for (size_t i = 0; i < mapping_info_->rhs_buffer_indices[rhs_buffer].size(); ++i) {
const tir::VarNode* var = mapping_info_->rhs_buffer_indices[rhs_buffer][i].as<tir::VarNode>();
for (size_t i = 0; i < mapping_info->rhs_buffer_indices[rhs_buffer].size(); ++i) {
const tir::VarNode* var = mapping_info->rhs_buffer_indices[rhs_buffer][i].as<tir::VarNode>();
ICHECK(var != nullptr);
sub_index_map_tgt.push_back(rhs_to_index_map_tgt[GetRef<tir::Var>(var)]);
}
Expand Down Expand Up @@ -345,7 +394,7 @@ inline std::vector<State> MultiLevelTilingTensorCoreNode::TransformForTensorizat
TensorCoreState state) const {
// Do reindex and layout transformations.
Optional<LoopRV> transformed_loop_rv =
TransformWithTensorIntrin(state.operator->(), intrin_group.compute_intrin);
TransformWithTensorIntrin(state.operator->(), state->intrin_group.compute_intrin);
if (!transformed_loop_rv.defined()) {
// The workload can't be tensorized.
return {};
Expand All @@ -356,32 +405,24 @@ inline std::vector<State> MultiLevelTilingTensorCoreNode::TransformForTensorizat

// Add annotations for post processors.
state->sch->Annotate(state->block_rv, tir::attr::meta_schedule_auto_tensorize,
intrin_group.compute_intrin);
state->intrin_group.compute_intrin);
state->sch->Annotate(state->block_rv, tir::attr::meta_schedule_auto_tensorize_init,
intrin_group.init_intrin);
state->intrin_group.init_intrin);
state->sch->Annotate(state->block_rv, tir::attr::warp_execution, Bool(true));
return {std::move(state)};
}

ScheduleRule ScheduleRule::MultiLevelTilingTensorCore(
Map<String, String> intrin_group, String structure, Optional<Array<String>> tile_binds,
Array<Map<String, String>> intrin_groups, String structure, Optional<Array<String>> tile_binds,
Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write) {
auto node = MultiLevelTilingInitCommon<MultiLevelTilingTensorCoreNode>(
structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write);

auto f_initialize_intrin = [&intrin_group](String key_name, String* intrin_name) {
CHECK(intrin_group.count(key_name)) << "ValueError: " << key_name << " is not set.";
*intrin_name = intrin_group.at(key_name);
// Check the existence of the intrin
tir::TensorIntrin::Get(*intrin_name);
};
f_initialize_intrin("init", &node->intrin_group.init_intrin);
f_initialize_intrin("load_a", &node->intrin_group.load_a_intrin);
f_initialize_intrin("load_b", &node->intrin_group.load_b_intrin);
f_initialize_intrin("compute", &node->intrin_group.compute_intrin);
f_initialize_intrin("store", &node->intrin_group.store_intrin);

node->intrin_groups.reserve(intrin_groups.size());
for (const auto& intrin_group_config : intrin_groups) {
node->intrin_groups.emplace_back(TensorCoreIntrinGroup::FromConfig(intrin_group_config));
}
return ScheduleRule(node);
}

Expand Down
Loading