Skip to content

Commit

Permalink
implement tensorflow tensor handlers (#421)
Browse files Browse the repository at this point in the history
* init tf_tensor_handler test

* implement tensorflow tensor http_handler

* implement tensorflow tensor cli_handler and lambda_handler

* move input tensor auto transforming into artifact

* fix tf_tensor_handler test

* add new arg 'method' to TensorflowTensorHandler

* add NestedDecoder; decode tensor handler output

* add warning about importing tf 1.x SavedModel

* style
  • Loading branch information
bojiang authored and parano committed Dec 19, 2019
1 parent 23bac65 commit db783a0
Show file tree
Hide file tree
Showing 5 changed files with 410 additions and 6 deletions.
98 changes: 97 additions & 1 deletion bentoml/artifact/tf_savedmodel_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,81 @@ def _is_path_like(p):
return isinstance(p, (str, bytes, pathlib.PurePath, os.PathLike))


class _TensorflowFunctionWrapper:
'''
TensorflowFunctionWrapper
transform input tensor following function input signature
'''

def __init__(self, origin_func, fullargspec):
self.origin_func = origin_func
self.fullargspec = fullargspec
self._args_to_indices = {arg: i for i, arg in enumerate(fullargspec.args)}

def __call__(self, *args, **kwargs):
signatures = self.origin_func.input_signature
for k in kwargs:
if k not in self._args_to_indices:
raise TypeError(f"Function got an unexpected keyword argument {k}")
signatures_by_kw = {
k: signatures[self._args_to_indices[k]]
for k in kwargs
}
# INFO:
# how signature with kwargs works?
# https://github.com/tensorflow/tensorflow/blob/v2.0.0/tensorflow/python/eager/function.py#L1519

transformed_args = tuple(
self._transform_input_by_tensorspec(arg, signatures[i])
for i, arg in enumerate(args))
transformed_kwargs = {
k: self._transform_input_by_tensorspec(arg, signatures_by_kw[k])
for k, arg in kwargs.items()
}
return self.origin_func(*transformed_args, **transformed_kwargs)

def __getattr__(self, k):
return getattr(self.origin_func, k)

@staticmethod
def _transform_input_by_tensorspec(_input, tensorspec):
'''
transform dtype & shape following tensorspec
'''
try:
import tensorflow as tf
except ImportError:
raise MissingDependencyException(
"Tensorflow package is required to use TfSavedModelArtifact")

if _input.dtype != tensorspec.dtype:
# may raise TypeError
_input = tf.dtypes.cast(_input, tensorspec.dtype)
if not tensorspec.is_compatible_with(_input):
_input = tf.reshape(_input, tuple(
i is None and -1 or i for i in tensorspec.shape))
return _input

@classmethod
def hook_loaded_model(cls, loaded_model):
try:
from tensorflow.python.util import tf_inspect
from tensorflow.python.eager import def_function
except ImportError:
raise MissingDependencyException(
"Tensorflow package is required to use TfSavedModelArtifact")

for k in dir(loaded_model):
v = getattr(loaded_model, k, None)
if isinstance(v, def_function.Function):
fullargspec = tf_inspect.getfullargspec(v)
setattr(loaded_model, k, cls(v, fullargspec))


def _load_tf_saved_model(path):
try:
import tensorflow as tf
from tensorflow.python.training.tracking.tracking import AutoTrackable

TF2 = tf.__version__.startswith('2')
except ImportError:
Expand All @@ -45,7 +117,30 @@ def _load_tf_saved_model(path):
if TF2:
return tf.saved_model.load(path)
else:
return tf.compat.v2.saved_model.load(path)
loaded = tf.compat.v2.saved_model.load(path)
if (isinstance(loaded, AutoTrackable)
and not hasattr(loaded, "__call__")):
logger.warning(
'''Importing SavedModels from TensorFlow 1.x.
`outputs = imported(inputs)` is not supported in bento service due to
tensorflow API.
Recommended usage:
```python
from tensorflow.python.saved_model import signature_constants
imported = tf.saved_model.load(path_to_v1_saved_model)
wrapped_function = imported.signatures[
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
wrapped_function(tf.ones([]))
```
See https://www.tensorflow.org/api_docs/python/tf/saved_model/load for
details.
'''
)
return loaded


class TensorflowSavedModelArtifact(BentoServiceArtifact):
Expand Down Expand Up @@ -128,6 +223,7 @@ def pack(
def load(self, path):
saved_model_path = self._saved_model_path(path)
loaded_model = _load_tf_saved_model(saved_model_path)
_TensorflowFunctionWrapper.hook_loaded_model(loaded_model)
return self.pack(loaded_model)


Expand Down
133 changes: 128 additions & 5 deletions bentoml/handlers/tensorflow_tensor_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,142 @@
from __future__ import division
from __future__ import print_function

from bentoml.handlers.base_handlers import BentoHandler
import json
import argparse
from flask import Response
from bentoml.handlers.utils import (
NestedConverter, tf_b64_2_bytes, tf_tendor_2_serializable)
from bentoml.handlers.base_handlers import BentoHandler, get_output_str
from bentoml.exceptions import BentoMLException, BadInput


decode_b64_if_needed = NestedConverter(tf_b64_2_bytes)
decode_tf_if_needed = NestedConverter(tf_tendor_2_serializable)


class TensorflowTensorHandler(BentoHandler):
"""
Tensor handlers for Tensorflow models
Tensor handlers for Tensorflow models.
Transform incoming tf tensor data from http request, cli or lambda event into
tf tensor.
The behaviour should be compatible with tensorflow serving REST API:
* https://www.tensorflow.org/tfx/serving/api_rest#classify_and_regress_api
* https://www.tensorflow.org/tfx/serving/api_rest#predict_api
Args:
* method: equivalence of serving API methods: (predict, classify, regress)
Raises:
BentoMLException: BentoML currently doesn't support Content-Type
"""
METHODS = (
PREDICT,
CLASSIFY,
REGRESS,
) = (
"predict",
"classify",
"regress",
)

def __init__(self, method=PREDICT):
self.method = method

@property
def request_schema(self):
if self.method == self.PREDICT:
return {
"application/json": {
"schema": {
"type": "object",
"properties": {
"signature_name": {
"type": "string",
"default": None,
},
"instances": {
"type": "array",
"items": {
"type": "object",
},
"default": None,
},
"inputs": {
"type": "object",
"default": None,
}
},
}
}
}
else:
raise NotImplementedError(f"method {self.method} is not implemented")

def _handle_raw_str(self, raw_str, output_format, func):
import tensorflow as tf
parsed_json = json.loads(raw_str)
if parsed_json.get("instances") is not None:
instances = parsed_json.get("instances")
instances = decode_b64_if_needed(instances)
parsed_tensor = tf.constant(instances)
result = func(parsed_tensor)
result = decode_tf_if_needed(result)

elif parsed_json.get("inputs"):
raise NotImplementedError("column format 'inputs' is not implemented")

if output_format == "json":
result_object = {"predictions": result}
result_str = json.dumps(result_object)
elif output_format == "str":
result_str = get_output_str(result, output_format)

return result_str

def handle_request(self, request, func):
raise NotImplementedError
"""Handle http request that has jsonlized tensorflow tensor. It will convert it
into a tf tensor for the function to consume.
Args:
request: incoming request object.
func: function that will take ndarray as its arg.
Return:
response object
"""
output_format = request.headers.get("output", "json")
if output_format not in {"json", "str"}:
raise BadInput(
"Request output must be 'json' or 'str' for this BentoService API")
if request.content_type == "application/json":
input_str = request.data.decode("utf-8")
output_format = request.headers.get("output", "json")
result_str = self._handle_raw_str(input_str, output_format, func)
return Response(
response=result_str, status=200, mimetype="application/json")
else:
raise BadInput(
"Request content-type must be 'application/json'"
" for this BentoService API")

def handle_cli(self, args, func):
raise NotImplementedError
parser = argparse.ArgumentParser()
parser.add_argument("--input", required=True)
parser.add_argument(
"-o", "--output", default="str", choices=["str", "json"]
)
parsed_args = parser.parse_args(args)

result = self._handle_raw_str(parsed_args.input, parsed_args.output, func)
print(result)

def handle_aws_lambda_event(self, event, func):
raise NotImplementedError
if event["headers"].get("Content-Type", "") == "application/json":
result = self._handle_raw_str(
event["body"], event["headers"].get("output", "json"), func)
else:
raise BentoMLException(
"BentoML currently doesn't support Content-Type: {content_type} for "
"AWS Lambda".format(content_type=event["headers"]["Content-Type"])
)

return {"statusCode": 200, "body": result}
84 changes: 84 additions & 0 deletions bentoml/handlers/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
TF_B64_KEY = 'b64'


def tf_b64_2_bytes(obj):
import base64
if isinstance(obj, dict) and TF_B64_KEY in obj:
return base64.b64decode(obj[TF_B64_KEY])
else:
return obj


def bytes_2_tf_b64(obj):
import base64
if isinstance(obj, bytes):
return {
TF_B64_KEY: base64.b64encode(obj).decode('utf-8')
}
else:
return obj


def tf_tendor_2_serializable(obj):
'''
To convert
tf.Tensor -> json serializable
np.ndarray -> json serializable
bytes -> {'b64': <b64_str>}
others -> themselves
'''
import tensorflow as tf
import numpy as np

# Tensor -> ndarray or object
if isinstance(obj, tf.Tensor):
if tf.__version__.startswith("1."):
with tf.compat.v1.Session():
obj = obj.numpy()
else:
obj = obj.numpy()

# ndarray -> serializable python object
TYPES = (int, float, str)
if isinstance(obj, np.ndarray):
for _type in TYPES:
# dtype of string/bytes ndarrays returned by tensor.numpy()
# are both np.dtype(object), which are not json serializable
try:
obj = obj.astype(_type)
except (UnicodeDecodeError, ValueError, OverflowError):
continue
break
else:
obj = np.vectorize(bytes_2_tf_b64)(obj)
obj = obj.tolist()
elif isinstance(obj, bytes):
# tensor.numpy() will return single value directly
try:
obj = obj.decode("utf8")
except UnicodeDecodeError:
obj = bytes_2_tf_b64(obj)

return obj


class NestedConverter:
'''
Generate a nested converter that supports object in list/tuple/dict
from a single converter.
'''

def __init__(self, converter):
self.converter = converter

def __call__(self, obj):
converted = self.converter(obj)
if obj is obj and converted is not obj:
return converted

if isinstance(obj, dict):
return {k: self(v) for k, v in obj.items()}
elif isinstance(obj, (list, tuple)):
return [self(v) for v in obj]
else:
return obj
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
"coverage>=4.4",
"codecov",
"moto",
"numpy",
]
+ imageio
+ aws_sam_cli
Expand Down
Loading

0 comments on commit db783a0

Please sign in to comment.