Skip to content

Commit

Permalink
[Refactor] Unify the shared pass prefix between vm and graph (apache#…
Browse files Browse the repository at this point in the history
  • Loading branch information
YuchenJin authored and ylc committed Jan 13, 2022
1 parent 5882181 commit aa55095
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 102 deletions.
52 changes: 1 addition & 51 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -313,57 +313,7 @@ class RelayBuildModule : public runtime::ModuleNode {
relay_module_ptr->Update(main_glb_var, new_main);
}

Array<Pass> pass_seqs;
Array<runtime::String> entry_functions{"main"};
pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
pass_seqs.push_back(transform::ToBasicBlockNormalForm());

// Run all dialect legalization passes.
pass_seqs.push_back(relay::qnn::transform::Legalize());

// Legalize pass is restricted to homogeneous execution for now.
if (targets.size() == 1) {
pass_seqs.push_back(transform::Legalize());
}

pass_seqs.push_back(transform::SimplifyInference());

// Convert Dynamic ops to static versions
pass_seqs.push_back(transform::DynamicToStatic());

PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
Expr expr = args[0];
*rv = false;
if (expr.as<CallNode>()) {
auto call_node = expr.as<CallNode>();
auto op_node = call_node->op.as<OpNode>();
if (op_node->name == "cast") {
auto attrs = call_node->attrs.as<CastAttrs>();
if (attrs->dtype == DataType::Int(32)) {
*rv = true;
}
}
}
});
pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
pass_seqs.push_back(transform::SimplifyExpr());
pass_seqs.push_back(transform::CombineParallelConv2D(3));
pass_seqs.push_back(transform::CombineParallelDense(3));
pass_seqs.push_back(transform::CombineParallelBatchMatmul(3));
pass_seqs.push_back(transform::FoldConstant());
pass_seqs.push_back(transform::FoldScaleAxis());
pass_seqs.push_back(transform::CanonicalizeCast());
pass_seqs.push_back(transform::CanonicalizeOps());

// Alter layout transformation is only applied to homogeneous execution yet.
if (targets.size() == 1) {
pass_seqs.push_back(transform::InferType());
pass_seqs.push_back(transform::AlterOpLayout());
}

// Fast math optimizations.
pass_seqs.push_back(transform::FastMath());
pass_seqs.push_back(transform::FoldConstant());
Array<Pass> pass_seqs = GetPassPrefix(targets, false);

if (targets.size() == 1) {
const auto& target = (*targets.begin()).second;
Expand Down
67 changes: 67 additions & 0 deletions src/relay/backend/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@

#include "utils.h"

#include <tvm/relay/qnn/transform.h>

namespace tvm {
namespace relay {
namespace backend {
Expand Down Expand Up @@ -120,6 +122,71 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
<< ",\n relay_primfuncs=" << node->relay_primfuncs << ")";
});

Array<Pass> GetPassPrefix(const Map<tvm::Integer, tvm::Target>& targets, bool is_vm) {
Array<Pass> pass_seqs;
Array<runtime::String> entry_functions{"main"};
pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
pass_seqs.push_back(transform::ToBasicBlockNormalForm());
// Run all dialect legalization passes.
pass_seqs.push_back(relay::qnn::transform::Legalize());

// Legalize pass is restricted to homogeneous execution for now.
if (targets.size() == 1) {
pass_seqs.push_back(transform::Legalize());
}

pass_seqs.push_back(transform::SimplifyInference());

if (is_vm) {
// eta expand to support constructors in argument position
pass_seqs.push_back(transform::EtaExpand(
/* expand_constructor */ true, /* expand_global_var */ false));
} else {
// Convert Dynamic ops to static versions
pass_seqs.push_back(transform::DynamicToStatic());
}

PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
Expr expr = args[0];
if (expr.as<CallNode>()) {
auto call_node = expr.as<CallNode>();
auto op_node = call_node->op.as<OpNode>();
if (op_node->name == "cast") {
auto attrs = call_node->attrs.as<CastAttrs>();
if (attrs->dtype == DataType::Int(32)) {
*rv = true;
}
}
}
*rv = false;
});
pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
pass_seqs.push_back(transform::SimplifyExpr());
if (is_vm) {
pass_seqs.push_back(transform::InlinePrimitives());
}
pass_seqs.push_back(transform::CombineParallelConv2D(3));
pass_seqs.push_back(transform::CombineParallelDense(3));
pass_seqs.push_back(transform::CombineParallelBatchMatmul(3));
pass_seqs.push_back(transform::FoldConstant());
pass_seqs.push_back(transform::FoldScaleAxis());
pass_seqs.push_back(transform::CanonicalizeCast());
pass_seqs.push_back(transform::CanonicalizeOps());

// Alter layout transformation is only applied to homogeneous execution yet.
if (targets.size() == 1) {
if (!is_vm) {
pass_seqs.push_back(transform::InferType());
}
pass_seqs.push_back(transform::AlterOpLayout());
}

// Fast math optimizations.
pass_seqs.push_back(transform::FastMath());
pass_seqs.push_back(transform::FoldConstant());
return pass_seqs;
}

} // namespace backend
} // namespace relay
} // namespace tvm
17 changes: 17 additions & 0 deletions src/relay/backend/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@

namespace tvm {
namespace relay {
namespace transform {
Pass InlinePrimitives();
}

namespace backend {
using Pass = tvm::transform::Pass;

/*!
* \brief The static storage information produced by memory planning.
Expand Down Expand Up @@ -410,6 +415,18 @@ inline bool IsCompileEngineCacheDisabled() {
.value();
}

/*!
* \brief Get the sequence of Relay optimization passes based on backend type.
* The prefix of the Relay passes almost overlaps between the vm and graph backend, with some slight
* difference. This function unifies the shared optimization pass prefix between vm and graph
* runtime, and returns the pass prefix given the backend type.
*
* \param targets The device type to `Target` mapping.
* \param is_vm A boolean indicating if the passes are used for vm or graph runtime.
* \return An array of passes.
*/
Array<Pass> GetPassPrefix(const Map<tvm::Integer, tvm::Target>& targets, bool is_vm);

} // namespace backend
} // namespace relay
} // namespace tvm
Expand Down
52 changes: 1 addition & 51 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1042,57 +1042,7 @@ IRModule VMCompiler::OptimizeModule(IRModule mod, const TargetsMap& targets_arg,
mod->Add(gvar, f);
}

Array<Pass> pass_seqs;
Array<runtime::String> entry_functions{"main"};
pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
pass_seqs.push_back(transform::ToBasicBlockNormalForm());
// Run all dialect legalization passes.
pass_seqs.push_back(relay::qnn::transform::Legalize());

// Legalize pass is restricted to homogeneous execution for now.
if (targets.size() == 1) {
pass_seqs.push_back(transform::Legalize());
}

// eta expand to support constructors in argument position
pass_seqs.push_back(transform::EtaExpand(
/* expand_constructor */ true, /* expand_global_var */ false));

pass_seqs.push_back(transform::SimplifyInference());
PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
Expr expr = args[0];
if (expr.as<CallNode>()) {
auto call_node = expr.as<CallNode>();
auto op_node = call_node->op.as<OpNode>();
if (op_node->name == "cast") {
auto attrs = call_node->attrs.as<CastAttrs>();
if (attrs->dtype == DataType::Int(32)) {
*rv = true;
}
}
}
*rv = false;
});
pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
pass_seqs.push_back(transform::SimplifyExpr());
pass_seqs.push_back(transform::InlinePrimitives());

pass_seqs.push_back(transform::CombineParallelConv2D(3));
pass_seqs.push_back(transform::CombineParallelDense(3));
pass_seqs.push_back(transform::CombineParallelBatchMatmul(3));
pass_seqs.push_back(transform::FoldConstant());
pass_seqs.push_back(transform::FoldScaleAxis());
pass_seqs.push_back(transform::CanonicalizeCast());
pass_seqs.push_back(transform::CanonicalizeOps());

// Alter layout transformation is only applied to homogeneous execution yet.
if (targets.size() == 1) {
pass_seqs.push_back(transform::AlterOpLayout());
}

// Fast math optimizations.
pass_seqs.push_back(transform::FastMath());
pass_seqs.push_back(transform::FoldConstant());
Array<Pass> pass_seqs = relay::backend::GetPassPrefix(targets, true);

if (targets_.size() > 1) {
// Handle heterogeneous compilation.
Expand Down

0 comments on commit aa55095

Please sign in to comment.