Skip to content

Commit

Permalink
Refactor code for custom data support in tfserving_proxy
Browse files Browse the repository at this point in the history
  • Loading branch information
Sebastian Stawicki committed May 18, 2021
1 parent 46ebe81 commit e600275
Showing 1 changed file with 52 additions and 33 deletions.
85 changes: 52 additions & 33 deletions servers/tfserving_proxy/TfServingProxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,48 +62,67 @@ def predict_grpc(self, request):
if request_data_type not in ["data", "customData"]:
raise Exception("strData, binData and jsonData not supported.")

if request_data_type == "data":
result = self._predict_grpc_data(request, default_data_type)
else:
result = self._predict_grpc_custom_data(request)

return result

def _predict_grpc_data(self, request, default_data_type):
tfrequest = predict_pb2.PredictRequest()

# handle inputs
if request_data_type == "data":
# handle input
if default_data_type == "tftensor":
# For GRPC case, if we have a TFTensor message we can pass it directly
if default_data_type == "tftensor":
tfrequest.inputs[self.model_input].CopyFrom(request.data.tftensor)
else:
data_arr = grpc_datadef_to_array(request.data)
tfrequest.inputs[self.model_input].CopyFrom(
tf.make_tensor_proto(
data_arr.tolist(),
shape=data_arr.shape))

tfrequest.inputs[self.model_input].CopyFrom(request.data.tftensor)
else:
# Unpack custom data into tfrequest - taking raw inputs prepared by the user.
# This allows the use case when the model's input is not a single tftensor
# but a map of tensors like defined in predict.proto:
# PredictRequest.inputs: map<string, TensorProto>
request.customData.Unpack(tfrequest)
data_arr = grpc_datadef_to_array(request.data)
tfrequest.inputs[self.model_input].CopyFrom(
tf.make_tensor_proto(
data_arr.tolist(),
shape=data_arr.shape))

# handle prediction
tfrequest.model_spec.name = self.model_name
tfrequest.model_spec.signature_name = self.signature_name
tfresponse = self.stub.Predict(tfrequest)
tfresponse = self._handle_grpc_prediction(tfrequest)

# handle result
if request_data_type == "data":
datadef = prediction_pb2.DefaultData(
tftensor=tfresponse.outputs[self.model_output]
)
result = prediction_pb2.SeldonMessage(data=datadef)
else:
# Pack tfresponse into the SeldonMessage's custom data - letting user handle
# raw outputs. This allows the case when the model's output is not a single tftensor
# but a map of tensors like defined in predict.proto:
# PredictResponse: map<string, TensorProto>
custom_data = Any()
custom_data.Pack(tfresponse)
result = prediction_pb2.SeldonMessage(customData=custom_data)
datadef = prediction_pb2.DefaultData(
tftensor=tfresponse.outputs[self.model_output]
)

return result
return prediction_pb2.SeldonMessage(data=datadef)

def _predict_grpc_custom_data(self, request):
tfrequest = predict_pb2.PredictRequest()

# handle input
#
# Unpack custom data into tfrequest - taking raw inputs prepared by the user.
# This allows the use case when the model's input is not a single tftensor
# but a map of tensors like defined in predict.proto:
# PredictRequest.inputs: map<string, TensorProto>
request.customData.Unpack(tfrequest)

# handle prediction
tfresponse = self._handle_grpc_prediction(tfrequest)

# handle result
#
# Pack tfresponse into the SeldonMessage's custom data - letting user handle
# raw outputs. This allows the case when the model's output is not a single tftensor
# but a map of tensors like defined in predict.proto:
# PredictResponse: map<string, TensorProto>
custom_data = Any()
custom_data.Pack(tfresponse)
return prediction_pb2.SeldonMessage(customData=custom_data)

def _handle_grpc_prediction(self, tfrequest):
# handle prediction
tfrequest.model_spec.name = self.model_name
tfrequest.model_spec.signature_name = self.signature_name
tfresponse = self.stub.Predict(tfrequest)
return tfresponse

def predict(self, X, features_names=[]):
"""
Expand Down

0 comments on commit e600275

Please sign in to comment.