From 771f2b03a9de6d654d4fa98c4c54e2a228c9268e Mon Sep 17 00:00:00 2001 From: Jagadeesh J Date: Sat, 20 May 2023 05:07:52 +0530 Subject: [PATCH] fix: kserve fastapi migration issues (#2175) * fix: kserve fastapi migration issues --- .../kserve/kf_request_json/v2/mnist/README.md | 25 +++--- .../v2/mnist/mnist_v2_bytes.json | 8 +- .../kf_request_json/v2/mnist/tobytes.py | 17 +++- .../kserve/kserve_wrapper/TorchserveModel.py | 81 ++----------------- kubernetes/kserve/kserve_wrapper/__main__.py | 8 +- ts/torch_handler/request_envelope/kservev2.py | 35 +++++--- 6 files changed, 61 insertions(+), 113 deletions(-) diff --git a/kubernetes/kserve/kf_request_json/v2/mnist/README.md b/kubernetes/kserve/kf_request_json/v2/mnist/README.md index f8d41eb552..dcfcd1bd2b 100644 --- a/kubernetes/kserve/kf_request_json/v2/mnist/README.md +++ b/kubernetes/kserve/kf_request_json/v2/mnist/README.md @@ -19,13 +19,13 @@ The command will create `mnist.mar` file in current directory Move the mar file to model-store -``` +```bash sudo mv mnist.mar /mnt/models/model-store ``` and use the following config properties (`/mnt/models/config`) -``` +```conf inference_address=http://0.0.0.0:8085 management_address=http://0.0.0.0:8085 metrics_address=http://0.0.0.0:8082 @@ -51,13 +51,13 @@ Move to `kubernetes/kserve/kf_request_json/v2/mnist` For bytes input, use [tobytes](tobytes.py) utility. -``` +```bash python tobytes.py 0.png ``` For tensor input, use [totensor](totensor.py) utility -``` +```bash python totensor.py 0.png ``` @@ -66,7 +66,7 @@ python totensor.py 0.png Start TorchServe -``` +```bash torchserve --start --ts-config /mnt/models/config/config.properties --ncs ``` @@ -74,7 +74,7 @@ To test locally, clone TorchServe and move to the following folder `kubernetes/k Start Kserve -``` +```bash python __main__.py ``` @@ -85,12 +85,12 @@ Navigate to `kubernetes/kserve/kf_request_json/v2/mnist` Run the following command ```bash -curl -v -H "ContentType: application/json" http://localhost:8080/v2/models/mnist/infer -d @./mnist_v2_bytes.json +curl -v -H "Content-Type: application/json" http://localhost:8080/v2/models/mnist/infer -d @./mnist_v2_bytes.json ``` Expected Output -```bash +```json {"id": "d3b15cad-50a2-4eaf-80ce-8b0a428bd298", "model_name": "mnist", "model_version": "1.0", "outputs": [{"name": "predict", "shape": [1], "datatype": "INT64", "data": [0]}]} ``` @@ -100,8 +100,8 @@ Expected Output Run the following command -``` -curl -v -H "ContentType: application/json" http://localhost:8080/v2/models/mnist/infer -d @./mnist_v2_tensor.json +```bash +curl -v -H "Content-Type: application/json" http://localhost:8080/v2/models/mnist/infer -d @./mnist_v2_tensor.json ``` Expected output @@ -115,10 +115,11 @@ Expected output Run the following command ```bash -curl -v -H "ContentType: application/json" http://localhost:8080/v2/models/mnist/explain -d @./mnist_v2_bytes.json +curl -v -H "Content-Type: application/json" http://localhost:8080/v2/models/mnist/explain -d @./mnist_v2_bytes.json ``` Expected output -```bash + +```json {"id": "d3b15cad-50a2-4eaf-80ce-8b0a428bd298", "model_name": "mnist", "model_version": "1.0", "outputs": [{"name": "explain", "shape": [1, 28, 28], "datatype": "FP64", "data": [-0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, 0.0, -0.0, -0.0, 0.0, -0.0, 0.0, -0.0, -0.0, -0.0, -0.0, -0.0, 0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, 0.0, -0.0, 0.0, -0.0, -0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0, -0.0, 0.0, 0.0, -0.0, 0.0, -0.0, -0.0, -0.0, -0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0, -0.0, -0.0, 0.0, -0.0, 0.0, 0.0, 0.0, -0.0, -0.0, -0.0, 0.0, -0.0, 0.0, -0.0, -0.0, -0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0, -0.0, 0.0, 0.0, -0.0, -0.0, -0.0, -0.0, -0.0, 0.0, 0.0, -0.0, -0.0, -0.0, 0.0, 0.0, 0.0, -0.0, -0.0, -0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0, -0.0040547529196303285, -0.000226128774499257, -0.00012734138382422276, 0.005648369544853077, 0.0089047843954152, 0.002638536593970295, 0.002680245911942565, -0.0026578015819202173, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, 0.0, -0.0, -0.0, -0.0, -0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.00024465771891337887, 0.0008218450954311162, 0.01528591767842519, 0.007512832335428859, 0.00709498458333515, 0.0034056686436576803, -0.002091925041823873, -0.0007800293875604465, 0.02299587827540853, 0.019004329367380418, -0.0012529559050418735, -0.0014666116646934577, -0.0, -0.0, -0.0, 0.0, 0.0, 0.0, 0.0, -0.0, -0.0, -0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.005298396405518712, -0.0007901605729004231, 0.0039060659926479398, 0.023174082126728335, 0.01723791770922474, 0.010867034167828598, 0.003001563229273835, 0.00622421771715703, 0.006120712207087491, 0.01673632965122119, 0.005674718948781803, 0.004344134599735745, -0.0012328422311881568, -0.0, -0.0, -0.0, 0.0, 0.0, -0.0, -0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0, 0.0006867353833785289, 0.009772899792600862, -0.0038754932221901437, 0.001798693579973005, 0.001307544047675232, -0.0024510981010352315, -0.0008806773488194292, -0.0, -0.0, -0.00014277890760828639, -0.009322313235257151, 0.020608317727589167, 0.004351394518148479, -0.0007875566214137449, -0.0009075897508410689, -0.0, -0.0, 0.0, 0.0, 0.0, -0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.00022247238084657642, -0.0007829029819622099, 0.0026663695200516055, 0.0009733366691924418, 0.0, -0.0, 0.0, 0.0, 0.0, 0.0, -0.0, 0.0004323207980879993, 0.023657171939959983, 0.01069484496100618, -0.0023759529165659743, -0.0, -0.0, 0.0, 0.0, -0.0, -0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0, -0.002074797197335781, -0.002320101263777886, -0.001289920656543141, 0.0, 0.0, 0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, 0.007629679763806616, 0.01044862710854819, 0.00025032875474040415, -0.0, -0.0, 0.0, 0.0, 0.0, -0.0, -0.0, 0.0, 0.0, 0.0, 0.0, -0.0, -0.0, -0.0003770836745884539, -0.005156369309364184, 0.0012477582083019567, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0, -0.0, 0.0, -0.0, -4.442513564501309e-05, 0.010248046436803096, 0.0009971133914441863, -0.0, -0.0, 0.0, 0.0, 0.0, -0.0, -0.0, 0.0, -0.0, 0.0, 0.0, -0.0, 0.0004501048922351147, -0.00196305355861066, -0.0006664792277975681, 0.0020157403871024866, 0.0, 0.0, -0.0, -0.0, -0.0, -0.0, -0.0, 0.0, -0.0, -0.002214456978582924, 0.008361583668963536, 0.0031401942747203444, -0.0, -0.0, 0.0, 0.0, 0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0028943545250037983, -0.0031301382844878753, 0.002113252994616467, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0, -0.0, 0.0, -0.0, -0.0, -0.0010321050071136991, 0.008905753948020954, 0.0028464383724280478, -0.0, -0.0, 0.0, 0.0, 0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, 0.0, -0.0053052889804602885, -0.0019271100770928186, 0.0012090042664300153, 0.0, 0.0, 0.0, -0.0, -0.0, -0.0, 0.0, 0.0, 0.0, -0.0, -0.0011945155805738324, 0.005654442809865844, 0.0020132075147173286, -0.0, -0.0, 0.0, 0.0, 0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, 0.0, -0.0014689358119857122, 0.0010743412654248086, 0.0, 0.0, 0.0, 0.0, -0.0, -0.0, -0.0, -0.0, 0.0, -0.0, -0.0, -0.0017047980433136346, 0.0029066051664685937, -0.0007805868937027288, -0.0, -0.0, 0.0, 0.0, 0.0, -0.0, -0.0, -0.0, 0.0, -0.0, -0.0, 5.541726090138969e-05, 0.0014516115182299915, 0.0002827700518397855, 0.0, 0.0, 0.0, -0.0, -0.0, -0.0, 0.0, 0.0, 0.0, 0.0, -0.0, -0.001440140782635336, 0.002381249982038837, 0.002146825452068144, -0.0, -0.0, 0.0, -0.0, 0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, 0.001150052970321427, 0.0002865015237050364, 0.0029798150346815985, 0.0, 0.0, 0.0, -0.0, -0.0, -0.0, 0.0, 0.0, 0.0, -0.0, -0.0, -0.001775029606380323, 0.000833985914685474, -0.003770739075457816, -0.0, -0.0, 0.0, 0.0, 0.0, -0.0, -0.0, -0.0, -0.0, 0.0, 0.0, -0.0006093176893524411, -0.00046905781658387527, 0.0034053217440919658, 0.0, 0.0, 0.0, 0.0, -0.0, -0.0, -0.0, 0.0, -0.0, -0.0, -0.0007450012183962096, 0.001298767353118675, -0.008499247802184222, -6.145165255574976e-05, -0.0, -0.0, -0.0, 0.0, 0.0, -0.0, -0.0, -0.0, 0.0, -0.0, 0.0, 0.0011809726462884672, -0.0018384763902449712, 0.005411106715800028, 0.0, 0.0, 0.0, 0.0, -0.0, -0.0, -0.0, -0.0, 0.0, -0.0021392341817010304, 0.0003259163122540385, -0.005276118905978749, -0.0019509840184772497, -9.545685077687876e-07, 0.0, -0.0, 0.0, 0.0, 0.0, -0.0, -0.0, -0.0, -0.0, 0.0, 0.0, 0.0007772404694664217, -0.0001517954537059768, 0.006481484678129392, -0.0, 0.0, 0.0, -0.0, -0.0, -0.0, -0.0, -0.0, 8.098064554131295e-05, -0.0024904264199929506, -0.0020718618328775897, -5.3411287747038166e-05, -0.0004556472202791715, 0.0, -0.0, -0.0, 0.0, 0.0, 0.0, -0.0, -0.0, -0.0, -0.0, 0.0, 0.0, 0.0, 0.0022750984867578, 0.001716405971437602, 0.0003221344811922982, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0015560282437342534, 9.107229584202956e-05, 0.0008772841867241755, 0.0006502979194500701, -0.004128780661881036, 0.0006030386196211547, 0.0, -0.0, 0.0, -0.0, -0.0, 0.0, 0.0, -0.0, -0.0, 0.0, -0.0, -0.0, 0.0, 0.0, 0.0013959959731925453, 0.0026791526421029673, 0.002399500793142178, -0.00044960969955281656, 0.003101832495190209, 0.007494535809079955, 0.002864118744003058, -0.003052590549800204, 0.003420222341277871, 0.0014924017873988514, -0.0009357389226494119, 0.0007856229438140384, -0.001843397373255761, 1.6031851430693252e-05, 0.0, 0.0, -0.0, -0.0, 0.0, 0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, 0.0, -0.000699901824825285, 0.0043822508549258565, -0.003541931476855951, -0.0028896746311921715, -0.0004873454583246359, -0.006087345141728267, 0.000388224886755815, 0.002533641621974457, -0.004352836429303485, -0.0006079421449756437, -0.003810133409713042, -0.0008284413779488711, 0.0, -0.0, 0.0, 0.0, -0.0, 0.0, 0.0, 0.0, -0.0, -0.0, -0.0, 0.0, -0.0, -0.0, -0.0, -0.0, -0.0, 0.0010901530854686326, -0.013135007707490608, 0.0004734520308098294, 0.0020504232707536456, -0.006609452262924153, 0.0023647861306777536, 0.004678920703192049, -0.0018122526857900652, 0.0021375383049022263, 0.0, -0.0, -0.0, 0.0, 0.0, -0.0, -0.0, -0.0, 0.0, 0.0, 0.0, -0.0, -0.0, 0.0, -0.0, -0.0, -0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0, -0.0, -0.0, 0.0, -0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0, -0.0, -0.0, 0.0, 0.0, -0.0, -0.0, 0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0, -0.0, 0.0, -0.0, -0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0, -0.0, -0.0, 0.0, 0.0, -0.0, -0.0, -0.0, 0.0, 0.0, 0.0, 0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, 0.0, 0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.0, 0.0, -0.0, -0.0, -0.0, -0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}]} ``` diff --git a/kubernetes/kserve/kf_request_json/v2/mnist/mnist_v2_bytes.json b/kubernetes/kserve/kf_request_json/v2/mnist/mnist_v2_bytes.json index 0c07866dba..683ada7b73 100644 --- a/kubernetes/kserve/kf_request_json/v2/mnist/mnist_v2_bytes.json +++ b/kubernetes/kserve/kf_request_json/v2/mnist/mnist_v2_bytes.json @@ -1,10 +1,10 @@ { "inputs": [ { - "data": "iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAAA10lEQVR4nGNgGFhgy6xVdrCszBaLFN/mr28+/QOCr69DMCSnA8WvHti0acu/fx/10OS0X/975CDDw8DA1PDn/1pBVEmLf3+zocy2X/+8USXt/82Ds+/+m4sqeehfOpw97d9VFDmlO++t4JwQNMm6f6sZcEpee2+DR/I4A05J7tt4JJP+IUsu+ncRp6TxO9RAQJY0XvrvMAuypNNHuCTz8n+PzVEcy3DtqgiY1ptx6t8/ewY0yX9ntoDA63//Xs3hQpMMPPsPAv68qmDAAFKXwHIzMzCl6AoAxXp0QujtP+8AAAAASUVORK5CYII=", + "data": ["iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAAA10lEQVR4nGNgGFhgy6xVdrCszBaLFN/mr28+/QOCr69DMCSnA8WvHti0acu/fx/10OS0X/975CDDw8DA1PDn/1pBVEmLf3+zocy2X/+8USXt/82Ds+/+m4sqeehfOpw97d9VFDmlO++t4JwQNMm6f6sZcEpee2+DR/I4A05J7tt4JJP+IUsu+ncRp6TxO9RAQJY0XvrvMAuypNNHuCTz8n+PzVEcy3DtqgiY1ptx6t8/ewY0yX9ntoDA63//Xs3hQpMMPPsPAv68qmDAAFKXwHIzMzCl6AoAxXp0QujtP+8AAAAASUVORK5CYII="], "datatype": "BYTES", - "name": "312a4eb0-0ca7-4803-a101-a6d2c18486fe", - "shape": -1 + "name": "e8d5afed-0a56-4deb-ac9c-352663f51b93", + "shape": [-1] } ] -} \ No newline at end of file +} diff --git a/kubernetes/kserve/kf_request_json/v2/mnist/tobytes.py b/kubernetes/kserve/kf_request_json/v2/mnist/tobytes.py index f065acd31f..71ef7d3b62 100644 --- a/kubernetes/kserve/kf_request_json/v2/mnist/tobytes.py +++ b/kubernetes/kserve/kf_request_json/v2/mnist/tobytes.py @@ -1,6 +1,6 @@ +import argparse import base64 import json -import argparse import uuid parser = argparse.ArgumentParser() @@ -10,11 +10,20 @@ image = open(args.filename, "rb") # open binary file in read mode image_read = image.read() image_64_encode = base64.b64encode(image_read) -bytes_array = image_64_encode.decode("utf-8") +bytes_array = list(image_64_encode.decode("utf-8")) request = { - "inputs": [{"name": str(uuid.uuid4()), "shape": -1, "datatype": "BYTES", "data": bytes_array}] + "inputs": [ + { + "name": str(uuid.uuid4()), + "shape": [-1], + "datatype": "BYTES", + "data": bytes_array, + } + ] } -result_file = "{filename}.{ext}".format(filename=str(args.filename).split(".")[0], ext="json") +result_file = "{filename}.{ext}".format( + filename=str(args.filename).split(".")[0], ext="json" +) with open(result_file, "w") as outfile: json.dump(request, outfile, indent=4, sort_keys=True) diff --git a/kubernetes/kserve/kserve_wrapper/TorchserveModel.py b/kubernetes/kserve/kserve_wrapper/TorchserveModel.py index abf47959ed..aa28a50aa7 100644 --- a/kubernetes/kserve/kserve_wrapper/TorchserveModel.py +++ b/kubernetes/kserve/kserve_wrapper/TorchserveModel.py @@ -1,23 +1,19 @@ """ The torchserve side inference end-points request are handled to return a KServe side response """ -import json import logging import pathlib -from typing import Dict import kserve -import tornado.web +from kserve.errors import ModelMissingError from kserve.model import Model as Model -from kserve.model import ModelMissingError logging.basicConfig(level=kserve.constants.KSERVE_LOGLEVEL) +PREDICTOR_URL_FORMAT = PREDICTOR_V2_URL_FORMAT = "http://{0}/predictions/{1}" +EXPLAINER_URL_FORMAT = EXPLAINER_V2_URL_FORMAT = "http://{0}/explanations/{1}" REGISTER_URL_FORMAT = "{0}/models?initial_workers=1&url={1}" UNREGISTER_URL_FORMAT = "{0}/models/{1}" -PREDICTOR_URL_FORMAT = "http://{0}/v1/models/{1}:predict" -EXPLAINER_URL_FORMAT = "http://{0}/v1/models/{1}:explain" - class TorchserveModel(Model): """The torchserve side inference and explain end-points requests are handled to @@ -49,76 +45,9 @@ def __init__(self, name, inference_address, management_address, model_dir): self.management_address = management_address self.model_dir = model_dir - logging.info("kfmodel Predict URL set to %s", self.predictor_host) + logging.info("Predict URL set to %s", self.predictor_host) self.explainer_host = self.predictor_host - logging.info("kfmodel Explain URL set to %s", self.explainer_host) - - async def predict(self, request: Dict) -> Dict: - """The predict method is called when we hit the inference endpoint and handles - the inference request and response from the Torchserve side and passes it on - to the KServe side. - - Args: - request (Dict): Input request from the http client side. - - Raises: - NotImplementedError: If the predictor host on the KServe side is not - available. - - tornado.web.HTTPError: If there is a bad response from the http client. - - Returns: - Dict: The Response from the input from the inference endpoint. - """ - if not self.predictor_host: - raise NotImplementedError - logging.debug("kfmodel predict request is %s", json.dumps(request)) - logging.info("PREDICTOR_HOST : %s", self.predictor_host) - headers = {"Content-Type": "application/json; charset=UTF-8"} - response = await self._http_client.fetch( - PREDICTOR_URL_FORMAT.format(self.predictor_host, self.name), - method="POST", - request_timeout=self.timeout, - headers=headers, - body=json.dumps(request), - ) - - if response.code != 200: - raise tornado.web.HTTPError(status_code=response.code, reason=response.body) - return json.loads(response.body) - - async def explain(self, request: Dict) -> Dict: - """The predict method is called when we hit the explain endpoint and handles the - explain request and response from the Torchserve side and passes it on to the - KServe side. - - Args: - request (Dict): Input request from the http client side. - - Raises: - NotImplementedError: If the predictor host on the KServe side is not - available. - - tornado.web.HTTPError: If there is a bad response from the http client. - - Returns: - Dict: The Response from the input from the explain endpoint. - """ - if self.explainer_host is None: - raise NotImplementedError - logging.info("kfmodel explain request is %s", json.dumps(request)) - logging.info("EXPLAINER_HOST : %s", self.explainer_host) - headers = {"Content-Type": "application/json; charset=UTF-8"} - response = await self._http_client.fetch( - EXPLAINER_URL_FORMAT.format(self.explainer_host, self.name), - method="POST", - request_timeout=self.timeout, - headers=headers, - body=json.dumps(request), - ) - if response.code != 200: - raise tornado.web.HTTPError(status_code=response.code, reason=response.body) - return json.loads(response.body) + logging.info("Explain URL set to %s", self.explainer_host) def load(self) -> bool: """This method validates model availabilty in the model directory diff --git a/kubernetes/kserve/kserve_wrapper/__main__.py b/kubernetes/kserve/kserve_wrapper/__main__.py index e8063426fe..b31e3df375 100644 --- a/kubernetes/kserve/kserve_wrapper/__main__.py +++ b/kubernetes/kserve/kserve_wrapper/__main__.py @@ -12,7 +12,7 @@ DEFAULT_MODEL_NAME = "model" DEFAULT_INFERENCE_ADDRESS = "http://127.0.0.1:8085" INFERENCE_PORT = "8085" -DEFAULT_MANAGEMENT_ADDRESS = "http://127.0.0.1:8081" +DEFAULT_MANAGEMENT_ADDRESS = "http://127.0.0.1:8085" DEFAULT_MODEL_STORE = "/mnt/models/model-store" CONFIG_PATH = "/mnt/models/config/config.properties" @@ -31,10 +31,8 @@ def parse_config(): keys = {} with open(CONFIG_PATH) as f: - for line in f: if separator in line: - # Find the name and value by splitting the string name, value = line.split(separator, 1) @@ -79,13 +77,11 @@ def parse_config(): if __name__ == "__main__": - model_names, inference_address, management_address, model_dir = parse_config() models = [] for model_name in model_names: - model = TorchserveModel( model_name, inference_address, management_address, model_dir ) @@ -100,5 +96,5 @@ def parse_config(): ModelServer( registered_models=registeredModels, http_port=8080, - grpc_port=7070, + grpc_port=8081, ).start(models) diff --git a/ts/torch_handler/request_envelope/kservev2.py b/ts/torch_handler/request_envelope/kservev2.py index 33e573cfb9..5a88e9497d 100644 --- a/ts/torch_handler/request_envelope/kservev2.py +++ b/ts/torch_handler/request_envelope/kservev2.py @@ -4,7 +4,9 @@ """ import json import logging + import numpy as np + from .base import BaseEnvelope logger = logging.getLogger(__name__) @@ -87,7 +89,9 @@ def _batch_from_json(self, rows): Joins the instances of a batch of JSON objects """ logger.debug("Parse input data %s", rows) - body_list = [body_list.get("data") or body_list.get("body") for body_list in rows] + body_list = [ + body_list.get("data") or body_list.get("body") for body_list in rows + ] data_list = self._from_json(body_list) return data_list @@ -99,7 +103,15 @@ def _from_json(self, body_list): if isinstance(body_list[0], (bytes, bytearray)): body_list = [json.loads(body.decode()) for body in body_list] logger.debug("Bytes array is %s", body_list) - if "id" in body_list[0]: + + input_names = [] + for index, input in enumerate(body_list[0]["inputs"]): + if input["datatype"] == "BYTES": + body_list[0]["inputs"][index]["data"] = input["data"][0] + input_names.append(input["name"]) + setattr(self.context, "input_names", input_names) + logger.debug("Bytes array is %s", body_list) + if body_list[0].get("id") is not None: setattr(self.context, "input_request_id", body_list[0]["id"]) data_list = [inputs_list.get("inputs") for inputs_list in body_list][0] return data_list @@ -116,7 +128,7 @@ def format_output(self, data): "model_name": "bert", "model_version": "1", "outputs": [{ - "name": "predict", + "name": "input-0", "shape": [1], "datatype": "INT64", "data": [2] @@ -131,10 +143,10 @@ def format_output(self, data): delattr(self.context, "input_request_id") else: response["id"] = self.context.get_request_id(0) - response["model_name"] = self.context.manifest.get("model").get( - "modelName") + response["model_name"] = self.context.manifest.get("model").get("modelName") response["model_version"] = self.context.manifest.get("model").get( - "modelVersion") + "modelVersion" + ) response["outputs"] = self._batch_to_json(data) return [response] @@ -143,18 +155,19 @@ def _batch_to_json(self, data): Splits batch output to json objects """ output = [] - for item in data: - output.append(self._to_json(item)) + input_names = getattr(self.context, "input_names") + delattr(self.context, "input_names") + for index, item in enumerate(data): + output.append(self._to_json(item, input_names[index])) return output - def _to_json(self, data): + def _to_json(self, data, input_name): """ Constructs JSON object from data """ output_data = {} data_ndarray = np.array(data) - output_data["name"] = ("explain" if self.context.get_request_header( - 0, "explain") == "True" else "predict") + output_data["name"] = input_name output_data["shape"] = list(data_ndarray.shape) output_data["datatype"] = _to_datatype(data_ndarray.dtype) output_data["data"] = data_ndarray.flatten().tolist()