From 98c42134a5615e1c26f2cca70ff9a4c142850f65 Mon Sep 17 00:00:00 2001 From: MatthieuTPHR <75613333+MatthieuTPHR@users.noreply.github.com> Date: Wed, 2 Nov 2022 10:29:06 +0100 Subject: [PATCH] Up to 2x speedup on GPUs using memory efficient attention (#532) * 2x speedup using memory efficient attention * remove einops dependency * Swap K, M in op instantiation * Simplify code, remove unnecessary maybe_init call and function, remove unused self.scale parameter * make xformers a soft dependency * remove one-liner functions * change one letter variable to appropriate names * Remove Env variable dependency, remove MemoryEfficientCrossAttention class and use enable_xformers_memory_efficient_attention method * Add memory efficient attention toggle to img2img and inpaint pipelines * Clearer management of xformers' availability * update optimizations markdown to add info about memory efficient attention * add benchmarks for TITAN RTX * More detailed explanation of how the mem eff benchmark were ran * Removing autocast from optimization markdown * import_utils: import torch only if is available Co-authored-by: Nouamane Tazi --- docs/source/optimization/fp16.mdx | 39 +++++++++++++ src/diffusers/models/attention.py | 55 +++++++++++++++++-- src/diffusers/models/unet_2d_blocks.py | 12 ++++ src/diffusers/models/unet_2d_condition.py | 11 ++++ .../pipeline_stable_diffusion.py | 18 ++++++ .../pipeline_stable_diffusion_img2img.py | 18 ++++++ .../pipeline_stable_diffusion_inpaint.py | 18 ++++++ src/diffusers/utils/import_utils.py | 16 ++++++ 8 files changed, 183 insertions(+), 4 deletions(-) diff --git a/docs/source/optimization/fp16.mdx b/docs/source/optimization/fp16.mdx index f12c067ba5ee..4371daacc903 100644 --- a/docs/source/optimization/fp16.mdx +++ b/docs/source/optimization/fp16.mdx @@ -22,6 +22,7 @@ We present some techniques and ideas to optimize 🤗 Diffusers _inference_ for | fp16 | 3.61s | x2.63 | | channels last | 3.30s | x2.88 | | traced UNet | 3.21s | x2.96 | +| memory efficient attention | 2.63s | x3.61 | obtained on NVIDIA TITAN RTX by generating a single image of size 512x512 from @@ -290,3 +291,41 @@ pipe.unet = TracedUNet() with torch.inference_mode(): image = pipe([prompt] * 1, num_inference_steps=50).images[0] ``` + + +## Memory Efficient Attention +Recent work on optimizing the bandwitdh in the attention block have generated huge speed ups and gains in GPU memory usage. The most recent being Flash Attention (from @tridao, [code](https://github.com/HazyResearch/flash-attention), [paper](https://arxiv.org/pdf/2205.14135.pdf)) . +Here are the speedups we obtain on a few Nvidia GPUs when running the inference at 512x512 with a batch size of 1 (one prompt): + +| GPU | Base Attention FP16 | Memory Efficient Attention FP16 | +|------------------ |--------------------- |--------------------------------- | +| NVIDIA Tesla T4 | 3.5it/s | 5.5it/s | +| NVIDIA 3060 RTX | 4.6it/s | 7.8it/s | +| NVIDIA A10G | 8.88it/s | 15.6it/s | +| NVIDIA RTX A6000 | 11.7it/s | 21.09it/s | +| NVIDIA TITAN RTX | 12.51it/s | 18.22it/s | +| A100-SXM4-40GB | 18.6it/s | 29.it/s | +| A100-SXM-80GB | 18.7it/s | 29.5it/s | + +To leverage it just make sure you have: + - PyTorch > 1.12 + - Cuda available + - Installed the [xformers](https://github.com/facebookresearch/xformers) library +```python +from diffusers import StableDiffusionPipeline +import torch + +pipe = StableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + revision="fp16", + torch_dtype=torch.float16, +).to("cuda") + +pipe.enable_xformers_memory_efficient_attention() + +with torch.inference_mode(): + sample = pipe("a small cat") + +# optional: You can disable it via +# pipe.disable_xformers_memory_efficient_attention() +``` \ No newline at end of file diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index af441ef86181..1f9cf641c32d 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -18,6 +18,15 @@ import torch.nn.functional as F from torch import nn +from diffusers.utils.import_utils import is_xformers_available + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + class AttentionBlock(nn.Module): """ @@ -150,6 +159,10 @@ def _set_attention_slice(self, slice_size): for block in self.transformer_blocks: block._set_attention_slice(slice_size) + def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for block in self.transformer_blocks: + block._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + def forward(self, hidden_states, context=None): # note: if no context is given, cross-attention defaults to self-attention batch, channel, height, weight = hidden_states.shape @@ -206,6 +219,32 @@ def _set_attention_slice(self, slice_size): self.attn1._slice_size = slice_size self.attn2._slice_size = slice_size + def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + if not is_xformers_available(): + print("Here is how to install it") + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + def forward(self, hidden_states, context=None): hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states @@ -239,6 +278,7 @@ def __init__( # is split across the batch axis to save memory # You can set slice_size with `set_attention_slice` self._slice_size = None + self._use_memory_efficient_attention_xformers = False self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) @@ -279,11 +319,13 @@ def forward(self, hidden_states, context=None, mask=None): # TODO(PVP) - mask is currently never used. Remember to re-implement when used # attention, what we cannot get enough of - - if self._slice_size is None or query.shape[0] // self._slice_size == 1: - hidden_states = self._attention(query, key, value) + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value) else: - hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim) # linear proj hidden_states = self.to_out[0](hidden_states) @@ -341,6 +383,11 @@ def _sliced_attention(self, query, key, value, sequence_length, dim): hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states + def _memory_efficient_attention_xformers(self, query, key, value): + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=None) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + class FeedForward(nn.Module): r""" diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index f4081c5c1cac..ae4fe2d8bba7 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -367,6 +367,10 @@ def set_attention_slice(self, slice_size): for attn in self.attentions: attn._set_attention_slice(slice_size) + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for attn in self.attentions: + attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): hidden_states = self.resnets[0](hidden_states, temb) for attn, resnet in zip(self.attentions, self.resnets[1:]): @@ -542,6 +546,10 @@ def set_attention_slice(self, slice_size): for attn in self.attentions: attn._set_attention_slice(slice_size) + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for attn in self.attentions: + attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + def forward(self, hidden_states, temb=None, encoder_hidden_states=None): output_states = () @@ -1117,6 +1125,10 @@ def set_attention_slice(self, slice_size): self.gradient_checkpointing = False + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for attn in self.attentions: + attn._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + def forward( self, hidden_states, diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index d271b78a6525..7f7f3ecd4435 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -225,6 +225,17 @@ def set_attention_slice(self, slice_size): if hasattr(block, "attentions") and block.attentions is not None: block.set_attention_slice(slice_size) + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): + for block in self.down_blocks: + if hasattr(block, "attentions") and block.attentions is not None: + block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + + self.mid_block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + + for block in self.up_blocks: + if hasattr(block, "attentions") and block.attentions is not None: + block.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers) + def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): module.gradient_checkpointing = value diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 5927f36b12a1..3c1eb734a49d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -113,6 +113,24 @@ def __init__( feature_extractor=feature_extractor, ) + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) + def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): r""" Enable sliced attention computation. diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 00c364f8e5e3..e61fb27acc1e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -151,6 +151,24 @@ def disable_attention_slicing(self): # set slice_size = `None` to disable `set_attention_slice` self.enable_attention_slicing(None) + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) + @torch.no_grad() def __call__( self, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 57f9b65716ee..bbe6ee60832c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -151,6 +151,24 @@ def disable_attention_slicing(self): # set slice_size = `None` to disable `attention slicing` self.enable_attention_slicing(None) + def enable_xformers_memory_efficient_attention(self): + r""" + Enable memory efficient attention as implemented in xformers. + + When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference + time. Speed up at training time is not guaranteed. + + Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention + is used. + """ + self.unet.set_use_memory_efficient_attention_xformers(True) + + def disable_xformers_memory_efficient_attention(self): + r""" + Disable memory efficient attention as implemented in xformers. + """ + self.unet.set_use_memory_efficient_attention_xformers(False) + @torch.no_grad() def __call__( self, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index 2a5f7f64dd07..4ea02dcc94da 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -168,6 +168,18 @@ except importlib_metadata.PackageNotFoundError: _accelerate_available = False +_xformers_available = importlib.util.find_spec("xformers") is not None +try: + _xformers_version = importlib_metadata.version("xformers") + if _torch_available: + import torch + + if torch.__version__ < version.Version("1.12"): + raise ValueError("PyTorch should be >= 1.12") + logger.debug(f"Successfully imported xformers version {_xformers_version}") +except importlib_metadata.PackageNotFoundError: + _xformers_available = False + def is_torch_available(): return _torch_available @@ -205,6 +217,10 @@ def is_scipy_available(): return _scipy_available +def is_xformers_available(): + return _xformers_available + + def is_accelerate_available(): return _accelerate_available