-
Notifications
You must be signed in to change notification settings - Fork 486
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add xformers to training scripts (#103)
- Loading branch information
1 parent
7dd0467
commit 4936d0f
Showing
3 changed files
with
88 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
import functools | ||
|
||
import torch | ||
from diffusers.models.attention import BasicTransformerBlock | ||
from diffusers.utils.import_utils import is_xformers_available | ||
|
||
from .lora import LoraInjectedLinear | ||
|
||
if is_xformers_available(): | ||
import xformers | ||
import xformers.ops | ||
else: | ||
xformers = None | ||
|
||
|
||
@functools.cache | ||
def test_xformers_backwards(size): | ||
@torch.enable_grad() | ||
def _grad(size): | ||
q = torch.randn((1, 4, size), device="cuda") | ||
k = torch.randn((1, 4, size), device="cuda") | ||
v = torch.randn((1, 4, size), device="cuda") | ||
|
||
q = q.detach().requires_grad_() | ||
k = k.detach().requires_grad_() | ||
v = v.detach().requires_grad_() | ||
|
||
out = xformers.ops.memory_efficient_attention(q, k, v) | ||
loss = out.sum(2).mean(0).sum() | ||
|
||
return torch.autograd.grad(loss, v) | ||
|
||
try: | ||
_grad(size) | ||
print(size, "pass") | ||
return True | ||
except Exception as e: | ||
print(size, "fail") | ||
return False | ||
|
||
|
||
def set_use_memory_efficient_attention_xformers( | ||
module: torch.nn.Module, valid: bool | ||
) -> None: | ||
def fn_test_dim_head(module: torch.nn.Module): | ||
if isinstance(module, BasicTransformerBlock): | ||
# dim_head isn't stored anywhere, so back-calculate | ||
source = module.attn1.to_v | ||
if isinstance(source, LoraInjectedLinear): | ||
source = source.linear | ||
|
||
dim_head = source.out_features // module.attn1.heads | ||
|
||
result = test_xformers_backwards(dim_head) | ||
|
||
# If dim_head > dim_head_max, turn xformers off | ||
if not result: | ||
module.set_use_memory_efficient_attention_xformers(False) | ||
|
||
for child in module.children(): | ||
fn_test_dim_head(child) | ||
|
||
if not is_xformers_available() and valid: | ||
print("XFormers is not available. Skipping.") | ||
return | ||
|
||
module.set_use_memory_efficient_attention_xformers(valid) | ||
|
||
if valid: | ||
fn_test_dim_head(module) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
4936d0f
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm having trouble running with --use_xformers (Windows).
The error seems to be with how set_use_memory_efficient_attention_xformers is setup.
Traceback (most recent call last):
File "C:\Users<user>\lora\train_lora_dreambooth.py", line 1039, in
main(args)
File "C:\Users<user>\lora\train_lora_dreambooth.py", line 655, in main
set_use_memory_efficient_attention_xformers(unet, True)
File "C:\Users<user>\lora\lora_diffusion\xformers_utils.py", line 67, in set_use_memory_efficient_attention_xformers
module.set_use_memory_efficient_attention_xformers(valid)
File "F:\ANACONDA\envs\sd\lib\site-packages\torch\nn\modules\module.py", line 1207, in getattr
raise AttributeError("'{}' object has no attribute '{}'".format(
AttributeError: 'UNet2DConditionModel' object has no attribute 'set_use_memory_efficient_attention_xformers'