From 4bc05b765e6ad30f1b788ff71c5f48629177fa33 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Tue, 7 May 2024 21:16:43 +0000 Subject: [PATCH] add requirements.txt, annotate the script and add reference to index.rst --- docsrc/index.rst | 3 +++ .../data_parallel_gpt2.py | 25 ++++++++++++++++++ .../data_parallel_stable_diffusion.py | 26 ++++++++++++++++++- .../distributed_inference/requirement.txt | 3 +++ 4 files changed, 56 insertions(+), 1 deletion(-) create mode 100644 examples/distributed_inference/requirement.txt diff --git a/docsrc/index.rst b/docsrc/index.rst index 455aeab8b3..06393907ed 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -111,6 +111,9 @@ Tutorials tutorials/_rendered_examples/dynamo/torch_compile_transformers_example tutorials/_rendered_examples/dynamo/torch_compile_advanced_usage tutorials/_rendered_examples/dynamo/torch_compile_stable_diffusion + tutorials/_rendered_examples/distributed_inference/data_parallel_gpt2 + tutorials/_rendered_examples/distributed_inference/data_parallel_stable_diffusion + Python API Documenation ------------------------ diff --git a/examples/distributed_inference/data_parallel_gpt2.py b/examples/distributed_inference/data_parallel_gpt2.py index 400d91251b..c6e3d9d3c8 100644 --- a/examples/distributed_inference/data_parallel_gpt2.py +++ b/examples/distributed_inference/data_parallel_gpt2.py @@ -1,3 +1,19 @@ +""" +.. _data_parallel_gpt2: + +Torch-TensorRT Distributed Inference +====================================================== + +This interactive script is intended as a sample of distributed inference using data +parallelism using Accelerate +library with the Torch-TensorRT workflow on GPT2 model. + +""" + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + import torch from accelerate import PartialState from transformers import AutoTokenizer, GPT2LMHeadModel @@ -6,6 +22,7 @@ tokenizer = AutoTokenizer.from_pretrained("gpt2") +# Set input prompts for different devices prompt1 = "GPT2 is a model developed by." prompt2 = "Llama is a model developed by " @@ -14,8 +31,11 @@ distributed_state = PartialState() +# Import GPT2 model and load to distributed devices model = GPT2LMHeadModel.from_pretrained("gpt2").eval().to(distributed_state.device) + +# Instantiate model with Torch-TensorRT backend model.forward = torch.compile( model.forward, backend="torch_tensorrt", @@ -27,6 +47,11 @@ dynamic=False, ) +# %% +# Inference +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Assume there are 2 processes (2 devices) with distributed_state.split_between_processes([input_id1, input_id2]) as prompt: cur_input = torch.clone(prompt[0]).to(distributed_state.device) diff --git a/examples/distributed_inference/data_parallel_stable_diffusion.py b/examples/distributed_inference/data_parallel_stable_diffusion.py index 09a7f59dce..158201caf3 100644 --- a/examples/distributed_inference/data_parallel_stable_diffusion.py +++ b/examples/distributed_inference/data_parallel_stable_diffusion.py @@ -1,3 +1,18 @@ +""" +.. _data_parallel_stable_diffusion: + +Torch-TensorRT Distributed Inference +====================================================== + +This interactive script is intended as a sample of distributed inference using data +parallelism using Accelerate +library with the Torch-TensorRT workflow on Stable Diffusion model. + +""" + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ import torch from accelerate import PartialState from diffusers import DiffusionPipeline @@ -17,7 +32,10 @@ backend = "torch_tensorrt" # Optimize the UNet portion with Torch-TensorRT -pipe.unet = torch.compile( +pipe.unet = torch.compile( # %% + # Inference + # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + # Assume there are 2 processes (2 devices) pipe.unet, backend=backend, options={ @@ -30,6 +48,12 @@ ) torch_tensorrt.runtime.set_multi_device_safe_mode(True) + +# %% +# Inference +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Assume there are 2 processes (2 devices) with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt: print("before \n") result = pipe(prompt).images[0] diff --git a/examples/distributed_inference/requirement.txt b/examples/distributed_inference/requirement.txt new file mode 100644 index 0000000000..6d8e0aa9f2 --- /dev/null +++ b/examples/distributed_inference/requirement.txt @@ -0,0 +1,3 @@ +accelerate +transformers +diffusers \ No newline at end of file