Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BYOC] RelayToTIR custom codegen passes can still depend on dynamic shape functions #11619

Merged
merged 1 commit into from
Jun 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1064,9 +1064,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {

mod = transform::ToANormalForm()(mod);

IRModule lowered_mod = tec::LowerTEPass(
mod_name,
[this, workspace_byte_alignment](BaseFunc func) {
IRModule lowered_mod =
tec::LowerTE(mod_name, config_, [this, workspace_byte_alignment](BaseFunc func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
// allows us to process each function as we lower it.
Expand All @@ -1078,8 +1077,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
// execute as a further pass, instead writing data to the
// lowering process directly.
tec::UpdateFunctionMetadata(func, this->function_metadata_, workspace_byte_alignment);
},
config_)(mod);
})(mod);

auto lowered_main = lowered_mod->Lookup("main");
auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());
Expand Down
27 changes: 12 additions & 15 deletions src/relay/backend/graph_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,22 +217,19 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<
mod = WithAttr(mod, "main_func_info", func_info);
}

IRModule lowered_mod = tec::LowerTEPass(
mod_name_,
[this](BaseFunc func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
// allows us to process each function as we lower it.
if (func->GetAttr<String>(attr::kCompiler).defined()) {
UpdateConstants(func, &params_);
}
IRModule lowered_mod = tec::LowerTE(mod_name_, config_, [this](BaseFunc func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
// allows us to process each function as we lower it.
if (func->GetAttr<String>(attr::kCompiler).defined()) {
UpdateConstants(func, &params_);
}

// TODO(@areusch, @jroesch): We should refactor this to
// execute as a further pass, instead writing data to the
// lowering process directly.
tec::UpdateFunctionMetadata(func, this->function_metadata_);
},
config_)(mod);
// TODO(@areusch, @jroesch): We should refactor this to
// execute as a further pass, instead writing data to the
// lowering process directly.
tec::UpdateFunctionMetadata(func, this->function_metadata_);
})(mod);

Optional<backend::FunctionInfo> main_func_info =
lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info");
Expand Down
3 changes: 1 addition & 2 deletions src/relay/backend/interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -960,8 +960,7 @@ IRModule Prepare(IRModule mod, const CompilationConfig& config) {
// eta expand to support constructors in argument position.
transform::EtaExpand(
/*expand_constructor=*/true, /*expand_global_var=*/false),
transform::InferType(),
tec::LowerTEPass(/*module_name=*/"intrp", [](BaseFunc func) { /* no-op */ }, config)});
transform::InferType(), tec::LowerTE(/*module_name=*/"intrp", config)});

transform::PassContext pass_ctx = transform::PassContext::Current();
With<transform::PassContext> ctx(pass_ctx);
Expand Down
329 changes: 221 additions & 108 deletions src/relay/backend/te_compiler.cc

Large diffs are not rendered by default.

32 changes: 9 additions & 23 deletions src/relay/backend/te_compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
*/

/*!
* \file relay/backend/tir_compiler.h
* * \brief Internal compilation layer which lowers Relay "primitive functions" to TIR PrimFns.
* \file relay/backend/te_compiler.h
* \brief Internal compilation layer which lowers Relay "primitive functions" to TIR PrimFns.
*
*
* This represents the new design of the Relay compilation flow and will replace the interface
Expand Down Expand Up @@ -173,36 +173,22 @@ backend::FunctionInfo UpdateMainWorkspaceSize(const IRModule& mod, const Compila
*/
Map<Target, IRModule> GetPerTargetModules(IRModule mod);

/*! \brief Lower an IRModule's primitive functions to TIR.
*
* This is the "back half" of the Relay compiler which lowers "primitive functions"
* to TE expressions, schedules them, and then to TIR.
*
* \param module The IRModule.
* \param memory_plan The memory plan used during lowering
* \param module_name The name of this module
* \param process_fn Callback allowing one-level up code generators to process
* each function that we lower
* \return The lowered module, see above.
*/
IRModule LowerTE(
const IRModule& module, backend::StaticMemoryPlan memory_plan, const String& module_name,
ProcessFn process_fn = [](BaseFunc f) {});
inline void DefaultProcessFn(BaseFunc) {}

/*!
* \brief Pass to lower an IRModule's primitive functions to TIR.
*
* This is the "back half" of the Relay compiler which lowers "primitive functions"
* to TE expressions, schedules them, and then to TIR. It annotates all functions
* with their target.
* to TE expressions, schedules them, and emits PrimFuncs.
*
* \param module_name The name of this module
* \param process_fn Callback allowing one-level up code generators to process
* each function that we lower
* \param module_name The name of this module, used as a prefix for generated globals.
* \param config All available targets.
* \param process_fn Callback allowing one-level up code generators to process
* each function that we lower (default is no-op).
* \returns The pass which lowers primitive functions to TIR
*/
transform::Pass LowerTEPass(String module_name, ProcessFn process_fn, CompilationConfig config);
transform::Pass LowerTE(String module_name, CompilationConfig config,
ProcessFn process_fn = DefaultProcessFn);

} // namespace tec
} // namespace relay
Expand Down
24 changes: 10 additions & 14 deletions src/relay/backend/vm/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1039,13 +1039,11 @@ transform::Sequential VMCompiler::FuseAndLowerOperators(const CompilationConfig&
// Give each "primitive" Function a hash.
pass_seqs.push_back(LabelOps());
// Lower "primitive" Functions to PrimFuncs and rewrite calls.
pass_seqs.push_back(tec::LowerTEPass(/*module_name=*/"vm_mod",
[this](const BaseFunc& func) {
if (func->GetAttr<String>(attr::kCompiler).defined()) {
backend::UpdateConstants(func, &params_);
}
},
config));
pass_seqs.push_back(tec::LowerTE(/*module_name=*/"vm_mod", config, [this](const BaseFunc& func) {
if (func->GetAttr<String>(attr::kCompiler).defined()) {
backend::UpdateConstants(func, &params_);
}
}));
// Since lowered functions are bound in the IRModule, we can now eliminate any unused
// let-bound functions.
pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false));
Expand Down Expand Up @@ -1090,13 +1088,11 @@ IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
pass_seqs.push_back(transform::LabelOps());

// Lower all functions annotated as "primitive" by FuseOps.
pass_seqs.push_back(tec::LowerTEPass(/*module_name=*/"vm_mod",
[this](const BaseFunc& func) {
if (func->GetAttr<String>(attr::kCompiler).defined()) {
backend::UpdateConstants(func, &params_);
}
},
config_));
pass_seqs.push_back(tec::LowerTE(/*module_name=*/"vm_mod", config_, [this](const BaseFunc& func) {
if (func->GetAttr<String>(attr::kCompiler).defined()) {
backend::UpdateConstants(func, &params_);
}
}));

// Since lowered functions are bound in the IRModule, we can now eliminate any unused
// let-bound functions.
Expand Down
51 changes: 0 additions & 51 deletions src/relay/transforms/compiler_function_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,42 +81,6 @@ class Outliner : public MixedModeMutator {
IRModule mod_;
};

/*!
* \brief Rewrite calls to global "Compiler" functions to use the 'call_lowered' convention.
*/
class CallRewriter : public MixedModeMutator {
public:
CallRewriter(std::string compiler_filter, IRModule mod)
: compiler_filter_(std::move(compiler_filter)), mod_(std::move(mod)) {}

Expr Rewrite_(const CallNode* pre, const Expr& post) final {
Call new_call = Downcast<Call>(post);
if (const auto* global_var_node = new_call->op.as<GlobalVarNode>()) {
if (const auto* function_node =
mod_->Lookup(GetRef<GlobalVar>(global_var_node)).as<FunctionNode>()) {
Optional<String> opt_compiler = function_node->GetAttr<String>(attr::kCompiler);
if (opt_compiler.defined() &&
(compiler_filter_.empty() || opt_compiler.value() == compiler_filter_)) {
Optional<String> opt_global_symbol =
function_node->GetAttr<String>(tvm::attr::kGlobalSymbol);
ICHECK(opt_global_symbol.defined());
GlobalVar global_symbol = mod_->GetGlobalVar(opt_global_symbol.value());
CallLoweredAttrs attrs;
attrs.metadata.Set("relay_attrs", new_call->attrs);
return CallLowered(global_symbol, new_call->args, attrs, new_call->span);
}
}
}
return post;
}

private:
/*! \brief If non-empty, the "Compiler" attribute value to require on functions to outline. */
std::string compiler_filter_;
/*! \brief Module being rewritten. */
IRModule mod_;
};

} // namespace

GlobalSymbolCache::~GlobalSymbolCache() = default;
Expand Down Expand Up @@ -169,20 +133,6 @@ transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) {
runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
[compiler_filter = std::move(compiler_filter)](IRModule mod, transform::PassContext ctx) {
IRModule output_mod = mod->ShallowCopy();

// First pass, rewrite the calls.
// We have to do this before marking functions as 'extern' to know which calls to rewrite!
for (const auto& kv : mod->functions) {
if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) {
Expr new_body =
CallRewriter(compiler_filter, output_mod).VisitExpr(function_node->body);
Function new_function =
WithFields(GetRef<Function>(function_node), /*opt_params=*/{}, new_body);
output_mod->Update(kv.first, new_function);
}
}

// Second pass, mark functions as 'extern'.
for (const auto& kv : mod->functions) {
if (const auto* function_node = kv.second.as<FunctionNode>()) {
Optional<String> opt_compiler = function_node->GetAttr<String>(attr::kCompiler);
Expand All @@ -197,7 +147,6 @@ transform::Pass MarkCompilerFunctionsAsExtern(std::string compiler_filter) {
}
}
}

return output_mod;
};

Expand Down
11 changes: 4 additions & 7 deletions src/relay/transforms/compiler_function_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,8 @@
*
* - \p MarkCompilerFunctionsAsExtern will replace global functions with a matching "Compiler"
* attribute with the same function with just an "Extern" attribute, signalling the function
* has been dealt with. Calls to such functions will be rewritten to use the 'call_lowered'
* calling convention. Can be used after lowering to cleanup the IRModule.
*
* Note that the above behaviour is hard coded within the TECompiler, but is only available to
* external codegen using the Function-at-a-time "relay.ext.toolchain" extension point.
* has been dealt with. However calls to such functions will be left unchanged. Can be used
* after lowering to cleanup the IRModule.
*/

#ifndef TVM_RELAY_TRANSFORMS_COMPILER_FUNCTION_UTILS_H_
Expand Down Expand Up @@ -118,8 +115,8 @@ transform::Pass OutlineCompilerFunctionsWithExistingGlobalSymbols(std::string co

/*!
* \brief A pass to mark all global functions which have a "Compiler" attribute matching
* compiler_filter as 'extern' by replacing all attributes with a single "Extern" attribute, and
* rewrite all calls to such functions to use the 'call_lowered' calling convention.
* compiler_filter as 'extern' by replacing all attributes with a single "Extern" attribute.
* Calls to such functions are not changed.
*
* If \p compiler_filter is non-empty only functions with that as their attribute value are
* outlined.
Expand Down
Loading