From db24b3b3509a3a7d9c07bcd2450cacafdf665fdd Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Thu, 16 May 2024 23:04:28 +0000 Subject: [PATCH] feat: data parallel inference sample --- docsrc/index.rst | 6 ++ examples/distributed_inference/README.md | 14 ++++ .../data_parallel_gpt2.py | 64 +++++++++++++++++++ .../data_parallel_stable_diffusion.py | 61 ++++++++++++++++++ .../distributed_inference/requirement.txt | 3 + 5 files changed, 148 insertions(+) create mode 100644 examples/distributed_inference/README.md create mode 100644 examples/distributed_inference/data_parallel_gpt2.py create mode 100644 examples/distributed_inference/data_parallel_stable_diffusion.py create mode 100644 examples/distributed_inference/requirement.txt diff --git a/docsrc/index.rst b/docsrc/index.rst index 175ab7e8ab..b6acbaf075 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -111,7 +111,13 @@ Tutorials tutorials/_rendered_examples/dynamo/torch_compile_transformers_example tutorials/_rendered_examples/dynamo/torch_compile_advanced_usage tutorials/_rendered_examples/dynamo/torch_compile_stable_diffusion +<<<<<<< HEAD tutorials/_rendered_examples/dynamo/custom_kernel_plugins +======= + tutorials/_rendered_examples/distributed_inference/data_parallel_gpt2 + tutorials/_rendered_examples/distributed_inference/data_parallel_stable_diffusion + +>>>>>>> dfbf6ea84 (feat: data parallel inference sample) Python API Documenation ------------------------ diff --git a/examples/distributed_inference/README.md b/examples/distributed_inference/README.md new file mode 100644 index 0000000000..f9608e8950 --- /dev/null +++ b/examples/distributed_inference/README.md @@ -0,0 +1,14 @@ +# Torch-TensorRT parallelism for distributed inference + +Examples in this folder demonstrates doing distributed inference on multiple devices with Torch-TensorRT backend. + +1. Data parallel distributed inference based on [Acclerate](https://huggingface.co/docs/accelerate/usage_guides/distributed_inference) + +Using Accelerate users can achieve data parallel distributed inference with Torch-TensorRt backend. In this case, the entire model +will be loaded onto each GPU and different chunks of batch input is processed on each device. + +See the examples started with `data_parallel` for more details. + +2. Tensor parallel distributed inference + +In development. diff --git a/examples/distributed_inference/data_parallel_gpt2.py b/examples/distributed_inference/data_parallel_gpt2.py new file mode 100644 index 0000000000..c6e3d9d3c8 --- /dev/null +++ b/examples/distributed_inference/data_parallel_gpt2.py @@ -0,0 +1,64 @@ +""" +.. _data_parallel_gpt2: + +Torch-TensorRT Distributed Inference +====================================================== + +This interactive script is intended as a sample of distributed inference using data +parallelism using Accelerate +library with the Torch-TensorRT workflow on GPT2 model. + +""" + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +import torch +from accelerate import PartialState +from transformers import AutoTokenizer, GPT2LMHeadModel + +import torch_tensorrt + +tokenizer = AutoTokenizer.from_pretrained("gpt2") + +# Set input prompts for different devices +prompt1 = "GPT2 is a model developed by." +prompt2 = "Llama is a model developed by " + +input_id1 = tokenizer(prompt1, return_tensors="pt").input_ids +input_id2 = tokenizer(prompt2, return_tensors="pt").input_ids + +distributed_state = PartialState() + +# Import GPT2 model and load to distributed devices +model = GPT2LMHeadModel.from_pretrained("gpt2").eval().to(distributed_state.device) + + +# Instantiate model with Torch-TensorRT backend +model.forward = torch.compile( + model.forward, + backend="torch_tensorrt", + options={ + "truncate_long_and_double": True, + "enabled_precisions": {torch.float16}, + "debug": True, + }, + dynamic=False, +) + +# %% +# Inference +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Assume there are 2 processes (2 devices) +with distributed_state.split_between_processes([input_id1, input_id2]) as prompt: + cur_input = torch.clone(prompt[0]).to(distributed_state.device) + + gen_tokens = model.generate( + cur_input, + do_sample=True, + temperature=0.9, + max_length=100, + ) + gen_text = tokenizer.batch_decode(gen_tokens)[0] diff --git a/examples/distributed_inference/data_parallel_stable_diffusion.py b/examples/distributed_inference/data_parallel_stable_diffusion.py new file mode 100644 index 0000000000..158201caf3 --- /dev/null +++ b/examples/distributed_inference/data_parallel_stable_diffusion.py @@ -0,0 +1,61 @@ +""" +.. _data_parallel_stable_diffusion: + +Torch-TensorRT Distributed Inference +====================================================== + +This interactive script is intended as a sample of distributed inference using data +parallelism using Accelerate +library with the Torch-TensorRT workflow on Stable Diffusion model. + +""" + +# %% +# Imports and Model Definition +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +import torch +from accelerate import PartialState +from diffusers import DiffusionPipeline + +import torch_tensorrt + +model_id = "CompVis/stable-diffusion-v1-4" + +# Instantiate Stable Diffusion Pipeline with FP16 weights +pipe = DiffusionPipeline.from_pretrained( + model_id, revision="fp16", torch_dtype=torch.float16 +) + +distributed_state = PartialState() +pipe = pipe.to(distributed_state.device) + +backend = "torch_tensorrt" + +# Optimize the UNet portion with Torch-TensorRT +pipe.unet = torch.compile( # %% + # Inference + # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + # Assume there are 2 processes (2 devices) + pipe.unet, + backend=backend, + options={ + "truncate_long_and_double": True, + "precision": torch.float16, + "debug": True, + "use_python_runtime": True, + }, + dynamic=False, +) +torch_tensorrt.runtime.set_multi_device_safe_mode(True) + + +# %% +# Inference +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +# Assume there are 2 processes (2 devices) +with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt: + print("before \n") + result = pipe(prompt).images[0] + print("after ") + result.save(f"result_{distributed_state.process_index}.png") diff --git a/examples/distributed_inference/requirement.txt b/examples/distributed_inference/requirement.txt new file mode 100644 index 0000000000..6d8e0aa9f2 --- /dev/null +++ b/examples/distributed_inference/requirement.txt @@ -0,0 +1,3 @@ +accelerate +transformers +diffusers \ No newline at end of file