diff --git a/python/kserve/kserve/protocol/infer_type.py b/python/kserve/kserve/protocol/infer_type.py index b5da94df013..1298827ff27 100644 --- a/python/kserve/kserve/protocol/infer_type.py +++ b/python/kserve/kserve/protocol/infer_type.py @@ -621,12 +621,13 @@ def __eq__(self, other): class InferResponse: id: str model_name: str + model_version: Optional[str] parameters: Optional[Dict] outputs: List[InferOutput] from_grpc: bool def __init__(self, response_id: str, model_name: str, infer_outputs: List[InferOutput], - raw_outputs=None, from_grpc: Optional[bool] = False, + model_version: Optional[str] = None, raw_outputs=None, from_grpc: Optional[bool] = False, parameters: Optional[Union[Dict, MessageMap[str, InferParameter]]] = None): """The InferResponse Data Model @@ -634,6 +635,7 @@ def __init__(self, response_id: str, model_name: str, infer_outputs: List[InferO response_id: The id of the inference response. model_name: The name of the model. infer_outputs: The inference outputs of the inference response. + model_version: The version of the model. raw_outputs: The raw binary data of the inference outputs. from_grpc: Indicate if the InferResponse is constructed from a gRPC response. parameters: The additional inference parameters. @@ -641,6 +643,7 @@ def __init__(self, response_id: str, model_name: str, infer_outputs: List[InferO self.id = response_id self.model_name = model_name + self.model_version = model_version self.outputs = infer_outputs self.parameters = parameters self.from_grpc = from_grpc @@ -657,8 +660,9 @@ def from_grpc(cls, response: ModelInferResponse) -> 'InferResponse': data=get_content(output.datatype, output.contents), parameters=output.parameters) for output in response.outputs] - return cls(model_name=response.model_name, response_id=response.id, parameters=response.parameters, - infer_outputs=infer_outputs, raw_outputs=response.raw_output_contents, from_grpc=True) + return cls(model_name=response.model_name, model_version=response.model_version, response_id=response.id, + parameters=response.parameters, infer_outputs=infer_outputs, + raw_outputs=response.raw_output_contents, from_grpc=True) @classmethod def from_rest(cls, model_name: str, response: Dict) -> 'InferResponse': @@ -672,6 +676,7 @@ def from_rest(cls, model_name: str, response: Dict) -> 'InferResponse': parameters=output.get('parameters', None)) for output in response['outputs']] return cls(model_name=model_name, + model_version=response.get('model_version', None), response_id=response.get('id', None), parameters=response.get('parameters', None), infer_outputs=infer_outputs) @@ -702,6 +707,7 @@ def to_rest(self) -> Dict: res = { 'id': self.id, 'model_name': self.model_name, + 'model_version': self.model_version, 'outputs': infer_outputs } if self.parameters: @@ -742,8 +748,8 @@ def to_grpc(self) -> ModelInferResponse: raise InvalidInput("to_grpc: invalid output datatype") infer_outputs.append(infer_output_dict) - return ModelInferResponse(id=self.id, model_name=self.model_name, outputs=infer_outputs, - raw_output_contents=raw_output_contents, + return ModelInferResponse(id=self.id, model_name=self.model_name, model_version=self.model_version, + outputs=infer_outputs, raw_output_contents=raw_output_contents, parameters=to_grpc_parameters(self.parameters) if self.parameters else None) def __eq__(self, other): @@ -751,6 +757,8 @@ def __eq__(self, other): return False if self.model_name != other.model_name: return False + if self.model_version != other.model_version: + return False if self.id != other.id: return False if self.from_grpc != other.from_grpc: diff --git a/python/kserve/test/test_infer_type.py b/python/kserve/test/test_infer_type.py index 43a4010dfbe..ff7d2931dca 100644 --- a/python/kserve/test/test_infer_type.py +++ b/python/kserve/test/test_infer_type.py @@ -139,7 +139,7 @@ def test_from_grpc(self): class TestInferResponse: def test_to_rest(self): - infer_res = InferResponse(model_name="TestModel", response_id="123", + infer_res = InferResponse(model_name="TestModel", response_id="123", model_version="v1", parameters={ "test-str": InferParameter(string_param="dummy"), "test-bool": InferParameter(bool_param=True), @@ -156,6 +156,7 @@ def test_to_rest(self): expected = { "id": "123", "model_name": "TestModel", + "model_version": "v1", "outputs": [ { "name": "output-0", @@ -179,7 +180,7 @@ def test_to_rest(self): assert res == expected def test_to_grpc(self): - infer_res = InferResponse(model_name="TestModel", response_id="123", + infer_res = InferResponse(model_name="TestModel", response_id="123", model_version="v1", parameters={ "test-str": "dummy", "test-bool": True, @@ -193,7 +194,7 @@ def test_to_grpc(self): "test-int": 100 })] ) - expected = ModelInferResponse(model_name="TestModel", id="123", + expected = ModelInferResponse(model_name="TestModel", id="123", model_version="v1", parameters={ "test-str": InferParameter(string_param="dummy"), "test-bool": InferParameter(bool_param=True), @@ -218,7 +219,7 @@ def test_to_grpc(self): assert res == expected def test_from_grpc(self): - infer_res = ModelInferResponse(model_name="TestModel", id="123", + infer_res = ModelInferResponse(model_name="TestModel", id="123", model_version="v1", parameters={ "test-str": InferParameter(string_param="dummy"), "test-bool": InferParameter(bool_param=True), @@ -239,7 +240,7 @@ def test_from_grpc(self): }, }] ) - expected = InferResponse(model_name="TestModel", response_id="123", + expected = InferResponse(model_name="TestModel", response_id="123", model_version="v1", parameters={ "test-str": InferParameter(string_param="dummy"), "test-bool": InferParameter(bool_param=True),