Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Pass manager #2546

Merged
merged 32 commits into from
Mar 12, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
ebf4c29
initial commit
Jan 22, 2019
4630afc
add python frontend and module tests
zhiics Jan 30, 2019
c894f04
add unit tests for function pass and optimize interface
zhiics Jan 31, 2019
0880027
add ExprPass
zhiics Jan 31, 2019
4122079
remove PassState and pass context for run
zhiics Feb 14, 2019
49ae421
add required_passes
zhiics Feb 14, 2019
b1adacc
return module
zhiics Feb 15, 2019
42a3227
remove move
zhiics Feb 15, 2019
c1c6d07
fix minor reviews
zhiics Feb 19, 2019
fd22d34
remove optimizer, optimizer->pass_manager, make pass a the base class…
zhiics Feb 20, 2019
4cf5843
remove deleted files
zhiics Feb 20, 2019
04ef13e
move resolvedependency to sequential pass, use ir_pass namespace
zhiics Feb 21, 2019
4cd4bd1
add todo
zhiics Feb 21, 2019
d98af5a
add disabled passes in sequetialpass
zhiics Feb 21, 2019
ccb5197
fix minor
zhiics Feb 21, 2019
e5da540
fix currying doc
zhiics Feb 21, 2019
c021126
remove pass_kind from passnode
zhiics Feb 25, 2019
42c5619
remove pass kind from test
zhiics Feb 25, 2019
7cfd1b6
fix doc
zhiics Feb 27, 2019
8c4d548
fix per @tqchen's comments
zhiics Mar 2, 2019
15e1d6c
remove pass_manager.py create separate classes
zhiics Mar 2, 2019
02d7de7
simplify pass_func
zhiics Mar 3, 2019
550ad63
inline using passfunc
zhiics Mar 3, 2019
699e5b3
update doc
zhiics Mar 3, 2019
dc2b30c
disable test_quantize_pass for now
zhiics Mar 5, 2019
49df272
create PassInfo class to contain the meta data
zhiics Mar 8, 2019
810cbac
flatten passinfo for interface
zhiics Mar 8, 2019
ee359fb
retrigger ci
zhiics Mar 8, 2019
a0863b4
remove required method
zhiics Mar 9, 2019
c8fb6b5
make Pass python class lighter
zhiics Mar 10, 2019
6df5c7d
create pass -> decorator
zhiics Mar 11, 2019
e78f4e2
make the api consistent for all classes
zhiics Mar 11, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 58 additions & 20 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ namespace relay {

namespace pass {

// Define pass context.
/*
* \brief The context of pass.
*/
class PassContext;

/*!
Expand All @@ -82,20 +84,60 @@ class PassContextNode : public RelayNode {

TVM_DEFINE_NODE_REF(PassContext, PassContextNode)

class Pass;
/*
* \brief The meta data of a pass.
*
* PassInfo can be extended conveniently in the future if more meta information
* is needed.
*/
class PassInfo;

/*!
* \brief PassNode is the base type of differnt types of optimization passes.
* It is implemented by different pass subclasses at different granularity of
* Relay nodes.
* \brief PassInfoNode contains meta data that will be used to help optimization
* and analysis.
*/
class PassNode : public RelayNode {
class PassInfoNode : public RelayNode {
public:
/*! \brief The name of an optimization/analysis pass. */
std::string name;

/*! \brief The minimal optimization level that this pass will be enabled. */
int opt_level;

/*! \brief The passes that are required to perform the current pass. */
tvm::Array<tvm::Expr> required;

PassInfoNode() = default;

void VisitAttrs(tvm::AttrVisitor* v) final {
v->Visit("name", &name);
v->Visit("opt_level", &opt_level);
v->Visit("required", &required);
}

TVM_DLL static PassInfo make(std::string name, int opt_level,
tvm::Array<tvm::Expr> required);

static constexpr const char* _type_key = "relay.PassInfo";
TVM_DECLARE_NODE_TYPE_INFO(PassInfoNode, RelayNode);
};

TVM_DEFINE_NODE_REF(PassInfo, PassInfoNode)

class Pass;

/*!
* \brief PassNode is the base type of differnt types of optimization passes.
* It is designed as a pure class and implemented by different pass subclasses
* at different granularity of Relay nodes.
*/
class PassNode : public RelayNode {
public:
/*
* \brief Get the pass information/meta data.
*/
virtual PassInfo GetPassInfo() const = 0;

/*!
* \brief Set the context information for a pass.
*
Expand All @@ -117,10 +159,7 @@ class PassNode : public RelayNode {
*/
virtual Module operator()(const Module& mod) const = 0;

void VisitAttrs(tvm::AttrVisitor* v) override {
v->Visit("name", &name);
v->Visit("opt_level", &opt_level);
}
void VisitAttrs(tvm::AttrVisitor* v) override {}

static constexpr const char* _type_key = "relay.Pass";
TVM_DECLARE_BASE_NODE_INFO(PassNode, RelayNode);
Expand All @@ -141,40 +180,39 @@ class Pass : public NodeRef {
/*
* \brief Create a module pass.
*
* \param name The name of the module pass.
* \param opt_level The optimization level of the module pass.
* \param pass_info The encapsulated pass meta data that will be used to set up
* a pass.
* \param pass_func The packed function that contains the optimization.
*
* \return The created module pass.
*/
Pass CreateModulePass(
const std::string& name, int opt_level,
const PassInfo& pass_info,
const runtime::TypedPackedFunc<Module(Module, PassContext)>& pass_func);

/*
* \brief Create a function pass.
*
* \param name The name of the function pass.
* \param opt_level The optimization level of the function pass.
* \param pass_info The encapsulated pass meta data that will be used to set up
* a pass.
* \param pass_func The packed function that contains the optimization.
*
* \return The created function pass.
*/
Pass CreateFunctionPass(
const std::string& name, int opt_level,
const PassInfo& pass_info,
const runtime::TypedPackedFunc<Function(Function, PassContext)>& pass_func);
/*
* \brief Create a sequential pass.
*
* \param name The name of the sequential pass.
* \param opt_level The optimization level of the sequential pass. It could be
* the highest opt_level of the list of passes.
* \param pass_info The encapsulated pass meta data that will be used to set up
* a pass.
* \param passes The optimization passes will be performed.
* \param disabled The disabled passes.
*
* \return The created sequential pass.
*/
Pass CreateSequentialPass(const std::string& name, int opt_level,
Pass CreateSequentialPass(const PassInfo& pass_info,
const tvm::Array<Pass>& passes,
const tvm::Array<tvm::Expr>& disabled);

Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@
load_param_dict = param_dict.load_param_dict

# Pass manager
PassInfo = ir_pass.PassInfo
PassContext = ir_pass.PassContext
Pass = ir_pass.Pass
ModulePass = ir_pass.ModulePass
Expand Down
33 changes: 19 additions & 14 deletions python/tvm/relay/_ir_pass.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,40 +8,45 @@ class PassContext(NodeBase):
def __init__(self):
...


class Pass(NodeBase):
class PassInfo(NodeBase):
name = ... # type: str
opt_level = ... # type: int
required = ... # type: list

def __init__(self, name, opt_level, required)
# type: (str, int, list) -> None


class Pass(NodeBase):
def __init__(self):
...


class ModulePass(Pass):
name = ... # type: str
opt_level = ... # type: int
pass_info = ... # type: list
pass_func = ... # type: Callable

def __init__(self, name, opt_level, pass_func):
# type: (str, int, Callable) -> None
def __init__(self, pass_info, pass_func):
# type: (list, Callable) -> None
...


class FunctionPass(Pass):
name = ... # type: str
opt_level = ... # type: int
pass_info = ... # type: list
pass_func = ... # type: Callable

def __init__(self, name, opt_level, pass_func):
# type: (str, int, Callable) -> None
def __init__(self, pass_info, pass_func):
# type: (list, Callable) -> None
...


class SequentialPass(Pass):
name = ... # type: str
opt_level = ... # type: int
pass_info = ... # type: list
passes = ... # type: list
disabled = ... # type: list

def __init__(self, name, opt_level, passes, disabled):
# type: (str, int, list, list) -> None
def __init__(self, pass_info, passes, disabled):
# type: (str, list, list, list) -> None
...


Expand Down
90 changes: 44 additions & 46 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,23 @@
from .module import Module


@register_relay_node
class PassInfo(RelayNode):
"""The class that contains the meta data required by a pass. It is the
container of information needed by running an optimization or analysis.
This class can be extended by adding new members when more meta data is
needed.
"""

def __init__(self, name, opt_level, required=None):
required = required if required else []
if not isinstance(required, (list, tuple)):
raise TypeError("required must be the list or tuple type that " +
"contains a host of dependent pass namees.")
self.__init_handle_by_constructor__(_ir_pass.PassInfo, name, opt_level,
required)


@register_relay_node
class PassContext(RelayNode):
"""The basis where a Relay optimization/analysis runs on.
Expand Down Expand Up @@ -86,19 +103,16 @@ class ModulePass(Pass):

Parameters
----------
name : str
The pass name.

opt_level : int
The optimization level of this pass.
pass_info : PassInfo
The meta data required by a module class.

pass_func : Callable[(tvm.relay.Module, PassContext) -> tvm.relay.Module]
The callback function that sketches a certain optimization.
"""

def __init__(self, name, opt_level, pass_func):
self.__init_handle_by_constructor__(_ir_pass.CreateModulePass, name,
opt_level, pass_func)
def __init__(self, pass_info, pass_func):
self.__init_handle_by_constructor__(_ir_pass.CreateModulePass,
pass_info, pass_func)

def __call__(self, mod):
"""Execute a module pass.
Expand All @@ -122,20 +136,17 @@ class FunctionPass(Pass):

Parameters
----------
name : str
The pass name.

opt_level : int
The optimization level of this pass.
pass_info : PassInfo
The meta data required by a function class.

pass_func : Callable[(tvm.relay.Function, PassContext) ->
tvm.relay.Function]
The callback function that sketches a certain optimization.
"""

def __init__(self, name, opt_level, pass_func):
self.__init_handle_by_constructor__(_ir_pass.CreateFunctionPass, name,
opt_level, pass_func)
def __init__(self, pass_info, pass_func):
self.__init_handle_by_constructor__(_ir_pass.CreateFunctionPass,
pass_info, pass_func)

def __call__(self, mod):
"""Execute a function pass.
Expand All @@ -159,11 +170,8 @@ class SequentialPass(Pass):

Parameters
----------
name : str
The pass name.

opt_level : int
The optimization level of this pass.
pass_info : PassInfo
The meta data required by a sequential class.

passes : List[Pass]
The pass candidates to be executed.
Expand All @@ -172,12 +180,12 @@ class SequentialPass(Pass):
The list of passes that are disabled.
"""

def __init__(self, name, opt_level, passes, disabled=None):
def __init__(self, pass_info, passes, disabled=None):
disabled = disabled if disabled else []
if not isinstance(disabled, (list, tuple)):
raise TypeError("disabled must be a list or tuple of pass names")
self.__init_handle_by_constructor__(_ir_pass.CreateSequentialPass,
name, opt_level, passes, disabled)
pass_info, passes, disabled)

def __call__(self, mod):
"""Execute a sequence of passes.
Expand All @@ -195,16 +203,13 @@ def __call__(self, mod):
return _ir_pass.RunSequentialPass(self, mod)
zhiics marked this conversation as resolved.
Show resolved Hide resolved


def create_module_pass(pass_name, opt_level, pass_func):
def create_module_pass(pass_info, pass_func):
"""Create a module pass using a defined optimization function from Python.

Parameters
----------
pass_name : str
The name of the pass.

opt_level : int
The optimization level of this pass.
pass_info : PassInfo
The meta data required by a module class.

pass_func : Optional[Callable[(Module/Function, PassContext) ->
Module/Function]]
Expand All @@ -218,20 +223,17 @@ def create_module_pass(pass_name, opt_level, pass_func):
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")

return _ir_pass.CreateModulePass(pass_name, opt_level, pass_func)
return _ir_pass.CreateModulePass(pass_info, pass_func)


def create_function_pass(pass_name, opt_level, pass_func):
def create_function_pass(pass_info, pass_func):
"""Create a function pass using a defined optimization function from
Python.

Parameters
----------
pass_name : str
The name of the pass.

opt_level : int
The optimization level of this pass.
pass_info : PassInfo
The meta data required by a function class.

pass_func : Optional[Callable[(Module/Function, PassContext) ->
Module/Function]]
Expand All @@ -245,21 +247,17 @@ def create_function_pass(pass_name, opt_level, pass_func):
if not isinstance(pass_func, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")

return _ir_pass.CreateFunctionPass(pass_name, opt_level, pass_func)
return _ir_pass.CreateFunctionPass(pass_info, pass_func)


def create_sequential_pass(pass_name, opt_level, sequential_passes,
disabled=None):
def create_sequential_pass(pass_info, sequential_passes, disabled=None):
"""Create a sequential pass using a defined optimization function from
Python.

Parameters
----------
pass_name : str
The name of the pass.

opt_level : int
The optimization level of this pass.
pass_info : PassInfo
zhiics marked this conversation as resolved.
Show resolved Hide resolved
The meta data required by a sequential class.

sequential_passes : Optional[List[Pass]]
A sequence of passes candidate for optimization.
Expand All @@ -279,8 +277,8 @@ def create_sequential_pass(pass_name, opt_level, sequential_passes,
if not isinstance(disabled, (list, tuple)):
raise TypeError("disabled must be a list or tuple of pass names")

return _ir_pass.CreateSequentialPass(pass_name, opt_level,
sequential_passes, disabled)
return _ir_pass.CreateSequentialPass(pass_info, sequential_passes,
disabled)


def post_order_visit(expr, fvisit):
Expand Down
Loading