Skip to content

Commit

Permalink
fix per @tqchen's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed Mar 2, 2019
1 parent 11040fc commit 018c2fb
Show file tree
Hide file tree
Showing 7 changed files with 488 additions and 480 deletions.
184 changes: 184 additions & 0 deletions include/tvm/relay/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,202 @@
* Copyright (c) 2018 by Contributors
* \file tvm/relay/pass.h
* \brief The set of Relay passes written in C++.
*
* This file also implements a pass manager. The pass manager manages a sequence
* of Relay-to-Relay transformation passes over a particlar unit of AST. The
* design is largely inspired from LLVM's pass manager.
*
* The responsibilities of a pass manager usually involves:
* - Organizing the execution order of optimization passes though not
* necessarily in the optimal sequence.
* - Collecting required analysis information and keep them up-to-date.
* - Reducing the effort required to implement new passes for compiler
* developers, etc.
*
* TODO(jroesch, zhiics): We are currently using a very simple design for the
* pass manager, i.e. it executes a specific pass or sequence of passes.
*
* In the future we need to describe constraints between passes. For example,
* we may want to preserve dependencies between different passes and validate
* them on the completion of a certain pass.
*
* We also need to store side information and import the error reporting system.
*/
#ifndef TVM_RELAY_PASS_H_
#define TVM_RELAY_PASS_H_

#include <tvm/attrs.h>
#include <tvm/ir.h>
#include <tvm/packed_func_ext.h>
#include <tvm/relay/error.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/module.h>
#include <tvm/relay/op_attr_types.h>
#include <tvm/relay/type.h>

#include <string>
#include <vector>

namespace tvm {
namespace relay {

namespace pass {

// Forward declaration
class ModulePass;
class FunctionPass;
class SequentialPass;

// Define pass context.
class PassContext;

/*!
* \brief PassContextNode contains the information that a pass can rely on, such as
* analysis results.
*/
class PassContextNode : public RelayNode {
public:
/*!
* \brief The error reporter used to notify users why an optimization fails.
*/
ErrorReporter err_reporter_;

PassContextNode() = default;

void VisitAttrs(tvm::AttrVisitor* v) final {
}

TVM_DLL static PassContext make();

static constexpr const char* _type_key = "relay.PassContext";
TVM_DECLARE_NODE_TYPE_INFO(PassContextNode, RelayNode);
};

class PassContext : public NodeRef {
public:
PassContext() = default;
explicit PassContext(NodePtr<tvm::Node> p) : NodeRef(p) {}

const PassContextNode* operator->() const {
return static_cast<PassContextNode*>(this->node_.get());
}

using ContainerType = PassContextNode;
};

// We use currying here. It runs on a Relay node type NodeT and yields a new
// node with the same type. The Relay module is captured for optimizations as
// most of the current Relay optimizations are module to module. Currying
// sketches the optimization, i.e. how we want to mutate an AST, and it is
// passed as packed functions that will be invoked when called by `run`.
//
// For example, PassFunc<Function> indicates we perform a Function to Function
// transformation on the given Module.
template <typename NodeT,
typename = std::enable_if<(std::is_same<NodeT, Module>::value ||
std::is_same<NodeT, Function>::value)>>
using PassFunc =
runtime::TypedPackedFunc<runtime::TypedPackedFunc<NodeT(NodeT)>(
const Module& mod)>;

class Pass;

/*!
* \brief PassNode is the base type of differnt types of optimization passes.
* It is implemented by different pass subclasses at different granularity of
* Relay nodes.
*/
class PassNode : public RelayNode {
public:
/*! \brief The name of an optimization/analysis pass. */
std::string name;
/*! \brief The minimal optimization level that this pass will be enabled. */
int opt_level;

/*!
* \brief Set the context information for a pass.
*
* \param pass_ctx The context information for a certain pass.
*/
virtual void SetContext(const PassContext& pass_ctx) = 0;

/*!
* \brief Get the required passes for this pass as a vector of std::string.
*/
virtual std::vector<std::string> Required() const = 0;

/*!
* \brief Execute the optimization pass using a functor. This functor invokes
* the `run` method to perform a real optimization on a certain type
* of node.
*
* \param mod The module that an optimization pass runs on.
*
* \return The updated module.
*/
virtual Module operator()(const Module& mod) const = 0;

void VisitAttrs(tvm::AttrVisitor* v) override {
v->Visit("name", &name);
v->Visit("opt_level", &opt_level);
}

static constexpr const char* _type_key = "relay.Pass";
TVM_DECLARE_BASE_NODE_INFO(PassNode, RelayNode);
};

class Pass : public NodeRef {
public:
Pass() = default;
explicit Pass(NodePtr<tvm::Node> p) : NodeRef(p) {}

PassNode* operator->() const {
return static_cast<PassNode*>(this->node_.get());
}

using ContainerType = PassNode;
};

/*
* \brief Create a module pass.
*
* \param name The name of the module pass.
* \param opt_level The optimization level of the module pass.
* \param pass_func The curried packed function that contains the optimization.
*
* \return The created module pass.
*/
ModulePass CreateModulePass(const std::string& name, int opt_level,
const PassFunc<Module>& pass_func);

/*
* \brief Create a function pass.
*
* \param name The name of the function pass.
* \param opt_level The optimization level of the function pass.
* \param pass_func The curried packed function that contains the optimization.
*
* \return The created function pass.
*/
FunctionPass CreateFunctionPass(const std::string& name, int opt_level,
const PassFunc<Function>& pass_func);
/*
* \brief Create a sequential pass.
*
* \param name The name of the sequential pass.
* \param opt_level The optimization level of the sequential pass. It could be
* the highest opt_level of the list of passes.
* \param passes The optimization passes will be performed.
* \param disabled The disabled passes.
*
* \return The created sequential pass.
*/
SequentialPass CreateSequentialPass(const std::string& name, int opt_level,
const tvm::Array<Pass>& passes,
const tvm::Array<tvm::Expr>& disabled);

} // namespace pass

/*!
* \brief Infer the type of an expression.
*
Expand Down
Loading

0 comments on commit 018c2fb

Please sign in to comment.