From 5279d132718290c8740d3b9c4781915070dbf5ee Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 12 Sep 2022 13:22:05 -0700 Subject: [PATCH] [MetaSchedule][Test] Migrate AddRFactor to use SEqual This PR migrates the usage of `check_trace` to `check_sketch`, which prefers structural equality of TIRs insteda of string equalty of traces. --- .../meta_schedule/testing/schedule_rule.py | 15 +- python/tvm/tir/schedule/testing.py | 8 +- .../schedule_rule/add_rfactor.cc | 5 +- src/tir/schedule/primitive/sampling.cc | 4 +- ...meta_schedule_schedule_rule_add_rfactor.py | 142 ++++++++++++------ 5 files changed, 109 insertions(+), 65 deletions(-) diff --git a/python/tvm/meta_schedule/testing/schedule_rule.py b/python/tvm/meta_schedule/testing/schedule_rule.py index 46df4b95ce07b..8b43034e35231 100644 --- a/python/tvm/meta_schedule/testing/schedule_rule.py +++ b/python/tvm/meta_schedule/testing/schedule_rule.py @@ -28,7 +28,9 @@ ReuseType, ScheduleRule, ) -from tvm.meta_schedule.schedule_rule.multi_level_tiling import MultiLevelTilingTensorCore +from tvm.meta_schedule.schedule_rule.multi_level_tiling import ( + MultiLevelTilingTensorCore, +) from tvm.target import Target @@ -64,13 +66,6 @@ def auto_inline(target: Target) -> ScheduleRule: raise NotImplementedError(f"{target.kind.name} is not supported") -def add_rfactor(target: Target) -> ScheduleRule: - """Default schedule rules for with add_rfactor""" - if target.kind.name == "llvm": - return AddRFactor(max_jobs_per_core=16, max_innermost_factor=64) - raise NotImplementedError(f"{target.kind.name} is not supported") - - def cross_thread_reduction(target: Target) -> ScheduleRule: """Default schedule rules for with cross-thread reduction""" if target.kind.name == "cuda": @@ -131,7 +126,9 @@ def multi_level_tiling_tensor_core( trans_b = [trans_b] if target.kind.name == "cuda": - from tvm.tir.tensor_intrin import cuda # pylint: disable=import-outside-toplevel + from tvm.tir.tensor_intrin import ( + cuda, # pylint: disable=import-outside-toplevel + ) intrin_groups = [ cuda.get_wmma_intrin_group(write_reuse_scope, _in_dtype, _out_dtype, _trans_b) diff --git a/python/tvm/tir/schedule/testing.py b/python/tvm/tir/schedule/testing.py index 3689f756e83cf..891a7436691d2 100644 --- a/python/tvm/tir/schedule/testing.py +++ b/python/tvm/tir/schedule/testing.py @@ -15,12 +15,12 @@ # specific language governing permissions and limitations # under the License. """Testing utilities for the TensorIR schedule API""" -from typing import Union, Sequence +from typing import Sequence, Union import tvm -from tvm.ir import IRModule, structural_equal +from tvm.ir import IRModule, assert_structural_equal, structural_equal from tvm.tir import PrimFunc -from tvm.tir.schedule import Trace, Schedule +from tvm.tir.schedule import Schedule, Trace def verify_trace_roundtrip( @@ -70,7 +70,7 @@ def verify_trace_roundtrip( assert text_format in ("json", "python"), f"Unknown text format: {text_format}" # Step 2. Verify that the round-trip produced the same scheduling - assert structural_equal(new_sch.mod, sch.mod) + assert_structural_equal(new_sch.mod, sch.mod) # Step 3. Check the consistency of the text format between the old and new traces py_repr = "\n".join(trace.as_python()) diff --git a/src/meta_schedule/schedule_rule/add_rfactor.cc b/src/meta_schedule/schedule_rule/add_rfactor.cc index 5ef2ac3aad367..cf87f24ac2336 100644 --- a/src/meta_schedule/schedule_rule/add_rfactor.cc +++ b/src/meta_schedule/schedule_rule/add_rfactor.cc @@ -90,8 +90,7 @@ Array AddRFactorNode::Apply(const tir::Schedule& sch, const tir:: // Split the fused reduction loop. Array factors = sch->SamplePerfectTile(fused_reduce_loop, 2, max_innermost_factor); - const Array& split_loops = - sch->Split(fused_reduce_loop, {factors.begin(), factors.end()}); + Array split_loops = sch->Split(fused_reduce_loop, {factors.begin(), factors.end()}); Array res; for (const tir::LoopRV& split_loop : split_loops) { @@ -104,7 +103,7 @@ Array AddRFactorNode::Apply(const tir::Schedule& sch, const tir:: // Annotate that the rfactor block, which is now the producer of the original block, needs to // be considered by the rule Random-Compute-Location. - sch_tmp->Annotate(block_rv, tir::attr::meta_schedule_random_compute_producer, Bool(true)); + sch_tmp->Annotate(block_rv, tir::attr::meta_schedule_random_compute_producer, Integer(1)); res.push_back(sch_tmp); } catch (const tvm::runtime::Error& e) { } diff --git a/src/tir/schedule/primitive/sampling.cc b/src/tir/schedule/primitive/sampling.cc index b1001a7f94550..ec12b045d3f0f 100644 --- a/src/tir/schedule/primitive/sampling.cc +++ b/src/tir/schedule/primitive/sampling.cc @@ -338,7 +338,9 @@ std::vector SamplePerfectTile( } else { // Case 3. Use fresh new sampling result result = SamplePerfectTile(rand_state, *extent, n_splits, max_innermost_factor); - ICHECK_LE(result.back(), max_innermost_factor); + if (max_innermost_factor != -1) { + ICHECK_LE(result.back(), max_innermost_factor); + } } *decision = support::AsArray(result); return result; diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py b/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py index a39c8aea5fb6a..17f42654fcf7b 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_add_rfactor.py @@ -15,62 +15,108 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring - -from tvm.meta_schedule.space_generator.post_order_apply import PostOrderApply +from tvm import meta_schedule as ms from tvm.meta_schedule.testing import te_workload -from tvm.meta_schedule.testing.schedule_rule import add_rfactor -from tvm.meta_schedule.testing.space_generation import check_trace -from tvm.meta_schedule.tune_context import TuneContext +from tvm.meta_schedule.testing.space_generation import check_sketches +from tvm.script import tir as T from tvm.target import Target -from tvm.te.operation import create_prim_func +from tvm.te import create_prim_func -def _create_context(mod, target, rule) -> TuneContext: - ctx = TuneContext( - mod=mod, - target=target, - space_generator=PostOrderApply(), - sch_rules=[rule], - task_name="test", - ) - return ctx +def test_cpu_matmul(): + @T.prim_func + def cpu_matmul_0( + A: T.Buffer[(4, 512), "float32"], + B: T.Buffer[(512, 4), "float32"], + C: T.Buffer[(4, 4), "float32"], + ) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i0, i1, i2 in T.grid(4, 4, 512): + with T.block("C"): + i, j, k = T.axis.remap("SSR", [i0, i1, i2]) + T.reads(A[i, k], B[k, j]) + T.writes(C[i, j]) + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + A[i, k] * B[k, j] + @T.prim_func + def cpu_matmul_1( + A: T.Buffer[(4, 512), "float32"], + B: T.Buffer[(512, 4), "float32"], + C: T.Buffer[(4, 4), "float32"], + ) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + C_rf = T.alloc_buffer([4, 4, 128], dtype="float32") + for i0, i1, i2_0, i2_1 in T.grid(4, 4, 4, 128): + with T.block("C_rf"): + vi2_1, i, j, vi2_0 = T.axis.remap("SSSR", [i2_1, i0, i1, i2_0]) + T.reads(A[i, vi2_0 * 128 + vi2_1], B[vi2_0 * 128 + vi2_1, j]) + T.writes(C_rf[i, j, vi2_1]) + with T.init(): + C_rf[i, j, vi2_1] = T.float32(0) + C_rf[i, j, vi2_1] = ( + C_rf[i, j, vi2_1] + A[i, vi2_0 * 128 + vi2_1] * B[vi2_0 * 128 + vi2_1, j] + ) + for i0, i1, i2_1 in T.grid(4, 4, 128): + with T.block("C"): + vi2_1, i, j = T.axis.remap("RSS", [i2_1, i0, i1]) + T.reads(C_rf[i, j, vi2_1]) + T.writes(C[i, j]) + T.block_attr({"meta_schedule.random_compute_producer": 1}) + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + C_rf[i, j, vi2_1] -def test_cpu_matmul(): - expected = [ - [], - [ - 'b0 = sch.get_block(name="C", func_name="main")', - "l1, l2, l3 = sch.get_loops(block=b0)", - "v4, v5 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", - "l6, l7 = sch.split(loop=l3, factors=[v4, v5], preserve_unit_iters=True)", - "b8 = sch.rfactor(loop=l7, factor_axis=2)", - 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)', - ], - [ - 'b0 = sch.get_block(name="C", func_name="main")', - "l1, l2, l3 = sch.get_loops(block=b0)", - "v4, v5 = sch.sample_perfect_tile(loop=l3, n=2, max_innermost_factor=64)", - "l6, l7 = sch.split(loop=l3, factors=[v4, v5], preserve_unit_iters=True)", - "b8 = sch.rfactor(loop=l6, factor_axis=2)", - 'sch.annotate(block_or_loop=b0, ann_key="meta_schedule.random_compute_producer", ann_val=1)', - ], + @T.prim_func + def cpu_matmul_2( + A: T.Buffer[(4, 512), "float32"], + B: T.Buffer[(512, 4), "float32"], + C: T.Buffer[(4, 4), "float32"], + ) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + C_rf = T.alloc_buffer([4, 4, 4], dtype="float32") + for i0, i1, i2_0, i2_1 in T.grid(4, 4, 4, 128): + with T.block("C_rf"): + vi2_0, i, j, vi2_1 = T.axis.remap("SSSR", [i2_0, i0, i1, i2_1]) + T.reads(A[i, vi2_0 * 128 + vi2_1], B[vi2_0 * 128 + vi2_1, j]) + T.writes(C_rf[i, j, vi2_0]) + with T.init(): + C_rf[i, j, vi2_0] = T.float32(0) + C_rf[i, j, vi2_0] = ( + C_rf[i, j, vi2_0] + A[i, vi2_0 * 128 + vi2_1] * B[vi2_0 * 128 + vi2_1, j] + ) + for i0, i1, i2_0 in T.grid(4, 4, 4): + with T.block("C"): + vi2_0, i, j = T.axis.remap("RSS", [i2_0, i0, i1]) + T.reads(C_rf[i, j, vi2_0]) + T.writes(C[i, j]) + T.block_attr({"meta_schedule.random_compute_producer": 1}) + with T.init(): + C[i, j] = T.float32(0) + C[i, j] = C[i, j] + C_rf[i, j, vi2_0] + + decision_0 = [] # type: ignore + decision_1 = [ + ("SamplePerfectTile", [4, 128]), + ] + decision_2 = [ + ("SamplePerfectTile", [4, 128]), ] - target = Target("llvm --num-cores=32") - ctx = _create_context( - create_prim_func( - te_workload.matmul( - n=4, - m=4, - k=512, - ) - ), - target=target, - rule=add_rfactor(target=target), + mod = create_prim_func(te_workload.matmul(n=4, m=4, k=512)) + actual = ms.TuneContext( + mod=mod, + target=Target("llvm --num-cores=32"), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=[ms.schedule_rule.AddRFactor()], + task_name="test", + ).generate_design_space() + check_sketches( + mod, + sketches=actual, + expected_mods=[cpu_matmul_0, cpu_matmul_1, cpu_matmul_2], + expected_decisions=[decision_0, decision_1, decision_2], ) - spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) - assert len(spaces) == 3 - check_trace(spaces, expected) if __name__ == "__main__":