Skip to content

Commit

Permalink
VNNI -> WithIntrin
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Apr 7, 2022
1 parent 4284a47 commit 823797e
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 33 deletions.
14 changes: 8 additions & 6 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,14 @@ class ScheduleRule : public runtime::ObjectRef {
Optional<Map<String, ObjectRef>> reuse_read, //
Optional<Map<String, ObjectRef>> reuse_write);

TVM_DLL static ScheduleRule MultiLevelTilingVNNI(String structure, //
Optional<Array<String>> tile_binds, //
Optional<Integer> max_innermost_factor, //
Optional<Array<Integer>> vector_load_lens, //
Optional<Map<String, ObjectRef>> reuse_read, //
Optional<Map<String, ObjectRef>> reuse_write);
TVM_DLL static ScheduleRule MultiLevelTilingWithIntrin(
String intrin_name, //
String structure, //
Optional<Array<String>> tile_binds, //
Optional<Integer> max_innermost_factor, //
Optional<Array<Integer>> vector_load_lens, //
Optional<Map<String, ObjectRef>> reuse_read, //
Optional<Map<String, ObjectRef>> reuse_write);

/*!
* \brief Create a rule: add-rfactor to some blocks if needed
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/schedule_rule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .add_rfactor import AddRFactor
from .auto_inline import AutoInline
from .cross_thread_reduction import CrossThreadReduction
from .multi_level_tiling import MultiLevelTiling, MultiLevelTilingVNNI, ReuseType
from .multi_level_tiling import MultiLevelTiling, MultiLevelTilingWithIntrin, ReuseType
from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll
from .random_compute_location import RandomComputeLocation
from .schedule_rule import PyScheduleRule, ScheduleRule
8 changes: 5 additions & 3 deletions python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def __init__(
)


@register_object("meta_schedule.MultiLevelTilingVNNI")
class MultiLevelTilingVNNI(ScheduleRule):
@register_object("meta_schedule.MultiLevelTilingWithIntrin")
class MultiLevelTilingWithIntrin(ScheduleRule):
"""Multi-level tiling with reuse.
Parameters
Expand All @@ -111,6 +111,7 @@ class MultiLevelTilingVNNI(ScheduleRule):

def __init__(
self,
intrin_name: str,
structure: str,
tile_binds: Optional[List[str]] = None,
max_innermost_factor: Optional[int] = None,
Expand All @@ -119,7 +120,8 @@ def __init__(
reuse_write: Optional[ReuseType] = None,
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.ScheduleRuleMultiLevelTilingVNNI, # type: ignore # pylint: disable=no-member
_ffi_api.ScheduleRuleMultiLevelTilingWithIntrin, # type: ignore # pylint: disable=no-member
intrin_name,
structure,
tile_binds,
max_innermost_factor,
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def _sch_rules() -> List[ScheduleRule]:
disallow_op=["tir.exp"],
),
M.AddRFactor(max_jobs_per_core=16, max_innermost_factor=64),
M.MultiLevelTilingVNNI(
M.MultiLevelTiling(
structure="SSRSRS",
tile_binds=None,
max_innermost_factor=64,
Expand Down
6 changes: 4 additions & 2 deletions src/meta_schedule/schedule_rule/multi_level_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "multi_level_tiling.h"

#include "../utils.h"
#include "tvm/meta_schedule/schedule_rule.h"

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -233,7 +234,7 @@ std::vector<State> MultiLevelTilingNode::AddReadReuse(State state) const {
if (!vector_load_lens.empty()) {
int n = vector_load_lens.size();
double prob = 1.0 / n;
tir::ExprRV vector_load_len =
tir::ExprRV vector_load_len =
sch->SampleCategorical(support::AsArray<int, Integer>(vector_load_lens),
Array<FloatImm>(n, FloatImm(DataType::Float(64), prob)));
sch->Annotate(cache_read_block, tir::attr::meta_schedule_cooperative_fetch,
Expand All @@ -254,8 +255,9 @@ ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional<Array<Str
Optional<Array<Integer>> vector_load_lens,
Optional<Map<String, ObjectRef>> reuse_read,
Optional<Map<String, ObjectRef>> reuse_write) {
return MultiLevelTilingInitCommon<MultiLevelTilingNode>(
auto node = MultiLevelTilingInitCommon<MultiLevelTilingNode>(
structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write);
return ScheduleRule(node);
}

TVM_REGISTER_NODE_TYPE(MultiLevelTilingNode);
Expand Down
13 changes: 7 additions & 6 deletions src/meta_schedule/schedule_rule/multi_level_tiling.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <tvm/meta_schedule/schedule_rule.h>
#include <tvm/tir/schedule/schedule.h>

#include "../../support/array.h"

namespace tvm {
Expand Down Expand Up @@ -178,11 +179,11 @@ class MultiLevelTilingNode : public ScheduleRuleNode {
};

template <typename NodeType>
ScheduleRule MultiLevelTilingInitCommon(String structure, Optional<Array<String>> tile_binds,
Optional<Integer> max_innermost_factor,
Optional<Array<Integer>> vector_load_lens,
Optional<Map<String, ObjectRef>> reuse_read,
Optional<Map<String, ObjectRef>> reuse_write) {
ObjectPtr<NodeType> MultiLevelTilingInitCommon(String structure, Optional<Array<String>> tile_binds,
Optional<Integer> max_innermost_factor,
Optional<Array<Integer>> vector_load_lens,
Optional<Map<String, ObjectRef>> reuse_read,
Optional<Map<String, ObjectRef>> reuse_write) {
ObjectPtr<NodeType> n = make_object<NodeType>();
n->structure = structure;
n->tile_binds = tile_binds.value_or({});
Expand All @@ -204,7 +205,7 @@ ScheduleRule MultiLevelTilingInitCommon(String structure, Optional<Array<String>
}
n->thread_warp_size_ = -1;
n->max_threads_per_block_ = -1;
return ScheduleRule(n);
return n;
}

} // namespace meta_schedule
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,34 +24,37 @@
namespace tvm {
namespace meta_schedule {

class MultiLevelTilingVNNINode : public MultiLevelTilingNode {
class MultiLevelTilingWithIntrinNode : public MultiLevelTilingNode {
protected:
virtual std::vector<State> ApplySubRules(std::vector<State> states) {
states = SubRule(std::move(states), [&](State state) {
state.block_rv = TileForIntrin(state.sch, state.block_rv, "dot_16x4_vnni");
state.block_rv = TileForIntrin(state.sch, state.block_rv, intrin_name);
return std::vector<State>(1, state);
});
return MultiLevelTilingNode::ApplySubRules(states);
}

public:
static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingVNNI";
TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingVNNINode, MultiLevelTilingNode);
static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingWithIntrin";
TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingWithIntrinNode, MultiLevelTilingNode);

String intrin_name;
};

ScheduleRule ScheduleRule::MultiLevelTilingVNNI(String structure,
Optional<Array<String>> tile_binds,
Optional<Integer> max_innermost_factor,
Optional<Array<Integer>> vector_load_lens,
Optional<Map<String, ObjectRef>> reuse_read,
Optional<Map<String, ObjectRef>> reuse_write) {
return MultiLevelTilingInitCommon<MultiLevelTilingVNNINode>(
ScheduleRule ScheduleRule::MultiLevelTilingWithIntrin(
String intrin_name, String structure, Optional<Array<String>> tile_binds,
Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens,
Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write) {
ICHECK(tir::TensorIntrin::Get(intrin_name).defined());
auto node = MultiLevelTilingInitCommon<MultiLevelTilingWithIntrinNode>(
structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write);
node->intrin_name = intrin_name;
return ScheduleRule(node);
}

TVM_REGISTER_NODE_TYPE(MultiLevelTilingVNNINode);
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTilingVNNI")
.set_body_typed(ScheduleRule::MultiLevelTilingVNNI);
TVM_REGISTER_NODE_TYPE(MultiLevelTilingWithIntrinNode);
TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTilingWithIntrin")
.set_body_typed(ScheduleRule::MultiLevelTilingWithIntrin);

} // namespace meta_schedule
} // namespace tvm

0 comments on commit 823797e

Please sign in to comment.