Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support SDXL and its distributed inference #1514

Closed
wants to merge 2 commits into from
Closed

Conversation

Zars19
Copy link
Contributor

@Zars19 Zars19 commented Apr 28, 2024

The idea of patch parallelism comes from the CVPR 2024 paper Distrifusion. In order to reduce the difficulty of implementation, all communications in the example are synchronous.

This can help SDXL achieve better performance, especially when the resolution is very high

A100, 50 steps, 2048x2048, SDXL

Framework sync_mode n_gpu latency(s) speed_up memory(MiB)
Torch - 1 25.25 1x 42147
TRT - 1 21.98 1.15x 42895
DistrFusion(Torch) split_batch 2 13.33 1.89x 40173
Ours split_batch 2 11.69 2.16x 42675
DistrFusion(Torch) corrected_async_gn 4 8.27 3.05x 49087
DistrFusion(Torch) full_sync 4 8.64 2.92x 51943
Ours full_sync 4 7.73 3.27x 43073

@Zars19 Zars19 changed the title Add distributed inference for UNet models and SDXL examples Support SDXL and its distributed inference Apr 30, 2024
@juney-nvidia juney-nvidia requested a review from nv-guomingz June 3, 2024 04:21
@juney-nvidia
Copy link
Collaborator

@Zars19 thanks for the contribution to TensorRT-LLM!

@nv-guomingz can you help take care of this? :)

Thanks
June

@nv-guomingz
Copy link
Collaborator

@Zars19 thanks for the contribution to TensorRT-LLM!

@nv-guomingz can you help take care of this? :)

Thanks June

Sure, I'll collobrate with @Zars19 for enabling SDXL with TRT-LLM.

@nv-guomingz
Copy link
Collaborator

Hi @Zars19 , could u please resolve the code conflicts firstly?

@Zars19
Copy link
Contributor Author

Zars19 commented Jun 12, 2024

Hi @Zars19 , could u please resolve the code conflicts firstly?

I have resolved the conflict :) @nv-guomingz

@nv-guomingz
Copy link
Collaborator

Hi @Zars19 thanks for your patience.
Could u please update this MR by updating/rebasing those two commit(including one merge commit) into one commit which make us easy to integrate and testing?

@Zars19
Copy link
Contributor Author

Zars19 commented Jul 12, 2024

@nv-guomingz I completed the git rebase

tensorrt_llm/builder.py Outdated Show resolved Hide resolved
@lmxyy
Copy link

lmxyy commented Aug 22, 2024

Any updates on the code review?

@Zars19
Copy link
Contributor Author

Zars19 commented Aug 23, 2024

Any updates on the code review?

After rebasing the code, I haven't received feedback for a while now
@nv-guomingz @juney-nvidia

@hchings
Copy link
Collaborator

hchings commented Oct 27, 2024

Hi @Zars19, thanks for your patience. We recently resumed reviewing this with the latest TRT-LLM and noticed some issues so far:

  1. dtype passed into get_timestep_embedding(line) but time embedding is still fixed to fp32. -- this is a minor fix and we've added it.
  2. Image output when n_gpu=2 is incorrect. It seems that the DistriUNetPP or maybe the scatter/sparse op has issues. Can you help taking a look? Below are output imgs of this MR on Distrifusion's default prompt, you can see the part dealing w/ activation patch is not right. (Config is the same as this MR's, except that randomseed is set to 1234 for easier comparison w/ Distrifuion original repo).
image
  1. n_gpu=4 with either TRT-LLM v0.13 or our internal latest main branch failed to build with the following error:
image

A few other questions:
  • In your benchmark, do the rows that marked as Ours / split_batch and Ours / full_async both use batch splitting? I don't see there's a flag in script to enable/disable batch splitting in your scripts.

  • Could you share how did you measure the peak mem usage? I got similar speedup as yours (other than I cannot run 4 gpus yet), but I observed much less peak mem usuage of DistrFusion(Torch).



@Zars19
Copy link
Contributor Author

Zars19 commented Nov 1, 2024

@hchings Thank you for your feedback! I've rebased with the latest official code, debugged, and updated some issues.

  1. I've fixed the fp32 constant type mismatch issue in the latest commit.
  2. I couldn't reproduce the image generation error with the default prompt. Using the following commands, the generated image is consistent with the original result:
    mpirun -n 2 python build_sdxl_unet.py --size 1024
    mpirun -n 2 python run_sdxl.py --size 1024 --prompt "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" --seed 1234
    

image

  1. In the new commit, I manually fused pre-padding or post-padding into conv2d, resolving the compilation issue.

Responses to other issues:

  • Yes, splitting the batch enhances the efficiency of distributed inference, hence I prioritized processing the two classifier-guidance-free batches separately. This is done by default without a user option.
  • The table shows performance and memory usage for generating images at a resolution of 2048x2048, as optimizations yield a higher speedup ratio at higher resolutions. Generating 1024x1024 images requires less than half of the memory. Could you please confirm if your comparison is based on the memory data at the 1024 resolution size?

@hchings
Copy link
Collaborator

hchings commented Nov 4, 2024

Hi @Zars19 , thanks for the fixes!

  • Re # 2. I pulled your latest changes and can still reproduce it with --size=2048. Attached another result below. However, it seems that it's a limitation of SDXL rather than distributed implementation issue, as I can repro similar broken images w/ the Distrifuser repo and worse img with vanilla HuggingFace diffusers. If this is true, then it seems there's not much need to test speedup for 2048x2048 and 3840x3840 for this MR as did in the paper, as the output img won't be practical. Pls correct me if I miss sth.

    # this MR
    mpirun -n 4 python build_sdxl_unet.py --size 2048
    mpirun -n 4 python run_sdxl.py --size 2048 --prompt "flowers, rabbit" --seed 1234
    
    # Distrifuser
    torchrun --nproc_per_node=4 scripts/run_sdxl.py --image_size 2048 --prompt "flowers, rabbit" --output_path xxx
    
    image

  • The latest commit of this MR introduced a shape mismatch issue at ResNet forward() (line), where input_tensors should be in the shape of (1, 320, 128, 128) but is now (1, 320, 126, 126). Can you double check your new modification in conv2d.py?

    Repro (after build):

    mpirun -n 2 python build_sdxl_unet.py --size 1024
    
    image
  • " Could you please confirm if your comparison is based on the memory data at the 1024 resolution size?" -- I was testing w/ 2048 as you did. I will retest at our end again once all model issues are fixed. Thanks!

@Zars19
Copy link
Contributor Author

Zars19 commented Nov 4, 2024

@hchings Thanks for your reply.

  • Yes, the 2048 resolution images generated by distrifuser and this PR are consistent with the original output of the SDXL model. I believe the paper's comparison of speedup at higher resolutions aims to demonstrate a general trend: that this distributed method benefits diffusion models with improved performance at higher resolutions. However, for the SDXL model specifically, generating resolutions above the trained 1024 often results in pattern repetition and artifacts, leading to invalid outputs.
import torch
from diffusers import StableDiffusionXLPipeline

pipeline = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
    use_safetensors=True,
)
pipeline.to('cuda')

seed = 1234
size = 2048
#prompt = "flowers, rabbit"
prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"

image = pipeline(
    prompt=prompt,
    generator=torch.Generator(device="cuda").manual_seed(seed),
    height=size, width=size).images[0]

image.save(f"output.png")

Picture generated by the original SDXL:

  • Previously, I set the default values of the pre_padding and post_padding parameters in the conv2d function of functional.py to (0,0). These had higher priority and were causing shape issues by making the regular padding parameter ineffective. This issue has now been resolved.

@hchings
Copy link
Collaborator

hchings commented Nov 12, 2024

Hi @Zars19, FYI that we're doing some final wrap-ups of this and will merge it soon. Thanks!

@hchings hchings added triaged Issue has been triaged by maintainers and removed waiting for feedback labels Nov 13, 2024
@hchings hchings added the Merged label Nov 22, 2024
@hchings
Copy link
Collaborator

hchings commented Nov 22, 2024

Hi @Zars19, we've merged this internally and it will show up in the upcoming public release under /example/sdxl.
Thank you for the contribution & patience.

@hchings hchings closed this Nov 22, 2024
@kaiyux kaiyux mentioned this pull request Nov 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged triaged Issue has been triaged by maintainers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants