-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path__init__.py
105 lines (85 loc) · 3.21 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU Lesser General Public
# License as published by the Free Software Foundation; either
# version 2.1 of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
# Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public
# License along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
# 02110-1301 USA
import torch
import comfy
from comfy.ldm.modules.attention import (
attention_sub_quad,
attention_pytorch,
attention_split,
attention_xformers,
)
from . import fused_attention
def attention_triton(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
(q, k, v),
)
dtype = q.dtype
if dtype != torch.float16:
q, k, v = map(
lambda t: t.to(torch.float16),
(q, k, v),
)
out = fused_attention.attention(q, k, v, False, dim_head ** -0.5)
out = (
out.transpose(1, 2).reshape(b, -1, heads * dim_head)
)
return out.to(dtype)
class AttnSelectorWithTriton:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
available_attns = []
available_attns.append("triton")
available_attns.append("xformers")
available_attns.append("pytorch")
available_attns.append("split")
available_attns.append("sub-quad")
return {
"required": {
"attention": (available_attns,),
"Model": ("MODEL", )
},
}
RETURN_TYPES = ("MODEL", )
FUNCTION = "test"
OUTPUT_NODE = True
CATEGORY = "_for_testing"
def test(self, attention, Model):
print("Select optimized attention:", attention)
if attention == "xformers":
attention_algorithm = attention_xformers
elif attention == "pytorch":
attention_algorithm = attention_pytorch
elif attention == "split":
attention_algorithm = attention_split
elif attention == "sub-quad":
attention_algorithm = attention_sub_quad
elif attention == "triton":
attention_algorithm = attention_triton
comfy.ldm.flux.math.optimized_attention = attention_algorithm
return (Model, attention)
# A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = {"AttnSelectorWithTriton": AttnSelectorWithTriton}
# A dictionary that contains the friendly/humanly readable titles for the nodes
NODE_DISPLAY_NAME_MAPPINGS = {"AttnSelectorWithTriton": "Attention selector (with triton)"}
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]