Skip to content

Commit

Permalink
Fixes after review
Browse files Browse the repository at this point in the history
  • Loading branch information
r0x0d committed Oct 22, 2024
1 parent b6e6573 commit 16a7e0e
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 62 deletions.
2 changes: 1 addition & 1 deletion packaging/shellai.spec
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ BuildRequires: python3-tomli

Requires: python3-requests
%if 0%{?rhel} && 0%{?rhel} < 10
BuildRequires: python3-tomli
Requires: python3-tomli
%endif

%description
Expand Down
4 changes: 3 additions & 1 deletion shellai/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import sys

from shellai import utils
from shellai.config import (
CONFIG_DEFAULT_PATH,
load_config_file,
Expand Down Expand Up @@ -61,7 +62,8 @@ def get_args():
def main():
parser, args = get_args()

config = load_config_file(args.config)
config_file = utils.expand_user_path(args.config)
config = load_config_file(config_file)

enforce_script_session = config.output.enforce_script
output_file = config.output.enforce_script
Expand Down
128 changes: 76 additions & 52 deletions shellai/config.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import json
import logging
import os
import sys
from collections import namedtuple
from pathlib import Path

if sys.version_info >= (3, 11):
from shellai import utils

# tomllib is available in the stdlib after Python3.11. Before that, we import
# from tomli.
try:
import tomllib
else:
except ImportError:
import tomli as tomllib

CONFIG_DEFAULT_PATH: str = "~/.config/shellai/config.toml"
CONFIG_DEFAULT_PATH: Path = Path("~/.config/shellai/config.toml")

# tomllib does not support writting files, so we will create our own.
CONFIG_TEMPLATE = """
CONFIG_TEMPLATE = """\
[output]
# otherwise recording via script session will be enforced
enforce_script = {enforce_script}
Expand All @@ -29,104 +32,125 @@
[backend]
endpoint = "{endpoint}"
"""


class Output(namedtuple("Output", ["enforce_script", "file", "prompt_separator"])):
class OutputSchema(
namedtuple("Output", ["enforce_script", "file", "prompt_separator"])
):
"""This class represents the [output] section of our config.toml file."""

# Locking down against extra fields at runtime
__slots__ = ()

# We are overriding __new__ here because namedtuple only offers default values to fields from Python 3.7+
def __new__(
cls,
enforce_script: bool = False,
file: str = "/tmp/shellai_output.txt",
prompt_separator: str = "$",
):
return super(Output, cls).__new__(cls, enforce_script, file, prompt_separator)
file = utils.expand_user_path(file)
return super(OutputSchema, cls).__new__(
cls, enforce_script, file, prompt_separator
)


class History(namedtuple("History", ["enabled", "file", "max_size"])):
class HistorySchema(namedtuple("History", ["enabled", "file", "max_size"])):
"""This class represents the [history] section of our config.toml file."""

# Locking down against extra fields at runtime
__slots__ = ()

# We are overriding __new__ here because namedtuple only offers default values to fields from Python 3.7+
def __new__(
cls,
enabled: bool = True,
file: str = "/tmp/shellai_output.txt",
file: str = "~/.local/share/shellai/shellai_history.json",
max_size: int = 100,
):
return super(History, cls).__new__(cls, enabled, file, max_size)
file = utils.expand_user_path(file)
return super(HistorySchema, cls).__new__(cls, enabled, file, max_size)


class Backend(namedtuple("Backend", ["endpoint"])):
endpoint: str = "http://0.0.0.0:8080/v1/query/"
class BackendSchema(namedtuple("Backend", ["endpoint"])):
"""This class represents the [backend] section of our config.toml file."""

# Locking down against extra fields at runtime
__slots__ = ()

# We are overriding __new__ here because namedtuple only offers default values to fields from Python 3.7+
def __new__(
cls,
endpoint: str = "http://0.0.0.0:8080/v1/query/",
):
return super(Backend, cls).__new__(cls, endpoint)
return super(BackendSchema, cls).__new__(cls, endpoint)


class Config:
"""Class that holds our configuration file representation.
With this class, after being initialized, one can access their fields like:
>>> config = Config()
>>> config.output.enforce_script
The currently available top-level fields are:
* output = Match the `py:Output` class and their fields
* history = Match the `py:History` class and their fields
* backend = Match the `py:backend` class and their fields
"""

def __init__(self, output: dict, history: dict, backend: dict) -> None:
self.output: Output = Output(**output)
self.history: History = History(**history)
self.backend: Backend = Backend(**backend)
self.output: OutputSchema = OutputSchema(**output)
self.history: HistorySchema = HistorySchema(**history)
self.backend: BackendSchema = BackendSchema(**backend)


def _create_config_file(config_path: str) -> None:
def _create_config_file(config_file: Path) -> None:
"""Create a new configuration file with default values."""
config_dir = os.path.dirname(config_path)
logging.info(f"Creating new config file at {config_path}")
os.makedirs(config_dir, mode=0o755, exist_ok=True)
base_config = Config(Output()._asdict(), History()._asdict(), Backend()._asdict())

with open(config_path, mode="w") as handler:
mapping = {
"enforce_script": json.dumps(base_config.output.enforce_script),
"output_file": base_config.output.file,
"prompt_separator": base_config.output.prompt_separator,
"enabled": json.dumps(base_config.history.enabled),
"history_file": base_config.history.file,
"max_size": base_config.history.max_size,
"endpoint": base_config.backend.endpoint,
}
config_formatted = CONFIG_TEMPLATE.format_map(mapping)
handler.write(config_formatted)


def _read_config_file(config_path: str) -> Config:

logging.info(f"Creating new config file at {config_file.parent}")
config_file.parent.mkdir(mode=0o755)
base_config = Config(
OutputSchema()._asdict(), HistorySchema()._asdict(), BackendSchema()._asdict()
)

mapping = {
"enforce_script": json.dumps(base_config.output.enforce_script),
"output_file": base_config.output.file,
"prompt_separator": base_config.output.prompt_separator,
"enabled": json.dumps(base_config.history.enabled),
"history_file": base_config.history.file,
"max_size": base_config.history.max_size,
"endpoint": base_config.backend.endpoint,
}
config_formatted = CONFIG_TEMPLATE.format_map(mapping)
config_file.write_text(config_formatted)


def _read_config_file(config_file: Path) -> Config:
"""Read configuration file."""
config_dict = {}
try:
with open(config_path, mode="rb") as handler:
config_dict = tomllib.load(handler)
data = config_file.read_text()
config_dict = tomllib.loads(data)
except FileNotFoundError as ex:
logging.error(ex)

# Normalize filepaths
config_dict["history"]["file"] = os.path.expanduser(config_dict["history"]["file"])
config_dict["output"]["file"] = os.path.expanduser(config_dict["output"]["file"])

return Config(
output=config_dict["output"],
history=config_dict["history"],
backend=config_dict["backend"],
)


def load_config_file(config_path: str) -> Config:
def load_config_file(config_file: Path) -> Config:
"""Load configuration file for shellai.
If the user specifies a path where no config file is located, we will create one with default values.
"""
config_file = os.path.expanduser(config_path)
# Handle case where the user initiates a config file in current dir.
if not os.path.dirname(config_file):
config_file = os.path.join(os.path.curdir, config_file)

if not os.path.exists(config_file):
if not config_file.exists():
_create_config_file(config_file)

return _read_config_file(config_file)
2 changes: 1 addition & 1 deletion shellai/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def handle_query(query: str, config: Config) -> None:
query_endpoint,
headers={"Content-Type": "application/json"},
data=json.dumps(payload),
timeout=320, # waiting for more than 30 seconds does not make sense
timeout=30, # waiting for more than 30 seconds does not make sense
)
response.raise_for_status()
completion = response.json()
Expand Down
16 changes: 9 additions & 7 deletions shellai/history.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import logging
import os

from shellai.config import Config

Expand All @@ -13,16 +12,16 @@ def handle_history_read(config: Config) -> dict:
return []

filepath = config.history.file
if not filepath or not os.path.exists(filepath):
if not filepath or not filepath.exists():
logging.warning(f"History file {filepath} does not exist.")
logging.warning("File will be created with first response.")
return []

max_size = config.history.max_size
history = []
try:
with open(filepath, "r") as f:
history = json.load(f)
data = filepath.read_text()
history = json.loads(data)
except json.JSONDecodeError as e:
logging.error(f"Failed to read history file {filepath}: {e}")
return []
Expand All @@ -37,12 +36,15 @@ def handle_history_write(config: Config, history: list, response: str) -> None:
"""
if not config.history.enabled:
return

filepath = config.history.file
os.makedirs(os.path.dirname(filepath), mode=0o755, exist_ok=True)
filepath.makedirs(mode=0o755)

if response:
history.append({"role": "assistant", "content": response})

try:
with open(filepath, "w") as f:
json.dump(history, f)
data = json.dumps(history)
filepath.write_text(data)
except json.JSONDecodeError as e:
logging.error(f"Failed to write history file {filepath}: {e}")
10 changes: 10 additions & 0 deletions shellai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import select
import sys
from pathlib import Path

import yaml

Expand Down Expand Up @@ -37,3 +38,12 @@ def get_payload(query: str) -> dict:
# {"role": "user", "content": "how do I enable selinux?"},
payload = {"query": query}
return payload


def expand_user_path(file_path: str) -> Path:
"""Helper method to expand user provided path."""
path = Path(file_path)
if not path.exists():
raise FileNotFoundError(f"Current file does not exist or was not found: {path}")

return Path(path).expanduser()

0 comments on commit 16a7e0e

Please sign in to comment.