Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RSPEED-175] Rework config file handler and history #10

Merged
merged 3 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions command_line_assistant/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,17 @@
import os
import sys

from command_line_assistant import utils
from command_line_assistant.config import (
CONFIG_DEFAULT_PATH,
load_config_file,
)
from command_line_assistant.handlers import (
handle_history_write,
handle_query,
handle_script_session,
)
from command_line_assistant.utils import read_stdin, read_yaml_config
from command_line_assistant.utils import read_stdin

logging.basicConfig(
level=logging.INFO,
Expand All @@ -32,7 +37,7 @@ def get_args():
)
parser.add_argument(
"--config",
default=os.getenv("COMMAND_LINE_ASSISTANT_CONFIG", "config.yaml"),
default=CONFIG_DEFAULT_PATH,
help="Path to the config file.",
)

Expand All @@ -57,17 +62,11 @@ def get_args():
def main():
parser, args = get_args()

config = read_yaml_config(args.config)
if not config:
logging.warning(
"Config file not found. Script will continue with default values."
)
config_file = utils.expand_user_path(args.config)
config = load_config_file(config_file)

output_capture_conf = config.get("output_capture", {})
enforce_script_session = output_capture_conf.get("enforce_script", False)
output_file = output_capture_conf.get(
"output_file", "/tmp/command-line-assistant_output.txt"
)
enforce_script_session = config.output.enforce_script
output_file = config.output.file

if enforce_script_session and (not args.record or not os.path.exists(output_file)):
parser.error(
Expand All @@ -81,7 +80,7 @@ def main():
exit(0)
if args.history_clear:
logging.info("Clearing history of conversation")
handle_history_write(config.get("history", {}), [], "")
handle_history_write(config, [], "")
if args.query_string:
handle_query(args.query_string, config)

Expand Down
156 changes: 156 additions & 0 deletions command_line_assistant/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
import json
import logging
from collections import namedtuple
from pathlib import Path

# tomllib is available in the stdlib after Python3.11. Before that, we import
# from tomli.
try:
import tomllib
except ImportError:
import tomli as tomllib

from command_line_assistant import utils

CONFIG_DEFAULT_PATH: Path = Path("~/.config/shellai/config.toml")

# tomllib does not support writting files, so we will create our own.
CONFIG_TEMPLATE = """\
[output]
# otherwise recording via script session will be enforced
enforce_script = {enforce_script}
# file with output(s) of regular commands (e.g. ls, echo, etc.)
file = "{output_file}"
# Keep non-empty if your file contains only output of commands (not prompt itself)
prompt_separator = "{prompt_separator}"

[history]
enabled = {enabled}
file = "{history_file}"
# max number of queries in history (including responses)
max_size = {max_size}

[backend]
endpoint = "{endpoint}"
"""


class OutputSchema(
namedtuple("Output", ["enforce_script", "file", "prompt_separator"])
):
"""This class represents the [output] section of our config.toml file."""

# Locking down against extra fields at runtime
__slots__ = ()

# We are overriding __new__ here because namedtuple only offers default values to fields from Python 3.7+
def __new__(
cls,
enforce_script: bool = False,
file: str = "/tmp/shellai_output.txt",
prompt_separator: str = "$",
):
file = utils.expand_user_path(file)
return super(OutputSchema, cls).__new__(
cls, enforce_script, file, prompt_separator
)


class HistorySchema(namedtuple("History", ["enabled", "file", "max_size"])):
"""This class represents the [history] section of our config.toml file."""

# Locking down against extra fields at runtime
__slots__ = ()

# We are overriding __new__ here because namedtuple only offers default values to fields from Python 3.7+
def __new__(
cls,
enabled: bool = True,
file: str = "~/.local/share/shellai/shellai_history.json",
max_size: int = 100,
):
file = utils.expand_user_path(file)
return super(HistorySchema, cls).__new__(cls, enabled, file, max_size)


class BackendSchema(namedtuple("Backend", ["endpoint"])):
"""This class represents the [backend] section of our config.toml file."""

# Locking down against extra fields at runtime
__slots__ = ()

# We are overriding __new__ here because namedtuple only offers default values to fields from Python 3.7+
def __new__(
cls,
endpoint: str = "http://0.0.0.0:8080/v1/query/",
):
return super(BackendSchema, cls).__new__(cls, endpoint)


class Config:
"""Class that holds our configuration file representation.

With this class, after being initialized, one can access their fields like:

>>> config = Config()
>>> config.output.enforce_script

The currently available top-level fields are:
* output = Match the `py:Output` class and their fields
* history = Match the `py:History` class and their fields
* backend = Match the `py:backend` class and their fields
"""

def __init__(self, output: dict, history: dict, backend: dict) -> None:
self.output: OutputSchema = OutputSchema(**output)
self.history: HistorySchema = HistorySchema(**history)
self.backend: BackendSchema = BackendSchema(**backend)


def _create_config_file(config_file: Path) -> None:
"""Create a new configuration file with default values."""

logging.info(f"Creating new config file at {config_file.parent}")
config_file.parent.mkdir(mode=0o755)
base_config = Config(
OutputSchema()._asdict(), HistorySchema()._asdict(), BackendSchema()._asdict()
)

mapping = {
"enforce_script": json.dumps(base_config.output.enforce_script),
"output_file": base_config.output.file,
"prompt_separator": base_config.output.prompt_separator,
"enabled": json.dumps(base_config.history.enabled),
"history_file": base_config.history.file,
"max_size": base_config.history.max_size,
"endpoint": base_config.backend.endpoint,
}
config_formatted = CONFIG_TEMPLATE.format_map(mapping)
config_file.write_text(config_formatted)


def _read_config_file(config_file: Path) -> Config:
"""Read configuration file."""
config_dict = {}
try:
data = config_file.read_text()
config_dict = tomllib.loads(data)
except FileNotFoundError as ex:
logging.error(ex)

return Config(
output=config_dict["output"],
history=config_dict["history"],
backend=config_dict["backend"],
)


def load_config_file(config_file: Path) -> Config:
"""Load configuration file for shellai.

If the user specifies a path where no config file is located, we will create one with default values.
"""
if not config_file.exists():
_create_config_file(config_file)

return _read_config_file(config_file)
66 changes: 10 additions & 56 deletions command_line_assistant/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,51 +4,11 @@

import requests

from command_line_assistant.config import Config
from command_line_assistant.history import handle_history_read, handle_history_write
from command_line_assistant.utils import get_payload


def _handle_history_read(config: dict) -> dict:
"""
Reads the history from a file and returns it as a list of dictionaries.
"""
if not config.get("enabled", False):
return []

filepath = config.get("filepath", "/tmp/command-line-assistant_history.json")
if not filepath or not os.path.exists(filepath):
logging.warning(f"History file {filepath} does not exist.")
logging.warning("File will be created with first response.")
return []

max_size = config.get("max_size", 100)
history = []
try:
with open(filepath, "r") as f:
history = json.load(f)
except json.JSONDecodeError as e:
logging.error(f"Failed to read history file {filepath}: {e}")
return []

logging.info(f"Taking maximum of {max_size} entries from history.")
return history[:max_size]


def handle_history_write(config: dict, history: list, response: str) -> None:
"""
Writes the history to a file.
"""
if not config.get("enabled", False):
return
filepath = config.get("filepath", "/tmp/command-line-assistant_history.json")
if response:
history.append({"role": "assistant", "content": response})
try:
with open(filepath, "w") as f:
json.dump(history, f)
except json.JSONDecodeError as e:
logging.error(f"Failed to write history file {filepath}: {e}")


def handle_script_session(command_line_assistant_tmp_file) -> None:
"""
Starts a 'script' session and writes the PID to a file, but leaves control of the terminal to the user.
Expand All @@ -65,47 +25,41 @@ def handle_script_session(command_line_assistant_tmp_file) -> None:
os.remove(command_line_assistant_tmp_file)


def _handle_caret(query: str, config: dict) -> str:
def _handle_caret(query: str, config: Config) -> str:
"""
Replaces caret (^) with command output specified in config file.
"""
if "^" not in query:
return query

output_capture_settings = config.get("output_capture_settings", {})
captured_output_file = output_capture_settings.get(
"captured_output_file", "/tmp/command-line-assistant_output.txt"
)
captured_output_file = config.output.file

if not os.path.exists(captured_output_file):
logging.error(
f"Output file {captured_output_file} does not exist, change location of file in config to use '^'."
)
exit(1)

prompt_separator = output_capture_settings.get("prompt_separator", "$")
prompt_separator = config.output.prompt_separator
with open(captured_output_file, "r") as f:
# NOTE: takes only last command + output from file
output = f.read().split(prompt_separator)[-1].strip()

query = query.replace("^", "")
query = f"Context data: {output}\nQuestion: " + query
return query


def handle_query(query: str, config: dict) -> None:
def handle_query(query: str, config: Config) -> None:
query = _handle_caret(query, config)
# NOTE: Add more query handling here

logging.info(f"Query: {query}")

backend_service = config.get("backend_service", {})
query_endpoint = backend_service.get(
"query_endpoint", "http://0.0.0.0:8080/v1/query/"
)
query_endpoint = config.backend.endpoint

try:
history_conf = config.get("history", {})
history = _handle_history_read(history_conf)
history = handle_history_read(config)
payload = get_payload(query)
logging.info("Waiting for response from AI...")
response = requests.post(
Expand All @@ -125,7 +79,7 @@ def handle_query(query: str, config: dict) -> None:
"\n\nReferences:\n" + "\n".join(references) if references else ""
)
handle_history_write(
history_conf,
config,
[
*history,
{"role": "user", "content": query},
Expand Down
10 changes: 10 additions & 0 deletions command_line_assistant/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import select
import sys
from pathlib import Path

import yaml

Expand Down Expand Up @@ -37,3 +38,12 @@ def get_payload(query: str) -> dict:
# {"role": "user", "content": "how do I enable selinux?"},
payload = {"query": query}
return payload


def expand_user_path(file_path: str) -> Path:
"""Helper method to expand user provided path."""
path = Path(file_path)
if not path.exists():
raise FileNotFoundError(f"Current file does not exist or was not found: {path}")

return Path(path).expanduser()
16 changes: 16 additions & 0 deletions config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[output]
# otherwise recording via script session will be enforced
enforce_script = false
# file with output(s) of regular commands (e.g. ls, echo, etc.)
file = "/tmp/shellai_output.txt"
# Keep non-empty if your file contains only output of commands (not prompt itself)
prompt_separator = "$"

[history]
enabled = true
file = "~/.local/share/shellai/shellai_history.json"
# max number of queries in history (including responses)
max_size = 100

[backend]
endpoint = "http://0.0.0.0:8080/v1/query/"
Loading