Skip to content

Commit

Permalink
feat: add KServe gRPC v2 support (#2176)
Browse files Browse the repository at this point in the history
* feat: add KServe gRPC v2 support

Signed-off-by: jagadeesh <jagadeeshj@ideas2it.com>

* feat: add utils to convert kserve pb to ts pb

Signed-off-by: jagadeesh <jagadeeshj@ideas2it.com>

* add ts pb to kserve pb conversion method

Signed-off-by: jagadeesh <jagadeeshj@ideas2it.com>

* Add pb python file generation step at docker build

Signed-off-by: jagadeesh <jagadeeshj@ideas2it.com>

* fix: readme doc

 - add logs

Signed-off-by: jagadeesh <jagadeeshj@ideas2it.com>

* update readme

Signed-off-by: jagadeesh <jagadeeshj@ideas2it.com>

* fix lint errors

* fix kserve_v2 service envelop and test data

Signed-off-by: jagadeesh <jagadeeshj@ideas2it.com>

* re-test

Signed-off-by: jagadeesh <jagadeeshj@ideas2it.com>

* re-test

Signed-off-by: jagadeesh <jagadeeshj@ideas2it.com>

---------

Signed-off-by: jagadeesh <jagadeeshj@ideas2it.com>
Co-authored-by: Geeta Chauhan <4461127+chauhang@users.noreply.github.com>
Co-authored-by: Ankith Gunapal <agunapal@ischool.Berkeley.edu>
Co-authored-by: Mark Saroufim <marksaroufim@fb.com>
  • Loading branch information
4 people authored Aug 24, 2023
1 parent 2ff5020 commit 39e715d
Show file tree
Hide file tree
Showing 16 changed files with 326 additions and 58 deletions.
15 changes: 12 additions & 3 deletions kubernetes/kserve/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# syntax = docker/dockerfile:experimental
#
# Following comments have been shamelessly copied from https://github.com/pytorch/pytorch/blob/master/Dockerfile
#
#
# NOTE: To build this you will need a docker version > 18.06 with
# experimental enabled and DOCKER_BUILDKIT=1
#
# If you do not use buildkit you are not going to have a good time
#
# For reference:
# For reference:
# https://docs.docker.com/develop/develop-images/build_enhancements

ARG BASE_IMAGE=pytorch/torchserve:latest
Expand All @@ -24,9 +24,18 @@ RUN pip install -r requirements.txt
COPY dockerd-entrypoint.sh /usr/local/bin/dockerd-entrypoint.sh
RUN chmod +x /usr/local/bin/dockerd-entrypoint.sh
COPY kserve_wrapper kserve_wrapper

COPY ./*.proto ./kserve_wrapper/

RUN python -m grpc_tools.protoc \
--proto_path=./kserve_wrapper \
--python_out=./kserve_wrapper \
--grpc_python_out=./kserve_wrapper \
./kserve_wrapper/inference.proto \
./kserve_wrapper/management.proto

COPY config.properties config.properties

USER model-server

ENTRYPOINT ["/usr/local/bin/dockerd-entrypoint.sh"]

2 changes: 2 additions & 0 deletions kubernetes/kserve/Dockerfile.dev
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ RUN if [ "$MACHINE_TYPE" = "gpu" ]; then export USE_CUDA=1; fi \
&& chmod +x /usr/local/bin/dockerd-entrypoint.sh \
&& chown -R model-server /home/model-server \
&& cp -R kubernetes/kserve/kserve_wrapper /home/model-server/kserve_wrapper \
&& cp frontend/server/src/main/resources/proto/*.proto /home/model-serve/kserve_wrapper \
&& python -m grpc_tools.protoc --proto_path=/home/model-server/kserve_wrapper --python_out=/home/model-server/kserve_wrapper --grpc_python_out=/home/model-server/kserve_wrapper /home/model-server/kserve_wrapper/inference.proto /home/model-server/kserve_wrapper/management.proto \
&& cp kubernetes/kserve/config.properties /home/model-server/config.properties \
&& mkdir /home/model-server/model-store && chown -R model-server /home/model-server/model-store

Expand Down
4 changes: 2 additions & 2 deletions kubernetes/kserve/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ Currently, KServe supports the Inference API for all the existing models but tex
./build_image.sh -g -t <repository>/<image>:<tag>
```

### Docker Image Dev Build
- To create dev image

```bash
DOCKER_BUILDKIT=1 docker build -f Dockerfile.dev -t pytorch/torchserve-kfs:latest-dev .
./build_image.sh -g -d -t <repository>/<image>:<tag>
```

## Running Torchserve inference service in KServe cluster
Expand Down
9 changes: 8 additions & 1 deletion kubernetes/kserve/build_image.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

DOCKER_TAG="pytorch/torchserve-kfs:latest"
BASE_IMAGE="pytorch/torchserve:latest"
DOCKER_FILE="Dockerfile"

for arg in "$@"
do
Expand All @@ -18,6 +19,10 @@ do
BASE_IMAGE="pytorch/torchserve:latest-gpu"
shift
;;
-d|--dev)
DOCKER_FILE="Dockerfile.dev"
shift
;;
-t|--tag)
DOCKER_TAG="$2"
shift
Expand All @@ -26,4 +31,6 @@ do
esac
done

DOCKER_BUILDKIT=1 docker build --file Dockerfile --build-arg BASE_IMAGE=$BASE_IMAGE -t "$DOCKER_TAG" .
cp ../../frontend/server/src/main/resources/proto/*.proto .

DOCKER_BUILDKIT=1 docker build --file "$DOCKER_FILE" --build-arg BASE_IMAGE=$BASE_IMAGE -t "$DOCKER_TAG" .
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
{
"id": "d3b15cad-50a2-4eaf-80ce-8b0a428bd298",
"inputs": [
{
"data": ["iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAAA10lEQVR4nGNgGFhgy6xVdrCszBaLFN/mr28+/QOCr69DMCSnA8WvHti0acu/fx/10OS0X/975CDDw8DA1PDn/1pBVEmLf3+zocy2X/+8USXt/82Ds+/+m4sqeehfOpw97d9VFDmlO++t4JwQNMm6f6sZcEpee2+DR/I4A05J7tt4JJP+IUsu+ncRp6TxO9RAQJY0XvrvMAuypNNHuCTz8n+PzVEcy3DtqgiY1ptx6t8/ewY0yX9ntoDA63//Xs3hQpMMPPsPAv68qmDAAFKXwHIzMzCl6AoAxXp0QujtP+8AAAAASUVORK5CYII="],
"datatype": "BYTES",
"name": "e8d5afed-0a56-4deb-ac9c-352663f51b93",
"name": "312a4eb0-0ca7-4803-a101-a6d2c18486fe",
"shape": [-1]
}
]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"model_name": "mnist",
"inputs": [{
"name": "312a4eb0-0ca7-4803-a101-a6d2c18486fe",
"shape": [-1],
"datatype": "BYTES",
"contents": {
"bytes_contents": ["iVBORw0KGgoAAAANSUhEUgAAABwAAAAcCAAAAABXZoBIAAAA10lEQVR4nGNgGFhgy6xVdrCszBaLFN/mr28+/QOCr69DMCSnA8WvHti0acu/fx/10OS0X/975CDDw8DA1PDn/1pBVEmLf3+zocy2X/+8USXt/82Ds+/+m4sqeehfOpw97d9VFDmlO++t4JwQNMm6f6sZcEpee2+DR/I4A05J7tt4JJP+IUsu+ncRp6TxO9RAQJY0XvrvMAuypNNHuCTz8n+PzVEcy3DtqgiY1ptx6t8/ewY0yX9ntoDA63//Xs3hQpMMPPsPAv68qmDAAFKXwHIzMzCl6AoAxXp0QujtP+8AAAAASUVORK5CYII="]
}
}]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"id": "d3b15cad-50a2-4eaf-80ce-8b0a428bd298",
"model_name": "mnist",
"inputs": [{
"name": "input-0",
"shape": [1, 28, 28],
"datatype": "FP32",
"contents": {
"fp32_contents": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.23919999599456787, 0.011800000444054604, 0.1647000014781952, 0.4627000093460083, 0.7569000124931335, 0.4627000093460083, 0.4627000093460083, 0.23919999599456787, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.05490000173449516, 0.7020000219345093, 0.9607999920845032, 0.9254999756813049, 0.9490000009536743, 0.9961000084877014, 0.9961000084877014, 0.9961000084877014, 0.9961000084877014, 0.9607999920845032, 0.9215999841690063, 0.3294000029563904, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.592199981212616, 0.9961000084877014, 0.9961000084877014, 0.9961000084877014, 0.8353000283241272, 0.7529000043869019, 0.6980000138282776, 0.6980000138282776, 0.7059000134468079, 0.9961000084877014, 0.9961000084877014, 0.9451000094413757, 0.18039999902248383, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.16859999299049377, 0.9215999841690063, 0.9961000084877014, 0.8863000273704529, 0.25099998712539673, 0.10980000346899033, 0.0471000000834465, 0.0, 0.0, 0.007799999788403511, 0.5019999742507935, 0.9882000088691711, 1.0, 0.6783999800682068, 0.06669999659061432, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.21960000693798065, 0.9961000084877014, 0.9922000169754028, 0.4196000099182129, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5254999995231628, 0.980400025844574, 0.9961000084877014, 0.29409998655319214, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.24709999561309814, 0.9961000084877014, 0.6195999979972839, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8666999936103821, 0.9961000084877014, 0.6157000064849854, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.7608000040054321, 0.9961000084877014, 0.40389999747276306, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.5881999731063843, 0.9961000084877014, 0.8353000283241272, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.13330000638961792, 0.8626999855041504, 0.9373000264167786, 0.22750000655651093, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3294000029563904, 0.9961000084877014, 0.8353000283241272, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.49410000443458557, 0.9961000084877014, 0.6705999970436096, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3294000029563904, 0.9961000084877014, 0.8353000283241272, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8392000198364258, 0.9373000264167786, 0.2353000044822693, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3294000029563904, 0.9961000084877014, 0.8353000283241272, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8392000198364258, 0.7803999781608582, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3294000029563904, 0.9961000084877014, 0.8353000283241272, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.04309999942779541, 0.8587999939918518, 0.7803999781608582, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3294000029563904, 0.9961000084877014, 0.8353000283241272, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.38429999351501465, 0.9961000084877014, 0.7803999781608582, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6352999806404114, 0.9961000084877014, 0.819599986076355, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.38429999351501465, 0.9961000084877014, 0.7803999781608582, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.20000000298023224, 0.9333000183105469, 0.9961000084877014, 0.29409998655319214, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.38429999351501465, 0.9961000084877014, 0.7803999781608582, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.20000000298023224, 0.6470999717712402, 0.9961000084877014, 0.7646999955177307, 0.015699999406933784, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2587999999523163, 0.9451000094413757, 0.7803999781608582, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.011800000444054604, 0.6549000144004822, 0.9961000084877014, 0.8902000188827515, 0.21570000052452087, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8392000198364258, 0.8353000283241272, 0.07840000092983246, 0.0, 0.0, 0.0, 0.0, 0.0, 0.18039999902248383, 0.5960999727249146, 0.7922000288963318, 0.9961000084877014, 0.9961000084877014, 0.24709999561309814, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8392000198364258, 0.9961000084877014, 0.800000011920929, 0.7059000134468079, 0.7059000134468079, 0.7059000134468079, 0.7059000134468079, 0.7059000134468079, 0.9215999841690063, 0.9961000084877014, 0.9961000084877014, 0.9175999760627747, 0.6118000149726868, 0.03920000046491623, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.3176000118255615, 0.8039000034332275, 0.9961000084877014, 0.9961000084877014, 0.9961000084877014, 0.9961000084877014, 0.9961000084877014, 0.9961000084877014, 0.9961000084877014, 0.9882000088691711, 0.9175999760627747, 0.4706000089645386, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.10199999809265137, 0.8234999775886536, 0.9961000084877014, 0.9961000084877014, 0.9961000084877014, 0.9961000084877014, 0.9961000084877014, 0.6000000238418579, 0.40779998898506165, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
}
}]
}
59 changes: 35 additions & 24 deletions kubernetes/kserve/kserve_wrapper/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Follow the below steps to serve the MNIST Model :
- Step 2 : Install KServe as below:

```bash
pip install kserve>=0.9.0
pip install kserve>=0.9.0 grpcio protobuf grpcio-tools
```

- Step 4 : Run the Install Dependencies script
Expand Down Expand Up @@ -59,11 +59,11 @@ sudo mkdir -p /mnt/models/model-store

For v1 protocol

``export TS_SERVICE_ENVELOPE=kserve`
`export TS_SERVICE_ENVELOPE=kserve`

For v2 protocol

``export TS_SERVICE_ENVELOPE=kservev2`
`export TS_SERVICE_ENVELOPE=kservev2`

- Step 10: Move the config.properties to /mnt/models/config/.
The config.properties file is as below :
Expand Down Expand Up @@ -93,6 +93,26 @@ torchserve --start --ts-config /mnt/models/config/config.properties

- Step 12: Run the below command to start the KFServer

- Step 13: Set protocol version

For v1 protocol

`export PROTOCOL_VERSION=v1`

For v2 protocol

`export PROTOCOL_VERSION=v2`

For grpc protocol v2 format set

`export PROTOCOL_VERSION=grpc-v2`

- Generate python gRPC client stub using the proto files

```bash
python -m grpc_tools.protoc --proto_path=frontend/server/src/main/resources/proto/ --python_out=ts_scripts --grpc_python_out=ts_scripts frontend/server/src/main/resources/proto/inference.proto frontend/server/src/main/resources/proto/management.proto
```

```bash
python3 serve/kubernetes/kserve/kserve_wrapper/__main__.py
```
Expand Down Expand Up @@ -127,7 +147,7 @@ Output:

The curl request for explain is as below:

```
```bash
curl -H "Content-Type: application/json" --data @serve/kubernetes/kserve/kf_request_json/v1/mnist.json http://0.0.0.0:8080/v1/models/mnist:explain
```

Expand All @@ -146,7 +166,7 @@ For v2 protocol
The curl request for inference is as below:

```bash
curl -H "Content-Type: application/json" --data @serve/kubernetes/kserve/kf_request_json/mnist_v2.json http://0.0.0.0:8080/v2/models/mnist/infer
curl -H "Content-Type: application/json" --data @serve/kubernetes/kserve/kf_request_json/v2/mnist/mnist_v2_tensor.json http://0.0.0.0:8080/v2/models/mnist/infer
```

Response:
Expand All @@ -167,29 +187,20 @@ Response:
}
```

The curl request for explain is as below:
For grpc-v2 protocol

```
curl -H "Content-Type: application/json" --data @serve/kubernetes/kserve/kf_request_json/v1/mnist.json http://0.0.0.0:8080/v2/models/mnist/explain
- Download the proto file

```bash
curl -O https://raw.githubusercontent.com/kserve/kserve/master/docs/predict-api/v2/grpc_predict_v2.proto
```

Response:
- Download [grpcurl](https://github.com/fullstorydev/grpcurl)

```json
{
"id": "3482b766-0483-40e9-84b0-8ce8d4d1576e",
"model_name": "mnist",
"model_version": "1.0",
"outputs": [{
"name": "explain",
"shape": [1, 28, 28],
"datatype": "FP64",
"data": [-0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, -0.0, 0.0, -0.0, -0.0, 0.0, -0.0, 0.0
...
...
]
}]
}
Make gRPC request

```bash
grpcurl -vv -plaintext -proto grpc_predict_v2.proto -d @ localhost:8081 inference.GRPCInferenceService.ModelInfer <<< $(cat "serve/kubernetes/kserve/kf_request_json/v2/mnist/mnist_v2_tensor_grpc.json")
```

## KServe Wrapper Testing in Local for BERT
Expand Down
7 changes: 6 additions & 1 deletion kubernetes/kserve/kserve_wrapper/TSModelRepository.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ class TSModelRepository(ModelRepository):
as inputs to the TSModel Repository.
"""

def __init__(self, inference_address: str, management_address: str, model_dir: str):
def __init__(
self,
inference_address: str,
management_address: str,
model_dir: str,
):
"""The Inference Address, Management Address and the Model Directory from the kserve
side is initialized here.
Expand Down
95 changes: 92 additions & 3 deletions kubernetes/kserve/kserve_wrapper/TorchserveModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,37 @@
return a KServe side response """
import logging
import pathlib
from enum import Enum
from typing import Dict, Union

import grpc
import inference_pb2_grpc
import kserve
from gprc_utils import from_ts_grpc, to_ts_grpc
from inference_pb2 import PredictionResponse
from kserve.errors import ModelMissingError
from kserve.model import Model as Model
from kserve.protocol.grpc.grpc_predict_v2_pb2 import (
ModelInferRequest,
ModelInferResponse,
)
from kserve.protocol.infer_type import InferRequest, InferResponse
from kserve.storage import Storage

logging.basicConfig(level=kserve.constants.KSERVE_LOGLEVEL)

PREDICTOR_URL_FORMAT = PREDICTOR_V2_URL_FORMAT = "http://{0}/predictions/{1}"
EXPLAINER_URL_FORMAT = EXPLAINER_V2_URL_FORMAT = "http://{0}/explanations/{1}"
EXPLAINER_URL_FORMAT = EXPLAINER_v2_URL_FORMAT = "http://{0}/explanations/{1}"
REGISTER_URL_FORMAT = "{0}/models?initial_workers=1&url={1}"
UNREGISTER_URL_FORMAT = "{0}/models/{1}"


class PredictorProtocol(Enum):
REST_V1 = "v1"
REST_V2 = "v2"
GRPC_V2 = "grpc-v2"


class TorchserveModel(Model):
"""The torchserve side inference and explain end-points requests are handled to
return a KServe side response
Expand All @@ -25,7 +42,15 @@ class TorchserveModel(Model):
side predict and explain http requests.
"""

def __init__(self, name, inference_address, management_address, model_dir):
def __init__(
self,
name,
inference_address,
management_address,
grpc_inference_address,
protocol,
model_dir,
):
"""The Model Name, Inference Address, Management Address and the model directory
are specified.
Expand All @@ -45,10 +70,74 @@ def __init__(self, name, inference_address, management_address, model_dir):
self.inference_address = inference_address
self.management_address = management_address
self.model_dir = model_dir
self.protocol = protocol

if self.protocol == PredictorProtocol.GRPC_V2.value:
self.predictor_host = grpc_inference_address

logging.info("Predict URL set to %s", self.predictor_host)
self.explainer_host = self.predictor_host
logging.info("Explain URL set to %s", self.explainer_host)
logging.info("Protocol version is %s", self.protocol)

def grpc_client(self):
if self._grpc_client_stub is None:
self.channel = grpc.aio.insecure_channel(self.predictor_host)
self.grpc_client_stub = inference_pb2_grpc.InferenceAPIsServiceStub(
self.channel
)
return self.grpc_client_stub

async def _grpc_predict(
self,
payload: Union[ModelInferRequest, InferRequest],
headers: Dict[str, str] = None,
) -> ModelInferResponse:
"""Overrides the `_grpc_predict` method in Model class. The predict method calls
the `_grpc_predict` method if the self.protocol is "grpc_v2"
Args:
request (Dict|InferRequest|ModelInferRequest): The response passed from ``predict`` handler.
Returns:
Dict: Torchserve grpc response.
"""
payload = to_ts_grpc(payload)
grpc_stub = self.grpc_client()
async_result = await grpc_stub.Predictions(payload)
return async_result

def postprocess(
self,
response: Union[Dict, InferResponse, ModelInferResponse, PredictionResponse],
headers: Dict[str, str] = None,
) -> Union[Dict, ModelInferResponse]:
"""This method converts the v2 infer response types to gRPC or REST.
For gRPC request it converts InferResponse to gRPC message or directly returns ModelInferResponse from
predictor call or converts TS PredictionResponse to ModelInferResponse.
For REST request it converts ModelInferResponse to Dict or directly returns from predictor call.
Args:
response (Dict|InferResponse|ModelInferResponse|PredictionResponse): The response passed from ``predict`` handler.
headers (Dict): Request headers.
Returns:
Dict: post-processed response.
"""
if headers:
if "grpc" in headers.get("user-agent", ""):
if isinstance(response, ModelInferResponse):
return response
elif isinstance(response, InferResponse):
return response.to_grpc()
elif isinstance(response, PredictionResponse):
return from_ts_grpc(response)
if "application/json" in headers.get("content-type", ""):
# If the original request is REST, convert the gRPC predict response to dict
if isinstance(response, ModelInferResponse):
return InferResponse.from_grpc(response).to_rest()
elif isinstance(response, InferResponse):
return response.to_rest()
return response

def load(self) -> bool:
"""This method validates model availabilty in the model directory
Expand Down
Loading

0 comments on commit 39e715d

Please sign in to comment.