forked from ddPn08/Radiata
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel.py
179 lines (147 loc) · 5.36 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import gc
import os
import random
from concurrent.futures import ThreadPoolExecutor
from queue import Queue
from typing import *
import torch
from api.models.diffusion import ImageGenerationOptions
from lib.diffusers.scheduler import SCHEDULERS
from . import config, utils
from .images import save_image
from .shared import hf_diffusers_cache_dir
ModelMode = Literal["diffusers", "tensorrt"]
class DiffusersModel:
def __init__(self, model_id: str):
self.model_id: str = model_id
self.mode: ModelMode = "diffusers"
self.activated: bool = False
self.pipe = None
def available_modes(self):
modes = ["diffusers"]
if self.trt_available():
modes.append("tensorrt")
return modes
def get_model_dir(self):
return os.path.join(config.get("model_dir"), self.model_id.replace("/", os.sep))
def get_trt_path(self):
return os.path.join(
config.get("model_dir"),
"accelerate",
"tensorrt",
self.model_id.replace("/", os.sep),
)
def trt_available(self):
global logged_trt_warning
trt_path = self.get_trt_path()
necessary_files = [
"engine/clip.plan",
"engine/unet.plan",
"engine/vae.plan",
"engine/vae_encoder.plan",
"onnx/clip.opt.onnx",
"onnx/unet.opt.onnx",
"onnx/vae.opt.onnx",
"onnx/vae_encoder.opt.onnx",
]
for file in necessary_files:
filepath = os.path.join(trt_path, *file.split("/"))
if not os.path.exists(filepath):
return False
trt_module_status, trt_version_status = utils.tensorrt_is_available()
if config.get("tensorrt"):
if not trt_module_status or not trt_version_status:
return False
return True
def activate(self):
if self.activated:
return
if self.mode == "diffusers":
from .diffusion.pipelines.diffusers import DiffusersPipeline
self.pipe = DiffusersPipeline.from_pretrained(
self.model_id,
use_auth_token=config.get("hf_token"),
torch_dtype=torch.float16,
cache_dir=hf_diffusers_cache_dir(),
).to(device=torch.device("cuda"))
self.pipe.enable_attention_slicing()
if utils.is_installed("xformers") and config.get("xformers"):
self.pipe.enable_xformers_memory_efficient_attention()
elif self.mode == "tensorrt":
from .diffusion.pipelines.tensorrt import TensorRTStableDiffusionPipeline
model_dir = self.get_trt_path()
self.pipe = TensorRTStableDiffusionPipeline.from_pretrained(
model_id=self.model_id,
engine_dir=os.path.join(model_dir, "engine"),
use_auth_token=config.get("hf_token"),
device=torch.device("cuda"),
max_batch_size=1,
hf_cache_dir=hf_diffusers_cache_dir(),
)
self.activated = True
def teardown(self):
if not self.activated:
return
self.pipe = None
gc.collect()
torch.cuda.empty_cache()
self.activated = False
def change_mode(self, mode: ModelMode):
if mode == self.mode:
return
self.teardown()
self.mode = mode
self.activate()
def swap_scheduler(self, scheduler_id: str):
if not self.activated:
raise RuntimeError("Model not activated")
self.pipe.scheduler = SCHEDULERS[scheduler_id].from_config(
self.pipe.scheduler.config
)
def __call__(self, opts: ImageGenerationOptions, plugin_data: Dict[str, List] = {}):
if not self.activated:
raise RuntimeError("Model not activated")
if opts.seed is None or opts.seed == -1:
opts.seed = random.randrange(0, 4294967294, 1)
self.swap_scheduler(opts.scheduler_id)
queue = Queue()
done = object()
total_steps = 0
results = []
def callback(*args, **kwargs):
nonlocal total_steps
total_steps += 1
queue.put((total_steps, results))
def on_done(feature):
queue.put(done)
for i in range(opts.batch_count):
manual_seed = int(opts.seed + i)
generator = torch.Generator(device=self.pipe.device).manual_seed(
manual_seed
)
with ThreadPoolExecutor() as executer:
feature = executer.submit(
self.pipe,
opts=opts,
generator=generator,
callback=callback,
plugin_data=plugin_data,
)
feature.add_done_callback(on_done)
while True:
item = queue.get()
if item is done:
break
yield item
images = feature.result().images
results.append(
(
images,
ImageGenerationOptions.parse_obj(
{"seed": manual_seed, **opts.dict()}
),
)
)
for img in images:
save_image(img, opts)
yield results