Skip to content

Commit

Permalink
Set tvm.micro.project_api as a Python Module (#8963)
Browse files Browse the repository at this point in the history
* Add missing tvm.micro.project_api module file. The missing
  __init__.py makes it impossible to import this module with
  `import tvm.micro.project_api`.

* This uncover 30-ish linting errors, which are also fixed here.
  • Loading branch information
leandron authored Sep 9, 2021
1 parent f8b1df4 commit 2e0a711
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 29 deletions.
17 changes: 17 additions & 0 deletions python/tvm/micro/project_api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""MicroTVM Project API Client and Server"""
16 changes: 11 additions & 5 deletions python/tvm/micro/project_api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""
Project API client.
"""
import base64
import io
import json
import logging
import platform
import os
import pathlib
import subprocess
Expand Down Expand Up @@ -56,6 +59,7 @@ class UnsupportedProtocolVersionError(ProjectAPIErrorBase):

class RPCError(ProjectAPIErrorBase):
def __init__(self, request, error):
ProjectAPIErrorBase.__init__()
self.request = request
self.error = error

Expand Down Expand Up @@ -129,7 +133,8 @@ def _request_reply(self, method, params):

if "error" in reply:
raise server.JSONRPCError.from_json(f"calling method {method}", reply["error"])
elif "result" not in reply:

if "result" not in reply:
raise MalformedReplyError(f"Expected 'result' key in server reply, got {reply!r}")

return reply["result"]
Expand Down Expand Up @@ -189,15 +194,16 @@ def write_transport(self, data, timeout_sec):

# NOTE: windows support untested
SERVER_LAUNCH_SCRIPT_FILENAME = (
f"launch_microtvm_api_server.{'sh' if os.system != 'win32' else '.bat'}"
f"launch_microtvm_api_server.{'sh' if platform.system() != 'Windows' else '.bat'}"
)


SERVER_PYTHON_FILENAME = "microtvm_api_server.py"


def instantiate_from_dir(project_dir: typing.Union[pathlib.Path, str], debug: bool = False):
"""Launch server located in project_dir, and instantiate a Project API Client connected to it."""
"""Launch server located in project_dir, and instantiate a Project API Client
connected to it."""
args = None

project_dir = pathlib.Path(project_dir)
Expand All @@ -224,7 +230,7 @@ def instantiate_from_dir(project_dir: typing.Union[pathlib.Path, str], debug: bo
if debug:
args.append("--debug")

api_server_proc = subprocess.Popen(
api_server_proc = subprocess.Popen( # pylint: disable=unused-variable
args, bufsize=0, pass_fds=(api_server_read_fd, api_server_write_fd), cwd=project_dir
)
os.close(api_server_read_fd)
Expand Down
60 changes: 36 additions & 24 deletions python/tvm/micro/project_api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import re
import select
import sys
import textwrap
import time
import traceback
import typing
Expand Down Expand Up @@ -100,6 +99,7 @@ class JSONRPCError(Exception):
"""An error class with properties that meet the JSON-RPC error spec."""

def __init__(self, code, message, data, client_context=None):
Exception.__init__(self)
self.code = code
self.message = message
self.data = data
Expand All @@ -123,9 +123,7 @@ def __str__(self):

@classmethod
def from_json(cls, client_context, json_error):
# Subclasses of ServerError capture exceptions that occur in the Handler, and thus return a
# traceback. The encoding in `json_error` is also slightly different to allow the specific subclass
# to be identified.
"""Convert an encapsulated ServerError into JSON-RPC compliant format."""
found_server_error = False
try:
if ErrorCode(json_error["code"]) == ErrorCode.SERVER_ERROR:
Expand All @@ -145,6 +143,8 @@ def from_json(cls, client_context, json_error):


class ServerError(JSONRPCError):
"""Superclass for JSON-RPC errors which occur while processing valid requests."""

@classmethod
def from_exception(cls, exc, **kw):
to_return = cls(**kw)
Expand All @@ -168,21 +168,25 @@ def __str__(self):
super_str = super(ServerError, self).__str__()
return context_str + super_str

def set_traceback(self, traceback):
def set_traceback(self, traceback): # pylint: disable=redefined-outer-name
"""Format a traceback to be embedded in the JSON-RPC format."""

if self.data is None:
self.data = {}

if "traceback" not in self.data:
# NOTE: TVM's FFI layer reorders Python stack traces several times and strips
# intermediary lines that start with "Traceback". This logic adds a comment to the first
# stack frame to explicitly identify the first stack frame line that occurs on the server.
# stack frame to explicitly identify the first stack frame line that occurs on the
# server.
traceback_list = list(traceback)

# The traceback list contains one entry per stack frame, and each entry contains 1-2 lines:
# The traceback list contains one entry per stack frame, and each entry contains 1-2
# lines:
# File "path/to/file", line 123, in <method>:
# <copy of the line>
# We want to place a comment on the first line of the outermost frame to indicate this is the
# server-side stack frame.
# We want to place a comment on the first line of the outermost frame to indicate this
# is the server-side stack frame.
first_frame_list = traceback_list[1].split("\n")
self.data["traceback"] = (
traceback_list[0]
Expand Down Expand Up @@ -307,7 +311,8 @@ def flash(self, options: dict):
def open_transport(self, options: dict) -> TransportTimeouts:
"""Open resources needed for the transport layer.
This function might e.g. open files or serial ports needed in write_transport or read_transport.
This function might e.g. open files or serial ports needed in write_transport or
read_transport.
Calling this function enables the write_transport and read_transport calls. If the
transport is not open, this method is a no-op.
Expand All @@ -323,14 +328,16 @@ def open_transport(self, options: dict) -> TransportTimeouts:
def close_transport(self):
"""Close resources needed to operate the transport layer.
This function might e.g. close files or serial ports needed in write_transport or read_transport.
This function might e.g. close files or serial ports needed in write_transport or
read_transport.
Calling this function disables the write_transport and read_transport calls. If the
transport is not open, this method is a no-op.
"""
raise NotImplementedError()

@abc.abstractmethod
# pylint: disable=unidiomatic-typecheck
def read_transport(self, n: int, timeout_sec: typing.Union[float, type(None)]) -> bytes:
"""Read data from the transport.
Expand Down Expand Up @@ -389,7 +396,8 @@ def write_transport(self, data: bytes, timeout_sec: float):
class ProjectAPIServer:
"""Base class for Project API Servers.
This API server implements communication using JSON-RPC 2.0: https://www.jsonrpc.org/specification
This API server implements communication using JSON-RPC 2.0:
https://www.jsonrpc.org/specification
Suggested use of this class is to import this module or copy this file into Project Generator
implementations, then instantiate it with server.start().
Expand Down Expand Up @@ -451,7 +459,7 @@ def serve_one_request(self):
_LOG.error("EOF")
return False

except Exception as exc:
except Exception as exc: # pylint: disable=broad-except
_LOG.error("Caught error reading request", exc_info=1)
return False

Expand All @@ -466,7 +474,7 @@ def serve_one_request(self):
request_id = None if not did_validate else request.get("id")
self._reply_error(request_id, exc)
return did_validate
except Exception as exc:
except Exception as exc: # pylint: disable=broad-except
message = "validating request"
if did_validate:
message = f"calling method {request['method']}"
Expand All @@ -481,7 +489,7 @@ def serve_one_request(self):
VALID_METHOD_RE = re.compile("^[a-zA-Z0-9_]+$")

def _validate_request(self, request):
if type(request) is not dict:
if not isinstance(request, dict):
raise JSONRPCError(
ErrorCode.INVALID_REQUEST, f"request: want dict; got {request!r}", None
)
Expand All @@ -493,26 +501,28 @@ def _validate_request(self, request):
)

method = request.get("method")
if type(method) != str:
if not isinstance(method, str):
raise JSONRPCError(
ErrorCode.INVALID_REQUEST, f'request["method"]: want str; got {method!r}', None
)

if not self.VALID_METHOD_RE.match(method):
raise JSONRPCError(
ErrorCode.INVALID_REQUEST,
f'request["method"]: should match regex {self.VALID_METHOD_RE.pattern}; got {method!r}',
f'request["method"]: should match regex {self.VALID_METHOD_RE.pattern}; '
f"got {method!r}",
None,
)

params = request.get("params")
if type(params) != dict:
if not isinstance(params, dict):
raise JSONRPCError(
ErrorCode.INVALID_REQUEST, f'request["params"]: want dict; got {type(params)}', None
)

request_id = request.get("id")
if type(request_id) not in (str, int, type(None)):
# pylint: disable=unidiomatic-typecheck
if not isinstance(request_id, (str, int, type(None))):
raise JSONRPCError(
ErrorCode.INVALID_REQUEST,
f'request["id"]: want str, number, null; got {request_id!r}',
Expand All @@ -538,10 +548,11 @@ def _dispatch_request(self, request):
params = {}

for var_name, var_type in typing.get_type_hints(interface_method).items():
if var_name == "self" or var_name == "return":
if var_name in ("self", "return"):
continue

# NOTE: types can only be JSON-compatible types, so var_type is expected to be of type 'type'.
# NOTE: types can only be JSON-compatible types, so var_type is expected to be of type
# 'type'.
if var_name not in request_params:
raise JSONRPCError(
ErrorCode.INVALID_PARAMS,
Expand All @@ -553,7 +564,8 @@ def _dispatch_request(self, request):
if not has_preprocessing and not isinstance(param, var_type):
raise JSONRPCError(
ErrorCode.INVALID_PARAMS,
f'method {request["method"]}: parameter {var_name}: want {var_type!r}, got {type(param)!r}',
f'method {request["method"]}: parameter {var_name}: want {var_type!r}, '
f"got {type(param)!r}",
None,
)

Expand Down Expand Up @@ -636,7 +648,7 @@ def _await_nonblocking_ready(rlist, wlist, timeout_sec=None, end_time=None):
return True


def read_with_timeout(fd, n, timeout_sec):
def read_with_timeout(fd, n, timeout_sec): # pylint: disable=invalid-name
"""Read data from a file descriptor, with timeout.
This function is intended as a helper function for implementations of ProjectAPIHandler
Expand Down Expand Up @@ -683,7 +695,7 @@ def read_with_timeout(fd, n, timeout_sec):
return to_return


def write_with_timeout(fd, data, timeout_sec):
def write_with_timeout(fd, data, timeout_sec): # pylint: disable=invalid-name
"""Write data to a file descriptor, with timeout.
This function is intended as a helper function for implementations of ProjectAPIHandler
Expand Down

0 comments on commit 2e0a711

Please sign in to comment.