Skip to content

Commit

Permalink
Merge branch 'main' into andrewzhaoluo-add-cumprod
Browse files Browse the repository at this point in the history
* main:
  [AutoScheduler] Add function name in message (apache#7703)
  [Vulkan] Workaround for zero size allocation (apache#7691)
  Change behavior of onnx importer to throw when user provides an input no in the graph. (apache#7699)
  Free TensorRT engine and context (apache#7702)
  [TFLite] Cast operator adapted for MLIR-based convertor (apache#7639)
  [CPP_RPC] allow user supplied work dir (apache#7670)
  Default value for graph_runtime Init lookup_linked_param_func (apache#7676)
  • Loading branch information
Andrew Zhao Luo authored and Andrew Zhao Luo committed Mar 19, 2021
2 parents 78ee787 + aa494cf commit 04f0f41
Show file tree
Hide file tree
Showing 16 changed files with 162 additions and 71 deletions.
10 changes: 9 additions & 1 deletion apps/cpp_rpc/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ static const string kUsage =
"--tracker - The RPC tracker address in host:port format e.g. 10.1.1.2:9190 Default=\"\"\n"
"--key - The key used to identify the device type in tracker. Default=\"\"\n"
"--custom-addr - Custom IP Address to Report to RPC Tracker. Default=\"\"\n"
"--work-dir - Custom work directory. Default=\"\"\n"
"--silent - Whether to run in silent mode. Default=False\n"
"\n"
" Example\n"
Expand All @@ -70,6 +71,7 @@ static const string kUsage =
* \arg tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default=""
* \arg key The key used to identify the device type in tracker. Default=""
* \arg custom_addr Custom IP Address to Report to RPC Tracker. Default=""
* \arg work_dir Custom work directory. Default=""
* \arg silent Whether run in silent mode. Default=False
*/
struct RpcServerArgs {
Expand All @@ -79,6 +81,7 @@ struct RpcServerArgs {
string tracker;
string key;
string custom_addr;
string work_dir;
bool silent = false;
#if defined(WIN32)
std::string mmap_path;
Expand All @@ -96,6 +99,7 @@ void PrintArgs(const RpcServerArgs& args) {
LOG(INFO) << "tracker = " << args.tracker;
LOG(INFO) << "key = " << args.key;
LOG(INFO) << "custom_addr = " << args.custom_addr;
LOG(INFO) << "work_dir = " << args.work_dir;
LOG(INFO) << "silent = " << ((args.silent) ? ("True") : ("False"));
}

Expand Down Expand Up @@ -238,6 +242,10 @@ void ParseCmdArgs(int argc, char* argv[], struct RpcServerArgs& args) {
dmlc::InitLogging("--minloglevel=0");
}
#endif
const string work_dir = GetCmdOption(argc, argv, "--work-dir=");
if (!work_dir.empty()) {
args.work_dir = work_dir;
}
}

/*!
Expand Down Expand Up @@ -274,7 +282,7 @@ int RpcServer(int argc, char* argv[]) {
#endif

RPCServerCreate(args.host, args.port, args.port_end, args.tracker, args.key, args.custom_addr,
args.silent);
args.work_dir, args.silent);
return 0;
}

Expand Down
35 changes: 20 additions & 15 deletions apps/cpp_rpc/rpc_env.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ int mkdir(const char* path, int /* ignored */) { return _mkdir(path); }
#include <iostream>
#include <string>
#include <vector>

#include "../../src/support/utils.h"
#include "rpc_env.h"

Expand Down Expand Up @@ -85,25 +84,31 @@ void CleanDir(const std::string& dirname);
*/
std::string BuildSharedLibrary(std::string file_in);

RPCEnv::RPCEnv() {
RPCEnv::RPCEnv(const std::string& wd) {
if (wd != "") {
base_ = wd + "/.cache";
mkdir(wd.c_str(), 0777);
mkdir(base_.c_str(), 0777);
} else {
#if defined(ANDROID) || defined(__ANDROID__)
char cwd[PATH_MAX];
auto cmdline = fopen("/proc/self/cmdline", "r");
fread(cwd, 1, sizeof(cwd), cmdline);
fclose(cmdline);
base_ = "/data/data/" + std::string(cwd) + "/cache/rpc";
char cwd[PATH_MAX];
auto cmdline = fopen("/proc/self/cmdline", "r");
fread(cwd, 1, sizeof(cwd), cmdline);
fclose(cmdline);
base_ = "/data/data/" + std::string(cwd) + "/cache/rpc";
#elif !defined(_WIN32)
char cwd[PATH_MAX];
if (getcwd(cwd, sizeof(cwd))) {
base_ = std::string(cwd) + "/rpc";
} else {
base_ = "./rpc";
}
char cwd[PATH_MAX];
if (getcwd(cwd, sizeof(cwd))) {
base_ = std::string(cwd) + "/rpc";
} else {
base_ = "./rpc";
}
#else
base_ = "./rpc";
base_ = "./rpc";
#endif
mkdir(base_.c_str(), 0777);
}

mkdir(base_.c_str(), 0777);
TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath").set_body([this](TVMArgs args, TVMRetValue* rv) {
*rv = this->GetPath(args[0]);
});
Expand Down
2 changes: 1 addition & 1 deletion apps/cpp_rpc/rpc_env.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ struct RPCEnv {
/*!
* \brief Constructor Init The RPC Environment initialize function
*/
RPCEnv();
RPCEnv(const std::string& word_dir = "");
/*!
* \brief GetPath To get the workpath from packed function
* \param name The file name
Expand Down
21 changes: 12 additions & 9 deletions apps/cpp_rpc/rpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,15 @@ class RPCServer {
* \brief Constructor.
*/
RPCServer(std::string host, int port, int port_end, std::string tracker_addr, std::string key,
std::string custom_addr)
std::string custom_addr, std::string work_dir)
: host_(std::move(host)),
port_(port),
my_port_(0),
port_end_(port_end),
tracker_addr_(std::move(tracker_addr)),
key_(std::move(key)),
custom_addr_(std::move(custom_addr)) {}
custom_addr_(std::move(custom_addr)),
work_dir_(std::move(work_dir)) {}

/*!
* \brief Destructor.
Expand Down Expand Up @@ -174,7 +175,7 @@ class RPCServer {
const pid_t worker_pid = fork();
if (worker_pid == 0) {
// Worker process
ServerLoopProc(conn, addr);
ServerLoopProc(conn, addr, work_dir_);
_exit(0);
}

Expand All @@ -201,7 +202,7 @@ class RPCServer {
} else {
auto pid = fork();
if (pid == 0) {
ServerLoopProc(conn, addr);
ServerLoopProc(conn, addr, work_dir_);
exit(0);
}
// Wait for the result
Expand Down Expand Up @@ -308,9 +309,10 @@ class RPCServer {
* \param sock The socket information
* \param addr The socket address information
*/
static void ServerLoopProc(support::TCPSocket sock, support::SockAddr addr) {
static void ServerLoopProc(support::TCPSocket sock, support::SockAddr addr,
std::string work_dir) {
// Server loop
const auto env = RPCEnv();
const auto env = RPCEnv(work_dir);
RPCServerLoop(int(sock.sockfd));
LOG(INFO) << "Finish serving " << addr.AsString();
env.CleanUp();
Expand Down Expand Up @@ -339,6 +341,7 @@ class RPCServer {
std::string tracker_addr_;
std::string key_;
std::string custom_addr_;
std::string work_dir_;
support::TCPSocket listen_sock_;
support::TCPSocket tracker_sock_;
};
Expand Down Expand Up @@ -370,19 +373,19 @@ void ServerLoopFromChild(SOCKET socket) {
* silent mode. Default=True
*/
void RPCServerCreate(std::string host, int port, int port_end, std::string tracker_addr,
std::string key, std::string custom_addr, bool silent) {
std::string key, std::string custom_addr, std::string work_dir, bool silent) {
if (silent) {
// Only errors and fatal is logged
dmlc::InitLogging("--minloglevel=2");
}
// Start the rpc server
RPCServer rpc(std::move(host), port, port_end, std::move(tracker_addr), std::move(key),
std::move(custom_addr));
std::move(custom_addr), std::move(work_dir));
rpc.Start();
}

TVM_REGISTER_GLOBAL("rpc.ServerCreate").set_body([](TVMArgs args, TVMRetValue* rv) {
RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]);
RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7]);
});
} // namespace runtime
} // namespace tvm
3 changes: 2 additions & 1 deletion apps/cpp_rpc/rpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,12 @@ void ServerLoopFromChild(SOCKET socket);
* \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default=""
* \param key The key used to identify the device type in tracker. Default=""
* \param custom_addr Custom IP Address to Report to RPC Tracker. Default=""
* \param work_dir Custom work directory. Default=""
* \param silent Whether run in silent mode. Default=True
*/
void RPCServerCreate(std::string host = "", int port = 9090, int port_end = 9099,
std::string tracker_addr = "", std::string key = "",
std::string custom_addr = "", bool silent = true);
std::string custom_addr = "", std::string work_dir = "", bool silent = true);
} // namespace runtime
} // namespace tvm
#endif // TVM_APPS_CPP_RPC_SERVER_H_
49 changes: 29 additions & 20 deletions python/tvm/auto_scheduler/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class DispatchContext(object):
def __init__(self):
self._old_ctx = DispatchContext.current

def query(self, target, workload_key, has_complex_op, dag):
def query(self, target, workload_key, has_complex_op, dag, func_name):
"""
Query the context to get the specific config for a workload.
If cannot find the result inside this context, this function will query it
Expand All @@ -66,15 +66,17 @@ def query(self, target, workload_key, has_complex_op, dag):
Whether this workload has at least one complex op.
dag: ComputeDAG
The ComputeDAG of the workload.
func_name: str
The function name of this workload.
Returns
-------
state : StateObject
The state that stores schedule configuration for the workload
"""
ret = self._query_inside(target, workload_key)
ret = self._query_inside(target, workload_key, func_name)
if ret is None:
ret = self._old_ctx.query(target, workload_key, has_complex_op, dag)
ret = self._old_ctx.query(target, workload_key, has_complex_op, dag, func_name)
return ret

def update(self, target, workload_key, state):
Expand All @@ -92,7 +94,7 @@ def update(self, target, workload_key, state):
"""
raise NotImplementedError()

def _query_inside(self, target, workload_key):
def _query_inside(self, target, workload_key, func_name):
"""
Query the context to get the specific config for a workload.
This function only query config inside this context.
Expand All @@ -103,6 +105,8 @@ def _query_inside(self, target, workload_key):
The current target
workload_key : str
The current workload_key.
func_name: str
The function name of this workload.
Returns
-------
Expand Down Expand Up @@ -241,7 +245,7 @@ def load(self, records, n_lines=None):

logger.debug("Finish loading %d records", counter)

def _query_inside(self, target, workload_key):
def _query_inside(self, target, workload_key, func_name):
if target is None:
raise RuntimeError(
"Need a target context to find the history best. "
Expand Down Expand Up @@ -343,18 +347,20 @@ def __init__(
records, n_lines=None, include_compatible=True
)

def query(self, target, workload_key, has_complex_op, dag):
def query(self, target, workload_key, has_complex_op, dag, func_name):
if has_complex_op or self.sample_simple_workloads:
ret = self._query_inside(target, workload_key)
ret = self._query_inside(target, workload_key, func_name)
else:
ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key)
ret = super(ApplyHistoryBestOrSample, self)._query_inside(
target, workload_key, func_name
)

if ret is None:
ret = self._old_ctx.query(target, workload_key, has_complex_op, dag)
ret = self._old_ctx.query(target, workload_key, has_complex_op, dag, func_name)
return ret

def _query_inside(self, target, workload_key):
ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key)
def _query_inside(self, target, workload_key, func_name):
ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key, func_name)
if ret is not None:
return ret

Expand Down Expand Up @@ -386,7 +392,9 @@ def _query_inside(self, target, workload_key):

# Load the sampled records and query again.
self.load(log_file)
ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key)
ret = super(ApplyHistoryBestOrSample, self)._query_inside(
target, workload_key, func_name
)

del measure_ctx
return ret
Expand All @@ -411,18 +419,19 @@ def __init__(self):
# a set to prevent print duplicated message
self.messages = set()

def query(self, target, workload_key, has_complex_op, dag):
def query(self, target, workload_key, has_complex_op, dag, func_name):
key = (str(target), workload_key)
if key in self.memory:
return self.memory[key]

if self.verbose == 2 or (has_complex_op and self.verbose == 1):
msg = (
"-----------------------------------\n"
"Cannot find tuned schedules for target=%s, workload_key=%s. "
"A fallback TOPI schedule is used, "
"which may bring great performance regression or even compilation failure. "
"Compute DAG info:\n%s" % (target, workload_key, dag)
f"-----------------------------------\n"
f"{func_name}\n"
f"Cannot find tuned schedules for target={target}, workload_key={workload_key}. "
f"A fallback TOPI schedule is used, "
f"which may bring great performance regression or even compilation failure. "
f"Compute DAG info:\n{dag}"
)
if msg not in self.messages:
self.messages.add(msg)
Expand All @@ -434,8 +443,8 @@ def query(self, target, workload_key, has_complex_op, dag):
self.memory[key] = state
return state

def _query_inside(self, target, workload_key):
_ = target = workload_key
def _query_inside(self, target, workload_key, func_name):
_ = target = workload_key = func_name
raise RuntimeError("This function should never be called")

def update(self, target, workload_key, state):
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,17 @@ def traverse(t):


@tvm._ffi.register_func("auto_scheduler.relay_integration.auto_schedule_topi_compute")
def auto_schedule_topi(outs):
def auto_schedule_topi(func_name, outs):
"""Use auto-scheduler to schedule any topi compute function.
Note: This is used internally for relay integration. Do
not use this as a general user-facing API.
Parameters
----------
func_name: str
The name of the function being scheduled.
outs: List[Tensor]
The output tensors of topi compute functions
Expand All @@ -289,7 +292,7 @@ def auto_schedule_topi(outs):
target = tvm.target.Target.current()

dispatch_ctx = DispatchContext.current
state = dispatch_ctx.query(target, key, has_complex_op, dag)
state = dispatch_ctx.query(target, key, has_complex_op, dag, func_name)
schedule = None

env = TracingEnvironment.current
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2914,7 +2914,7 @@ def from_onnx(self, graph, opset, get_output_expr=False):
else:
self._num_input += 1
if i_name in self._shape:
i_shape = self._shape[i_name]
i_shape = self._shape.pop(i_name)
else:
if "?" in str(i_shape):
warning_msg = (
Expand All @@ -2929,6 +2929,11 @@ def from_onnx(self, graph, opset, get_output_expr=False):
dtype = d_type
self._nodes[i_name] = new_var(i_name, shape=i_shape, dtype=dtype)
self._inputs[i_name] = self._nodes[i_name]
assert (
len(self._shape) == 0
), "User specified the shape for inputs that weren't found in the graph: " + str(
self._shape
)
# get list of unsupported ops
convert_map = _get_convert_map(opset)
unsupported_ops = set()
Expand Down
Loading

0 comments on commit 04f0f41

Please sign in to comment.