forked from ddPn08/Radiata
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodel_manager.py
91 lines (68 loc) · 2.53 KB
/
model_manager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import os
from glob import glob
from typing import *
from huggingface_hub import HfApi, ModelFilter
from modules.logger import logger
from modules.shared import ROOT_DIR
from . import config
from .model import DiffusersModel
sd_models: List[DiffusersModel] = []
sd_model: Optional[DiffusersModel] = None
mode: Literal["stable-diffusion", "deepfloyd_if"] = (
"deepfloyd_if" if config.get("deepfloyd_if") else "stable-diffusion"
)
def get_model(model_id: str):
model = [x for x in sd_models if x.model_id == model_id]
if len(model) < 1:
return None
return model[0]
def add_model(model_id: str):
global sd_models
sd_models.append(DiffusersModel(model_id=model_id))
config.set("models", [x.model_id for x in sd_models])
def set_model(model_id: str):
global sd_model
sd_model.teardown()
try:
sd_model = [x for x in sd_models if x.model_id == model_id]
if len(sd_model) != 1:
raise ValueError("Model not found or multiple models with same ID.")
else:
sd_model = sd_model[0]
logger.info(f"Loading {sd_model.model_id}...")
sd_model.activate()
config.set("model", sd_model.model_id)
logger.info(f"Loaded {sd_model.model_id}...")
except Exception as e:
logger.error(f"Failed to load {model_id}...")
logger.error(e)
set_default_model()
def search_model(model_id: str):
api = HfApi()
models = api.list_models(filter=ModelFilter(model_name=model_id))
return models
def set_default_model():
global sd_model
prev = config.get("model")
sd_model = [x for x in sd_models if x.model_id == prev]
if len(sd_model) == 1:
sd_model = sd_model[0]
else:
sd_model = [*sd_models][0]
set_model(sd_model.model_id)
def init():
if mode != "stable-diffusion":
return
raw_model_list = config.get("models") or []
if len(raw_model_list) < 1:
raw_model_list = config.DEFAULT_CONFIG["models"]
for model_id in raw_model_list:
sd_models.append(DiffusersModel(model_id=model_id))
checkpoints_path = os.path.join(ROOT_DIR, "models", "checkpoints")
for model in glob(os.path.join(checkpoints_path, "**", "*"), recursive=True):
if model.endswith(".safetensors") or model.endswith(".ckpt"):
relpath = os.path.relpath(model, checkpoints_path)
model_id = relpath.replace(os.sep, "/")
if model_id not in raw_model_list:
sd_models.append(DiffusersModel(model_id=model_id))
set_default_model()