From 2e0a71184a4cce907296575683b2b36cdc46d283 Mon Sep 17 00:00:00 2001 From: Leandro Nunes Date: Thu, 9 Sep 2021 06:32:26 +0100 Subject: [PATCH] Set tvm.micro.project_api as a Python Module (#8963) * Add missing tvm.micro.project_api module file. The missing __init__.py makes it impossible to import this module with `import tvm.micro.project_api`. * This uncover 30-ish linting errors, which are also fixed here. --- python/tvm/micro/project_api/__init__.py | 17 +++++++ python/tvm/micro/project_api/client.py | 16 +++++-- python/tvm/micro/project_api/server.py | 60 ++++++++++++++---------- 3 files changed, 64 insertions(+), 29 deletions(-) create mode 100644 python/tvm/micro/project_api/__init__.py diff --git a/python/tvm/micro/project_api/__init__.py b/python/tvm/micro/project_api/__init__.py new file mode 100644 index 000000000000..9915040a922c --- /dev/null +++ b/python/tvm/micro/project_api/__init__.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""MicroTVM Project API Client and Server""" diff --git a/python/tvm/micro/project_api/client.py b/python/tvm/micro/project_api/client.py index f650ad946d87..ac8ff629a718 100644 --- a/python/tvm/micro/project_api/client.py +++ b/python/tvm/micro/project_api/client.py @@ -14,11 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - +""" +Project API client. +""" import base64 import io import json import logging +import platform import os import pathlib import subprocess @@ -56,6 +59,7 @@ class UnsupportedProtocolVersionError(ProjectAPIErrorBase): class RPCError(ProjectAPIErrorBase): def __init__(self, request, error): + ProjectAPIErrorBase.__init__() self.request = request self.error = error @@ -129,7 +133,8 @@ def _request_reply(self, method, params): if "error" in reply: raise server.JSONRPCError.from_json(f"calling method {method}", reply["error"]) - elif "result" not in reply: + + if "result" not in reply: raise MalformedReplyError(f"Expected 'result' key in server reply, got {reply!r}") return reply["result"] @@ -189,7 +194,7 @@ def write_transport(self, data, timeout_sec): # NOTE: windows support untested SERVER_LAUNCH_SCRIPT_FILENAME = ( - f"launch_microtvm_api_server.{'sh' if os.system != 'win32' else '.bat'}" + f"launch_microtvm_api_server.{'sh' if platform.system() != 'Windows' else '.bat'}" ) @@ -197,7 +202,8 @@ def write_transport(self, data, timeout_sec): def instantiate_from_dir(project_dir: typing.Union[pathlib.Path, str], debug: bool = False): - """Launch server located in project_dir, and instantiate a Project API Client connected to it.""" + """Launch server located in project_dir, and instantiate a Project API Client + connected to it.""" args = None project_dir = pathlib.Path(project_dir) @@ -224,7 +230,7 @@ def instantiate_from_dir(project_dir: typing.Union[pathlib.Path, str], debug: bo if debug: args.append("--debug") - api_server_proc = subprocess.Popen( + api_server_proc = subprocess.Popen( # pylint: disable=unused-variable args, bufsize=0, pass_fds=(api_server_read_fd, api_server_write_fd), cwd=project_dir ) os.close(api_server_read_fd) diff --git a/python/tvm/micro/project_api/server.py b/python/tvm/micro/project_api/server.py index 144f0cb6dee1..cee0205303f0 100644 --- a/python/tvm/micro/project_api/server.py +++ b/python/tvm/micro/project_api/server.py @@ -34,7 +34,6 @@ import re import select import sys -import textwrap import time import traceback import typing @@ -100,6 +99,7 @@ class JSONRPCError(Exception): """An error class with properties that meet the JSON-RPC error spec.""" def __init__(self, code, message, data, client_context=None): + Exception.__init__(self) self.code = code self.message = message self.data = data @@ -123,9 +123,7 @@ def __str__(self): @classmethod def from_json(cls, client_context, json_error): - # Subclasses of ServerError capture exceptions that occur in the Handler, and thus return a - # traceback. The encoding in `json_error` is also slightly different to allow the specific subclass - # to be identified. + """Convert an encapsulated ServerError into JSON-RPC compliant format.""" found_server_error = False try: if ErrorCode(json_error["code"]) == ErrorCode.SERVER_ERROR: @@ -145,6 +143,8 @@ def from_json(cls, client_context, json_error): class ServerError(JSONRPCError): + """Superclass for JSON-RPC errors which occur while processing valid requests.""" + @classmethod def from_exception(cls, exc, **kw): to_return = cls(**kw) @@ -168,21 +168,25 @@ def __str__(self): super_str = super(ServerError, self).__str__() return context_str + super_str - def set_traceback(self, traceback): + def set_traceback(self, traceback): # pylint: disable=redefined-outer-name + """Format a traceback to be embedded in the JSON-RPC format.""" + if self.data is None: self.data = {} if "traceback" not in self.data: # NOTE: TVM's FFI layer reorders Python stack traces several times and strips # intermediary lines that start with "Traceback". This logic adds a comment to the first - # stack frame to explicitly identify the first stack frame line that occurs on the server. + # stack frame to explicitly identify the first stack frame line that occurs on the + # server. traceback_list = list(traceback) - # The traceback list contains one entry per stack frame, and each entry contains 1-2 lines: + # The traceback list contains one entry per stack frame, and each entry contains 1-2 + # lines: # File "path/to/file", line 123, in : # - # We want to place a comment on the first line of the outermost frame to indicate this is the - # server-side stack frame. + # We want to place a comment on the first line of the outermost frame to indicate this + # is the server-side stack frame. first_frame_list = traceback_list[1].split("\n") self.data["traceback"] = ( traceback_list[0] @@ -307,7 +311,8 @@ def flash(self, options: dict): def open_transport(self, options: dict) -> TransportTimeouts: """Open resources needed for the transport layer. - This function might e.g. open files or serial ports needed in write_transport or read_transport. + This function might e.g. open files or serial ports needed in write_transport or + read_transport. Calling this function enables the write_transport and read_transport calls. If the transport is not open, this method is a no-op. @@ -323,7 +328,8 @@ def open_transport(self, options: dict) -> TransportTimeouts: def close_transport(self): """Close resources needed to operate the transport layer. - This function might e.g. close files or serial ports needed in write_transport or read_transport. + This function might e.g. close files or serial ports needed in write_transport or + read_transport. Calling this function disables the write_transport and read_transport calls. If the transport is not open, this method is a no-op. @@ -331,6 +337,7 @@ def close_transport(self): raise NotImplementedError() @abc.abstractmethod + # pylint: disable=unidiomatic-typecheck def read_transport(self, n: int, timeout_sec: typing.Union[float, type(None)]) -> bytes: """Read data from the transport. @@ -389,7 +396,8 @@ def write_transport(self, data: bytes, timeout_sec: float): class ProjectAPIServer: """Base class for Project API Servers. - This API server implements communication using JSON-RPC 2.0: https://www.jsonrpc.org/specification + This API server implements communication using JSON-RPC 2.0: + https://www.jsonrpc.org/specification Suggested use of this class is to import this module or copy this file into Project Generator implementations, then instantiate it with server.start(). @@ -451,7 +459,7 @@ def serve_one_request(self): _LOG.error("EOF") return False - except Exception as exc: + except Exception as exc: # pylint: disable=broad-except _LOG.error("Caught error reading request", exc_info=1) return False @@ -466,7 +474,7 @@ def serve_one_request(self): request_id = None if not did_validate else request.get("id") self._reply_error(request_id, exc) return did_validate - except Exception as exc: + except Exception as exc: # pylint: disable=broad-except message = "validating request" if did_validate: message = f"calling method {request['method']}" @@ -481,7 +489,7 @@ def serve_one_request(self): VALID_METHOD_RE = re.compile("^[a-zA-Z0-9_]+$") def _validate_request(self, request): - if type(request) is not dict: + if not isinstance(request, dict): raise JSONRPCError( ErrorCode.INVALID_REQUEST, f"request: want dict; got {request!r}", None ) @@ -493,7 +501,7 @@ def _validate_request(self, request): ) method = request.get("method") - if type(method) != str: + if not isinstance(method, str): raise JSONRPCError( ErrorCode.INVALID_REQUEST, f'request["method"]: want str; got {method!r}', None ) @@ -501,18 +509,20 @@ def _validate_request(self, request): if not self.VALID_METHOD_RE.match(method): raise JSONRPCError( ErrorCode.INVALID_REQUEST, - f'request["method"]: should match regex {self.VALID_METHOD_RE.pattern}; got {method!r}', + f'request["method"]: should match regex {self.VALID_METHOD_RE.pattern}; ' + f"got {method!r}", None, ) params = request.get("params") - if type(params) != dict: + if not isinstance(params, dict): raise JSONRPCError( ErrorCode.INVALID_REQUEST, f'request["params"]: want dict; got {type(params)}', None ) request_id = request.get("id") - if type(request_id) not in (str, int, type(None)): + # pylint: disable=unidiomatic-typecheck + if not isinstance(request_id, (str, int, type(None))): raise JSONRPCError( ErrorCode.INVALID_REQUEST, f'request["id"]: want str, number, null; got {request_id!r}', @@ -538,10 +548,11 @@ def _dispatch_request(self, request): params = {} for var_name, var_type in typing.get_type_hints(interface_method).items(): - if var_name == "self" or var_name == "return": + if var_name in ("self", "return"): continue - # NOTE: types can only be JSON-compatible types, so var_type is expected to be of type 'type'. + # NOTE: types can only be JSON-compatible types, so var_type is expected to be of type + # 'type'. if var_name not in request_params: raise JSONRPCError( ErrorCode.INVALID_PARAMS, @@ -553,7 +564,8 @@ def _dispatch_request(self, request): if not has_preprocessing and not isinstance(param, var_type): raise JSONRPCError( ErrorCode.INVALID_PARAMS, - f'method {request["method"]}: parameter {var_name}: want {var_type!r}, got {type(param)!r}', + f'method {request["method"]}: parameter {var_name}: want {var_type!r}, ' + f"got {type(param)!r}", None, ) @@ -636,7 +648,7 @@ def _await_nonblocking_ready(rlist, wlist, timeout_sec=None, end_time=None): return True -def read_with_timeout(fd, n, timeout_sec): +def read_with_timeout(fd, n, timeout_sec): # pylint: disable=invalid-name """Read data from a file descriptor, with timeout. This function is intended as a helper function for implementations of ProjectAPIHandler @@ -683,7 +695,7 @@ def read_with_timeout(fd, n, timeout_sec): return to_return -def write_with_timeout(fd, data, timeout_sec): +def write_with_timeout(fd, data, timeout_sec): # pylint: disable=invalid-name """Write data to a file descriptor, with timeout. This function is intended as a helper function for implementations of ProjectAPIHandler