Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Replace compile engine with TE compiler in the VM #8501

Merged
merged 22 commits into from
Aug 9, 2021
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
b9c95fd
[VM] Add imports to new TE in VM compiler
mikepapadim Jul 14, 2021
451d838
[VM] Add comments to compile engine usages
mikepapadim Jul 14, 2021
11ff63b
Merge branch 'main' of https://github.com/apache/tvm into vm_te_migra…
mikepapadim Jul 15, 2021
07947cc
Merge branch 'main' of https://github.com/apache/tvm into vm_te_migra…
mikepapadim Jul 16, 2021
d53fd1e
[VM] Replace depreceated CachedFunc of compile_engine with TE_compiler
mikepapadim Jul 16, 2021
37f344e
[VM] rm compiler engine compiler.cc
mikepapadim Jul 16, 2021
0f428cf
[VM] Replace compile engine with TECompiler in memory allocator
mikepapadim Jul 16, 2021
88707fa
[VM] Add relay interface to te_compiler
mikepapadim Jul 16, 2021
ea022e1
Merge branch 'main' of https://github.com/apache/tvm into vm_te_migra…
mikepapadim Jul 19, 2021
a33e069
Merge branch 'main' of https://github.com/apache/tvm into vm_te_migra…
mikepapadim Jul 19, 2021
9fd6552
Merge branch 'vm_te_migration' of github.com:mikepapadim/tvm into vm_…
mikepapadim Jul 20, 2021
452815d
[Relay] Fix linting errors
mikepapadim Jul 20, 2021
c4d08d2
Merge branch 'main' of https://github.com/apache/tvm into vm_te_migra…
mikepapadim Jul 20, 2021
aed5d3b
Merge branch 'main' of https://github.com/apache/tvm into vm_te_migra…
mikepapadim Jul 21, 2021
a30cf6a
Move TEcompiler to VMCompilerContext; add global func into IRmodule w…
YuchenJin Jul 22, 2021
7dd9282
add back the check
YuchenJin Jul 22, 2021
3234a77
skip the check for ext func in tecompiler
YuchenJin Jul 26, 2021
9ca3ffe
skip tvm::build for external functions
YuchenJin Aug 2, 2021
0e823ce
trigger ci
YuchenJin Aug 3, 2021
c6f7d26
retrigger ci
YuchenJin Aug 4, 2021
03d1fdf
retrigger ci
YuchenJin Aug 5, 2021
2aa3114
remove the unnecessary loop in tecompiler
YuchenJin Aug 8, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,11 @@ class RelayBuildModule : public runtime::ModuleNode {

auto lowered_funcs = executor_codegen_->GetIRModule();

// No need to build for external functions.
if (lowered_funcs.find("ext_dev") != lowered_funcs.end()) {
lowered_funcs.Set("ext_dev", IRModule());
}

// Generate a placeholder function that attaches linked params as its arguments.
if (target_host->GetAttr<Bool>("link-params").value_or(Bool(false))) {
CHECK(pf != nullptr) << "Unable to link-params with no target_host and no llvm codegen.";
Expand Down
4 changes: 4 additions & 0 deletions src/relay/backend/te_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ class TECompilerImpl : public TECompilerNode {
auto target = Target("ext_dev");
auto global_var = GlobalVar(func_name);
global_var->checked_type_ = key->source_func->checked_type();
ir_module->Add(global_var, key->source_func);
value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module);
return value;
}
Expand Down Expand Up @@ -349,6 +350,9 @@ class LowerTensorExpr : public ExprMutator {
Map<GlobalVar, tir::PrimFunc> prim_fns;

for (auto prim_fn : ext_func->funcs->functions) {
if (prim_fn.second->GetAttr<String>(attr::kCompiler).defined()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just querying, when is this untrue for an external function that's gone via Lower ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Mousius, After calling TECompilerImpl::Lower() on an external function, the TECompiler will just encapsulate it into CachedFunc and return, and the external function is not lowered. It will be lowered by the external codegen with the TECompilerNode::LowerExternalFunctions() later. I added a line ir_module->Add(global_var, key->source_func); to LowerInternal, so this is untrue when some function in the funcs of the CachedFuncNode is not lowered yet.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @YuchenJin, thanks for getting back to me, the reason I asked was because this loop is gated by the kCompiler attribute:

if (func->GetAttr<String>(attr::kCompiler).defined()) {

And then when it goes into LowerInternal it'll end up in this block (with your alterations):

if (key->source_func->GetAttr<String>(attr::kCompiler).defined()) {
      auto ir_module = IRModule();
      const auto name_node = key->source_func->GetAttr<String>(tvm::attr::kGlobalSymbol);
      ICHECK(name_node.defined()) << "External function has not been attached a name yet.";
      auto func_name = GetUniqueName(name_node.value(), &name_map_);
      auto target = Target("ext_dev");
      auto global_var = GlobalVar(func_name);
      global_var->checked_type_ = key->source_func->checked_type();
      ir_module->Add(global_var, key->source_func);
      value->cached_func = CachedFunc(target, global_var, {}, {}, te::Schedule(), {}, ir_module);
      return value;
    }

Which means when we enter this loop, the only function in the IRModule should be a function taken from key->source_func which has the attr on it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Mousius for elaborating it! I think you are right, and we can probably remove this loop because the IRModule should not contain a PrimFunc anyway for external functions, and we don't need to check it. What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try removing the loop, if tests still pass we should be confident it wasn't necessary 😸

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, tests still pass after removing it. @jroesch, could you review this PR? :)

continue;
}
CHECK(prim_fn.second.as<tir::PrimFuncNode>()) << "must be a prim fn";
prim_fns.Set(prim_fn.first, Downcast<tir::PrimFunc>(prim_fn.second));
}
Expand Down
14 changes: 5 additions & 9 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@
#include <vector>

#include "../../../target/source/codegen_source_base.h"
#include "../../backend/compile_engine.h"
#include "../../op/op_common.h"
#include "../../transforms/pass_utils.h"
#include "../utils.h"
Expand Down Expand Up @@ -79,6 +78,7 @@ namespace vm {
using namespace tvm::runtime;
using namespace tvm::runtime::vm;
using namespace relay::transform;
using namespace tec;

// (@jroesch): VM passes, eventually declare as passes.
bool IsClosure(const Function& func);
Expand Down Expand Up @@ -253,7 +253,6 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
ExprDeviceMap expr_device_map)
: last_register_(0),
registers_num_(0),
engine_(CompileEngine::Global()),
context_(context),
target_host_(target_host),
expr_device_map_(std::move(expr_device_map)) {
Expand Down Expand Up @@ -465,7 +464,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
void EmitShapeFunc(Function func, Array<Expr> inputs, Array<Expr> outputs) {
// Lower shape function
CCacheKey key(func, target_host_);
auto cfunc = engine_->LowerShapeFunc(key);
auto cfunc = context_->compiler->LowerShapeFunc(key);
int op_index = -1;
// pick the only function inside the context
ICHECK_EQ(cfunc->funcs->functions.size(), 1);
Expand Down Expand Up @@ -551,7 +550,7 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {

CCacheKey key(func, target);
auto mangle_fn = [](String name) { return name; };
auto cfunc = engine_->Lower(key, mangle_fn);
auto cfunc = context_->compiler->Lower(key, mangle_fn);

auto op_index = -1;
if (func->GetAttr<String>(attr::kCompiler).defined()) {
Expand Down Expand Up @@ -857,8 +856,6 @@ class VMFunctionCompiler : ExprFunctor<void(const Expr& expr)> {
size_t last_register_;
/*! \brief Total number of virtual registers allocated. */
size_t registers_num_;
/*! \brief Compiler engine to lower primitive functions. */
CompileEngine engine_;
/*! \brief Global shared meta data */
VMCompilerContext* context_;
/*! \brief Target devices. */
Expand Down Expand Up @@ -1184,8 +1181,8 @@ void VMCompiler::Codegen() {
}
}

auto compile_engine = CompileEngine::Global();
auto ext_mods = compile_engine->LowerExternalFunctions();
auto ext_mods = context_.compiler->LowerExternalFunctions();

runtime::Module lib;
if (funcs.size() > 0) {
lib = tvm::build(funcs, target_host_);
Expand All @@ -1196,7 +1193,6 @@ void VMCompiler::Codegen() {
}
lib = codegen::CreateMetadataModule(params_, lib, ext_mods, target_host_, runtime::Metadata());
exec_->SetLib(lib);
CompileEngine::Global()->Clear();
}

ExprDeviceMap VMCompiler::AnalyzeContext() const {
Expand Down
7 changes: 5 additions & 2 deletions src/relay/backend/vm/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@

#include "../../../runtime/vm/naive_allocator.h"
#include "../../../runtime/vm/profiler/vm.h"
#include "../../backend/compile_engine.h"
#include "../../transforms/pass_utils.h"
#include "../te_compiler.h"
#include "../te_compiler_cache.h"

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -75,12 +76,14 @@ struct VMCompilerContext {
TagMap tag_map;
// Map from global var to a unique integer
GlobalMap global_map;
// TEcompiler for lowering
tec::TECompiler compiler;
// List of constants
std::vector<NDArray> constants;
// Device type for constants
std::vector<Index> const_device_type;
// List of cached functions
std::vector<CachedFunc> cached_funcs;
std::vector<tec::CachedFunc> cached_funcs;
// The functions that have been lowered.
std::unordered_map<tir::PrimFunc, size_t, ObjectPtrHash, ObjectPtrEqual> seen_funcs;
};
Expand Down
10 changes: 7 additions & 3 deletions src/relay/transforms/memory_alloc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,16 @@
#include <unordered_set>
#include <vector>

#include "../backend/compile_engine.h"
#include "../backend/te_compiler.h"
#include "../backend/te_compiler_cache.h"
#include "../op/memory/memory.h"
#include "../op/vm/vm.h"
#include "./pass_utils.h"
#include "let_list.h"
#include "pattern_utils.h"

using namespace tvm::runtime;
using namespace tvm::relay::tec;

namespace tvm {
namespace relay {
Expand Down Expand Up @@ -271,9 +273,11 @@ class DialectRewriter : public ExprMutator {
Array<Expr> EmitShapeFunc(LetList* scope, const Function& func,
const std::vector<Expr>& new_args) {
Array<Expr> shape_func_ins;
auto engine = CompileEngine::Global();

TECompiler compiler;

CCacheKey key(func, target_host_);
auto cfunc = engine->LowerShapeFunc(key);
auto cfunc = compiler->LowerShapeFunc(key);
auto input_states = cfunc->shape_func_param_states;

Array<Integer> is_inputs;
Expand Down