Skip to content

Commit

Permalink
refactor: move functions into classes
Browse files Browse the repository at this point in the history
- move functions into existing classes and files
- move AutoFP8 dialog out of a function and into __init__
  • Loading branch information
leafspark committed Sep 9, 2024
1 parent e46c626 commit be38e35
Show file tree
Hide file tree
Showing 6 changed files with 195 additions and 173 deletions.
236 changes: 64 additions & 172 deletions src/AutoGGUF.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
open_file_safe,
resource_path,
show_about,
load_dotenv,
)


Expand All @@ -48,7 +49,7 @@ def __init__(self, args: List[str]) -> None:
self.setGeometry(100, 100, width, height)
self.setWindowFlag(Qt.FramelessWindowHint)

self.load_dotenv() # Loads the .env file
load_dotenv(self) # Loads the .env file

# Configuration
self.model_dir_name = os.environ.get("AUTOGGUF_MODEL_DIR_NAME", "models")
Expand Down Expand Up @@ -117,6 +118,18 @@ def __init__(self, args: List[str]) -> None:
self.delete_lora_adapter_item = partial(
lora_conversion.delete_lora_adapter_item, self
)
self.lora_conversion_finished = partial(
lora_conversion.lora_conversion_finished, self
)
self.parse_progress = partial(QuantizationThread.parse_progress, self)
self.create_label = partial(ui_update.create_label, self)
self.browse_imatrix_datafile = ui_update.browse_imatrix_datafile.__get__(self)
self.browse_imatrix_model = ui_update.browse_imatrix_model.__get__(self)
self.browse_imatrix_output = ui_update.browse_imatrix_output.__get__(self)
self.restart_task = partial(TaskListItem.restart_task, self)
self.browse_hf_outfile = ui_update.browse_hf_outfile.__get__(self)
self.browse_hf_model_input = ui_update.browse_hf_model_input.__get__(self)
self.browse_base_model = ui_update.browse_base_model.__get__(self)

# Set up main widget and layout
main_widget = QWidget()
Expand Down Expand Up @@ -154,11 +167,56 @@ def __init__(self, args: List[str]) -> None:
about_action.triggered.connect(self.show_about)
help_menu.addAction(about_action)

# AutoFP8 Window
self.fp8_dialog = QDialog(self)
self.fp8_dialog.setWindowTitle(QUANTIZE_TO_FP8_DYNAMIC)
self.fp8_dialog.setFixedWidth(500)
self.fp8_layout = QVBoxLayout()

# Input path
input_layout = QHBoxLayout()
self.fp8_input = QLineEdit()
input_button = QPushButton(BROWSE)
input_button.clicked.connect(
lambda: self.fp8_input.setText(
QFileDialog.getExistingDirectory(self, OPEN_MODEL_FOLDER)
)
)
input_layout.addWidget(QLabel(INPUT_MODEL))
input_layout.addWidget(self.fp8_input)
input_layout.addWidget(input_button)
self.fp8_layout.addLayout(input_layout)

# Output path
output_layout = QHBoxLayout()
self.fp8_output = QLineEdit()
output_button = QPushButton(BROWSE)
output_button.clicked.connect(
lambda: self.fp8_output.setText(
QFileDialog.getExistingDirectory(self, OPEN_MODEL_FOLDER)
)
)
output_layout.addWidget(QLabel(OUTPUT))
output_layout.addWidget(self.fp8_output)
output_layout.addWidget(output_button)
self.fp8_layout.addLayout(output_layout)

# Quantize button
quantize_button = QPushButton(QUANTIZE)
quantize_button.clicked.connect(
lambda: self.quantize_to_fp8_dynamic(
self.fp8_input.text(), self.fp8_output.text()
)
)

self.fp8_layout.addWidget(quantize_button)
self.fp8_dialog.setLayout(self.fp8_layout)

# Tools menu
tools_menu = self.menubar.addMenu("&Tools")
autofp8_action = QAction("&AutoFP8", self)
autofp8_action.setShortcut(QKeySequence("Shift+Q"))
autofp8_action.triggered.connect(self.show_autofp8_window)
autofp8_action.triggered.connect(self.fp8_dialog.exec)
tools_menu.addAction(autofp8_action)

# Content widget
Expand Down Expand Up @@ -744,18 +802,16 @@ def __init__(self, args: List[str]) -> None:
self.hf_no_lazy = QCheckBox(NO_LAZY_EVALUATION)
hf_to_gguf_layout.addRow(self.hf_no_lazy)

self.hf_model_name = QLineEdit()
hf_to_gguf_layout.addRow(MODEL_NAME, self.hf_model_name)

self.hf_verbose = QCheckBox(VERBOSE)
hf_to_gguf_layout.addRow(self.hf_verbose)
self.hf_dry_run = QCheckBox(DRY_RUN)
hf_to_gguf_layout.addRow(self.hf_dry_run)
self.hf_model_name = QLineEdit()
hf_to_gguf_layout.addRow(MODEL_NAME, self.hf_model_name)

self.hf_split_max_size = QLineEdit()
hf_to_gguf_layout.addRow(SPLIT_MAX_SIZE, self.hf_split_max_size)

self.hf_dry_run = QCheckBox(DRY_RUN)
hf_to_gguf_layout.addRow(self.hf_dry_run)

hf_to_gguf_convert_button = QPushButton(CONVERT_HF_TO_GGUF)
hf_to_gguf_convert_button.clicked.connect(self.convert_hf_to_gguf)
hf_to_gguf_layout.addRow(hf_to_gguf_convert_button)
Expand Down Expand Up @@ -812,41 +868,6 @@ def __init__(self, args: List[str]) -> None:
self.logger.info(AUTOGGUF_INITIALIZATION_COMPLETE)
self.logger.info(STARTUP_ELASPED_TIME.format(init_timer.elapsed()))

def load_dotenv(self):
if not os.path.isfile(".env"):
self.logger.warning(DOTENV_FILE_NOT_FOUND)
return

try:
with open(".env") as f:
for line in f:
# Strip leading/trailing whitespace
line = line.strip()

# Ignore comments and empty lines
if not line or line.startswith("#"):
continue

# Match key-value pairs (unquoted and quoted values)
match = re.match(r"^([^=]+)=(.*)$", line)
if not match:
self.logger.warning(COULD_NOT_PARSE_LINE.format(line))
continue

key, value = match.groups()

# Remove any surrounding quotes from the value
if value.startswith(("'", '"')) and value.endswith(("'", '"')):
value = value[1:-1]

# Decode escape sequences
value = bytes(value, "utf-8").decode("unicode_escape")

# Set the environment variable
os.environ[key.strip()] = value.strip()
except Exception as e:
self.logger.error(ERROR_LOADING_DOTENV.format(e))

def load_plugins(self) -> Dict[str, Dict[str, Any]]:
plugins = {}
plugin_dir = "plugins"
Expand Down Expand Up @@ -1038,28 +1059,6 @@ def save_task_preset(self, task_item) -> None:
)
break

def browse_base_model(self) -> None:
self.logger.info(BROWSING_FOR_BASE_MODEL_FOLDER) # Updated log message
base_model_folder = QFileDialog.getExistingDirectory(
self, SELECT_BASE_MODEL_FOLDER
)
if base_model_folder:
self.base_model_path.setText(os.path.abspath(base_model_folder))

def browse_hf_model_input(self) -> None:
self.logger.info(BROWSE_FOR_HF_MODEL_DIRECTORY)
model_dir = QFileDialog.getExistingDirectory(self, SELECT_HF_MODEL_DIRECTORY)
if model_dir:
self.hf_model_input.setText(os.path.abspath(model_dir))

def browse_hf_outfile(self) -> None:
self.logger.info(BROWSE_FOR_HF_TO_GGUF_OUTPUT)
outfile, _ = QFileDialog.getSaveFileName(
self, SELECT_OUTPUT_FILE, "", GGUF_FILES
)
if outfile:
self.hf_outfile.setText(os.path.abspath(outfile))

def quantize_to_fp8_dynamic(self, model_dir: str, output_dir: str) -> None:
self.logger.info(
QUANTIZING_TO_WITH_AUTOFP8.format(os.path.basename(model_dir), output_dir)
Expand Down Expand Up @@ -1107,52 +1106,6 @@ def quantize_to_fp8_dynamic(self, model_dir: str, output_dir: str) -> None:
show_error(self.logger, f"{ERROR_STARTING_AUTOFP8_QUANTIZATION}: {e}")
self.logger.info(AUTOFP8_QUANTIZATION_TASK_STARTED)

def show_autofp8_window(self):
dialog = QDialog(self)
dialog.setWindowTitle(QUANTIZE_TO_FP8_DYNAMIC)
dialog.setFixedWidth(500)
layout = QVBoxLayout()

# Input path
input_layout = QHBoxLayout()
self.fp8_input = QLineEdit()
input_button = QPushButton(BROWSE)
input_button.clicked.connect(
lambda: self.fp8_input.setText(
QFileDialog.getExistingDirectory(self, OPEN_MODEL_FOLDER)
)
)
input_layout.addWidget(QLabel(INPUT_MODEL))
input_layout.addWidget(self.fp8_input)
input_layout.addWidget(input_button)
layout.addLayout(input_layout)

# Output path
output_layout = QHBoxLayout()
self.fp8_output = QLineEdit()
output_button = QPushButton(BROWSE)
output_button.clicked.connect(
lambda: self.fp8_output.setText(
QFileDialog.getExistingDirectory(self, OPEN_MODEL_FOLDER)
)
)
output_layout.addWidget(QLabel(OUTPUT))
output_layout.addWidget(self.fp8_output)
output_layout.addWidget(output_button)
layout.addLayout(output_layout)

# Quantize button
quantize_button = QPushButton(QUANTIZE)
quantize_button.clicked.connect(
lambda: self.quantize_to_fp8_dynamic(
self.fp8_input.text(), self.fp8_output.text()
)
)
layout.addWidget(quantize_button)

dialog.setLayout(layout)
dialog.exec()

def convert_hf_to_gguf(self) -> None:
self.logger.info(STARTING_HF_TO_GGUF_CONVERSION)
try:
Expand Down Expand Up @@ -1229,31 +1182,6 @@ def convert_hf_to_gguf(self) -> None:
show_error(self.logger, ERROR_STARTING_HF_TO_GGUF_CONVERSION.format(str(e)))
self.logger.info(HF_TO_GGUF_CONVERSION_TASK_STARTED)

def restart_task(self, task_item) -> None:
self.logger.info(RESTARTING_TASK.format(task_item.task_name))
for thread in self.quant_threads:
if thread.log_file == task_item.log_file:
new_thread = QuantizationThread(
thread.command, thread.cwd, thread.log_file
)
self.quant_threads.append(new_thread)
new_thread.status_signal.connect(task_item.update_status)
new_thread.finished_signal.connect(
lambda: self.task_finished(new_thread, task_item)
)
new_thread.error_signal.connect(
lambda err: handle_error(self.logger, err, task_item)
)
new_thread.model_info_signal.connect(self.update_model_info)
new_thread.start()
task_item.update_status(IN_PROGRESS)
break

def lora_conversion_finished(self, thread) -> None:
self.logger.info(LORA_CONVERSION_FINISHED)
if thread in self.quant_threads:
self.quant_threads.remove(thread)

def download_finished(self, extract_dir) -> None:
self.logger.info(DOWNLOAD_FINISHED_EXTRACTED_TO.format(extract_dir))
self.download_button.setEnabled(True)
Expand Down Expand Up @@ -1313,11 +1241,6 @@ def download_error(self, error_message) -> None:
if os.path.exists(partial_file):
os.remove(partial_file)

def create_label(self, text, tooltip) -> QLabel:
label = QLabel(text)
label.setToolTip(tooltip)
return label

def verify_gguf(self, file_path) -> bool:
try:
with open(file_path, "rb") as f:
Expand Down Expand Up @@ -1613,15 +1536,6 @@ def quantize_model(self) -> None:
except Exception as e:
show_error(self.logger, ERROR_STARTING_QUANTIZATION.format(str(e)))

def parse_progress(self, line, task_item) -> None:
# Parses the output line for progress information and updates the task item.
match = re.search(r"\[\s*(\d+)\s*/\s*(\d+)\s*].*", line)
if match:
current = int(match.group(1))
total = int(match.group(2))
progress = int((current / total) * 100)
task_item.update_progress(progress)

def task_finished(self, thread, task_item) -> None:
self.logger.info(TASK_FINISHED.format(thread.log_file))
if thread in self.quant_threads:
Expand Down Expand Up @@ -1681,28 +1595,6 @@ def import_model(self) -> None:
self.load_models()
self.logger.info(MODEL_IMPORTED_SUCCESSFULLY.format(file_name))

def browse_imatrix_datafile(self) -> None:
self.logger.info(BROWSING_FOR_IMATRIX_DATA_FILE)
datafile, _ = QFileDialog.getOpenFileName(self, SELECT_DATA_FILE, "", ALL_FILES)
if datafile:
self.imatrix_datafile.setText(os.path.abspath(datafile))

def browse_imatrix_model(self) -> None:
self.logger.info(BROWSING_FOR_IMATRIX_MODEL_FILE)
model_file, _ = QFileDialog.getOpenFileName(
self, SELECT_MODEL_FILE, "", GGUF_FILES
)
if model_file:
self.imatrix_model.setText(os.path.abspath(model_file))

def browse_imatrix_output(self) -> None:
self.logger.info(BROWSING_FOR_IMATRIX_OUTPUT_FILE)
output_file, _ = QFileDialog.getSaveFileName(
self, SELECT_OUTPUT_FILE, "", DAT_FILES
)
if output_file:
self.imatrix_output.setText(os.path.abspath(output_file))

def generate_imatrix(self) -> None:
self.logger.info(STARTING_IMATRIX_GENERATION)
try:
Expand Down
10 changes: 10 additions & 0 deletions src/QuantizationThread.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import re
import signal
import subprocess

Expand Down Expand Up @@ -78,6 +79,15 @@ def parse_model_info(self, line) -> None:
f"{quant_type}: {tensors} tensors"
)

def parse_progress(self, line, task_item) -> None:
# Parses the output line for progress information and updates the task item.
match = re.search(r"\[\s*(\d+)\s*/\s*(\d+)\s*].*", line)
if match:
current = int(match.group(1))
total = int(match.group(2))
progress = int((current / total) * 100)
task_item.update_progress(progress)

def terminate(self) -> None:
# Terminate the subprocess if it's still running
if self.process:
Expand Down
20 changes: 20 additions & 0 deletions src/TaskListItem.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,23 @@ def update_progress(self, value=None) -> None:
else:
# Set progress bar to zero for indeterminate progress
self.progress_bar.setValue(0)

def restart_task(self, task_item) -> None:
self.logger.info(RESTARTING_TASK.format(task_item.task_name))
for thread in self.quant_threads:
if thread.log_file == task_item.log_file:
new_thread = QuantizationThread(
thread.command, thread.cwd, thread.log_file
)
self.quant_threads.append(new_thread)
new_thread.status_signal.connect(task_item.update_status)
new_thread.finished_signal.connect(
lambda: self.task_finished(new_thread, task_item)
)
new_thread.error_signal.connect(
lambda err: handle_error(self.logger, err, task_item)
)
new_thread.model_info_signal.connect(self.update_model_info)
new_thread.start()
task_item.update_status(IN_PROGRESS)
break
Loading

0 comments on commit be38e35

Please sign in to comment.