diff --git a/lora_diffusion/xformers_utils.py b/lora_diffusion/xformers_utils.py new file mode 100644 index 0000000..fdabf66 --- /dev/null +++ b/lora_diffusion/xformers_utils.py @@ -0,0 +1,70 @@ +import functools + +import torch +from diffusers.models.attention import BasicTransformerBlock +from diffusers.utils.import_utils import is_xformers_available + +from .lora import LoraInjectedLinear + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +@functools.cache +def test_xformers_backwards(size): + @torch.enable_grad() + def _grad(size): + q = torch.randn((1, 4, size), device="cuda") + k = torch.randn((1, 4, size), device="cuda") + v = torch.randn((1, 4, size), device="cuda") + + q = q.detach().requires_grad_() + k = k.detach().requires_grad_() + v = v.detach().requires_grad_() + + out = xformers.ops.memory_efficient_attention(q, k, v) + loss = out.sum(2).mean(0).sum() + + return torch.autograd.grad(loss, v) + + try: + _grad(size) + print(size, "pass") + return True + except Exception as e: + print(size, "fail") + return False + + +def set_use_memory_efficient_attention_xformers( + module: torch.nn.Module, valid: bool +) -> None: + def fn_test_dim_head(module: torch.nn.Module): + if isinstance(module, BasicTransformerBlock): + # dim_head isn't stored anywhere, so back-calculate + source = module.attn1.to_v + if isinstance(source, LoraInjectedLinear): + source = source.linear + + dim_head = source.out_features // module.attn1.heads + + result = test_xformers_backwards(dim_head) + + # If dim_head > dim_head_max, turn xformers off + if not result: + module.set_use_memory_efficient_attention_xformers(False) + + for child in module.children(): + fn_test_dim_head(child) + + if not is_xformers_available() and valid: + print("XFormers is not available. Skipping.") + return + + module.set_use_memory_efficient_attention_xformers(valid) + + if valid: + fn_test_dim_head(module) diff --git a/train_lora_dreambooth.py b/train_lora_dreambooth.py index 362fcb3..6a29a5c 100644 --- a/train_lora_dreambooth.py +++ b/train_lora_dreambooth.py @@ -36,9 +36,9 @@ save_lora_weight, save_safeloras, ) - -from torch.utils.data import Dataset +from lora_diffusion.xformers_utils import set_use_memory_efficient_attention_xformers from PIL import Image +from torch.utils.data import Dataset from torchvision import transforms from pathlib import Path @@ -450,6 +450,9 @@ def parse_args(input_args=None): required=False, help="Should images be resized to --resolution before training?", ) + parser.add_argument( + "--use_xformers", action="store_true", help="Whether or not to use xformers" + ) if input_args is not None: args = parser.parse_args(input_args) @@ -615,6 +618,10 @@ def main(args): ) break + if args.use_xformers: + set_use_memory_efficient_attention_xformers(unet, True) + set_use_memory_efficient_attention_xformers(vae, True) + if args.gradient_checkpointing: unet.enable_gradient_checkpointing() if args.train_text_encoder: diff --git a/train_lora_w_ti.py b/train_lora_w_ti.py index 868dcff..7e9eb81 100644 --- a/train_lora_w_ti.py +++ b/train_lora_w_ti.py @@ -36,9 +36,9 @@ save_lora_weight, extract_lora_ups_down, ) - -from torch.utils.data import Dataset +from lora_diffusion.xformers_utils import set_use_memory_efficient_attention_xformers from PIL import Image +from torch.utils.data import Dataset from torchvision import transforms from pathlib import Path @@ -575,6 +575,9 @@ def parse_args(input_args=None): required=False, help="Should images be resized to --resolution before training?", ) + parser.add_argument( + "--use_xformers", action="store_true", help="Whether or not to use xformers" + ) if input_args is not None: args = parser.parse_args(input_args) @@ -774,6 +777,10 @@ def main(args): print("Before training: text encoder First Layer lora down", _down.weight.data) break + if args.use_xformers: + set_use_memory_efficient_attention_xformers(unet, True) + set_use_memory_efficient_attention_xformers(vae, True) + if args.gradient_checkpointing: unet.enable_gradient_checkpointing() if args.train_text_encoder: