Skip to content

Commit

Permalink
[MetaSchedule] Developer Ergonomics Enhancement II
Browse files Browse the repository at this point in the history
Follow-up of #11622, per discussion with @Kathryn-cat

- [x] Allow using a string `"default"` in `TuneContext` to quickly specify a set of target-specific
rules
- [x] Enhance detection of `ScheduleFn` in `TuneContext` to make it easier for users to quickly try
out template-driven scheduling on TIR.

Next PR:
- Add `TuneContext.tune` to allow directly tuning without task scheduler.

Co-Authored-By: Kathryn (Jinqi) Chen <65606304+Kathryn-cat@users.noreply.github.com>
  • Loading branch information
junrushao and Kathryn-cat committed Jun 15, 2022
1 parent 1312658 commit 34ddf43
Show file tree
Hide file tree
Showing 27 changed files with 70 additions and 68 deletions.
6 changes: 3 additions & 3 deletions python/tvm/meta_schedule/space_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
Meta Schedule design space generators that generates design
space for generation of measure candidates.
"""
from .space_generator import SpaceGenerator, PySpaceGenerator
from .space_generator_union import SpaceGeneratorUnion
from .schedule_fn import ScheduleFn
from .post_order_apply import PostOrderApply
from .schedule_fn import SCH_FN_TYPE, ScheduleFn
from .space_generator import PySpaceGenerator, SpaceGenerator
from .space_generator_union import SpaceGeneratorUnion
13 changes: 6 additions & 7 deletions python/tvm/meta_schedule/space_generator/schedule_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,17 @@
if TYPE_CHECKING:
from ..tune_context import TuneContext

SCH_FN_TYPE = Union[ # pylint: disable=invalid-name
Callable[[Schedule], None], # No output
Callable[[Schedule], Schedule], # Single output
Callable[[Schedule], List[Schedule]], # Multiple outputs
]


@derived_object
class ScheduleFn(PySpaceGenerator):
"""A design space generator with design spaces specified by a schedule function."""

# Multiple cases of schedule functions supported
SCH_FN_TYPE = Union[
Callable[[Schedule], None], # No output
Callable[[Schedule], Schedule], # Single output
Callable[[Schedule], List[Schedule]], # Multiple outputs
]

def __init__(self, sch_fn: SCH_FN_TYPE):
"""Constructor.
Expand Down
3 changes: 1 addition & 2 deletions python/tvm/meta_schedule/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class TuneConfig(NamedTuple):
search_strategy_config: Optional[Dict[str, Any]] = None
logger_config: Optional[Dict[str, Any]] = None

def create_strategy(self, **kwargs):
def create_strategy(self):
"""Create search strategy from configuration"""
cls_tbl = {
"evolutionary": EvolutionarySearch,
Expand All @@ -111,7 +111,6 @@ def create_strategy(self, **kwargs):
return cls_tbl[self.strategy](
num_trials_per_iter=self.num_trials_per_iter,
max_trials_per_task=max_trials_per_task,
**kwargs,
**config,
)

Expand Down
85 changes: 59 additions & 26 deletions python/tvm/meta_schedule/tune_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""Meta Schedule tuning context."""

import logging
from typing import TYPE_CHECKING, Dict, List, Optional
from typing import TYPE_CHECKING, Dict, List, Optional, Union

from tvm import IRModule
from tvm._ffi import register_object
Expand All @@ -36,7 +36,8 @@
from .runner import RunnerResult
from .schedule_rule import ScheduleRule
from .search_strategy import MeasureCandidate, SearchStrategy
from .space_generator import SpaceGenerator
from .space_generator import SCH_FN_TYPE, ScheduleFn, SpaceGenerator
from .tune import TuneConfig


@register_object("meta_schedule.TuneContext")
Expand All @@ -54,16 +55,24 @@ class TuneContext(Object):
The workload to be optimized.
target : Optional[Target] = None
The target to be optimized for.
space_generator : Optional[SpaceGenerator] = None
space_generator : Union[None, SCH_FN_TYPE, SpaceGenerator] = None
The design space generator.
search_strategy : Optional[SearchStrategy] = None
search_strategy : Union[None, TuneConfig, SearchStrategy] = None
The search strategy.
sch_rules: Optional[List[ScheduleRule]] = None,
if None, the strategy is left blank.
If TuneConfig, the strategy is initialized with the TuneConfig.create_strategy().
sch_rules: Union[None, str, List[ScheduleRule]] = None,
The schedule rules.
postprocs: Optional[List[Postproc"]] = None,
If None, use an empty list of rules.
if "default", use target-default rules.
postprocs: Union[None, str, List[Postproc"]] = None,
The postprocessors.
mutator_probs: Optional[Dict[Mutator, float]]
If None, use an empty list of rules.
if "default", use target-default rules.
mutator_probs: Union[None, str, Dict[Mutator, float]]
Mutators and their probability mass.
If None, use an empty list of rules.
if "default", use target-default rules.
task_name : Optional[str] = None
The name of the tuning task.
logger : logging.Logger
Expand Down Expand Up @@ -99,24 +108,53 @@ def __init__(
mod: Optional[IRModule] = None,
*,
target: Optional[Target] = None,
space_generator: Optional["SpaceGenerator"] = None,
search_strategy: Optional["SearchStrategy"] = None,
sch_rules: Optional[List["ScheduleRule"]] = None,
postprocs: Optional[List["Postproc"]] = None,
mutator_probs: Optional[Dict["Mutator", float]] = None,
space_generator: Union[None, "SCH_FN_TYPE", "ScheduleFn", "SpaceGenerator"] = None,
search_strategy: Union[None, "SearchStrategy", "TuneConfig"] = None,
sch_rules: Union[None, str, List["ScheduleRule"]] = None,
postprocs: Union[None, str, List["Postproc"]] = None,
mutator_probs: Union[None, str, Dict["Mutator", float]] = None,
task_name: str = "main",
logger: Optional[logging.Logger] = None,
rand_state: int = -1,
num_threads: Optional[int] = None,
):
# pylint: disable=import-outside-toplevel
from . import default_config
from .space_generator import ScheduleFn
from .tune import TuneConfig

# pylint: enable=import-outside-toplevel
if isinstance(mod, PrimFunc):
mod = IRModule.from_expr(mod)
if num_threads is None:
num_threads = cpu_count()
if callable(space_generator):
space_generator = ScheduleFn(space_generator)
if isinstance(search_strategy, TuneConfig):
search_strategy = search_strategy.create_strategy()
if isinstance(sch_rules, str):
if sch_rules == "default":
if target is None:
raise ValueError("target is required when sch_rules is 'default'")
sch_rules = default_config.schedule_rules(None, target)
else:
raise ValueError("sch_rules should be a list of ScheduleRule or 'default'")
if isinstance(postprocs, str):
if postprocs == "default":
if target is None:
raise ValueError("target is required when postprocs is 'default'")
postprocs = default_config.postproc(None, target)
else:
raise ValueError("postprocs should be a list of Postproc or 'default'")
if isinstance(mutator_probs, str):
if mutator_probs == "default":
if target is None:
raise ValueError("target is required when mutator_probs is 'default'")
mutator_probs = default_config.mutator_probs(None, target)
if logger is None:
self.logger = logging.getLogger(__name__)
else:
self.logger = None
if num_threads is None:
num_threads = cpu_count()
self.__init_handle_by_constructor__(
_ffi_api.TuneContext, # type: ignore # pylint: disable=no-member
mod,
Expand All @@ -131,9 +169,6 @@ def __init__(
rand_state,
num_threads,
)

def initialize(self):
"""Initialize the tuning context"""
_ffi_api.TuneContextInitialize(self) # type: ignore # pylint: disable=no-member

def generate_design_space(self) -> List[Schedule]:
Expand All @@ -157,7 +192,7 @@ def generate_design_space(self) -> List[Schedule]:

def pre_tuning(
self,
design_spaces: List[Schedule],
design_spaces: Optional[List[Schedule]] = None,
database: Optional["Database"] = None,
cost_model: Optional["CostModel"] = None,
) -> None:
Expand All @@ -167,7 +202,7 @@ def pre_tuning(
Parameters
----------
design_spaces : List[Schedule]
design_spaces : Optional[List[Schedule]]
The design spaces used during tuning process.
database : Optional[Database] = None
The database used during tuning process.
Expand All @@ -179,6 +214,8 @@ def pre_tuning(
"search_strategy is not provided."
"Please construct TuneContext with search_strategy"
)
if design_spaces is None:
design_spaces = self.generate_design_space()
return self.search_strategy.pre_tuning(design_spaces, database, cost_model)

def post_tuning(self) -> None:
Expand All @@ -191,7 +228,7 @@ def post_tuning(self) -> None:
"search_strategy is not provided."
"Please construct TuneContext with search_strategy"
)
_ffi_api.SearchStrategyPostTuning(self) # type: ignore # pylint: disable=no-member
return self.search_strategy.post_tuning()

def generate_measure_candidates(self) -> Optional[List["MeasureCandidate"]]:
"""Generate a batch of measure candidates from design spaces for measurement.
Expand All @@ -208,7 +245,7 @@ def generate_measure_candidates(self) -> Optional[List["MeasureCandidate"]]:
"search_strategy is not provided."
"Please construct TuneContext with search_strategy"
)
return _ffi_api.SearchStrategyGenerateMeasureCandidates(self) # type: ignore # pylint: disable=no-member
return self.search_strategy.generate_measure_candidates()

def notify_runner_results(
self,
Expand All @@ -231,8 +268,4 @@ def notify_runner_results(
"search_strategy is not provided."
"Please construct TuneContext with search_strategy"
)
_ffi_api.SearchStrategyNotifyRunnerResults( # type: ignore # pylint: disable=no-member
self,
measure_candidates,
results,
)
return self.search_strategy.notify_runner_results(measure_candidates, results)
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,6 @@ def test_conv2d_winograd_cpu():
target,
),
)
context.initialize()
post_order_apply = context.space_generator
(sch,) = post_order_apply.generate_design_space(mod)
decisions = dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,6 @@ def test_conv2d_winograd_cuda():
None, Target("cuda")
),
)
context.initialize()
post_order_apply = context.space_generator
(sch,) = post_order_apply.generate_design_space(mod)
decisions = dict(
Expand Down
2 changes: 1 addition & 1 deletion tests/python/unittest/test_meta_schedule_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def traverse(t):
mod,
target="llvm",
params=params,
filter_func=filter_func,
te_filter_func=filter_func,
)
expected_task_names = [
"fused_" + s
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def _make_mutator(target: Target) -> Mutator:
MutateComputeLocation(): 1.0,
},
)
ctx.initialize()
return list(ctx.mutator_probs.keys())[0]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def _make_mutator(target: Target, max_jobs_per_core: int) -> Mutator:
MutateParallel(max_jobs_per_core): 1.0,
},
)
ctx.initialize()
return list(ctx.mutator_probs.keys())[0]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ def _make_mutator(target: Target) -> Mutator:
MutateThreadBinding(): 1.0,
},
)
ctx.initialize()
return list(ctx.mutator_probs.keys())[0]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ def _make_mutator(target: Target) -> Mutator:
target=target,
mutator_probs={MutateTileSize(): 1.0},
)
ctx.initialize()
return list(ctx.mutator_probs.keys())[0]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def _make_mutator(target: Target) -> Mutator:
MutateUnroll(): 1.0,
},
)
ctx.initialize()
return list(ctx.mutator_probs.keys())[0]


Expand Down
6 changes: 0 additions & 6 deletions tests/python/unittest/test_meta_schedule_post_order_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ def test_meta_schedule_post_order_apply():
space_generator=PostOrderApply(),
sch_rules=[WowSoFancyScheduleRule()],
)
context.initialize()
post_order_apply = context.space_generator
schs = post_order_apply.generate_design_space(mod)
assert len(schs) == 1
Expand All @@ -240,7 +239,6 @@ def test_meta_schedule_post_order_apply_double():
space_generator=PostOrderApply(),
sch_rules=[DoubleScheduleRule()],
)
context.initialize()
post_order_apply = context.space_generator
schs = post_order_apply.generate_design_space(mod)
assert len(schs) == 2
Expand All @@ -258,7 +256,6 @@ def test_meta_schedule_post_order_apply_multiple():
space_generator=PostOrderApply(),
sch_rules=[DoubleScheduleRule(), ReorderScheduleRule()],
)
context.initialize()
post_order_apply = context.space_generator
schs = post_order_apply.generate_design_space(mod)
assert len(schs) == 4
Expand All @@ -276,7 +273,6 @@ def test_meta_schedule_post_order_apply_duplicate_matmul():
space_generator=PostOrderApply(),
sch_rules=[WowSoFancyScheduleRule()],
)
context.initialize()
post_order_apply = context.space_generator
with pytest.raises(
TVMError,
Expand Down Expand Up @@ -348,7 +344,6 @@ def correct_trace(a, b, c, d):
space_generator=PostOrderApply(),
sch_rules=[RemoveBlock(), TrinityDouble()],
)
context.initialize()
post_order_apply = context.space_generator
schs = post_order_apply.generate_design_space(mod)
assert len(schs) == 4
Expand Down Expand Up @@ -376,7 +371,6 @@ def test_meta_schedule_custom_search_space():
space_generator=PostOrderApply(),
sch_rules=[],
)
context.initialize()
post_order_apply = context.space_generator
post_order_apply.generate_design_space(mod)
called = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def _create_context(mod, target) -> TuneContext:
],
task_name="test",
)
ctx.initialize()
return ctx


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def _create_context(mod, target) -> TuneContext:
],
task_name="test",
)
ctx.initialize()
return ctx


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def _create_context(mod, target) -> TuneContext:
],
task_name="test",
)
ctx.initialize()
return ctx


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,6 @@ def _create_context(mod, target, postprocs):
postprocs=postprocs,
task_name="test",
)
ctx.initialize()
return ctx


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def _create_context(mod, target) -> TuneContext:
],
task_name="test",
)
ctx.initialize()
return ctx


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def _create_context(mod, target) -> TuneContext:
],
task_name="test",
)
ctx.initialize()
return ctx


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def _create_context(mod, target, rule) -> TuneContext:
sch_rules=[rule],
task_name="test",
)
ctx.initialize()
return ctx


Expand Down
Loading

0 comments on commit 34ddf43

Please sign in to comment.