diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index c981f9d62b19..60f108aacf66 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -107,7 +107,7 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { VisitExpr(func); CreateStorage(call_node); for (const Expr& arg : args) { - GetStorage(arg); + VisitExpr(arg); } AssignReturnSid(GetRef(call_node)); } @@ -126,7 +126,7 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { for (const auto& param : func_node->params) { CreateStorage(param.get()); } - GetStorage(func_node->body); + VisitExpr(func_node->body); } void VisitExpr_(const GlobalVarNode* op) final { @@ -168,7 +168,9 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "if is not supported."; } void PreVisitLetBinding_(const Var& var, const Expr& value) final { - LOG(FATAL) << "let is not supported."; + VisitExpr(value); + StorageInfo si = GetStorage(value); + storage_device_map_[var] = si; } private: @@ -219,7 +221,8 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { Expr true_expr = IgnoreOnDevice(expr); VisitExpr(true_expr); auto it = storage_device_map_.find(true_expr); - ICHECK(it != storage_device_map_.end()); + ICHECK(it != storage_device_map_.end()) << "Could not find " << true_expr->GetTypeKey() << " " + << PrettyPrint(true_expr) << " in storage device map"; return it->second; } @@ -335,6 +338,9 @@ class AOTExecutorCodegen : public MixedModeVisitor { */ std::vector PackSid(Expr expr) { std::vector buffer_vars; + + ICHECK(storage_device_map_.find(expr) != storage_device_map_.end()) + << "Storage map did not contain constant expr " << PrettyPrint(expr); StorageInfo& sinfo = storage_device_map_[expr]; // Note that an expression can have multiple sids associated with it @@ -599,6 +605,12 @@ class AOTExecutorCodegen : public MixedModeVisitor { } void VisitExpr_(const CallNode* call_node) override { + OnDeviceProps on_device_props = GetOnDeviceProps(call_node); + if (on_device_props.body.defined()) { + VisitExpr(on_device_props.body); + return; + } + DeviceCopyProps device_copy_props = GetDeviceCopyProps(call_node); CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); @@ -626,6 +638,11 @@ class AOTExecutorCodegen : public MixedModeVisitor { Expr expr = GetRef(op); StorageInfo& sinfo = storage_device_map_[expr]; + // Let bound vars refer to a value, so these should not be considered "output" vars. + if (let_bound_vars_.find(GetRef(op)) != let_bound_vars_.end()) { + return; + } + // If the Var node is an output node we need to copy the content of the variable to the output // It's safe to check the SID here because Var StorageToken are never reallocated auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]); @@ -646,6 +663,8 @@ class AOTExecutorCodegen : public MixedModeVisitor { void VisitExpr_(const ConstantNode* op) override { Expr expr = GetRef(op); + ICHECK(storage_device_map_.find(expr) != storage_device_map_.end()) + << "Storage map did not contain constant expr " << PrettyPrint(expr); StorageInfo& sinfo = storage_device_map_[expr]; std::stringstream ss; ss << "constant_" << constant_map_.size(); @@ -674,12 +693,20 @@ class AOTExecutorCodegen : public MixedModeVisitor { } void VisitExpr_(const LetNode* op) override { - // TODO(giuseros): support Let nodes in AOT - LOG(FATAL) << "Let not yet implemented in AOT"; + auto pre_visit = [this](const LetNode* op) { + let_bound_vars_.insert(op->var); + this->VisitExpr(op->value); + }; + auto post_visit = [this](const LetNode* op) { + this->VisitExpr(op->body); + this->visit_counter_[op] += 1; + }; + ExpandANormalForm(op, pre_visit, post_visit); } + void VisitExpr_(const TupleGetItemNode* op) override { VisitExpr(op->tuple); } void VisitExpr_(const OpNode* op) override { - if (GetRef(op) != CallLoweredOp()) { + if (GetRef(op) != CallLoweredOp() && GetRef(op) != OnDeviceOp()) { LOG(FATAL) << "All OpNodes except for call_lowered should have been expanded"; } } @@ -731,6 +758,12 @@ class AOTExecutorCodegen : public MixedModeVisitor { continue; } + // Make sure it hasn't already been allocated, this can happen + // with let-bound var/value pairs. + if (allocated.find(sid) != allocated.end()) { + continue; + } + allocated[sid] = constant_map_.count(sids_table_[sid]); // TODO(giuseros): we should allocate this once outside the PrimFunc @@ -775,21 +808,36 @@ class AOTExecutorCodegen : public MixedModeVisitor { } /*! - * brief Access IO vars using the buffer vars and + * \brief Access IO vars using the buffer vars and * not the actual var. */ tir::Var GetBufferVarForIO(int index) { return main_buffer_map_[main_signature_[index]]->data; } /*! - * brief Create tir::Var for input/output while updating - * the buffer_maps. + * \brief Create tir::Var for input/output while updating the buffer_maps. + * + * \param expr The expression to evaluate. + * \param original_name The name of the tir::Var. + * \param use_unique_name Whether to generate a new unique name where a name conflicts. */ void CreateIOVar(const Expr& expr, const std::string& original_name, bool use_unique_name = true) { - if (expr->IsInstance()) { - Tuple tuple = Downcast(expr); - for (unsigned i = 0; i < tuple->fields.size(); i++) { - CreateIOVar(tuple->fields[i], original_name); + CreateIOVar(expr->checked_type(), original_name, use_unique_name); + } + + /*! + * \brief Create tir::Var for input/output while updating the buffer_maps. + * + * \param expr The expression to evaluate. + * \param original_name The name of the tir::Var. + * \param use_unique_name Whether to generate a new unique name where a name conflicts. + */ + void CreateIOVar(const Type& type, const std::string& original_name, + bool use_unique_name = true) { + if (type->IsInstance()) { + TupleType tuple_type = Downcast(type); + for (unsigned i = 0; i < tuple_type->fields.size(); i++) { + CreateIOVar(tuple_type->fields[i], original_name); } } else { std::string name = original_name; @@ -798,19 +846,20 @@ class AOTExecutorCodegen : public MixedModeVisitor { } tir::Var var = tir::Var(name, DataType::Handle()); main_signature_.push_back(var); - auto tensor_type = expr->checked_type().as(); + auto tensor_type = type.as(); + ICHECK(tensor_type) << "Expected TensorType node but was " << type->GetTypeKey(); DataType elem_type = tensor_type->dtype; tir::Var buffer_var = tir::Var(name + "_buffer_var", PointerType(PrimType(elem_type), "global")); tir::Buffer buffer = tir::Buffer(buffer_var, elem_type, tensor_type->shape, {}, 0, name + "_buffer", 16, 1, tir::BufferType::kDefault); main_buffer_map_.Set(var, buffer); - io_tensor_types_.Set(var, Downcast(expr->checked_type())); + io_tensor_types_.Set(var, Downcast(type)); } } /*! - * brief Create a unique name for I/O Var + * \brief Create a unique name for I/O Var */ std::string GetUniqueIOVarName(std::string name) { if (io_var_names_.find(name) == io_var_names_.end()) { @@ -823,7 +872,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { } /*! - * brief Calculate workspace sizes for PrimFuncs in the IRModule + * \brief Calculate workspace sizes for PrimFuncs in the IRModule */ Map CalculateWorkspaceSizes( const IRModule& lowered_mod, const Map& function_metadata) { @@ -852,7 +901,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { } /*! - * brief Run USMP to plan memory for lowered IRModule + * \brief Run USMP to plan memory for lowered IRModule. */ IRModule PlanMemoryWithUSMP(const IRModule& mod) { VLOG(1) << "Planning memory with USMP for module:" << std::endl << PrettyPrint(mod); @@ -888,7 +937,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { } /*! - * brief Run StorageRewrite to plan memory for lowered IRModule + * \brief Run StorageRewrite to plan memory for lowered IRModule. */ IRModule PlanMemoryWithStorageRewrite(const IRModule& mod) { Executor executor_config = mod->GetAttr(tvm::attr::kExecutor).value(); @@ -966,6 +1015,8 @@ class AOTExecutorCodegen : public MixedModeVisitor { std::vector return_sid_; /*! \brief This is per IO var name counter to aid the generating unique names */ std::unordered_map io_var_names_; + /*! \brief A set of variables that are let bound. */ + std::unordered_set let_bound_vars_; public: AOTExecutorCodegen(runtime::Module* mod, const Array& targets) @@ -1011,6 +1062,8 @@ class AOTExecutorCodegen : public MixedModeVisitor { << ") is not one of the expected values"; } + mod = transform::ToANormalForm()(mod); + IRModule lowered_mod = tec::LowerTEPass( mod_name, [this, workspace_byte_alignment](BaseFunc func) { @@ -1071,12 +1124,13 @@ class AOTExecutorCodegen : public MixedModeVisitor { // If output tensor names were provided use them if (auto opt = func->GetAttr>("output_tensor_names")) { Array output_tensor_names = opt.value(); - if (lowered_main_func->body->IsInstance()) { - Tuple output_tuple = Downcast(lowered_main_func->body); - for (unsigned i = 0; i < output_tuple->fields.size(); i++) { + Expr output_expr = lowered_main_func->body; + if (output_expr->checked_type()->IsInstance()) { + TupleType output_tuple_type = Downcast(output_expr->checked_type()); + for (unsigned i = 0; i < output_tuple_type->fields.size(); i++) { // AoT Executor Codegen does not create these names, // thus should be used as they are provided. - CreateIOVar(output_tuple->fields[i], output_tensor_names[i], + CreateIOVar(output_tuple_type->fields[i], output_tensor_names[i], /*use_unique_name = */ false); } } else { diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index 722e7c69d9ab..210175817f9c 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -655,19 +655,61 @@ class RelayToTIRVisitor : public MixedModeMutator { return Call(new_global_var, call->args, call->attrs, call->type_args, call->span); } - Expr Rewrite_(const CallNode* pre, const Expr& post) override { - if (const CallNode* call = post.as()) { - auto* func = call->op.as(); - if (func == nullptr) { - return post; + Expr VisitExpr_(const LetNode* op) final { + auto pre_visit = [this](const LetNode* op) { + Expr var = this->VisitExpr(op->var); + Expr value = this->VisitExpr(op->value); + // outlineable function no longer needs let binding + if (this->CanOutlineExpr(value)) { + this->memo_[var] = value; + } + }; + auto post_visit = [this](const LetNode* op) { + // Rely on the Memoizer to cache pre-visit values + Expr value = this->VisitExpr(op->value); + Expr body = this->VisitExpr(op->body); + auto expr = GetRef(op); + // drop the let binding + if (this->CanOutlineExpr(value)) { + this->memo_[expr] = this->VisitExpr(op->body); + } else { + Var var = Downcast(this->VisitExpr(op->var)); + if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { + this->memo_[expr] = expr; + } else { + this->memo_[expr] = Let(var, value, body); + } } + }; + ExpandANormalForm(op, pre_visit, post_visit); + return memo_[GetRef(op)]; + } - auto codegen_name = func->GetAttr(attr::kCompiler); - if (codegen_name.defined() && codegen_name == "cmsis-nn") { - const CallNode* inner_call = func->body.as(); + bool CanOutlineExpr(const Expr& expr) { + // TODO(@lhutton1): This behaviour is similar to the OutlineCompilerFunctions pass + // we could reuse this functionality by separating outlining and lowering in this + // pass. + if (!expr->IsInstance()) { + return false; + } + const auto* func = expr.as(); + auto codegen_name = func->GetAttr(attr::kCompiler); + if (!codegen_name.defined() || codegen_name != "cmsis-nn") { + return false; + } + return true; + } + + Expr Rewrite_(const CallNode* pre, const Expr& post) override { + if (const auto* call = post.as()) { + if (CanOutlineExpr(call->op)) { + const auto* func = call->op.as(); + ICHECK(func) << "Expected function node but was " << call->op->GetTypeKey(); + const auto codegen_name = func->GetAttr(attr::kCompiler); auto global_func_name = func->GetAttr(tvm::attr::kGlobalSymbol); GlobalVar new_global_var(global_func_name.value()); + const CallNode* inner_call = func->body.as(); if (!inner_call) { return CallToFuncWithoutCompilerAttr(new_global_var, GetRef(call), GetRef(func)); @@ -684,21 +726,20 @@ class RelayToTIRVisitor : public MixedModeMutator { if (comp_name == "cmsis-nn.qnn_softmax") { EmitSoftMax(new_global_var, composite_func->body); - } - if (comp_name == "cmsis-nn.qnn_mul") { + } else if (comp_name == "cmsis-nn.qnn_mul") { EmitMul(new_global_var, composite_func->body); - } - if (comp_name == "cmsis-nn.qnn_add") { + } else if (comp_name == "cmsis-nn.qnn_add") { EmitAdd(new_global_var, composite_func->body); - } - if (comp_name == "cmsis-nn.qnn_conv2d") { + } else if (comp_name == "cmsis-nn.qnn_conv2d") { EmitConv2D(new_global_var, composite_func->body); - } - if (comp_name == "cmsis-nn.qnn_fully_connected") { + } else if (comp_name == "cmsis-nn.qnn_fully_connected") { EmitFullyConnected(new_global_var, composite_func->body); - } - if (comp_name == "cmsis-nn.qnn_avg_pool2d" || comp_name == "cmsis-nn.qnn_max_pool2d") { + } else if (comp_name == "cmsis-nn.qnn_avg_pool2d" || + comp_name == "cmsis-nn.qnn_max_pool2d") { EmitPool2D(new_global_var, composite_func->body, comp_name.value()); + } else { + return CallToFuncWithoutCompilerAttr(new_global_var, GetRef(call), + GetRef(func)); } Array args; @@ -709,7 +750,6 @@ class RelayToTIRVisitor : public MixedModeMutator { return Call(new_global_var, args, call->attrs, call->type_args, call->span); } } - return post; } diff --git a/src/relay/backend/contrib/ethosu/codegen.cc b/src/relay/backend/contrib/ethosu/codegen.cc index dfcf54f7b76c..47c80b47c579 100644 --- a/src/relay/backend/contrib/ethosu/codegen.cc +++ b/src/relay/backend/contrib/ethosu/codegen.cc @@ -57,28 +57,81 @@ class OutlineCompilerFunctionsMutator : public MixedModeMutator { explicit OutlineCompilerFunctionsMutator(const IRModule& mod, const std::string& compiler_name) : mod_(mod), compiler_name_(compiler_name) {} + Expr VisitExpr_(const LetNode* op) final { + auto pre_visit = [this](const LetNode* op) { + Expr var = this->VisitExpr(op->var); + Expr value = this->VisitExpr(op->value); + + // Outlineable function no longer needs let binding + if (this->CanOutlineExpr(value)) { + this->memo_[var] = value; + } + }; + auto post_visit = [this](const LetNode* op) { + // Rely on the Memoizer to cache pre-visit values + Expr value = this->VisitExpr(op->value); + Expr body = this->VisitExpr(op->body); + auto expr = GetRef(op); + + // Drop the let binding + if (this->CanOutlineExpr(value)) { + this->memo_[expr] = this->VisitExpr(op->body); + } else { + Var var = Downcast(this->VisitExpr(op->var)); + if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { + this->memo_[expr] = expr; + } else { + this->memo_[expr] = Let(var, value, body); + } + } + }; + ExpandANormalForm(op, pre_visit, post_visit); + return memo_[GetRef(op)]; + } + Expr Rewrite_(const CallNode* pre, const Expr& post) override { Call call = Downcast(post); - if (call->op->IsInstance()) { + if (CanOutlineExpr(call->op)) { Function func = Downcast(call->op); - auto compiler = func->GetAttr(attr::kCompiler); - if (compiler.defined() && compiler == compiler_name_) { - auto gv_name = func->GetAttr("global_symbol").value_or(""); - ICHECK_NE(gv_name, "") - << "Function to be outlined must have global_symbol attribute, but didn't."; - GlobalVar gv(gv_name); - if (func->checked_type_.defined()) { - gv->checked_type_ = func->checked_type(); - } - mod_->Update(gv, func); - return Call(gv, call->args, call->attrs, call->type_args); + auto gv_name = func->GetAttr("global_symbol").value_or(""); + ICHECK_NE(gv_name, "") + << "Function to be outlined must have global_symbol attribute, but didn't."; + GlobalVar gv(gv_name); + if (func->checked_type_.defined()) { + gv->checked_type_ = func->checked_type(); } + mod_->Update(gv, func); + return Call(gv, call->args, call->attrs, call->type_args); } return post; } private: + /*! + * \brief Check if the expr is a function and has the same + * compiler name as compiler_name_. + * + * \param expr The input expr. + * \return True if is outlineable else False. + */ + bool CanOutlineExpr(const Expr& expr) { + if (!expr->IsInstance()) { + return false; + } + Function func = Downcast(expr); + auto compiler = func->GetAttr(attr::kCompiler); + if (!compiler.defined()) { + return false; + } + if (compiler != compiler_name_) { + return false; + } + return true; + } + + /*! \brief The module that the pass will run on. */ IRModule mod_; + /*! \brief The name of the compiler to enable outlining on external functions for. */ std::string compiler_name_; }; @@ -188,7 +241,7 @@ class RemoveRedundantIdentities : public MixedModeMutator { const auto* call_tt = call->checked_type_.as(); const auto* identity_arg_tt = identity_arg->checked_type_.as(); - CHECK(call_tt && identity_arg_tt) + ICHECK(call_tt && identity_arg_tt) << "InferType should be run before RemoveRedundantIdentities"; // we can only remove the identity operation if the second non-compute operation diff --git a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc index 86f55caf9342..c498baa6d11d 100644 --- a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc +++ b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc @@ -94,23 +94,67 @@ class ConvertAddToSubtract : public MixedModeMutator { ir_module_->Add(new_global_var, replacement_func); } + Expr VisitExpr_(const LetNode* op) final { + auto pre_visit = [this](const LetNode* op) { + Expr var = this->VisitExpr(op->var); + Expr value = this->VisitExpr(op->value); + + // Outlineable function no longer needs let binding + if (this->CanLowerExpr(value)) { + this->memo_[var] = value; + } + }; + auto post_visit = [this](const LetNode* op) { + // Rely on the Memoizer to cache pre-visit values + Expr value = this->VisitExpr(op->value); + Expr body = this->VisitExpr(op->body); + auto expr = GetRef(op); + + // Drop the let binding + if (this->CanLowerExpr(value)) { + this->memo_[expr] = this->VisitExpr(op->body); + } else { + Var var = Downcast(this->VisitExpr(op->var)); + if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { + this->memo_[expr] = expr; + } else { + this->memo_[expr] = Let(var, value, body); + } + } + }; + ExpandANormalForm(op, pre_visit, post_visit); + return memo_[GetRef(op)]; + } + + bool CanLowerExpr(const Expr& expr) { + const auto* func = expr.as(); + if (func == nullptr) { + return false; + } + auto func_name = func->GetAttr(::tvm::attr::kGlobalSymbol); + if (!func_name.defined()) { + return false; + } + if (func_name != "replace_add_with_subtract") { + return false; + } + return true; + } + Expr Rewrite_(const CallNode* pre, const Expr& post) override { if (const CallNode* call = post.as()) { - auto* func = call->op.as(); - if (func == nullptr) { - return post; - } + if (CanLowerExpr(call->op)) { + auto* func = call->op.as(); + auto func_name = func->GetAttr(::tvm::attr::kGlobalSymbol); - auto func_name = func->GetAttr(::tvm::attr::kGlobalSymbol); - if (func_name.defined() && func_name == "replace_add_with_subtract") { // Introduce a new global var to map the function to and copy the source type // over for InferType GlobalVar new_global_var(func_name.value()); new_global_var->checked_type_ = func->checked_type(); ReplaceAddWithSubtractPrimFunc(new_global_var, GetRef(func)); - // Since we are replacing the Relay function with a call to a TIR function, we must use the - // call_lowered op. + // Since we are replacing the Relay function with a call to a TIR function, we must use + // the call_lowered op. CallLoweredAttrs attrs; attrs.metadata.Set("relay_attrs", call->attrs); ICHECK(call->type_args.empty()) << "lowered functions cannot be polymorphic"; diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 70d74ea92377..71b57aed81f6 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -678,6 +678,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { if (prim_func.defined()) { // Leaving let var scope primitive_functions_.erase(pre_let_node->var.get()); + // Drop the let node + return post_let_node->body; } return DeviceAwareExprMutator::PostVisitLet_(pre_let_node, post_let_node); } diff --git a/src/relay/transforms/target_hooks.cc b/src/relay/transforms/target_hooks.cc index b0ac883623d2..0022baf881ba 100644 --- a/src/relay/transforms/target_hooks.cc +++ b/src/relay/transforms/target_hooks.cc @@ -61,25 +61,19 @@ class TargetHookVisitor : public tvm::relay::MixedModeVisitor { ExpandANormalForm(op, pre_visit, post_visit); } - void VisitExpr_(const CallNode* call) override { - // Descend the call tree - for (auto arg : call->args) { - VisitExpr(arg); + void VisitExpr_(const FunctionNode* func) override { + ExprVisitor::VisitExpr_(func); + if (!func->GetAttr(attr::kCompiler).defined()) { + return; } - - if (const FunctionNode* func = call->op.as()) { - if (!func->GetAttr(attr::kCompiler).defined()) { - return; - } - String code_gen_name = func->GetAttr(attr::kCompiler).value(); - Optional target_kind = tvm::TargetKind::Get(code_gen_name); - if (!target_kind || !target_attr_map_.count(target_kind.value())) { - return; - } - Pass custom_target_pass = target_attr_map_[target_kind.value()]; - if (std::find(pass_list_.begin(), pass_list_.end(), custom_target_pass) == pass_list_.end()) { - pass_list_.push_back(custom_target_pass); - } + String code_gen_name = func->GetAttr(attr::kCompiler).value(); + Optional target_kind = tvm::TargetKind::Get(code_gen_name); + if (!target_kind || !target_attr_map_.count(target_kind.value())) { + return; + } + Pass custom_target_pass = target_attr_map_[target_kind.value()]; + if (std::find(pass_list_.begin(), pass_list_.end(), custom_target_pass) == pass_list_.end()) { + pass_list_.push_back(custom_target_pass); } } }; diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 3c44d2bf1bc8..2991cc01fc92 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -36,6 +36,7 @@ from tvm.relay.backend import Executor, Runtime from tvm.micro import model_library_format as mlf from tvm.micro import export_model_library_format +from tvm.ir.instrument import pass_instrument from aot_test_utils import ( AOTTestModel, AOT_DEFAULT_RUNNER, @@ -1027,5 +1028,51 @@ def test_aot_codegen_checks_returns(): ) +def test_aot_uses_anf(): + """Checks that A-Normal Form is being used in the AOT lowering pipeline.""" + x = relay.var("x", shape=(1, 10, 10, 10)) + y = relay.var("y", shape=(1, 10, 10, 10)) + z = relay.add(x, y) + func = relay.Function([x, y], z) + + @pass_instrument + class CheckANFRuns: + def __init__(self): + self.did_run_anf = False + + def run_before_pass(self, _, info): + if info.name == "ToANormalForm": + self.did_run_anf = True + if info.name == "LowerTE": + assert self.did_run_anf, "ToANormalForm pass should run before LowerTE." + + check_run_anf = CheckANFRuns() + + model = AOTTestModel(module=IRModule.from_expr(func), inputs=None, outputs=None) + runtime = Runtime("crt") + executor = Executor( + "aot", + { + "workspace-byte-alignment": 8, + "interface-api": "c", + "unpacked-api": True, + }, + ) + config = {"tir.disable_vectorize": True} + + with tvm.transform.PassContext(opt_level=3, config=config, instruments=[check_run_anf]): + tvm.relay.build( + model.module, + tvm.target.Target("c"), + executor=executor, + runtime=runtime, + workspace_memory_pools=None, + params=model.params, + mod_name=model.name, + ) + + assert check_run_anf.did_run_anf, "Expected ToANormalForm pass to have run." + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:]))