forked from ddPn08/Radiata
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimages.py
76 lines (58 loc) · 2.14 KB
/
images.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import glob
import json
import os
import re
from datetime import datetime
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from api.models.diffusion import ImageGenerationOptions
from modules import config
def get_category(opts: ImageGenerationOptions):
return "img2img" if opts.image is not None else "txt2img"
def replace_invalid_chars(filepath, replace_with="_"):
invalid_chars = '[\\/:*?"<>|]'
replace_with = replace_with
return re.sub(invalid_chars, replace_with, filepath)
def save_image(img: Image.Image, opts: ImageGenerationOptions):
metadata = PngInfo()
metadata.add_text("parameters", opts.json())
dir = config.get(f"images/{get_category(opts)}/save_dir")
basename: str = config.get(f"images/{get_category(opts)}/save_name")
filename = (
basename.format(
seed=opts.seed,
index=len(os.listdir(dir)) + 1 if os.path.exists(dir) else 0,
prompt=opts.prompt[:20].replace(" ", "_"),
date=datetime.now().strftime("%Y-%m-%d_%H-%M-%S"),
)
.replace("\n", "_")
.replace("\r", "_")
.replace("\t", "_")
)
filename = replace_invalid_chars(filename)
os.makedirs(dir, exist_ok=True)
filepath = os.path.join(dir, filename)
img.save(filepath, pnginfo=metadata)
return os.path.basename(filepath)
def get_image_filepath(category: str, filename: str):
dir = config.get(f"images/{category}/save_dir")
return os.path.join(dir, filename)
def get_image(category: str, filename: str):
return Image.open(get_image_filepath(category, filename))
def get_image_parameter(img: Image.Image):
text = img.text
parameters = text.pop("parameters", None)
try:
text.update(json.loads(parameters))
except:
text.update({"parameters": parameters})
return text
def get_all_image_files(category: str):
dir = config.get(f"images/{category}/save_dir")
files = glob.glob(os.path.join(dir, "*"))
files = sorted(
[f.replace(os.sep, "/") for f in files if os.path.isfile(f)],
key=os.path.getmtime,
)
files.reverse()
return [os.path.relpath(f, dir) for f in files]