Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AOT] Avoid call_extern() with incorrect argument count #15301

Merged
merged 1 commit into from
Jul 18, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 37 additions & 1 deletion src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,32 @@ class AOTExecutorCodegen : public MixedModeVisitor {
// call_extern calling convention with optional context
if (has_c_device_api_context) {
device_context = device_contexts_.Get(global_var).value();
args.push_back(device_context);

// call_extern has no further legalization steps, and
// requires the number of arguments to match exactly. For
// internal calls, conditionally append the device context.
bool requires_device_context = [&]() -> bool {
Optional<Integer> opt = num_arguments_.Get(global_var);
if (!opt.defined()) {
// For external calls, we must trust that the user has
// supplied a kernel that accepts a device_context
// argument.
return true;
}
int num_callee_params = opt.value()->value;
int num_args = call_lowered_props.arguments.size();
if (num_callee_params == num_args) {
return false;
} else if (num_callee_params == num_args + 1) {
return true;
} else {
LOG(FATAL) << "Callee " << global_var << " requires " << num_callee_params
<< ", but is called with " << num_args << " arguments.";
}
}();
if (requires_device_context) {
args.push_back(device_context);
}
}
func_call = tir::Evaluate(AddCheckReturn(
tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), args)));
Expand Down Expand Up @@ -1007,6 +1032,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
Map<String, tir::Var> devices_;
/*! \brief map of GlobalVars to C Device API contexts */
Map<GlobalVar, tir::Var> device_contexts_;
/*! \brief map of GlobalVars to the number of arguments they require */
Map<GlobalVar, Integer> num_arguments_;
/*! \brief input and output variables belonging to the main function signature */
Array<tir::Var> main_signature_;
/*! \brief input and output variables belonging to the main function signature */
Expand Down Expand Up @@ -1183,6 +1210,15 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}

CollectDeviceVariables(lowered_mod->GetAttr<Map<GlobalVar, String>>("device_contexts").value());
num_arguments_ = [&]() -> Map<GlobalVar, Integer> {
Map<GlobalVar, Integer> arg_count;
for (const auto& [gvar, func] : lowered_mod->functions) {
if (const auto* prim_func = func.as<tir::PrimFuncNode>()) {
arg_count.Set(gvar, prim_func->params.size());
}
}
return arg_count;
}();
VisitExpr(lowered_main_func->body);

// Create the runner function. Please note that the function is not legal yet
Expand Down
Loading