Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Add sugar method Schedule.work_on #11999

Merged
merged 2 commits into from
Jul 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,21 @@ class ScheduleNode : public runtime::Object {
virtual ScheduleState state() const = 0;
/*! \return The internally maintained trace of scheduling program execution */
virtual Optional<Trace> trace() const = 0;
/*!
* \brief Instruct the schedule to work on a function in the IRModule.
*
* By default, the schedule works on the function with the name "main", or the only function in
* the IRModule if there is only one. If there is multiple functions in the IRModule, and none of
* their names are "main", users will have to call this method to explicitly specify which
* function to work on.
*
* This sugar function will guide the `GetBlock` method if its `func_name` is not specified.
*
* \param func_name The name of the function to be working on
*
* \sa GetBlock
*/
virtual void WorkOn(const String& func_name) = 0;
/*!
* \brief Returns a copy of the schedule, including both its state and its symbol table,
* guaranteeing that
Expand Down Expand Up @@ -231,12 +246,19 @@ class ScheduleNode : public runtime::Object {
/******** Schedule: Get blocks & loops ********/
/*!
* \brief Retrieve a block in a specific function with its name
*
* By default, if `func_name` is not specified, the schedule will search for the block in the
* function that is currently being "worked on". To switch the function to be worked on, use
* `WorkOn` before calling this method.
*
* \param name The name of the block to be retrieved
* \param func_name The name of the function
* \return The block retrieved
* \note Indexing error is raised if 0 or multiple blocks exist with the specific name
*
* \sa WorkOn
*/
virtual BlockRV GetBlock(const String& name, const String& func_name = "main") = 0;
virtual BlockRV GetBlock(const String& name, const Optional<String>& func_name = NullOpt) = 0;
/*!
* \brief Get the parent loops of the block in its scope, from outer to inner
* \param block_rv The query block
Expand Down
25 changes: 23 additions & 2 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,23 @@ def trace(self) -> Optional[Trace]:
"""Returns the internally maintained trace of scheduling program execution"""
return _ffi_api.ScheduleGetTrace(self) # type: ignore # pylint: disable=no-member

def work_on(self, func_name: str) -> None:
"""Instruct the schedule to work on a function in the IRModule.

By default, the schedule works on the function with the name "main", or the only function in
the IRModule if there is only one. If there is multiple functions in the IRModule, and none
of their names are "main", users will have to call this method to explicitly specify which
function to work on.

This sugar function will guide the `GetBlock` method if its `func_name` is not specified.

Parameters
----------
func_name : str
The name of the function to work on.
"""
_ffi_api.ScheduleWorkOn(self, func_name) # type: ignore # pylint: disable=no-member

def copy(self) -> "Schedule":
"""Returns a copy of the schedule, including both the state and the symbol table,
* guaranteeing that
Expand Down Expand Up @@ -403,15 +420,19 @@ def sample_compute_location(
def get_block(
self,
name: str,
func_name: str = "main",
func_name: Optional[str] = None,
) -> BlockRV:
"""Retrieve a block in a specific function with its name

By default, if `func_name` is not specified, the schedule will search for the block in the
function that is currently being "worked on". To switch the function to be worked on, use
`work_on` before calling this method.

Parameters
----------
name : str
The name of the block
func_name : str = "main"
func_name : Optional[str] = None
The name of the function

Returns
Expand Down
41 changes: 41 additions & 0 deletions src/meta_schedule/arg_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,47 @@
namespace tvm {
namespace meta_schedule {

/*!
* \brief Find the entry function of the given IRModule, i.e, functions marked by
* `tir::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc.
* \param mod The IRModule to find the entry function.
* \return The entry function.
*/
inline tir::PrimFunc FindEntryFunc(const IRModule& mod) {
// Priority 1: PrimFunc marked as `tir::attr::kIsEntryFunc`
int num_prim_func = 0;
const tir::PrimFuncNode* main_func = nullptr;
const tir::PrimFuncNode* last_func = nullptr;
for (const auto& kv : mod->functions) {
GlobalVar gv = kv.first;
BaseFunc base_func = kv.second;
if (const auto* func = base_func.as<tir::PrimFuncNode>()) {
last_func = func;
if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
return GetRef<tir::PrimFunc>(func);
}
if (gv->name_hint == "main") {
main_func = func;
}
++num_prim_func;
}
}
// Priority 2: PrimFunc whose name is `main`
if (main_func != nullptr) {
return GetRef<tir::PrimFunc>(main_func);
}
// Priority 3: The only PrimFunc in the IRModule
if (num_prim_func == 0) {
LOG(FATAL) << "ValueError: Cannot find any PrimFunc in the given IRModule: "
<< tir::AsTVMScript(mod);
}
if (num_prim_func > 1) {
LOG(FATAL) << "ValueError: Multiple PrimFuncs exist in the IRModule, but none of them are "
"annotated with `kIsEntryFunc`, i.e. `tir.is_entry_func`"
<< tir::AsTVMScript(mod);
}
return GetRef<tir::PrimFunc>(last_func);
}
/******** ArgInfo ********/

ArgInfo ArgInfo::FromJSON(const ObjectRef& json_obj) {
Expand Down
3 changes: 2 additions & 1 deletion src/meta_schedule/mutator/mutate_parallel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ const BlockRVNode* GetInstGetBlockOutput(const Instruction& inst) {
std::vector<std::vector<int64_t>> AnalyzeParallel(const ScheduleState& self,
const String& block_name, const String& func_name,
int64_t limit) {
Array<StmtSRef> block_srefs = tir::GetBlocks(self, block_name, func_name);
Array<StmtSRef> block_srefs =
tir::GetBlocks(self, block_name, self->mod->GetGlobalVar(func_name));
ICHECK_EQ(block_srefs.size(), 1);
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_srefs[0]);
ScopeBlockLoopInfo info = GetScopeBlockLoopInfo(GetRef<Block>(block));
Expand Down
42 changes: 0 additions & 42 deletions src/meta_schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,48 +174,6 @@ inline String SHash2Hex(const ObjectRef& obj) {
return os.str();
}

/*!
* \brief Find the entry function of the given IRModule, i.e, functions marked by
* `tir::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc.
* \param mod The IRModule to find the entry function.
* \return The entry function.
*/
inline tir::PrimFunc FindEntryFunc(const IRModule& mod) {
// Priority 1: PrimFunc marked as `tir::attr::kIsEntryFunc`
int num_prim_func = 0;
const tir::PrimFuncNode* main_func = nullptr;
const tir::PrimFuncNode* last_func = nullptr;
for (const auto& kv : mod->functions) {
GlobalVar gv = kv.first;
BaseFunc base_func = kv.second;
if (const auto* func = base_func.as<tir::PrimFuncNode>()) {
last_func = func;
if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
return GetRef<tir::PrimFunc>(func);
}
if (gv->name_hint == "main") {
main_func = func;
}
++num_prim_func;
}
}
// Priority 2: PrimFunc whose name is `main`
if (main_func != nullptr) {
return GetRef<tir::PrimFunc>(main_func);
}
// Priority 3: The only PrimFunc in the IRModule
if (num_prim_func == 0) {
LOG(FATAL) << "ValueError: Cannot find any PrimFunc in the given IRModule: "
<< tir::AsTVMScript(mod);
}
if (num_prim_func > 1) {
LOG(FATAL) << "ValueError: Multiple PrimFuncs exist in the IRModule, but none of them are "
"annotated with `kIsEntryFunc`, i.e. `tir.is_entry_func`"
<< tir::AsTVMScript(mod);
}
return GetRef<tir::PrimFunc>(last_func);
}

/*!
* \brief Fork a random state into another, i.e. PRNG splitting.
* The given random state is also mutated.
Expand Down
9 changes: 9 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,15 @@ const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_bl
*/
StmtSRef GetSRefTreeRoot(const StmtSRef& sref);

/*!
* \brief Find the entry function of the given IRModule, i.e, functions marked by
* `tir::attr::kIsEntryFunc`, whose name is `main` or being the only PrimeFunc.
* \param mod The IRModule to find the entry function.
* \param result_g_var The result GlobalVar of the entry function.
* \return The entry function.
*/
const PrimFuncNode* FindEntryFunc(const IRModule& mod, GlobalVar* result_g_var);

/******** Scope ********/
/*!
* \brief Checks if scope the specified sref is in is a stage-pipeline and return it
Expand Down
41 changes: 41 additions & 0 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,47 @@ const PrimFuncNode* GetRootPrimFunc(const IRModule& mod, const StmtNode* root_bl
throw;
}

const PrimFuncNode* FindEntryFunc(const IRModule& mod, GlobalVar* result_g_var) {
GlobalVar result = NullValue<GlobalVar>();
// Priority 1: PrimFunc marked as `tir::attr::kIsEntryFunc`
int num_prim_func = 0;
const tir::PrimFuncNode* main_func = nullptr;
const tir::PrimFuncNode* last_func = nullptr;
for (const auto& kv : mod->functions) {
GlobalVar gv = kv.first;
BaseFunc base_func = kv.second;
if (const auto* func = base_func.as<tir::PrimFuncNode>()) {
last_func = func;
if (func->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
if (result_g_var != nullptr) {
*result_g_var = gv;
}
return func;
}
if (gv->name_hint == "main") {
main_func = func;
result = gv;
}
++num_prim_func;
}
}
// Priority 2: PrimFunc whose name is `main`
if (main_func != nullptr) {
if (result_g_var != nullptr) {
*result_g_var = result;
}
return main_func;
}
// Priority 3: The only PrimFunc in the IRModule
if (num_prim_func == 1) {
if (result_g_var != nullptr) {
*result_g_var = result;
}
return last_func;
}
return nullptr;
}

/******** Scope ********/

StmtSRef GetScopeRoot(const ScheduleState& self, const StmtSRef& sref,
Expand Down
25 changes: 23 additions & 2 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ Schedule Schedule::Concrete(IRModule mod, support::LinearCongruentialEngine::TRa
n->symbol_table_ = {};
n->analyzer_ = std::make_unique<arith::Analyzer>();
n->Seed(seed);
GlobalVar gv = NullValue<GlobalVar>();
if (FindEntryFunc(mod, &gv) != nullptr) {
n->func_working_on_ = gv;
} else {
n->func_working_on_ = NullOpt;
}
return Schedule(std::move(n));
}

Expand Down Expand Up @@ -177,13 +183,18 @@ class ScheduleCopier {
std::unordered_map<const StmtSRefNode*, StmtSRef> old2new_;
};

void ConcreteScheduleNode::WorkOn(const String& func_name) {
this->func_working_on_ = this->state_->mod->GetGlobalVar(func_name);
}

void ConcreteScheduleNode::Copy(ScheduleState* new_state, TSymbolTable* new_symbol_table) const {
ScheduleCopier::Copy(this, new_state, new_symbol_table);
new_state->get()->DebugVerify();
}

Schedule ConcreteScheduleNode::Copy() {
ObjectPtr<ConcreteScheduleNode> n = make_object<ConcreteScheduleNode>();
n->func_working_on_ = this->func_working_on_;
n->error_render_level_ = this->error_render_level_;
ConcreteScheduleNode::Copy(&n->state_, &n->symbol_table_);
n->analyzer_ = std::make_unique<arith::Analyzer>(); // new analyzer needed because it is stateful
Expand Down Expand Up @@ -251,7 +262,7 @@ LoopRV ConcreteScheduleNode::SampleComputeLocation(const BlockRV& block_rv,

/******** Schedule: Get blocks & loops ********/

BlockRV ConcreteScheduleNode::GetBlock(const String& name, const String& func_name) {
BlockRV ConcreteScheduleNode::GetBlock(const String& name, const Optional<String>& func_name) {
class NotSingleResult : public ScheduleError {
public:
explicit NotSingleResult(String name, IRModule mod, const Array<StmtSRef>& blocks)
Expand Down Expand Up @@ -286,7 +297,17 @@ BlockRV ConcreteScheduleNode::GetBlock(const String& name, const String& func_na
IRModule mod_;
Array<Block> blocks_;
};
Array<StmtSRef> blocks = tir::GetBlocks(this->state_, name, func_name);
GlobalVar gv = NullValue<GlobalVar>();
if (func_name.defined()) {
gv = state_->mod->GetGlobalVar(func_name.value());
} else if (func_working_on_.defined()) {
gv = this->func_working_on_.value();
} else {
LOG(FATAL) << "ValueError: `get_block` does not know which function to be working on. Please "
"specify the function name explicitly, or call `work_on` to specify the function "
"before using `get_block`.";
}
Array<StmtSRef> blocks = tir::GetBlocks(this->state_, name, gv);
if (blocks.size() != 1) {
TVM_TIR_SCHEDULE_BEGIN();
throw NotSingleResult(name, this->state_->mod, blocks);
Expand Down
8 changes: 6 additions & 2 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class ConcreteScheduleNode : public ScheduleNode {
protected:
/*! \brief The internal state of scheduling */
ScheduleState state_;
/*! \brief The function to be worked on. */
Optional<GlobalVar> func_working_on_;
/*! \brief The level of error rendering */
ScheduleErrorRenderLevel error_render_level_;
/*! \brief A symbol table that maps random variables to concrete StmtSRef/Integers */
Expand All @@ -50,17 +52,19 @@ class ConcreteScheduleNode : public ScheduleNode {
public:
void VisitAttrs(tvm::AttrVisitor* v) {
// `state_` is not visited
// `func_working_on_` is not visited
// `error_render_level_` is not visited
// `symbol_table_` is not visited
// `analyzer_` is not visited
// `rand_state_` is not visited
// `rgnd_state_` is not visited
}

virtual ~ConcreteScheduleNode() = default;

public:
ScheduleState state() const final { return state_; }
Optional<Trace> trace() const override { return NullOpt; }
void WorkOn(const String& func_name) final;
Schedule Copy() override;
void Seed(support::LinearCongruentialEngine::TRandState seed) final;
support::LinearCongruentialEngine::TRandState ForkSeed() final;
Expand Down Expand Up @@ -89,7 +93,7 @@ class ConcreteScheduleNode : public ScheduleNode {
LoopRV SampleComputeLocation(const BlockRV& block_rv,
Optional<Integer> decision = NullOpt) override;
/******** Schedule: Get blocks & loops ********/
BlockRV GetBlock(const String& name, const String& func_name = "main") override;
BlockRV GetBlock(const String& name, const Optional<String>& func_name) override;
Array<LoopRV> GetLoops(const BlockRV& block_rv) override;
Array<BlockRV> GetChildBlocks(const BlockRV& block_rv) override;
Array<BlockRV> GetChildBlocks(const LoopRV& loop_rv) override;
Expand Down
4 changes: 2 additions & 2 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ TVM_DLL tir::StmtSRef SampleComputeLocation(
* \brief Retrieves blocks in a specific function with its name
* \param self The schedule state
* \param name The name of the blocks to be retrieved
* \param func_name The name of the function
* \param gvar The function to be retrieved
* \return A list of blocks with the specific name
*/
Array<StmtSRef> GetBlocks(const ScheduleState& self, const String& name, const String& func_name);
Array<StmtSRef> GetBlocks(const ScheduleState& self, const String& name, const GlobalVar& gv);
/*!
* \brief Gets the parent loops of the block in its scope, from outer to inner
* \param self The schedule state
Expand Down
4 changes: 2 additions & 2 deletions src/tir/schedule/primitive/get_block_loop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
namespace tvm {
namespace tir {

Array<StmtSRef> GetBlocks(const ScheduleState& self, const String& name, const String& func_name) {
Array<StmtSRef> GetBlocks(const ScheduleState& self, const String& name, const GlobalVar& gv) {
struct Finder : public StmtVisitor {
explicit Finder(const ScheduleState& self, const String& name) : self_(self), name_(name) {}

Expand All @@ -39,7 +39,7 @@ Array<StmtSRef> GetBlocks(const ScheduleState& self, const String& name, const S
Array<StmtSRef> results_;
};

BaseFunc func = self->mod->Lookup(func_name);
BaseFunc func = self->mod->Lookup(gv);
const auto* prim_func = TVM_TYPE_AS(prim_func, func, PrimFuncNode);
Finder finder(self, name);
finder(prim_func->body);
Expand Down
2 changes: 2 additions & 0 deletions src/tir/schedule/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSeed") //
.set_body_method<Schedule>(&ScheduleNode::Seed);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleForkSeed") //
.set_body_method<Schedule>(&ScheduleNode::ForkSeed);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleWorkOn") //
.set_body_method<Schedule>(&ScheduleNode::WorkOn);

/**************** (FFI) Constructor ****************/

Expand Down
Loading