-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathblora_utils.py
46 lines (33 loc) · 1.37 KB
/
blora_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from typing import Optional
BLOCKS = {
'content': ['unet.up_blocks.0.attentions.0'],
'style': ['unet.up_blocks.0.attentions.1'],
}
def is_belong_to_blocks(key, blocks):
try:
for g in blocks:
if g in key:
return True
return False
except Exception as e:
raise type(e)(f'failed to is_belong_to_block, due to: {e}')
def filter_lora(state_dict, blocks_):
try:
return {k: v for k, v in state_dict.items() if is_belong_to_blocks(k, blocks_)}
except Exception as e:
raise type(e)(f'failed to filter_lora, due to: {e}')
def scale_lora(state_dict, alpha):
try:
return {k: v * alpha for k, v in state_dict.items()}
except Exception as e:
raise type(e)(f'failed to scale_lora, due to: {e}')
def get_target_modules(unet, blocks=None):
try:
if not blocks:
blocks = [('.').join(blk.split('.')[1:]) for blk in BLOCKS['content'] + BLOCKS['style']]
attns = [attn_processor_name.rsplit('.', 1)[0] for attn_processor_name, _ in unet.attn_processors.items() if
is_belong_to_blocks(attn_processor_name, blocks)]
target_modules = [f'{attn}.{mat}' for mat in ["to_k", "to_q", "to_v", "to_out.0"] for attn in attns]
return target_modules
except Exception as e:
raise type(e)(f'failed to get_target_modules, due to: {e}')