Skip to content

Commit

Permalink
Refactor to use dataclasses instead of namedtuple (#27)
Browse files Browse the repository at this point in the history
* Refactor to use dataclasses instead of namedtuple

* Fix slots missing in python 3.9
  • Loading branch information
r0x0d authored Nov 7, 2024
1 parent 1f8d598 commit f411a78
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 66 deletions.
129 changes: 74 additions & 55 deletions command_line_assistant/config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import dataclasses
import json
import logging
from collections import namedtuple
from pathlib import Path
from typing import Optional
from typing import ClassVar, Optional

# tomllib is available in the stdlib after Python3.11. Before that, we import
# from tomli.
Expand Down Expand Up @@ -36,55 +38,76 @@
"""


class OutputSchema(
namedtuple("Output", ["enforce_script", "file", "prompt_separator"])
):
def dataclass(cls, slots=True):
"""Custom dataclass decorator to mimic the behavior of dataclass for Python 3.9"""
try:
return dataclasses.dataclass(cls, slots=slots)
except TypeError:

def wrap(cls):
# Create a new dict for our new class.
cls_dict = dict(cls.__dict__)
field_names = tuple(name for name in cls_dict.keys())
# The slots for our class
cls_dict["__slots__"] = field_names
return dataclasses.dataclass(cls)

return wrap(cls)


@dataclass
class LoggingSchema:
"""This class represents the [logging] section of our config.toml file."""

_logging_types: ClassVar[tuple[str, str]] = (
"verbose",
"minimal",
)
type: str = "minimal"
file: str | Path = "~/.cache/command-line-assistant/command-line-assistant.log"

def _validate_logging_type(self, type: str):
if type not in self._logging_types:
raise TypeError(
f"Logging type {type} is not available. Please, choose from {(',').join(self._logging_types)}"
)

def __post_init__(self):
self.file = Path(self.file).expanduser()


@dataclass
class OutputSchema:
"""This class represents the [output] section of our config.toml file."""

# Locking down against extra fields at runtime
__slots__ = ()
enforce_script: bool = False
file: str | Path = "/tmp/command-line-assistant_output.txt"
prompt_separator: str = "$"

# We are overriding __new__ here because namedtuple only offers default values to fields from Python 3.7+
def __new__(
cls,
enforce_script: bool = False,
file: str = "/tmp/command-line-assistant_output.txt",
prompt_separator: str = "$",
):
file = Path(file).expanduser()
return super(OutputSchema, cls).__new__(
cls, enforce_script, file, prompt_separator
)
def __post_init__(self):
self.file = Path(self.file).expanduser()


class HistorySchema(namedtuple("History", ["enabled", "file", "max_size"])):
@dataclass
class HistorySchema:
"""This class represents the [history] section of our config.toml file."""

# Locking down against extra fields at runtime
__slots__ = ()
enabled: bool = True
file: str | Path = (
"~/.local/share/command-line-assistant/command-line-assistant_history.json"
)
max_size: int = 100

# We are overriding __new__ here because namedtuple only offers default values to fields from Python 3.7+
def __new__(
cls,
enabled: bool = True,
file: str = "~/.local/share/command-line-assistant/command-line-assistant_history.json",
max_size: int = 100,
):
file = Path(file).expanduser()
return super(HistorySchema, cls).__new__(cls, enabled, file, max_size)
def __post_init__(self):
self.file = Path(self.file).expanduser()


class BackendSchema(namedtuple("Backend", ["endpoint", "verify_ssl"])):
@dataclass
class BackendSchema:
"""This class represents the [backend] section of our config.toml file."""

# Locking down against extra fields at runtime
__slots__ = ()

# We are overriding __new__ here because namedtuple only offers default values to fields from Python 3.7+
def __new__(
cls, endpoint: str = "http://0.0.0.0:8080/v1/query/", verify_ssl: bool = True
):
return super(BackendSchema, cls).__new__(cls, endpoint, verify_ssl)
endpoint: str = "http://0.0.0.0:8080/v1/query"
verify_ssl: bool = True


class Config:
Expand All @@ -96,34 +119,30 @@ class Config:
>>> config.output.enforce_script
The currently available top-level fields are:
* output = Match the `py:Output` class and their fields
* history = Match the `py:History` class and their fields
* backend = Match the `py:backend` class and their fields
* output = Match the `py:OutputSchema` class and their fields
* history = Match the `py:HistorySchema` class and their fields
* backend = Match the `py:backendSchema` class and their fields
"""

def __init__(
self,
output: Optional[dict] = None,
history: Optional[dict] = None,
backend: Optional[dict] = None,
output: Optional[OutputSchema] = None,
history: Optional[HistorySchema] = None,
backend: Optional[BackendSchema] = None,
logging: Optional[LoggingSchema] = None,
) -> None:
self.output: OutputSchema = OutputSchema(**output) if output else OutputSchema()
self.history: HistorySchema = (
HistorySchema(**history) if history else HistorySchema()
)
self.backend: BackendSchema = (
BackendSchema(**backend) if backend else BackendSchema()
)
self.output: OutputSchema = output if output else OutputSchema()
self.history: HistorySchema = history if history else HistorySchema()
self.backend: BackendSchema = backend if backend else BackendSchema()
self.logging: LoggingSchema = logging if logging else LoggingSchema()


def _create_config_file(config_file: Path) -> None:
"""Create a new configuration file with default values."""

logging.info(f"Creating new config file at {config_file.parent}")
config_file.parent.mkdir(mode=0o755)
base_config = Config(
OutputSchema()._asdict(), HistorySchema()._asdict(), BackendSchema()._asdict()
)
base_config = Config()

mapping = {
"enforce_script": json.dumps(base_config.output.enforce_script),
Expand Down
22 changes: 11 additions & 11 deletions tests/test_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from command_line_assistant import history
from command_line_assistant.config import Config
from command_line_assistant.config import Config, HistorySchema

#: Mock history conversation for testing
MOCK_HISTORY_CONVERSATION: list[dict] = [
Expand All @@ -20,28 +20,28 @@ class TestHistoryRead:
"""Holds the testing functions for reading the history."""

def test_not_enabled(self):
config = Config(history={"enabled": False})
config = Config(history=HistorySchema(enabled=False))
assert not history.handle_history_read(config)

def test_history_file_missing(self, tmpdir, caplog):
history_file = tmpdir.join("non-existing-file.json")
config = Config(history={"file": history_file})
config = Config(history=HistorySchema(file=history_file))

assert not history.handle_history_read(config)
assert "File will be created with first response." in caplog.records[-1].message

def test_history_failed_to_decode_json(self, tmpdir, caplog):
history_file = tmpdir.join("non-existing-file.json")
history_file.write("not a json")
config = Config(history={"file": history_file})
config = Config(history=HistorySchema(file=history_file))

assert not history.handle_history_read(config)
assert "Failed to read history file" in caplog.records[-1].message

def test_history_read(self, tmpdir, caplog):
def test_history_read(self, tmpdir):
history_file = tmpdir.join("history.json")
history_file.write(json.dumps(MOCK_HISTORY_CONVERSATION))
config = Config(history={"file": history_file})
config = Config(history=HistorySchema(file=history_file))

assert history.handle_history_read(config) == MOCK_HISTORY_CONVERSATION

Expand All @@ -58,7 +58,7 @@ def test_history_over_max_size(self, tmpdir, multiply, max_size):
total_mock_data = MOCK_HISTORY_CONVERSATION * multiply
history_file = tmpdir.join("history.json")
history_file.write(json.dumps(total_mock_data))
config = Config(history={"file": history_file, "max_size": max_size})
config = Config(history=HistorySchema(file=history_file, max_size=max_size))

history_result = history.handle_history_read(config)
assert len(history_result) == max_size
Expand All @@ -69,20 +69,20 @@ def test_history_over_max_size(self, tmpdir, multiply, max_size):

class TestHistoryWrite:
def test_not_enabled(self):
config = Config(history={"enabled": False})
config = Config(history=HistorySchema(enabled=False))
assert not history.handle_history_write(config, [], "")

def test_history_file_missing(self, tmpdir):
history_file = tmpdir.join("history").join("non-existing-file.json")
config = Config(history={"file": history_file})
config = Config(history=HistorySchema(file=history_file))

history.handle_history_write(config, [], "test")
assert Path(history_file).exists()

def test_history_write(self, tmpdir):
expected = [{"role": "assistant", "content": "test"}]
history_file = tmpdir.join("history").join("non-existing-file.json")
config = Config(history={"file": history_file})
config = Config(history=HistorySchema(file=history_file))

history.handle_history_write(config, [], "test")

Expand All @@ -94,7 +94,7 @@ def test_history_append(self, tmpdir):
expected.append({"role": "assistant", "content": "test"})

history_file = tmpdir.join("history").join("non-existing-file.json")
config = Config(history={"file": history_file})
config = Config(history=HistorySchema(file=history_file))

history.handle_history_write(config, MOCK_HISTORY_CONVERSATION, "test")

Expand Down

0 comments on commit f411a78

Please sign in to comment.