Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MetaSchedule] Extend tune_tir to support tuning of specific blocks. #12342

Merged
merged 8 commits into from
Aug 9, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/tvm/meta_schedule/space_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class SpaceGenerator : public runtime::ObjectRef {
* to blocks in post-DFS order.
* \return The design space generator created.
*/
TVM_DLL static SpaceGenerator PostOrderApply();
TVM_DLL static SpaceGenerator PostOrderApply(runtime::PackedFunc f_block_filter = nullptr);
zxybazh marked this conversation as resolved.
Show resolved Hide resolved
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SpaceGenerator, ObjectRef, SpaceGeneratorNode);
};

Expand Down
4 changes: 2 additions & 2 deletions python/tvm/meta_schedule/space_generator/post_order_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class PostOrderApply(SpaceGenerator):
rules to blocks in post-DFS order.
"""

def __init__(self):
def __init__(self, filter_fn=None):
zxybazh marked this conversation as resolved.
Show resolved Hide resolved
"""Constructor"""
self.__init_handle_by_constructor__(
_ffi_api.SpaceGeneratorPostOrderApply, # type: ignore # pylint: disable=no-member
_ffi_api.SpaceGeneratorPostOrderApply, filter_fn # type: ignore # pylint: disable=no-member
)
26 changes: 26 additions & 0 deletions python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from tvm.ir import IRModule
from tvm.ir.transform import PassContext
from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply
from tvm.runtime import Module, NDArray, vm
from tvm.target import Target
from tvm.te import Tensor, create_prim_func
Expand Down Expand Up @@ -364,6 +365,7 @@ def tune_tir(
cost_model: Optional[CostModel] = None,
measure_callbacks: Optional[List[MeasureCallback]] = None,
space: Optional[FnSpaceGenerator] = None,
blocks: Optional[List[str]] = None,
sch_rules: Optional[FnScheduleRule] = None,
postprocs: Optional[FnPostproc] = None,
mutator_probs: Optional[FnMutatorProb] = None,
Expand Down Expand Up @@ -392,6 +394,21 @@ def tune_tir(
The cost model to use.
measure_callbacks : Optional[List[MeasureCallback]]
The callbacks used during tuning.
space : Optional[FnSpaceGenerator]
The space generator to use.
blocks : Optional[List[str]]
A list of block names to tune. If provided, other blocks
jwfromm marked this conversation as resolved.
Show resolved Hide resolved
will not be optimized.
sch_rules : Optional[FnScheduleRule]
The search rules to use.
postprocs : Optional[FnPostproc]
The postprocessors to use.
mutator_probs : Optional[FnMutatorProb]
The probability distribution to use different mutators.
task_name : str
The name of the function to extract schedules from.
num_threads : Optional[int]
The number of threads to use

Returns
-------
Expand All @@ -407,6 +424,15 @@ def tune_tir(
params=[{"log_dir": log_dir, "logger_name": __name__ + f".task_{task_name}"}],
)

if blocks is not None:
assert space is None, "Only one of blocks and space can be specified."
jwfromm marked this conversation as resolved.
Show resolved Hide resolved
# Create a filter function to identify named blocks.
def _filter_fn(block, target_names) -> bool:
jwfromm marked this conversation as resolved.
Show resolved Hide resolved
return block.name_hint in target_names

# Create a space generator that targets specific blocks.
space = PostOrderApply(filter_fn=lambda block: _filter_fn(block, blocks))

# pylint: disable=protected-access
mod = default_config.mod(mod)
target = default_config.target(target)
Expand Down
30 changes: 24 additions & 6 deletions src/meta_schedule/space_generator/post_order_apply.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ namespace meta_schedule {
/*! \brief Collecting all the blocks */
class BlockCollector : public tir::StmtVisitor {
public:
static Array<tir::BlockRV> Collect(const tir::Schedule& sch) { //
return BlockCollector(sch).Run();
static Array<tir::BlockRV> Collect(const tir::Schedule& sch,
const runtime::PackedFunc f_block_filter = nullptr) { //
return BlockCollector(sch, f_block_filter).Run();
}

private:
Expand All @@ -48,19 +49,32 @@ class BlockCollector : public tir::StmtVisitor {
return results;
}
/*! \brief Constructor */
explicit BlockCollector(const tir::Schedule& sch) : sch_(sch) {}
explicit BlockCollector(const tir::Schedule& sch,
const runtime::PackedFunc f_block_filter = nullptr)
: sch_(sch), f_block_filter_(f_block_filter) {}
/*! \brief Override the Stmt visiting behaviour */
void VisitStmt_(const tir::BlockNode* block) override {
tir::StmtVisitor::VisitStmt_(block);
CHECK(block_names_.count(block->name_hint) == 0)
<< "Duplicated block name " << block->name_hint << " in function " << func_name_
<< " not supported!";
block_names_.insert(block->name_hint);
blocks_to_collect_.push_back(block->name_hint);

// If filter function is provided, use it to selectively collect blocks.
// Otherwise collect all blocks.
Bool collect_block = Bool(true);
if (f_block_filter_ != nullptr) {
collect_block = f_block_filter_(GetRef<tir::Block>(block));
}
if (collect_block) {
blocks_to_collect_.push_back(block->name_hint);
}
}

/*! \brief The schedule to be collected */
const tir::Schedule& sch_;
/*! \brief An optional packed func that allows only certain blocks to be collected. */
const runtime::PackedFunc f_block_filter_;
/*! \brief The set of func name and block name pair */
std::unordered_set<String> block_names_;
/* \brief The list of blocks to collect in order */
Expand All @@ -81,6 +95,9 @@ class PostOrderApplyNode : public SpaceGeneratorNode {
Array<ScheduleRule> sch_rules_{nullptr};
/*! \brief The logging function to use. */
PackedFunc logging_func;
/*! \brief Optional block names to target. If not specified all blocks will have spaces generated.
*/
runtime::PackedFunc f_block_filter_ = nullptr;

void VisitAttrs(tvm::AttrVisitor* v) {
// `rand_state_` is not visited
Expand All @@ -107,7 +124,7 @@ class PostOrderApplyNode : public SpaceGeneratorNode {
Array<tir::Schedule> result{sch};
// Enumerate the schedule rules first because you can
// always concat multiple schedule rules as one
Array<tir::BlockRV> all_blocks = BlockCollector::Collect(sch);
Array<tir::BlockRV> all_blocks = BlockCollector::Collect(sch, f_block_filter_);
Array<Optional<ScheduleRule>> rules{NullOpt};
rules.insert(rules.end(), sch_rules_.begin(), sch_rules_.end());
for (Optional<ScheduleRule> sch_rule : rules) {
Expand Down Expand Up @@ -177,8 +194,9 @@ class PostOrderApplyNode : public SpaceGeneratorNode {
TVM_DECLARE_FINAL_OBJECT_INFO(PostOrderApplyNode, SpaceGeneratorNode);
};

SpaceGenerator SpaceGenerator::PostOrderApply() {
SpaceGenerator SpaceGenerator::PostOrderApply(runtime::PackedFunc f_block_filter) {
ObjectPtr<PostOrderApplyNode> n = make_object<PostOrderApplyNode>();
n->f_block_filter_ = f_block_filter;
return SpaceGenerator(n);
}

Expand Down
82 changes: 59 additions & 23 deletions tests/python/unittest/test_meta_schedule_post_order_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,29 @@ def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
return result


@derived_object
class TrinityDoubleRule(PyScheduleRule):
def _initialize_with_tune_context(self, context: "TuneContext") -> None:
pass

def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
if _is_root(sch, block):
return [sch]
new_sch = sch.copy()
i, j = new_sch.get_loops(block=block)
i_0, i_1 = new_sch.split(loop=i, factors=[16, 64])
j_0, j_1 = new_sch.split(loop=j, factors=[64, 16])
new_sch.reorder(i_0, j_0, i_1, j_1)
result = [new_sch]
new_sch = sch.copy()
i, j = new_sch.get_loops(block=block)
i_0, i_1 = new_sch.split(loop=i, factors=[2, 512])
j_0, j_1 = new_sch.split(loop=j, factors=[2, 512])
new_sch.reorder(i_0, j_0, i_1, j_1)
result.append(new_sch)
return result


@derived_object
class ReorderScheduleRule(PyScheduleRule):
def _initialize_with_tune_context(self, context: "TuneContext") -> None:
Expand Down Expand Up @@ -283,28 +306,6 @@ def test_meta_schedule_post_order_apply_duplicate_matmul():


def test_meta_schedule_post_order_apply_remove_block():
@derived_object
class TrinityDouble(PyScheduleRule):
def _initialize_with_tune_context(self, context: "TuneContext") -> None:
pass

def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
if _is_root(sch, block):
return [sch]
new_sch = sch.copy()
i, j = new_sch.get_loops(block=block)
i_0, i_1 = new_sch.split(loop=i, factors=[16, 64])
j_0, j_1 = new_sch.split(loop=j, factors=[64, 16])
new_sch.reorder(i_0, j_0, i_1, j_1)
result = [new_sch]
new_sch = sch.copy()
i, j = new_sch.get_loops(block=block)
i_0, i_1 = new_sch.split(loop=i, factors=[2, 512])
j_0, j_1 = new_sch.split(loop=j, factors=[2, 512])
new_sch.reorder(i_0, j_0, i_1, j_1)
result.append(new_sch)
return result

@derived_object
class RemoveBlock(PyScheduleRule):
def _initialize_with_tune_context(self, context: "TuneContext") -> None:
Expand Down Expand Up @@ -342,7 +343,7 @@ def correct_trace(a, b, c, d):
target=Target("llvm"),
task_name="Remove Block Task",
space_generator=PostOrderApply(),
sch_rules=[RemoveBlock(), TrinityDouble()],
sch_rules=[RemoveBlock(), TrinityDoubleRule()],
)
post_order_apply = context.space_generator
schs = post_order_apply.generate_design_space(mod)
Expand Down Expand Up @@ -385,5 +386,40 @@ def custom_search_space_func(sch: Schedule, _: BlockRV) -> List[Schedule]:
assert called


def test_target_blocks_search_space():
# Test that specific blocks of trinity matmul can be targeted.
def filter_fn(block, target_names) -> bool:
return block.name_hint in target_names

def _get_sch(filter_fn):
mod = TrinityMatmul
context = TuneContext(
mod=mod,
target=Target("llvm"),
task_name="Custom Search Space Task",
space_generator=PostOrderApply(filter_fn=filter_fn),
sch_rules=[TrinityDoubleRule()],
)
post_order_apply = context.space_generator
schs = post_order_apply.generate_design_space(mod)
return schs

# Start by checking that by default each block has a space generated.
schs = _get_sch(None)
assert len(schs) == 8

# Next check that we can target a specific block and only get its' revelant schedules.
schs = _get_sch(lambda block: filter_fn(block, ["B"]))
assert len(schs) == 2

## Check that extracting two blocks works.
schs = _get_sch(lambda block: filter_fn(block, ["A", "C"]))
assert len(schs) == 4

## Finally check that all blocks can be extracted by name.
schs = _get_sch(lambda block: filter_fn(block, ["A", "B", "C"]))
assert len(schs) == 8


if __name__ == "__main__":
tvm.testing.main()
1 change: 1 addition & 0 deletions tests/python/unittest/test_meta_schedule_tune_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def test_tune_matmul_cpu():
max_trials_global=32,
),
work_dir=work_dir,
blocks=["update"],
zxybazh marked this conversation as resolved.
Show resolved Hide resolved
)
if sch is None:
print("No valid schedule found!")
Expand Down