Skip to content

Commit

Permalink
Torch Tensor RT example (#2483)
Browse files Browse the repository at this point in the history
* Working example with Torch Tensor RT

* lint

* spellcheck
  • Loading branch information
agunapal authored Jul 21, 2023
1 parent 7e5857f commit b998f8c
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 0 deletions.
52 changes: 52 additions & 0 deletions examples/torch_tensorrt/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# TorchServe inference with torch tensorrt model

This example shows how to run TorchServe inference with [Torch-TensorRT](https://github.com/pytorch/TensorRT) model

### Pre-requisites

- Install CUDA and cuDNN. Verified with CUDA 11.7 and cuDNN 8.9.3.28
- Verified to be working with `tensorrt==8.5.3.1` and `torch-tensorrt==1.4.0`

Change directory to the root of `serve`
Ex: if `serve` is under `/home/ubuntu`, change directory to `/home/ubuntu/serve`


### Create a Torch Tensor RT model

We use `float16` precision
TorchServe's base handler supports loading Torch TensorRT model with `.pt` extension. Hence, the model is saved with `.pt` extension.

```
python examples/torch_tensorrt/resnet_tensorrt.py
```

### Create model archive

```
torch-model-archiver --model-name res50-trt-fp16 --handler image_classifier --version 1.0 --serialized-file res50_trt_fp16.pt --extra-files ./examples/image_classifier/index_to_name.json
mkdir model_store
mv res50-trt-fp16.mar model_store/.
```

#### Start TorchServe
```
torchserve --start --model-store model_store --models res50-trt-fp16=res50-trt-fp16.mar --ncs
```

#### Run Inference

```
curl http://127.0.0.1:8080/predictions/res50-trt-fp16 -T ./examples/image_classifier/kitten.jpg
```

produces the output

```
{
"tabby": 0.2723647356033325,
"tiger_cat": 0.13748960196971893,
"Egyptian_cat": 0.04659610986709595,
"lynx": 0.00318642589263618,
"lens_cap": 0.00224193069152534
}
```
2 changes: 2 additions & 0 deletions examples/torch_tensorrt/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
tensorrt==8.5.3.1
torch-tensorrt==1.4.0
23 changes: 23 additions & 0 deletions examples/torch_tensorrt/resnet_tensorrt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch
import torch_tensorrt
from torchvision.models import ResNet50_Weights, resnet50

model = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

model.eval()


trt_model_fp16 = torch_tensorrt.compile(
model,
inputs=[
torch_tensorrt.Input(
min_shape=(1, 3, 224, 224),
opt_shape=(32, 3, 224, 224),
max_shape=(64, 3, 224, 224),
dtype=torch.float32,
)
],
enabled_precisions=torch.float16, # Run with FP32
workspace_size=1 << 22,
)
torch.jit.save(trt_model_fp16, "res50_trt_fp16.pt")
6 changes: 6 additions & 0 deletions ts/torch_handler/base_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@
logger.warning("proceeding without onnxruntime")
ONNX_AVAILABLE = False

try:
import torch_tensorrt
logger.info("Torch TensorRT enabled")
except ImportError:
logger.warning("Torch TensorRT not enabled")


def setup_ort_session(model_pt_path, map_location):
providers = (
Expand Down
1 change: 1 addition & 0 deletions ts_scripts/spellcheck_conf/wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1064,3 +1064,4 @@ ActionSLAM
statins
ci
chatGPT
cuDNN

0 comments on commit b998f8c

Please sign in to comment.