Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize Stable Diffusion #371

Merged
merged 61 commits into from
Sep 30, 2022

Conversation

NouamaneTazi
Copy link
Member

@NouamaneTazi NouamaneTazi commented Sep 6, 2022

This PR aims to optimize the stable diffusion pipeline by using the following:

  • Remove any cuda sync operation
  • Using baddbmm instead of bmm + matmul resulting in 10% speedup
  • Full fp16 inference (before it only worked with autocast)

Remove unnecessary synchronizations

Before this PR:

image

After this PR:


image

Full fp16 vs autocast

Latency Speedup
original 10.51s x1
refactoring 9.50s x1.1
autocast (fp16) 5.47s x1.91
fp16 3.61s x2.91
obtained on **NVIDIA TITAN RTX**

results were obtained using the following script:

torch.backends.cudnn.benchmark = True

pipe = StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4", 
    use_auth_token=True,
    # revision="fp16",
    # torch_dtype=torch.float16
).to("cuda")

prompt = "a photo of an astronaut riding a horse on mars"

# warmup
with torch.inference_mode():
    image = pipe([prompt]*1, num_inference_steps=5)["sample"][0]  

for _ in range(3):
    torch.cuda.synchronize()
    start_time = time.time()
    with torch.inference_mode():
        image = pipe([prompt]*1, num_inference_steps=50)["sample"][0]  
    torch.cuda.synchronize()
    print(f"Pipeline inference took {time.time() - start_time:.2f} seconds")

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 6, 2022

The documentation is not available anymore as the PR was closed or merged.

@NouamaneTazi
Copy link
Member Author

NouamaneTazi commented Sep 6, 2022

I added a memory_format kwarg to the DiffusionPipeline.from_pretrained method, which only affects Unets for now. And it gives the same speed gain as before 5.46s -> 5.14s

scheduler = LMSDiscreteScheduler(
    beta_start=0.00085, 
    beta_end=0.012, 
    beta_schedule="scaled_linear"
)

pipe = StableDiffusionPipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-3", 
    scheduler=scheduler,
    use_auth_token=True,
    memory_format=torch.channels_last,
    # revision="fp16",
).to("cuda")

wdyt @patil-suraj @anton-l @patrickvonplaten ?

@NouamaneTazi
Copy link
Member Author

I think the channels last format deserve a PR on its own, so I'm only keeping changes to stable diffusion for now

@patil-suraj
Copy link
Contributor

Hey @NouamaneTazi

Not sure if we can have memory_format in DiffusionPipeline,

  • because some pipeline might not have unet if the model is transformer based.
  • Also does this work for 1D unet as well ? There are some audio models coming soon which use 1D conv layers.

so memory_format doesn't seem to be general enough to be in DiffusionPipeline.

But the speed-ups are interesting, so rather than supporting it in the pipeline directly maybe we could add a doc page here https://github.com/huggingface/diffusers/tree/main/docs/source/optimization, that shows how to put the unet in memory_format=torch.channels_last and then in diffusers just modify the unet forward pass to put the inputs in the same memory_format before feeding it to the layers.

@NouamaneTazi
Copy link
Member Author

Hello @patil-suraj. I agree that it's probably too early to add support for it in diffusers but it's definitely worth considering in the future imo.

Channels last memory format is aimed to work for any models that handle 4D NCHW tensors. Since it's still a beta feature, not all operators are supported yet (you can find the list here). So in the case of a model which has some operators that don't support this memory format, the switching between channels last and the default format could result in a worst performance as explained here

However, not all operators fully converted to support channels last (usually returning contiguous output instead). In the example posted above, layers that does not support channels last will stop the memory format propagation. In spite of that, as we have converted the model to channels last format, that means each convolution layer, which has its 4 dimensional weight in channels last memory format, will restore channels last memory format and benefit from faster kernels.

@patil-suraj
Copy link
Contributor

Hello @patil-suraj. I agree that it's probably too early to add support for it in diffusers but it's definitely worth considering in the future imo.

Yes, agree! For now we could definitely add this in the docs. Feel free to open a PR for that :)

Thanks for the information.

@yuananf
Copy link

yuananf commented Sep 29, 2022

Hello, I tested your branch with the code here https://github.com/NouamaneTazi/diffusers/blob/stable_diff_opti/docs/source/optimization/fp16.mdx#tracing, but it got error like this

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [3], in <cell line: 2>()
      1 prompt = "A beautiful painting of a singular lighthouse, shining its light across a tumultuous sea of blood by greg rutkowski and thomas kinkade, Trending on artstation, yellow color scheme"
      2 with torch.inference_mode():
----> 3     image = pipe([prompt]*1, num_inference_steps=50).images[0]
      5     image

File ~/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File ~/stable-diffusion/opt/diffusers/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py:197, in StableDiffusionPipeline.__call__(self, prompt, height, width, num_inference_steps, guidance_scale, eta, generator, latents, output_type, return_dict, **kwargs)
    192     logger.warning(
    193         "The following part of your input was truncated because CLIP can only handle sequences up to"
    194         f" {self.tokenizer.model_max_length} tokens: {removed_text}"
    195     )
    196     text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
--> 197 text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
    199 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
    200 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
    201 # corresponds to doing no classifier free guidance.
    202 do_classifier_free_guidance = guidance_scale > 1.0

File ~/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/ldm/lib/python3.8/site-packages/transformers/models/clip/modeling_clip.py:722, in CLIPTextModel.forward(self, input_ids, attention_mask, position_ids, output_attentions, output_hidden_states, return_dict)
    694 @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
    695 @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=CLIPTextConfig)
    696 def forward(
   (...)
    703     return_dict: Optional[bool] = None,
    704 ) -> Union[Tuple, BaseModelOutputWithPooling]:
    705     r"""
    706     Returns:
    707 
   (...)
    720     >>> pooled_output = outputs.pooler_output  # pooled (EOS token) states
    721     ```"""
--> 722     return self.text_model(
    723         input_ids=input_ids,
    724         attention_mask=attention_mask,
    725         position_ids=position_ids,
    726         output_attentions=output_attentions,
    727         output_hidden_states=output_hidden_states,
    728         return_dict=return_dict,
    729     )

File ~/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/ldm/lib/python3.8/site-packages/transformers/models/clip/modeling_clip.py:643, in CLIPTextTransformer.forward(self, input_ids, attention_mask, position_ids, output_attentions, output_hidden_states, return_dict)
    639 if attention_mask is not None:
    640     # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
    641     attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
--> 643 encoder_outputs = self.encoder(
    644     inputs_embeds=hidden_states,
    645     attention_mask=attention_mask,
    646     causal_attention_mask=causal_attention_mask,
    647     output_attentions=output_attentions,
    648     output_hidden_states=output_hidden_states,
    649     return_dict=return_dict,
    650 )
    652 last_hidden_state = encoder_outputs[0]
    653 last_hidden_state = self.final_layer_norm(last_hidden_state)

File ~/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/ldm/lib/python3.8/site-packages/transformers/models/clip/modeling_clip.py:574, in CLIPEncoder.forward(self, inputs_embeds, attention_mask, causal_attention_mask, output_attentions, output_hidden_states, return_dict)
    567     layer_outputs = torch.utils.checkpoint.checkpoint(
    568         create_custom_forward(encoder_layer),
    569         hidden_states,
    570         attention_mask,
    571         causal_attention_mask,
    572     )
    573 else:
--> 574     layer_outputs = encoder_layer(
    575         hidden_states,
    576         attention_mask,
    577         causal_attention_mask,
    578         output_attentions=output_attentions,
    579     )
    581 hidden_states = layer_outputs[0]
    583 if output_attentions:

File ~/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/ldm/lib/python3.8/site-packages/transformers/models/clip/modeling_clip.py:317, in CLIPEncoderLayer.forward(self, hidden_states, attention_mask, causal_attention_mask, output_attentions)
    314 residual = hidden_states
    316 hidden_states = self.layer_norm1(hidden_states)
--> 317 hidden_states, attn_weights = self.self_attn(
    318     hidden_states=hidden_states,
    319     attention_mask=attention_mask,
    320     causal_attention_mask=causal_attention_mask,
    321     output_attentions=output_attentions,
    322 )
    323 hidden_states = residual + hidden_states
    325 residual = hidden_states

File ~/anaconda3/envs/ldm/lib/python3.8/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/ldm/lib/python3.8/site-packages/transformers/models/clip/modeling_clip.py:257, in CLIPAttention.forward(self, hidden_states, attention_mask, causal_attention_mask, output_attentions)
    253     attn_weights_reshaped = None
    255 attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
--> 257 attn_output = torch.bmm(attn_probs, value_states)
    259 if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
    260     raise ValueError(
    261         f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}"
    262     )

RuntimeError: expected scalar type Half but found Float

Comment on lines +150 to +151
self.sigmas = torch.from_numpy(sigmas).to(device=device)
self.timesteps = torch.from_numpy(timesteps).to(device=device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this only required for LMS scheduler ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also note that with LMS we don't pass timesteps to the model but rather index of timesteps, cf

latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this only required for LMS scheduler ?

Because it's the only scheduler that has self.timesteps as a torch tensor

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also note that with LMS we don't pass timesteps to the model but rather index of timesteps, cf

The issue i was tackling comes from the UNet

noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh I see, my bad! Should we maybe add .to method to schedulers which will will put all the scheduler state on the device rather than manually handling it in multiple places. This will also avoid creating unused kwargs args.
wdyt @patrickvonplaten @anton-l @pcuenca

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@anton-l also suggested that, but since there were other tensor .to() operations in the same scheduler we postponed the decision: #534 (comment)

I think most schedulers don't need .to(), but if it saves us from a kwargs argument then I think that tips the scale.

@patrickvonplaten
Copy link
Contributor

Let's try to get this in. Three final things:

    1. Rebase to main, revert the lowering of the test tolerance for the K-LMS test
    1. Remove adding **kwarg
    1. and re-run all slow tests (they should all pass now)

@NouamaneTazi
Copy link
Member Author

All tests but tests/test_pipelines.py::PipelineTesterMixin::test_stable_diffusion_memory_chunking should be passing now :)

@NouamaneTazi
Copy link
Member Author

Hello, I tested your branch with the code here NouamaneTazi/diffusers@stable_diff_opti/docs/source/optimization/fp16.mdx#tracing, but it got error like this

@yuananf can you retry, it should be working now :-)

Copy link
Member Author

@NouamaneTazi NouamaneTazi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments about other things we could optimize. Also lmk if I should add tests to run SD pipelines in fp16

@patil-suraj @patrickvonplaten @anton-l

Comment on lines +40 to +41
exponent = -math.log(max_period) * torch.arange(
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
Copy link
Member Author

@NouamaneTazi NouamaneTazi Sep 29, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this opereration to run in fp32 even when the pipeline runs in fp16?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to try it out - let's maybe do it in a follow-up PR? :-)

exponent = exponent / (half_dim - downscale_freq_shift)

emb = torch.exp(exponent).to(device=timesteps.device)
emb = torch.exp(exponent)
emb = timesteps[:, None].float() * emb[None, :]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as previous comment

@@ -230,16 +230,16 @@ def forward(
# 1. time
timesteps = timestep
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need int64 tensors for timesteps, no matter the pipeline's precision?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At least for all the Stable Diffusion applications I've seen so far, timesteps are ints in the range 0..1000.

Even if other diffusion models do several orders of magnitude more than that, you'd think torch.int would be plenty.

Unless there's some byte alignment optimization reason to specifically make them 64-bit?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can leave timesteps as int32

if torch.is_tensor(self.scheduler.timesteps):
timesteps_tensor = self.scheduler.timesteps.to(self.device)
else:
timesteps_tensor = torch.tensor(self.scheduler.timesteps.copy(), device=self.device)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make the dtype for timesteps int64?

@yuananf
Copy link

yuananf commented Sep 30, 2022

Hello, I tested your branch with the code here NouamaneTazi/diffusers@stable_diff_opti/docs/source/optimization/fp16.mdx#tracing, but it got error like this

@yuananf can you retry, it should be working now :-)

Unfortunately it failed with the same error.

@patrickvonplaten
Copy link
Contributor

Cool merging! Let's monitor the slow tests for this one

@patrickvonplaten patrickvonplaten merged commit 9ebaea5 into huggingface:main Sep 30, 2022
@NouamaneTazi NouamaneTazi mentioned this pull request Sep 30, 2022
prathikr pushed a commit to prathikr/diffusers that referenced this pull request Oct 26, 2022
* initial commit

* make UNet stream capturable

* try to fix noise_pred value

* remove cuda graph and keep NB

* non blocking unet with PNDMScheduler

* make timesteps np arrays for pndm scheduler
because lists don't get formatted to tensors in `self.set_format`

* make max async in pndm

* use channel last format in unet

* avoid moving timesteps device in each unet call

* avoid memcpy op in `get_timestep_embedding`

* add `channels_last` kwarg to `DiffusionPipeline.from_pretrained`

* update TODO

* replace `channels_last` kwarg with `memory_format` for more generality

* revert the channels_last changes to leave it for another PR

* remove non_blocking when moving input ids to device

* remove blocking from all .to() operations at beginning of pipeline

* fix merging

* fix merging

* model can run in other precisions without autocast

* attn refactoring

* Revert "attn refactoring"

This reverts commit 0c70c0e.

* remove restriction to run conv_norm in fp32

* use `baddbmm` instead of `matmul`for better in attention for better perf

* removing all reshapes to test perf

* Revert "removing all reshapes to test perf"

This reverts commit 006ccb8.

* add shapes comments

* hardcore whats needed for jitting

* Revert "hardcore whats needed for jitting"

This reverts commit 2fa9c69.

* Revert "remove restriction to run conv_norm in fp32"

This reverts commit cec5928.

* revert using baddmm in attention's forward

* cleanup comment

* remove restriction to run conv_norm in fp32. no quality loss was noticed

This reverts commit cc9bc13.

* add more optimizations techniques to docs

* Revert "add shapes comments"

This reverts commit 31c58ea.

* apply suggestions

* make quality

* apply suggestions

* styling

* `scheduler.timesteps` are now arrays so we dont need .to()

* remove useless .type()

* use mean instead of max in `test_stable_diffusion_inpaint_pipeline_k_lms`

* move scheduler timestamps to correct device if tensors

* add device to `set_timesteps` in LMSD scheduler

* `self.scheduler.set_timesteps` now uses device arg for schedulers that accept it

* quick fix

* styling

* remove kwargs from schedulers `set_timesteps`

* revert to using max in K-LMS inpaint pipeline test

* Revert "`self.scheduler.set_timesteps` now uses device arg for schedulers that accept it"

This reverts commit 00d5a51.

* move timesteps to correct device before loop in SD pipeline

* apply previous fix to other SD pipelines

* UNet now accepts tensor timesteps even on wrong device, to avoid errors
- it shouldnt affect performance if timesteps are alrdy on correct device
- it does slow down performance if they're on the wrong device

* fix pipeline when timesteps are arrays with strides
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* initial commit

* make UNet stream capturable

* try to fix noise_pred value

* remove cuda graph and keep NB

* non blocking unet with PNDMScheduler

* make timesteps np arrays for pndm scheduler
because lists don't get formatted to tensors in `self.set_format`

* make max async in pndm

* use channel last format in unet

* avoid moving timesteps device in each unet call

* avoid memcpy op in `get_timestep_embedding`

* add `channels_last` kwarg to `DiffusionPipeline.from_pretrained`

* update TODO

* replace `channels_last` kwarg with `memory_format` for more generality

* revert the channels_last changes to leave it for another PR

* remove non_blocking when moving input ids to device

* remove blocking from all .to() operations at beginning of pipeline

* fix merging

* fix merging

* model can run in other precisions without autocast

* attn refactoring

* Revert "attn refactoring"

This reverts commit 0c70c0e.

* remove restriction to run conv_norm in fp32

* use `baddbmm` instead of `matmul`for better in attention for better perf

* removing all reshapes to test perf

* Revert "removing all reshapes to test perf"

This reverts commit 006ccb8.

* add shapes comments

* hardcore whats needed for jitting

* Revert "hardcore whats needed for jitting"

This reverts commit 2fa9c69.

* Revert "remove restriction to run conv_norm in fp32"

This reverts commit cec5928.

* revert using baddmm in attention's forward

* cleanup comment

* remove restriction to run conv_norm in fp32. no quality loss was noticed

This reverts commit cc9bc13.

* add more optimizations techniques to docs

* Revert "add shapes comments"

This reverts commit 31c58ea.

* apply suggestions

* make quality

* apply suggestions

* styling

* `scheduler.timesteps` are now arrays so we dont need .to()

* remove useless .type()

* use mean instead of max in `test_stable_diffusion_inpaint_pipeline_k_lms`

* move scheduler timestamps to correct device if tensors

* add device to `set_timesteps` in LMSD scheduler

* `self.scheduler.set_timesteps` now uses device arg for schedulers that accept it

* quick fix

* styling

* remove kwargs from schedulers `set_timesteps`

* revert to using max in K-LMS inpaint pipeline test

* Revert "`self.scheduler.set_timesteps` now uses device arg for schedulers that accept it"

This reverts commit 00d5a51.

* move timesteps to correct device before loop in SD pipeline

* apply previous fix to other SD pipelines

* UNet now accepts tensor timesteps even on wrong device, to avoid errors
- it shouldnt affect performance if timesteps are alrdy on correct device
- it does slow down performance if they're on the wrong device

* fix pipeline when timesteps are arrays with strides
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants