Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AOT] Enable A-Normal Form in the AOT executor #11091

Merged
merged 4 commits into from
May 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 78 additions & 24 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor {
VisitExpr(func);
CreateStorage(call_node);
for (const Expr& arg : args) {
GetStorage(arg);
VisitExpr(arg);
}
AssignReturnSid(GetRef<Expr>(call_node));
}
Expand All @@ -126,7 +126,7 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor {
for (const auto& param : func_node->params) {
CreateStorage(param.get());
}
GetStorage(func_node->body);
VisitExpr(func_node->body);
}

void VisitExpr_(const GlobalVarNode* op) final {
Expand Down Expand Up @@ -168,7 +168,9 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor {
void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "if is not supported."; }

void PreVisitLetBinding_(const Var& var, const Expr& value) final {
LOG(FATAL) << "let is not supported.";
VisitExpr(value);
StorageInfo si = GetStorage(value);
storage_device_map_[var] = si;
}

private:
Expand Down Expand Up @@ -219,7 +221,8 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor {
Expr true_expr = IgnoreOnDevice(expr);
VisitExpr(true_expr);
auto it = storage_device_map_.find(true_expr);
ICHECK(it != storage_device_map_.end());
ICHECK(it != storage_device_map_.end()) << "Could not find " << true_expr->GetTypeKey() << " "
<< PrettyPrint(true_expr) << " in storage device map";
return it->second;
}

Expand Down Expand Up @@ -335,6 +338,9 @@ class AOTExecutorCodegen : public MixedModeVisitor {
*/
std::vector<tir::Var> PackSid(Expr expr) {
std::vector<tir::Var> buffer_vars;

ICHECK(storage_device_map_.find(expr) != storage_device_map_.end())
<< "Storage map did not contain constant expr " << PrettyPrint(expr);
StorageInfo& sinfo = storage_device_map_[expr];

// Note that an expression can have multiple sids associated with it
Expand Down Expand Up @@ -599,6 +605,12 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}

void VisitExpr_(const CallNode* call_node) override {
OnDeviceProps on_device_props = GetOnDeviceProps(call_node);
if (on_device_props.body.defined()) {
VisitExpr(on_device_props.body);
return;
}

DeviceCopyProps device_copy_props = GetDeviceCopyProps(call_node);
CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);

Expand Down Expand Up @@ -626,6 +638,11 @@ class AOTExecutorCodegen : public MixedModeVisitor {
Expr expr = GetRef<Expr>(op);
StorageInfo& sinfo = storage_device_map_[expr];

// Let bound vars refer to a value, so these should not be considered "output" vars.
if (let_bound_vars_.find(GetRef<Var>(op)) != let_bound_vars_.end()) {
return;
}

// If the Var node is an output node we need to copy the content of the variable to the output
// It's safe to check the SID here because Var StorageToken are never reallocated
auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]);
Expand All @@ -646,6 +663,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {

void VisitExpr_(const ConstantNode* op) override {
Expr expr = GetRef<Expr>(op);
ICHECK(storage_device_map_.find(expr) != storage_device_map_.end())
<< "Storage map did not contain constant expr " << PrettyPrint(expr);
StorageInfo& sinfo = storage_device_map_[expr];
std::stringstream ss;
ss << "constant_" << constant_map_.size();
Expand Down Expand Up @@ -674,12 +693,20 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}

void VisitExpr_(const LetNode* op) override {
// TODO(giuseros): support Let nodes in AOT
LOG(FATAL) << "Let not yet implemented in AOT";
auto pre_visit = [this](const LetNode* op) {
let_bound_vars_.insert(op->var);
this->VisitExpr(op->value);
};
auto post_visit = [this](const LetNode* op) {
this->VisitExpr(op->body);
this->visit_counter_[op] += 1;
};
ExpandANormalForm(op, pre_visit, post_visit);
}

void VisitExpr_(const TupleGetItemNode* op) override { VisitExpr(op->tuple); }
void VisitExpr_(const OpNode* op) override {
if (GetRef<Op>(op) != CallLoweredOp()) {
if (GetRef<Op>(op) != CallLoweredOp() && GetRef<Op>(op) != OnDeviceOp()) {
LOG(FATAL) << "All OpNodes except for call_lowered should have been expanded";
}
}
Expand Down Expand Up @@ -731,6 +758,12 @@ class AOTExecutorCodegen : public MixedModeVisitor {
continue;
}

// Make sure it hasn't already been allocated, this can happen
// with let-bound var/value pairs.
if (allocated.find(sid) != allocated.end()) {
continue;
}

allocated[sid] = constant_map_.count(sids_table_[sid]);

// TODO(giuseros): we should allocate this once outside the PrimFunc
Expand Down Expand Up @@ -775,21 +808,36 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}

/*!
* brief Access IO vars using the buffer vars and
* \brief Access IO vars using the buffer vars and
* not the actual var.
*/
tir::Var GetBufferVarForIO(int index) { return main_buffer_map_[main_signature_[index]]->data; }

/*!
* brief Create tir::Var for input/output while updating
* the buffer_maps.
* \brief Create tir::Var for input/output while updating the buffer_maps.
*
* \param expr The expression to evaluate.
* \param original_name The name of the tir::Var.
* \param use_unique_name Whether to generate a new unique name where a name conflicts.
*/
void CreateIOVar(const Expr& expr, const std::string& original_name,
bool use_unique_name = true) {
if (expr->IsInstance<TupleNode>()) {
Tuple tuple = Downcast<Tuple>(expr);
for (unsigned i = 0; i < tuple->fields.size(); i++) {
CreateIOVar(tuple->fields[i], original_name);
CreateIOVar(expr->checked_type(), original_name, use_unique_name);
}

/*!
* \brief Create tir::Var for input/output while updating the buffer_maps.
*
* \param expr The expression to evaluate.
* \param original_name The name of the tir::Var.
* \param use_unique_name Whether to generate a new unique name where a name conflicts.
*/
void CreateIOVar(const Type& type, const std::string& original_name,
bool use_unique_name = true) {
if (type->IsInstance<TupleTypeNode>()) {
TupleType tuple_type = Downcast<TupleType>(type);
for (unsigned i = 0; i < tuple_type->fields.size(); i++) {
CreateIOVar(tuple_type->fields[i], original_name);
}
} else {
std::string name = original_name;
Expand All @@ -798,19 +846,20 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}
tir::Var var = tir::Var(name, DataType::Handle());
main_signature_.push_back(var);
auto tensor_type = expr->checked_type().as<TensorTypeNode>();
auto tensor_type = type.as<TensorTypeNode>();
ICHECK(tensor_type) << "Expected TensorType node but was " << type->GetTypeKey();
DataType elem_type = tensor_type->dtype;
tir::Var buffer_var =
tir::Var(name + "_buffer_var", PointerType(PrimType(elem_type), "global"));
tir::Buffer buffer = tir::Buffer(buffer_var, elem_type, tensor_type->shape, {}, 0,
name + "_buffer", 16, 1, tir::BufferType::kDefault);
main_buffer_map_.Set(var, buffer);
io_tensor_types_.Set(var, Downcast<TensorType>(expr->checked_type()));
io_tensor_types_.Set(var, Downcast<TensorType>(type));
}
}

/*!
* brief Create a unique name for I/O Var
* \brief Create a unique name for I/O Var
*/
std::string GetUniqueIOVarName(std::string name) {
if (io_var_names_.find(name) == io_var_names_.end()) {
Expand All @@ -823,7 +872,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}

/*!
* brief Calculate workspace sizes for PrimFuncs in the IRModule
* \brief Calculate workspace sizes for PrimFuncs in the IRModule
*/
Map<String, FunctionInfo> CalculateWorkspaceSizes(
const IRModule& lowered_mod, const Map<String, FunctionInfo>& function_metadata) {
Expand Down Expand Up @@ -852,7 +901,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}

/*!
* brief Run USMP to plan memory for lowered IRModule
* \brief Run USMP to plan memory for lowered IRModule.
*/
IRModule PlanMemoryWithUSMP(const IRModule& mod) {
VLOG(1) << "Planning memory with USMP for module:" << std::endl << PrettyPrint(mod);
Expand Down Expand Up @@ -888,7 +937,7 @@ class AOTExecutorCodegen : public MixedModeVisitor {
}

/*!
* brief Run StorageRewrite to plan memory for lowered IRModule
* \brief Run StorageRewrite to plan memory for lowered IRModule.
*/
IRModule PlanMemoryWithStorageRewrite(const IRModule& mod) {
Executor executor_config = mod->GetAttr<Executor>(tvm::attr::kExecutor).value();
Expand Down Expand Up @@ -966,6 +1015,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
std::vector<int> return_sid_;
/*! \brief This is per IO var name counter to aid the generating unique names */
std::unordered_map<std::string, int> io_var_names_;
/*! \brief A set of variables that are let bound. */
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> let_bound_vars_;

public:
AOTExecutorCodegen(runtime::Module* mod, const Array<Target>& targets)
Expand Down Expand Up @@ -1011,6 +1062,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
<< ") is not one of the expected values";
}

mod = transform::ToANormalForm()(mod);

IRModule lowered_mod = tec::LowerTEPass(
mod_name,
[this, workspace_byte_alignment](BaseFunc func) {
Expand Down Expand Up @@ -1071,12 +1124,13 @@ class AOTExecutorCodegen : public MixedModeVisitor {
// If output tensor names were provided use them
if (auto opt = func->GetAttr<Array<String>>("output_tensor_names")) {
Array<String> output_tensor_names = opt.value();
if (lowered_main_func->body->IsInstance<TupleNode>()) {
Tuple output_tuple = Downcast<Tuple>(lowered_main_func->body);
for (unsigned i = 0; i < output_tuple->fields.size(); i++) {
Expr output_expr = lowered_main_func->body;
if (output_expr->checked_type()->IsInstance<TupleTypeNode>()) {
TupleType output_tuple_type = Downcast<TupleType>(output_expr->checked_type());
for (unsigned i = 0; i < output_tuple_type->fields.size(); i++) {
// AoT Executor Codegen does not create these names,
// thus should be used as they are provided.
CreateIOVar(output_tuple->fields[i], output_tensor_names[i],
CreateIOVar(output_tuple_type->fields[i], output_tensor_names[i],
/*use_unique_name = */ false);
}
} else {
Expand Down
78 changes: 59 additions & 19 deletions src/relay/backend/contrib/cmsisnn/relay_to_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -655,19 +655,61 @@ class RelayToTIRVisitor : public MixedModeMutator {
return Call(new_global_var, call->args, call->attrs, call->type_args, call->span);
}

Expr Rewrite_(const CallNode* pre, const Expr& post) override {
if (const CallNode* call = post.as<CallNode>()) {
auto* func = call->op.as<FunctionNode>();
if (func == nullptr) {
return post;
Expr VisitExpr_(const LetNode* op) final {
auto pre_visit = [this](const LetNode* op) {
Expr var = this->VisitExpr(op->var);
Expr value = this->VisitExpr(op->value);
// outlineable function no longer needs let binding
if (this->CanOutlineExpr(value)) {
this->memo_[var] = value;
}
};
auto post_visit = [this](const LetNode* op) {
// Rely on the Memoizer to cache pre-visit values
Expr value = this->VisitExpr(op->value);
Expr body = this->VisitExpr(op->body);
auto expr = GetRef<Expr>(op);
// drop the let binding
if (this->CanOutlineExpr(value)) {
this->memo_[expr] = this->VisitExpr(op->body);
} else {
Var var = Downcast<Var>(this->VisitExpr(op->var));
if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) {
this->memo_[expr] = expr;
} else {
this->memo_[expr] = Let(var, value, body);
}
}
};
ExpandANormalForm(op, pre_visit, post_visit);
return memo_[GetRef<Expr>(op)];
}

auto codegen_name = func->GetAttr<String>(attr::kCompiler);
if (codegen_name.defined() && codegen_name == "cmsis-nn") {
const CallNode* inner_call = func->body.as<CallNode>();
bool CanOutlineExpr(const Expr& expr) {
// TODO(@lhutton1): This behaviour is similar to the OutlineCompilerFunctions pass
// we could reuse this functionality by separating outlining and lowering in this
// pass.
if (!expr->IsInstance<FunctionNode>()) {
return false;
}
const auto* func = expr.as<FunctionNode>();
auto codegen_name = func->GetAttr<String>(attr::kCompiler);
if (!codegen_name.defined() || codegen_name != "cmsis-nn") {
return false;
}
return true;
}

Expr Rewrite_(const CallNode* pre, const Expr& post) override {
if (const auto* call = post.as<CallNode>()) {
if (CanOutlineExpr(call->op)) {
const auto* func = call->op.as<FunctionNode>();
ICHECK(func) << "Expected function node but was " << call->op->GetTypeKey();
const auto codegen_name = func->GetAttr<String>(attr::kCompiler);
auto global_func_name = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
GlobalVar new_global_var(global_func_name.value());

const CallNode* inner_call = func->body.as<CallNode>();
if (!inner_call) {
return CallToFuncWithoutCompilerAttr(new_global_var, GetRef<Call>(call),
GetRef<Function>(func));
Expand All @@ -684,21 +726,20 @@ class RelayToTIRVisitor : public MixedModeMutator {

if (comp_name == "cmsis-nn.qnn_softmax") {
EmitSoftMax(new_global_var, composite_func->body);
}
if (comp_name == "cmsis-nn.qnn_mul") {
} else if (comp_name == "cmsis-nn.qnn_mul") {
EmitMul(new_global_var, composite_func->body);
}
if (comp_name == "cmsis-nn.qnn_add") {
} else if (comp_name == "cmsis-nn.qnn_add") {
EmitAdd(new_global_var, composite_func->body);
}
if (comp_name == "cmsis-nn.qnn_conv2d") {
} else if (comp_name == "cmsis-nn.qnn_conv2d") {
EmitConv2D(new_global_var, composite_func->body);
}
if (comp_name == "cmsis-nn.qnn_fully_connected") {
} else if (comp_name == "cmsis-nn.qnn_fully_connected") {
EmitFullyConnected(new_global_var, composite_func->body);
}
if (comp_name == "cmsis-nn.qnn_avg_pool2d" || comp_name == "cmsis-nn.qnn_max_pool2d") {
} else if (comp_name == "cmsis-nn.qnn_avg_pool2d" ||
comp_name == "cmsis-nn.qnn_max_pool2d") {
EmitPool2D(new_global_var, composite_func->body, comp_name.value());
} else {
return CallToFuncWithoutCompilerAttr(new_global_var, GetRef<Call>(call),
GetRef<Function>(func));
}

Array<Expr> args;
Expand All @@ -709,7 +750,6 @@ class RelayToTIRVisitor : public MixedModeMutator {
return Call(new_global_var, args, call->attrs, call->type_args, call->span);
}
}

return post;
}

Expand Down
Loading