Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add xformers to training scripts #103

Merged
merged 1 commit into from
Dec 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions lora_diffusion/xformers_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import functools

import torch
from diffusers.models.attention import BasicTransformerBlock
from diffusers.utils.import_utils import is_xformers_available

from .lora import LoraInjectedLinear

if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None


@functools.cache
def test_xformers_backwards(size):
@torch.enable_grad()
def _grad(size):
q = torch.randn((1, 4, size), device="cuda")
k = torch.randn((1, 4, size), device="cuda")
v = torch.randn((1, 4, size), device="cuda")

q = q.detach().requires_grad_()
k = k.detach().requires_grad_()
v = v.detach().requires_grad_()

out = xformers.ops.memory_efficient_attention(q, k, v)
loss = out.sum(2).mean(0).sum()

return torch.autograd.grad(loss, v)

try:
_grad(size)
print(size, "pass")
return True
except Exception as e:
print(size, "fail")
return False


def set_use_memory_efficient_attention_xformers(
module: torch.nn.Module, valid: bool
) -> None:
def fn_test_dim_head(module: torch.nn.Module):
if isinstance(module, BasicTransformerBlock):
# dim_head isn't stored anywhere, so back-calculate
source = module.attn1.to_v
if isinstance(source, LoraInjectedLinear):
source = source.linear

dim_head = source.out_features // module.attn1.heads

result = test_xformers_backwards(dim_head)

# If dim_head > dim_head_max, turn xformers off
if not result:
module.set_use_memory_efficient_attention_xformers(False)

for child in module.children():
fn_test_dim_head(child)

if not is_xformers_available() and valid:
print("XFormers is not available. Skipping.")
return

module.set_use_memory_efficient_attention_xformers(valid)

if valid:
fn_test_dim_head(module)
11 changes: 9 additions & 2 deletions train_lora_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@
save_lora_weight,
save_safeloras,
)

from torch.utils.data import Dataset
from lora_diffusion.xformers_utils import set_use_memory_efficient_attention_xformers
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

from pathlib import Path
Expand Down Expand Up @@ -450,6 +450,9 @@ def parse_args(input_args=None):
required=False,
help="Should images be resized to --resolution before training?",
)
parser.add_argument(
"--use_xformers", action="store_true", help="Whether or not to use xformers"
)

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -615,6 +618,10 @@ def main(args):
)
break

if args.use_xformers:
set_use_memory_efficient_attention_xformers(unet, True)
set_use_memory_efficient_attention_xformers(vae, True)

if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if args.train_text_encoder:
Expand Down
11 changes: 9 additions & 2 deletions train_lora_w_ti.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@
save_lora_weight,
extract_lora_ups_down,
)

from torch.utils.data import Dataset
from lora_diffusion.xformers_utils import set_use_memory_efficient_attention_xformers
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

from pathlib import Path
Expand Down Expand Up @@ -575,6 +575,9 @@ def parse_args(input_args=None):
required=False,
help="Should images be resized to --resolution before training?",
)
parser.add_argument(
"--use_xformers", action="store_true", help="Whether or not to use xformers"
)

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -774,6 +777,10 @@ def main(args):
print("Before training: text encoder First Layer lora down", _down.weight.data)
break

if args.use_xformers:
set_use_memory_efficient_attention_xformers(unet, True)
set_use_memory_efficient_attention_xformers(vae, True)

if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if args.train_text_encoder:
Expand Down