diff --git a/include/tvm/meta_schedule/space_generator.h b/include/tvm/meta_schedule/space_generator.h index f7d6cac31cab..2df040e5d941 100644 --- a/include/tvm/meta_schedule/space_generator.h +++ b/include/tvm/meta_schedule/space_generator.h @@ -153,7 +153,7 @@ class SpaceGenerator : public runtime::ObjectRef { * to blocks in post-DFS order. * \return The design space generator created. */ - TVM_DLL static SpaceGenerator PostOrderApply(); + TVM_DLL static SpaceGenerator PostOrderApply(runtime::PackedFunc f_block_filter = nullptr); TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(SpaceGenerator, ObjectRef, SpaceGeneratorNode); }; diff --git a/python/tvm/meta_schedule/space_generator/post_order_apply.py b/python/tvm/meta_schedule/space_generator/post_order_apply.py index 80f372a448f5..6e2a2c52b1a1 100644 --- a/python/tvm/meta_schedule/space_generator/post_order_apply.py +++ b/python/tvm/meta_schedule/space_generator/post_order_apply.py @@ -27,10 +27,18 @@ class PostOrderApply(SpaceGenerator): """ PostOrderApply is the design space generator that generates design spaces by applying schedule rules to blocks in post-DFS order. + + Parameters + ---------- + f_block_filter : Optional[function] + An optional callback function that is used to filter which blocks have schedules generated + for them. The function should take in a block and return True if a schedule should + be generated or False if that block should be skipped. If no function is provided + all blocks will have schedules generated. """ - def __init__(self): + def __init__(self, f_block_filter=None): """Constructor""" self.__init_handle_by_constructor__( - _ffi_api.SpaceGeneratorPostOrderApply, # type: ignore # pylint: disable=no-member + _ffi_api.SpaceGeneratorPostOrderApply, f_block_filter # type: ignore # pylint: disable=no-member ) diff --git a/python/tvm/meta_schedule/tune.py b/python/tvm/meta_schedule/tune.py index fbbe24b32e4d..447fb56637ef 100644 --- a/python/tvm/meta_schedule/tune.py +++ b/python/tvm/meta_schedule/tune.py @@ -24,6 +24,7 @@ from tvm.ir import IRModule from tvm.ir.transform import PassContext +from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply from tvm.runtime import Module, NDArray, vm from tvm.target import Target from tvm.te import Tensor, create_prim_func @@ -364,6 +365,7 @@ def tune_tir( cost_model: Optional[CostModel] = None, measure_callbacks: Optional[List[MeasureCallback]] = None, space: Optional[FnSpaceGenerator] = None, + blocks: Optional[List[str]] = None, sch_rules: Optional[FnScheduleRule] = None, postprocs: Optional[FnPostproc] = None, mutator_probs: Optional[FnMutatorProb] = None, @@ -392,6 +394,22 @@ def tune_tir( The cost model to use. measure_callbacks : Optional[List[MeasureCallback]] The callbacks used during tuning. + space : Optional[FnSpaceGenerator] + The space generator to use. + blocks : Optional[List[str]] + A list of block names specifying blocks to be tuned. Note that if + the list is not None, blocks outside this list will not be tuned. + Only one of this argument and space may be provided. + sch_rules : Optional[FnScheduleRule] + The search rules to use. + postprocs : Optional[FnPostproc] + The postprocessors to use. + mutator_probs : Optional[FnMutatorProb] + The probability distribution to use different mutators. + task_name : str + The name of the function to extract schedules from. + num_threads : Optional[int] + The number of threads to use Returns ------- @@ -407,6 +425,15 @@ def tune_tir( params=[{"log_dir": log_dir, "logger_name": __name__ + f".task_{task_name}"}], ) + if blocks is not None: + assert space is None, "Can not specify blocks to tune when a search space is given." + # Create a filter function to identify named blocks. + def _f_block_filter(block, target_names) -> bool: + return block.name_hint in target_names + + # Create a space generator that targets specific blocks. + space = PostOrderApply(f_block_filter=lambda block: _f_block_filter(block, blocks)) + # pylint: disable=protected-access mod = default_config.mod(mod) target = default_config.target(target) diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index 50b49943f5ff..51dea2c2fe90 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -24,8 +24,9 @@ namespace meta_schedule { /*! \brief Collecting all the blocks */ class BlockCollector : public tir::StmtVisitor { public: - static Array Collect(const tir::Schedule& sch) { // - return BlockCollector(sch).Run(); + static Array Collect(const tir::Schedule& sch, + const runtime::PackedFunc f_block_filter = nullptr) { // + return BlockCollector(sch, f_block_filter).Run(); } private: @@ -48,7 +49,9 @@ class BlockCollector : public tir::StmtVisitor { return results; } /*! \brief Constructor */ - explicit BlockCollector(const tir::Schedule& sch) : sch_(sch) {} + explicit BlockCollector(const tir::Schedule& sch, + const runtime::PackedFunc f_block_filter = nullptr) + : sch_(sch), f_block_filter_(f_block_filter) {} /*! \brief Override the Stmt visiting behaviour */ void VisitStmt_(const tir::BlockNode* block) override { tir::StmtVisitor::VisitStmt_(block); @@ -56,11 +59,22 @@ class BlockCollector : public tir::StmtVisitor { << "Duplicated block name " << block->name_hint << " in function " << func_name_ << " not supported!"; block_names_.insert(block->name_hint); - blocks_to_collect_.push_back(block->name_hint); + + // If filter function is provided, use it to selectively collect blocks. + // Otherwise collect all blocks. + Bool collect_block = Bool(true); + if (f_block_filter_ != nullptr) { + collect_block = f_block_filter_(GetRef(block)); + } + if (collect_block) { + blocks_to_collect_.push_back(block->name_hint); + } } /*! \brief The schedule to be collected */ const tir::Schedule& sch_; + /*! \brief An optional packed func that allows only certain blocks to be collected. */ + const runtime::PackedFunc f_block_filter_; /*! \brief The set of func name and block name pair */ std::unordered_set block_names_; /* \brief The list of blocks to collect in order */ @@ -81,6 +95,9 @@ class PostOrderApplyNode : public SpaceGeneratorNode { Array sch_rules_{nullptr}; /*! \brief The logging function to use. */ PackedFunc logging_func; + /*! \brief Optional block names to target. If not specified all blocks will have spaces generated. + */ + runtime::PackedFunc f_block_filter_ = nullptr; void VisitAttrs(tvm::AttrVisitor* v) { // `rand_state_` is not visited @@ -107,7 +124,7 @@ class PostOrderApplyNode : public SpaceGeneratorNode { Array result{sch}; // Enumerate the schedule rules first because you can // always concat multiple schedule rules as one - Array all_blocks = BlockCollector::Collect(sch); + Array all_blocks = BlockCollector::Collect(sch, f_block_filter_); Array> rules{NullOpt}; rules.insert(rules.end(), sch_rules_.begin(), sch_rules_.end()); for (Optional sch_rule : rules) { @@ -177,8 +194,9 @@ class PostOrderApplyNode : public SpaceGeneratorNode { TVM_DECLARE_FINAL_OBJECT_INFO(PostOrderApplyNode, SpaceGeneratorNode); }; -SpaceGenerator SpaceGenerator::PostOrderApply() { +SpaceGenerator SpaceGenerator::PostOrderApply(runtime::PackedFunc f_block_filter) { ObjectPtr n = make_object(); + n->f_block_filter_ = f_block_filter; return SpaceGenerator(n); } diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py b/tests/python/unittest/test_meta_schedule_post_order_apply.py index 21d29ac74d82..97a49602fb26 100644 --- a/tests/python/unittest/test_meta_schedule_post_order_apply.py +++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py @@ -195,6 +195,29 @@ def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: return result +@derived_object +class TrinityDoubleRule(PyScheduleRule): + def _initialize_with_tune_context(self, context: "TuneContext") -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: + if _is_root(sch, block): + return [sch] + new_sch = sch.copy() + i, j = new_sch.get_loops(block=block) + i_0, i_1 = new_sch.split(loop=i, factors=[16, 64]) + j_0, j_1 = new_sch.split(loop=j, factors=[64, 16]) + new_sch.reorder(i_0, j_0, i_1, j_1) + result = [new_sch] + new_sch = sch.copy() + i, j = new_sch.get_loops(block=block) + i_0, i_1 = new_sch.split(loop=i, factors=[2, 512]) + j_0, j_1 = new_sch.split(loop=j, factors=[2, 512]) + new_sch.reorder(i_0, j_0, i_1, j_1) + result.append(new_sch) + return result + + @derived_object class ReorderScheduleRule(PyScheduleRule): def _initialize_with_tune_context(self, context: "TuneContext") -> None: @@ -283,28 +306,6 @@ def test_meta_schedule_post_order_apply_duplicate_matmul(): def test_meta_schedule_post_order_apply_remove_block(): - @derived_object - class TrinityDouble(PyScheduleRule): - def _initialize_with_tune_context(self, context: "TuneContext") -> None: - pass - - def apply(self, sch: Schedule, block: BlockRV) -> List[Schedule]: - if _is_root(sch, block): - return [sch] - new_sch = sch.copy() - i, j = new_sch.get_loops(block=block) - i_0, i_1 = new_sch.split(loop=i, factors=[16, 64]) - j_0, j_1 = new_sch.split(loop=j, factors=[64, 16]) - new_sch.reorder(i_0, j_0, i_1, j_1) - result = [new_sch] - new_sch = sch.copy() - i, j = new_sch.get_loops(block=block) - i_0, i_1 = new_sch.split(loop=i, factors=[2, 512]) - j_0, j_1 = new_sch.split(loop=j, factors=[2, 512]) - new_sch.reorder(i_0, j_0, i_1, j_1) - result.append(new_sch) - return result - @derived_object class RemoveBlock(PyScheduleRule): def _initialize_with_tune_context(self, context: "TuneContext") -> None: @@ -342,7 +343,7 @@ def correct_trace(a, b, c, d): target=Target("llvm"), task_name="Remove Block Task", space_generator=PostOrderApply(), - sch_rules=[RemoveBlock(), TrinityDouble()], + sch_rules=[RemoveBlock(), TrinityDoubleRule()], ) post_order_apply = context.space_generator schs = post_order_apply.generate_design_space(mod) @@ -385,5 +386,40 @@ def custom_search_space_func(sch: Schedule, _: BlockRV) -> List[Schedule]: assert called +def test_target_blocks_search_space(): + # Test that specific blocks of trinity matmul can be targeted. + def filter_fn(block, target_names) -> bool: + return block.name_hint in target_names + + def _get_sch(filter_fn): + mod = TrinityMatmul + context = TuneContext( + mod=mod, + target=Target("llvm"), + task_name="Custom Search Space Task", + space_generator=PostOrderApply(f_block_filter=filter_fn), + sch_rules=[TrinityDoubleRule()], + ) + post_order_apply = context.space_generator + schs = post_order_apply.generate_design_space(mod) + return schs + + # Start by checking that by default each block has a space generated. + schs = _get_sch(None) + assert len(schs) == 8 + + # Next check that we can target a specific block and only get its' revelant schedules. + schs = _get_sch(lambda block: filter_fn(block, ["B"])) + assert len(schs) == 2 + + ## Check that extracting two blocks works. + schs = _get_sch(lambda block: filter_fn(block, ["A", "C"])) + assert len(schs) == 4 + + ## Finally check that all blocks can be extracted by name. + schs = _get_sch(lambda block: filter_fn(block, ["A", "B", "C"])) + assert len(schs) == 8 + + if __name__ == "__main__": tvm.testing.main() diff --git a/tests/python/unittest/test_meta_schedule_tune_tir.py b/tests/python/unittest/test_meta_schedule_tune_tir.py index 0e8c205230e6..6ab5f9b8c5c4 100644 --- a/tests/python/unittest/test_meta_schedule_tune_tir.py +++ b/tests/python/unittest/test_meta_schedule_tune_tir.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=missing-docstring +# pylint: disable=missing-docstring,no-member,invalid-name,unused-variable import logging import tempfile import numpy as np @@ -23,20 +23,19 @@ import tvm from tvm import meta_schedule as ms -from tvm.meta_schedule import TuneConfig, tune_tir +from tvm.meta_schedule import TuneContext, TuneConfig, tune_tir from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc from tvm.meta_schedule.testing.local_rpc import LocalRPC +from tvm.meta_schedule.schedule_rule import PyScheduleRule +from tvm.meta_schedule.utils import derived_object from tvm.script import tir as T from tvm.target import Target -from tvm.tir import Schedule +from tvm.tir.schedule import BlockRV, Schedule logging.basicConfig() logging.getLogger("tvm.meta_schedule").setLevel(logging.DEBUG) -# pylint: disable=no-member,invalid-name,unused-variable - - @T.prim_func def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer(a, [128, 128]) @@ -50,7 +49,19 @@ def matmul(a: T.handle, b: T.handle, c: T.handle) -> None: C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vj, vk] -# pylint: enable=no-member,invalid-name,unused-variable +@T.prim_func +def two_step(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.alloc_buffer((1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j in T.grid(1024, 1024): + with T.block("A"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(1024, 1024): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + C[vi, vj] = B[vi, vj] + 3.0 @pytest.mark.skip("Integration test") @@ -74,6 +85,37 @@ def test_tune_matmul_cpu(): print(sch.trace) +@pytest.mark.skip("Integration test") +def test_tune_block_cpu(): + @derived_object + class RemoveBlock(PyScheduleRule): + def _initialize_with_tune_context(self, context: TuneContext) -> None: + pass + + def apply(self, sch: Schedule, block: BlockRV): + if sch.get(block).name_hint == "root": + return [sch] + sch = sch.copy() + sch.compute_inline(block) + return [sch] + + with tempfile.TemporaryDirectory() as work_dir: + sch: Schedule = tune_tir( + mod=two_step, + target=Target("llvm --num-cores=16"), + config=TuneConfig( + strategy="replay_trace", + num_trials_per_iter=32, + max_trials_per_task=32, + max_trials_global=32, + ), + work_dir=work_dir, + blocks=["A"], + sch_rules=lambda *args: [RemoveBlock()], + ) + assert sch is not None + + @pytest.mark.skip("Integration test") def test_tune_matmul_cuda(): with tempfile.TemporaryDirectory() as work_dir: @@ -141,3 +183,4 @@ def f_timer(rt_mod, dev, input_data): test_tune_matmul_cpu() test_tune_matmul_cuda() test_tune_run_module_via_rpc() + test_tune_block_cpu()