From 431a7d6c0b7e5ae71b411c500836b136322f9fbf Mon Sep 17 00:00:00 2001 From: Alexander Pivovarov Date: Thu, 18 Mar 2021 16:44:45 -0700 Subject: [PATCH 1/7] Default value for graph_runtime Init lookup_linked_param_func (#7676) --- src/runtime/graph/graph_runtime.cc | 5 +++-- src/runtime/graph/graph_runtime.h | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index 7e98acb6fb3ee..5c7b756961684 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -66,10 +66,11 @@ void GraphRuntime::Run() { * processor. * \param ctxs The context of the host and devices where graph nodes will be * executed on. - * \param lookup_linked_param_func Linked parameter lookup function. + * \param lookup_linked_param_func Linked parameter lookup function. Default is nullptr. */ void GraphRuntime::Init(const std::string& graph_json, tvm::runtime::Module module, - const std::vector& ctxs, PackedFunc lookup_linked_param_func) { + const std::vector& ctxs, + const PackedFunc lookup_linked_param_func) { std::istringstream is(graph_json); dmlc::JSONReader reader(&is); this->Load(&reader); diff --git a/src/runtime/graph/graph_runtime.h b/src/runtime/graph/graph_runtime.h index a1e2ee3b5d744..e417d2aa4bfcc 100644 --- a/src/runtime/graph/graph_runtime.h +++ b/src/runtime/graph/graph_runtime.h @@ -93,11 +93,12 @@ class TVM_DLL GraphRuntime : public ModuleNode { * executed on. * \param lookup_linked_param_func If given, a PackedFunc invoked to lookup linked parameters * by storage_id. If not given, linked parameters are looked-up using an internal implementation, - * which is not compatible with RPCModules. + * which is not compatible with RPCModules. Default is nullptr. */ void Init(const std::string& graph_json, tvm::runtime::Module module, - const std::vector& ctxs, const PackedFunc lookup_linked_param_func); + const std::vector& ctxs, + const PackedFunc lookup_linked_param_func = nullptr); /*! * \brief Get the input index given the name of input. From e467748ef46836d94fd48a9673e8f05b305afe4c Mon Sep 17 00:00:00 2001 From: eric Date: Fri, 19 Mar 2021 12:13:04 +0900 Subject: [PATCH 2/7] [CPP_RPC] allow user supplied work dir (#7670) * [CPP_RPC] allow user supplied work dir * clang format --- apps/cpp_rpc/main.cc | 10 +++++++++- apps/cpp_rpc/rpc_env.cc | 35 ++++++++++++++++++++--------------- apps/cpp_rpc/rpc_env.h | 2 +- apps/cpp_rpc/rpc_server.cc | 21 ++++++++++++--------- apps/cpp_rpc/rpc_server.h | 3 ++- 5 files changed, 44 insertions(+), 27 deletions(-) diff --git a/apps/cpp_rpc/main.cc b/apps/cpp_rpc/main.cc index e381dd2b261b9..0663c378819e2 100644 --- a/apps/cpp_rpc/main.cc +++ b/apps/cpp_rpc/main.cc @@ -55,6 +55,7 @@ static const string kUsage = "--tracker - The RPC tracker address in host:port format e.g. 10.1.1.2:9190 Default=\"\"\n" "--key - The key used to identify the device type in tracker. Default=\"\"\n" "--custom-addr - Custom IP Address to Report to RPC Tracker. Default=\"\"\n" + "--work-dir - Custom work directory. Default=\"\"\n" "--silent - Whether to run in silent mode. Default=False\n" "\n" " Example\n" @@ -70,6 +71,7 @@ static const string kUsage = * \arg tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default="" * \arg key The key used to identify the device type in tracker. Default="" * \arg custom_addr Custom IP Address to Report to RPC Tracker. Default="" + * \arg work_dir Custom work directory. Default="" * \arg silent Whether run in silent mode. Default=False */ struct RpcServerArgs { @@ -79,6 +81,7 @@ struct RpcServerArgs { string tracker; string key; string custom_addr; + string work_dir; bool silent = false; #if defined(WIN32) std::string mmap_path; @@ -96,6 +99,7 @@ void PrintArgs(const RpcServerArgs& args) { LOG(INFO) << "tracker = " << args.tracker; LOG(INFO) << "key = " << args.key; LOG(INFO) << "custom_addr = " << args.custom_addr; + LOG(INFO) << "work_dir = " << args.work_dir; LOG(INFO) << "silent = " << ((args.silent) ? ("True") : ("False")); } @@ -238,6 +242,10 @@ void ParseCmdArgs(int argc, char* argv[], struct RpcServerArgs& args) { dmlc::InitLogging("--minloglevel=0"); } #endif + const string work_dir = GetCmdOption(argc, argv, "--work-dir="); + if (!work_dir.empty()) { + args.work_dir = work_dir; + } } /*! @@ -274,7 +282,7 @@ int RpcServer(int argc, char* argv[]) { #endif RPCServerCreate(args.host, args.port, args.port_end, args.tracker, args.key, args.custom_addr, - args.silent); + args.work_dir, args.silent); return 0; } diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc index ea19cfa3979d3..5f703e1dc2b01 100644 --- a/apps/cpp_rpc/rpc_env.cc +++ b/apps/cpp_rpc/rpc_env.cc @@ -39,7 +39,6 @@ int mkdir(const char* path, int /* ignored */) { return _mkdir(path); } #include #include #include - #include "../../src/support/utils.h" #include "rpc_env.h" @@ -85,25 +84,31 @@ void CleanDir(const std::string& dirname); */ std::string BuildSharedLibrary(std::string file_in); -RPCEnv::RPCEnv() { +RPCEnv::RPCEnv(const std::string& wd) { + if (wd != "") { + base_ = wd + "/.cache"; + mkdir(wd.c_str(), 0777); + mkdir(base_.c_str(), 0777); + } else { #if defined(ANDROID) || defined(__ANDROID__) - char cwd[PATH_MAX]; - auto cmdline = fopen("/proc/self/cmdline", "r"); - fread(cwd, 1, sizeof(cwd), cmdline); - fclose(cmdline); - base_ = "/data/data/" + std::string(cwd) + "/cache/rpc"; + char cwd[PATH_MAX]; + auto cmdline = fopen("/proc/self/cmdline", "r"); + fread(cwd, 1, sizeof(cwd), cmdline); + fclose(cmdline); + base_ = "/data/data/" + std::string(cwd) + "/cache/rpc"; #elif !defined(_WIN32) - char cwd[PATH_MAX]; - if (getcwd(cwd, sizeof(cwd))) { - base_ = std::string(cwd) + "/rpc"; - } else { - base_ = "./rpc"; - } + char cwd[PATH_MAX]; + if (getcwd(cwd, sizeof(cwd))) { + base_ = std::string(cwd) + "/rpc"; + } else { + base_ = "./rpc"; + } #else - base_ = "./rpc"; + base_ = "./rpc"; #endif + mkdir(base_.c_str(), 0777); + } - mkdir(base_.c_str(), 0777); TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath").set_body([this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetPath(args[0]); }); diff --git a/apps/cpp_rpc/rpc_env.h b/apps/cpp_rpc/rpc_env.h index 50ef3835e015d..dbb0a62d2c5d1 100644 --- a/apps/cpp_rpc/rpc_env.h +++ b/apps/cpp_rpc/rpc_env.h @@ -39,7 +39,7 @@ struct RPCEnv { /*! * \brief Constructor Init The RPC Environment initialize function */ - RPCEnv(); + RPCEnv(const std::string& word_dir = ""); /*! * \brief GetPath To get the workpath from packed function * \param name The file name diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc index a4028ff61eca8..52b5da965b4cb 100644 --- a/apps/cpp_rpc/rpc_server.cc +++ b/apps/cpp_rpc/rpc_server.cc @@ -98,14 +98,15 @@ class RPCServer { * \brief Constructor. */ RPCServer(std::string host, int port, int port_end, std::string tracker_addr, std::string key, - std::string custom_addr) + std::string custom_addr, std::string work_dir) : host_(std::move(host)), port_(port), my_port_(0), port_end_(port_end), tracker_addr_(std::move(tracker_addr)), key_(std::move(key)), - custom_addr_(std::move(custom_addr)) {} + custom_addr_(std::move(custom_addr)), + work_dir_(std::move(work_dir)) {} /*! * \brief Destructor. @@ -174,7 +175,7 @@ class RPCServer { const pid_t worker_pid = fork(); if (worker_pid == 0) { // Worker process - ServerLoopProc(conn, addr); + ServerLoopProc(conn, addr, work_dir_); _exit(0); } @@ -201,7 +202,7 @@ class RPCServer { } else { auto pid = fork(); if (pid == 0) { - ServerLoopProc(conn, addr); + ServerLoopProc(conn, addr, work_dir_); exit(0); } // Wait for the result @@ -308,9 +309,10 @@ class RPCServer { * \param sock The socket information * \param addr The socket address information */ - static void ServerLoopProc(support::TCPSocket sock, support::SockAddr addr) { + static void ServerLoopProc(support::TCPSocket sock, support::SockAddr addr, + std::string work_dir) { // Server loop - const auto env = RPCEnv(); + const auto env = RPCEnv(work_dir); RPCServerLoop(int(sock.sockfd)); LOG(INFO) << "Finish serving " << addr.AsString(); env.CleanUp(); @@ -339,6 +341,7 @@ class RPCServer { std::string tracker_addr_; std::string key_; std::string custom_addr_; + std::string work_dir_; support::TCPSocket listen_sock_; support::TCPSocket tracker_sock_; }; @@ -370,19 +373,19 @@ void ServerLoopFromChild(SOCKET socket) { * silent mode. Default=True */ void RPCServerCreate(std::string host, int port, int port_end, std::string tracker_addr, - std::string key, std::string custom_addr, bool silent) { + std::string key, std::string custom_addr, std::string work_dir, bool silent) { if (silent) { // Only errors and fatal is logged dmlc::InitLogging("--minloglevel=2"); } // Start the rpc server RPCServer rpc(std::move(host), port, port_end, std::move(tracker_addr), std::move(key), - std::move(custom_addr)); + std::move(custom_addr), std::move(work_dir)); rpc.Start(); } TVM_REGISTER_GLOBAL("rpc.ServerCreate").set_body([](TVMArgs args, TVMRetValue* rv) { - RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]); + RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7]); }); } // namespace runtime } // namespace tvm diff --git a/apps/cpp_rpc/rpc_server.h b/apps/cpp_rpc/rpc_server.h index 7a4bda5d65c41..e4565d095b2e1 100644 --- a/apps/cpp_rpc/rpc_server.h +++ b/apps/cpp_rpc/rpc_server.h @@ -48,11 +48,12 @@ void ServerLoopFromChild(SOCKET socket); * \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default="" * \param key The key used to identify the device type in tracker. Default="" * \param custom_addr Custom IP Address to Report to RPC Tracker. Default="" + * \param work_dir Custom work directory. Default="" * \param silent Whether run in silent mode. Default=True */ void RPCServerCreate(std::string host = "", int port = 9090, int port_end = 9099, std::string tracker_addr = "", std::string key = "", - std::string custom_addr = "", bool silent = true); + std::string custom_addr = "", std::string work_dir = "", bool silent = true); } // namespace runtime } // namespace tvm #endif // TVM_APPS_CPP_RPC_SERVER_H_ From 2ee860e902e77f45996a5585fc09c5e5c29788e1 Mon Sep 17 00:00:00 2001 From: Dmitriy Smirnov Date: Fri, 19 Mar 2021 06:47:45 +0000 Subject: [PATCH 3/7] [TFLite] Cast operator adapted for MLIR-based convertor (#7639) * [TFLite] Cast operator adapted for MLIR-based convertor Cast operator now can be executed in MLIR-based version. Unit test updated Change-Id: I30e5c1c9d69355116b560af8f6d0582b2d593538 * Comment added Change-Id: I3e2d29ef201283de337168d0b82679b63ca2fcf4 --- python/tvm/relay/frontend/tflite.py | 17 ++++++++++++----- tests/python/frontend/tflite/test_forward.py | 19 ++++++++++++++----- 2 files changed, 26 insertions(+), 10 deletions(-) diff --git a/python/tvm/relay/frontend/tflite.py b/python/tvm/relay/frontend/tflite.py index d6f704703cae8..a5c9a586e2753 100644 --- a/python/tvm/relay/frontend/tflite.py +++ b/python/tvm/relay/frontend/tflite.py @@ -2336,11 +2336,18 @@ def convert_cast(self, op): input_tensor = input_tensors[0] in_expr = self.get_expr(input_tensor.tensor_idx) - assert op.BuiltinOptionsType() == BuiltinOptions.CastOptions - op_options = op.BuiltinOptions() - cast_options = CastOptions() - cast_options.Init(op_options.Bytes, op_options.Pos) - cast_dtype = cast_options.OutDataType() + # MLIR-based converter outputs no BuiltinOptions for Cast operator. In this + # case the output type can be derived from the Cast operator output tensor. + # When TOCO converter is used there will be "normal" BuiltinOptions.CastOptions + # with output type. + if op.BuiltinOptions() is not None: + assert op.BuiltinOptionsType() == BuiltinOptions.CastOptions + op_options = op.BuiltinOptions() + cast_options = CastOptions() + cast_options.Init(op_options.Bytes, op_options.Pos) + cast_dtype = cast_options.OutDataType() + else: + cast_dtype = self.get_output_tensors(op)[0].tensor.Type() out = _op.cast(in_expr, self.get_tensor_type_str(cast_dtype)) diff --git a/tests/python/frontend/tflite/test_forward.py b/tests/python/frontend/tflite/test_forward.py index 0d02c15f2eb82..7c12cd3365cab 100644 --- a/tests/python/frontend/tflite/test_forward.py +++ b/tests/python/frontend/tflite/test_forward.py @@ -647,19 +647,28 @@ def test_forward_transpose(): # ---- -def _test_cast(data, cast_dtype): +def _test_cast(data, cast_dtype, use_mlir=False): """ One iteration of CAST """ with tf.Graph().as_default(): in_data = array_ops.placeholder(shape=data.shape, dtype=data.dtype) out = math_ops.cast(in_data, cast_dtype) - compare_tflite_with_tvm(data, "Placeholder:0", [in_data], [out]) + compare_tflite_with_tvm( + data, "Placeholder:0", [in_data], [out], experimental_new_converter=use_mlir + ) def test_forward_cast(): """ CAST """ - _test_cast(np.arange(6.0, dtype=np.float32).reshape((1, 6)), cast_dtype=tf.int32) - _test_cast(np.arange(6.0, dtype=np.float32).reshape((1, 6)), cast_dtype=tf.uint8) - _test_cast(np.arange(6.0, dtype=np.int32).reshape((1, 6)), cast_dtype=tf.int64) + for use_mlir in [False, True]: + _test_cast( + np.arange(6.0, dtype=np.float32).reshape((1, 6)), cast_dtype=tf.int32, use_mlir=use_mlir + ) + _test_cast( + np.arange(6.0, dtype=np.float32).reshape((1, 6)), cast_dtype=tf.uint8, use_mlir=use_mlir + ) + _test_cast( + np.arange(6.0, dtype=np.int32).reshape((1, 6)), cast_dtype=tf.int64, use_mlir=use_mlir + ) ####################################################################### From 570767f78851fbc0472c230adcb2c98e47bad0e8 Mon Sep 17 00:00:00 2001 From: Trevor Morris Date: Fri, 19 Mar 2021 01:09:45 -0700 Subject: [PATCH 4/7] Free TensorRT engine and context (#7702) --- src/runtime/contrib/tensorrt/tensorrt_runtime.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc index 3f87f8d00ee66..e28c5a8c61d04 100644 --- a/src/runtime/contrib/tensorrt/tensorrt_runtime.cc +++ b/src/runtime/contrib/tensorrt/tensorrt_runtime.cc @@ -109,6 +109,14 @@ class TensorRTRuntime : public JSONRuntimeBase { } #ifdef TVM_GRAPH_RUNTIME_TENSORRT + /*! \brief Destroy engines and contexts. */ + ~TensorRTRuntime() { + for (auto& it : trt_engine_cache_) { + it.second.context->destroy(); + it.second.engine->destroy(); + } + } + /*! \brief Run inference using built engine. */ void Run() override { BuildEngine(); From 35b43e1837cd5fcd688798cc3bf60ccc7f08bfbc Mon Sep 17 00:00:00 2001 From: Josh Fromm Date: Fri, 19 Mar 2021 02:32:58 -0700 Subject: [PATCH 5/7] Change behavior of onnx importer to throw when user provides an input no in the graph. (#7699) --- python/tvm/relay/frontend/onnx.py | 7 +++- tests/python/frontend/onnx/test_forward.py | 39 ++++++++++++++++++---- 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 391eaaab5f64b..fab4ae889dd7e 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2914,7 +2914,7 @@ def from_onnx(self, graph, opset, get_output_expr=False): else: self._num_input += 1 if i_name in self._shape: - i_shape = self._shape[i_name] + i_shape = self._shape.pop(i_name) else: if "?" in str(i_shape): warning_msg = ( @@ -2929,6 +2929,11 @@ def from_onnx(self, graph, opset, get_output_expr=False): dtype = d_type self._nodes[i_name] = new_var(i_name, shape=i_shape, dtype=dtype) self._inputs[i_name] = self._nodes[i_name] + assert ( + len(self._shape) == 0 + ), "User specified the shape for inputs that weren't found in the graph: " + str( + self._shape + ) # get list of unsupported ops convert_map = _get_convert_map(opset) unsupported_ops = set() diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 177bed66f466b..5a6216ac705de 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -19,6 +19,7 @@ from onnx import helper, TensorProto, mapping, numpy_helper import torch import torchvision +import pytest import tvm.topi.testing import tvm from tvm import relay @@ -57,7 +58,7 @@ def get_tvm_output_with_vm( mod = relay.transform.DynamicToStatic()(mod) ex = relay.create_executor("vm", mod=mod, ctx=ctx, target=target) - result = ex.evaluate()(*input_data) + result = ex.evaluate()(*input_data, **params) if isinstance(result, tvm.runtime.NDArray): return result.asnumpy() return [r.asnumpy() for r in result] @@ -500,7 +501,7 @@ def test_squeeze(): model = helper.make_model(graph, producer_name="squeeze_test") x = np.random.uniform(size=in_shape).astype("float32") - verify_with_ort_with_inputs(model, [x], [out_shape]) + verify_with_ort_with_inputs(model, [x], [out_shape], opset=11) @tvm.testing.uses_gpu @@ -538,7 +539,7 @@ def test_unsqueeze(): ) model = helper.make_model(graph, producer_name="squeeze_test") - verify_with_ort(model, [in_shape]) + verify_with_ort(model, [in_shape], opset=11) def verify_gather(in_shape, indices, axis, dtype): @@ -1584,7 +1585,7 @@ def verify_pad_v11(indata, pads, mode="constant", value=0.0): pads = np.array(pads) # onnx graph if mode in ["edge", "reflect"]: - inputs = [indata, pads] + inputs = [indata] outdata = np.pad(indata, pad_width=np_pads, mode=mode) node = helper.make_node("Pad", inputs=["input", "pads"], outputs=["output"], mode=mode) graph = helper.make_graph( @@ -1600,7 +1601,7 @@ def verify_pad_v11(indata, pads, mode="constant", value=0.0): ], ) else: - inputs = [indata, pads, np.array([value]).astype("float32")] + inputs = [indata] outdata = np.pad(indata, pad_width=np_pads, mode="constant", constant_values=value) node = helper.make_node( "Pad", inputs=["input", "pads", "constant_value"], outputs=["output"], mode="constant" @@ -1663,7 +1664,7 @@ def verify_reduce_func(func, data, axis, keepdims): model = helper.make_model(graph, producer_name="reduce_test") - verify_with_ort_with_inputs(model, [data], [outshape]) + verify_with_ort_with_inputs(model, [data], [outshape], opset=11) @tvm.testing.uses_gpu @@ -4089,6 +4090,31 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): verify_cumsum(data, 1, 1, 1, type="int32") +def test_wrong_input(): + node = helper.make_node( + "Softplus", + inputs=["X"], + outputs=["Y"], + ) + + graph = helper.make_graph( + [node], + "softplus_test", + inputs=[helper.make_tensor_value_info("X", TensorProto.FLOAT, list([5]))], + outputs=[helper.make_tensor_value_info("Y", TensorProto.FLOAT, list([5]))], + ) + model = helper.make_model(graph, producer_name="softplus_test") + + # Check that the graph can import correctly with proper shape definitions. + correct_shape_dict = {"X": [5]} + relay.frontend.from_onnx(model, shape=correct_shape_dict) + + # Check that an assertion is triggered when an input not in the graph is provided. + wrong_shape_dict = {"Z": [5]} + with pytest.raises(AssertionError): + relay.frontend.from_onnx(model, shape=wrong_shape_dict) + + if __name__ == "__main__": test_flatten() test_reshape() @@ -4167,3 +4193,4 @@ def verify_cumsum(indata, axis, exclusive=0, reverse=0, type="float32"): test_maxunpool() test_softplus() test_cumsum() + test_wrong_input() From 9a29141db81a4128f49c84b5a2ad50325eb6c7bd Mon Sep 17 00:00:00 2001 From: masahi Date: Sat, 20 Mar 2021 05:44:05 +0900 Subject: [PATCH 6/7] [Vulkan] Workaround for zero size allocation (#7691) --- src/runtime/vulkan/vulkan.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/runtime/vulkan/vulkan.cc b/src/runtime/vulkan/vulkan.cc index ff1b82f930d73..f56318aee94d7 100644 --- a/src/runtime/vulkan/vulkan.cc +++ b/src/runtime/vulkan/vulkan.cc @@ -120,6 +120,10 @@ class VulkanDeviceAPI final : public DeviceAPI { std::vector GetComputeQueueFamilies(VkPhysicalDevice phy_dev); void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, DLDataType type_hint) final { + if (nbytes == 0) { + // Vulkan seems to have issues if we return nullptr on zero size alloc + nbytes = 1; + } const auto& vctx = context(ctx.device_id); VkBufferCreateInfo info; info.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO; From aa494cfbfd0943855889444f37e7f032b0b58051 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Fri, 19 Mar 2021 15:19:25 -0700 Subject: [PATCH 7/7] [AutoScheduler] Add function name in message (#7703) * [AutoScheduler] Add function name in message * fix --- python/tvm/auto_scheduler/dispatcher.py | 49 +++++++++++-------- .../tvm/auto_scheduler/relay_integration.py | 7 ++- src/relay/backend/compile_engine.cc | 2 +- 3 files changed, 35 insertions(+), 23 deletions(-) diff --git a/python/tvm/auto_scheduler/dispatcher.py b/python/tvm/auto_scheduler/dispatcher.py index 6a25960fe7b74..c843dcfccdf05 100644 --- a/python/tvm/auto_scheduler/dispatcher.py +++ b/python/tvm/auto_scheduler/dispatcher.py @@ -50,7 +50,7 @@ class DispatchContext(object): def __init__(self): self._old_ctx = DispatchContext.current - def query(self, target, workload_key, has_complex_op, dag): + def query(self, target, workload_key, has_complex_op, dag, func_name): """ Query the context to get the specific config for a workload. If cannot find the result inside this context, this function will query it @@ -66,15 +66,17 @@ def query(self, target, workload_key, has_complex_op, dag): Whether this workload has at least one complex op. dag: ComputeDAG The ComputeDAG of the workload. + func_name: str + The function name of this workload. Returns ------- state : StateObject The state that stores schedule configuration for the workload """ - ret = self._query_inside(target, workload_key) + ret = self._query_inside(target, workload_key, func_name) if ret is None: - ret = self._old_ctx.query(target, workload_key, has_complex_op, dag) + ret = self._old_ctx.query(target, workload_key, has_complex_op, dag, func_name) return ret def update(self, target, workload_key, state): @@ -92,7 +94,7 @@ def update(self, target, workload_key, state): """ raise NotImplementedError() - def _query_inside(self, target, workload_key): + def _query_inside(self, target, workload_key, func_name): """ Query the context to get the specific config for a workload. This function only query config inside this context. @@ -103,6 +105,8 @@ def _query_inside(self, target, workload_key): The current target workload_key : str The current workload_key. + func_name: str + The function name of this workload. Returns ------- @@ -241,7 +245,7 @@ def load(self, records, n_lines=None): logger.debug("Finish loading %d records", counter) - def _query_inside(self, target, workload_key): + def _query_inside(self, target, workload_key, func_name): if target is None: raise RuntimeError( "Need a target context to find the history best. " @@ -343,18 +347,20 @@ def __init__( records, n_lines=None, include_compatible=True ) - def query(self, target, workload_key, has_complex_op, dag): + def query(self, target, workload_key, has_complex_op, dag, func_name): if has_complex_op or self.sample_simple_workloads: - ret = self._query_inside(target, workload_key) + ret = self._query_inside(target, workload_key, func_name) else: - ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key) + ret = super(ApplyHistoryBestOrSample, self)._query_inside( + target, workload_key, func_name + ) if ret is None: - ret = self._old_ctx.query(target, workload_key, has_complex_op, dag) + ret = self._old_ctx.query(target, workload_key, has_complex_op, dag, func_name) return ret - def _query_inside(self, target, workload_key): - ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key) + def _query_inside(self, target, workload_key, func_name): + ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key, func_name) if ret is not None: return ret @@ -386,7 +392,9 @@ def _query_inside(self, target, workload_key): # Load the sampled records and query again. self.load(log_file) - ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key) + ret = super(ApplyHistoryBestOrSample, self)._query_inside( + target, workload_key, func_name + ) del measure_ctx return ret @@ -411,18 +419,19 @@ def __init__(self): # a set to prevent print duplicated message self.messages = set() - def query(self, target, workload_key, has_complex_op, dag): + def query(self, target, workload_key, has_complex_op, dag, func_name): key = (str(target), workload_key) if key in self.memory: return self.memory[key] if self.verbose == 2 or (has_complex_op and self.verbose == 1): msg = ( - "-----------------------------------\n" - "Cannot find tuned schedules for target=%s, workload_key=%s. " - "A fallback TOPI schedule is used, " - "which may bring great performance regression or even compilation failure. " - "Compute DAG info:\n%s" % (target, workload_key, dag) + f"-----------------------------------\n" + f"{func_name}\n" + f"Cannot find tuned schedules for target={target}, workload_key={workload_key}. " + f"A fallback TOPI schedule is used, " + f"which may bring great performance regression or even compilation failure. " + f"Compute DAG info:\n{dag}" ) if msg not in self.messages: self.messages.add(msg) @@ -434,8 +443,8 @@ def query(self, target, workload_key, has_complex_op, dag): self.memory[key] = state return state - def _query_inside(self, target, workload_key): - _ = target = workload_key + def _query_inside(self, target, workload_key, func_name): + _ = target = workload_key = func_name raise RuntimeError("This function should never be called") def update(self, target, workload_key, state): diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py index 6cce30f2f5599..366d3d021d9e5 100644 --- a/python/tvm/auto_scheduler/relay_integration.py +++ b/python/tvm/auto_scheduler/relay_integration.py @@ -256,7 +256,7 @@ def traverse(t): @tvm._ffi.register_func("auto_scheduler.relay_integration.auto_schedule_topi_compute") -def auto_schedule_topi(outs): +def auto_schedule_topi(func_name, outs): """Use auto-scheduler to schedule any topi compute function. Note: This is used internally for relay integration. Do @@ -264,6 +264,9 @@ def auto_schedule_topi(outs): Parameters ---------- + func_name: str + The name of the function being scheduled. + outs: List[Tensor] The output tensors of topi compute functions @@ -289,7 +292,7 @@ def auto_schedule_topi(outs): target = tvm.target.Target.current() dispatch_ctx = DispatchContext.current - state = dispatch_ctx.query(target, key, has_complex_op, dag) + state = dispatch_ctx.query(target, key, has_complex_op, dag, func_name) schedule = None env = TracingEnvironment.current diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index ae975a5f32401..f492b70565ace 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -157,7 +157,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute"); ICHECK(fauto_schedule != nullptr) << "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered"; - ObjectRef obj = (*fauto_schedule)(tensor_outs); + ObjectRef obj = (*fauto_schedule)(String(cache_node->func_name), tensor_outs); if (obj.defined()) { schedule = Downcast(obj); }