Skip to content

Commit

Permalink
[TIR][Schedule] Method returning the function being worked on
Browse files Browse the repository at this point in the history
PR apache#11999 introduces the sugar method `work_on` to TIR Schedule, with
a field `func_working_on_` newly added to the ScheduleNode. In some
cases we may want to know which function a ScheduleNode is working on,
which is not supported previously.

Therefore, this PR introduces a method to ScheduleNode that returns
the function (more accurately, GlobalVar) currently being worked on.
With this we are able to know the function being worked on.
  • Loading branch information
MasterJH5574 committed Apr 11, 2023
1 parent eeae66b commit d213ac6
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 1 deletion.
2 changes: 2 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ class ScheduleNode : public runtime::Object {
virtual ScheduleState state() const = 0;
/*! \return The internally maintained trace of scheduling program execution */
virtual Optional<Trace> trace() const = 0;
/*! \return The GlobalVar of the func that the schedule is currently working on */
virtual Optional<GlobalVar> func_working_on() const = 0;
/*!
* \brief Instruct the schedule to work on a function in the IRModule.
*
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from tvm._ffi import register_object as _register_object
from tvm.error import TVMError, register_error
from tvm.ir import IRModule, PrimExpr
from tvm.ir import GlobalVar, IRModule, PrimExpr
from tvm.runtime import Object, String
from tvm.tir import Block, Buffer, FloatImm, For, IntImm, PrimFunc

Expand Down Expand Up @@ -207,6 +207,11 @@ def trace(self) -> Optional[Trace]:
"""Returns the internally maintained trace of scheduling program execution"""
return _ffi_api.ScheduleGetTrace(self) # type: ignore # pylint: disable=no-member

@property
def func_working_on(self) -> Optional[GlobalVar]:
"""Returns the GlobalVar of the func that the schedule is currently working on"""
return _ffi_api.ScheduleGetFuncWorkingOn(self) # type: ignore # pylint: disable=no-member

def work_on(self, func_name: str) -> None:
"""Instruct the schedule to work on a function in the IRModule.
Expand Down
1 change: 1 addition & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class ConcreteScheduleNode : public ScheduleNode {
public:
ScheduleState state() const final { return state_; }
Optional<Trace> trace() const override { return NullOpt; }
Optional<GlobalVar> func_working_on() const final { return func_working_on_; }
void WorkOn(const String& func_name) final;
Schedule Copy() override;
void Seed(support::LinearCongruentialEngine::TRandState seed) final;
Expand Down
2 changes: 2 additions & 0 deletions src/tir/schedule/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetState") //
.set_body_method<Schedule>(&ScheduleNode::state);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetTrace") //
.set_body_method<Schedule>(&ScheduleNode::trace);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleGetFuncWorkingOn") //
.set_body_method<Schedule>(&ScheduleNode::func_working_on);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleCopy") //
.set_body_method<Schedule>(&ScheduleNode::Copy);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed") //
Expand Down
1 change: 1 addition & 0 deletions tests/python/unittest/test_tir_schedule_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ def test_tir_schedule_work_on():
sch.get_block(name="init")
sch.work_on(func_name="vector_add")
sch.get_block(name="init")
assert sch.func_working_on == sch.mod.get_global_var("vector_add")


def test_tir_schedule_get_loops(use_block_name):
Expand Down

0 comments on commit d213ac6

Please sign in to comment.