Skip to content

Commit

Permalink
fix: kservev2 batching issue and missing parameters
Browse files Browse the repository at this point in the history
Fix KServeV2 picking up only the first item from the list.
As a result it is able to dynamically batch the requests.
Updated that if the model returns a dictionary it picks up the keys
as the names of the outputs.
Update the envelope to handle request parameters and input parameters
  • Loading branch information
pkluska committed Dec 3, 2024
1 parent 3182443 commit da465e9
Showing 1 changed file with 69 additions and 42 deletions.
111 changes: 69 additions & 42 deletions ts/torch_handler/request_envelope/kservev2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""
import json
import logging
from typing import Optional

import numpy as np

Expand Down Expand Up @@ -104,24 +105,43 @@ def _from_json(self, body_list):
logger.debug("Bytes array is %s", body_list)

input_names = []
for index, input in enumerate(body_list[0]["inputs"]):
if input["datatype"] == "BYTES":
body_list[0]["inputs"][index]["data"] = input["data"][0]
else:
body_list[0]["inputs"][index]["data"] = (
np.array(input["data"]).reshape(tuple(input["shape"])).tolist()
)
input_names.append(input["name"])
parameters = []
ids = []
input_parameters = []
data_list = []

for body in body_list:
id = body.get("id")
ids.append(id)
params = body.get("parameters")
if params:
parameters.append(params)
inp_names = []
inp_params = []
for i, input in enumerate(body["inputs"]):
params = input.get("parameters")
if params:
inp_params.append(params)
if input["datatype"] == "BYTES":
body["inputs"][i]["data"] = input["data"][0]
else:
body["inputs"][i]["data"] = (
np.array(input["data"]).reshape(tuple(input["shape"])).tolist()
)
inp_names.append(input["name"])
data = body["inputs"] if len(body["inputs"]) > 1 else body["inputs"][0]
data_list.append(data)

input_parameters.append(inp_params)
input_names.append(inp_names)

setattr(self.context, "input_request_id", ids)
setattr(self.context, "input_names", input_names)
logger.debug("Bytes array is %s", body_list)
id = body_list[0].get("id")
if id and id.strip():
setattr(self.context, "input_request_id", body_list[0]["id"])
# TODO: Add parameters support
# parameters = body_list[0].get("parameters")
# if parameters:
# setattr(self.context, "input_parameters", body_list[0]["parameters"])
data_list = [inputs_list.get("inputs") for inputs_list in body_list][0]
setattr(self.context, "request_parameters", parameters)
setattr(self.context, "input_parameters", input_parameters)
logger.debug("Data array is %s", data_list)
logger.debug("Request paraemeters array is %s", parameters)
logger.debug("Input parameters is %s", input_parameters)
return data_list

def format_output(self, data):
Expand All @@ -145,41 +165,48 @@ def format_output(self, data):
"""
logger.debug("The Response of KServe v2 format %s", data)
response = {}
if hasattr(self.context, "input_request_id"):
response["id"] = getattr(self.context, "input_request_id")
delattr(self.context, "input_request_id")
else:
response["id"] = self.context.get_request_id(0)
# TODO: Add parameters support
# if hasattr(self.context, "input_parameters"):
# response["parameters"] = getattr(self.context, "input_parameters")
# delattr(self.context, "input_parameters")
response["model_name"] = self.context.manifest.get("model").get("modelName")
response["model_version"] = self.context.manifest.get("model").get(
"modelVersion"
)
response["outputs"] = self._batch_to_json(data)
return [response]

def _batch_to_json(self, data):
return self._batch_to_json(data)

def _batch_to_json(self, batch: dict):
"""
Splits batch output to json objects
"""
output = []
input_names = getattr(self.context, "input_names")
parameters = getattr(self.context, "request_parameters")
ids = getattr(self.context, "input_request_id")
input_parameters = getattr(self.context, "input_parameters")
responses = []
for index, data in enumerate(batch):
response = {}
response["id"] = ids[index] or self.context.get_request_id(index)
if parameters and parameters[index]:
response["parameters"] = parameters[index]
response["model_name"] = self.context.manifest.get("model").get("modelName")
response["model_version"] = self.context.manifest.get("model").get(
"modelVersion"
)
outputs = []
if isinstance(data, dict):
for key, item in data.items():
outputs.append(self._to_json(item, key, input_parameters))
else:
outputs.append(self._to_json(data, "predictions", input_parameters))
response["outputs"] = outputs
responses.append(response)
delattr(self.context, "input_names")
for index, item in enumerate(data):
output.append(self._to_json(item, input_names[index]))
return output
delattr(self.context, "input_request_id")
delattr(self.context, "input_parameters")
delattr(self.context, "request_parameters")
return responses

def _to_json(self, data, input_name):
def _to_json(self, data, output_name, parameters: Optional[list] = None):
"""
Constructs JSON object from data
"""
output_data = {}
data_ndarray = np.array(data).flatten()
output_data["name"] = input_name
output_data["name"] = output_name
if parameters:
output_data["parameters"] = parameters
output_data["datatype"] = _to_datatype(data_ndarray.dtype)
output_data["data"] = data_ndarray.tolist()
output_data["shape"] = data_ndarray.flatten().shape
Expand Down

0 comments on commit da465e9

Please sign in to comment.