diff --git a/include/tvm/ir/transform.h b/include/tvm/ir/transform.h new file mode 100644 index 000000000000..b9d9d855475b --- /dev/null +++ b/include/tvm/ir/transform.h @@ -0,0 +1,330 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/transform.h + * + * This file implements a pass manager. The pass manager manages a sequence + * of IRModule -> IRModule transformation passes over a particlar unit of AST. The + * design is largely inspired from LLVM's pass manager and modern deep learning + * frameworks that perform tensor->tensor transformations. + * + * The responsibilities of a traditional compiler pass manager usually involves: + * - Organizing the execution order of optimization passes though not + * necessarily in the optimal sequence. + * - Collecting required analysis information and keep them up-to-date. + * - Reducing the effort required to implement new passes for compiler + * developers, etc. + * + * Similar to LLVM's pass manager, we designed the Relay pass manager to work + * different granularity, i.e. module level, function level, and even sequential + * passe that contains a host of passes. + * + * However, we also extend the functionality of the traditional pass manager + * with the consideration of requirements/convention from deep learning + * frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay pass + * manager performs the IRModule -> IRModule transformation. All + * different types of passes, including the sequential-level pass object, are + * essentially pass objects. This design, therefore, effectively provides users + * a consistent and convenient interface, i.e. Pass, to play with. It offers a + * means to ease the development and testing of Relay passes. For example, with + * the pass manager, external users will be able to have custom passes correctly + * scheduled without having to modify a single handcrafted pass order. + * + * In the future we need to describe constraints between passes. For example, + * we may want to preserve dependencies between different passes and validate + * them on the completion of a certain pass. + * + * We also need to store side information and import the error reporting system. + */ +#ifndef TVM_IR_TRANSFORM_H_ +#define TVM_IR_TRANSFORM_H_ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace transform { + +/*! + * \brief PassContextNode contains the information that a pass can rely on, + * such as analysis results. + * \sa PassContext + */ +class PassContextNode : public Object { + public: + /*! + * \brief The error reporter used to notify users why an optimization fails. + */ + ErrorReporter err_reporter; + + /*! \brief The default optimization level. */ + int opt_level{2}; + + /*! \brief CPU is the default fallback device for heterogeneous execution. */ + int fallback_device{static_cast(kDLCPU)}; + + /*! \brief The list of required passes. */ + tvm::Array required_pass; + /*! \brief The list of disabled passes. */ + tvm::Array disabled_pass; + + PassContextNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("opt_level", &opt_level); + v->Visit("fallback_device", &fallback_device); + v->Visit("required_pass", &required_pass); + v->Visit("disabled_pass", &disabled_pass); + } + + static constexpr const char* _type_key = "relay.PassContext"; + TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, Object); +}; + +/*! + * \brief PassContext that is used to configure the pass behavior. + * + * \code + * + * auto new_ctx = PassContext::Create(); + * ctx->opt_level = 2; + * ctx->fallback_device = kDLCPU; + * With scope(ctx); + * // pass context in effect. + * + * \endcode + * \sa PassContextNode + */ +class PassContext : public ObjectRef { + public: + PassContext() {} + explicit PassContext(ObjectPtr<::tvm::Object> n) : ObjectRef(n) {} + /*! + * \brief const accessor. + * \return const access pointer. + */ + const PassContextNode* operator->() const { + CHECK(get() != nullptr); + return static_cast(get()); + } + /*! + * \brief mutable accessor. + * \return mutable access pointer. + */ + PassContextNode* operator->() { + CHECK(get() != nullptr); + return static_cast(get_mutable()); + } + /*! + * \brief Construct a PassContext containing the default configurations. + * \return The new PassContext. + */ + TVM_DLL static PassContext Create(); + /*! + * \brief Get the default pass context in the current scope. + * \return The pass context. + */ + TVM_DLL static PassContext Current(); + + // accessor. + using ContainerType = PassContextNode; + class Internal; + + private: + // The entry of a pass context scope. + TVM_DLL void EnterWithScope(); + // The exit of a pass context scope. + TVM_DLL void ExitWithScope(); + + // Classes to get the Python `with` like syntax. + friend class Internal; + friend class tvm::With; +}; + +/*! + * \brief Meta data that will be used to help optimization and analysis. + * \sa PassInfo + */ +class PassInfoNode : public Object { + public: + /*! \brief The minimal optimization level that this pass will be enabled. */ + int opt_level; + + /*! \brief The name of an optimization/analysis pass. */ + std::string name; + + /*! \brief The passes that are required to perform the current pass. */ + tvm::Array required; + + PassInfoNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("opt_level", &opt_level); + v->Visit("name", &name); + v->Visit("required", &required); + } + + static constexpr const char* _type_key = "relay.PassInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, Object); +}; + +/* + * \brief Managed reference class for PassInfoNode + * \sa PassInfoNode + */ +class PassInfo : public ObjectRef { + public: + /*! + * \brief Constructor + * \param opt_level The optimization level + * \param name Name of the pass. + * \param required The passes that are required to perform the current pass. + */ + TVM_DLL PassInfo(int opt_level, + std::string name, + tvm::Array required); + + TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode); +}; + +/*! + * \brief PassNode is the base type of differnt types of optimization passes. + * It is designed as a pure class and implemented by different pass subclasses + * at different granularity of Relay nodes. + */ +class PassNode : public Object { + public: + virtual ~PassNode() {} + /*! + * \brief Get the pass information/meta data. */ + virtual PassInfo Info() const = 0; + + /*! + * \brief Transform mod using the default PassContext in the current scope. + * + * \param mod The module that an optimization pass runs on. + * + * \return The transformed module. + */ + IRModule operator()(const IRModule& mod) const { + return this->operator()(mod, PassContext::Current()); + } + + /*! + * \brief Transform mod using a functor under a given pass context. + * + * \param mod The module that an optimization pass runs on. + * \param pass_ctx The pass context that can provide information for the optimization. + * + * \return The transformed module. + */ + virtual IRModule operator()(const IRModule& mod, + const PassContext& pass_ctx) const = 0; + + void VisitAttrs(tvm::AttrVisitor* v) {} + + static constexpr const char* _type_key = "relay.Pass"; + TVM_DECLARE_BASE_OBJECT_INFO(PassNode, Object); +}; + +class Pass : public ObjectRef { + public: + /*! + * \brief Transform mod using the default PassContext in the current scope. + * + * \param mod The module that an optimization pass runs on. + * + * \return The transformed module. + */ + IRModule operator()(const IRModule& mod) const { + const PassNode* node = operator->(); + CHECK(node != nullptr); + return node->operator()(mod); + } + /*! + * \brief Transform mod using a functor under a given pass context. + * + * \param mod The module that an optimization pass runs on. + * \param pass_ctx The pass context that can provide information for the optimization. + * + * \return The transformed module. + */ + IRModule operator()(const IRModule& mod, + const PassContext& pass_ctx) const { + const PassNode* node = operator->(); + CHECK(node != nullptr); + return node->operator()(mod, pass_ctx); + } + + TVM_DEFINE_OBJECT_REF_METHODS(Pass, ObjectRef, PassNode); +}; + +class SequentialNode; + +class Sequential : public Pass { + public: + /*! + * \brief The constructor of `Sequential`. + * + * \param passes The passes to apply. + * \param pass_info The pass metadata. + */ + TVM_DLL Sequential(tvm::Array passes, PassInfo pass_info); + + /*! + * \brief The constructor of `Sequential`. + * + * \param passes The passes to apply. + * \param name The name of a sequential pass. It's defaulted to "sequential". + * This allows users to only provide a list of passes and execute them + * under a given context. + */ + TVM_DLL Sequential(tvm::Array passes, std::string name = "sequential"); + + Sequential() = default; + explicit Sequential(tvm::ObjectPtr<::tvm::Object> n) : Pass(n) {} + + const SequentialNode* operator->() const; + using ContainerType = Sequential; +}; + +/* + * \brief Create a module pass. + * + * \param pass_func The packed function that contains the optimization. + * \param opt_level The optimization level of the module pass. + * \param name The name of the module pass. + * \param required The list of the passes that the module pass is dependent on. + * + * \return The created module pass. + */ +Pass CreateModulePass( + const runtime::TypedPackedFunc& pass_func, + int opt_level, + const std::string& name, + const tvm::Array& required); + +} // namespace transform +} // namespace tvm + +#endif // TVM_IR_TRANSFORM_H_ diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index a740ea40571e..294ffb995c5e 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -19,320 +19,31 @@ /*! * \file tvm/relay/transform.h - * - * This file implements a pass manager. The pass manager manages a sequence - * of Relay-to-Relay transformation passes over a particlar unit of AST. The - * design is largely inspired from LLVM's pass manager and modern deep learning - * frameworks that perform tensor->tensor transformations. - * - * The responsibilities of a traditional compiler pass manager usually involves: - * - Organizing the execution order of optimization passes though not - * necessarily in the optimal sequence. - * - Collecting required analysis information and keep them up-to-date. - * - Reducing the effort required to implement new passes for compiler - * developers, etc. - * - * Similar to LLVM's pass manager, we designed the Relay pass manager to work - * different granularity, i.e. module level, function level, and even sequential - * passe that contains a host of passes. - * - * However, we also extend the functionality of the traditional pass manager - * with the consideration of requirements/convention from deep learning - * frameworks, such as Pytorch and Gluon, etc. Each pass in the Relay pass - * manager performs the Relay.Module -> Relay.Module transformation. All - * different types of passes, including the sequential-level pass object, are - * essentially pass objects. This design, therefore, effectively provides users - * a consistent and convenient interface, i.e. Pass, to play with. It offers a - * means to ease the development and testing of Relay passes. For example, with - * the pass manager, external users will be able to have custom passes correctly - * scheduled without having to modify a single handcrafted pass order. - * - * In the future we need to describe constraints between passes. For example, - * we may want to preserve dependencies between different passes and validate - * them on the completion of a certain pass. - * - * We also need to store side information and import the error reporting system. + * \brief Relay specific transformation passes. */ #ifndef TVM_RELAY_TRANSFORM_H_ #define TVM_RELAY_TRANSFORM_H_ -#include #include #include -#include +#include #include -#include -#include #include +#include + #include -#include -#include namespace tvm { namespace relay { namespace transform { -/* - * \brief The context of pass. - */ -class PassContext; - -/*! - * \brief PassContextNode contains the information that a pass can rely on, - * such as analysis results. - */ -class PassContextNode : public RelayNode { - public: - /*! - * \brief The error reporter used to notify users why an optimization fails. - */ - ErrorReporter err_reporter; - - /*! \brief The default optimization level. */ - int opt_level{2}; - - /*! \brief CPU is the default fallback device for heterogeneous execution. */ - int fallback_device{static_cast(kDLCPU)}; - - /*! \brief The list of required passes. */ - tvm::Array required_pass; - /*! \brief The list of disabled passes. */ - tvm::Array disabled_pass; - - PassContextNode() = default; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("opt_level", &opt_level); - v->Visit("fallback_device", &fallback_device); - v->Visit("required_pass", &required_pass); - v->Visit("disabled_pass", &disabled_pass); - } - - static constexpr const char* _type_key = "relay.PassContext"; - TVM_DECLARE_FINAL_OBJECT_INFO(PassContextNode, RelayNode); -}; - -/*! - * \brief PassContext that is used to configure the pass behavior. - * - * \code - * - * auto new_ctx = PassContext::Create(); - * ctx->opt_level = 2; - * ctx->fallback_device = kDLCPU; - * With scope(ctx); - * // pass context in effect. - * - * \endcode - */ -class PassContext : public ObjectRef { - public: - PassContext() {} - explicit PassContext(ObjectPtr<::tvm::Object> n) : ObjectRef(n) {} - /*! - * \brief const accessor. - * \return const access pointer. - */ - const PassContextNode* operator->() const { - CHECK(get() != nullptr); - return static_cast(get()); - } - /*! - * \brief mutable accessor. - * \return mutable access pointer. - */ - PassContextNode* operator->() { - CHECK(get() != nullptr); - return static_cast(get_mutable()); - } - /*! - * \brief Construct a PassContext containing the default configurations. - * \return The new PassContext. - */ - TVM_DLL static PassContext Create(); - /*! - * \brief Get the default pass context in the current scope. - * \return The pass context. - */ - TVM_DLL static PassContext Current(); - - // accessor. - using ContainerType = PassContextNode; - class Internal; - - private: - // The entry of a pass context scope. - TVM_DLL void EnterWithScope(); - // The exit of a pass context scope. - TVM_DLL void ExitWithScope(); - - // Classes to get the Python `with` like syntax. - friend class Internal; - friend class tvm::With; -}; - -/* - * \brief The meta data of a pass. - * - * PassInfo can be extended conveniently in the future if more meta information - * is needed. - */ -class PassInfo; - -/*! - * \brief PassInfoNode contains meta data that will be used to help optimization - * and analysis. - */ -class PassInfoNode : public RelayNode { - public: - /*! \brief The minimal optimization level that this pass will be enabled. */ - int opt_level; - - /*! \brief The name of an optimization/analysis pass. */ - std::string name; - - /*! \brief The passes that are required to perform the current pass. */ - tvm::Array required; - - PassInfoNode() = default; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("opt_level", &opt_level); - v->Visit("name", &name); - v->Visit("required", &required); - } - - TVM_DLL static PassInfo make(int opt_level, - std::string name, - tvm::Array required); - - static constexpr const char* _type_key = "relay.PassInfo"; - TVM_DECLARE_FINAL_OBJECT_INFO(PassInfoNode, RelayNode); -}; - -class PassInfo : public ObjectRef { - public: - TVM_DEFINE_OBJECT_REF_METHODS(PassInfo, ObjectRef, PassInfoNode); -}; - -class Pass; - -/*! - * \brief PassNode is the base type of differnt types of optimization passes. - * It is designed as a pure class and implemented by different pass subclasses - * at different granularity of Relay nodes. - */ -class PassNode : public RelayNode { - public: - virtual ~PassNode() {} - /*! - * \brief Get the pass information/meta data. */ - virtual PassInfo Info() const = 0; - - /*! - * \brief Transform mod using the default PassContext in the current scope. - * - * \param mod The module that an optimization pass runs on. - * - * \return The transformed module. - */ - IRModule operator()(const IRModule& mod) const { - return this->operator()(mod, PassContext::Current()); - } - - /*! - * \brief Transform mod using a functor under a given pass context. - * - * \param mod The module that an optimization pass runs on. - * \param pass_ctx The pass context that can provide information for the optimization. - * - * \return The transformed module. - */ - virtual IRModule operator()(const IRModule& mod, - const PassContext& pass_ctx) const = 0; - - void VisitAttrs(tvm::AttrVisitor* v) {} - - static constexpr const char* _type_key = "relay.Pass"; - TVM_DECLARE_BASE_OBJECT_INFO(PassNode, RelayNode); -}; - -class Pass : public ObjectRef { - public: - /*! - * \brief Transform mod using the default PassContext in the current scope. - * - * \param mod The module that an optimization pass runs on. - * - * \return The transformed module. - */ - IRModule operator()(const IRModule& mod) const { - const PassNode* node = operator->(); - CHECK(node != nullptr); - return node->operator()(mod); - } - /*! - * \brief Transform mod using a functor under a given pass context. - * - * \param mod The module that an optimization pass runs on. - * \param pass_ctx The pass context that can provide information for the optimization. - * - * \return The transformed module. - */ - IRModule operator()(const IRModule& mod, - const PassContext& pass_ctx) const { - const PassNode* node = operator->(); - CHECK(node != nullptr); - return node->operator()(mod, pass_ctx); - } - - TVM_DEFINE_OBJECT_REF_METHODS(Pass, ObjectRef, PassNode); -}; - -class SequentialNode; - -class Sequential : public Pass { - public: - /*! - * \brief The constructor of `Sequential`. - * - * \param passes The passes to apply. - * \param pass_info The pass metadata. - */ - TVM_DLL Sequential(tvm::Array passes, PassInfo pass_info); - - /*! - * \brief The constructor of `Sequential`. - * - * \param passes The passes to apply. - * \param name The name of a sequential pass. It's defaulted to "sequential". - * This allows users to only provide a list of passes and execute them - * under a given context. - */ - TVM_DLL Sequential(tvm::Array passes, std::string name = "sequential"); - - Sequential() = default; - explicit Sequential(tvm::ObjectPtr<::tvm::Object> n) : Pass(n) {} - - const SequentialNode* operator->() const; - using ContainerType = Sequential; -}; - -/* - * \brief Create a module pass. - * - * \param pass_func The packed function that contains the optimization. - * \param opt_level The optimization level of the module pass. - * \param name The name of the module pass. - * \param required The list of the passes that the module pass is dependent on. - * - * \return The created module pass. - */ -Pass CreateModulePass( - const runtime::TypedPackedFunc& pass_func, - int opt_level, - const std::string& name, - const tvm::Array& required); +using Pass = tvm::transform::Pass; +using PassNode = tvm::transform::PassNode; +using PassInfo = tvm::transform::PassInfo; +using PassInfoNode = tvm::transform::PassInfoNode; +using PassContext = tvm::transform::PassContext; +using PassContextNode = tvm::transform::PassContextNode; +using Sequential = tvm::transform::Sequential; /* * \brief Create a function pass. diff --git a/src/relay/pass/pass_manager.cc b/src/ir/transform.cc similarity index 69% rename from src/relay/pass/pass_manager.cc rename to src/ir/transform.cc index bcd4451d6a3b..bf180cfadb90 100644 --- a/src/relay/pass/pass_manager.cc +++ b/src/ir/transform.cc @@ -18,48 +18,52 @@ */ /*! - * \file src/relay/pass/pass_manager.cc - * \brief Relay pass manager implementation. + * \file src/ir/transform.cc + * \brief Infrastructure for transformation passes. */ #include -#include -#include +#include #include +#include +#include + +// TODO(tqchen): Update to use String container after it is merged. +#include -#include #include #include namespace tvm { -namespace relay { namespace transform { +using tvm::runtime::TVMArgs; +using tvm::runtime::TVMRetValue; using tvm::NodePrinter; -struct RelayPassContextThreadLocalEntry { +struct PassContextThreadLocalEntry { /*! \brief The default pass context. */ PassContext default_context; /*! \brief The current pass context. */ std::stack context_stack; - RelayPassContextThreadLocalEntry() { + PassContextThreadLocalEntry() { default_context = PassContext(make_object()); } }; /*! \brief Thread local store to hold the pass context. */ -typedef dmlc::ThreadLocalStore +typedef dmlc::ThreadLocalStore RelayPassContextThreadLocalStore; void PassContext::EnterWithScope() { - RelayPassContextThreadLocalEntry* entry = + PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); entry->context_stack.push(*this); } void PassContext::ExitWithScope() { - RelayPassContextThreadLocalEntry* entry = + PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); CHECK(!entry->context_stack.empty()); CHECK(entry->context_stack.top().same_as(*this)); @@ -67,7 +71,7 @@ void PassContext::ExitWithScope() { } PassContext PassContext::Current() { - RelayPassContextThreadLocalEntry* entry = + PassContextThreadLocalEntry* entry = RelayPassContextThreadLocalStore::Get(); if (!entry->context_stack.empty()) { return entry->context_stack.top(); @@ -121,84 +125,16 @@ class ModulePassNode : public PassNode { */ PassInfo Info() const override { return pass_info; } - TVM_DLL static ModulePass make( - runtime::TypedPackedFunc pass_func, - PassInfo pass_info); - static constexpr const char* _type_key = "relay.ModulePass"; TVM_DECLARE_FINAL_OBJECT_INFO(ModulePassNode, PassNode); }; class ModulePass : public Pass { public: - TVM_DEFINE_OBJECT_REF_METHODS(ModulePass, Pass, ModulePassNode); -}; - -class FunctionPass; - -/*! - * \brief Function-level passes are used to implement various global - * optimizations for a given Relay module. It fetches one function at a time - * from the function list in the module for optimization. - * - * Note that the scope of passes at this level is a Relay function. Therefore, - * we cannot add or delete a function through these passes as they are not aware - * of the global information. - */ -class FunctionPassNode : public PassNode { - public: - /* \brief The pass meta data.*/ - PassInfo pass_info; - - /*! \brief The packed pass function sketches the real optimization. For - * instance, we can implement a pass that works on a Relay function as a - * `pass_func` and let it run on a given module. The same `pass_func` will - * then be applied on each function in the module. - */ - runtime::TypedPackedFunc pass_func; - - FunctionPassNode() = default; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("pass_info", &pass_info); - } - - /*! - * \brief Run a function pass on given pass context. - * - * \param mod The module that an optimization pass is applied on. - * \param mod The context that an optimization pass executes on. - * - * \return Return the updated module. - */ - IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final; - - /*! - * \brief Get the pass information/meta data. - */ - PassInfo Info() const override { return pass_info; } - - TVM_DLL static FunctionPass make( - runtime::TypedPackedFunc pass_func, - PassInfo pass_info); - - static constexpr const char* _type_key = "relay.FunctionPass"; - TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPassNode, PassNode); - - private: - /* - * \brief Check if a function should be skipped for optimization. - * - * \param func The target function to be checked. - * - * \return Return true if the function will be skipped, otherwise false. - */ - bool SkipFunction(const Function& func) const; -}; + ModulePass(runtime::TypedPackedFunc pass_func, + PassInfo pass_info); -class FunctionPass : public Pass { - public: - TVM_DEFINE_OBJECT_REF_METHODS(FunctionPass, Pass, FunctionPassNode); + TVM_DEFINE_OBJECT_REF_METHODS(ModulePass, Pass, ModulePassNode); }; /*! @@ -267,28 +203,28 @@ class SequentialNode : public PassNode { TVM_DECLARE_FINAL_OBJECT_INFO(SequentialNode, PassNode); }; -PassInfo PassInfoNode::make(int opt_level, - std::string name, - tvm::Array required) { +PassInfo::PassInfo(int opt_level, + std::string name, + tvm::Array required) { auto pass_info = make_object(); pass_info->opt_level = opt_level; pass_info->name = std::move(name); pass_info->required = std::move(required); - return PassInfo(pass_info); + data_ = std::move(pass_info); } -ModulePass ModulePassNode::make( +ModulePass::ModulePass( runtime::TypedPackedFunc pass_func, PassInfo pass_info) { auto n = make_object(); n->pass_func = std::move(pass_func); n->pass_info = std::move(pass_info); - return ModulePass(n); + data_ = std::move(n); } // Module -> Module optimizations. IRModule ModulePassNode::operator()(const IRModule& mod, - const PassContext& pass_ctx) const { + const PassContext& pass_ctx) const { const PassInfo& pass_info = Info(); DLOG(INFO) << "Executing module pass : " << pass_info->name @@ -300,51 +236,6 @@ IRModule ModulePassNode::operator()(const IRModule& mod, return updated_mod; } -FunctionPass FunctionPassNode::make( - runtime::TypedPackedFunc pass_func, - PassInfo pass_info) { - auto n = make_object(); - n->pass_func = std::move(pass_func); - n->pass_info = std::move(pass_info); - return FunctionPass(n); -} - -// Perform Module -> Module optimizations at the Function level. -IRModule FunctionPassNode::operator()(const IRModule& mod, - const PassContext& pass_ctx) const { - const PassInfo& pass_info = Info(); - CHECK(mod.defined()); - DLOG(INFO) << "Executing function pass : " - << pass_info->name - << " with opt level: " - << pass_info->opt_level; - - // Execute the pass function and return a new module. - IRModule updated_mod = IRModule(mod->functions, mod->type_definitions, mod->Imports()); - std::vector > updates; - for (const auto& it : updated_mod->functions) { - // only picks up relay::Function - if (auto* n = it.second.as()) { - Function func = GetRef(n); - auto updated_func = SkipFunction(func) - ? func - : pass_func(func, updated_mod, pass_ctx); - updates.push_back({it.first, updated_func}); - } - } - - for (const auto& pair : updates) { - updated_mod->Add(pair.first, pair.second, true); - } - return updated_mod; -} - -bool FunctionPassNode::SkipFunction(const Function& func) const { - ObjectRef skip_opt = FunctionGetAttr(func, attr::kSkipOptimization); - const ir::IntImmNode* pval = skip_opt.as(); - return (pval && pval->value != 0) || (!func->UseDefaultCompiler()); -} - Sequential::Sequential(tvm::Array passes, PassInfo pass_info) { auto n = make_object(); n->passes = std::move(passes); @@ -355,7 +246,7 @@ Sequential::Sequential(tvm::Array passes, PassInfo pass_info) { Sequential::Sequential(tvm::Array passes, std::string name) { auto n = make_object(); n->passes = std::move(passes); - PassInfo pass_info = PassInfoNode::make(2, std::move(name), {}); + PassInfo pass_info = PassInfo(2, std::move(name), {}); n->pass_info = std::move(pass_info); data_ = std::move(n); } @@ -433,23 +324,16 @@ Pass CreateModulePass( int opt_level, const std::string& name, const tvm::Array& required) { - PassInfo pass_info = PassInfoNode::make(opt_level, name, required); - return ModulePassNode::make(pass_func, pass_info); -} - -Pass CreateFunctionPass( - const runtime::TypedPackedFunc& pass_func, - int opt_level, - const std::string& name, - const tvm::Array& required) { - PassInfo pass_info = PassInfoNode::make(opt_level, name, required); - return FunctionPassNode::make(pass_func, pass_info); + PassInfo pass_info = PassInfo(opt_level, name, required); + return ModulePass(pass_func, pass_info); } TVM_REGISTER_NODE_TYPE(PassInfoNode); TVM_REGISTER_GLOBAL("relay._transform.PassInfo") -.set_body_typed(PassInfoNode::make); +.set_body_typed([](int opt_level, std::string name, tvm::Array required) { + return PassInfo(opt_level, name, required); +}); TVM_REGISTER_GLOBAL("relay._transform.Info") .set_body([](TVMArgs args, TVMRetValue* ret) { @@ -474,7 +358,11 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) TVM_REGISTER_NODE_TYPE(ModulePassNode); TVM_REGISTER_GLOBAL("relay._transform.MakeModulePass") -.set_body_typed(ModulePassNode::make); +.set_body_typed( + [](runtime::TypedPackedFunc pass_func, + PassInfo pass_info) { + return ModulePass(pass_func, pass_info); +}); TVM_REGISTER_GLOBAL("relay._transform.RunPass") .set_body([](TVMArgs args, TVMRetValue* ret) { @@ -491,19 +379,6 @@ TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) << " at the optimization level " << info->opt_level; }); -TVM_REGISTER_NODE_TYPE(FunctionPassNode); - -TVM_REGISTER_GLOBAL("relay._transform.MakeFunctionPass") -.set_body_typed(FunctionPassNode::make); - -TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) -.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { - auto* node = static_cast(ref.get()); - const PassInfo info = node->Info(); - p->stream << "Run Function pass: " << info->name - << " at the optimization level " << info->opt_level; -}); - TVM_REGISTER_NODE_TYPE(SequentialNode); TVM_REGISTER_GLOBAL("relay._transform.Sequential") @@ -512,7 +387,7 @@ TVM_REGISTER_GLOBAL("relay._transform.Sequential") int opt_level = args[1]; std::string name = args[2]; tvm::Array required = args[3]; - PassInfo pass_info = PassInfoNode::make(opt_level, name, required); + PassInfo pass_info = PassInfo(opt_level, name, required); *ret = Sequential(passes, pass_info); }); @@ -589,5 +464,4 @@ TVM_REGISTER_GLOBAL("relay._transform.ExitPassContext") .set_body_typed(PassContext::Internal::ExitScope); } // namespace transform -} // namespace relay } // namespace tvm diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc new file mode 100644 index 000000000000..02ae273350bc --- /dev/null +++ b/src/relay/ir/transform.cc @@ -0,0 +1,170 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file relay/ir/transform.cc + * \brief Relay specific transformation passes. + */ +#include +#include +#include +#include + + +namespace tvm { +namespace relay { +namespace transform { + +class FunctionPass; + +/*! + * \brief Function-level passes are used to implement various global + * optimizations for a given Relay module. It fetches one function at a time + * from the function list in the module for optimization. + * + * Note that the scope of passes at this level is a Relay function. Therefore, + * we cannot add or delete a function through these passes as they are not aware + * of the global information. + */ +class FunctionPassNode : public PassNode { + public: + /* \brief The pass meta data.*/ + PassInfo pass_info; + + /*! \brief The packed pass function sketches the real optimization. For + * instance, we can implement a pass that works on a Relay function as a + * `pass_func` and let it run on a given module. The same `pass_func` will + * then be applied on each function in the module. + */ + runtime::TypedPackedFunc pass_func; + + FunctionPassNode() = default; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pass_info", &pass_info); + } + + /*! + * \brief Run a function pass on given pass context. + * + * \param mod The module that an optimization pass is applied on. + * \param mod The context that an optimization pass executes on. + * + * \return Return the updated module. + */ + IRModule operator()(const IRModule& mod, const PassContext& pass_ctx) const final; + + /*! + * \brief Get the pass information/meta data. + */ + PassInfo Info() const override { return pass_info; } + + TVM_DLL static FunctionPass make( + runtime::TypedPackedFunc pass_func, + PassInfo pass_info); + + static constexpr const char* _type_key = "relay.FunctionPass"; + TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPassNode, PassNode); + + private: + /* + * \brief Check if a function should be skipped for optimization. + * + * \param func The target function to be checked. + * + * \return Return true if the function will be skipped, otherwise false. + */ + bool SkipFunction(const Function& func) const; +}; + +class FunctionPass : public Pass { + public: + TVM_DEFINE_OBJECT_REF_METHODS(FunctionPass, Pass, FunctionPassNode); +}; + +FunctionPass FunctionPassNode::make( + runtime::TypedPackedFunc pass_func, + PassInfo pass_info) { + auto n = make_object(); + n->pass_func = std::move(pass_func); + n->pass_info = std::move(pass_info); + return FunctionPass(n); +} + +// Perform Module -> Module optimizations at the Function level. +IRModule FunctionPassNode::operator()(const IRModule& mod, + const PassContext& pass_ctx) const { + const PassInfo& pass_info = Info(); + CHECK(mod.defined()); + DLOG(INFO) << "Executing function pass : " + << pass_info->name + << " with opt level: " + << pass_info->opt_level; + + // Execute the pass function and return a new module. + IRModule updated_mod = IRModule(mod->functions, mod->type_definitions, mod->Imports()); + std::vector > updates; + for (const auto& it : updated_mod->functions) { + // only picks up relay::Function + if (auto* n = it.second.as()) { + Function func = GetRef(n); + auto updated_func = SkipFunction(func) + ? func + : pass_func(func, updated_mod, pass_ctx); + updates.push_back({it.first, updated_func}); + } + } + + for (const auto& pair : updates) { + updated_mod->Add(pair.first, pair.second, true); + } + return updated_mod; +} + +bool FunctionPassNode::SkipFunction(const Function& func) const { + ObjectRef skip_opt = FunctionGetAttr(func, attr::kSkipOptimization); + const ir::IntImmNode* pval = skip_opt.as(); + return (pval && pval->value != 0) || (!func->UseDefaultCompiler()); +} + +Pass CreateFunctionPass( + const runtime::TypedPackedFunc& pass_func, + int opt_level, + const std::string& name, + const tvm::Array& required) { + PassInfo pass_info = PassInfo(opt_level, name, required); + return FunctionPassNode::make(pass_func, pass_info); +} + +TVM_REGISTER_NODE_TYPE(FunctionPassNode); + +TVM_REGISTER_GLOBAL("relay._transform.MakeFunctionPass") +.set_body_typed(FunctionPassNode::make); + +TVM_STATIC_IR_FUNCTOR(NodePrinter, vtable) +.set_dispatch([](const ObjectRef& ref, NodePrinter* p) { + auto* node = static_cast(ref.get()); + const PassInfo info = node->Info(); + p->stream << "Run Function pass: " << info->name + << " at the optimization level " << info->opt_level; +}); + +} // namespace transform +} // namespace relay +} // namespace tvm