-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils.py
183 lines (146 loc) · 6.61 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
from huggingface_hub import InferenceClient
import json
import tempfile
import torch
import os
from diffusers import StableDiffusionPipeline, DPMSolverSinglestepScheduler, UNet2DConditionModel
from langchain_core.prompts import PromptTemplate
from peft import LoraConfig, get_peft_model, PeftModel
import copy
from transformers import AutoTokenizer
import pandas as pd
from PIL import Image
import re
from io import StringIO
import PyPDF2
from docx import Document
import shutil
import base64
import io
def image_to_base64(img):
buffered = io.BytesIO()
img.save(buffered, format="PNG") # Adjust format as needed
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def base64_to_image(base64_str):
img_data = base64.b64decode(base64_str)
return Image.open(io.BytesIO(img_data))
def serialize_data(obj):
if isinstance(obj, Image.Image): # Check if obj is a PIL Image
return {"data": image_to_base64(obj)}
def save_log(messages, filename='log_dump.txt'):
with open(filename, 'w') as txt:
json.dump(messages, txt)
def read_log(filename='log_dump.txt'):
with open(filename, 'r') as log:
messages = json.load(log)
return messages
def update_session_log(session_log: dict, log_path: str):
with open(log_path, 'w') as log:
json.dump(session_log, log, default=serialize_data) # Use custom serialization
# def update_session_log(session_log: dict, log_path: str):
# with open(log_path, 'w') as log:
# log.write(json.dumps(session_log))
def send_img_request(prompt, negative_prompt,
model_path='..\stable-diffusion-webui\models\Stable-diffusion',
custom_weights='realisticVisionV60B1_v51HyperVAE.safetensors',
inference_steps=6,
cfg_scale=2,
use_karras = True,
seed=0, use_lora=True,
width=512, height=512):
custom_weights = os.path.join(model_path, custom_weights)
#clear vram
torch.cuda.empty_cache()
if torch.cuda.is_available():
print(f'Running on GPU {torch.cuda.get_device_name(0)}')
device="cuda"
else:
print('No cuda found, running on CPU')
device="cpu"
#opening custom weight or checkpoints
pipe = StableDiffusionPipeline.from_single_file(custom_weights, torch_dtype=torch.float16).to(device)
if use_lora:
lora_dir = '..\stable-diffusion-webui\models\Lora'
lora_file = ['add_detail.safetensors', 'neon_palette_offset.safetensors', 'more_details.safetensors']
lora_weights = [1.0, 1.0, 0.5]
lora_adapters = [i.split(".")[0] for i in lora_file]
#iterate each lora weights
for ldir, la in zip(lora_file, lora_adapters):
print(ldir, la)
pipe.load_lora_weights(lora_dir, weight_name=ldir,adapter_name=la)
pipe.set_adapters(lora_adapters, adapter_weights=lora_weights)
#scheduler DPM SDE++ Karras
print("scheduling")
pipe.scheduler = DPMSolverSinglestepScheduler.from_config(pipe.scheduler.config,
use_karras_sigmas = use_karras)
#seed assignment
torch.manual_seed(seed=seed)
#generating image
image = pipe(prompt, negative_prompt=negative_prompt, num_inference_steps=inference_steps, guidance_scale=cfg_scale, width=width, height=height).images[0]
print(image.size)
return image
def ask_img(prompt):
# Create a regex pattern to match the specified structure
# This will look for 'keyword' followed by 'image' somewhere in the string
pattern = r'\b(craft|design|generate|create|draw|illustrate|render|make)\b(?:\s+\w+){0,5}\s+\b(potrait|portrait|image|picture|pic|pics|img|photo|art|painting|paintings|design|potrayal|portrayal)\b'
return bool(re.search(pattern, prompt, re.IGNORECASE))
def update_data(dataframe: pd.DataFrame, dir: str) -> pd.DataFrame:
# if value == []:
# raise ValueError("Trying to add empty value to dataframe")
try:
# dataframe.loc[len(dataframe)] = value
dataframe = pd.DataFrame({"Filename":os.listdir(dir), "checkbox":[True]*len(os.listdir(dir))})
return dataframe
except ValueError:
raise ValueError(f"Extra value on list. Data should be equal to column dataframe which are {len(dataframe.shape[1])}. List are {len(value)} index long")
except Exception as e:
print(f"Encounter an error while trying to add value to dataframe. {e}")
def create_session_dir():
return tempfile.TemporaryDirectory()
def create_session_log(dir: str):
log = os.path.join(dir, 'log.txt')
with open(log, 'w') as file:
file.write(json.dumps({})) #dictionary
return log
def save_session_file(file, name: str, dir_name: str):
try:
with open(os.path.join(dir_name, name), 'wb') as uf:
uf.write(file)
print("Session file saved..")
except Exception as e:
print(f"Error saving session file.. {e}")
def save_to_database(filename: str, filedir: str):
try:
file = os.path.join(filedir, filename)
shutil.copy(file, './data')
return "Requested file successfully saved to database. Table will be updated when you refresh the browser 🤖"
except FileNotFoundError:
return "Error saving file to database. File not found"
except Exception as e:
# Handle any other exceptions and return the error message
return f"An error occurred: {e}"
def read_session_file(filename: str, dir: str) -> str:
try:
filename = os.path.join(filename, dir)
if filename.endswith('.csv') or filename.endswith('.xlsx'):
content = pd.read_csv(filename)
elif filename.endswith('.txt'):
content = StringIO(filename.getvalue().decode("utf-8"))
elif filename.endswith('.pdf'):
pdf_reader = PyPDF2.PdfReader(filename)
for page_num in range(len(pdf_reader.pages)):
page = pdf_reader.pages[page_num]
content += page.extract_text()
elif filename.endswith('docx'):
doc = Document(filename)
for para in doc.paragraphs:
content += para.text + "\n"
return content
except Exception as e:
print("Error has occured.. {e}")
# def tokenizer(model_name: str, prompt: str) -> int:
# # model_name = "meta-llama/Meta-Llama-3-70B-Instruct"
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# return tokenizer
# # tokens = tokenizer.encode(prompt, add_special_tokens=True)
# # return len(tokens)