From d213ac6cfb5e5a32ac61935ca3157ecb308ab379 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Tue, 11 Apr 2023 14:02:44 -0400 Subject: [PATCH] [TIR][Schedule] Method returning the function being worked on PR #11999 introduces the sugar method `work_on` to TIR Schedule, with a field `func_working_on_` newly added to the ScheduleNode. In some cases we may want to know which function a ScheduleNode is working on, which is not supported previously. Therefore, this PR introduces a method to ScheduleNode that returns the function (more accurately, GlobalVar) currently being worked on. With this we are able to know the function being worked on. --- include/tvm/tir/schedule/schedule.h | 2 ++ python/tvm/tir/schedule/schedule.py | 7 ++++++- src/tir/schedule/concrete_schedule.h | 1 + src/tir/schedule/schedule.cc | 2 ++ tests/python/unittest/test_tir_schedule_utilities.py | 1 + 5 files changed, 12 insertions(+), 1 deletion(-) diff --git a/include/tvm/tir/schedule/schedule.h b/include/tvm/tir/schedule/schedule.h index 570560c62d8c..e7b7e1f45340 100644 --- a/include/tvm/tir/schedule/schedule.h +++ b/include/tvm/tir/schedule/schedule.h @@ -115,6 +115,8 @@ class ScheduleNode : public runtime::Object { virtual ScheduleState state() const = 0; /*! \return The internally maintained trace of scheduling program execution */ virtual Optional trace() const = 0; + /*! \return The GlobalVar of the func that the schedule is currently working on */ + virtual Optional func_working_on() const = 0; /*! * \brief Instruct the schedule to work on a function in the IRModule. * diff --git a/python/tvm/tir/schedule/schedule.py b/python/tvm/tir/schedule/schedule.py index 68f0b9454cb1..7221fa48b0b9 100644 --- a/python/tvm/tir/schedule/schedule.py +++ b/python/tvm/tir/schedule/schedule.py @@ -19,7 +19,7 @@ from tvm._ffi import register_object as _register_object from tvm.error import TVMError, register_error -from tvm.ir import IRModule, PrimExpr +from tvm.ir import GlobalVar, IRModule, PrimExpr from tvm.runtime import Object, String from tvm.tir import Block, Buffer, FloatImm, For, IntImm, PrimFunc @@ -207,6 +207,11 @@ def trace(self) -> Optional[Trace]: """Returns the internally maintained trace of scheduling program execution""" return _ffi_api.ScheduleGetTrace(self) # type: ignore # pylint: disable=no-member + @property + def func_working_on(self) -> Optional[GlobalVar]: + """Returns the GlobalVar of the func that the schedule is currently working on""" + return _ffi_api.ScheduleGetFuncWorkingOn(self) # type: ignore # pylint: disable=no-member + def work_on(self, func_name: str) -> None: """Instruct the schedule to work on a function in the IRModule. diff --git a/src/tir/schedule/concrete_schedule.h b/src/tir/schedule/concrete_schedule.h index 227288b232d9..d68683c45fd8 100644 --- a/src/tir/schedule/concrete_schedule.h +++ b/src/tir/schedule/concrete_schedule.h @@ -64,6 +64,7 @@ class ConcreteScheduleNode : public ScheduleNode { public: ScheduleState state() const final { return state_; } Optional trace() const override { return NullOpt; } + Optional func_working_on() const final { return func_working_on_; } void WorkOn(const String& func_name) final; Schedule Copy() override; void Seed(support::LinearCongruentialEngine::TRandState seed) final; diff --git a/src/tir/schedule/schedule.cc b/src/tir/schedule/schedule.cc index a0e39b74d31b..ce28c39a81f1 100644 --- a/src/tir/schedule/schedule.cc +++ b/src/tir/schedule/schedule.cc @@ -50,6 +50,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetState") // .set_body_method(&ScheduleNode::state); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetTrace") // .set_body_method(&ScheduleNode::trace); +TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetFuncWorkingOn") // + .set_body_method(&ScheduleNode::func_working_on); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCopy") // .set_body_method(&ScheduleNode::Copy); TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed") // diff --git a/tests/python/unittest/test_tir_schedule_utilities.py b/tests/python/unittest/test_tir_schedule_utilities.py index 53ee6a58cd9a..0ce2f0ea914d 100644 --- a/tests/python/unittest/test_tir_schedule_utilities.py +++ b/tests/python/unittest/test_tir_schedule_utilities.py @@ -193,6 +193,7 @@ def test_tir_schedule_work_on(): sch.get_block(name="init") sch.work_on(func_name="vector_add") sch.get_block(name="init") + assert sch.func_working_on == sch.mod.get_global_var("vector_add") def test_tir_schedule_get_loops(use_block_name):