Skip to content

Commit

Permalink
[MetaSchedule] Extend tune_tir to support tuning of specific blocks. (a…
Browse files Browse the repository at this point in the history
…pache#12342)

* Added optional target blocks.

* Checkpoint for debugging.

* Building with packedfunc filter.

* Extended tune_tir API to support named blocks.

* Remove accidental import.

* Improve integration test.

* Change names for more consistency.

* Update integration test.
  • Loading branch information
Josh Fromm authored and Mikael Sevenier committed Aug 12, 2022
1 parent e5c9890 commit f66e4e1
Show file tree
Hide file tree
Showing 6 changed files with 171 additions and 39 deletions.
2 changes: 1 addition & 1 deletion include/tvm/meta_schedule/space_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ class SpaceGenerator : public runtime::ObjectRef {
* to blocks in post-DFS order.
* \return The design space generator created.
*/
TVM_DLL static SpaceGenerator PostOrderApply();
TVM_DLL static SpaceGenerator PostOrderApply(runtime::PackedFunc f_block_filter = nullptr);
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SpaceGenerator, ObjectRef, SpaceGeneratorNode);
};

Expand Down
12 changes: 10 additions & 2 deletions python/tvm/meta_schedule/space_generator/post_order_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,18 @@ class PostOrderApply(SpaceGenerator):
"""
PostOrderApply is the design space generator that generates design spaces by applying schedule
rules to blocks in post-DFS order.
Parameters
----------
f_block_filter : Optional[function]
An optional callback function that is used to filter which blocks have schedules generated
for them. The function should take in a block and return True if a schedule should
be generated or False if that block should be skipped. If no function is provided
all blocks will have schedules generated.
"""

def __init__(self):
def __init__(self, f_block_filter=None):
"""Constructor"""
self.__init_handle_by_constructor__(
_ffi_api.SpaceGeneratorPostOrderApply, # type: ignore # pylint: disable=no-member
_ffi_api.SpaceGeneratorPostOrderApply, f_block_filter # type: ignore # pylint: disable=no-member
)
27 changes: 27 additions & 0 deletions python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from tvm.ir import IRModule
from tvm.ir.transform import PassContext
from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply
from tvm.runtime import Module, NDArray, vm
from tvm.target import Target
from tvm.te import Tensor, create_prim_func
Expand Down Expand Up @@ -364,6 +365,7 @@ def tune_tir(
cost_model: Optional[CostModel] = None,
measure_callbacks: Optional[List[MeasureCallback]] = None,
space: Optional[FnSpaceGenerator] = None,
blocks: Optional[List[str]] = None,
sch_rules: Optional[FnScheduleRule] = None,
postprocs: Optional[FnPostproc] = None,
mutator_probs: Optional[FnMutatorProb] = None,
Expand Down Expand Up @@ -392,6 +394,22 @@ def tune_tir(
The cost model to use.
measure_callbacks : Optional[List[MeasureCallback]]
The callbacks used during tuning.
space : Optional[FnSpaceGenerator]
The space generator to use.
blocks : Optional[List[str]]
A list of block names specifying blocks to be tuned. Note that if
the list is not None, blocks outside this list will not be tuned.
Only one of this argument and space may be provided.
sch_rules : Optional[FnScheduleRule]
The search rules to use.
postprocs : Optional[FnPostproc]
The postprocessors to use.
mutator_probs : Optional[FnMutatorProb]
The probability distribution to use different mutators.
task_name : str
The name of the function to extract schedules from.
num_threads : Optional[int]
The number of threads to use
Returns
-------
Expand All @@ -407,6 +425,15 @@ def tune_tir(
params=[{"log_dir": log_dir, "logger_name": __name__ + f".task_{task_name}"}],
)

if blocks is not None:
assert space is None, "Can not specify blocks to tune when a search space is given."
# Create a filter function to identify named blocks.
def _f_block_filter(block, target_names) -> bool:
return block.name_hint in target_names

# Create a space generator that targets specific blocks.
space = PostOrderApply(f_block_filter=lambda block: _f_block_filter(block, blocks))

# pylint: disable=protected-access
mod = default_config.mod(mod)
target = default_config.target(target)
Expand Down
30 changes: 24 additions & 6 deletions src/meta_schedule/space_generator/post_order_apply.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ namespace meta_schedule {
/*! \brief Collecting all the blocks */
class BlockCollector : public tir::StmtVisitor {
public:
static Array<tir::BlockRV> Collect(const tir::Schedule& sch) { //
return BlockCollector(sch).Run();
static Array<tir::BlockRV> Collect(const tir::Schedule& sch,
const runtime::PackedFunc f_block_filter = nullptr) { //
return BlockCollector(sch, f_block_filter).Run();
}

private:
Expand All @@ -48,19 +49,32 @@ class BlockCollector : public tir::StmtVisitor {
return results;
}
/*! \brief Constructor */
explicit BlockCollector(const tir::Schedule& sch) : sch_(sch) {}
explicit BlockCollector(const tir::Schedule& sch,
const runtime::PackedFunc f_block_filter = nullptr)
: sch_(sch), f_block_filter_(f_block_filter) {}
/*! \brief Override the Stmt visiting behaviour */
void VisitStmt_(const tir::BlockNode* block) override {
tir::StmtVisitor::VisitStmt_(block);
CHECK(block_names_.count(block->name_hint) == 0)
<< "Duplicated block name " << block->name_hint << " in function " << func_name_
<< " not supported!";
block_names_.insert(block->name_hint);
blocks_to_collect_.push_back(block->name_hint);

// If filter function is provided, use it to selectively collect blocks.
// Otherwise collect all blocks.
Bool collect_block = Bool(true);
if (f_block_filter_ != nullptr) {
collect_block = f_block_filter_(GetRef<tir::Block>(block));
}
if (collect_block) {
blocks_to_collect_.push_back(block->name_hint);
}
}

/*! \brief The schedule to be collected */
const tir::Schedule& sch_;
/*! \brief An optional packed func that allows only certain blocks to be collected. */
const runtime::PackedFunc f_block_filter_;
/*! \brief The set of func name and block name pair */
std::unordered_set<String> block_names_;
/* \brief The list of blocks to collect in order */
Expand All @@ -81,6 +95,9 @@ class PostOrderApplyNode : public SpaceGeneratorNode {
Array<ScheduleRule> sch_rules_{nullptr};
/*! \brief The logging function to use. */
PackedFunc logging_func;
/*! \brief Optional block names to target. If not specified all blocks will have spaces generated.
*/
runtime::PackedFunc f_block_filter_ = nullptr;

void VisitAttrs(tvm::AttrVisitor* v) {
// `rand_state_` is not visited
Expand All @@ -107,7 +124,7 @@ class PostOrderApplyNode : public SpaceGeneratorNode {
Array<tir::Schedule> result{sch};
// Enumerate the schedule rules first because you can
// always concat multiple schedule rules as one
Array<tir::BlockRV> all_blocks = BlockCollector::Collect(sch);
Array<tir::BlockRV> all_blocks = BlockCollector::Collect(sch, f_block_filter_);
Array<Optional<ScheduleRule>> rules{NullOpt};
rules.insert(rules.end(), sch_rules_.begin(), sch_rules_.end());
for (Optional<ScheduleRule> sch_rule : rules) {
Expand Down Expand Up @@ -177,8 +194,9 @@ class PostOrderApplyNode : public SpaceGeneratorNode {
TVM_DECLARE_FINAL_OBJECT_INFO(PostOrderApplyNode, SpaceGeneratorNode);
};

SpaceGenerator SpaceGenerator::PostOrderApply() {
SpaceGenerator SpaceGenerator::PostOrderApply(runtime::PackedFunc f_block_filter) {
ObjectPtr<PostOrderApplyNode> n = make_object<PostOrderApplyNode>();
n->f_block_filter_ = f_block_filter;
return SpaceGenerator(n);
}

Expand Down
82 changes: 59 additions & 23 deletions tests/python/unittest/test_meta_schedule_post_order_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,29 @@ def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
return result


@derived_object
class TrinityDoubleRule(PyScheduleRule):
def _initialize_with_tune_context(self, context: "TuneContext") -> None:
pass

def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
if _is_root(sch, block):
return [sch]
new_sch = sch.copy()
i, j = new_sch.get_loops(block=block)
i_0, i_1 = new_sch.split(loop=i, factors=[16, 64])
j_0, j_1 = new_sch.split(loop=j, factors=[64, 16])
new_sch.reorder(i_0, j_0, i_1, j_1)
result = [new_sch]
new_sch = sch.copy()
i, j = new_sch.get_loops(block=block)
i_0, i_1 = new_sch.split(loop=i, factors=[2, 512])
j_0, j_1 = new_sch.split(loop=j, factors=[2, 512])
new_sch.reorder(i_0, j_0, i_1, j_1)
result.append(new_sch)
return result


@derived_object
class ReorderScheduleRule(PyScheduleRule):
def _initialize_with_tune_context(self, context: "TuneContext") -> None:
Expand Down Expand Up @@ -283,28 +306,6 @@ def test_meta_schedule_post_order_apply_duplicate_matmul():


def test_meta_schedule_post_order_apply_remove_block():
@derived_object
class TrinityDouble(PyScheduleRule):
def _initialize_with_tune_context(self, context: "TuneContext") -> None:
pass

def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]:
if _is_root(sch, block):
return [sch]
new_sch = sch.copy()
i, j = new_sch.get_loops(block=block)
i_0, i_1 = new_sch.split(loop=i, factors=[16, 64])
j_0, j_1 = new_sch.split(loop=j, factors=[64, 16])
new_sch.reorder(i_0, j_0, i_1, j_1)
result = [new_sch]
new_sch = sch.copy()
i, j = new_sch.get_loops(block=block)
i_0, i_1 = new_sch.split(loop=i, factors=[2, 512])
j_0, j_1 = new_sch.split(loop=j, factors=[2, 512])
new_sch.reorder(i_0, j_0, i_1, j_1)
result.append(new_sch)
return result

@derived_object
class RemoveBlock(PyScheduleRule):
def _initialize_with_tune_context(self, context: "TuneContext") -> None:
Expand Down Expand Up @@ -342,7 +343,7 @@ def correct_trace(a, b, c, d):
target=Target("llvm"),
task_name="Remove Block Task",
space_generator=PostOrderApply(),
sch_rules=[RemoveBlock(), TrinityDouble()],
sch_rules=[RemoveBlock(), TrinityDoubleRule()],
)
post_order_apply = context.space_generator
schs = post_order_apply.generate_design_space(mod)
Expand Down Expand Up @@ -385,5 +386,40 @@ def custom_search_space_func(sch: Schedule, _: BlockRV) -> List[Schedule]:
assert called


def test_target_blocks_search_space():
# Test that specific blocks of trinity matmul can be targeted.
def filter_fn(block, target_names) -> bool:
return block.name_hint in target_names

def _get_sch(filter_fn):
mod = TrinityMatmul
context = TuneContext(
mod=mod,
target=Target("llvm"),
task_name="Custom Search Space Task",
space_generator=PostOrderApply(f_block_filter=filter_fn),
sch_rules=[TrinityDoubleRule()],
)
post_order_apply = context.space_generator
schs = post_order_apply.generate_design_space(mod)
return schs

# Start by checking that by default each block has a space generated.
schs = _get_sch(None)
assert len(schs) == 8

# Next check that we can target a specific block and only get its' revelant schedules.
schs = _get_sch(lambda block: filter_fn(block, ["B"]))
assert len(schs) == 2

## Check that extracting two blocks works.
schs = _get_sch(lambda block: filter_fn(block, ["A", "C"]))
assert len(schs) == 4

## Finally check that all blocks can be extracted by name.
schs = _get_sch(lambda block: filter_fn(block, ["A", "B", "C"]))
assert len(schs) == 8


if __name__ == "__main__":
tvm.testing.main()
57 changes: 50 additions & 7 deletions tests/python/unittest/test_meta_schedule_tune_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring
# pylint: disable=missing-docstring,no-member,invalid-name,unused-variable
import logging
import tempfile
import numpy as np
Expand All @@ -23,20 +23,19 @@
import tvm

from tvm import meta_schedule as ms
from tvm.meta_schedule import TuneConfig, tune_tir
from tvm.meta_schedule import TuneContext, TuneConfig, tune_tir
from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
from tvm.meta_schedule.testing.local_rpc import LocalRPC
from tvm.meta_schedule.schedule_rule import PyScheduleRule
from tvm.meta_schedule.utils import derived_object
from tvm.script import tir as T
from tvm.target import Target
from tvm.tir import Schedule
from tvm.tir.schedule import BlockRV, Schedule

logging.basicConfig()
logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG)


# pylint: disable=no-member,invalid-name,unused-variable


@T.prim_func
def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, [128, 128])
Expand All @@ -50,7 +49,19 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None:
C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk]


# pylint: enable=no-member,invalid-name,unused-variable
@T.prim_func
def two_step(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (1024, 1024), "float32")
B = T.alloc_buffer((1024, 1024), "float32")
C = T.match_buffer(c, (1024, 1024), "float32")
for i, j in T.grid(1024, 1024):
with T.block("A"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(1024, 1024):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, vj] + 3.0


@pytest.mark.skip("Integration test")
Expand All @@ -74,6 +85,37 @@ def test_tune_matmul_cpu():
print(sch.trace)


@pytest.mark.skip("Integration test")
def test_tune_block_cpu():
@derived_object
class RemoveBlock(PyScheduleRule):
def _initialize_with_tune_context(self, context: TuneContext) -> None:
pass

def apply(self, sch: Schedule, block: BlockRV):
if sch.get(block).name_hint == "root":
return [sch]
sch = sch.copy()
sch.compute_inline(block)
return [sch]

with tempfile.TemporaryDirectory() as work_dir:
sch: Schedule = tune_tir(
mod=two_step,
target=Target("llvm --num-cores=16"),
config=TuneConfig(
strategy="replay_trace",
num_trials_per_iter=32,
max_trials_per_task=32,
max_trials_global=32,
),
work_dir=work_dir,
blocks=["A"],
sch_rules=lambda *args: [RemoveBlock()],
)
assert sch is not None


@pytest.mark.skip("Integration test")
def test_tune_matmul_cuda():
with tempfile.TemporaryDirectory() as work_dir:
Expand Down Expand Up @@ -141,3 +183,4 @@ def f_timer(rt_mod, dev, input_data):
test_tune_matmul_cpu()
test_tune_matmul_cuda()
test_tune_run_module_via_rpc()
test_tune_block_cpu()

0 comments on commit f66e4e1

Please sign in to comment.