From 16a7e0e72d2a72890f11de7f05ec613a484c86a4 Mon Sep 17 00:00:00 2001 From: Rodolfo Olivieri Date: Mon, 21 Oct 2024 14:30:12 -0300 Subject: [PATCH] Fixes after review --- packaging/shellai.spec | 2 +- shellai/__main__.py | 4 +- shellai/config.py | 128 ++++++++++++++++++++++++----------------- shellai/handlers.py | 2 +- shellai/history.py | 16 +++--- shellai/utils.py | 10 ++++ 6 files changed, 100 insertions(+), 62 deletions(-) diff --git a/packaging/shellai.spec b/packaging/shellai.spec index 2e2c69a..f02aab5 100644 --- a/packaging/shellai.spec +++ b/packaging/shellai.spec @@ -18,7 +18,7 @@ BuildRequires: python3-tomli Requires: python3-requests %if 0%{?rhel} && 0%{?rhel} < 10 -BuildRequires: python3-tomli +Requires: python3-tomli %endif %description diff --git a/shellai/__main__.py b/shellai/__main__.py index 65a32dd..8265834 100644 --- a/shellai/__main__.py +++ b/shellai/__main__.py @@ -3,6 +3,7 @@ import os import sys +from shellai import utils from shellai.config import ( CONFIG_DEFAULT_PATH, load_config_file, @@ -61,7 +62,8 @@ def get_args(): def main(): parser, args = get_args() - config = load_config_file(args.config) + config_file = utils.expand_user_path(args.config) + config = load_config_file(config_file) enforce_script_session = config.output.enforce_script output_file = config.output.enforce_script diff --git a/shellai/config.py b/shellai/config.py index f3b1180..d95a049 100644 --- a/shellai/config.py +++ b/shellai/config.py @@ -1,18 +1,21 @@ import json import logging -import os -import sys from collections import namedtuple +from pathlib import Path -if sys.version_info >= (3, 11): +from shellai import utils + +# tomllib is available in the stdlib after Python3.11. Before that, we import +# from tomli. +try: import tomllib -else: +except ImportError: import tomli as tomllib -CONFIG_DEFAULT_PATH: str = "~/.config/shellai/config.toml" +CONFIG_DEFAULT_PATH: Path = Path("~/.config/shellai/config.toml") # tomllib does not support writting files, so we will create our own. -CONFIG_TEMPLATE = """ +CONFIG_TEMPLATE = """\ [output] # otherwise recording via script session will be enforced enforce_script = {enforce_script} @@ -29,86 +32,112 @@ [backend] endpoint = "{endpoint}" - """ -class Output(namedtuple("Output", ["enforce_script", "file", "prompt_separator"])): +class OutputSchema( + namedtuple("Output", ["enforce_script", "file", "prompt_separator"]) +): + """This class represents the [output] section of our config.toml file.""" + + # Locking down against extra fields at runtime __slots__ = () + # We are overriding __new__ here because namedtuple only offers default values to fields from Python 3.7+ def __new__( cls, enforce_script: bool = False, file: str = "/tmp/shellai_output.txt", prompt_separator: str = "$", ): - return super(Output, cls).__new__(cls, enforce_script, file, prompt_separator) + file = utils.expand_user_path(file) + return super(OutputSchema, cls).__new__( + cls, enforce_script, file, prompt_separator + ) -class History(namedtuple("History", ["enabled", "file", "max_size"])): +class HistorySchema(namedtuple("History", ["enabled", "file", "max_size"])): + """This class represents the [history] section of our config.toml file.""" + + # Locking down against extra fields at runtime __slots__ = () + # We are overriding __new__ here because namedtuple only offers default values to fields from Python 3.7+ def __new__( cls, enabled: bool = True, - file: str = "/tmp/shellai_output.txt", + file: str = "~/.local/share/shellai/shellai_history.json", max_size: int = 100, ): - return super(History, cls).__new__(cls, enabled, file, max_size) + file = utils.expand_user_path(file) + return super(HistorySchema, cls).__new__(cls, enabled, file, max_size) -class Backend(namedtuple("Backend", ["endpoint"])): - endpoint: str = "http://0.0.0.0:8080/v1/query/" +class BackendSchema(namedtuple("Backend", ["endpoint"])): + """This class represents the [backend] section of our config.toml file.""" + + # Locking down against extra fields at runtime __slots__ = () + # We are overriding __new__ here because namedtuple only offers default values to fields from Python 3.7+ def __new__( cls, endpoint: str = "http://0.0.0.0:8080/v1/query/", ): - return super(Backend, cls).__new__(cls, endpoint) + return super(BackendSchema, cls).__new__(cls, endpoint) class Config: + """Class that holds our configuration file representation. + + With this class, after being initialized, one can access their fields like: + + >>> config = Config() + >>> config.output.enforce_script + + The currently available top-level fields are: + * output = Match the `py:Output` class and their fields + * history = Match the `py:History` class and their fields + * backend = Match the `py:backend` class and their fields + """ + def __init__(self, output: dict, history: dict, backend: dict) -> None: - self.output: Output = Output(**output) - self.history: History = History(**history) - self.backend: Backend = Backend(**backend) + self.output: OutputSchema = OutputSchema(**output) + self.history: HistorySchema = HistorySchema(**history) + self.backend: BackendSchema = BackendSchema(**backend) -def _create_config_file(config_path: str) -> None: +def _create_config_file(config_file: Path) -> None: """Create a new configuration file with default values.""" - config_dir = os.path.dirname(config_path) - logging.info(f"Creating new config file at {config_path}") - os.makedirs(config_dir, mode=0o755, exist_ok=True) - base_config = Config(Output()._asdict(), History()._asdict(), Backend()._asdict()) - - with open(config_path, mode="w") as handler: - mapping = { - "enforce_script": json.dumps(base_config.output.enforce_script), - "output_file": base_config.output.file, - "prompt_separator": base_config.output.prompt_separator, - "enabled": json.dumps(base_config.history.enabled), - "history_file": base_config.history.file, - "max_size": base_config.history.max_size, - "endpoint": base_config.backend.endpoint, - } - config_formatted = CONFIG_TEMPLATE.format_map(mapping) - handler.write(config_formatted) - - -def _read_config_file(config_path: str) -> Config: + + logging.info(f"Creating new config file at {config_file.parent}") + config_file.parent.mkdir(mode=0o755) + base_config = Config( + OutputSchema()._asdict(), HistorySchema()._asdict(), BackendSchema()._asdict() + ) + + mapping = { + "enforce_script": json.dumps(base_config.output.enforce_script), + "output_file": base_config.output.file, + "prompt_separator": base_config.output.prompt_separator, + "enabled": json.dumps(base_config.history.enabled), + "history_file": base_config.history.file, + "max_size": base_config.history.max_size, + "endpoint": base_config.backend.endpoint, + } + config_formatted = CONFIG_TEMPLATE.format_map(mapping) + config_file.write_text(config_formatted) + + +def _read_config_file(config_file: Path) -> Config: """Read configuration file.""" config_dict = {} try: - with open(config_path, mode="rb") as handler: - config_dict = tomllib.load(handler) + data = config_file.read_text() + config_dict = tomllib.loads(data) except FileNotFoundError as ex: logging.error(ex) - # Normalize filepaths - config_dict["history"]["file"] = os.path.expanduser(config_dict["history"]["file"]) - config_dict["output"]["file"] = os.path.expanduser(config_dict["output"]["file"]) - return Config( output=config_dict["output"], history=config_dict["history"], @@ -116,17 +145,12 @@ def _read_config_file(config_path: str) -> Config: ) -def load_config_file(config_path: str) -> Config: +def load_config_file(config_file: Path) -> Config: """Load configuration file for shellai. If the user specifies a path where no config file is located, we will create one with default values. """ - config_file = os.path.expanduser(config_path) - # Handle case where the user initiates a config file in current dir. - if not os.path.dirname(config_file): - config_file = os.path.join(os.path.curdir, config_file) - - if not os.path.exists(config_file): + if not config_file.exists(): _create_config_file(config_file) return _read_config_file(config_file) diff --git a/shellai/handlers.py b/shellai/handlers.py index f660b6e..ba67482 100644 --- a/shellai/handlers.py +++ b/shellai/handlers.py @@ -66,7 +66,7 @@ def handle_query(query: str, config: Config) -> None: query_endpoint, headers={"Content-Type": "application/json"}, data=json.dumps(payload), - timeout=320, # waiting for more than 30 seconds does not make sense + timeout=30, # waiting for more than 30 seconds does not make sense ) response.raise_for_status() completion = response.json() diff --git a/shellai/history.py b/shellai/history.py index 6a20f13..afb0acb 100644 --- a/shellai/history.py +++ b/shellai/history.py @@ -1,6 +1,5 @@ import json import logging -import os from shellai.config import Config @@ -13,7 +12,7 @@ def handle_history_read(config: Config) -> dict: return [] filepath = config.history.file - if not filepath or not os.path.exists(filepath): + if not filepath or not filepath.exists(): logging.warning(f"History file {filepath} does not exist.") logging.warning("File will be created with first response.") return [] @@ -21,8 +20,8 @@ def handle_history_read(config: Config) -> dict: max_size = config.history.max_size history = [] try: - with open(filepath, "r") as f: - history = json.load(f) + data = filepath.read_text() + history = json.loads(data) except json.JSONDecodeError as e: logging.error(f"Failed to read history file {filepath}: {e}") return [] @@ -37,12 +36,15 @@ def handle_history_write(config: Config, history: list, response: str) -> None: """ if not config.history.enabled: return + filepath = config.history.file - os.makedirs(os.path.dirname(filepath), mode=0o755, exist_ok=True) + filepath.makedirs(mode=0o755) + if response: history.append({"role": "assistant", "content": response}) + try: - with open(filepath, "w") as f: - json.dump(history, f) + data = json.dumps(history) + filepath.write_text(data) except json.JSONDecodeError as e: logging.error(f"Failed to write history file {filepath}: {e}") diff --git a/shellai/utils.py b/shellai/utils.py index feb1d8b..4a84d4d 100644 --- a/shellai/utils.py +++ b/shellai/utils.py @@ -2,6 +2,7 @@ import os import select import sys +from pathlib import Path import yaml @@ -37,3 +38,12 @@ def get_payload(query: str) -> dict: # {"role": "user", "content": "how do I enable selinux?"}, payload = {"query": query} return payload + + +def expand_user_path(file_path: str) -> Path: + """Helper method to expand user provided path.""" + path = Path(file_path) + if not path.exists(): + raise FileNotFoundError(f"Current file does not exist or was not found: {path}") + + return Path(path).expanduser()