From aa3bcd9d3374878c5e958b842f51bfd82f0ebd9e Mon Sep 17 00:00:00 2001 From: Mohamad Katanbaf Date: Wed, 4 May 2022 17:10:18 -0700 Subject: [PATCH] Implemented rpc logging (#10967) Co-authored-by: Mohamad --- CMakeLists.txt | 1 + python/tvm/micro/session.py | 1 + python/tvm/rpc/client.py | 13 +- .../crt/microtvm_rpc_server/rpc_server.cc | 2 - src/runtime/micro/micro_session.cc | 8 + src/runtime/minrpc/minrpc_interfaces.h | 93 +++ src/runtime/minrpc/minrpc_logger.cc | 291 ++++++++ src/runtime/minrpc/minrpc_logger.h | 296 ++++++++ src/runtime/minrpc/minrpc_server.h | 649 +++++++++++------- src/runtime/minrpc/minrpc_server_logging.h | 166 +++++ src/runtime/rpc/rpc_channel_logger.h | 183 +++++ src/runtime/rpc/rpc_endpoint.h | 2 + src/runtime/rpc/rpc_socket_impl.cc | 21 +- tests/python/unittest/test_runtime_rpc.py | 23 +- 14 files changed, 1474 insertions(+), 275 deletions(-) create mode 100644 src/runtime/minrpc/minrpc_interfaces.h create mode 100644 src/runtime/minrpc/minrpc_logger.cc create mode 100644 src/runtime/minrpc/minrpc_logger.h create mode 100644 src/runtime/minrpc/minrpc_server_logging.h create mode 100644 src/runtime/rpc/rpc_channel_logger.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 90cc0f95185d..7023caf97eb5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -318,6 +318,7 @@ list(APPEND COMPILER_SRCS "src/target/datatype/myfloat/myfloat.cc") tvm_file_glob(GLOB RUNTIME_SRCS src/runtime/*.cc src/runtime/vm/*.cc + src/runtime/minrpc/*.cc ) if(BUILD_FOR_HEXAGON) diff --git a/python/tvm/micro/session.py b/python/tvm/micro/session.py index 4f754d9d442c..4c38476207ba 100644 --- a/python/tvm/micro/session.py +++ b/python/tvm/micro/session.py @@ -133,6 +133,7 @@ def __enter__(self): int(timeouts.session_start_timeout_sec * 1e6), int(timeouts.session_established_timeout_sec * 1e6), self._cleanup, + False, ) ) self.device = self._rpc.cpu(0) diff --git a/python/tvm/rpc/client.py b/python/tvm/rpc/client.py index 4e6c9025383f..eddc324b3390 100644 --- a/python/tvm/rpc/client.py +++ b/python/tvm/rpc/client.py @@ -459,7 +459,9 @@ def request_and_run(self, key, func, priority=1, session_timeout=0, max_retry=2) ) -def connect(url, port, key="", session_timeout=0, session_constructor_args=None): +def connect( + url, port, key="", session_timeout=0, session_constructor_args=None, enable_logging=False +): """Connect to RPC Server Parameters @@ -483,6 +485,9 @@ def connect(url, port, key="", session_timeout=0, session_constructor_args=None) The first element of the list is always a string specifying the name of the session constructor, the following args are the positional args to that function. + enable_logging: boolean + flag to enable/disable logging. Logging is disabled by default. + Returns ------- sess : RPCSession @@ -503,9 +508,9 @@ def connect(url, port, key="", session_timeout=0, session_constructor_args=None) .. code-block:: python client_via_proxy = rpc.connect( - proxy_server_url, proxy_server_port, proxy_server_key, + proxy_server_url, proxy_server_port, proxy_server_key, enable_logging session_constructor_args=[ - "rpc.Connect", internal_url, internal_port, internal_key]) + "rpc.Connect", internal_url, internal_port, internal_key, internal_logging]) """ try: @@ -514,7 +519,7 @@ def connect(url, port, key="", session_timeout=0, session_constructor_args=None) session_constructor_args = session_constructor_args if session_constructor_args else [] if not isinstance(session_constructor_args, (list, tuple)): raise TypeError("Expect the session constructor to be a list or tuple") - sess = _ffi_api.Connect(url, port, key, *session_constructor_args) + sess = _ffi_api.Connect(url, port, key, enable_logging, *session_constructor_args) except NameError: raise RuntimeError("Please compile with USE_RPC=1") return RPCSession(sess) diff --git a/src/runtime/crt/microtvm_rpc_server/rpc_server.cc b/src/runtime/crt/microtvm_rpc_server/rpc_server.cc index ac10c82b580c..b7bae243ecf0 100644 --- a/src/runtime/crt/microtvm_rpc_server/rpc_server.cc +++ b/src/runtime/crt/microtvm_rpc_server/rpc_server.cc @@ -193,8 +193,6 @@ class MicroRPCServer { } // namespace runtime } // namespace tvm -void* operator new[](size_t count, void* ptr) noexcept { return ptr; } - extern "C" { static microtvm_rpc_server_t g_rpc_server = nullptr; diff --git a/src/runtime/micro/micro_session.cc b/src/runtime/micro/micro_session.cc index 9e6664ff5984..6911c2021ac1 100644 --- a/src/runtime/micro/micro_session.cc +++ b/src/runtime/micro/micro_session.cc @@ -38,6 +38,7 @@ #include "../../support/str_escape.h" #include "../rpc/rpc_channel.h" +#include "../rpc/rpc_channel_logger.h" #include "../rpc/rpc_endpoint.h" #include "../rpc/rpc_session.h" #include "crt_config.h" @@ -404,6 +405,13 @@ TVM_REGISTER_GLOBAL("micro._rpc_connect").set_body([](TVMArgs args, TVMRetValue* throw std::runtime_error(ss.str()); } std::unique_ptr channel(micro_channel); + bool enable_logging = false; + if (args.num_args > 7) { + enable_logging = args[7]; + } + if (enable_logging) { + channel.reset(new RPCChannelLogging(std::move(channel))); + } auto ep = RPCEndpoint::Create(std::move(channel), args[0], "", args[6]); auto sess = CreateClientSession(ep); *rv = CreateRPCSessionModule(sess); diff --git a/src/runtime/minrpc/minrpc_interfaces.h b/src/runtime/minrpc/minrpc_interfaces.h new file mode 100644 index 000000000000..a45dee9f2c35 --- /dev/null +++ b/src/runtime/minrpc/minrpc_interfaces.h @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_MINRPC_MINRPC_INTERFACES_H_ +#define TVM_RUNTIME_MINRPC_MINRPC_INTERFACES_H_ + +#include + +#include "rpc_reference.h" + +namespace tvm { +namespace runtime { + +/*! + * \brief Return interface used in ExecInterface to generate and send the responses. + */ +class MinRPCReturnInterface { + public: + virtual ~MinRPCReturnInterface() {} + /*! * \brief sends a response to the client with kTVMNullptr in payload. */ + virtual void ReturnVoid() = 0; + + /*! * \brief sends a response to the client with one kTVMOpaqueHandle in payload. */ + virtual void ReturnHandle(void* handle) = 0; + + /*! * \brief sends an exception response to the client with a kTVMStr in payload. */ + virtual void ReturnException(const char* msg) = 0; + + /*! * \brief sends a packed argument sequnce to the client. */ + virtual void ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args) = 0; + + /*! * \brief sends a copy of the requested remote data to the client. */ + virtual void ReturnCopyFromRemote(uint8_t* data_ptr, uint64_t num_bytes) = 0; + + /*! * \brief sends an exception response to the client with the last TVM erros as the message. */ + virtual void ReturnLastTVMError() = 0; + + /*! * \brief internal error. */ + virtual void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) = 0; +}; + +/*! + * \brief Execute interface used in MinRPCServer to process different received commands + */ +class MinRPCExecInterface { + public: + virtual ~MinRPCExecInterface() {} + + /*! * \brief Execute an Initilize server command. */ + virtual void InitServer(int num_args) = 0; + + /*! * \brief calls a function specified by the call_handle. */ + virtual void NormalCallFunc(uint64_t call_handle, TVMValue* values, int* tcodes, + int num_args) = 0; + + /*! * \brief Execute a copy from remote command by sending the data described in arr to the client + */ + virtual void CopyFromRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* data_ptr) = 0; + + /*! * \brief Execute a copy to remote command by receiving the data described in arr from the + * client */ + virtual int CopyToRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* data_ptr) = 0; + + /*! * \brief calls a system function specified by the code. */ + virtual void SysCallFunc(RPCCode code, TVMValue* values, int* tcodes, int num_args) = 0; + + /*! * \brief internal error. */ + virtual void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) = 0; + + /*! * \brief return the ReturnInterface pointer that is used to generate and send the responses. + */ + virtual MinRPCReturnInterface* GetReturnInterface() = 0; +}; + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_MINRPC_MINRPC_INTERFACES_H_ diff --git a/src/runtime/minrpc/minrpc_logger.cc b/src/runtime/minrpc/minrpc_logger.cc new file mode 100644 index 000000000000..4f3b7e764c9b --- /dev/null +++ b/src/runtime/minrpc/minrpc_logger.cc @@ -0,0 +1,291 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "minrpc_logger.h" + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "minrpc_interfaces.h" +#include "rpc_reference.h" + +namespace tvm { +namespace runtime { + +void Logger::LogTVMValue(int tcode, TVMValue value) { + switch (tcode) { + case kDLInt: { + LogValue("(int64)", value.v_int64); + break; + } + case kDLUInt: { + LogValue("(uint64)", value.v_int64); + break; + } + case kDLFloat: { + LogValue("(float)", value.v_float64); + break; + } + case kTVMDataType: { + LogDLData("DLDataType(code,bits,lane)", &value.v_type); + break; + } + case kDLDevice: { + LogDLDevice("DLDevice(type,id)", &value.v_device); + break; + } + case kTVMPackedFuncHandle: { + LogValue("(PackedFuncHandle)", value.v_handle); + break; + } + case kTVMModuleHandle: { + LogValue("(ModuleHandle)", value.v_handle); + break; + } + case kTVMOpaqueHandle: { + LogValue("(OpaqueHandle)", value.v_handle); + break; + } + case kTVMDLTensorHandle: { + LogValue("(TensorHandle)", value.v_handle); + break; + } + case kTVMNDArrayHandle: { + LogValue("kTVMNDArrayHandle", value.v_handle); + break; + } + case kTVMNullptr: { + Log("Nullptr"); + break; + } + case kTVMStr: { + Log("\""); + Log(value.v_str); + Log("\""); + break; + } + case kTVMBytes: { + TVMByteArray* bytes = static_cast(value.v_handle); + int len = bytes->size; + LogValue("(Bytes) [size]: ", len); + if (PRINT_BYTES) { + Log(", [Values]:"); + Log(" { "); + if (len > 0) { + LogValue("", (uint8_t)bytes->data[0]); + } + for (int j = 1; j < len; j++) LogValue(" - ", (uint8_t)bytes->data[j]); + Log(" } "); + } + break; + } + default: { + Log("ERROR-kUnknownTypeCode)"); + break; + } + } + Log("; "); +} + +void Logger::OutputLog() { + LOG(INFO) << os_.str(); + os_.str(std::string()); +} + +void MinRPCReturnsWithLog::ReturnVoid() { + next_->ReturnVoid(); + logger_->Log("-> ReturnVoid"); + logger_->OutputLog(); +} + +void MinRPCReturnsWithLog::ReturnHandle(void* handle) { + next_->ReturnHandle(handle); + if (code_ == RPCCode::kGetGlobalFunc) { + RegisterHandleName(handle); + } + logger_->LogValue("-> ReturnHandle: ", handle); + logger_->OutputLog(); +} + +void MinRPCReturnsWithLog::ReturnException(const char* msg) { + next_->ReturnException(msg); + logger_->Log("-> Exception: "); + logger_->Log(msg); + logger_->OutputLog(); +} + +void MinRPCReturnsWithLog::ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, + int num_args) { + next_->ReturnPackedSeq(arg_values, type_codes, num_args); + ProcessValues(arg_values, type_codes, num_args); + logger_->OutputLog(); +} + +void MinRPCReturnsWithLog::ReturnCopyFromRemote(uint8_t* data_ptr, uint64_t num_bytes) { + next_->ReturnCopyFromRemote(data_ptr, num_bytes); + logger_->LogValue("-> CopyFromRemote: ", num_bytes); + logger_->LogValue(", ", static_cast(data_ptr)); + logger_->OutputLog(); +} + +void MinRPCReturnsWithLog::ReturnLastTVMError() { + const char* err = TVMGetLastError(); + ReturnException(err); +} + +void MinRPCReturnsWithLog::ThrowError(RPCServerStatus code, RPCCode info) { + next_->ThrowError(code, info); + logger_->Log("-> ERROR: "); + logger_->Log(RPCServerStatusToString(code)); + logger_->OutputLog(); +} + +void MinRPCReturnsWithLog::ProcessValues(const TVMValue* values, const int* tcodes, int num_args) { + if (tcodes != nullptr) { + logger_->Log("-> ["); + for (int i = 0; i < num_args; ++i) { + logger_->LogTVMValue(tcodes[i], values[i]); + + if (tcodes[i] == kTVMOpaqueHandle) { + RegisterHandleName(values[i].v_handle); + } + } + logger_->Log("]"); + } +} + +void MinRPCReturnsWithLog::ResetHandleName(RPCCode code) { + code_ = code; + handle_name_.clear(); +} + +void MinRPCReturnsWithLog::UpdateHandleName(const char* name) { + if (handle_name_.length() != 0) { + handle_name_.append("::"); + } + handle_name_.append(name); +} + +void MinRPCReturnsWithLog::GetHandleName(void* handle) { + if (handle_descriptions_.find(handle) != handle_descriptions_.end()) { + handle_name_.append(handle_descriptions_[handle]); + logger_->LogHandleName(handle_name_); + } +} + +void MinRPCReturnsWithLog::ReleaseHandleName(void* handle) { + if (handle_descriptions_.find(handle) != handle_descriptions_.end()) { + logger_->LogHandleName(handle_descriptions_[handle]); + handle_descriptions_.erase(handle); + } +} + +void MinRPCReturnsWithLog::RegisterHandleName(void* handle) { + handle_descriptions_[handle] = handle_name_; +} + +void MinRPCExecuteWithLog::InitServer(int num_args) { + SetRPCCode(RPCCode::kInitServer); + logger_->Log("Init Server"); + next_->InitServer(num_args); +} + +void MinRPCExecuteWithLog::NormalCallFunc(uint64_t call_handle, TVMValue* values, int* tcodes, + int num_args) { + SetRPCCode(RPCCode::kCallFunc); + logger_->LogValue("call_handle: ", reinterpret_cast(call_handle)); + ret_handler_->GetHandleName(reinterpret_cast(call_handle)); + if (num_args > 0) { + logger_->Log(", "); + } + ProcessValues(values, tcodes, num_args); + next_->NormalCallFunc(call_handle, values, tcodes, num_args); +} + +void MinRPCExecuteWithLog::CopyFromRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* temp_data) { + SetRPCCode(RPCCode::kCopyFromRemote); + logger_->LogValue("data_handle: ", static_cast(arr->data)); + logger_->LogDLDevice(", DLDevice(type,id):", &(arr->device)); + logger_->LogValue(", ndim: ", arr->ndim); + logger_->LogDLData(", DLDataType(code,bits,lane): ", &(arr->dtype)); + logger_->LogValue(", num_bytes:", num_bytes); + next_->CopyFromRemote(arr, num_bytes, temp_data); +} + +int MinRPCExecuteWithLog::CopyToRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* data_ptr) { + SetRPCCode(RPCCode::kCopyToRemote); + logger_->LogValue("data_handle: ", static_cast(arr->data)); + logger_->LogDLDevice(", DLDevice(type,id):", &(arr->device)); + logger_->LogValue(", ndim: ", arr->ndim); + logger_->LogDLData(", DLDataType(code,bits,lane): ", &(arr->dtype)); + logger_->LogValue(", byte_offset: ", arr->byte_offset); + return next_->CopyToRemote(arr, num_bytes, data_ptr); +} + +void MinRPCExecuteWithLog::SysCallFunc(RPCCode code, TVMValue* values, int* tcodes, int num_args) { + SetRPCCode(code); + if ((code) == RPCCode::kFreeHandle) { + if ((num_args == 2) && (tcodes[0] == kTVMOpaqueHandle) && (tcodes[1] == kDLInt)) { + logger_->LogValue("handle: ", static_cast(values[0].v_handle)); + if (values[1].v_int64 == kTVMModuleHandle || values[1].v_int64 == kTVMPackedFuncHandle) { + ret_handler_->ReleaseHandleName(static_cast(values[0].v_handle)); + } + } + } else { + ProcessValues(values, tcodes, num_args); + } + next_->SysCallFunc(code, values, tcodes, num_args); +} + +void MinRPCExecuteWithLog::ThrowError(RPCServerStatus code, RPCCode info) { + logger_->Log("-> Error\n"); + next_->ThrowError(code, info); +} + +void MinRPCExecuteWithLog::ProcessValues(TVMValue* values, int* tcodes, int num_args) { + if (tcodes != nullptr) { + logger_->Log("["); + for (int i = 0; i < num_args; ++i) { + logger_->LogTVMValue(tcodes[i], values[i]); + + if (tcodes[i] == kTVMStr) { + if (strlen(values[i].v_str) > 0) { + ret_handler_->UpdateHandleName(values[i].v_str); + } + } + } + logger_->Log("]"); + } +} + +void MinRPCExecuteWithLog::SetRPCCode(RPCCode code) { + logger_->Log(RPCCodeToString(code)); + logger_->Log(", "); + ret_handler_->ResetHandleName(code); +} + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/minrpc/minrpc_logger.h b/src/runtime/minrpc/minrpc_logger.h new file mode 100644 index 000000000000..13d44c3cba9b --- /dev/null +++ b/src/runtime/minrpc/minrpc_logger.h @@ -0,0 +1,296 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_MINRPC_MINRPC_LOGGER_H_ +#define TVM_RUNTIME_MINRPC_MINRPC_LOGGER_H_ + +#include + +#include +#include +#include +#include + +#include "minrpc_interfaces.h" +#include "rpc_reference.h" + +namespace tvm { +namespace runtime { + +#define PRINT_BYTES false + +/*! + * \brief Generates a user readeable log on the console + */ +class Logger { + public: + Logger() {} + + /*! + * \brief this function logs a string + * + * \param s the string to be logged. + */ + void Log(const char* s) { os_ << s; } + void Log(std::string s) { os_ << s; } + + /*! + * \brief this function logs a numerical value + * + * \param desc adds any necessary description before the value. + * \param val is the value to be logged. + */ + template + void LogValue(const char* desc, T val) { + os_ << desc << val; + } + + /*! + * \brief this function logs the properties of a DLDevice + * + * \param desc adds any necessary description before the DLDevice. + * \param dev is the pointer to the DLDevice to be logged. + */ + void LogDLDevice(const char* desc, DLDevice* dev) { + os_ << desc << "(" << dev->device_type << "," << dev->device_id << ")"; + } + + /*! + * \brief this function logs the properties of a DLDataType + * + * \param desc adds any necessary description before the DLDataType. + * \param data is the pointer to the DLDataType to be logged. + */ + void LogDLData(const char* desc, DLDataType* data) { + os_ << desc << "(" << (uint16_t)data->code << "," << (uint16_t)data->bits << "," << data->lanes + << ")"; + } + + /*! + * \brief this function logs a handle name. + * + * \param name is the name to be logged. + */ + void LogHandleName(std::string name) { + if (name.length() > 0) { + os_ << " <" << name.c_str() << ">"; + } + } + + /*! + * \brief this function logs a TVMValue based on its type. + * + * \param tcode the type_code of the value stored in TVMValue. + * \param value is the TVMValue to be logged. + */ + void LogTVMValue(int tcode, TVMValue value); + + /*! + * \brief this function output the log to the console. + */ + void OutputLog(); + + private: + std::stringstream os_; +}; + +/*! + * \brief A wrapper for a MinRPCReturns object, that also logs the responses. + * + * \param next underlying MinRPCReturns that generates the responses. + */ +class MinRPCReturnsWithLog : public MinRPCReturnInterface { + public: + /*! + * \brief Constructor. + * \param io The IO handler. + */ + MinRPCReturnsWithLog(MinRPCReturnInterface* next, Logger* logger) + : next_(next), logger_(logger) {} + + ~MinRPCReturnsWithLog() {} + + void ReturnVoid(); + + void ReturnHandle(void* handle); + + void ReturnException(const char* msg); + + void ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args); + + void ReturnCopyFromRemote(uint8_t* data_ptr, uint64_t num_bytes); + + void ReturnLastTVMError(); + + void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone); + + /*! + * \brief this function logs a list of TVMValues, and registers handle_name when needed. + * + * \param values is the list of TVMValues. + * \param tcodes is the list type_code of the TVMValues. + * \param num_args is the number of items in the list. + */ + void ProcessValues(const TVMValue* values, const int* tcodes, int num_args); + + /*! + * \brief this function is called when a new command is executed. + * It clears the handle_name_ and records the command code. + * + * \param code the RPC command code. + */ + void ResetHandleName(RPCCode code); + + /*! + * \brief appends name to the handle_name_. + * + * \param name handle name. + */ + void UpdateHandleName(const char* name); + + /*! + * \brief get the stored handle description. + * + * \param handle the handle to get the description for. + */ + void GetHandleName(void* handle); + + /*! + * \brief remove the handle description from handle_descriptions_. + * + * \param handle the handle to remove the description for. + */ + void ReleaseHandleName(void* handle); + + private: + /*! + * \brief add the handle description to handle_descriptions_. + * + * \param handle the handle to add the description for. + */ + void RegisterHandleName(void* handle); + + MinRPCReturnInterface* next_; + std::string handle_name_; + std::unordered_map handle_descriptions_; + RPCCode code_; + Logger* logger_; +}; + +/*! + * \brief A wrapper for a MinRPCExecute object, that also logs the responses. + * + * \param next: underlying MinRPCExecute that processes the packets. + */ +class MinRPCExecuteWithLog : public MinRPCExecInterface { + public: + MinRPCExecuteWithLog(MinRPCExecInterface* next, Logger* logger) : next_(next), logger_(logger) { + ret_handler_ = reinterpret_cast(next_->GetReturnInterface()); + } + + ~MinRPCExecuteWithLog() {} + + void InitServer(int num_args); + + void NormalCallFunc(uint64_t call_handle, TVMValue* values, int* tcodes, int num_args); + + void CopyFromRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* temp_data); + + int CopyToRemote(DLTensor* arr, uint64_t _num_bytes, uint8_t* _data_ptr); + + void SysCallFunc(RPCCode code, TVMValue* values, int* tcodes, int num_args); + + void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone); + + MinRPCReturnInterface* GetReturnInterface() { return next_->GetReturnInterface(); } + + private: + /*! + * \brief this function logs a list of TVMValues, and updates handle_name when needed. + * + * \param values is the list of TVMValues. + * \param tcodes is the list type_code of the TVMValues. + * \param num_args is the number of items in the list. + */ + void ProcessValues(TVMValue* values, int* tcodes, int num_args); + + /*! + * \brief this function is called when a new command is executed. + * + * \param code the RPC command code. + */ + void SetRPCCode(RPCCode code); + + MinRPCExecInterface* next_; + MinRPCReturnsWithLog* ret_handler_; + Logger* logger_; +}; + +/*! + * \brief A No-operation MinRPCReturns used within the MinRPCSniffer + * + * \tparam TIOHandler* IO provider to provide io handling. + */ +template +class MinRPCReturnsNoOp : public MinRPCReturnInterface { + public: + /*! + * \brief Constructor. + * \param io The IO handler. + */ + explicit MinRPCReturnsNoOp(TIOHandler* io) : io_(io) {} + ~MinRPCReturnsNoOp() {} + void ReturnVoid() {} + void ReturnHandle(void* handle) {} + void ReturnException(const char* msg) {} + void ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args) {} + void ReturnCopyFromRemote(uint8_t* data_ptr, uint64_t num_bytes) {} + void ReturnLastTVMError() {} + void ThrowError(RPCServerStatus code, RPCCode info) {} + + private: + TIOHandler* io_; +}; + +/*! + * \brief A No-operation MinRPCExecute used within the MinRPCSniffer + * + * \tparam ReturnInterface* ReturnInterface pointer to generate and send the responses. + + */ +class MinRPCExecuteNoOp : public MinRPCExecInterface { + public: + explicit MinRPCExecuteNoOp(MinRPCReturnInterface* ret_handler) : ret_handler_(ret_handler) {} + ~MinRPCExecuteNoOp() {} + void InitServer(int _num_args) {} + void NormalCallFunc(uint64_t call_handle, TVMValue* values, int* tcodes, int num_args) {} + void CopyFromRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* temp_data) {} + int CopyToRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* data_ptr) { return 1; } + void SysCallFunc(RPCCode code, TVMValue* values, int* tcodes, int num_args) {} + void ThrowError(RPCServerStatus code, RPCCode info) {} + MinRPCReturnInterface* GetReturnInterface() { return ret_handler_; } + + private: + MinRPCReturnInterface* ret_handler_; +}; + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_MINRPC_MINRPC_LOGGER_H_" diff --git a/src/runtime/minrpc/minrpc_server.h b/src/runtime/minrpc/minrpc_server.h index 92cb2e819f22..4684aa0e1616 100644 --- a/src/runtime/minrpc/minrpc_server.h +++ b/src/runtime/minrpc/minrpc_server.h @@ -28,27 +28,25 @@ #ifndef TVM_RUNTIME_MINRPC_MINRPC_SERVER_H_ #define TVM_RUNTIME_MINRPC_MINRPC_SERVER_H_ +#ifndef DMLC_LITTLE_ENDIAN #define DMLC_LITTLE_ENDIAN 1 +#endif + #include #include +#include +#include + #include "../../support/generic_arena.h" +#include "minrpc_interfaces.h" #include "rpc_reference.h" -/*! \brief Whether or not to enable glog style DLOG */ -#ifndef TVM_MINRPC_ENABLE_LOGGING -#define TVM_MINRPC_ENABLE_LOGGING 0 -#endif - #ifndef MINRPC_CHECK #define MINRPC_CHECK(cond) \ if (!(cond)) this->ThrowError(RPCServerStatus::kCheckError); #endif -#if TVM_MINRPC_ENABLE_LOGGING -#include -#endif - namespace tvm { namespace runtime { @@ -58,95 +56,133 @@ class PageAllocator; } /*! - * \brief A minimum RPC server that only depends on the tvm C runtime.. - * - * All the dependencies are provided by the io arguments. + * \brief Responses to a minimum RPC command. * * \tparam TIOHandler IO provider to provide io handling. - * An IOHandler needs to provide the following functions: - * - PosixWrite, PosixRead, Close: posix style, read, write, close API. - * - MessageStart(num_bytes), MessageDone(): framing APIs. - * - Exit: exit with status code. */ -template class Allocator = detail::PageAllocator> -class MinRPCServer { +template +class MinRPCReturns : public MinRPCReturnInterface { public: - using PageAllocator = Allocator; - /*! * \brief Constructor. * \param io The IO handler. */ - explicit MinRPCServer(TIOHandler* io) : io_(io), arena_(PageAllocator(io)) {} + explicit MinRPCReturns(TIOHandler* io) : io_(io) {} - /*! \brief Process a single request. - * - * \return true when the server should continue processing requests. false when it should be - * shutdown. - */ - bool ProcessOnePacket() { - RPCCode code; - uint64_t packet_len; + void ReturnVoid() { + int32_t num_args = 1; + int32_t tcode = kTVMNullptr; + RPCCode code = RPCCode::kReturn; - arena_.RecycleAll(); - allow_clean_shutdown_ = true; + uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode); - this->Read(&packet_len); - if (packet_len == 0) return true; - this->Read(&code); + io_->MessageStart(packet_nbytes); + Write(packet_nbytes); + Write(code); + Write(num_args); + Write(tcode); + io_->MessageDone(); + } - allow_clean_shutdown_ = false; + void ReturnHandle(void* handle) { + int32_t num_args = 1; + int32_t tcode = kTVMOpaqueHandle; + RPCCode code = RPCCode::kReturn; + uint64_t encode_handle = reinterpret_cast(handle); + uint64_t packet_nbytes = + sizeof(code) + sizeof(num_args) + sizeof(tcode) + sizeof(encode_handle); - if (code >= RPCCode::kSyscallCodeStart) { - this->HandleSyscallFunc(code); - } else { - switch (code) { - case RPCCode::kCallFunc: { - HandleNormalCallFunc(); - break; - } - case RPCCode::kInitServer: { - HandleInitServer(); - break; - } - case RPCCode::kCopyFromRemote: { - HandleCopyFromRemote(); - break; - } - case RPCCode::kCopyToRemote: { - HandleCopyToRemote(); - break; - } - case RPCCode::kShutdown: { - this->Shutdown(); - return false; - } - default: { - this->ThrowError(RPCServerStatus::kUnknownRPCCode); - break; - } + io_->MessageStart(packet_nbytes); + Write(packet_nbytes); + Write(code); + Write(num_args); + Write(tcode); + Write(encode_handle); + io_->MessageDone(); + } + + void ReturnException(const char* msg) { RPCReference::ReturnException(msg, this); } + + void ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args) { + RPCReference::ReturnPackedSeq(arg_values, type_codes, num_args, this); + } + + void ReturnCopyFromRemote(uint8_t* data_ptr, uint64_t num_bytes) { + RPCCode code = RPCCode::kCopyAck; + uint64_t packet_nbytes = sizeof(code) + num_bytes; + + io_->MessageStart(packet_nbytes); + Write(packet_nbytes); + Write(code); + WriteArray(data_ptr, num_bytes); + io_->MessageDone(); + } + + void ReturnLastTVMError() { + const char* err = TVMGetLastError(); + ReturnException(err); + } + + void MessageStart(uint64_t packet_nbytes) { io_->MessageStart(packet_nbytes); } + + void MessageDone() { io_->MessageDone(); } + + void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) { + io_->Exit(static_cast(code)); + } + + template + void Write(const T& data) { + static_assert(std::is_trivial::value && std::is_standard_layout::value, + "need to be trival"); + return WriteRawBytes(&data, sizeof(T)); + } + + template + void WriteArray(T* data, size_t count) { + static_assert(std::is_trivial::value && std::is_standard_layout::value, + "need to be trival"); + return WriteRawBytes(data, sizeof(T) * count); + } + + private: + void WriteRawBytes(const void* data, size_t size) { + const uint8_t* buf = static_cast(data); + size_t ndone = 0; + while (ndone < size) { + ssize_t ret = io_->PosixWrite(buf, size - ndone); + if (ret <= 0) { + this->ThrowError(RPCServerStatus::kWriteError); } + buf += ret; + ndone += ret; } - - return true; } - void Shutdown() { - arena_.FreeAll(); - io_->Close(); + TIOHandler* io_; +}; + +/*! + * \brief Executing a minimum RPC command. + * + * \tparam TIOHandler IO provider to provide io handling. + * \tparam MinRPCReturnInterface* handles response generatation and transmission. + */ +template +class MinRPCExecute : public MinRPCExecInterface { + public: + MinRPCExecute(TIOHandler* io, MinRPCReturnInterface* ret_handler) + : io_(io), ret_handler_(ret_handler) {} + + void InitServer(int num_args) { + MINRPC_CHECK(num_args == 0); + ret_handler_->ReturnVoid(); } - void HandleNormalCallFunc() { - uint64_t call_handle; - TVMValue* values; - int* tcodes; - int num_args; + void NormalCallFunc(uint64_t call_handle, TVMValue* values, int* tcodes, int num_args) { TVMValue ret_value[3]; int ret_tcode[3]; - this->Read(&call_handle); - RecvPackedSeq(&values, &tcodes, &num_args); - int call_ecode = TVMFuncCall(reinterpret_cast(call_handle), values, tcodes, num_args, &(ret_value[1]), &(ret_tcode[1])); @@ -159,46 +195,27 @@ class MinRPCServer { ret_tcode[1] = kTVMDLTensorHandle; ret_value[2].v_handle = ret_value[1].v_handle; ret_tcode[2] = kTVMOpaqueHandle; - this->ReturnPackedSeq(ret_value, ret_tcode, 3); + ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 3); } else if (rv_tcode == kTVMBytes) { ret_tcode[1] = kTVMBytes; - this->ReturnPackedSeq(ret_value, ret_tcode, 2); + ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 2); TVMByteArrayFree(reinterpret_cast(ret_value[1].v_handle)); // NOLINT(*) } else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle) { ret_tcode[1] = kTVMOpaqueHandle; - this->ReturnPackedSeq(ret_value, ret_tcode, 2); + ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 2); } else { - this->ReturnPackedSeq(ret_value, ret_tcode, 2); + ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 2); } } else { - this->ReturnLastTVMError(); + ret_handler_->ReturnLastTVMError(); } } - void HandleCopyFromRemote() { - DLTensor* arr = this->ArenaAlloc(1); - uint64_t data_handle; - this->Read(&data_handle); - arr->data = reinterpret_cast(data_handle); - this->Read(&(arr->device)); - this->Read(&(arr->ndim)); - this->Read(&(arr->dtype)); - arr->shape = this->ArenaAlloc(arr->ndim); - this->ReadArray(arr->shape, arr->ndim); - arr->strides = nullptr; - this->Read(&(arr->byte_offset)); - - uint64_t num_bytes; - this->Read(&num_bytes); - - uint8_t* data_ptr; + void CopyFromRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* data_ptr) { int call_ecode = 0; - if (arr->device.device_type == kDLCPU) { - data_ptr = reinterpret_cast(data_handle) + arr->byte_offset; - } else { - data_ptr = this->ArenaAlloc(num_bytes); + if (arr->device.device_type != kDLCPU) { DLTensor temp; - temp.data = reinterpret_cast(data_ptr); + temp.data = static_cast(data_ptr); temp.device = DLDevice{kDLCPU, 0}; temp.ndim = arr->ndim; temp.dtype = arr->dtype; @@ -213,43 +230,21 @@ class MinRPCServer { } if (call_ecode == 0) { - RPCCode code = RPCCode::kCopyAck; - uint64_t packet_nbytes = sizeof(code) + num_bytes; - - io_->MessageStart(packet_nbytes); - this->Write(packet_nbytes); - this->Write(code); - this->WriteArray(data_ptr, num_bytes); - io_->MessageDone(); + ret_handler_->ReturnCopyFromRemote(data_ptr, num_bytes); } else { - this->ReturnLastTVMError(); + ret_handler_->ReturnLastTVMError(); } } - void HandleCopyToRemote() { - DLTensor* arr = this->ArenaAlloc(1); - uint64_t data_handle; - this->Read(&data_handle); - arr->data = reinterpret_cast(data_handle); - this->Read(&(arr->device)); - this->Read(&(arr->ndim)); - this->Read(&(arr->dtype)); - arr->shape = this->ArenaAlloc(arr->ndim); - this->ReadArray(arr->shape, arr->ndim); - arr->strides = nullptr; - this->Read(&(arr->byte_offset)); - uint64_t num_bytes; - this->Read(&num_bytes); - + int CopyToRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* data_ptr) { int call_ecode = 0; - if (arr->device.device_type == kDLCPU) { - uint8_t* dptr = reinterpret_cast(data_handle) + arr->byte_offset; - this->ReadArray(dptr, num_bytes); - } else { - uint8_t* temp_data = this->ArenaAlloc(num_bytes); - this->ReadArray(temp_data, num_bytes); + + int ret = ReadArray(data_ptr, num_bytes); + if (ret <= 0) return ret; + + if (arr->device.device_type != kDLCPU) { DLTensor temp; - temp.data = temp_data; + temp.data = data_ptr; temp.device = DLDevice{kDLCPU, 0}; temp.ndim = arr->ndim; temp.dtype = arr->dtype; @@ -264,87 +259,71 @@ class MinRPCServer { } if (call_ecode == 0) { - this->ReturnVoid(); + ret_handler_->ReturnVoid(); } else { - this->ReturnLastTVMError(); + ret_handler_->ReturnLastTVMError(); } + + return 1; } - void HandleSyscallFunc(RPCCode code) { - TVMValue* values; - int* tcodes; - int num_args; - RecvPackedSeq(&values, &tcodes, &num_args); + void SysCallFunc(RPCCode code, TVMValue* values, int* tcodes, int num_args) { switch (code) { case RPCCode::kFreeHandle: { - this->SyscallFreeHandle(values, tcodes, num_args); + SyscallFreeHandle(values, tcodes, num_args); break; } case RPCCode::kGetGlobalFunc: { - this->SyscallGetGlobalFunc(values, tcodes, num_args); + SyscallGetGlobalFunc(values, tcodes, num_args); break; } case RPCCode::kDevSetDevice: { - this->ReturnException("SetDevice not supported"); + ret_handler_->ReturnException("SetDevice not supported"); break; } case RPCCode::kDevGetAttr: { - this->ReturnException("GetAttr not supported"); + ret_handler_->ReturnException("GetAttr not supported"); break; } case RPCCode::kDevAllocData: { - this->SyscallDevAllocData(values, tcodes, num_args); + SyscallDevAllocData(values, tcodes, num_args); break; } case RPCCode::kDevAllocDataWithScope: { - this->SyscallDevAllocDataWithScope(values, tcodes, num_args); + SyscallDevAllocDataWithScope(values, tcodes, num_args); break; } case RPCCode::kDevFreeData: { - this->SyscallDevFreeData(values, tcodes, num_args); + SyscallDevFreeData(values, tcodes, num_args); break; } case RPCCode::kDevCreateStream: { - this->SyscallDevCreateStream(values, tcodes, num_args); + SyscallDevCreateStream(values, tcodes, num_args); break; } case RPCCode::kDevFreeStream: { - this->SyscallDevFreeStream(values, tcodes, num_args); + SyscallDevFreeStream(values, tcodes, num_args); break; } case RPCCode::kDevStreamSync: { - this->SyscallDevStreamSync(values, tcodes, num_args); + SyscallDevStreamSync(values, tcodes, num_args); break; } case RPCCode::kDevSetStream: { - this->SyscallDevSetStream(values, tcodes, num_args); + SyscallDevSetStream(values, tcodes, num_args); break; } case RPCCode::kCopyAmongRemote: { - this->SyscallCopyAmongRemote(values, tcodes, num_args); + SyscallCopyAmongRemote(values, tcodes, num_args); break; } default: { - this->ReturnException("Syscall not recognized"); + ret_handler_->ReturnException("Syscall not recognized"); break; } } } - void HandleInitServer() { - uint64_t len; - this->Read(&len); - char* proto_ver = this->ArenaAlloc(len + 1); - this->ReadArray(proto_ver, len); - - TVMValue* values; - int* tcodes; - int num_args; - RecvPackedSeq(&values, &tcodes, &num_args); - MINRPC_CHECK(num_args == 0); - this->ReturnVoid(); - } - void SyscallFreeHandle(TVMValue* values, int* tcodes, int num_args) { MINRPC_CHECK(num_args == 2); MINRPC_CHECK(tcodes[0] == kTVMOpaqueHandle); @@ -364,23 +343,22 @@ class MinRPCServer { } if (call_ecode == 0) { - this->ReturnVoid(); + ret_handler_->ReturnVoid(); } else { - this->ReturnLastTVMError(); + ret_handler_->ReturnLastTVMError(); } } void SyscallGetGlobalFunc(TVMValue* values, int* tcodes, int num_args) { MINRPC_CHECK(num_args == 1); MINRPC_CHECK(tcodes[0] == kTVMStr); - void* handle; int call_ecode = TVMFuncGetGlobal(values[0].v_str, &handle); if (call_ecode == 0) { - this->ReturnHandle(handle); + ret_handler_->ReturnHandle(handle); } else { - this->ReturnLastTVMError(); + ret_handler_->ReturnLastTVMError(); } } @@ -401,9 +379,9 @@ class MinRPCServer { reinterpret_cast(to), stream); if (call_ecode == 0) { - this->ReturnVoid(); + ret_handler_->ReturnVoid(); } else { - this->ReturnLastTVMError(); + ret_handler_->ReturnLastTVMError(); } } @@ -423,9 +401,9 @@ class MinRPCServer { int call_ecode = TVMDeviceAllocDataSpace(dev, nbytes, alignment, type_hint, &handle); if (call_ecode == 0) { - this->ReturnHandle(handle); + ret_handler_->ReturnHandle(handle); } else { - this->ReturnLastTVMError(); + ret_handler_->ReturnLastTVMError(); } } @@ -434,15 +412,15 @@ class MinRPCServer { MINRPC_CHECK(tcodes[0] == kTVMDLTensorHandle); MINRPC_CHECK(tcodes[1] == kTVMNullptr || tcodes[1] == kTVMStr); - DLTensor* arr = reinterpret_cast(values[0].v_handle); + DLTensor* arr = static_cast(values[0].v_handle); const char* mem_scope = (tcodes[1] == kTVMNullptr ? nullptr : values[1].v_str); void* handle; int call_ecode = TVMDeviceAllocDataSpaceWithScope(arr->device, arr->ndim, arr->shape, arr->dtype, mem_scope, &handle); if (call_ecode == 0) { - this->ReturnHandle(handle); + ret_handler_->ReturnHandle(handle); } else { - this->ReturnLastTVMError(); + ret_handler_->ReturnLastTVMError(); } } @@ -457,9 +435,9 @@ class MinRPCServer { int call_ecode = TVMDeviceFreeDataSpace(dev, handle); if (call_ecode == 0) { - this->ReturnVoid(); + ret_handler_->ReturnVoid(); } else { - this->ReturnLastTVMError(); + ret_handler_->ReturnLastTVMError(); } } @@ -473,9 +451,9 @@ class MinRPCServer { int call_ecode = TVMStreamCreate(dev.device_type, dev.device_id, &handle); if (call_ecode == 0) { - this->ReturnHandle(handle); + ret_handler_->ReturnHandle(handle); } else { - this->ReturnLastTVMError(); + ret_handler_->ReturnLastTVMError(); } } @@ -490,9 +468,9 @@ class MinRPCServer { int call_ecode = TVMStreamFree(dev.device_type, dev.device_id, handle); if (call_ecode == 0) { - this->ReturnVoid(); + ret_handler_->ReturnVoid(); } else { - this->ReturnLastTVMError(); + ret_handler_->ReturnLastTVMError(); } } @@ -507,9 +485,9 @@ class MinRPCServer { int call_ecode = TVMSynchronize(dev.device_type, dev.device_id, handle); if (call_ecode == 0) { - this->ReturnVoid(); + ret_handler_->ReturnVoid(); } else { - this->ReturnLastTVMError(); + ret_handler_->ReturnLastTVMError(); } } @@ -524,103 +502,265 @@ class MinRPCServer { int call_ecode = TVMSetStream(dev.device_type, dev.device_id, handle); if (call_ecode == 0) { - this->ReturnVoid(); + ret_handler_->ReturnVoid(); } else { - this->ReturnLastTVMError(); + ret_handler_->ReturnLastTVMError(); } } void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) { - io_->Exit(static_cast(code)); + ret_handler_->ThrowError(code, info); } + MinRPCReturnInterface* GetReturnInterface() { return ret_handler_; } + + private: template - T* ArenaAlloc(int count) { - static_assert(std::is_pod::value, "need to be trival"); - return arena_.template allocate_(count); + int ReadArray(T* data, size_t count) { + static_assert(std::is_trivial::value && std::is_standard_layout::value, + "need to be trival"); + return ReadRawBytes(data, sizeof(T) * count); } - template - void Read(T* data) { - static_assert(std::is_pod::value, "need to be trival"); - this->ReadRawBytes(data, sizeof(T)); + int ReadRawBytes(void* data, size_t size) { + uint8_t* buf = static_cast(data); + size_t ndone = 0; + while (ndone < size) { + ssize_t ret = io_->PosixRead(buf, size - ndone); + if (ret <= 0) return ret; + ndone += ret; + buf += ret; + } + return 1; } - template - void ReadArray(T* data, size_t count) { - static_assert(std::is_pod::value, "need to be trival"); - return this->ReadRawBytes(data, sizeof(T) * count); + TIOHandler* io_; + MinRPCReturnInterface* ret_handler_; +}; + +/*! + * \brief A minimum RPC server that only depends on the tvm C runtime.. + * + * All the dependencies are provided by the io arguments. + * + * \tparam TIOHandler IO provider to provide io handling. + * An IOHandler needs to provide the following functions: + * - PosixWrite, PosixRead, Close: posix style, read, write, close API. + * - MessageStart(num_bytes), MessageDone(): framing APIs. + * - Exit: exit with status code. + */ +template class Allocator = detail::PageAllocator> +class MinRPCServer { + public: + using PageAllocator = Allocator; + + /*! + * \brief Constructor. + * \param io The IO handler. + */ + MinRPCServer(TIOHandler* io, std::unique_ptr&& exec_handler) + : io_(io), arena_(PageAllocator(io_)), exec_handler_(std::move(exec_handler)) {} + + explicit MinRPCServer(TIOHandler* io) + : io_(io), + arena_(PageAllocator(io)), + ret_handler_(new MinRPCReturns(io_)), + exec_handler_(std::unique_ptr( + new MinRPCExecute(io_, ret_handler_))) {} + + ~MinRPCServer() { + if (ret_handler_ != nullptr) { + delete ret_handler_; + } } - template - void Write(const T& data) { - static_assert(std::is_pod::value, "need to be trival"); - return this->WriteRawBytes(&data, sizeof(T)); + /*! \brief Process a single request. + * + * \return true when the server should continue processing requests. false when it should be + * shutdown. + */ + bool ProcessOnePacket() { + RPCCode code; + uint64_t packet_len; + + arena_.RecycleAll(); + allow_clean_shutdown_ = true; + + Read(&packet_len); + if (packet_len == 0) return true; + Read(&code); + allow_clean_shutdown_ = false; + + if (code >= RPCCode::kSyscallCodeStart) { + HandleSyscallFunc(code); + } else { + switch (code) { + case RPCCode::kCallFunc: { + HandleNormalCallFunc(); + break; + } + case RPCCode::kInitServer: { + HandleInitServer(); + break; + } + case RPCCode::kCopyFromRemote: { + HandleCopyFromRemote(); + break; + } + case RPCCode::kCopyToRemote: { + HandleCopyToRemote(); + break; + } + case RPCCode::kShutdown: { + Shutdown(); + return false; + } + default: { + this->ThrowError(RPCServerStatus::kUnknownRPCCode); + break; + } + } + } + + return true; } - template - void WriteArray(T* data, size_t count) { - static_assert(std::is_pod::value, "need to be trival"); - return this->WriteRawBytes(data, sizeof(T) * count); + void HandleInitServer() { + uint64_t len; + Read(&len); + char* proto_ver = ArenaAlloc(len + 1); + ReadArray(proto_ver, len); + TVMValue* values; + int* tcodes; + int num_args; + RecvPackedSeq(&values, &tcodes, &num_args); + exec_handler_->InitServer(num_args); } - void MessageStart(uint64_t packet_nbytes) { io_->MessageStart(packet_nbytes); } + void Shutdown() { + arena_.FreeAll(); + io_->Close(); + } - void MessageDone() { io_->MessageDone(); } + void HandleNormalCallFunc() { + uint64_t call_handle; + TVMValue* values; + int* tcodes; + int num_args; - private: - void RecvPackedSeq(TVMValue** out_values, int** out_tcodes, int* out_num_args) { - RPCReference::RecvPackedSeq(out_values, out_tcodes, out_num_args, this); + Read(&call_handle); + RecvPackedSeq(&values, &tcodes, &num_args); + exec_handler_->NormalCallFunc(call_handle, values, tcodes, num_args); } - void ReturnVoid() { - int32_t num_args = 1; - int32_t tcode = kTVMNullptr; - RPCCode code = RPCCode::kReturn; + void HandleCopyFromRemote() { + DLTensor* arr = ArenaAlloc(1); + uint64_t data_handle; + Read(&data_handle); + arr->data = reinterpret_cast(data_handle); + Read(&(arr->device)); + Read(&(arr->ndim)); + Read(&(arr->dtype)); + arr->shape = ArenaAlloc(arr->ndim); + ReadArray(arr->shape, arr->ndim); + arr->strides = nullptr; + Read(&(arr->byte_offset)); - uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode); + uint64_t num_bytes; + Read(&num_bytes); - io_->MessageStart(packet_nbytes); - this->Write(packet_nbytes); - this->Write(code); - this->Write(num_args); - this->Write(tcode); - io_->MessageDone(); + uint8_t* data_ptr; + if (arr->device.device_type == kDLCPU) { + data_ptr = reinterpret_cast(data_handle) + arr->byte_offset; + } else { + data_ptr = ArenaAlloc(num_bytes); + } + + exec_handler_->CopyFromRemote(arr, num_bytes, data_ptr); } - void ReturnHandle(void* handle) { - int32_t num_args = 1; - int32_t tcode = kTVMOpaqueHandle; - RPCCode code = RPCCode::kReturn; - uint64_t encode_handle = reinterpret_cast(handle); - uint64_t packet_nbytes = - sizeof(code) + sizeof(num_args) + sizeof(tcode) + sizeof(encode_handle); + void HandleCopyToRemote() { + DLTensor* arr = ArenaAlloc(1); + uint64_t data_handle; + Read(&data_handle); + arr->data = reinterpret_cast(data_handle); + Read(&(arr->device)); + Read(&(arr->ndim)); + Read(&(arr->dtype)); + arr->shape = ArenaAlloc(arr->ndim); + ReadArray(arr->shape, arr->ndim); + arr->strides = nullptr; + Read(&(arr->byte_offset)); + uint64_t num_bytes; + Read(&num_bytes); + int ret; + if (arr->device.device_type == kDLCPU) { + uint8_t* dptr = reinterpret_cast(data_handle) + arr->byte_offset; + ret = exec_handler_->CopyToRemote(arr, num_bytes, dptr); + } else { + uint8_t* temp_data = ArenaAlloc(num_bytes); + ret = exec_handler_->CopyToRemote(arr, num_bytes, temp_data); + } + if (ret == 0) { + if (allow_clean_shutdown_) { + Shutdown(); + io_->Exit(0); + } else { + this->ThrowError(RPCServerStatus::kReadError); + } + } + if (ret == -1) { + this->ThrowError(RPCServerStatus::kReadError); + } + } - io_->MessageStart(packet_nbytes); - this->Write(packet_nbytes); - this->Write(code); - this->Write(num_args); - this->Write(tcode); - this->Write(encode_handle); - io_->MessageDone(); + void HandleSyscallFunc(RPCCode code) { + TVMValue* values; + int* tcodes; + int num_args; + RecvPackedSeq(&values, &tcodes, &num_args); + + exec_handler_->SysCallFunc(code, values, tcodes, num_args); } - void ReturnException(const char* msg) { RPCReference::ReturnException(msg, this); } + void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) { + io_->Exit(static_cast(code)); + } - void ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args) { - RPCReference::ReturnPackedSeq(arg_values, type_codes, num_args, this); + template + T* ArenaAlloc(int count) { + static_assert(std::is_trivial::value && std::is_standard_layout::value, + "need to be trival"); + return arena_.template allocate_(count); } - void ReturnLastTVMError() { this->ReturnException(TVMGetLastError()); } + template + void Read(T* data) { + static_assert(std::is_trivial::value && std::is_standard_layout::value, + "need to be trival"); + ReadRawBytes(data, sizeof(T)); + } + + template + void ReadArray(T* data, size_t count) { + static_assert(std::is_trivial::value && std::is_standard_layout::value, + "need to be trival"); + return ReadRawBytes(data, sizeof(T) * count); + } + + private: + void RecvPackedSeq(TVMValue** out_values, int** out_tcodes, int* out_num_args) { + RPCReference::RecvPackedSeq(out_values, out_tcodes, out_num_args, this); + } void ReadRawBytes(void* data, size_t size) { - uint8_t* buf = reinterpret_cast(data); + uint8_t* buf = static_cast(data); size_t ndone = 0; while (ndone < size) { ssize_t ret = io_->PosixRead(buf, size - ndone); if (ret == 0) { if (allow_clean_shutdown_) { - this->Shutdown(); + Shutdown(); io_->Exit(0); } else { this->ThrowError(RPCServerStatus::kReadError); @@ -634,26 +774,15 @@ class MinRPCServer { } } - void WriteRawBytes(const void* data, size_t size) { - const uint8_t* buf = reinterpret_cast(data); - size_t ndone = 0; - while (ndone < size) { - ssize_t ret = io_->PosixWrite(buf, size - ndone); - if (ret == 0 || ret == -1) { - this->ThrowError(RPCServerStatus::kWriteError); - } - buf += ret; - ndone += ret; - } - } - /*! \brief IO handler. */ TIOHandler* io_; /*! \brief internal arena. */ support::GenericArena arena_; + MinRPCReturns* ret_handler_ = nullptr; + std::unique_ptr exec_handler_; /*! \brief Whether we are in a state that allows clean shutdown. */ bool allow_clean_shutdown_{true}; - static_assert(DMLC_LITTLE_ENDIAN, "MinRPC only works on little endian."); + static_assert(DMLC_LITTLE_ENDIAN == 1, "MinRPC only works on little endian."); }; namespace detail { diff --git a/src/runtime/minrpc/minrpc_server_logging.h b/src/runtime/minrpc/minrpc_server_logging.h new file mode 100644 index 000000000000..deca2156ce62 --- /dev/null +++ b/src/runtime/minrpc/minrpc_server_logging.h @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_MINRPC_MINRPC_SERVER_LOGGING_H_ +#define TVM_RUNTIME_MINRPC_MINRPC_SERVER_LOGGING_H_ + +#include +#include + +#include "minrpc_logger.h" +#include "minrpc_server.h" + +namespace tvm { +namespace runtime { + +/*! + * \brief A minimum RPC server that logs the received commands. + * + * \tparam TIOHandler IO provider to provide io handling. + */ +template +class MinRPCServerWithLog { + public: + explicit MinRPCServerWithLog(TIOHandler* io) + : ret_handler_(io), + ret_handler_wlog_(&ret_handler_, &logger_), + exec_handler_(io, &ret_handler_wlog_), + exec_handler_ptr_(new MinRPCExecuteWithLog(&exec_handler_, &logger_)), + next_(io, std::move(exec_handler_ptr_)) {} + + bool ProcessOnePacket() { return next_.ProcessOnePacket(); } + + private: + Logger logger_; + MinRPCReturns ret_handler_; + MinRPCExecute exec_handler_; + MinRPCReturnsWithLog ret_handler_wlog_; + std::unique_ptr exec_handler_ptr_; + MinRPCServer next_; +}; + +/*! + * \brief A minimum RPC server that only logs the outgoing commands and received responses. + * (Does not process the packets or respond to them.) + * + * \tparam TIOHandler IO provider to provide io handling. + */ +template class Allocator = detail::PageAllocator> +class MinRPCSniffer { + public: + using PageAllocator = Allocator; + explicit MinRPCSniffer(TIOHandler* io) + : io_(io), + arena_(PageAllocator(io_)), + ret_handler_(io_), + ret_handler_wlog_(&ret_handler_, &logger_), + exec_handler_(&ret_handler_wlog_), + exec_handler_ptr_(new MinRPCExecuteWithLog(&exec_handler_, &logger_)), + next_(io_, std::move(exec_handler_ptr_)) {} + + bool ProcessOnePacket() { return next_.ProcessOnePacket(); } + + void ProcessOneResponse() { + RPCCode code; + uint64_t packet_len = 0; + + if (!Read(&packet_len)) return; + if (packet_len == 0) { + OutputLog(); + return; + } + if (!Read(&code)) return; + switch (code) { + case RPCCode::kReturn: { + int32_t num_args; + int* type_codes; + TVMValue* values; + RPCReference::RecvPackedSeq(&values, &type_codes, &num_args, this); + ret_handler_wlog_.ReturnPackedSeq(values, type_codes, num_args); + break; + } + case RPCCode::kException: { + ret_handler_wlog_.ReturnException(""); + break; + } + default: { + OutputLog(); + break; + } + } + } + + void OutputLog() { logger_.OutputLog(); } + + void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) { + logger_.Log("-> "); + logger_.Log(RPCServerStatusToString(code)); + OutputLog(); + } + + template + T* ArenaAlloc(int count) { + static_assert(std::is_trivial::value && std::is_standard_layout::value, + "need to be trival"); + return arena_.template allocate_(count); + } + + template + bool Read(T* data) { + static_assert(std::is_trivial::value && std::is_standard_layout::value, + "need to be trival"); + return ReadRawBytes(data, sizeof(T)); + } + + template + bool ReadArray(T* data, size_t count) { + static_assert(std::is_trivial::value && std::is_standard_layout::value, + "need to be trival"); + return ReadRawBytes(data, sizeof(T) * count); + } + + private: + bool ReadRawBytes(void* data, size_t size) { + uint8_t* buf = reinterpret_cast(data); + size_t ndone = 0; + while (ndone < size) { + ssize_t ret = io_->PosixRead(buf, size - ndone); + if (ret <= 0) { + this->ThrowError(RPCServerStatus::kReadError); + return false; + } + ndone += ret; + buf += ret; + } + return true; + } + + Logger logger_; + TIOHandler* io_; + support::GenericArena arena_; + MinRPCReturnsNoOp ret_handler_; + MinRPCReturnsWithLog ret_handler_wlog_; + MinRPCExecuteNoOp exec_handler_; + std::unique_ptr exec_handler_ptr_; + MinRPCServer next_; +}; + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_MINRPC_MINRPC_SERVER_LOGGING_H_ diff --git a/src/runtime/rpc/rpc_channel_logger.h b/src/runtime/rpc/rpc_channel_logger.h new file mode 100644 index 000000000000..53144956eb80 --- /dev/null +++ b/src/runtime/rpc/rpc_channel_logger.h @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file rpc_channel_logger.h + * \brief A wrapper for RPCChannel with a NanoRPCListener for logging the commands. + */ +#ifndef TVM_RUNTIME_RPC_RPC_CHANNEL_LOGGER_H_ +#define TVM_RUNTIME_RPC_RPC_CHANNEL_LOGGER_H_ + +#include +#include + +#include "../minrpc/minrpc_server_logging.h" +#include "rpc_channel.h" + +#define RX_BUFFER_SIZE 65536 + +namespace tvm { +namespace runtime { + +class Buffer { + public: + Buffer(uint8_t* data, size_t data_size_bytes) + : data_{data}, capacity_{data_size_bytes}, num_valid_bytes_{0}, read_cursor_{0} {} + + size_t Write(const uint8_t* data, size_t data_size_bytes) { + size_t num_bytes_available = capacity_ - num_valid_bytes_; + size_t num_bytes_to_copy = data_size_bytes; + if (num_bytes_available < num_bytes_to_copy) { + num_bytes_to_copy = num_bytes_available; + } + + memcpy(&data_[num_valid_bytes_], data, num_bytes_to_copy); + num_valid_bytes_ += num_bytes_to_copy; + return num_bytes_to_copy; + } + + size_t Read(uint8_t* data, size_t data_size_bytes) { + size_t num_bytes_to_copy = data_size_bytes; + size_t num_bytes_available = num_valid_bytes_ - read_cursor_; + if (num_bytes_available < num_bytes_to_copy) { + num_bytes_to_copy = num_bytes_available; + } + + memcpy(data, &data_[read_cursor_], num_bytes_to_copy); + read_cursor_ += num_bytes_to_copy; + return num_bytes_to_copy; + } + + void Clear() { + num_valid_bytes_ = 0; + read_cursor_ = 0; + } + + size_t Size() const { return num_valid_bytes_; } + + private: + /*! \brief pointer to data buffer. */ + uint8_t* data_; + + /*! \brief The total number of bytes available in data_.*/ + size_t capacity_; + + /*! \brief number of valid bytes in the buffer. */ + size_t num_valid_bytes_; + + /*! \brief Read cursor position. */ + size_t read_cursor_; +}; + +/*! + * \brief A simple IO handler for MinRPCSniffer. + * + * \tparam Buffer* buffer to store received data. + */ +class SnifferIOHandler { + public: + explicit SnifferIOHandler(Buffer* receive_buffer) : receive_buffer_(receive_buffer) {} + + void MessageStart(size_t message_size_bytes) {} + + ssize_t PosixWrite(const uint8_t* buf, size_t buf_size_bytes) { return 0; } + + void MessageDone() {} + + ssize_t PosixRead(uint8_t* buf, size_t buf_size_bytes) { + return receive_buffer_->Read(buf, buf_size_bytes); + } + + void Close() {} + + void Exit(int code) {} + + private: + Buffer* receive_buffer_; +}; + +/*! + * \brief A simple rpc session that logs the received commands. + */ +class NanoRPCListener { + public: + NanoRPCListener() + : receive_buffer_(receive_storage_, receive_storage_size_bytes_), + io_(&receive_buffer_), + rpc_server_(&io_) {} + + void Listen(const uint8_t* data, size_t size) { receive_buffer_.Write(data, size); } + + void ProcessTxPacket() { + rpc_server_.ProcessOnePacket(); + ClearBuffer(); + } + + void ProcessRxPacket() { + rpc_server_.ProcessOneResponse(); + ClearBuffer(); + } + + private: + void ClearBuffer() { receive_buffer_.Clear(); } + + private: + size_t receive_storage_size_bytes_ = RX_BUFFER_SIZE; + uint8_t receive_storage_[RX_BUFFER_SIZE]; + Buffer receive_buffer_; + SnifferIOHandler io_; + MinRPCSniffer rpc_server_; + + void HandleCompleteMessage() { rpc_server_.ProcessOnePacket(); } + + static void HandleCompleteMessageCb(void* context) { + static_cast(context)->HandleCompleteMessage(); + } +}; + +/*! + * \brief A wrapper for RPCChannel, that also logs the commands sent. + * + * \tparam std::unique_ptr&& underlying RPCChannel unique_ptr. + */ +class RPCChannelLogging : public RPCChannel { + public: + explicit RPCChannelLogging(std::unique_ptr&& next) { next_ = std::move(next); } + + size_t Send(const void* data, size_t size) { + listener_.ProcessRxPacket(); + listener_.Listen((const uint8_t*)data, size); + listener_.ProcessTxPacket(); + return next_->Send(data, size); + } + + size_t Recv(void* data, size_t size) { + size_t ret = next_->Recv(data, size); + listener_.Listen((const uint8_t*)data, size); + return ret; + } + + private: + std::unique_ptr next_; + NanoRPCListener listener_; +}; + +} // namespace runtime +} // namespace tvm +#endif // TVM_RUNTIME_RPC_RPC_CHANNEL_LOGGER_H_ diff --git a/src/runtime/rpc/rpc_endpoint.h b/src/runtime/rpc/rpc_endpoint.h index ed19a3f59e58..d8e2dece73c5 100644 --- a/src/runtime/rpc/rpc_endpoint.h +++ b/src/runtime/rpc/rpc_endpoint.h @@ -34,6 +34,7 @@ #include "../../support/ring_buffer.h" #include "../minrpc/rpc_reference.h" #include "rpc_channel.h" +#include "rpc_channel_logger.h" #include "rpc_session.h" namespace tvm { @@ -180,6 +181,7 @@ class RPCEndpoint { void Shutdown(); // Internal channel. std::unique_ptr channel_; + // Internal mutex std::mutex mutex_; // Internal ring buffer. diff --git a/src/runtime/rpc/rpc_socket_impl.cc b/src/runtime/rpc/rpc_socket_impl.cc index 1456fc719113..bc274ff88812 100644 --- a/src/runtime/rpc/rpc_socket_impl.cc +++ b/src/runtime/rpc/rpc_socket_impl.cc @@ -65,7 +65,7 @@ class SockChannel final : public RPCChannel { }; std::shared_ptr RPCConnect(std::string url, int port, std::string key, - TVMArgs init_seq) { + bool enable_logging, TVMArgs init_seq) { support::TCPSocket sock; support::SockAddr addr(url.c_str(), port); sock.Create(addr.ss_family()); @@ -96,14 +96,20 @@ std::shared_ptr RPCConnect(std::string url, int port, std::string k remote_key.resize(keylen); ICHECK_EQ(sock.RecvAll(&remote_key[0], keylen), keylen); } - auto endpt = - RPCEndpoint::Create(std::unique_ptr(new SockChannel(sock)), key, remote_key); + + std::unique_ptr channel{new SockChannel(sock)}; + if (enable_logging) { + channel.reset(new RPCChannelLogging(std::move(channel))); + } + auto endpt = RPCEndpoint::Create(std::move(channel), key, remote_key); + endpt->InitRemoteSession(init_seq); return endpt; } -Module RPCClientConnect(std::string url, int port, std::string key, TVMArgs init_seq) { - auto endpt = RPCConnect(url, port, "client:" + key, init_seq); +Module RPCClientConnect(std::string url, int port, std::string key, bool enable_logging, + TVMArgs init_seq) { + auto endpt = RPCConnect(url, port, "client:" + key, enable_logging, init_seq); return CreateRPCSessionModule(CreateClientSession(endpt)); } @@ -124,8 +130,9 @@ TVM_REGISTER_GLOBAL("rpc.Connect").set_body([](TVMArgs args, TVMRetValue* rv) { std::string url = args[0]; int port = args[1]; std::string key = args[2]; - *rv = RPCClientConnect(url, port, key, - TVMArgs(args.values + 3, args.type_codes + 3, args.size() - 3)); + bool enable_logging = args[3]; + *rv = RPCClientConnect(url, port, key, enable_logging, + TVMArgs(args.values + 4, args.type_codes + 4, args.size() - 4)); }); TVM_REGISTER_GLOBAL("rpc.ServerLoop").set_body([](TVMArgs args, TVMRetValue* rv) { diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py index f0ddcb60a1fd..63be742fdbb9 100644 --- a/tests/python/unittest/test_runtime_rpc.py +++ b/tests/python/unittest/test_runtime_rpc.py @@ -109,6 +109,25 @@ def check_remote(): check_remote() +@tvm.testing.requires_rpc +def test_rpc_simple_wlog(): + server = rpc.Server(key="x1") + client = rpc.connect("127.0.0.1", server.port, key="x1", enable_logging=True) + + def check_remote(): + f1 = client.get_function("rpc.test.addone") + assert f1(10) == 11 + f3 = client.get_function("rpc.test.except") + + with pytest.raises(tvm._ffi.base.TVMError): + f3("abc") + + f2 = client.get_function("rpc.test.strcat") + assert f2("abc", 11) == "abc:11" + + check_remote() + + @tvm.testing.requires_rpc def test_rpc_runtime_string(): server = rpc.Server(key="x1") @@ -231,7 +250,7 @@ def test_rpc_remote_module(): "127.0.0.1", server0.port, key="x0", - session_constructor_args=["rpc.Connect", "127.0.0.1", server1.port, "x1"], + session_constructor_args=["rpc.Connect", "127.0.0.1", server1.port, "x1", False], ) def check_remote(remote): @@ -366,7 +385,7 @@ def check_multi_hop(): "127.0.0.1", server0.port, key="x0", - session_constructor_args=["rpc.Connect", "127.0.0.1", server1.port, "x1"], + session_constructor_args=["rpc.Connect", "127.0.0.1", server1.port, "x1", False], ) fecho = client.get_function("testing.echo")