diff --git a/python/tvm/rpc/client.py b/python/tvm/rpc/client.py index 7196f209712ee..6fcd3713f566a 100644 --- a/python/tvm/rpc/client.py +++ b/python/tvm/rpc/client.py @@ -458,7 +458,7 @@ def connect(url, port, key="", session_timeout=0, session_constructor_args=None) Additional key to match server session_timeout : float, optional - The duration of the session, allows server to kill + The duration of the session in seconds, allows server to kill the connection when duration is longer than this value. When duration is zero, it means the request must always be kept alive. diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py index d0de0520a6741..4fc90ac27a43d 100644 --- a/python/tvm/runtime/vm.py +++ b/python/tvm/runtime/vm.py @@ -28,6 +28,7 @@ from tvm._ffi import base as _base from .object import Object from . import _ffi_api, container +from ..rpc.base import RPC_SESS_MASK def _convert(arg, cargs): @@ -341,6 +342,9 @@ def __init__(self, exe, device, memory_cfg=None): self._exec = exe self._init = self.module["init"] self._invoke = self.module["invoke"] + self._invoke_stateful = self.module["invoke_stateful"] + self._get_output = self.module["get_output"] + self._get_num_outputs = self.module["get_num_outputs"] self._set_input = self.module["set_input"] self._setup_device(device, memory_cfg) @@ -356,7 +360,7 @@ def _setup_device(self, dev, memory_cfg): devs = [dev] # CPU is required for executing shape functions - if not any(c.device_type == tvm.cpu().device_type for c in devs): + if not any(c.device_type % RPC_SESS_MASK == tvm.cpu().device_type for c in devs): devs.append(tvm.cpu()) default_alloc_type = VirtualMachine.POOLED_ALLOCATOR @@ -374,7 +378,7 @@ def _setup_device(self, dev, memory_cfg): ) init_args = [] for device in devs: - init_args.append(device.device_type) + init_args.append(device.device_type % RPC_SESS_MASK) init_args.append(device.device_id) alloc_type = memory_cfg[device] if device in memory_cfg else default_alloc_type init_args.append(alloc_type) @@ -455,3 +459,34 @@ def run(self, *args, **kwargs): The output. """ return self.invoke("main", *args, **kwargs) + + def invoke_stateful(self, func_name, *args, **kwargs): + """Invoke a function and ignore the returned result. + + Use this function when running over rpc because it is currently + impossible to return a ADT object over rpc. To get the outputs, use + :py:func`get_outputs`. + + Parameters + ---------- + func_name : str + The name of the function. + + args : list[tvm.runtime.NDArray] or list[np.ndarray] + The arguments to the function. + + kwargs: dict of str to tvm.runtime.NDArray or np.ndarray + Named arguments to the function. + """ + if args or kwargs: + self.set_input(func_name, *args, **kwargs) + self._invoke_stateful(func_name) + + def get_outputs(self): + """Get the outputs from a call to :py:func`invoke_stateful`. + + Returns + ------- + outputs : List[NDArray] + """ + return [self._get_output(i) for i in range(self._get_num_outputs())] diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index b5768146b3f76..236be8e56a705 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -477,16 +477,17 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { TVMArgs args = RecvPackedSeq(); this->SwitchToState(kWaitForAsyncCallback); - GetServingSession()->AsyncCallFunc(reinterpret_cast(call_handle), args.values, - args.type_codes, args.size(), - [this](RPCCode status, TVMArgs args) { - if (status == RPCCode::kException) { - this->ReturnException(args.values[0].v_str); - } else { - this->ReturnPackedSeq(args); - } - this->SwitchToState(kRecvPacketNumBytes); - }); + GetServingSession()->AsyncCallFunc( + reinterpret_cast(call_handle), args.values, args.type_codes, args.size(), + [this](RPCCode status, TVMArgs args) { + if (status == RPCCode::kException) { + this->ReturnException(args.values[0].v_str); + } else { + ValidateArguments(args.values, args.type_codes, args.size()); + this->ReturnPackedSeq(args); + } + this->SwitchToState(kRecvPacketNumBytes); + }); } void HandleInitServer() { @@ -637,7 +638,7 @@ RPCCode RPCEndpoint::HandleUntilReturnEvent(bool client_mode, RPCSession::FEncod if (handler_->CanCleanShutdown()) { return RPCCode::kShutdown; } else { - LOG(FATAL) << "Channel closes before we get neded bytes"; + LOG(FATAL) << "Channel closes before we get needed bytes"; } } } @@ -794,7 +795,7 @@ void RPCEndpoint::CallFunc(RPCSession::PackedFuncHandle h, const TVMValue* arg_v handler_->SendPackedSeq(arg_values, arg_type_codes, num_args, true); code = HandleUntilReturnEvent(true, encode_return); - ICHECK(code == RPCCode::kReturn) << "code=" << static_cast(code); + ICHECK(code == RPCCode::kReturn) << "code=" << RPCCodeToString(code); } void RPCEndpoint::CopyToRemote(void* from_bytes, DLTensor* to, uint64_t nbytes) { diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 76ca009bc7417..a0edb3bad93a4 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -132,6 +132,21 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, *rv = Invoke(func, func_args); } }); + } else if (name == "invoke_stateful") { + // TODO(tkonolige, jroesch, tqchen): invoke_stateful and get_output are + // stop-gap measure to allow using vm over a remote connection. + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + PackedFunc invoke = GetFunction("invoke", sptr_to_self); + TVMRetValue rv_; + invoke.CallPacked(args, &rv_); + }); + } else if (name == "get_output") { + return TypedPackedFunc([this](int64_t index) { + return Downcast(Downcast(this->return_register_)[index]); + }); + } else if (name == "get_num_outputs") { + return TypedPackedFunc( + [this]() -> int64_t { return Downcast(this->return_register_).size(); }); } else if (name == "init") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { ICHECK_EQ(args.size() % 3, 0); @@ -165,8 +180,21 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, for (int i = 1; i < args.size(); ++i) { Index device_type = vm_func.params_device_type[i - 1]; Device dev = GetDevice(device_type); - ObjectRef obj = CopyTo(args[i], dev); - func_args[i - 1] = obj; + + if (args[i].type_code() == kTVMDLTensorHandle) { + // Automatically convert input DLTensors to NDArray + DLTensor* tensor = args[i]; + std::vector shape; + for (int64_t i = 0; i < tensor->ndim; i++) { + shape.push_back(tensor->shape[i]); + } + NDArray ary = NDArray::Empty(shape, tensor->dtype, dev); + ary.CopyFrom(tensor); + func_args[i - 1] = ary; + } else { + ObjectRef obj = CopyTo(args[i], dev); + func_args[i - 1] = obj; + } } inputs_.erase(func_name); inputs_.emplace(func_name, func_args); diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 7ca06c5c97e04..58985832fb359 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -16,6 +16,7 @@ # under the License. import numpy as np import pytest +import time import tvm from tvm import runtime @@ -823,8 +824,12 @@ def test_vm_rpc(): path = temp.relpath("vm_library.so") vm_exec.mod.export_library(path) - # Use LocalRPC for testing. - remote = rpc.LocalSession() + # Use local rpc server for testing. + # Server must use popen so it doesn't inherit the current process state. It + # will crash otherwise. + server = rpc.Server("localhost", port=9120, use_popen=True) + time.sleep(2) + remote = rpc.connect(server.host, server.port, session_timeout=10) # Upload the serialized Executable. remote.upload(path) @@ -837,10 +842,16 @@ def test_vm_rpc(): np_input = np.random.uniform(size=(10, 1)).astype("float32") input_tensor = tvm.nd.array(np_input, ctx) # Invoke its "main" function. - out = vm_factory.invoke("main", [input_tensor]) + out = vm_factory.invoke("main", input_tensor) # Check the result. np.testing.assert_allclose(out.asnumpy(), np_input + np_input) + # delete tensors before the server shuts down so we don't throw errors. + del input_tensor + del out + + server.terminate() + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/unittest/test_runtime_profiling.py b/tests/python/unittest/test_runtime_profiling.py index c4d855a3c7d3b..dd2f58ec77f45 100644 --- a/tests/python/unittest/test_runtime_profiling.py +++ b/tests/python/unittest/test_runtime_profiling.py @@ -33,7 +33,7 @@ def test_vm(target, dev): vm = profiler_vm.VirtualMachineProfiler(exe, dev) data = np.random.rand(1, 1, 28, 28).astype("float32") - report = vm.profile([data], func_name="main") + report = vm.profile(data, func_name="main") assert "fused_nn_softmax" in report assert "Total time" in report diff --git a/tests/scripts/task_python_integration.sh b/tests/scripts/task_python_integration.sh index 613c7cbdf34f3..00b63af486460 100755 --- a/tests/scripts/task_python_integration.sh +++ b/tests/scripts/task_python_integration.sh @@ -65,6 +65,7 @@ if python -c "import tvm; from tvm.relay.op.contrib.ethosn import ethosn_availab fi run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME}-contrib tests/python/contrib +# forked is needed because the global registry gets contaminated TVM_TEST_TARGETS="${TVM_RELAY_TEST_TARGETS:-llvm;cuda}" \ run_pytest ctypes ${TVM_INTEGRATION_TESTSUITE_NAME}-relay tests/python/relay