-
Notifications
You must be signed in to change notification settings - Fork 356
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: data parallel inference sample
- Loading branch information
Showing
5 changed files
with
145 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# Torch-TensorRT parallelism for distributed inference | ||
|
||
Examples in this folder demonstrates doing distributed inference on multiple devices with Torch-TensorRT backend. | ||
|
||
1. Data parallel distributed inference based on [Acclerate](https://huggingface.co/docs/accelerate/usage_guides/distributed_inference) | ||
|
||
Using Accelerate users can achieve data parallel distributed inference with Torch-TensorRt backend. In this case, the entire model | ||
will be loaded onto each GPU and different chunks of batch input is processed on each device. | ||
|
||
See the examples started with `data_parallel` for more details. | ||
|
||
2. Tensor parallel distributed inference | ||
|
||
In development. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
""" | ||
.. _data_parallel_gpt2: | ||
Torch-TensorRT Distributed Inference | ||
====================================================== | ||
This interactive script is intended as a sample of distributed inference using data | ||
parallelism using Accelerate | ||
library with the Torch-TensorRT workflow on GPT2 model. | ||
""" | ||
|
||
# %% | ||
# Imports and Model Definition | ||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
import torch | ||
from accelerate import PartialState | ||
from transformers import AutoTokenizer, GPT2LMHeadModel | ||
|
||
import torch_tensorrt | ||
|
||
tokenizer = AutoTokenizer.from_pretrained("gpt2") | ||
|
||
# Set input prompts for different devices | ||
prompt1 = "GPT2 is a model developed by." | ||
prompt2 = "Llama is a model developed by " | ||
|
||
input_id1 = tokenizer(prompt1, return_tensors="pt").input_ids | ||
input_id2 = tokenizer(prompt2, return_tensors="pt").input_ids | ||
|
||
distributed_state = PartialState() | ||
|
||
# Import GPT2 model and load to distributed devices | ||
model = GPT2LMHeadModel.from_pretrained("gpt2").eval().to(distributed_state.device) | ||
|
||
|
||
# Instantiate model with Torch-TensorRT backend | ||
model.forward = torch.compile( | ||
model.forward, | ||
backend="torch_tensorrt", | ||
options={ | ||
"truncate_long_and_double": True, | ||
"enabled_precisions": {torch.float16}, | ||
"debug": True, | ||
}, | ||
dynamic=False, | ||
) | ||
|
||
# %% | ||
# Inference | ||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
# Assume there are 2 processes (2 devices) | ||
with distributed_state.split_between_processes([input_id1, input_id2]) as prompt: | ||
cur_input = torch.clone(prompt[0]).to(distributed_state.device) | ||
|
||
gen_tokens = model.generate( | ||
cur_input, | ||
do_sample=True, | ||
temperature=0.9, | ||
max_length=100, | ||
) | ||
gen_text = tokenizer.batch_decode(gen_tokens)[0] |
61 changes: 61 additions & 0 deletions
61
examples/distributed_inference/data_parallel_stable_diffusion.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
""" | ||
.. _data_parallel_stable_diffusion: | ||
Torch-TensorRT Distributed Inference | ||
====================================================== | ||
This interactive script is intended as a sample of distributed inference using data | ||
parallelism using Accelerate | ||
library with the Torch-TensorRT workflow on Stable Diffusion model. | ||
""" | ||
|
||
# %% | ||
# Imports and Model Definition | ||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
import torch | ||
from accelerate import PartialState | ||
from diffusers import DiffusionPipeline | ||
|
||
import torch_tensorrt | ||
|
||
model_id = "CompVis/stable-diffusion-v1-4" | ||
|
||
# Instantiate Stable Diffusion Pipeline with FP16 weights | ||
pipe = DiffusionPipeline.from_pretrained( | ||
model_id, revision="fp16", torch_dtype=torch.float16 | ||
) | ||
|
||
distributed_state = PartialState() | ||
pipe = pipe.to(distributed_state.device) | ||
|
||
backend = "torch_tensorrt" | ||
|
||
# Optimize the UNet portion with Torch-TensorRT | ||
pipe.unet = torch.compile( # %% | ||
# Inference | ||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
# Assume there are 2 processes (2 devices) | ||
pipe.unet, | ||
backend=backend, | ||
options={ | ||
"truncate_long_and_double": True, | ||
"precision": torch.float16, | ||
"debug": True, | ||
"use_python_runtime": True, | ||
}, | ||
dynamic=False, | ||
) | ||
torch_tensorrt.runtime.set_multi_device_safe_mode(True) | ||
|
||
|
||
# %% | ||
# Inference | ||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||
|
||
# Assume there are 2 processes (2 devices) | ||
with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt: | ||
print("before \n") | ||
result = pipe(prompt).images[0] | ||
print("after ") | ||
result.save(f"result_{distributed_state.process_index}.png") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
accelerate | ||
transformers | ||
diffusers |