Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stateful inference #2513

Merged
merged 59 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
ed5239e
stateful inference-core layer
lxning Aug 1, 2023
0794f54
add grpc layer
lxning Aug 5, 2023
4d55643
add google rpc submodule
lxning Aug 7, 2023
a857307
fmt
lxning Aug 15, 2023
4ae7404
update sequence batch img
lxning Aug 15, 2023
5c0dd97
update sequence batch img
lxning Aug 15, 2023
0651806
fmt
lxning Aug 15, 2023
3e16993
delete used file
lxning Aug 15, 2023
6aee437
fmt
lxning Aug 15, 2023
91b9f99
fmt
lxning Aug 15, 2023
a3f84eb
fix log and update doc
lxning Aug 16, 2023
c60f390
update log
lxning Aug 16, 2023
f5c7707
fmt
lxning Aug 16, 2023
dd23216
merge master and fix conflict
lxning Sep 27, 2023
1ea33cf
make BatchAggregator as base
lxning Sep 27, 2023
fdb03c9
fix conflict
lxning Sep 27, 2023
c3a2cca
fix conflict
lxning Sep 27, 2023
ba1bc45
add SequenceBatchAggregator
lxning Sep 29, 2023
f6c888d
update ci for submodule
lxning Sep 29, 2023
d723754
merge master
lxning Oct 4, 2023
077bf27
refactor
lxning Oct 5, 2023
bd3296a
fmt
lxning Oct 5, 2023
d3777e5
merge master
lxning Oct 5, 2023
d1e5e8d
fmt
lxning Oct 5, 2023
c24f1b8
fix lint
lxning Oct 6, 2023
3831ff5
code refactor
lxning Oct 8, 2023
8840a1c
fix conflict merging master
lxning Oct 11, 2023
79f2e66
update readme
lxning Oct 11, 2023
91a201e
update readme
lxning Oct 11, 2023
a431e78
fmt
lxning Oct 11, 2023
7c81022
Merge branch 'master' into feat/stateful
lxning Oct 11, 2023
fef7780
fmt
lxning Oct 11, 2023
cb3d232
Merge branch 'master' into feat/stateful
lxning Oct 12, 2023
67436d3
test workflow
lxning Oct 12, 2023
fe663fb
revert test
lxning Oct 12, 2023
421f31e
revert test response
lxning Oct 12, 2023
7f7bb69
fmt
lxning Oct 25, 2023
838a896
fmt
lxning Oct 11, 2023
0698bab
fix conflict
lxning Oct 25, 2023
4749b74
update readme
lxning Oct 27, 2023
80053ca
allow number ofjobGroup is larger than batchsize
lxning Oct 28, 2023
44d3986
fmt
lxning Oct 11, 2023
8879393
Merge branch 'master' into feat/stateful
lxning Oct 28, 2023
b05e653
fix typo
lxning Oct 28, 2023
5f7125e
add stateful test data
lxning Oct 28, 2023
9bb9245
fmt
lxning Oct 28, 2023
2f83255
Merge branch 'master' into feat/stateful
lxning Oct 28, 2023
a592f10
fmt
lxning Oct 28, 2023
4b9145b
fmt
lxning Oct 30, 2023
5fe05cd
fmt
lxning Oct 11, 2023
8e7ce9e
Merge branch 'master' into feat/stateful
lxning Oct 30, 2023
0a90a87
set default maxNumSequence
lxning Nov 1, 2023
fb9cdb5
fmt
lxning Nov 3, 2023
627a31e
Merge branch 'master' into feat/stateful
lxning Nov 3, 2023
876d83d
fmt
lxning Oct 11, 2023
4b19885
Merge branch 'master' into feat/stateful
lxning Nov 3, 2023
c5a0708
revert back config.properties
lxning Nov 3, 2023
6dc374a
fmt
lxning Nov 5, 2023
d4ea03d
Merge branch 'master' into feat/stateful
lxning Nov 7, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "third_party/google/rpc"]
path = third_party/google/rpc
url = https://github.com/googleapis/googleapis.git
6 changes: 3 additions & 3 deletions docs/grpc_api.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ cd serve
- Install gRPC python dependencies

```bash
pip install -U grpcio protobuf grpcio-tools
pip install -U grpcio protobuf grpcio-tools googleapis-common-protos
```

- Start torchServe
Expand All @@ -51,7 +51,7 @@ torchserve --start --model-store models/
- Generate python gRPC client stub using the proto files

```bash
python -m grpc_tools.protoc --proto_path=frontend/server/src/main/resources/proto/ --python_out=ts_scripts --grpc_python_out=ts_scripts frontend/server/src/main/resources/proto/inference.proto frontend/server/src/main/resources/proto/management.proto
python -m grpc_tools.protoc -I third_party/google/rpc --proto_path=frontend/server/src/main/resources/proto/ --python_out=ts_scripts --grpc_python_out=ts_scripts frontend/server/src/main/resources/proto/inference.proto frontend/server/src/main/resources/proto/management.proto
```

- Register densenet161 model
Expand Down Expand Up @@ -95,4 +95,4 @@ def handle(data, context):
for i in range (3):
send_intermediate_predict_response(["intermediate_response"], context.request_ids, "Intermediate Prediction success", 200, context)
return ["hello world "]
```
```
Binary file added docs/images/stateful_batch.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
115 changes: 115 additions & 0 deletions examples/stateful/Readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Stateful Inference

A stateful model possesses the ability to discern interdependencies between successive inference requests. This type of model maintains a persistent state across inference requests, thereby establishing a linkage between the outcomes of prior inquiries and those that follow. Notable illustrations of stateful models encompass online speech recognition systems, such as the Long Short-Term Memory (LSTM) model. Employing stateful inference mandates that the model server adheres to the sequential order of inference requests, ensuring predictions build upon the previous outcomes.

Within this context, TorchServe offers a mechanism known as sequence batching. This approach involves the retrieval of an individual inference request from a particular sequence, followed by the amalgamation of multiple requests originating from diverse sequences into a unified batch. Each request is associated with a unique sequence ID, which can be extracted using the "get_sequence_id" function of context.py. This sequence ID serves as a key employed by custom handlers to store and retrieve values within the backend cache store, fostering efficient management of stateful inference processes.

This example serves as a practical showcase of employing stateful inference. Underneath the surface, the backend leverages an [LRU dictionary](https://github.com/amitdev/lru-dict), functioning as a caching layer.



### Step 1: Implement handler

stateful_handler.py is an example of stateful handler. It creates a cache `self.cache` by calling `[LRU](https://github.com/amitdev/lru-dict)`.

```python
def initialize(self, ctx: Context):
"""
Loads the model and Initializes the necessary artifacts
"""

super().initialize(ctx)
self.context = ctx
if self.context.model_yaml_config["handler"] is not None:
lxning marked this conversation as resolved.
Show resolved Hide resolved
self.cache = LRU(int(self.context.model_yaml_config["handler"]["cache"]["capacity"]))
```

Handler uses sequenceId (ie., `sequence_id = self.context.get_sequence_id(idx)`) as key to store and fetch values from `self.cache`.

```python
def preprocess(self, data):
"""
Preprocess function to convert the request input to a tensor(Torchserve supported format).
The user needs to override to customize the pre-processing

Args :
data (list): List of the data from the request input.

Returns:
tensor: Returns the tensor data of the input
"""

self.sequence_ids = {}
results = []
for idx, row in enumerate(data):
sequence_id = self.context.get_sequence_id(idx)

prev = None
lxning marked this conversation as resolved.
Show resolved Hide resolved
if self.cache.has_key(sequence_id):
prev = int(self.cache[sequence_id])
else:
prev = int(0)

request = row.get("data") or row.get("body")
if isinstance(request, (bytes, bytearray)):
request = request.decode("utf-8")

val = prev + int(request)
self.cache[sequence_id] = val
results.append(val)

return results
```

### Step 2: Model configuration

Stateful inference has three parameters.
* sequenceMaxIdleMSec: the max idle in milliseconds of a sequence inference request of this stateful model. The default value is 0 (ie. this is not a stateful model.)
lxning marked this conversation as resolved.
Show resolved Hide resolved
* maxNumSequence: the max number of sequence inference requests of this stateful model. The default value is minWorkers * batchSize.
lxning marked this conversation as resolved.
Show resolved Hide resolved
* maxSequenceJobQueueSize: the job queue size of an inference sequence of this stateful model. The default value is 1.
lxning marked this conversation as resolved.
Show resolved Hide resolved


```yaml
#cat model-config.yaml

minWorkers: 2
maxWorkers: 2
batchSize: 4
sequenceMaxIdleMSec: 60000
maxNumSequence: 4
maxSequenceJobQueueSize: 10

handler:
cache:
capacity: 4
```

### Step 3: Generate mar or tgz file

```bash
torch-model-archiver --model-name stateful --version 1.0 --model-file model.py --serialized-file model_cnn.pt --handler stateful_handler.py -r requirements.txt --config-file model-config.yaml
```

### Step 4: Start torchserve

```bash
torchserve --start --ncs --model-store model_store --models stateful.mar
```

### Step 6: Build GRPC Client
The details can be found at [here](https://github.com/pytorch/serve/blob/master/docs/grpc_api.md).
lxning marked this conversation as resolved.
Show resolved Hide resolved
* Install gRPC python dependencies
* Generate python gRPC client stub using the proto files

### Step 7: Run inference
* Start TorchServe

```bash
torchserve --ncs --start --model-store models --model stateful.mar --ts-config config.properties
```

* Run sequence infernce
```bash
cd ../../
python ts_scripts/torchserve_grpc_client.py infer_stream2 stateful seq_0 examples/stateful/sample/sample1.txt,examples/stateful/sample/sample2.txt,examples/stateful/sample/sample3.txt
```
13 changes: 13 additions & 0 deletions examples/stateful/config.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
inference_address=http://0.0.0.0:8080
management_address=http://0.0.0.0:8081

number_of_netty_threads=32
job_queue_size=1000

vmargs=-Xmx4g -XX:+ExitOnOutOfMemoryError -XX:+HeapDumpOnOutOfMemoryError
prefer_direct_buffer=True

default_response_timeout=300
unregister_model_timeout=300
install_py_dep_per_model=true
enable_envvars_config=true
10 changes: 10 additions & 0 deletions examples/stateful/model-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
minWorkers: 2
maxWorkers: 2
batchSize: 4
sequenceMaxIdleMSec: 1000
maxNumSequence: 4
maxSequenceJobQueueSize: 10

handler:
cache:
capacity: 4
28 changes: 28 additions & 0 deletions examples/stateful/model.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're not using the network in the handler, so no need for any layers. Just return x in forward.

Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch
import torch.nn.functional as F
from torch import nn


class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 64, 3, 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(9216, 128)
self.fc2 = nn.Linear(128, 10)

def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
output = F.log_softmax(x, dim=1)
return output
Binary file added examples/stateful/model_cnn.pt
Binary file not shown.
1 change: 1 addition & 0 deletions examples/stateful/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
lru-dict
1 change: 1 addition & 0 deletions examples/stateful/sample/sample1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
1
1 change: 1 addition & 0 deletions examples/stateful/sample/sample2.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
2
1 change: 1 addition & 0 deletions examples/stateful/sample/sample3.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3
79 changes: 79 additions & 0 deletions examples/stateful/stateful_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import logging
from abc import ABC
from typing import Dict

from lru import LRU

from ts.context import Context
from ts.torch_handler.base_handler import BaseHandler

logger = logging.getLogger(__name__)


class StatefulHandler(BaseHandler, ABC):
def __init__(self):
super().__init__()
self.cache: LRU = None
self.sequence_ids: Dict = None
self.context = None

def initialize(self, ctx: Context):
"""
Loads the model and Initializes the necessary artifacts
"""

super().initialize(ctx)
self.context = ctx
if self.context.model_yaml_config["handler"] is not None:
lxning marked this conversation as resolved.
Show resolved Hide resolved
self.cache = LRU(
int(self.context.model_yaml_config["handler"]["cache"]["capacity"])
)

self.initialized = True

def preprocess(self, data):
"""
Preprocess function to convert the request input to a tensor(Torchserve supported format).
The user needs to override to customize the pre-processing

Args :
data (list): List of the data from the request input.

Returns:
tensor: Returns the tensor data of the input
"""

self.sequence_ids = {}
results = []
for idx, row in enumerate(data):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To confirm, is it the case that batchSize is the least upper bound of len(data), i.e.len(data) <= batchSize and for all l such that len(data) <= l, batchSize <= l?

Is it possible for two separate requests to get batched to this worker? If so, suppose there are two separate streaming requests that are batched to this worker. What happens if one client is much much faster than the other? Do we throttle the faster client to match the speed of the slower one by buffering the faster client's messages?

Copy link
Collaborator Author

@lxning lxning Aug 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Q1: yes, len(data) <= batchSize. data is a batch of requests received at realtime.

  • Q2: Yes, a batch of requests comes from different sequences. eg. len(data) = 4, it means there are 4 sequences. Each sequence has its own dedicated jobQ. Only the parameter "maxBatchDelay" decides the msec of batching a group of requests from different sequences. In other words, the different traffic volume of different sequences has no impact on batching latency.

Copy link

@calebho calebho Aug 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok but if two streams produce data at drastically different rates, how do you keep the batch index coherent? For instance, fix a stateful worker. At time t_0, the worker receives data d_0_0 and d_1_0 from two streams. So then len(data) == 2 and data[0] is the payload for stream 0 and data[1] is the payload for stream 1.

At t_1, stream 0 does not produce any data because it took longer than maxBatchDelay, but stream 1 produces data d_1_1. So then len(data) == 1 and data[0] is the payload for stream 1. In the line below, idx == 0, so then you fetch the sequence ID for index 0. It seems like this would fetch the sequence ID for stream 0,

sequence_id = self.context.get_sequence_id(idx)

but you actually want the sequence ID for stream 1. Am I understanding the API semantics correctly? Perhaps I am misunderstanding how context.get_sequence_id works. Does it keep track of which stream corresponds to the elements of the data list passed to the handler?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

each request's sequence id is added into its header with key = "ts_request_sequence_id". Backend can get a request's sequence id via its header. This can guarantee we can always get the sequence id regardless the real batch size is changed or the request of a sequence enters into a different batch slot.

sequence_id = self.context.get_sequence_id(idx)

prev = None
if self.cache.has_key(sequence_id):
prev = int(self.cache[sequence_id])
else:
prev = int(0)

request = row.get("data") or row.get("body")
if isinstance(request, (bytes, bytearray)):
request = request.decode("utf-8")

val = prev + int(request)
self.cache[sequence_id] = val
results.append(val)

return results

def inference(self, data, *args, **kwargs):
return data

def postprocess(self, data):
"""
The post process function makes use of the output from the inference and converts into a
Torchserve supported response output.

Returns:
List: The post process function returns a list of the predicted output.
"""

return data
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,20 @@ public class ModelConfig {
* available workers.
*/
private boolean useJobTicket;
/**
* the max idle in milliseconds of a sequence inference request of this stateful model. The
* default value is 0 (ie. this is not a stateful model.)
*/
private long sequenceMaxIdleMSec;
/**
* the max number of sequence inference requests of this stateful model. The default value is
* minWorkers * batchSize.
*/
private int maxNumSequence;
/**
* the job queue size of an inference sequence of this stateful model. The default value is 1.
*/
private int maxSequenceJobQueueSize = 1;

public static ModelConfig build(Map<String, Object> yamlMap) {
ModelConfig modelConfig = new ModelConfig();
Expand Down Expand Up @@ -158,6 +172,32 @@ public static ModelConfig build(Map<String, Object> yamlMap) {
logger.warn("Invalid useJobTicket: {}, should be true or false", v);
}
break;
case "sequenceMaxIdleMSec":
if (v instanceof Integer) {
modelConfig.setSequenceMaxIdleMSec(((Integer) v).longValue());
} else {
logger.warn(
"Invalid sequenceMaxIdleMSec: {}, should be positive int",
v);
}
break;
case "maxNumSequence":
if (v instanceof Integer) {
modelConfig.setMaxNumSequence((int) v);
} else {
logger.warn(
"Invalid maxNumSequence: {}, should be positive int", v);
}
break;
case "maxSequenceJobQueueSize":
if (v instanceof Integer) {
modelConfig.setMaxSequenceJobQueueSize((int) v);
} else {
logger.warn(
"Invalid maxSequenceJobQueueSize: {}, should be positive int",
v);
}
break;
default:
break;
}
Expand Down Expand Up @@ -313,6 +353,36 @@ public void setUseJobTicket(boolean useJobTicket) {
this.useJobTicket = useJobTicket;
}

public long getSequenceMaxIdleMSec() {
return sequenceMaxIdleMSec;
}

public void setSequenceMaxIdleMSec(long sequenceMaxIdleMSec) {
if (sequenceMaxIdleMSec > 0) {
this.sequenceMaxIdleMSec = sequenceMaxIdleMSec;
lxning marked this conversation as resolved.
Show resolved Hide resolved
}
}

public int getMaxNumSequence() {
return maxNumSequence;
}

public void setMaxNumSequence(int maxNumSequence) {
if (maxNumSequence > 0) {
this.maxNumSequence = maxNumSequence;
lxning marked this conversation as resolved.
Show resolved Hide resolved
}
}

public int getMaxSequenceJobQueueSize() {
return maxSequenceJobQueueSize;
}

public void setMaxSequenceJobQueueSize(int maxsequenceJobQueueSize) {
if (maxsequenceJobQueueSize > 0) {
this.maxSequenceJobQueueSize = maxsequenceJobQueueSize;
lxning marked this conversation as resolved.
Show resolved Hide resolved
}
}

public enum ParallelType {
NONE(""),
PP("pp"),
Expand Down
Loading