Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] Add MultiLevelTilingTensorCore rule for auto-tensorization on CUDA #12059

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,32 @@ class ScheduleRule : public runtime::ObjectRef {
Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write);

/*!
* \brief Extension of MultiLevelTiling for auto-tensorizing with a single group of tensor core
* intrinsics
* \param intrin_group A group of tensor core intrinsics. The map should contains key "init",
* "load_a", "load_b", "compute", "store", which represent the tensor intrin for initialization,
* loading operand A, loading operand B, tensor core computation, storing the result. The value of
* the map should be names of tensor intrinsics, must be registerd via TensorIntrin.register(...)
* beforehand
* \param structure The tiling structure. Recommended:
* - 'SSRSRS' on CPU
* - 'SSSRRSRS' on GPU
* \param tile_binds For each level of tiles, which thread axis it is bound to. Recommended:
* - NullOpt on CPU
masahi marked this conversation as resolved.
Show resolved Hide resolved
* - [blockIdx.y, blockIdx.x, threadIdx.y] on GPU
* \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit
* \param vector_load_lens The length of vector lane in vectorized cooperative fetching.
* NullOpt means disable vectorization
* \param reuse_read Data reuse configuration for reading. NullOpt means no reuse.
* \param reuse_write Data reuse configuration for writing. NullOpt means no reuse.
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule MultiLevelTilingTensorCore(
Map<String, String> intrin_group, String structure, Optional<Array<String>> tile_binds,
Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write);

/*!
* \brief Create a rule: add-rfactor to some blocks if needed
* \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1524,6 +1524,10 @@ constexpr const char* meta_schedule_auto_tensorize = "meta_schedule.auto_tensori

/*! \brief Mark that a block is a preprocessor block for layout rewrite. */
constexpr const char* meta_schedule_layout_rewrite_preproc = "meta_schedule.layout_rewrite_preproc";
/*!
* \brief Mark that the init statement of a block should be further rewritten using tensorization.
*/
constexpr const char* meta_schedule_auto_tensorize_init = "meta_schedule.auto_tensorize_init";

/*!
* \brief Mark that a block is executed by a warp. This implies the extend of threadIdx.x is
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/meta_schedule/schedule_rule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@
from .auto_bind import AutoBind
from .auto_inline import AutoInline
from .cross_thread_reduction import CrossThreadReduction
from .multi_level_tiling import MultiLevelTiling, MultiLevelTilingWithIntrin, ReuseType
from .multi_level_tiling import (
MultiLevelTiling,
MultiLevelTilingWithIntrin,
ReuseType,
MultiLevelTilingTensorCore,
)
from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll
from .random_compute_location import RandomComputeLocation
from .schedule_rule import PyScheduleRule, ScheduleRule
56 changes: 55 additions & 1 deletion python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Multi-level tiling with reuse."""
from typing import Any, Dict, List, NamedTuple, Optional
from typing import Any, Dict, List, Mapping, NamedTuple, Optional

from tvm._ffi import register_object

Expand Down Expand Up @@ -131,3 +131,57 @@ def __init__(
reuse_read.as_dict() if reuse_read is not None else None,
reuse_write.as_dict() if reuse_write is not None else None,
)


@register_object("meta_schedule.MultiLevelTilingTensorCore")
class MultiLevelTilingTensorCore(ScheduleRule):
"""Extension of MultiLevelTiling for auto-tensorizing with a single group of tensor core
intrinsics.

Parameters
----------
intrin_group : Mapping[str, str]
A group of tensor core intrinsics. The map should contains key "init", "load_a", "load_b",
"compute", "store", which represent the tensor intrin for initialization, loading operand A,
loading operand B, tensor core computation, storing the result.
The value of the map should be names of tensor intrinsics, must be registerd via
TensorIntrin.register(...) beforehand
structure : str
The tiling structure. Recommended:
- 'SSRSRS' on CPU
- 'SSSRRSRS' on GPU
tile_bind : Optional[List[str]]
For each level of tiles, which thread axis it is bound to. Recommended:
- None on CPU
- [blockIdx.y, vthread.x, threadIdx.y] on GPU
max_innermost_factor : Optional[int]
The maximum size of the innermost factor. None means no limit
vector_load_lens : Optional[List[int]]
The length of vector lane in vectorized cooperative fetching.
None means disable vectorization
reuse_read : Optional[ReuseType]
Data reuse configuration for reading. None means no reuse.
reuse_write : Optional[ReuseType]
Data reuse configuration for writing. None means no reuse.
"""

def __init__(
self,
intrin_group: Mapping[str, str],
structure: str,
tile_binds: Optional[List[str]] = None,
max_innermost_factor: Optional[int] = None,
vector_load_lens: Optional[List[int]] = None,
reuse_read: Optional[ReuseType] = None,
reuse_write: Optional[ReuseType] = None,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ScheduleRuleMultiLevelTilingTensorCore, # type: ignore # pylint: disable=no-member
intrin_group,
structure,
tile_binds,
max_innermost_factor,
vector_load_lens,
reuse_read.as_dict() if reuse_read is not None else None,
reuse_write.as_dict() if reuse_write is not None else None,
)
34 changes: 34 additions & 0 deletions python/tvm/meta_schedule/testing/schedule_rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
ReuseType,
ScheduleRule,
)
from tvm.meta_schedule.schedule_rule.multi_level_tiling import MultiLevelTilingTensorCore
from tvm.tir import tensor_intrin
from tvm.target import Target


Expand Down Expand Up @@ -110,6 +112,38 @@ def multi_level_tiling(target: Target) -> ScheduleRule:
raise NotImplementedError(f"{target.kind.name} is not supported")


def multi_level_tiling_tensor_core(target: Target, scope="shared") -> ScheduleRule:
"""Default schedule rules for with multi-level tiling reuse for tensor core"""
assert scope in ["shared", "global"]
if target.kind.name == "cuda":
return MultiLevelTilingTensorCore(
intrin_group={
"init": tensor_intrin.WMMA_FILL_16x16x16_F32_INTRIN,
"load_a": tensor_intrin.WMMA_LOAD_16x16x16_F16_A_INTRIN,
"load_b": tensor_intrin.WMMA_LOAD_16x16x16_F16_B_INTRIN,
"compute": tensor_intrin.WMMA_SYNC_16x16x16_f16f16f32_INTRIN,
"store": tensor_intrin.WMMA_STORE_16x16x16_F32_SHARED_INTRIN
if scope == "shared"
else tensor_intrin.WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN,
vinx13 marked this conversation as resolved.
Show resolved Hide resolved
},
structure="SSSRRSRS",
tile_binds=["blockIdx.y", "blockIdx.x", "threadIdx.y"],
max_innermost_factor=4, # 64 // tensor intrin size
vector_load_lens=[1, 2, 3, 4],
reuse_read=ReuseType(
req="must",
levels=[4],
scope="shared",
),
reuse_write=ReuseType(
req="must" if scope == "shared" else "no",
levels=[2],
scope="shared",
vinx13 marked this conversation as resolved.
Show resolved Hide resolved
),
)
raise NotImplementedError(f"{target.kind.name} is not supported")


def random_compute_location(target: Target) -> ScheduleRule:
"""Default schedule rules for with random-compute-location"""
if target.kind.name == "llvm":
Expand Down
8 changes: 4 additions & 4 deletions python/tvm/tir/tensor_intrin/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,15 +769,15 @@ def wmma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
*get_wmma_load_intrin(16, 16, 16, "float16", "shared", True, False),
)

WMMA_LOAD_16x16x16_F16_A_INTRIN = "wmma_load_16x16x16_f16_a_trans"
WMMA_LOAD_16x16x16_F16_A_TRANS_INTRIN = "wmma_load_16x16x16_f16_a_trans"
TensorIntrin.register(
WMMA_LOAD_16x16x16_F16_A_INTRIN,
WMMA_LOAD_16x16x16_F16_A_TRANS_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "float16", "shared", False, True),
)

WMMA_LOAD_16x16x16_F16_B_INTRIN = "wmma_load_16x16x16_f16_b_trans"
WMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN = "wmma_load_16x16x16_f16_b_trans"
TensorIntrin.register(
WMMA_LOAD_16x16x16_F16_B_INTRIN,
WMMA_LOAD_16x16x16_F16_B_TRANS_INTRIN,
*get_wmma_load_intrin(16, 16, 16, "float16", "shared", True, True),
)

Expand Down
15 changes: 15 additions & 0 deletions src/meta_schedule/postproc/rewrite_reduction_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,21 @@ bool RewriteReductionBlockNode::Apply(const tir::Schedule& sch) {
tir::BlockRV block_rv = GetRVFromSRef(sch, block_sref, global_var_name);
Array<tir::LoopRV> loop_rvs = sch->GetLoops(block_rv);
tir::BlockRV init_block_rv = sch->DecomposeReduction(block_rv, loop_rvs[decompose_point]);

// Rewrite auto tensorization related annotations
if (tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize).defined()) {
// Remove tensorization annotation as it shouldn't be propagated to the init block.
sch->Unannotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize);
}
if (Optional<String> tensorize_init =
tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize_init)) {
// The annotation of tensorization of the init statement should be moved to the init block
// after 'DecomposeReduction'.
sch->Annotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize,
tensorize_init.value());
sch->Unannotate(block_rv, tir::attr::meta_schedule_auto_tensorize_init);
sch->Unannotate(init_block_rv, tir::attr::meta_schedule_auto_tensorize_init);
}
++rewritten;
}
if (rewritten == 0) {
Expand Down
34 changes: 16 additions & 18 deletions src/meta_schedule/postproc/rewrite_tensorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,24 @@ void CollectTensorizationJobs(
tir::PostOrderVisit(func->body, [=, &jobs](const ObjectRef& obj) {
if (const auto* block = obj.as<tir::BlockNode>()) {
tir::StmtSRef block_sref = sch->GetSRef(block);
std::string block_name = block_sref->StmtAs<tir::BlockNode>()->name_hint;
if (Optional<String> intrin_name =
tir::GetAnn<String>(block_sref, tir::attr::meta_schedule_auto_tensorize)) {
std::string block_name = block_sref->StmtAs<tir::BlockNode>()->name_hint;
if (block_name.find("init") == std::string::npos) {
jobs->emplace_back(block_name, func_name, [sch, intrin_name](tir::BlockRV block) {
try {
sch->Tensorize(block, intrin_name.value());
} catch (const std::exception& e) {
LOG(WARNING) << "Tensorize failed with error " << e.what();
}
});
} else if (vectorize_init_loop) {
jobs->emplace_back(block_name, func_name, [sch](tir::BlockRV block) {
Array<BlockRV> child_blocks = sch->GetChildBlocks(block);
ICHECK(child_blocks.size() == 1);
Array<LoopRV> init_loops = sch->GetLoops(child_blocks[0]);
ICHECK(init_loops.size() == 1);
sch->Vectorize(init_loops[0]);
});
}
jobs->emplace_back(block_name, func_name, [sch, intrin_name](tir::BlockRV block) {
try {
sch->Tensorize(block, intrin_name.value());
} catch (const std::exception& e) {
LOG(WARNING) << "Tensorize failed with error " << e.what();
}
});
} else if (block_name.find("init") && vectorize_init_loop) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we ever hit this condition after your change in rewrite_reduction_block.cc?

To vectorize init loop, should we switch to using tir::attr::meta_schedule_auto_tensorize_init?

Copy link
Member Author

@vinx13 vinx13 Jul 12, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In rewrite_reduction_block, tir::attr::meta_schedule_auto_tensorize will be removed from the init block by default, unless the original reduction block is annotated with tir::attr::meta_schedule_auto_tensorize_init. tir::attr::meta_schedule_auto_tensorize_init will be renamed to tir::attr::meta_schedule_auto_tensorize so that in rewrite_tensorize we can check a single annotation. However I hit another issue that block_name.find("init") is not safe. I changed the logic here a bit let me know if that makes sense to you

jobs->emplace_back(block_name, func_name, [sch](tir::BlockRV block) {
Array<BlockRV> child_blocks = sch->GetChildBlocks(block);
ICHECK(child_blocks.size() == 1);
Array<LoopRV> init_loops = sch->GetLoops(child_blocks[0]);
ICHECK(init_loops.size() == 1);
sch->Vectorize(init_loops[0]);
});
}
}
});
Expand Down
3 changes: 2 additions & 1 deletion src/meta_schedule/schedule_rule/auto_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ inline InlineType AutoInlineNode::CheckInline(const tir::Schedule& sch,
Array<tir::StmtSRef> producer_srefs = GetProducers(state, block_sref);
if (producer_srefs.size() == 1 &&
tir::IsCompleteBlock(sch->state(), producer_srefs[0], scope_block) &&
CanReverseComputeInline(state, block_sref)) {
CanReverseComputeInline(state, block_sref) &&
!GetAnn<String>(producer_srefs[0], tir::attr::meta_schedule_auto_tensorize).defined()) {
return InlineType::kInlineIntoProducer;
}
}
Expand Down
28 changes: 18 additions & 10 deletions src/meta_schedule/schedule_rule/multi_level_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ using tir::IterVarType;
using tir::LoopRV;
using tir::Schedule;

TVM_REGISTER_OBJECT_TYPE(StateNode);

State::State(tir::Schedule sch, tir::BlockRV block_rv, Array<Array<tir::LoopRV>> tiles) {
ObjectPtr<StateNode> node = make_object<StateNode>();
node->sch = std::move(sch);
Expand Down Expand Up @@ -133,6 +135,7 @@ std::vector<State> MultiLevelTilingNode::AddWriteReuse(State state) const {
new_state->sch->ReverseComputeAt(consumer_rvs[0], loop_rv, true);
results.push_back(std::move(new_state));
}
state->write_reuse.emplace(0, consumer_rvs[0]);
results.push_back(state);
return results;
} else {
Expand All @@ -146,6 +149,7 @@ std::vector<State> MultiLevelTilingNode::AddWriteReuse(State state) const {
BlockRV write_cache =
state->sch->CacheWrite(/*block_rv=*/state->block_rv, /*read_buffer_index=*/0,
/*storage_scope=*/config.scope);
state->write_reuse.emplace(0, write_cache);
for (int level : levels) {
State new_state = state->Copy();
const LoopRV& loop_rv = new_state->tiles[level - 1].back();
Expand Down Expand Up @@ -247,22 +251,26 @@ std::vector<State> MultiLevelTilingNode::AddReadReuse(State state) const {
Array<LoopRV> buffer_loops = sch->GetLoops(cache_read_block);
LoopRV fused = sch->Fuse(Array<LoopRV>{buffer_loops.end() - buffer_ndim, //
buffer_loops.end()});
// Annotate cooperative fetching
if (!vector_load_lens.empty()) {
int n = vector_load_lens.size();
double prob = 1.0 / n;
tir::ExprRV vector_load_len =
sch->SampleCategorical(support::AsArray<int, Integer>(vector_load_lens),
Array<FloatImm>(n, FloatImm(DataType::Float(64), prob)));
sch->Annotate(cache_read_block, tir::attr::meta_schedule_cooperative_fetch,
vector_load_len);
}
AnnotateCooperativeFetching(&sch, cache_read_block);
new_state->read_reuse.emplace(i, cache_read_block);
}
results.push_back(std::move(new_state));
}
return results;
}

void MultiLevelTilingNode::AnnotateCooperativeFetching(Schedule* sch,
const tir::BlockRV& block) const {
if (!vector_load_lens.empty()) {
int n = vector_load_lens.size();
double prob = 1.0 / n;
tir::ExprRV vector_load_len =
(*sch)->SampleCategorical(support::AsArray<int, Integer>(vector_load_lens),
Array<FloatImm>(n, FloatImm(DataType::Float(64), prob)));
(*sch)->Annotate(block, tir::attr::meta_schedule_cooperative_fetch, vector_load_len);
}
}

// Constructor

ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional<Array<String>> tile_binds,
Expand Down
35 changes: 34 additions & 1 deletion src/meta_schedule/schedule_rule/multi_level_tiling.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <tvm/meta_schedule/schedule_rule.h>
#include <tvm/tir/schedule/schedule.h>

#include <unordered_map>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -93,6 +94,10 @@ class StateNode : public Object {
tir::BlockRV block_rv;
/*! \brief The loop tiles */
Array<Array<tir::LoopRV>> tiles;
/*! \brief The mapping from buffer index to read cache block. */
std::unordered_map<int, tir::BlockRV> read_reuse;
/*! \brief The mapping from buffer index to write cache block. */
std::unordered_map<int, tir::BlockRV> write_reuse;

/*!
* \brief Create a copy of the state. The underlying schedule is copied. Schedule rules that
Expand All @@ -112,6 +117,31 @@ class State : public ObjectRef {
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(State, ObjectRef, StateNode);
};

class TensorCoreStateNode : public StateNode {
public:
/*! \brief The Tensor Core reindex block A for Tensor Core computation */
tir::BlockRV tensor_core_reindex_A;
/*! \brief The Tensor Core reindex block B for Tensor Core computation */
tir::BlockRV tensor_core_reindex_B;
/*! \brief The Tensor Core reindex store block for Tensor Core computation */
tir::BlockRV tensor_core_reindex_store;

State Copy() const final;

static constexpr const char* _type_key = "meta_schedule.TensorCoreState";
TVM_DECLARE_FINAL_OBJECT_INFO(TensorCoreStateNode, StateNode);
};

class TensorCoreState : public State {
public:
explicit TensorCoreState(tir::Schedule sch, tir::BlockRV block_rv,
Array<Array<tir::LoopRV>> tiles = {});

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorCoreState, State, TensorCoreStateNode);
};
vinx13 marked this conversation as resolved.
Show resolved Hide resolved

struct AutoTensorizationState : public State {};
vinx13 marked this conversation as resolved.
Show resolved Hide resolved

/*!
* \brief Helper to apply a sub-rule to a list of auto scheduling states
* \tparam FLambda The type of the sub-rule functor
Expand Down Expand Up @@ -148,11 +178,14 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
void InitializeWithTuneContext(const TuneContext& context) final;

// Entry of the mega rule; Inherited from ScheduleRuleNode
Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final;
Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) override;

protected:
virtual std::vector<State> ApplySubRules(std::vector<State> states);

// Annotate a block to use cooperative fetching
void AnnotateCooperativeFetching(tir::Schedule* sch, const tir::BlockRV& block) const;

public:
/*!
* \brief The tiling structure. Recommended:
Expand Down
Loading