Skip to content

Commit

Permalink
Clean up LowerTEPass and pass IRModule Attrs through passes (apache#8914
Browse files Browse the repository at this point in the history
)
  • Loading branch information
electriclilies authored and ylc committed Jan 13, 2022
1 parent 01aa50a commit 2b3ec0d
Show file tree
Hide file tree
Showing 12 changed files with 77 additions and 85 deletions.
20 changes: 18 additions & 2 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class IRModuleNode : public Object {
v->Visit("global_var_map_", &global_var_map_);
v->Visit("global_type_var_map_", &global_type_var_map_);
v->Visit("source_map", &source_map);
v->Visit("attrs", &attrs);
}

TVM_DLL bool SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const;
Expand Down Expand Up @@ -277,6 +278,12 @@ class IRModuleNode : public Object {
*/
TVM_DLL void Update(const IRModule& other);

/*!
* \brief Create a shallow copy of this IRModule.
* \returns The shallow copy of the IRModule.
*/
TVM_DLL IRModule ShallowCopy();

/*!
* \brief Import Relay code from the file at path.
* \param path The path of the Relay code to import.
Expand Down Expand Up @@ -348,12 +355,14 @@ class IRModule : public ObjectRef {
* \brief constructor
* \param functions Functions in the module.
* \param type_definitions Type definitions in the module.
* \param import_set Set of imported files in the module
* \param import_set Set of imported files in the module.
* \param map The module source map.
* \param attrs The module attributes.
*/
TVM_DLL explicit IRModule(Map<GlobalVar, BaseFunc> functions,
Map<GlobalTypeVar, TypeData> type_definitions = {},
std::unordered_set<String> import_set = {}, parser::SourceMap map = {});
std::unordered_set<String> import_set = {}, parser::SourceMap map = {},
DictAttrs attrs = {});

/*! \brief default constructor */
IRModule() : IRModule(Map<GlobalVar, BaseFunc>({})) {}
Expand Down Expand Up @@ -415,6 +424,13 @@ class IRModule : public ObjectRef {
*/
TVM_DLL static IRModule FromText(const String& text, const String& source_path);

/*!
* \brief Create a shallow copy of an IRModule.
* \param mod The module to copy.
* \return The copied module.
*/
IRModule ShallowCopyIRModule(IRModule mod);

/*! \brief Declare the container type. */
using ContainerType = IRModuleNode;

Expand Down
11 changes: 10 additions & 1 deletion src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ namespace tvm {

IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
tvm::Map<GlobalTypeVar, TypeData> type_definitions,
std::unordered_set<String> import_set, parser::SourceMap source_map) {
std::unordered_set<String> import_set, parser::SourceMap source_map,
DictAttrs attrs) {
auto n = make_object<IRModuleNode>();
n->functions = std::move(functions);
n->type_definitions = std::move(type_definitions);
Expand All @@ -52,6 +53,7 @@ IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,
n->constructor_tag_map_ = {};
n->import_set_ = std::move(import_set);
n->source_map = source_map;
n->attrs = std::move(attrs);

for (const auto& kv : n->functions) {
// set global var map
Expand All @@ -72,6 +74,7 @@ IRModule::IRModule(tvm::Map<GlobalVar, BaseFunc> functions,

bool IRModuleNode::SEqualReduce(const IRModuleNode* other, SEqualReducer equal) const {
if (functions.size() != other->functions.size()) return false;
if (!equal(this->attrs, other->attrs)) return false;
for (const auto& kv : this->functions) {
if (!other->ContainGlobalVar(kv.first->name_hint)) return false;
if (!equal(kv.second, other->Lookup(kv.first->name_hint))) return false;
Expand Down Expand Up @@ -112,6 +115,7 @@ void IRModuleNode::SHashReduce(SHashReducer hash_reduce) const {
temp.emplace_back(kv.first->name_hint, kv.second);
}
reduce_temp();
hash_reduce(this->attrs);
}

bool IRModuleNode::ContainGlobalVar(const String& name) const {
Expand Down Expand Up @@ -361,6 +365,11 @@ void IRModuleNode::Update(const IRModule& mod) {
}
}

IRModule IRModuleNode::ShallowCopy() {
return IRModule(this->functions, this->type_definitions, this->Imports(), this->source_map,
this->attrs);
}

std::pair<IRModule, GlobalVar> IRModule::FromExprInContext(
const RelayExpr& expr, const tvm::Map<GlobalVar, BaseFunc>& global_funcs,
const tvm::Map<GlobalTypeVar, TypeData>& type_definitions,
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {

Optional<Array<tvm::runtime::Module>> external_modules =
lowered_mod->GetAttr<Array<tvm::runtime::Module>>("external_mods");
ICHECK(external_modules) << "Attribute \"external_modules\" should be set at this point.";
ICHECK(external_modules) << "Attribute \"external_mods\" should be set at this point.";

// This is the point where we separate the functions in the module by target
ret.lowered_funcs = tec::GetPerTargetModules(lowered_mod);
Expand Down
13 changes: 5 additions & 8 deletions src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -241,26 +241,23 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
ICHECK(main_func_info) << "The attribute \"main_func_info\" should be set at this point.";
function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value());

// Get only the Relay functions out of the lowered module so we can run type inference on them
IRModule main_module = tec::GetMainModule(lowered_mod);
main_module = relay::transform::InferType()(main_module);
relay::Function main_func = Downcast<relay::Function>(main_module->Lookup("main"));
Function lowered_main_func = Downcast<Function>(lowered_mod->Lookup("main"));

// Now that we have lowered all operators to TIR code, we can proceed with compilation.
//
// We need to unfortunately re-plan as the previous results have been invalidated by lowering
// we will fix this in future refactors.
memory_plan_ = GraphPlanMemory(main_func);
memory_plan_ = GraphPlanMemory(lowered_main_func);

// The graph planner also can not handle planning calls to global variables to we must remap

// First we convert all the parameters into input nodes.
for (auto param : main_func->params) {
for (auto param : lowered_main_func->params) {
auto node_ptr = GraphInputNode::make_node_ptr(param->name_hint(), GraphAttrs());
var_map_[param.get()] = AddNode(node_ptr, param);
}

heads_ = VisitExpr(main_func->body);
heads_ = VisitExpr(lowered_main_func->body);
std::ostringstream os;

dmlc::JSONWriter writer(&os);
Expand All @@ -277,7 +274,7 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<

Optional<Array<tvm::runtime::Module>> external_modules =
lowered_mod->GetAttr<Array<tvm::runtime::Module>>("external_mods");
ICHECK(external_modules) << "Attribute \"external_modules\" should be set at this point.";
ICHECK(external_modules) << "Attribute \"external_mods\" should be set at this point.";

// This is the point where we separate the functions in the module by target
ret.lowered_funcs = tec::GetPerTargetModules(lowered_mod);
Expand Down
52 changes: 20 additions & 32 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -292,14 +292,8 @@ InterpreterState::InterpreterState(Expr current_expr, InterpreterState::Stack st
class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
PatternFunctor<bool(const Pattern& p, const ObjectRef& v)> {
public:
// TODO(mbs, electriclilies): Collapse mod and per_target_module once IRModule subsumes
// LoweredModule.
Interpreter(IRModule mod, Map<Target, IRModule> per_target_module, Device device, Target target)
: mod_(mod),
per_target_module_(per_target_module),
device_(device),
target_(target),
debug_op_(Op::Get("debug")) {}
Interpreter(IRModule unified_mod, Device device, Target target)
: unified_mod_(unified_mod), device_(device), target_(target), debug_op_(Op::Get("debug")) {}

template <typename T>
T WithFrame(const Frame& fr, const std::function<T()>& f) {
Expand All @@ -316,7 +310,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
ObjectRef VisitExpr_(const VarNode* var_node) final { return Lookup(GetRef<Var>(var_node)); }

ObjectRef VisitExpr_(const GlobalVarNode* op) final {
return Eval(mod_->Lookup(GetRef<GlobalVar>(op)));
return Eval(unified_mod_->Lookup(GetRef<GlobalVar>(op)));
}

ObjectRef VisitExpr_(const OpNode* id) override {
Expand Down Expand Up @@ -387,9 +381,9 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,

// Project out just the function(s) we need.
IRModule lowered_projected_mod;
Map<Target, IRModule> per_target_module = tec::GetPerTargetModules(unified_mod_);
std::unordered_map<Target, IRModule, backend::TargetStrHash, backend::TargetStrEqual>
per_target_module_std_map =
backend::TargetModuleMapToTargetStrModuleMap(per_target_module_);
per_target_module_std_map = backend::TargetModuleMapToTargetStrModuleMap(per_target_module);
auto mod_itr = per_target_module_std_map.find(target);
ICHECK(mod_itr != per_target_module_std_map.end())
<< "No target module for target '" << target->str() << "'";
Expand Down Expand Up @@ -876,13 +870,11 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
}

private:
// Main module. All expressions are eval'ed w.r.t. the definitions in this module. This module
// may contain calls to TIR functions bound in a per_target_module_ below.
IRModule mod_;
// Map from target key to lowered TIR functions derived from mod_.
// Note that primitives are implicitly executed on target_, while shape functions are implicitly
// executed on the default 'cpu' host. Thus this map has at most two entries.
Map<Target, IRModule> per_target_module_;
// Unified module. Functions are annotated with their target.
// All expressions are eval'ed w.r.t. the definitions in this module.
// This module contains functions that used to be in main_module and the per_target_module (TIR
// functions) in one module.
IRModule unified_mod_;
// Cached packed functions for the primitives and shape functions, keyed by target and
// global var name.
std::unordered_map<std::pair<Target, std::string>, PackedFunc, PairHash> compiled_packed_funcs_;
Expand All @@ -902,7 +894,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
* rewritten \p mod and target-specific modules containing bindings for all TIR primitive
* functions needed by the rewritten module.
*/
std::pair<IRModule, Map<Target, IRModule>> Prepare(IRModule mod, Device device, Target target) {
IRModule Prepare(IRModule mod, Device device, Target target) {
// Things to initialize to pass into tec::LowerTEPass
// We only have one device-specific target.
tec::TargetMap targets = {{device.device_type, target}};
Expand Down Expand Up @@ -930,8 +922,7 @@ std::pair<IRModule, Map<Target, IRModule>> Prepare(IRModule mod, Device device,
With<transform::PassContext> ctx(pass_ctx);
mod = seq(mod);

// Lower all primitive functions reachable from expr.
return {tec::GetMainModule(mod), tec::GetPerTargetModules(mod)};
return mod;
}

/*! \brief Check if an expression could be changed by \p Prepare.
Expand Down Expand Up @@ -1020,11 +1011,9 @@ TypedPackedFunc<ObjectRef(Array<Expr>)> EvalFunction(IRModule mod, Expr expr, De
// and can just eval it directly.
expr_to_eval = expr;
}
std::pair<IRModule, Map<Target, IRModule>> main_and_lowered =
Prepare(mod_with_expr, device, target);
std::shared_ptr<Interpreter> intrp = std::make_shared<Interpreter>(
/*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device,
target);
IRModule lowered_mod = Prepare(mod_with_expr, device, target);

std::shared_ptr<Interpreter> intrp = std::make_shared<Interpreter>(lowered_mod, device, target);

//
// Step 2: Evaluate target function to a closure.
Expand Down Expand Up @@ -1063,12 +1052,11 @@ ObjectRef Eval(Expr expr, Map<GlobalTypeVar, TypeData> type_definitions,
std::unordered_set<String> import_set, Device device, Target target) {
std::pair<IRModule, GlobalVar> mod_and_global =
IRModule::FromExprInContext(expr, /*global_funcs=*/{}, type_definitions, import_set);
std::pair<IRModule, Map<Target, IRModule>> main_and_lowered =
Prepare(mod_and_global.first, device, target);
Interpreter intrp(
/*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device,
target);
Expr expr_to_eval = main_and_lowered.first->GetGlobalVar(mod_and_global.second->name_hint);

IRModule mod = Prepare(mod_and_global.first, device, target);

Interpreter intrp(mod, device, target);
Expr expr_to_eval = mod->GetGlobalVar(mod_and_global.second->name_hint);
if (expr.as<BaseFuncNode>() == nullptr) {
// TODO(mbs): IRModule::FromExpr will implicitly close over the free vars of expr
// unless it is a function, so we must reverse that in the expression to eval.
Expand Down
27 changes: 5 additions & 22 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -900,8 +900,9 @@ Map<Target, IRModule> GetPerTargetModules(IRModule mod) {

// Put the function in per_target_modules
if (!per_target_modules.count(target.value())) {
// Initialize the IRModule for this target and add the function
IRModule target_module;
// Initialize the IRModule for this target with the attributes from the input IRModule
IRModule target_module = IRModule({}, {}, {}, {}, mod->attrs);
// Add the function to the IRModule
target_module->Add(var, func);
per_target_modules[target.value()] = target_module;
} else {
Expand All @@ -918,33 +919,15 @@ Map<Target, IRModule> GetPerTargetModules(IRModule mod) {
return per_target_modules;
}

IRModule GetMainModule(IRModule mod) {
IRModule main_module;
// Copy the type defs
for (const auto& kv : mod->type_definitions) {
main_module->AddTypeDef(kv.first, kv.second);
}
// Copy all Relay functions (we don't include PrimFuncs)
for (auto kv : mod->functions) {
const GlobalVar& var = kv.first;
const BaseFunc& func = kv.second;
if (func->IsInstance<tvm::relay::FunctionNode>()) {
main_module->Add(var, func);
}
}
return main_module;
}

Pass LowerTEPass(TargetMap targets, DeviceMap device_context_map,
backend::StaticMemoryPlan memory_plan, const String& module_name,
std::function<void(Function)> process_fn) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule module,
PassContext ctx) {
return LowerTE(module, targets, device_context_map, memory_plan, module_name, process_fn);
};
// TODO(@electriclilies, mbs): Fold InferType() pass into LowerTEPass since it will always need to
// be called afterwards
return tvm::transform::CreateModulePass(pass_func, 1, "LowerTE", {});
return tvm::transform::Sequential(
{tvm::transform::CreateModulePass(pass_func, 0, "LowerTE", {}), InferType()});
}
} // namespace tec
} // namespace relay
Expand Down
7 changes: 0 additions & 7 deletions src/relay/backend/te_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,6 @@ Target GetTargetFromInteger(DLDeviceType dev_type, TargetMap targets);
*/
Map<Target, IRModule> GetPerTargetModules(IRModule mod);

/*!
* \brief Utility to extract all the Relay functions from an IRModule, with no PrimFuncs.
* \param mod The IRModule to extract the Relay functions from
* \return An IRModule containing only the Relay functions that are in the input mod (no PrimFuncs)
*/
IRModule GetMainModule(IRModule mod);

/*! \brief Lower an IRModule's primitive functions to TIR.
*
* This is the "back half" of the Relay compiler which lowers "primitive functions"
Expand Down
4 changes: 1 addition & 3 deletions src/relay/ir/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,7 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx)
DLOG(INFO) << "Executing function pass : " << pass_info->name
<< " with opt level: " << pass_info->opt_level;

// Execute the pass function and return a new module.
IRModule updated_mod =
IRModule(mod->functions, mod->type_definitions, mod->Imports(), mod->source_map);
IRModule updated_mod = mod->ShallowCopy();

std::vector<std::pair<GlobalVar, Function> > updates;
for (const auto& it : updated_mod->functions) {
Expand Down
5 changes: 4 additions & 1 deletion src/relay/transforms/partition_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
*/

#include <tvm/ir/error.h>
#include <tvm/ir/module.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/attrs/annotation.h>
#include <tvm/relay/expr.h>
Expand Down Expand Up @@ -509,7 +510,9 @@ class NameMangleExtFuncs : public MixedModeMutator {

// Walk the tree and mangle the functions. Then replace compiler functions
// with mangled functions in the module
IRModule new_module = IRModule({}, module_->type_definitions, module_->Imports());
IRModule new_module = module_->ShallowCopy();
new_module->functions = {};

for (const auto& pair : glob_funcs) {
if (auto* fn = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
Expand Down
2 changes: 1 addition & 1 deletion src/relay/transforms/to_basic_block_normal_form.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ IRModule ToBasicBlockNormalForm(const IRModule& mod) {
DLOG(INFO) << "ToBBlock:" << std::endl << mod;

// Create a new module by shallow copy.
auto mod_ = IRModule(mod->functions, mod->type_definitions, mod->Imports(), mod->source_map);
IRModule mod_ = mod->ShallowCopy();

tvm::Map<GlobalVar, Function> updates;
auto funcs = mod_->functions;
Expand Down
17 changes: 10 additions & 7 deletions src/relay/transforms/type_infer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,17 @@ class TypeInferencer : private ExprFunctor<Type(const Expr&)>,
this->EmitFatal(Diagnostic::Error(op->span) << "Cannot do type inference on global variables "
<< "without a module");
}

if (mod_->ContainGlobalVar(var->name_hint)) {
relay::Function e = Downcast<Function>(mod_->Lookup(var));
return e->checked_type();
} else {
return op->checked_type_;
BaseFunc func = mod_->Lookup(var->name_hint);

if (func->IsInstance<FunctionNode>()) {
relay::Function relay_func = Downcast<Function>(func);
return relay_func->checked_type();
}
}
// Return op->checked_type if the module doesn't contain the GlobalVar or the function is a
// PrimFunc (we don't typecheck PrimFuncs)
return op->checked_type_;
}

Type VisitExpr_(const ConstantNode* op) final { return op->tensor_type(); }
Expand Down Expand Up @@ -822,8 +826,7 @@ Pass InferType() {
[=](IRModule mod, const PassContext& pass_ctx) {
DLOG(INFO) << "tvm::relay::transform::InferType";
// Execute the pass function and return a new module.
IRModule updated_mod =
IRModule(mod->functions, mod->type_definitions, mod->Imports(), mod->source_map);
IRModule updated_mod = mod->ShallowCopy();

pass_ctx->diag_ctx = DiagnosticContext::Default(updated_mod);

Expand Down
2 changes: 2 additions & 0 deletions tests/python/relay/test_backend_compile_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ def get_func(shape):
engine.dump()


# Note: Once compile engine is removed, we should keep this test so that
# we make sure that opt_level=0 passes are being called correctly.
def test_compile_placeholder_bypass():
engine = relay.backend.compile_engine.get()
x = relay.var("x", shape=(2, 3))
Expand Down

0 comments on commit 2b3ec0d

Please sign in to comment.