Skip to content

Commit

Permalink
[AOT] Avoid call_extern() with incorrect argument count (apache#15301)
Browse files Browse the repository at this point in the history
Prior to this commit, if device initialization is required, the AOT
main function produced a `call_extern()` that included the device
context as input.  This commit updates the AOT main function to
provide the device context only if the function being called accepts a
device context as input.

If an extra device context argument is included at the call site, the
C codegen would produce a function signature that includes the device
context for the caller's compilation unit, but a signature without the
device context for the callee's compilation unit.  While this can
compile and run in some cases, it is undefined behavior for the
signature to vary between compilation units, and should be avoided.

This was initially discovered while debugging
apache#14985, in which changes to the
lowering flow resulted in the caller and callee being within the same
compilation unit.
  • Loading branch information
Lunderberg authored and junrushao committed Jul 27, 2023
1 parent a67ccba commit 5e8d95d
Showing 1 changed file with 37 additions and 1 deletion.
38 changes: 37 additions & 1 deletion src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,32 @@ class AOTExecutorCodegen : public MixedModeVisitor {
// call_extern calling convention with optional context
if (has_c_device_api_context) {
device_context = device_contexts_.Get(global_var).value();
args.push_back(device_context);

// call_extern has no further legalization steps, and
// requires the number of arguments to match exactly. For
// internal calls, conditionally append the device context.
bool requires_device_context = [&]() -> bool {
Optional<Integer> opt = num_arguments_.Get(global_var);
if (!opt.defined()) {
// For external calls, we must trust that the user has
// supplied a kernel that accepts a device_context
// argument.
return true;
}
int num_callee_params = opt.value()->value;
int num_args = call_lowered_props.arguments.size();
if (num_callee_params == num_args) {
return false;
} else if (num_callee_params == num_args + 1) {
return true;
} else {
LOG(FATAL) << "Callee " << global_var << " requires " << num_callee_params
<< ", but is called with " << num_args << " arguments.";
}
}();
if (requires_device_context) {
args.push_back(device_context);
}
}
func_call = tir::Evaluate(AddCheckReturn(
tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), args)));
Expand Down Expand Up @@ -1007,6 +1032,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
Map<String, tir::Var> devices_;
/*! \brief map of GlobalVars to C Device API contexts */
Map<GlobalVar, tir::Var> device_contexts_;
/*! \brief map of GlobalVars to the number of arguments they require */
Map<GlobalVar, Integer> num_arguments_;
/*! \brief input and output variables belonging to the main function signature */
Array<tir::Var> main_signature_;
/*! \brief input and output variables belonging to the main function signature */
Expand Down Expand Up @@ -1183,6 +1210,15 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}

CollectDeviceVariables(lowered_mod->GetAttr<Map<GlobalVar, String>>("device_contexts").value());
num_arguments_ = [&]() -> Map<GlobalVar, Integer> {
Map<GlobalVar, Integer> arg_count;
for (const auto& [gvar, func] : lowered_mod->functions) {
if (const auto* prim_func = func.as<tir::PrimFuncNode>()) {
arg_count.Set(gvar, prim_func->params.size());
}
}
return arg_count;
}();
VisitExpr(lowered_main_func->body);

// Create the runner function. Please note that the function is not legal yet
Expand Down

0 comments on commit 5e8d95d

Please sign in to comment.