Skip to content

Commit

Permalink
feat(ui): support shift clicking to get quantization command
Browse files Browse the repository at this point in the history
- support shift clicking Quantize Model button to get quantize command
- clean up imports in AutoGGUF.py and add localization keys
- use str() for getting log_dir_name
- remove legacy validate_quantization_inputs() function
- add return_command parameter to quantize_model() function
  • Loading branch information
leafspark committed Nov 13, 2024
1 parent 6aaefb2 commit 749f321
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 22 deletions.
46 changes: 24 additions & 22 deletions src/AutoGGUF.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import urllib.request
from datetime import datetime
from functools import partial, wraps
from typing import List
from typing import Any, List, Union

from PySide6.QtCore import *
from PySide6.QtGui import *
Expand Down Expand Up @@ -71,7 +71,7 @@ def __init__(self, args: List[str]) -> None:

self.parse_resolution = ui_update.parse_resolution.__get__(self)

self.log_dir_name = os.environ.get("AUTOGGUF_LOG_DIR_NAME", "logs")
self.log_dir_name = str(os.environ.get("AUTOGGUF_LOG_DIR_NAME", "logs"))

width, height = self.parse_resolution()
self.logger = Logger("AutoGGUF", self.log_dir_name)
Expand Down Expand Up @@ -775,7 +775,7 @@ def __init__(self, args: List[str]) -> None:
# Quantize button layout
quantize_layout = QHBoxLayout()
quantize_button = QPushButton(QUANTIZE_MODEL)
quantize_button.clicked.connect(self.quantize_model)
quantize_button.clicked[bool].connect(self.quantize_model_handler)
save_preset_button = QPushButton(SAVE_PRESET)
save_preset_button.clicked.connect(self.save_preset)
load_preset_button = QPushButton(LOAD_PRESET)
Expand Down Expand Up @@ -1101,6 +1101,20 @@ def __init__(self, args: List[str]) -> None:
self.logger.info(AUTOGGUF_INITIALIZATION_COMPLETE)
self.logger.info(STARTUP_ELASPED_TIME.format(init_timer.elapsed()))

def quantize_model_handler(self) -> None:
if QApplication.keyboardModifiers() == Qt.ShiftModifier and self.quantize_model(
return_command=True
):
QApplication.clipboard().setText(self.quantize_model(return_command=True))
QMessageBox.information(
None,
INFO,
f"{COPIED_COMMAND_TO_CLIPBOARD} "
+ f"<code style='font-family: monospace; white-space: pre;'>{self.quantize_model(return_command=True)}</code>",
)
else:
self.quantize_model()

def resizeEvent(self, event) -> None:
super().resizeEvent(event)
path = QPainterPath()
Expand Down Expand Up @@ -1254,23 +1268,6 @@ def download_finished(self, extract_dir) -> None:
if index >= 0:
self.backend_combo.setCurrentIndex(index)

def validate_quantization_inputs(self) -> None:
self.logger.debug(VALIDATING_QUANTIZATION_INPUTS)
errors = []
if not self.backend_combo.currentData():
errors.append(NO_BACKEND_SELECTED)
if not self.models_input.text():
errors.append(MODELS_PATH_REQUIRED)
if not self.output_input.text():
errors.append(OUTPUT_PATH_REQUIRED)
if not self.logs_input.text():
errors.append(LOGS_PATH_REQUIRED)
if not self.model_tree.currentItem():
errors.append(NO_MODEL_SELECTED)

if errors:
raise ValueError("\n".join(errors))

def load_models(self) -> None:
self.logger.info(LOADING_MODELS)
models_dir = self.models_input.text()
Expand Down Expand Up @@ -1698,10 +1695,9 @@ def merge_gguf(self, model_dir: str, output_dir: str) -> None:
show_error(self.logger, "Error starting merge GGUF task: {}".format(e))
self.logger.info("Split GGUF task finished.")

def quantize_model(self) -> None:
def quantize_model(self, return_command=False) -> str:
self.logger.info(STARTING_MODEL_QUANTIZATION)
try:
self.validate_quantization_inputs()
selected_item = self.model_tree.currentItem()
if not selected_item:
raise ValueError(NO_MODEL_SELECTED)
Expand Down Expand Up @@ -1822,6 +1818,12 @@ def quantize_model(self) -> None:
if self.extra_arguments.text():
command.extend(self.extra_arguments.text().split())

if return_command:
self.logger.info(
f"{QUANTIZATION_COMMAND}: {str(' '.join(command))}"
)
return str(" ".join(command))

logs_path = self.logs_input.text()
ensure_directory(logs_path)

Expand Down
3 changes: 3 additions & 0 deletions src/Localizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,9 @@ def __init__(self):

self.EXTRA_COMMAND_ARGUMENTS = "Additional command-line arguments"

self.INFO = "Info"
self.COPIED_COMMAND_TO_CLIPBOARD = "Copied command to clipboard:"


class _French(_Localization):
def __init__(self):
Expand Down

0 comments on commit 749f321

Please sign in to comment.