From ea69f71c9c67c796e7ea2769c61cd13e1b544edf Mon Sep 17 00:00:00 2001 From: mbs-octoml Date: Fri, 17 Jun 2022 14:01:07 -0700 Subject: [PATCH] [BYOC] Handle constants in IRModule-at-a-time external codegen I tried to do to the TensorRT integration what #11631 did to the CUTLASS integration, viz: - Make sure all compilation options are passed in Target instances. This helps Collage. - Use a custom pass invoked via RelayToTIRTargetHooks instead of the relay.ext.$toolchain mechanism. This helps use decouple external codegen from lowering. This PR collects the prep for that change: - TensorRT uses the JSONSerializer visitor to encode each partition function. Previously, when the visitor encountered a Constant it simply generated and recorded a name for the constant. Then, completely separately, and via a callback in TECompiler, the function is visited again in the same order and with the same name generation convention by a ConstantUpdater to actually collect the bindings, which are then encoded into a ConstLoaderModule to be made available at runtime. However if all TensorRT compilation is to be done by a stand-alone pass there's no TECompiler callback hackery available. So I've added a "const_name_to_ndarray" attribute to the IRModule of type Map so that named constants can be accumulated throughout compilation by any pass which needs to do so. Then the Graph, AOT and VM executors are all updated to merge those constants into the final runtime artifact (Compare with "Constants", the equivalent attribute for extracting TIR AllocateConsts.) - The TensorRT tests use the create_executor interface but it wasn't quite ready for the new more general form of passing list-of-targets. - I want TensorRT compilation to work out of the box without the need for any special targets if all the default options should apply. Go back and make the CUTLASS integration I did follow the same convention. - To test this I also switched the 'demo' "ccompiler" external codegen target to IRModule-at-a-time style. This means we can test most of external codegen machinery in one place without depending on any target which may not be enabled in CI (eg TensorRT): - Target instances are plumbed correctly so compile-time options are available. - External modules are conveyed to the final export library. - Constant bindings are conveyed to the metadata module. --- cmake/modules/contrib/CODEGENC.cmake | 2 +- include/tvm/ir/module.h | 30 +- include/tvm/tir/stmt.h | 6 +- python/tvm/relay/backend/interpreter.py | 2 +- python/tvm/relay/backend/vm.py | 5 +- python/tvm/relay/build_module.py | 41 ++- python/tvm/relay/transform/transform.py | 4 +- python/tvm/tir/stmt.py | 2 +- src/relay/backend/aot_executor_codegen.cc | 39 ++- src/relay/backend/build_module.cc | 12 +- .../contrib/arm_compute_lib/codegen.cc | 9 +- src/relay/backend/contrib/bnns/codegen.cc | 8 +- .../backend/contrib/codegen_c/codegen.cc | 281 ++++++++++++------ .../backend/contrib/codegen_c/codegen_c.h | 13 +- src/relay/backend/contrib/codegen_c/target.cc | 43 +++ .../contrib/codegen_json/codegen_json.h | 46 ++- src/relay/backend/contrib/cutlass/codegen.cc | 34 ++- src/relay/backend/contrib/dnnl/codegen.cc | 8 +- .../contrib/example_target_hooks/target.cc | 1 - src/relay/backend/contrib/tensorrt/codegen.cc | 9 +- .../backend/contrib/verilator/codegen.cc | 9 +- src/relay/backend/graph_executor_codegen.cc | 39 +-- src/relay/backend/te_compiler.cc | 4 +- src/relay/backend/utils.h | 8 +- src/relay/backend/vm/compiler.cc | 28 +- .../transforms/compiler_function_utils.cc | 34 ++- .../transforms/compiler_function_utils.h | 13 +- src/relay/transforms/target_hooks.cc | 7 +- src/target/metadata_module.cc | 2 + src/tir/transforms/extract_constants.cc | 6 +- tests/python/relay/test_external_codegen.py | 40 +-- .../transform/test_compiler_function_utils.py | 40 +++ .../python/unittest/test_custom_datatypes.py | 3 +- .../test_tir_transform_extract_constants.py | 5 +- 34 files changed, 571 insertions(+), 262 deletions(-) create mode 100644 src/relay/backend/contrib/codegen_c/target.cc diff --git a/cmake/modules/contrib/CODEGENC.cmake b/cmake/modules/contrib/CODEGENC.cmake index 275c32514eba..412fa3e8ffc5 100644 --- a/cmake/modules/contrib/CODEGENC.cmake +++ b/cmake/modules/contrib/CODEGENC.cmake @@ -15,6 +15,6 @@ # specific language governing permissions and limitations # under the License. -tvm_file_glob(GLOB CSOURCE_RELAY_CONTRIB_SRC src/relay/backend/contrib/codegen_c/codegen.cc) +tvm_file_glob(GLOB CSOURCE_RELAY_CONTRIB_SRC src/relay/backend/contrib/codegen_c/*.cc) list(APPEND COMPILER_SRCS ${CSOURCE_RELAY_CONTRIB_SRC}) diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index b78f16a84f02..f73f2230df4d 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -479,8 +479,10 @@ TVM_DLL String AsText(const ObjectRef& node, bool show_meta_data = true, namespace attr { +// Following are attributes for IRModule only. + /*! - * \brief Executor targetted by the module + * \brief Executor targeted by the module * * Type: Executor * @@ -516,9 +518,31 @@ constexpr const char* kWorkspaceMemoryPools = "workspace_memory_pools"; constexpr const char* kConstantMemoryPools = "constant_memory_pools"; /* - * \brief Module attribute for tir constants + * \brief All the runtime::NDArrays extracted from PrimFunc tir::AllocateConst nodes. The + * node will record the index into this array. See also kConstNameToConstant below, which is + * the analog for Realy Functions. + * + * Type: Array + */ +constexpr const char* kConstants = "constants"; + +/*! + * \brief All the runtime::Modules accumulated during compilation by external codegen. These + * modules must be either directly linked or captured in the final compilation artifact. + * + * Type: Array + */ +constexpr const char* kExternalMods = "external_mods"; + +/*! + * \brief All the named runtime::NDArrays accumulated during compilation by external codegen. + * Generally the associated runtime::Module will indicate it requires bindings for these names, + * and during module initialization these bindings will be recovered from a ConstLoaderModule. + * See also kConstantsArray above, which is the analog for PrimFuncs. + * + * Type: Map */ -constexpr const char* kConstantsArray = "Constants"; +constexpr const char* kConstNameToConstant = "const_name_to_constant"; } // namespace attr } // namespace tvm diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 4c8a3076a20b..ec9a8b29334d 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -599,9 +599,9 @@ class AllocateConstNode : public StmtNode { /*! \brief The optional data associated to the constant. */ Optional data; - /*! \brief If the PrimFunc containing the Stmt is added to IRModule, - this is an optional index to indicate the index within - "Constants" attribute, that is a Array of IRModule. + /*! + * \brief If the PrimFunc containing the Stmt is added to IRModule, this is an optional index + * to indicate the index within "constants" attribute, that is a Array of IRModule. */ Optional irmod_storage_idx; /*! \brief The type of the buffer. */ diff --git a/python/tvm/relay/backend/interpreter.py b/python/tvm/relay/backend/interpreter.py index 819e5eda41f5..020736beb5c4 100644 --- a/python/tvm/relay/backend/interpreter.py +++ b/python/tvm/relay/backend/interpreter.py @@ -195,7 +195,7 @@ class Interpreter(Executor): The runtime device to run the code on. target : tvm.Target - The target option to build the function. + The target option to build the function. Only homogeneous execution is supported. CAUTION: Despite the API the module is prepared upon each call to evaluate rather than once in create_executor. diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index d4a82cd8d427..bc11d43cb0ca 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -198,8 +198,9 @@ class VMExecutor(Executor): device : :py:class:`~tvm.runtime.Device` The runtime device to run the code on. - target : :py:class:`Target` - The target option to build the function. + target : any multi-target like object, see Target.canon_multi_target + For homogeneous compilation, the unique build target. + For heterogeneous compilation, a dictionary or list of possible build targets. """ def __init__(self, mod, device, target): diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index 1353d8c5f595..32ad6c70794c 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -570,8 +570,9 @@ class GraphExecutor(_interpreter.Executor): device : :py:class:`Device` The runtime device to run the code on. - target : :py:class:`Target` - The target option to build the function. + target : any multi-target like object, see Target.canon_multi_target + For homogeneous compilation, the unique build target. + For heterogeneous compilation, a dictionary or list of possible build targets. """ def __init__(self, mod, device, target): @@ -630,8 +631,9 @@ class AotExecutor(_interpreter.Executor): device : :py:class:`Device` The runtime device to run the code on. - target : :py:class:`Target` - The target option to build the function. + target : any multi-target like object, see Target.canon_multi_target + For homogeneous compilation, the unique build target. + For heterogeneous compilation, a dictionary or list of possible build targets. """ def __init__(self, mod, device, target): @@ -639,7 +641,6 @@ def __init__(self, mod, device, target): self.mod = mod self.device = device self.target = target - assert target.attrs.get("executor", "graph") == "aot" def _make_executor(self, expr=None): if expr: @@ -719,8 +720,11 @@ def create_executor(kind="debug", mod=None, device=None, target="llvm", params=N device : :py:class:`Device` The device to execute the code. - target : :py:class:`tvm.Target` - The corresponding context + target : any multi-target like object, see Target.canon_multi_target + For homogeneous compilation, the unique build target. + For heterogeneous compilation, a dictionary or list of possible build targets. + CAUTION: Though this API allows multiple targets, it does not allow multiple devices, so + heterogenous compilation is not yet supported. params : dict of str to NDArray Input parameters to the graph that do not change @@ -730,24 +734,31 @@ def create_executor(kind="debug", mod=None, device=None, target="llvm", params=N ------- executor : :py:class:`~tvm.relay.backend.interpreter.Executor` """ + raw_targets = Target.canon_multi_target(target) if mod is None: mod = IRModule() if device is not None: - assert device.device_type == _nd.device(str(target), 0).device_type + assert device.device_type == raw_targets[0].kind.device_type else: - device = _nd.device(str(target), 0) + # Derive the default device from the first target. + device = _nd.device(raw_targets[0].kind.device_type, 0) if params is not None: mod = IRModule.from_expr(bind_params_by_name(mod["main"], params)) - if isinstance(target, str): - target = Target(target) + assert "executor" not in raw_targets[0].attrs or raw_targets[0].attrs["executor"] == kind + if kind == "debug": - return _interpreter.Interpreter(mod, device, target) + assert len(raw_targets) == 1, "The interpreter currently only supports a single target" + return _interpreter.Interpreter(mod, device, raw_targets[0]) if kind == "graph": - return GraphExecutor(mod, device, target) + return GraphExecutor(mod, device, raw_targets) if kind == "vm": - return VMExecutor(mod, device, target) + return VMExecutor(mod, device, raw_targets) if kind == "aot": - return AotExecutor(mod, device, target) + # The AOT requires the executor as a target attribute. + # (The compilation paths for the other executors currently do not always provide this + # attribute, hence the above generic assert is more forgiving). + assert "executor" in raw_targets[0].attrs + return AotExecutor(mod, device, raw_targets) raise RuntimeError("unknown execution strategy: {0}".format(kind)) diff --git a/python/tvm/relay/transform/transform.py b/python/tvm/relay/transform/transform.py index c931289d40c6..d7979a757171 100644 --- a/python/tvm/relay/transform/transform.py +++ b/python/tvm/relay/transform/transform.py @@ -1386,7 +1386,7 @@ def OutlineCompilerFunctionsWithExistingGlobalSymbols(compiler_filter=""): Parameters ---------- compiler_filter : String - If non-empty, the 'compiler' attribute to filter on. + If non-empty, the "Compiler" attribute to filter on. Returns ------- @@ -1412,7 +1412,7 @@ def MarkCompilerFunctionsAsExtern(compiler_filter=""): Parameters ---------- compiler_filter : String - If non-empty, the 'compiler' attribute to filter on. + If non-empty, the "Compiler" attribute to filter on. Returns ------- diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 301bfa73c818..063439e068a4 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -358,7 +358,7 @@ class AllocateConst(Stmt): data_or_idx : Union[NDArray, int] If an NDArray, this is the const data associated with the constant. If an integer, this is the index into the - "Constants" attribute of the `IRModule` that contains the + "constants" attribute of the `IRModule` that contains the `AllocateConst`. body : Stmt diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 5020e79714b2..ae60970b78af 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -1167,11 +1167,19 @@ class AOTExecutorCodegen : public MixedModeVisitor { // because the packed calls arguments are not wrapped in TVMValues. To make this happen we need // to run the LegalizePackedCalls pass. LoweredOutput ret; - ret.params = std::unordered_map>(); - for (auto param : params_) { - ret.params.emplace(std::make_pair( - param.first, - std::make_pair(static_cast(param_storage_ids_[param.first]), param.second))); + + // Collect any constants extracted by external codegen. + ret.params = std::unordered_map(); + Map const_name_to_constant = + lowered_mod->GetAttr>(tvm::attr::kConstNameToConstant) + .value_or({}); + for (const auto& kv : const_name_to_constant) { + ICHECK(ret.params.emplace(kv.first, kv.second).second); + } + + // Collect any constants extracted during lowering. + for (const auto& kv : params_) { + ICHECK(ret.params.emplace(kv.first, kv.second).second); } // AoT Executor codegen works completely on TIR beyond this point, hence removing relay main @@ -1212,9 +1220,9 @@ class AOTExecutorCodegen : public MixedModeVisitor { lowered_mod = pack_calls(lowered_mod); } - Optional> external_modules = - lowered_mod->GetAttr>("external_mods"); - ICHECK(external_modules) << "Attribute \"external_mods\" should be set at this point."; + // Collect any runtime modules generated by external codegen. + ret.external_mods = + lowered_mod->GetAttr>(tvm::attr::kExternalMods).value_or({}); // This is the point where we separate the functions in the module by target VLOG(1) << "lowered module:" << std::endl << PrettyPrint(lowered_mod); @@ -1227,8 +1235,6 @@ class AOTExecutorCodegen : public MixedModeVisitor { << PrettyPrint(kv.second); } - ret.external_mods = external_modules.value(); - // Extract USMP metadata to pass onto metadata sources Map pool_var_info; std::vector pool_vars; @@ -1316,11 +1322,6 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { String key = args[0]; *rv = get_param_by_name(key); }); - } else if (name == "get_param_id") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - String key = args[0]; - *rv = get_param_id(key); - }); } else if (name == "get_irmodule") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = get_irmodule(); }); @@ -1362,17 +1363,11 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { runtime::NDArray get_param_by_name(String key) { auto it = this->output_.params.find(key); CHECK(it != this->output_.params.end()) << "no such parameter " << key; - return (*it).second.second; + return (*it).second; } Array get_external_modules() { return output_.external_mods; } - int get_param_id(String key) { - auto it = this->output_.params.find(key); - CHECK(it != this->output_.params.end()) << "no such parameter " << key; - return (*it).second.first; - } - Map get_irmodule() { return this->output_.lowered_funcs; } std::shared_ptr codegen_; diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 628dee0844ec..9a68b567305d 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -86,17 +86,6 @@ struct ExecutorCodegen { return ret; } - std::unordered_map GetParamIds() { - std::unordered_map ret; - auto names = CallFunc>("list_params_name", nullptr); - for (const auto& expr : names) { - // Implicit cast from runtime::String to std::string - std::string key = expr; - ret[key] = CallFunc("get_param_id", key); - } - return ret; - } - Array GetExternalModules() { return CallFunc>("get_external_modules", nullptr); } @@ -478,6 +467,7 @@ class RelayBuildModule : public runtime::ModuleNode { for (size_t i = 0; i < variables.size(); i++) { auto it = ret_.params.find(variables[i].operator std::string()); if (it != ret_.params.end()) { + VLOG(1) << "constant '" << variables[i] << "' has been captured in external module"; ret_.params.erase(it); } } diff --git a/src/relay/backend/contrib/arm_compute_lib/codegen.cc b/src/relay/backend/contrib/arm_compute_lib/codegen.cc index 842ede3bf20b..81a5b5bbd9d8 100644 --- a/src/relay/backend/contrib/arm_compute_lib/codegen.cc +++ b/src/relay/backend/contrib/arm_compute_lib/codegen.cc @@ -392,10 +392,15 @@ runtime::Module ACLCompiler(const ObjectRef& ref) { ACLJSONSerializer serializer(func_name, func); serializer.serialize(); std::string graph_json = serializer.GetJSON(); - auto param_names = serializer.GetParams(); + + // Note that serializer.const_name_to_constant() is ignored. Instead the TECompiler invokes + // a callback which calls backend::UpdateConstants to capture the map before the function + // 'disappears' into lowered form, on the assumption the visit order and thus constant + // names match those generated by the JSONSerializer. + const auto* pf = runtime::Registry::Get("runtime.arm_compute_lib_runtime_create"); ICHECK(pf != nullptr) << "Cannot find JSON runtime module to create"; - runtime::Module lib = (*pf)(func_name, graph_json, param_names); + runtime::Module lib = (*pf)(func_name, graph_json, serializer.const_names()); return lib; } diff --git a/src/relay/backend/contrib/bnns/codegen.cc b/src/relay/backend/contrib/bnns/codegen.cc index 72c32fb5b19e..3791773ad67d 100644 --- a/src/relay/backend/contrib/bnns/codegen.cc +++ b/src/relay/backend/contrib/bnns/codegen.cc @@ -136,11 +136,15 @@ runtime::Module BNNSCompiler(const ObjectRef& ref) { BNNSJSONSerializer serializer(func_name, func); serializer.serialize(); std::string graph_json = serializer.GetJSON(); - auto params = serializer.GetParams(); + + // Note that serializer.const_name_to_constant() is ignored. Instead the TECompiler invokes + // a callback which calls backend::UpdateConstants to capture the map before the function + // 'disappears' into lowered form, on the assumption the visit order and thus constant + // names match those generated by the JSONSerializer. const auto* pf = runtime::Registry::Get("runtime.BNNSJSONRuntimeCreate"); ICHECK(pf != nullptr) << "Cannot find JSON runtime module to create"; - auto mod = (*pf)(func_name, graph_json, params); + auto mod = (*pf)(func_name, graph_json, serializer.const_names()); return mod; } diff --git a/src/relay/backend/contrib/codegen_c/codegen.cc b/src/relay/backend/contrib/codegen_c/codegen.cc index fd1c39bb9283..ee8724fe92fe 100644 --- a/src/relay/backend/contrib/codegen_c/codegen.cc +++ b/src/relay/backend/contrib/codegen_c/codegen.cc @@ -16,17 +16,17 @@ * specific language governing permissions and limitations * under the License. */ -#include + #include #include #include #include #include -#include #include #include +#include "../../../transforms/compiler_function_utils.h" #include "../../utils.h" #include "codegen_c.h" @@ -34,30 +34,62 @@ namespace tvm { namespace relay { namespace contrib { -using namespace backend; +/*! \brief Return the "ccompiler" Target instance to use to guide compilation. */ +Target GetCCompilerTarget() { + Target target = Target::Current(/*allow_not_defined=*/true); + if (!target.defined() || target->kind->name != "ccompiler") { + // Use the default compilation options if no specific "ccompiler" target was given + // in the overall targets list. In that case target_hooks.cc will invoke the custom pass + // without pushing any target instance onto the implicit target stack. + target = Target("ccompiler"); + } + return target; +} /*! - * \brief An example codegen that is only used for quick prototyping and testing - * purpose. Only several binary options are covered. Users - * may need to extend them to cover more operators. + * \brief Emits C/C++ code for a single function. + * + * For testing and demonstration only, only a few binary operators are supported. */ -class CodegenC : public MemoizedExprTranslator>, public CodegenCBase { +class CodegenC : public backend::MemoizedExprTranslator>, public CodegenCBase { public: - explicit CodegenC(const std::string& id) { this->ext_func_id_ = id; } + CodegenC(std::unordered_map* const_name_to_constant, + Array* const_names, bool* needs_extra_headers, std::string ext_func_id) + : const_name_to_constant_(const_name_to_constant), + const_names_(const_names), + needs_extra_headers_(needs_extra_headers), + ext_func_id_(std::move(ext_func_id)) {} - std::vector VisitExprDefault_(const Object* op) final { + /*! + * \brief Emit the source code that invokes C compiler compatible wrappers. + * + * \return The emitted code. + */ + std::string JIT(const std::vector& out) override { + if (!ext_func_args_.empty()) { + *needs_extra_headers_ = true; + } + // Write function macros + for (auto decl : func_decl_) { + code_stream_ << decl << "\n"; + } + return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body_, const_array_name_, out); + } + + private: + std::vector VisitExprDefault_(const Object* op) override { LOG(FATAL) << "C codegen doesn't support: " << op->GetTypeKey(); return {}; } - std::vector VisitExpr_(const VarNode* node) final { + std::vector VisitExpr_(const VarNode* node) override { ext_func_args_.push_back(GetRef(node)); Output output; output.name = node->name_hint(); return {output}; } - std::vector VisitExpr_(const TupleNode* node) final { + std::vector VisitExpr_(const TupleNode* node) override { std::vector outs; for (auto field : node->fields) { auto res = VisitExpr(field); @@ -67,7 +99,7 @@ class CodegenC : public MemoizedExprTranslator>, public Code return outs; } - std::vector VisitExpr_(const TupleGetItemNode* op) final { + std::vector VisitExpr_(const TupleGetItemNode* op) override { auto res = VisitExpr(op->tuple); ICHECK_GT(res.size(), static_cast(op->index)); @@ -76,19 +108,21 @@ class CodegenC : public MemoizedExprTranslator>, public Code return {res[op->index]}; } - std::vector VisitExpr_(const ConstantNode* cn) final { + std::vector VisitExpr_(const ConstantNode* cn) override { std::ostringstream decl_stream; std::ostringstream buf_stream; Output output; // Get const: static_cast(gcc_0_consts[0]->data) - output.name = CreateDataReference(ext_func_id_, const_idx_); + size_t const_id = const_name_to_constant_->size(); + output.name = CreateDataReference(ext_func_id_, const_id); const auto* type_node = cn->checked_type().as(); ICHECK(type_node); const auto& dtype = GetDtypeString(type_node); // Generate the global variable for needed ndarrays if (const_array_name_.empty()) { + *needs_extra_headers_ = true; const_array_name_ = CreateNDArrayPool(ext_func_id_); std::string checker = CreateInitChecker(ext_func_id_); ext_func_body_.insert(ext_func_body_.begin(), checker); @@ -97,14 +131,14 @@ class CodegenC : public MemoizedExprTranslator>, public Code ICHECK(dtype == "float" || dtype == "int") << "Only float and int are supported for now."; output.dtype = dtype; - std::string const_var_name = CreateConstVar(ext_func_id_, const_idx_); - const_vars_.push_back(const_var_name); - const_idx_++; + std::string const_var_name = CreateConstVar(ext_func_id_, const_id); + const_name_to_constant_->emplace(const_var_name, cn->data); + const_names_->push_back(const_var_name); return {output}; } - std::vector VisitExpr_(const CallNode* call) final { + std::vector VisitExpr_(const CallNode* call) override { std::ostringstream macro_stream; std::ostringstream decl_stream; std::ostringstream buf_stream; @@ -114,17 +148,17 @@ class CodegenC : public MemoizedExprTranslator>, public Code // Make function declaration macro_stream << "CSOURCE_BINARY_OP_" << call->args.size() << "D(" << func_name << ", "; - if (IsOp(call, "add")) { + if (backend::IsOp(call, "add")) { macro_stream << "+"; - } else if (IsOp(call, "subtract")) { + } else if (backend::IsOp(call, "subtract")) { macro_stream << "-"; - } else if (IsOp(call, "multiply")) { + } else if (backend::IsOp(call, "multiply")) { macro_stream << "*"; } else { LOG(FATAL) << "Unrecognized op"; } - auto in_shape = GetShape(call->args[0]->checked_type()); + auto in_shape = backend::GetShape(call->args[0]->checked_type()); for (size_t i = 0; i < in_shape.size(); ++i) { macro_stream << ", " << in_shape[i]; } @@ -152,7 +186,7 @@ class CodegenC : public MemoizedExprTranslator>, public Code } std::string out = "buf_" + std::to_string(buf_idx_++); - auto out_shape = GetShape(call->checked_type()); + auto out_shape = backend::GetShape(call->checked_type()); int out_size = 1; for (size_t i = 0; i < out_shape.size(); ++i) { out_size *= out_shape[i]; @@ -175,27 +209,21 @@ class CodegenC : public MemoizedExprTranslator>, public Code } /*! - * \brief Emit the source code that invokes C compiler compatible wrappers. - * - * \return The emitted code. + * \brief The accumulated constant name to constant mapping. Shared between all generated + * functions. */ - std::string JIT(const std::vector& out) { - // Write function macros - for (auto decl : func_decl_) { - code_stream_ << decl << "\n"; - } - return JitImpl(ext_func_id_, ext_func_args_, buf_decl_, ext_func_body_, const_array_name_, out); - } - - private: - /*! \brief The function id that represents a C source function. */ - std::string ext_func_id_ = ""; - /*! \brief The index of a wrapped C function. */ + std::unordered_map* const_name_to_constant_; + /*! \brief The accumulated constant names, in the order they were generated. */ + Array* const_names_; + /*! \brief Set to true if the ndarray and packed function headers are required. */ + bool* needs_extra_headers_; + /*! \brief Name of the global function currently being compiled. */ + std::string ext_func_id_; + + /*! \brief The index of the next available wrapped C function. */ int func_idx = 0; - /*! \brief The index of allocated buffers. */ + /*! \brief The index of the next available allocated buffers. */ int buf_idx_ = 0; - /*! \brief The index of global constants. */ - int const_idx_ = 0; /*! \brief The arguments of a C compiler compatible function. */ Array ext_func_args_; /*! \brief The statements of a C compiler compatible function. */ @@ -206,53 +234,55 @@ class CodegenC : public MemoizedExprTranslator>, public Code std::vector func_decl_; /*! \brief The declaration statements of buffers. */ std::vector buf_decl_; - /*! \brief The variable name to constant mapping. */ - Array const_vars_; - - friend class CSourceCodegen; }; -class CSourceCodegen : public CSourceModuleCodegenBase { +/*! \brief Emits C/C++ code for a module. */ +class CodegenCModule { public: - std::tuple, String, String> GenCFunc(const Function& func) { - ICHECK(func.defined()) << "Input error: expect a Relay function."; - CodegenC builder(GetExtSymbol(func)); - auto out = builder.VisitExpr(func->body); - return std::make_tuple(builder.const_vars_, builder.ext_func_id_, builder.JIT(out)); - } + CodegenCModule(Target target, IRModule mod) : target_(std::move(target)), mod_(std::move(mod)) {} - runtime::Module CreateCSourceModule(const ObjectRef& ref) override { - ICHECK(ref->IsInstance()); - auto res = GenCFunc(Downcast(ref)); - Array variables = std::get<0>(res); - String func_name = std::get<1>(res); - - Optional opt_target = Target::Current(); - if (opt_target.defined() && opt_target.value()->kind->name == "ccompiler") { - Optional header = opt_target.value()->GetAttr("header"); - if (header.defined() && !header.value().empty()) { - code_stream_ << header.value().c_str() << "\n"; + runtime::Module CreateCSourceModule() { + for (const auto& kv : mod_->functions) { + if (const auto* function_node = GetCCompilerFunctionNode(kv.second)) { + GenCFunc(GetRef(function_node)); } } + return Finalize(); + } + + /*! \brief Returns the accumulated constant name to constant mapping. */ + const std::unordered_map& const_name_to_constant() const { + return const_name_to_constant_; + } + + private: + /*! \brief Emits the standard C/C++ header into \p os. */ + void EmitPreamble(std::ostringstream& os) { + // Custom header, if any. + Optional header = target_->GetAttr("header"); + if (header.defined() && !header.value().empty()) { + os << header.value().c_str() << "\n"; + } + + // Standard includes. + os << "#include \n"; + os << "#include \n"; + os << "#include \n"; + os << "#include \n"; + os << "#include \n"; - // Create headers - code_stream_ << "#include \n"; - code_stream_ << "#include \n"; - code_stream_ << "#include \n"; - code_stream_ << "#include \n"; - code_stream_ << "#include \n"; - if (!variables.empty()) { + if (needs_extra_headers_) { // This segment would be generated in C++ because of the usage // of tvm::runtime::Array. This is not ideal, but this to demonstrate // constant copying process used packed imports in other external // codegen. Moreover, in microTVM we dont expect this part to be generated. - code_stream_ << "#ifdef __cplusplus\n"; - code_stream_ << "#include \n"; - code_stream_ << "#include \n"; - code_stream_ << "#endif\n"; + os << "#ifdef __cplusplus\n"; + os << "#include \n"; + os << "#include \n"; + os << "#endif\n"; } - // Append some common macro for operator definition. + // Define some macros to help operator implementations. const char* operator_macro = R"op_macro( #define CSOURCE_BINARY_OP_1D(p_ID_, p_OP_, p_DIM1_, p_DTYPE) \ void p_ID_(p_DTYPE* a, p_DTYPE* b, p_DTYPE* out) { \ @@ -272,38 +302,97 @@ class CSourceCodegen : public CSourceModuleCodegenBase { } )op_macro"; - code_stream_ << operator_macro << "\n\n"; - code_stream_ << std::get<2>(res); - std::string code = code_stream_.str(); + os << operator_macro << "\n\n"; + } + + void GenCFunc(const Function& function) { + ICHECK(function.defined()) << "Input error: expect a Relay function."; + std::string ext_func_id = backend::GetExtSymbol(function); + CodegenC builder(&const_name_to_constant_, &const_names_, &needs_extra_headers_, ext_func_id); + std::vector out = builder.VisitExpr(function->body); + code_stream_ << builder.JIT(out); + func_names_.push_back(ext_func_id); + } + + /*! \brief Returns function if it is tagged with "Compiler=ccompiler". */ + static const FunctionNode* GetCCompilerFunctionNode(const Expr& expr) { + if (const auto* function_node = expr.as()) { + Optional opt_compiler = function_node->GetAttr(attr::kCompiler); + if (opt_compiler.defined() && opt_compiler.value() == "ccompiler") { + return function_node; + } + } + return nullptr; + } + + runtime::Module Finalize() { + std::ostringstream os; + EmitPreamble(os); + os << code_stream_.str(); + std::string code = os.str(); + + VLOG(1) << "CodegenCModule generated:" << std::endl << code; // Create a CSource module const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate"); ICHECK(pf != nullptr) << "Cannot find csource module to create the external runtime module"; - return (*pf)(code, "c", Array{func_name}, variables); + return (*pf)(code, "c", func_names_, const_names_); } - private: + /*! \brief "ccompiler" Target with compilation options to use. */ + Target target_; + /*! \brief Module we are compiling. */ + IRModule mod_; + + /*! \brief True if we need to include the ndarray and packed function headers. */ + bool needs_extra_headers_ = false; + /*! \brief The accumulated constant name to constant mapping. */ + std::unordered_map const_name_to_constant_; + /*! \brief The accumulated constant names, in the order they were generated. */ + Array const_names_; + /*! \brief The accumulated function names. */ + Array func_names_; + /*! + * \brief The accumulated code stream containing all function definitions. + * (Does not include the preamble.) + */ std::ostringstream code_stream_; }; -/*! - * \brief The external compiler/codegen tool. It takes a Relay expression/module and - * compile it into a runtime module. - * - * The external codegen tool should have been registered similiarly to LLVM, - * CUDA, etc, under TVM, so the generated code could be packed in a runtime - * module. This module simplifies code serialization and invocation. - */ -runtime::Module CCompiler(const ObjectRef& ref) { - CSourceCodegen csource; - return csource.CreateCSourceModule(ref); -} +/*! \brief The actual translation pass. */ +transform::Pass CCompilerImpl() { + auto pass_func = [=](IRModule mod, const transform::PassContext& pass_ctx) { + VLOG(1) << "CCompilerImpl input:" << std::endl << PrettyPrint(mod); + Target target = GetCCompilerTarget(); + + // Emit the C/C++ code and package it as a CSourceModule. + CodegenCModule codegen(target, mod); + runtime::Module runtime_mod = codegen.CreateCSourceModule(); + + // Capture the new runtime module. + Array external_mods = + mod->GetAttr>(tvm::attr::kExternalMods).value_or({}); + external_mods.push_back(runtime_mod); + + // Capture the new constants. + Map const_name_to_constant = + mod->GetAttr>(tvm::attr::kConstNameToConstant).value_or({}); + for (const auto& kv : codegen.const_name_to_constant()) { + ICHECK_EQ(const_name_to_constant.count(kv.first), 0); + const_name_to_constant.Set(kv.first, kv.second); + } -TVM_REGISTER_GLOBAL("relay.ext.ccompiler").set_body_typed(CCompiler); + return WithAttrs(mod, {{tvm::attr::kExternalMods, external_mods}, + {tvm::attr::kConstNameToConstant, const_name_to_constant}}); + }; + return tvm::transform::CreateModulePass(pass_func, 0, "CCompilerImpl", {}); +} -TVM_REGISTER_TARGET_KIND("ccompiler", kDLCPU) - .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)) - .add_attr_option("header", String("")); // value is prepended to every output CModule +transform::Pass CCompilerPass() { + return transform::Sequential( + {transforms::OutlineCompilerFunctionsWithExistingGlobalSymbols("ccompiler"), CCompilerImpl(), + transforms::MarkCompilerFunctionsAsExtern("ccompiler")}); +} } // namespace contrib } // namespace relay diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index 49a5bca068d1..1ee72c149f1a 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -409,7 +409,7 @@ class CodegenCBase { * * \return The created reference */ - std::string CreateDataReference(const std::string& symbol, int const_id) const { + std::string CreateDataReference(const std::string& symbol, size_t const_id) const { return "(float*)(" + symbol + "_consts[" + std::to_string(const_id) + "]->data)"; } @@ -421,8 +421,8 @@ class CodegenCBase { * * \return The created variable name */ - std::string CreateConstVar(const std::string& symbol, int const_id) const { - return symbol + "_const_" + std::to_string(const_id++); + std::string CreateConstVar(const std::string& symbol, size_t const_id) const { + return symbol + "_const_" + std::to_string(const_id); } /*! \brief The external function source code stream. */ @@ -433,7 +433,14 @@ class CodegenCBase { int indent_{0}; }; +/*! + * \brief A pass to translate all "Primitive" Relay functions with "Compiler=ccompiler" to + * a \p CSourceModule. + */ +transform::Pass CCompilerPass(); + } // namespace contrib } // namespace relay } // namespace tvm + #endif // TVM_RELAY_BACKEND_CONTRIB_CODEGEN_C_CODEGEN_C_H_ diff --git a/src/relay/backend/contrib/codegen_c/target.cc b/src/relay/backend/contrib/codegen_c/target.cc new file mode 100644 index 000000000000..623057ac1762 --- /dev/null +++ b/src/relay/backend/contrib/codegen_c/target.cc @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +#include "./codegen_c.h" + +namespace tvm { +namespace relay { +namespace contrib { + +/*! + * \brief This demonstration external codegen target emits C/C++ for compilation by the native c + * compiler on CPU. + * - Patterns: None, functions must be explicitly marked as "Primitive" and "Compiler=ccompiler". + * - Custom compiler: relay/backend/contrib/codegen_c/codegen.cc + */ +TVM_REGISTER_TARGET_KIND("ccompiler", kDLCPU) + .set_attr(tvm::attr::kIsExternalCodegen, Bool(true)) + .set_attr(tvm::attr::kRelayToTIR, CCompilerPass()) + // Value is prepended to every output CModule. + .add_attr_option("header", String("")); + +} // namespace contrib +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/contrib/codegen_json/codegen_json.h b/src/relay/backend/contrib/codegen_json/codegen_json.h index 4966f3f01c7d..de6d0f74061b 100644 --- a/src/relay/backend/contrib/codegen_json/codegen_json.h +++ b/src/relay/backend/contrib/codegen_json/codegen_json.h @@ -33,6 +33,8 @@ #include #include #include +#include +#include #include #include "../../../../runtime/contrib/json/json_node.h" @@ -150,7 +152,8 @@ class JSONSerializer : public MemoizedExprTranslator(func_); @@ -162,8 +165,18 @@ class JSONSerializer : public MemoizedExprTranslatorbody); } - /*!\brief Return the required params. */ - Array GetParams() const { return params_; } + /*! + * \brief Returns the accumulated map from constant names to the NDArray they must be bound to + * at runtime. Also referred to a 'params' elsewhere in the code. + */ + const std::unordered_map& const_name_to_constant() const { + return const_name_to_constant_; + } + + /*! + * \brief Return the constant names in order they were encountered during translation. + */ + const Array& const_names() const { return const_names_; } /*!\brief Return the generated json. */ std::string GetJSON() { @@ -245,11 +258,15 @@ class JSONSerializer : public MemoizedExprTranslator(vn)]; } - std::vector VisitExpr_(const ConstantNode* cn) { - std::string name = symbol_ + "_const_" + std::to_string(params_.size()); - params_.push_back(name); - auto node = std::make_shared(name, "const" /* op_type_ */); - return AddNode(node, GetRef(cn)); + std::vector VisitExpr_(const ConstantNode* constant_node) { + std::string name = symbol_ + "_const_" + std::to_string(const_names_.size()); + VLOG(1) << "Will require parameter '" << name + << "' to be supplied by the ConstLoaderModule at runtime"; + ICHECK_EQ(const_name_to_constant_.count(name), 0); + const_name_to_constant_.emplace(name, constant_node->data); + const_names_.push_back(name); + auto node = std::make_shared(name, /*op_type=*/"const"); + return AddNode(node, GetRef(constant_node)); } std::vector VisitExpr_(const TupleNode* tn) { @@ -340,8 +357,17 @@ class JSONSerializer : public MemoizedExprTranslator nodes_; /*! \brief Output of the JSON graph. */ std::vector heads_; - /*! \brief The list of required constants. */ - Array params_; + /*! + * \brief A map from constant names to NDArrays for each Constant encountered during + * translation to JSON. The JSON will record only the constant name. The actual NDArray must + * be made available at runtime from a ConstLoaderModule. + */ + std::unordered_map const_name_to_constant_; + /*! + * \brief The domain of the above map, but in order the constants were encountered during + * translation. + */ + Array const_names_; }; } // namespace contrib diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc index 772007792ae6..de2934173b5f 100644 --- a/src/relay/backend/contrib/cutlass/codegen.cc +++ b/src/relay/backend/contrib/cutlass/codegen.cc @@ -43,6 +43,18 @@ namespace cutlass { namespace { +/*! \brief Return the "cutlass" Target instance to use to guide compilation. */ +Target GetCutlassTarget() { + Target target = Target::Current(/*allow_not_defined=*/true); + if (!target.defined() || target->kind->name != "cutlass") { + // Use the default CUTLASS compilation options if no specific "cutlass" target was given + // in the overall targets list. In that case target_hooks.cc will invoke the custom pass + // without pushing any target instance onto the implicit target stack. + target = Target("cutlass"); + } + return target; +} + using Str2StrMap = std::unordered_map; static Str2StrMap dtype_map = {{"float16", "cutlass::half_t"}, @@ -563,7 +575,7 @@ class CodegenCutlass : public backend::MemoizedExprTranslatorExitScope(); code_stream_ << "}\n"; - this->GenerateBackendCFunc(ext_func_id_, ext_func_args_, const_array_name_, out, true); + this->GenerateBackendCFunc(ext_func_id_, ext_func_args_, /*const_arr_name=*/"", out, true); return code_stream_.str(); } @@ -769,7 +781,7 @@ class CodegenCutlass : public backend::MemoizedExprTranslator attrs_; /*! @@ -781,8 +793,6 @@ class CodegenCutlass : public backend::MemoizedExprTranslator ext_func_args_; /*! \brief Statement of the function that will be compiled using CUTLASS kernels. */ std::vector ext_func_body_; - /*! \brief The array declared to store the constant values. */ - std::string const_array_name_; /*! \brief The declaration of intermediate buffers. */ std::vector buf_decl_; }; // class CodegenCutlass @@ -863,14 +873,14 @@ class CutlassModuleCodegen { const auto* pf = runtime::Registry::Get("runtime.CSourceModuleCreate"); ICHECK(pf != nullptr) << "Cannot find CSource module to create the external runtime module"; VLOG(1) << "Generated CUTLASS code:" << std::endl << code_stream_.str(); - return (*pf)(code_stream_.str(), "cu", func_names_, const_vars_); + return (*pf)(code_stream_.str(), "cu", func_names_, /*const_vars=*/Array()); } /*! * \brief Returns \p expr as function if it is a \p Function with "Compiler" attribute * value "cutlass". */ - const FunctionNode* GetCutlassFunctionNode(const Expr& expr) { + static const FunctionNode* GetCutlassFunctionNode(const Expr& expr) { if (const auto* function_node = expr.as()) { Optional opt_compiler = function_node->GetAttr(attr::kCompiler); if (opt_compiler.defined() && opt_compiler.value() == "cutlass") { @@ -886,8 +896,6 @@ class CutlassModuleCodegen { std::ostringstream code_stream_; /*! \brief The accumulated function names. */ Array func_names_; - /*! \brief The accumulated constant names. */ - Array const_vars_; }; // CutlassModuleCodegen /*! @@ -899,14 +907,12 @@ transform::Pass CompileForCutlassImpl() { VLOG(1) << "CompileForCutlass input:" << std::endl << PrettyPrint(mod); const auto* pf = runtime::Registry::Get("relay.ext.cutlass.compile_for_cutlass"); ICHECK(pf != nullptr) << "Cannot find compile_for_cutlass function"; - Optional opt_cutlass_target = Target::Current(); - ICHECK(opt_cutlass_target.defined()) << "Expecting Target::Current to be available"; - ICHECK_EQ(opt_cutlass_target.value()->kind->name, "cutlass"); - runtime::Module runtime_mod = (*pf)(mod, opt_cutlass_target.value()); + Target target = GetCutlassTarget(); + runtime::Module runtime_mod = (*pf)(mod, target); Array external_mods = - mod->GetAttr>("external_mods", Array()).value(); + mod->GetAttr>(tvm::attr::kExternalMods).value_or({}); external_mods.push_back(runtime_mod); - return WithAttr(mod, "external_mods", external_mods); + return WithAttr(mod, tvm::attr::kExternalMods, external_mods); }; return tvm::transform::CreateModulePass(pass_func, 0, "CompileForCutlass", {}); } diff --git a/src/relay/backend/contrib/dnnl/codegen.cc b/src/relay/backend/contrib/dnnl/codegen.cc index f17cdafa76a5..2f47c23a7cf9 100644 --- a/src/relay/backend/contrib/dnnl/codegen.cc +++ b/src/relay/backend/contrib/dnnl/codegen.cc @@ -585,11 +585,15 @@ runtime::Module DNNLCompiler(const ObjectRef& ref) { DNNLJSONSerializer serializer(func_name, func); serializer.serialize(); std::string graph_json = serializer.GetJSON(); - auto params = serializer.GetParams(); + + // Note that serializer.const_name_to_constant() is ignored. Instead the TECompiler invokes + // a callback which calls backend::UpdateConstants to capture the map before the function + // 'disappears' into lowered form, on the assumption the visit order and thus constant + // names match those generated by the JSONSerializer. const auto* pf = runtime::Registry::Get("runtime.DNNLJSONRuntimeCreate"); ICHECK(pf != nullptr) << "Cannot find JSON runtime module to create"; - auto mod = (*pf)(func_name, graph_json, params); + auto mod = (*pf)(func_name, graph_json, serializer.const_names()); return mod; #else DNNLModuleCodegen dnnl; diff --git a/src/relay/backend/contrib/example_target_hooks/target.cc b/src/relay/backend/contrib/example_target_hooks/target.cc index 19bfa8c68298..b01c23ed806a 100644 --- a/src/relay/backend/contrib/example_target_hooks/target.cc +++ b/src/relay/backend/contrib/example_target_hooks/target.cc @@ -1,4 +1,3 @@ - /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file diff --git a/src/relay/backend/contrib/tensorrt/codegen.cc b/src/relay/backend/contrib/tensorrt/codegen.cc index 149cc485c752..e08cd240d4d1 100644 --- a/src/relay/backend/contrib/tensorrt/codegen.cc +++ b/src/relay/backend/contrib/tensorrt/codegen.cc @@ -318,11 +318,16 @@ runtime::Module TensorRTCompiler(const ObjectRef& ref) { serializer.serialize(); std::string graph_json = serializer.GetJSON(); VLOG(1) << "TensorRT JSON:" << std::endl << graph_json; - auto param_names = serializer.GetParams(); + + // Note that serializer.const_name_to_constant() is ignored. Instead the TECompiler invokes + // a callback which calls backend::UpdateConstants to capture the map before the function + // 'disappears' into lowered form, on the assumption the visit order and thus constant + // names match those generated by the JSONSerializer. + const auto* pf = runtime::Registry::Get("runtime.tensorrt_runtime_create"); ICHECK(pf != nullptr) << "Cannot find TensorRT runtime module create function."; VLOG(1) << "Creating tensorrt runtime::Module for '" << func_name << "'"; - runtime::Module lib = (*pf)(func_name, graph_json, param_names); + runtime::Module lib = (*pf)(func_name, graph_json, serializer.const_names()); return lib; } diff --git a/src/relay/backend/contrib/verilator/codegen.cc b/src/relay/backend/contrib/verilator/codegen.cc index 2c29896d1b0e..2e6fb1326314 100644 --- a/src/relay/backend/contrib/verilator/codegen.cc +++ b/src/relay/backend/contrib/verilator/codegen.cc @@ -111,10 +111,15 @@ runtime::Module VerilatorBackend(const ObjectRef& ref) { VerilatorJSONSerializer serializer(func_name, func); serializer.serialize(); std::string graph_json = serializer.GetJSON(); - auto params = serializer.GetParams(); + + // Note that serializer.const_name_to_constant() is ignored. Instead the TECompiler invokes + // a callback which calls backend::UpdateConstants to capture the map before the function + // 'disappears' into lowered form, on the assumption the visit order and thus constant + // names match those generated by the JSONSerializer. // Create runtime object - auto n = make_object(func_name, graph_json, params); + auto n = make_object(func_name, graph_json, + serializer.const_names()); // Get Verilator compiler options auto ctx = transform::PassContext::Current(); diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index af426e5c71cb..faf9d2899fc3 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -259,21 +259,31 @@ class GraphExecutorCodegen : public backend::MemoizedExprTranslator>(); - for (auto param : params_) { - ret.params.emplace(std::make_pair( - param.first, - std::make_pair(static_cast(param_storage_ids_[param.first]), param.second))); + + // Collect any runtime modules generated by external codegen. + ret.external_mods = + lowered_mod->GetAttr>(tvm::attr::kExternalMods).value_or({}); + + // Collect any constants extracted by external codegen. + ret.params = std::unordered_map(); + Map const_name_to_constant = + lowered_mod->GetAttr>(tvm::attr::kConstNameToConstant) + .value_or({}); + for (const auto& kv : const_name_to_constant) { + VLOG(1) << "constant '" << kv.first << "' contributed by external codegen"; + ICHECK(ret.params.emplace(kv.first, kv.second).second); } - ret.function_metadata = std::move(function_metadata_); - Optional> external_modules = - lowered_mod->GetAttr>("external_mods"); - ICHECK(external_modules) << "Attribute \"external_mods\" should be set at this point."; + // Collect any constants extracted during lowering. + for (const auto& kv : params_) { + VLOG(1) << "constant '" << kv.first << "' contributed by TECompiler"; + ICHECK(ret.params.emplace(kv.first, kv.second).second); + } + + ret.function_metadata = std::move(function_metadata_); // This is the point where we separate the functions in the module by target ret.lowered_funcs = tec::GetPerTargetModules(lowered_mod); - ret.external_mods = external_modules.value(); ret.metadata = ExecutorCodegenMetadata({} /* inputs */, {} /* input_tensor_types */, {} /* outputs */, {} /* output_tensor_types */, {} /* pools */, {} /* devices */, @@ -650,14 +660,7 @@ class GraphExecutorCodegenModule : public runtime::ModuleNode { String key = args[0]; auto it = this->output_.params.find(key); CHECK(it != this->output_.params.end()) << "no such parameter " << key; - *rv = (*it).second.second; - }); - } else if (name == "get_param_id") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - String key = args[0]; - auto it = this->output_.params.find(key); - CHECK(it != this->output_.params.end()) << "no such parameter " << key; - *rv = (*it).second.first; + *rv = (*it).second; }); } else if (name == "get_irmodule") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index e9491b0a8901..4390e90b2cf3 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -1196,7 +1196,7 @@ IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn pr // annotate the module with the resulting runtime modules. // TODO(mbs): runtime modules should be first class rather than attributes. Array external_mods = - module->GetAttr>("external_mods", Array()).value(); + module->GetAttr>(tvm::attr::kExternalMods).value_or({}); Array new_external_mods = compiler->LowerExternalFunctions(); VLOG(1) << "capturing " << external_mods.size() << " existing and " << new_external_mods.size() << " new external modules"; @@ -1218,7 +1218,7 @@ IRModule LowerTE(const IRModule& module, const String& module_name, ProcessFn pr device_contexts.Set(kv.first, kv.second); // copy-on-write. } - updated_module = WithAttrs(updated_module, {{"external_mods", std::move(external_mods)}, + updated_module = WithAttrs(updated_module, {{tvm::attr::kExternalMods, std::move(external_mods)}, {"device_contexts", std::move(device_contexts)}}); if (backend::IsAutoSchedulerEnabled()) { diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index 67924a7835fb..d6fae8c72b5e 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -223,7 +223,11 @@ struct LoweredOutput { Map lowered_funcs; Array external_mods; Map function_metadata; - std::unordered_map> params; + /*! + * \brief Map from constant names (allocated by the codegen as constants are encountered) + * to the constant's value. + */ + std::unordered_map params; ExecutorCodegenMetadata metadata; }; @@ -249,7 +253,7 @@ struct ConstantUpdater : public ExprVisitor { void VisitExpr_(const ConstantNode* cn) final { std::string name = symbol_ + "_const_" + std::to_string(const_idx_++); - VLOG(1) << "Binding " << name << " to constant of type " << PrettyPrint(cn->checked_type()); + VLOG(1) << "binding '" << name << "' to constant of type " << PrettyPrint(cn->checked_type()); (*params_)[name] = cn->data; } diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 7371fd1f8083..a8bd3df32a90 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -1166,11 +1166,27 @@ void VMCompiler::Codegen() { for (const auto& kv : per_tvm_target_modules) { ICHECK(kv.first->kind->device_type != kDLExtDev); } - Array ext_mods = - context_.module->GetAttr>("external_mods", Array()) - .value(); - VLOG(0) << "have " << per_tvm_target_modules.size() << " targets to build and " << ext_mods.size() - << " external runtime modules"; + + // Retrieve all external runtime modules accumulated by external codegen (both function-at-a-time + // and IRModule-at-a-time). + Array external_mods = + context_.module->GetAttr>(tvm::attr::kExternalMods).value_or({}); + + // Retrieve any constant bindings accumulated by external codegen (by IRModule-at-a-time passes). + Map const_name_to_constant = + context_.module->GetAttr>(tvm::attr::kConstNameToConstant) + .value_or({}); + + VLOG(0) << "have " << per_tvm_target_modules.size() << " targets to build, " + << external_mods.size() << " external runtime modules, " << const_name_to_constant.size() + << " external constants, and " << params_.size() << " local constants"; + + // Any constant bindings must be merged into the overall 'params' map we've directly accumulated + // via the TECompiler callback. + for (const auto& kv : const_name_to_constant) { + ICHECK_EQ(params_.count(kv.first), 0); + params_.emplace(kv.first, kv.second); + } runtime::Module lib; if (per_tvm_target_modules.empty()) { @@ -1183,7 +1199,7 @@ void VMCompiler::Codegen() { } lib = - codegen::CreateMetadataModule(params_, lib, ext_mods, config_->host_target, + codegen::CreateMetadataModule(params_, lib, external_mods, config_->host_target, Runtime::Create("cpp"), Executor::Create("graph"), // DNS HACK relay::backend::ExecutorCodegenMetadata()); exec_->SetLib(lib); diff --git a/src/relay/transforms/compiler_function_utils.cc b/src/relay/transforms/compiler_function_utils.cc index 1b0f002f1def..0df9f5ee294c 100644 --- a/src/relay/transforms/compiler_function_utils.cc +++ b/src/relay/transforms/compiler_function_utils.cc @@ -50,7 +50,7 @@ const FunctionNode* AsFunctionNode(const Expr& expr, const std::string& compiler } /*! - * \brief Rewrite calls to inlined "Compiler" functions to global functions. The given + * \brief Rewrite calls to inlined and let-bound "Compiler" functions to global functions. The given * module will be extended with the newly outlined functions. */ class Outliner : public MixedModeMutator { @@ -58,6 +58,38 @@ class Outliner : public MixedModeMutator { Outliner(GlobalSymbolCache* cache, std::string compiler_filter, IRModule mod) : cache_(cache), compiler_filter_(std::move(compiler_filter)), mod_(std::move(mod)) {} + Expr VisitExpr_(const LetNode* op) final { + auto pre_visit = [this](const LetNode* op) { + Expr var = this->VisitExpr(op->var); + Expr value = this->VisitExpr(op->value); + + if (AsFunctionNode(value, compiler_filter_)) { + // Inline on-the-fly if the let-bound value is a function of interest. + this->memo_[var] = value; + } + }; + auto post_visit = [this](const LetNode* op) { + // Rely on the Memoizer to cache pre-visit values + Expr value = this->VisitExpr(op->value); + Expr body = this->VisitExpr(op->body); + auto expr = GetRef(op); + + if (AsFunctionNode(value, compiler_filter_)) { + // The let binding is no longer needed since inlined on-the-fly above. + this->memo_[expr] = this->VisitExpr(op->body); + } else { + Var var = Downcast(this->VisitExpr(op->var)); + if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { + this->memo_[expr] = expr; + } else { + this->memo_[expr] = Let(var, value, body); + } + } + }; + ExpandANormalForm(op, pre_visit, post_visit); + return memo_[GetRef(op)]; + } + Expr Rewrite_(const CallNode* pre, const Expr& post) final { Call new_call = Downcast(post); if (const auto* function_node = AsFunctionNode(new_call->op, compiler_filter_)) { diff --git a/src/relay/transforms/compiler_function_utils.h b/src/relay/transforms/compiler_function_utils.h index 6664594fc0a0..aa98430318a6 100644 --- a/src/relay/transforms/compiler_function_utils.h +++ b/src/relay/transforms/compiler_function_utils.h @@ -95,9 +95,10 @@ class ExistingGlobalSymbolCache : public GlobalSymbolCache { }; /*! - * \brief A pass to outline all literal functions in direct call positions which have a "Compiler" - * attribute. The given \p GlobalSymbolCache is used to determine a unique global symbol for each - * function, which is also assigned to the "global_symbol" attribute of the new global function. + * \brief A pass to outline all let-bound and literal functions in direct call positions which have + * a "Compiler" attribute. The given \p GlobalSymbolCache is used to determine a unique global + * symbol for each function, which is also assigned to the "global_symbol" attribute of the new + * global function. * * At most one function with the same global symbol is outlined. * @@ -108,9 +109,9 @@ transform::Pass OutlineCompilerFunctions(std::shared_ptr cach std::string compiler_filter = ""); /*! - * \brief A pass to outline all literal functions in direct call positions which have a "Compiler" - * attribute. The functions are bound to unique global vars according to their existing - * "global_symbol" attribute. At most one function with the same global symbol is outlined. + * \brief A pass to outline all let-bound and literal functions in direct call positions which have + * a "Compiler" attribute. The functions are bound to unique global vars according to their + * existing "global_symbol" attribute. At most one function with the same global symbol is outlined. * * If \p compiler_filter is non-empty only functions with that as their attribute value are * outlined. diff --git a/src/relay/transforms/target_hooks.cc b/src/relay/transforms/target_hooks.cc index 00953a1907e1..f52e95b2adbf 100644 --- a/src/relay/transforms/target_hooks.cc +++ b/src/relay/transforms/target_hooks.cc @@ -148,7 +148,7 @@ class TargetHookVisitor : public MixedModeVisitor { Pass RelayToTIRTargetHook(CompilationConfig config) { auto pass_func = [config = std::move(config)](IRModule mod, const PassContext& pass_ctx) { - VLOG(1) << "Before:" << std::endl << PrettyPrint(mod); + VLOG(1) << "RelayToTIRTargetHook before:" << std::endl << PrettyPrint(mod); TargetHookVisitor target_hook_visitor(mod, config); std::vector custom_passes = target_hook_visitor.Visit(); for (const auto& custom_pass : custom_passes) { @@ -161,11 +161,14 @@ Pass RelayToTIRTargetHook(CompilationConfig config) { mod = custom_pass.pass(mod); } else { // Invoke the pass. + // Note that there may be a non-external codegen target in scope. Each custom pass + // must be prepared to handle this, eg by creating a default target instance if the + // current target is either null or of a generic kind such as 'cuda' or 'llvm'. VLOG(0) << "Invoking custom pass for target kind '" << custom_pass.target_kind_name << "'"; mod = custom_pass.pass(mod); } } - VLOG(1) << "After:" << std::endl << PrettyPrint(mod); + VLOG(1) << "RelayToTIRTargetHook after:" << std::endl << PrettyPrint(mod); return mod; }; return tvm::transform::CreateModulePass(pass_func, 0, "RelayToTIRTargetHook", {}); diff --git a/src/target/metadata_module.cc b/src/target/metadata_module.cc index e5ca82d5c099..ec301d10812f 100644 --- a/src/target/metadata_module.cc +++ b/src/target/metadata_module.cc @@ -215,6 +215,8 @@ runtime::Module CreateMetadataModule( String symbol = pf_sym(); Array variables = pf_var(); for (size_t i = 0; i < variables.size(); i++) { + VLOG(1) << "From module of type '" << mod->type_key() << "' found const var '" + << variables[i] << "' for symbol '" << symbol << "'"; symbol_const_vars.push_back(variables[i].operator std::string()); } ICHECK_EQ(const_vars_by_symbol.count(symbol), 0U) << "Found duplicated symbol: " << symbol; diff --git a/src/tir/transforms/extract_constants.cc b/src/tir/transforms/extract_constants.cc index 237f923516da..f9e620ba3322 100644 --- a/src/tir/transforms/extract_constants.cc +++ b/src/tir/transforms/extract_constants.cc @@ -80,14 +80,14 @@ tvm::transform::Pass ExtractPrimFuncConstants() { } auto* attrs = m->attrs.CopyOnWrite(); ConstArrayType constant_array_ = - (attrs->dict.count(tvm::attr::kConstantsArray)) - ? Downcast(attrs->dict[tvm::attr::kConstantsArray]) + (attrs->dict.count(tvm::attr::kConstants)) + ? Downcast(attrs->dict[tvm::attr::kConstants]) : ConstArrayType(); Applicator a = Applicator(); func->body = a.Apply(func->body, constant_array_); const ConstArrayType constant_list = a.constant_array_; if (constant_list.size()) { - attrs->dict.Set(tvm::attr::kConstantsArray, constant_list); + attrs->dict.Set(tvm::attr::kConstants, constant_list); } return GetRef(func); }; diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index 4f451a125184..873475ac1ce7 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -235,37 +235,29 @@ def make_mod(): @pytest.mark.skipif(sys.platform == "win32", reason="Skip test on Windows for now") -def test_extern_gcc_consts(): - @tvm._ffi.register_func("relay.ext.ccompiler.constant_updater") - def constant_updater(expr, symbol): - """A dummy constant updater just to test that a custom one works.""" - return {"ccompiler_0_p0": tvm.nd.array(y0_data)} - - x = relay.var("x", shape=(8, 8)) - y0_data = np.random.uniform(0, 1, (8, 8)).astype("float32") +@pytest.mark.parametrize("check_result", [check_graph_executor_result, check_vm_result]) +def test_extern_gcc_consts(check_result): + shape = (8, 8) + dtype = "float32" + x = relay.var("x", shape=shape) + y0_data = np.random.uniform(0, 1, shape).astype(dtype) - x0 = relay.var("x0", shape=(8, 8)) - y0_const = relay.const(y0_data, "float32") + x0 = relay.var("x0", shape=shape) + y0_const = relay.const(y0_data, dtype) z = x0 + y0_const f = relay.Function([x0], z) f = set_external_func_attr(f, "ccompiler", "ccompiler_0") call = relay.Call(f, [x]) mod = tvm.IRModule.from_expr(call) - with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): - compiler = relay.backend.vm.VMCompiler() - compiler.lower(mod, "llvm") - compiler.codegen() - params = compiler.get_params() - assert len(params) == 1 - assert "ccompiler_0_p0" in params.keys() - - with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]): - _, _, params = relay.build(mod, target="llvm") - assert len(params) == 1 - assert "ccompiler_0_p0" in params.keys() - - tvm._ffi.registry.remove_global_func("relay.ext.ccompiler.constant_updater") + # Note that while the VMCompiler get_params() will return all 'parameters' from both + # TVM and external codegen compiled code, the GraphExecutor.get_params() will return only + # those from non-external modules. So in the following we'll test by execution rather than + # test by inspection. + x_data = np.random.rand(*shape).astype(dtype) + inputs = {"x": x_data} + expected_result = x_data + y0_data + check_result(mod, inputs, shape, expected_result, target="llvm") @pytest.mark.skipif( diff --git a/tests/python/relay/transform/test_compiler_function_utils.py b/tests/python/relay/transform/test_compiler_function_utils.py index 66abeff8ab29..b1056f60b82b 100644 --- a/tests/python/relay/transform/test_compiler_function_utils.py +++ b/tests/python/relay/transform/test_compiler_function_utils.py @@ -75,6 +75,39 @@ def @main(%x0 : Tensor[(1600, 768), float16], %x3 : Tensor[(600, 32, 64), float1 ) +def original_mod_let_bound(): + return tvm.parser.parse( + """ + #[version = "0.0.5"] + def @main(%x0 : Tensor[(1600, 768), float16], %x3 : Tensor[(600, 32, 64), float16]) -> (Tensor[(1600, 2304), float16], Tensor[(600, 32, 32), float16]) { + let %f = fn(%y_0_i0: Tensor[(1600, 768), float16], %y_0_i1: Tensor[(2304, 768), float16], %y_0_i2: Tensor[(2304), float16], + Inline=1, Compiler="cutlass", global_symbol="tvmgen_default_cutlass_main_0", Primitive=1) -> Tensor[(1600, 2304), float16] { + %4 = fn (%FunctionVar_0_0: Tensor[(1600, 768), float16], %FunctionVar_0_1: Tensor[(2304, 768), float16], %FunctionVar_0_2: Tensor[(2304), float16], + PartitionedFromPattern="nn.dense_add_", Composite="cutlass.dense_bias") -> Tensor[(1600, 2304), float16] { + %5 = nn.dense(%FunctionVar_0_0, %FunctionVar_0_1, units=2304); + add(%5, %FunctionVar_0_2) + }; + %4(%y_0_i0, %y_0_i1, %y_0_i2) + }; + %1 = %f(%x0, meta[relay.Constant][0], meta[relay.Constant][1]); + %2 = fn(%y_3_i0: Tensor[(600, 32, 64), float16], %y_3_i1: Tensor[(600, 32, 64), float16], + Inline=1, Compiler="cublas", global_symbol="tvmgen_default_cublas_main_3", Primitive=1) -> Tensor[(600, 32, 32), float16] { + %6 = fn (%FunctionVar_0_01: Tensor[(600, 32, 64), float16], %FunctionVar_0_11: Tensor[(600, 32, 64), float16], + PartitionedFromPattern="nn.batch_matmul_", Composite="cublas.batch_matmul") -> Tensor[(600, 32, 32), float16] { + nn.batch_matmul(%FunctionVar_0_01, %FunctionVar_0_11, out_dtype="float16", transpose_b=True) + }; + %6(%y_3_i0, %y_3_i1) + }; + %3 = %2(%x3, meta[relay.Constant][2]); + (%1, %3) + } + """, + "from_string", + None, + metatable, + ) + + def expected_outlined_mod(): return tvm.parser.parse( """ @@ -175,6 +208,13 @@ def test_outline_compiler_functions_with_existing_global_symbols(): tvm.ir.assert_structural_equal(actual_outlined_mod, expected_outlined_mod(), map_free_vars=True) +def test_outline_let_bound_compiler_functions_with_existing_global_symbols(): + actual_outlined_mod = tvm.relay.transform.OutlineCompilerFunctionsWithExistingGlobalSymbols( + "cutlass" + )(original_mod_let_bound()) + tvm.ir.assert_structural_equal(actual_outlined_mod, expected_outlined_mod(), map_free_vars=True) + + def test_mark_compiler_functions_as_extern(): actual_extern_mod = tvm.relay.transform.MarkCompilerFunctionsAsExtern("cutlass")( expected_outlined_mod() diff --git a/tests/python/unittest/test_custom_datatypes.py b/tests/python/unittest/test_custom_datatypes.py index b135973718bc..e3cff18c51f8 100644 --- a/tests/python/unittest/test_custom_datatypes.py +++ b/tests/python/unittest/test_custom_datatypes.py @@ -21,6 +21,7 @@ import pytest import tvm import tvm.topi.testing +import tvm.testing from tvm import relay from tvm.relay.testing.layers import batch_norm_infer from tvm.target.datatype import ( @@ -560,4 +561,4 @@ def test_posites2(): if __name__ == "__main__": - pytest.main([__file__]) + tvm.testing.main() diff --git a/tests/python/unittest/test_tir_transform_extract_constants.py b/tests/python/unittest/test_tir_transform_extract_constants.py index cb49e7286fbb..82f4f6515c09 100644 --- a/tests/python/unittest/test_tir_transform_extract_constants.py +++ b/tests/python/unittest/test_tir_transform_extract_constants.py @@ -18,6 +18,7 @@ import tvm from tvm import tir from tvm.script import tir as T +import tvm.testing @tvm.script.ir_module @@ -49,7 +50,7 @@ def constant3(a: T.handle) -> None: def test_const_extraction(): mod = tvm.tir.transform.ExtractPrimFuncConstants()(Module4) - constants = mod.attrs["Constants"] + constants = mod.attrs["constants"] assert len(constants) == 2 def _visit(stmt): @@ -63,4 +64,4 @@ def _visit(stmt): if __name__ == "__main__": - test_const_extraction() + tvm.testing.main()