From 54ce7a704b0c34ead83ba41fa0f99db069496a26 Mon Sep 17 00:00:00 2001 From: Rodolfo Olivieri Date: Wed, 11 Dec 2024 21:43:04 -0300 Subject: [PATCH] Improve rendering libary (#65) --- command_line_assistant/commands/query.py | 77 +++++- command_line_assistant/config/__init__.py | 4 +- command_line_assistant/rendering/base.py | 67 +++++ .../rendering/decorators/base.py | 9 - .../rendering/decorators/colors.py | 2 +- .../rendering/decorators/style.py | 2 +- .../rendering/decorators/text.py | 44 ++- command_line_assistant/rendering/render.py | 22 -- .../rendering/renders/__init__.py | 0 .../rendering/renders/spinner.py | 72 +++++ .../rendering/renders/text.py | 18 ++ command_line_assistant/rendering/spinner.py | 49 ---- command_line_assistant/rendering/stream.py | 47 ++++ command_line_assistant/utils/cli.py | 2 +- command_line_assistant/utils/environment.py | 29 +- .../config/command_line_assistant/config.toml | 2 +- data/release/clad.service | 2 +- tests/commands/test_query.py | 10 +- tests/config/test_config.py | 6 +- tests/conftest.py | 9 +- tests/daemon/http/test_adapters.py | 0 tests/daemon/http/test_session.py | 75 ++++++ tests/rendering/decorators/test_colors.py | 56 ++++ tests/rendering/decorators/test_text.py | 251 ++++++++++++++++-- tests/rendering/renders/__init__.py | 0 tests/rendering/renders/test_spinner.py | 183 +++++++++++++ .../{test_render.py => renders/test_text.py} | 4 +- tests/rendering/test_spinner.py | 106 -------- tests/rendering/test_stream.py | 226 ++++++++++++++++ tests/test_initialize.py | 2 +- tests/utils/test_environment.py | 30 ++- 31 files changed, 1166 insertions(+), 240 deletions(-) create mode 100644 command_line_assistant/rendering/base.py delete mode 100644 command_line_assistant/rendering/decorators/base.py delete mode 100644 command_line_assistant/rendering/render.py create mode 100644 command_line_assistant/rendering/renders/__init__.py create mode 100644 command_line_assistant/rendering/renders/spinner.py create mode 100644 command_line_assistant/rendering/renders/text.py delete mode 100644 command_line_assistant/rendering/spinner.py create mode 100644 command_line_assistant/rendering/stream.py create mode 100644 tests/daemon/http/test_adapters.py create mode 100644 tests/daemon/http/test_session.py create mode 100644 tests/rendering/renders/__init__.py create mode 100644 tests/rendering/renders/test_spinner.py rename tests/rendering/{test_render.py => renders/test_text.py} (95%) delete mode 100644 tests/rendering/test_spinner.py create mode 100644 tests/rendering/test_stream.py diff --git a/command_line_assistant/commands/query.py b/command_line_assistant/commands/query.py index 7b82581..d03bbf5 100644 --- a/command_line_assistant/commands/query.py +++ b/command_line_assistant/commands/query.py @@ -1,13 +1,70 @@ from argparse import Namespace +from dasbus.error import DBusError + from command_line_assistant.dbus.constants import SERVICE_IDENTIFIER from command_line_assistant.dbus.definitions import MessageInput, MessageOutput +from command_line_assistant.rendering.decorators.colors import ColorDecorator +from command_line_assistant.rendering.decorators.text import ( + EmojiDecorator, + TextWrapDecorator, + WriteOnceDecorator, +) +from command_line_assistant.rendering.renders.spinner import SpinnerRenderer +from command_line_assistant.rendering.renders.text import TextRenderer +from command_line_assistant.rendering.stream import StderrStream, StdoutStream from command_line_assistant.utils.cli import BaseCLICommand, SubParsersAction +LEGAL_NOTICE = ( + "RHEL Lightspeed Command Line Assistant can answer questions related to RHEL." + " Do not include personal or business sensitive information in your input." + "Interactions with RHEL Lightspeed may be reviewed and used to improve our " + "products and service." +) +ALWAYS_LEGAL_MESSAGE = ( + "Always check AI/LLM-generated responses for accuracy prior to use." +) + + +def _initialize_spinner_renderer() -> SpinnerRenderer: + spinner = SpinnerRenderer( + message="Requesting knowledge from AI", stream=StdoutStream(end="") + ) + + spinner.update(EmojiDecorator(emoji="U+1F916")) # Robot emoji + spinner.update(TextWrapDecorator()) + + return spinner + + +def _initialize_text_renderer() -> TextRenderer: + text = TextRenderer(stream=StdoutStream(end="\n")) + text.update(ColorDecorator(foreground="green")) # Robot emoji + text.update(TextWrapDecorator()) + + return text + + +def _initialize_legal_renderer(write_once: bool = False) -> TextRenderer: + text = TextRenderer(stream=StderrStream()) + text.update(ColorDecorator(foreground="lightyellow")) + text.update(TextWrapDecorator()) + + if write_once: + text.update(WriteOnceDecorator(state_filename="legal")) + + return text + class QueryCommand(BaseCLICommand): def __init__(self, query_string: str) -> None: self._query = query_string + + self._spinner_renderer: SpinnerRenderer = _initialize_spinner_renderer() + self._text_renderer: TextRenderer = _initialize_text_renderer() + self._legal_renderer: TextRenderer = _initialize_legal_renderer(write_once=True) + self._warning_renderer: TextRenderer = _initialize_legal_renderer() + super().__init__() def run(self) -> None: @@ -16,13 +73,21 @@ def run(self) -> None: input_query = MessageInput() input_query.message = self._query - print("Requesting knowledge from the AI :robot:") - proxy.ProcessQuery(MessageInput.to_structure(input_query)) - - output = MessageOutput.from_structure(proxy.RetrieveAnswer).message + output = "Nothing to see here..." + try: + with self._spinner_renderer: + proxy.ProcessQuery(MessageInput.to_structure(input_query)) + output = MessageOutput.from_structure(proxy.RetrieveAnswer).message - if output: - print("\n", output) + self._legal_renderer.render(LEGAL_NOTICE) + self._text_renderer.render(output) + self._warning_renderer.render(ALWAYS_LEGAL_MESSAGE) + except DBusError: + self._text_renderer.update(ColorDecorator(foreground="red")) + self._text_renderer.update(EmojiDecorator(emoji="U+1F641")) + self._text_renderer.render( + "Uh oh... Something went wrong. Try again later." + ) def register_subcommand(parser: SubParsersAction) -> None: diff --git a/command_line_assistant/config/__init__.py b/command_line_assistant/config/__init__.py index 9cc4cc8..2af252a 100644 --- a/command_line_assistant/config/__init__.py +++ b/command_line_assistant/config/__init__.py @@ -10,7 +10,7 @@ LoggingSchema, OutputSchema, ) -from command_line_assistant.utils.environment import get_xdg_path +from command_line_assistant.utils.environment import get_xdg_config_path # tomllib is available in the stdlib after Python3.11. Before that, we import # from tomli. @@ -54,7 +54,7 @@ def load_config_file() -> Config: """Read configuration file.""" config_dict = {} - config_file_path = Path(get_xdg_path(), *CONFIG_FILE_DEFINITION) + config_file_path = Path(get_xdg_config_path(), *CONFIG_FILE_DEFINITION) try: print(f"Loading configuration file from {config_file_path}") diff --git a/command_line_assistant/rendering/base.py b/command_line_assistant/rendering/base.py new file mode 100644 index 0000000..a73777f --- /dev/null +++ b/command_line_assistant/rendering/base.py @@ -0,0 +1,67 @@ +from abc import ABC, abstractmethod +from typing import TextIO + + +class RenderDecorator(ABC): + """Abstract base class for render decorators""" + + @abstractmethod + def decorate(self, text: str) -> str: + pass + + +class OutputStreamWritter(ABC): + """Abstract base class for output stream decorators""" + + def __init__(self, stream: TextIO, end: str = "\n") -> None: + """ + Initialize the output stream decorator. + + Args: + stream: The output stream to use + end: The string to append after the text (defaults to newline) + """ + self._stream = stream + self._end = end + + @abstractmethod + def write(self, text: str) -> None: + """Write the text to the output stream""" + pass + + @abstractmethod + def flush(self) -> None: + """Flush the output stream""" + pass + + def execute(self, text: str) -> None: + """ + Write the text to the output stream and return the original text for chaining. + """ + if text: + self.write(text) + self.flush() + + +class BaseRenderer(ABC): + """Base class for renderers providing common functionality.""" + + def __init__(self, stream: OutputStreamWritter) -> None: + self._stream = stream + self._decorators: dict[type, RenderDecorator] = {} + + def update(self, decorator: RenderDecorator) -> None: + """Update or add a decorator of the same type.""" + self._decorators[type(decorator)] = decorator + + def _apply_decorators(self, text: str) -> str: + """Apply all decorators to the text.""" + decorated_text = text + for decorator in self._decorators.values(): + decorated_text = decorator.decorate(decorated_text) + return decorated_text + + @abstractmethod + def render(self, text: str) -> None: + """Render the text with all decorators applied.""" + pass diff --git a/command_line_assistant/rendering/decorators/base.py b/command_line_assistant/rendering/decorators/base.py deleted file mode 100644 index fdb653e..0000000 --- a/command_line_assistant/rendering/decorators/base.py +++ /dev/null @@ -1,9 +0,0 @@ -from abc import ABC, abstractmethod - - -class RenderDecorator(ABC): - """Abstract base class for render decorators""" - - @abstractmethod - def decorate(self, text: str) -> str: - pass diff --git a/command_line_assistant/rendering/decorators/colors.py b/command_line_assistant/rendering/decorators/colors.py index 5d12445..0bf5c1c 100644 --- a/command_line_assistant/rendering/decorators/colors.py +++ b/command_line_assistant/rendering/decorators/colors.py @@ -3,7 +3,7 @@ from colorama import Back, Fore, Style -from command_line_assistant.rendering.decorators.base import RenderDecorator +from command_line_assistant.rendering.base import RenderDecorator class ColorDecorator(RenderDecorator): diff --git a/command_line_assistant/rendering/decorators/style.py b/command_line_assistant/rendering/decorators/style.py index 96ed1de..3ec56ab 100644 --- a/command_line_assistant/rendering/decorators/style.py +++ b/command_line_assistant/rendering/decorators/style.py @@ -2,7 +2,7 @@ from colorama import Style -from command_line_assistant.rendering.decorators.base import RenderDecorator +from command_line_assistant.rendering.base import RenderDecorator class StyleDecorator(RenderDecorator): diff --git a/command_line_assistant/rendering/decorators/text.py b/command_line_assistant/rendering/decorators/text.py index 01a6f02..084488d 100644 --- a/command_line_assistant/rendering/decorators/text.py +++ b/command_line_assistant/rendering/decorators/text.py @@ -1,8 +1,10 @@ import shutil import textwrap +from pathlib import Path from typing import Optional, Union -from command_line_assistant.rendering.decorators.base import RenderDecorator +from command_line_assistant.rendering.base import RenderDecorator +from command_line_assistant.utils.environment import get_xdg_state_path class EmojiDecorator(RenderDecorator): @@ -40,3 +42,43 @@ def decorate(self, text: str) -> str: initial_indent=self._indent, subsequent_indent=self._indent, ) + + +class WriteOnceDecorator(RenderDecorator): + """Decorator that ensures content is written only once by checking a state file. + + The state file is created under $XDG_STATE_HOME/command-line-assistant/legal/ + """ + + def __init__(self, state_filename: str = "written") -> None: + """Initialize the write once decorator. + + Args: + state_filename: Name of the state file to create/check + """ + self._state_dir = Path(get_xdg_state_path(), "command-line-assistant") + self._state_file = self._state_dir / state_filename + + def _should_write(self) -> bool: + """Check if content should be written by verifying state file existence.""" + if self._state_file.exists(): + return False + + if not self._state_dir.exists(): + # Create directory if it doesn't exist + self._state_dir.mkdir(parents=True) + + # Write state file + self._state_file.write_text("1") + return True + + def decorate(self, text: str) -> str: + """Write the text only if it hasn't been written before. + + Args: + text: The text to potentially write + + Returns: + The text if it should be written, None otherwise + """ + return text if self._should_write() else "" diff --git a/command_line_assistant/rendering/render.py b/command_line_assistant/rendering/render.py deleted file mode 100644 index 973c61b..0000000 --- a/command_line_assistant/rendering/render.py +++ /dev/null @@ -1,22 +0,0 @@ -import shutil - -from command_line_assistant.rendering.decorators.base import RenderDecorator - - -class TextRenderer: - def __init__(self) -> None: - # Fetch the current terminal size on initialization - self.terminal_width = shutil.get_terminal_size().columns - self._decorators: dict[type, RenderDecorator] = {} - - def update(self, decorator: RenderDecorator) -> None: - """Update or add a decorator of the same type.""" - self._decorators[type(decorator)] = decorator - - def render(self, text: str): - decorated_text = text - # Apply all decorators except Spinner - for decorator in self._decorators.values(): - decorated_text = decorator.decorate(decorated_text) - - print(decorated_text) diff --git a/command_line_assistant/rendering/renders/__init__.py b/command_line_assistant/rendering/renders/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/command_line_assistant/rendering/renders/spinner.py b/command_line_assistant/rendering/renders/spinner.py new file mode 100644 index 0000000..e76f1ae --- /dev/null +++ b/command_line_assistant/rendering/renders/spinner.py @@ -0,0 +1,72 @@ +import itertools +import threading +import time +from dataclasses import dataclass +from typing import Iterator, Optional + +from command_line_assistant.rendering.base import BaseRenderer, OutputStreamWritter +from command_line_assistant.rendering.stream import StdoutStream + + +@dataclass +class Frames: + default: Iterator[str] = itertools.cycle(["⠋", "⠙", "⠸", "⠴", "⠦", "⠇"]) + dash: Iterator[str] = itertools.cycle(["-", "\\", "|", "/"]) + circular: Iterator[str] = itertools.cycle(["◐", "◓", "◑", "◒"]) + dots: Iterator[str] = itertools.cycle([". ", ".. ", "...", " ..", " .", " "]) + arrows: Iterator[str] = itertools.cycle(["←", "↖", "↑", "↗", "→", "↘", "↓", "↙"]) + moving: Iterator[str] = itertools.cycle( + ["[ ]", "[= ]", "[== ]", "[===]", "[ ==]", "[ =]", "[ ]"] + ) + + +class SpinnerRenderer(BaseRenderer): + def __init__( + self, + message: str, + stream: Optional[OutputStreamWritter] = None, + frames: Iterator[str] = Frames.default, + delay: float = 0.1, + clear_message: bool = False, + ) -> None: + super().__init__(stream or StdoutStream()) + self._message = message + self._frames = frames + self._delay = delay + self._clear_message = clear_message + self._done = threading.Event() + self._spinner_thread: Optional[threading.Thread] = None + + def render(self, text: str) -> None: + """Render text with all decorators applied.""" + decorated_text = self._apply_decorators(text) + self._stream.execute(decorated_text) + + def _animation(self) -> None: + while not self._done.is_set(): + frame = next(self._frames) + message = self._apply_decorators(f"{frame} {self._message}") + self._stream.execute(f"\r{message}") + time.sleep(self._delay) + + def start(self) -> None: + """Start the spinner animation""" + self._done.clear() + self._spinner_thread = threading.Thread(target=self._animation) + self._spinner_thread.start() + + def stop(self) -> None: + """Stop the spinner animation""" + if self._spinner_thread: + self._done.set() + self._spinner_thread.join() + self._stream.execute("\n") + if self._clear_message: + self._stream.execute(f"\r{' ' * (len(self._message) + 2)}\r") + + def __enter__(self) -> "SpinnerRenderer": + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.stop() diff --git a/command_line_assistant/rendering/renders/text.py b/command_line_assistant/rendering/renders/text.py new file mode 100644 index 0000000..5a64319 --- /dev/null +++ b/command_line_assistant/rendering/renders/text.py @@ -0,0 +1,18 @@ +import shutil +from typing import Optional + +from command_line_assistant.rendering.base import BaseRenderer, OutputStreamWritter +from command_line_assistant.rendering.stream import StdoutStream + + +class TextRenderer(BaseRenderer): + def __init__(self, stream: Optional[OutputStreamWritter] = None) -> None: + super().__init__(stream or StdoutStream()) + self.terminal_width = shutil.get_terminal_size().columns + + def render(self, text: str) -> None: + """Render text with all decorators applied.""" + lines = text.splitlines() + for line in lines: + decorated_text = self._apply_decorators(line) + self._stream.execute(decorated_text) diff --git a/command_line_assistant/rendering/spinner.py b/command_line_assistant/rendering/spinner.py deleted file mode 100644 index 2905f7a..0000000 --- a/command_line_assistant/rendering/spinner.py +++ /dev/null @@ -1,49 +0,0 @@ -import itertools -import sys -import threading -import time -from contextlib import contextmanager -from dataclasses import dataclass -from typing import Generator, Iterator - - -@dataclass -class Frames: - default: Iterator[str] = itertools.cycle(["-", "\\", "|", "/"]) - braille: Iterator[str] = itertools.cycle(["⠋", "⠙", "⠸", "⠴", "⠦", "⠇"]) - circular: Iterator[str] = itertools.cycle(["◐", "◓", "◑", "◒"]) - dots: Iterator[str] = itertools.cycle([". ", ".. ", "...", " ..", " .", " "]) - arrows: Iterator[str] = itertools.cycle(["←", "↖", "↑", "↗", "→", "↘", "↓", "↙"]) - moving: Iterator[str] = itertools.cycle( - ["[ ]", "[= ]", "[== ]", "[===]", "[ ==]", "[ =]", "[ ]"] - ) - - -@contextmanager -def ascii_spinner( - message: str, - clear_message: bool = False, - frames: Iterator[str] = Frames.default, - delay: float = 0.1, -) -> Generator: - done = threading.Event() - - def animation() -> None: - while not done.is_set(): - sys.stdout.write(f"\r{next(frames)} {message}") # Write the current frame - sys.stdout.flush() - time.sleep(delay) # Delay between frames - - spinner_thread = threading.Thread(target=animation) - spinner_thread.start() - - try: - yield - finally: - done.set() # Signal the spinner to stop - spinner_thread.join() # Wait for the spinner thread to finish - sys.stdout.write("\r\n") - if clear_message: - # Clear the message by overwriting it with spaces and resetting the cursor - sys.stdout.write("\r" + " " * (len(message) + 2) + "\r") # Clear the line - sys.stdout.flush() diff --git a/command_line_assistant/rendering/stream.py b/command_line_assistant/rendering/stream.py new file mode 100644 index 0000000..3e14b34 --- /dev/null +++ b/command_line_assistant/rendering/stream.py @@ -0,0 +1,47 @@ +import sys + +from command_line_assistant.rendering.base import ( + OutputStreamWritter, +) + + +class StderrStream(OutputStreamWritter): + """Decorator for outputting text to stderr""" + + def __init__(self, end: str = "\n") -> None: + """ + Initialize the stderr decorator. + + Args: + end: The string to append after the text (defaults to newline) + """ + super().__init__(stream=sys.stderr, end=end) + + def write(self, text: str) -> None: + """Write the text to stderr""" + self._stream.write(text + self._end) + + def flush(self) -> None: + """Flush stderr""" + self._stream.flush() + + +class StdoutStream(OutputStreamWritter): + """Decorator for outputting text to stdout""" + + def __init__(self, end: str = "\n") -> None: + """ + Initialize the stdout decorator. + + Args: + end: The string to append after the text (defaults to newline) + """ + super().__init__(stream=sys.stdout, end=end) + + def write(self, text: str) -> None: + """Write the text to stdout""" + self._stream.write(text + self._end) + + def flush(self) -> None: + """Flush stdout""" + self._stream.flush() diff --git a/command_line_assistant/utils/cli.py b/command_line_assistant/utils/cli.py index 5e8c2c8..9f5c726 100644 --- a/command_line_assistant/utils/cli.py +++ b/command_line_assistant/utils/cli.py @@ -16,7 +16,7 @@ class BaseCLICommand(ABC): @abstractmethod def run(self): - raise NotImplementedError("Not implemented in base class.") + pass def add_default_command(argv): diff --git a/command_line_assistant/utils/environment.py b/command_line_assistant/utils/environment.py index fde9659..76f30a6 100644 --- a/command_line_assistant/utils/environment.py +++ b/command_line_assistant/utils/environment.py @@ -1,10 +1,31 @@ import os +from pathlib import Path # The wanted xdg path where the configuration files will live. -WANTED_XDG_PATH = "/etc/xdg" +WANTED_XDG_PATH = Path("/etc/xdg") +# The wanted xdg state path in case $XDG_STATE_HOME is not defined. +WANTED_XDG_STATE_PATH = Path("~/.local/state").expanduser() -def get_xdg_path() -> str: + +def get_xdg_state_path() -> Path: + """Check for the existence of XDG_STATE_HOME environment variable. + + In case it is not present, this function will return the default path that + is `~/.local/state`, which is where we want to place temporary state files for + Command Line Assistant. + + See: https://specifications.freedesktop.org/basedir-spec/latest/ + """ + xdg_state_home = os.getenv("XDG_STATE_HOME", "") + + # We call expanduser() for the xdg_state_home in case someone do "~/" + return ( + Path(xdg_state_home).expanduser() if xdg_state_home else WANTED_XDG_STATE_PATH + ) + + +def get_xdg_config_path() -> Path: """Check for the existence of XDG_CONFIG_DIRS environment variable. In case it is not present, this function will return the default path that @@ -35,11 +56,11 @@ def get_xdg_path() -> str: # XDG_CONFIG_DIRS was overrided and pointed to a specific location. # We hope to find the config file there. if len(xdg_config_dirs) == 1: - return xdg_config_dirs[0] + return Path(xdg_config_dirs[0]) # Try to find the first occurence of the wanted_xdg_dir in the path, in # case it is not present, return the default value. xdg_dir_found = next( (dir for dir in xdg_config_dirs if dir == WANTED_XDG_PATH), WANTED_XDG_PATH ) - return xdg_dir_found + return Path(xdg_dir_found) diff --git a/data/development/config/command_line_assistant/config.toml b/data/development/config/command_line_assistant/config.toml index 901cce4..d26f06a 100644 --- a/data/development/config/command_line_assistant/config.toml +++ b/data/development/config/command_line_assistant/config.toml @@ -13,7 +13,7 @@ file = "~/.local/share/command-line-assistant/command-line-assistant_history.jso max_size = 100 [backend] -endpoint = "http://localhost.com" +endpoint = "https://rlsrag-rhel-lightspeed--runtime-int.apps.int.spoke.preprod.us-east-1.aws.paas.redhat.com" [backend.auth] cert_file = "data/development/certificate/fake-certificate.pem" diff --git a/data/release/clad.service b/data/release/clad.service index 0ddf19b..3fc1c8e 100644 --- a/data/release/clad.service +++ b/data/release/clad.service @@ -4,7 +4,7 @@ Documentation=https://github.com/rhel-lightspeed/command-line-assistant After=network.service [Service] -BusName=redhat.rhel.lightspeed +BusName=com.rhel.lightspeed PrivateTmp=yes RemainAfterExit=no ExecStart=$(exec_prefix}/sbin/clad diff --git a/tests/commands/test_query.py b/tests/commands/test_query.py index bdf11ab..6eb3024 100644 --- a/tests/commands/test_query.py +++ b/tests/commands/test_query.py @@ -77,7 +77,7 @@ def test_query_command_empty_response(mock_dbus_service, capsys): command.run() captured = capsys.readouterr() - assert captured.out == "Requesting knowledge from the AI :robot:\n" + assert "Requesting knowledge from AI" in captured.out.strip() @pytest.mark.parametrize( @@ -122,7 +122,7 @@ def test_command_factory(): assert command._query == "test query" -def test_dbus_error_handling(mock_dbus_service): +def test_dbus_error_handling(mock_dbus_service, capsys): """Test handling of DBus errors""" from dasbus.error import DBusError @@ -130,9 +130,11 @@ def test_dbus_error_handling(mock_dbus_service): mock_dbus_service.ProcessQuery.side_effect = DBusError("Test DBus Error") command = QueryCommand("test query") + command.run() - with pytest.raises(DBusError): - command.run() + # Verify error message in stdout + captured = capsys.readouterr() + assert "Uh oh... Something went wrong. Try again later." in captured.out.strip() def test_query_with_special_characters(mock_dbus_service, capsys): diff --git a/tests/config/test_config.py b/tests/config/test_config.py index 371ca59..c512bde 100644 --- a/tests/config/test_config.py +++ b/tests/config/test_config.py @@ -47,7 +47,7 @@ def test_load_config_file(tmp_path, monkeypatch, get_config_template): config_file.parent.mkdir() config_file.write_text(get_config_template) - monkeypatch.setattr(config, "get_xdg_path", lambda: config_file_path) + monkeypatch.setattr(config, "get_xdg_config_path", lambda: config_file_path) instance = config.load_config_file() assert isinstance(instance, config.Config) @@ -58,7 +58,7 @@ def test_load_config_file(tmp_path, monkeypatch, get_config_template): def test_load_config_file_not_found(tmp_path, monkeypatch): config_file = tmp_path / "whatever" - monkeypatch.setattr(config, "get_xdg_path", lambda: config_file) + monkeypatch.setattr(config, "get_xdg_config_path", lambda: config_file) with pytest.raises(FileNotFoundError): config.load_config_file() @@ -73,7 +73,7 @@ def test_load_config_file_decoded_error(tmp_path, monkeypatch): enforce_script = False """) - monkeypatch.setattr(config, "get_xdg_path", lambda: config_file_path) + monkeypatch.setattr(config, "get_xdg_config_path", lambda: config_file_path) with pytest.raises(tomllib.TOMLDecodeError): config.load_config_file() diff --git a/tests/conftest.py b/tests/conftest.py index bd12423..9d536b7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -29,8 +29,13 @@ def setup_logger(request): @pytest.fixture -def mock_config(): +def mock_config(tmp_path): """Fixture to create a mock configuration""" + cert_file = tmp_path / "cert.pem" + key_file = tmp_path / "key.pem" + + cert_file.write_text("cert") + key_file.write_text("key") return Config( output=OutputSchema( enforce_script=False, @@ -39,7 +44,7 @@ def mock_config(): ), backend=BackendSchema( endpoint="http://test.endpoint/v1/query", - auth=AuthSchema(cert_file=Path(""), key_file=Path(""), verify_ssl=True), + auth=AuthSchema(cert_file=cert_file, key_file=key_file, verify_ssl=False), ), history=HistorySchema( enabled=True, file=Path("/tmp/test_history.json"), max_size=100 diff --git a/tests/daemon/http/test_adapters.py b/tests/daemon/http/test_adapters.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/daemon/http/test_session.py b/tests/daemon/http/test_session.py new file mode 100644 index 0000000..2bd2c19 --- /dev/null +++ b/tests/daemon/http/test_session.py @@ -0,0 +1,75 @@ +from unittest.mock import MagicMock, patch + +import pytest +import urllib3 + +from command_line_assistant.constants import VERSION +from command_line_assistant.daemon.http.session import get_session + + +def test_session_headers(mock_config): + """Test that session headers are properly set""" + session = get_session(mock_config) + + assert session.headers["User-Agent"] == f"clad/{VERSION}" + assert session.headers["Content-Type"] == "application/json" + + +def test_ssl_verification_disabledmock_config(mock_config): + """Test that SSL verification is disabled when configured""" + + with patch("urllib3.disable_warnings") as mock_disable_warnings: + session = get_session(mock_config) + + mock_disable_warnings.assert_called_once_with( + urllib3.exceptions.InsecureRequestWarning + ) + assert session.verify is False + + +@patch("command_line_assistant.daemon.http.session.Session") +def test_session_creation(mock_session, mock_config): + """Test basic session creation""" + mock_session_instance = MagicMock() + mock_session.return_value = mock_session_instance + + session = get_session(mock_config) + + mock_session.assert_called_once() + assert session == mock_session_instance + + +def test_ssl_verification_disabled_logs_warning(caplog, mock_config): + """Test that disabling SSL verification logs a warning""" + with patch("urllib3.disable_warnings"): + get_session(mock_config) + + assert "Disabling SSL verification as per user requested." in caplog.text + + +def test_different_endpoint_configuration(mock_config): + """Test session creation with different endpoint configurations""" + custom_endpoint = "https://custom-endpoint:9090" + mock_config.backend.endpoint = custom_endpoint + + session = get_session(mock_config) + + # Verify that the custom endpoint is used for mounting adapters + assert any(pattern == custom_endpoint for pattern, _ in session.adapters.items()) + + +@pytest.mark.parametrize( + "endpoint", + [ + "http://localhost:8080", + "https://api.example.com", + "http://127.0.0.1:5000", + ], +) +def test_various_endpoints(mock_config, endpoint): + """Test session creation with various endpoint configurations""" + mock_config.backend.endpoint = endpoint + session = get_session(mock_config) + + # Verify that the endpoint is used for mounting adapters + assert any(pattern == endpoint for pattern, _ in session.adapters.items()) diff --git a/tests/rendering/decorators/test_colors.py b/tests/rendering/decorators/test_colors.py index 7ba7fa9..8754e1f 100644 --- a/tests/rendering/decorators/test_colors.py +++ b/tests/rendering/decorators/test_colors.py @@ -1,7 +1,11 @@ +import os +from unittest.mock import patch + import pytest from command_line_assistant.rendering.decorators.colors import ( ColorDecorator, + should_disable_color_output, ) @@ -30,3 +34,55 @@ def test_color_decorator_invalid_color(): with pytest.raises(ValueError): ColorDecorator(foreground="white", background="invalid") + + +@pytest.mark.parametrize( + ("env_value", "expected"), + [ + ("1", True), + ("true", True), + ("TRUE", True), + ("True", True), + ("yes", True), + ("YES", True), + ("anything", True), # Any non-empty value except "0" or "false" + ("0", False), + ("false", False), + ("FALSE", False), + ("False", False), + (None, False), # NO_COLOR not set + ], +) +def test_should_disable_color_output(env_value, expected): + """Test different NO_COLOR environment variable values""" + with patch.dict( + os.environ, {"NO_COLOR": env_value} if env_value is not None else {}, clear=True + ): + assert should_disable_color_output() == expected + + +def test_should_disable_color_output_no_env(): + """Test when NO_COLOR environment variable is not set""" + with patch.dict(os.environ, {}, clear=True): + assert should_disable_color_output() is False + + +def test_should_disable_color_output_empty_string(): + """Test when NO_COLOR is set to empty string""" + with patch.dict(os.environ, {"NO_COLOR": ""}, clear=True): + assert should_disable_color_output() is True + + +@pytest.mark.parametrize( + ("env_vars", "expected"), + [ + ({"NO_COLOR": "1", "TERM": "xterm"}, True), # NO_COLOR takes precedence + ({"NO_COLOR": "0", "TERM": "dumb"}, False), # NO_COLOR takes precedence + ({"TERM": "dumb"}, False), # Only TERM present + ({}, False), # No environment variables set + ], +) +def test_should_disable_color_output_with_other_env(env_vars, expected): + """Test interaction with other environment variables""" + with patch.dict(os.environ, env_vars, clear=True): + assert should_disable_color_output() == expected diff --git a/tests/rendering/decorators/test_text.py b/tests/rendering/decorators/test_text.py index 95b9161..de3a6f0 100644 --- a/tests/rendering/decorators/test_text.py +++ b/tests/rendering/decorators/test_text.py @@ -1,29 +1,246 @@ +import shutil +from typing import Iterator + +import pytest + from command_line_assistant.rendering.decorators.text import ( EmojiDecorator, TextWrapDecorator, + WriteOnceDecorator, ) -def test_text_wrap_decorator(): - decorator = TextWrapDecorator(width=10) - text = "This is a long text that should be wrapped" - decorated = decorator.decorate(text) - assert len(max(decorated.split("\n"), key=len)) <= 10 +class TestEmojiDecorator: + @pytest.mark.parametrize( + ("emoji_input", "expected"), + [ + ("👍", "👍 Test text"), # Direct emoji + ("🚀", "🚀 Test text"), # Direct emoji + ("⭐", "⭐ Test text"), # Direct emoji + (0x1F604, "😄 Test text"), # Unicode code point as int + ("U+1F604", "😄 Test text"), # Unicode code point as string + ("0x1F604", "😄 Test text"), # Hex code point as string + ], + ) + def test_emoji_decorator(self, emoji_input, expected): + """Test emoji decorator with various input formats""" + decorator = EmojiDecorator(emoji_input) + result = decorator.decorate("Test text") + assert result == expected + + def test_invalid_emoji_type(self): + """Test emoji decorator with invalid input type""" + with pytest.raises(TypeError, match="Emoji must be string or int"): + EmojiDecorator([1, 2, 3]) # type: ignore + + @pytest.mark.parametrize( + "emoji_code", + [ + 0x1F600, # Grinning face + 0x2B50, # Star + 0x1F680, # Rocket + 0x1F44D, # Thumbs up + 0x2764, # Heart + ], + ) + def test_numeric_emoji_codes(self, emoji_code): + """Test emoji decorator with different numeric Unicode code points""" + decorator = EmojiDecorator(emoji_code) + result = decorator.decorate("Test") + assert len(result.split()[0]) <= 2 # Emoji should be 1-2 characters + assert result.endswith("Test") + + def test_emoji_with_empty_text(self): + """Test emoji decorator with empty text""" + decorator = EmojiDecorator("🎉") + result = decorator.decorate("") + assert result == "🎉 " + + def test_emoji_with_multiline_text(self): + """Test emoji decorator with multiline text""" + decorator = EmojiDecorator("📝") + text = "Line 1\nLine 2\nLine 3" + result = decorator.decorate(text) + assert result == "📝 " + text + + +class TestTextWrapDecorator: + @pytest.fixture + def terminal_width(self) -> Iterator[int]: + """Fixture to get the terminal width""" + original_terminal_size = shutil.get_terminal_size() + yield original_terminal_size.columns + + def test_default_width(self, terminal_width): + """Test text wrap decorator with default width""" + decorator = TextWrapDecorator() + assert decorator._width == terminal_width + + def test_custom_width(self): + """Test text wrap decorator with custom width""" + decorator = TextWrapDecorator(width=40) + assert decorator._width == 40 + + def test_custom_indent(self): + """Test text wrap decorator with custom indent""" + decorator = TextWrapDecorator(indent=" ") + assert decorator._indent == " " + + @pytest.mark.parametrize( + ("width", "indent", "text", "expected"), + [ + (20, "", "Short text", "Short text"), + ( + 10, + "", + "This is a long text that should wrap", + "This is a\nlong text\nthat\nshould\nwrap", + ), + (10, " ", "Indented text wrap test", " Indented\n text\n wrap\n test"), + ( + 15, + "-> ", + "Multiple line indent test", + "-> Multiple\n-> line indent\n-> test", + ), + ], + ) + def test_text_wrapping(self, width, indent, text, expected): + """Test text wrapping with various configurations""" + decorator = TextWrapDecorator(width=width, indent=indent) + result = decorator.decorate(text) + assert result == expected + + def test_long_word_handling(self): + """Test handling of words longer than wrap width""" + decorator = TextWrapDecorator(width=10) + result = decorator.decorate("supercalifragilisticexpialidocious") + assert max(len(line) for line in result.split("\n")) >= 10 + + def test_empty_text(self): + """Test wrapping empty text""" + decorator = TextWrapDecorator(width=10) + result = decorator.decorate("") + assert result == "" + + def test_whitespace_handling(self): + """Test handling of various whitespace scenarios""" + decorator = TextWrapDecorator(width=10) + text = " Multiple spaces test " + result = decorator.decorate(text) + assert " " not in result # Should not contain multiple consecutive spaces + + +class TestWriteOnceDecorator: + @pytest.fixture + def temp_state_dir(self, tmp_path): + """Fixture to provide a temporary state directory""" + state_dir = tmp_path / "state" + state_dir.mkdir() + return state_dir + + @pytest.fixture + def decorator(self, temp_state_dir, monkeypatch): + """Fixture to provide a WriteOnceDecorator with mocked state directory""" + monkeypatch.setattr( + "command_line_assistant.rendering.decorators.text.get_xdg_state_path", + lambda: temp_state_dir, + ) + return WriteOnceDecorator("test_state") + + def test_first_write(self, decorator): + """Test first write with decorator""" + text = "First time text" + result = decorator.decorate(text) + assert result == text + assert decorator._state_file.exists() + + def test_subsequent_write(self, decorator): + """Test subsequent writes with decorator""" + first_result = decorator.decorate("First write") + second_result = decorator.decorate("Second write") + assert first_result == "First write" + assert not second_result + + def test_different_state_files(self, temp_state_dir, monkeypatch): + """Test different state files for different instances""" + monkeypatch.setattr( + "command_line_assistant.rendering.decorators.text.get_xdg_state_path", + lambda: temp_state_dir, + ) + + decorator1 = WriteOnceDecorator("state1") + decorator2 = WriteOnceDecorator("state2") + + result1 = decorator1.decorate("Text 1") + result2 = decorator2.decorate("Text 2") + + assert result1 == "Text 1" + assert result2 == "Text 2" + assert decorator1._state_file != decorator2._state_file + + def test_state_directory_creation(self, temp_state_dir, monkeypatch): + """Test state directory creation if it doesn't exist""" + non_existent_dir = temp_state_dir / "subdir" + monkeypatch.setattr( + "command_line_assistant.rendering.decorators.text.get_xdg_state_path", + lambda: non_existent_dir, + ) + + decorator = WriteOnceDecorator("test_state") + decorator.decorate("Test text") + + assert non_existent_dir.exists() + assert decorator._state_file.exists() + + def test_empty_text(self, decorator): + """Test decorator with empty text""" + result = decorator.decorate("") + assert result == "" + assert decorator._state_file.exists() + + def test_multiple_decorators_same_file(self, temp_state_dir, monkeypatch): + """Test multiple decorators using the same state file""" + monkeypatch.setattr( + "command_line_assistant.rendering.decorators.text.get_xdg_state_path", + lambda: temp_state_dir, + ) + + decorator1 = WriteOnceDecorator("same_state") + decorator2 = WriteOnceDecorator("same_state") + result1 = decorator1.decorate("First text") + result2 = decorator2.decorate("Second text") -def test_emoji_decorator(): - decorator = EmojiDecorator("⭐") - text = "Test text" - assert decorator.decorate(text) == "⭐ Test text" + assert result1 == "First text" + assert not result2 + def test_state_file_permissions(self, decorator): + """Test state file permissions""" + decorator.decorate("Test text") + assert decorator._state_file.exists() + assert oct(decorator._state_file.stat().st_mode)[-3:] == "644" -def test_emoji_decorator_with_hex(): - decorator = EmojiDecorator(0x2728) # Sparkles emoji - text = "Test text" - assert decorator.decorate(text) == "✨ Test text" + @pytest.mark.parametrize( + "filename", + [ + "test-state", + "test_state", + "test.state", + "TEST_STATE", + "123_state", + "state_123", + ], + ) + def test_various_filenames(self, temp_state_dir, monkeypatch, filename): + """Test decorator with various valid state filenames""" + monkeypatch.setattr( + "command_line_assistant.rendering.decorators.text.get_xdg_state_path", + lambda: temp_state_dir, + ) + decorator = WriteOnceDecorator(filename) + result = decorator.decorate("Test text") -def test_emoji_decorator_with_unicode(): - decorator = EmojiDecorator("U+1F4A5") # Collision emoji - text = "Test text" - assert decorator.decorate(text) == "💥 Test text" + assert result == "Test text" + assert decorator._state_file.exists() diff --git a/tests/rendering/renders/__init__.py b/tests/rendering/renders/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/rendering/renders/test_spinner.py b/tests/rendering/renders/test_spinner.py new file mode 100644 index 0000000..4ae75fc --- /dev/null +++ b/tests/rendering/renders/test_spinner.py @@ -0,0 +1,183 @@ +import threading +import time +from unittest.mock import MagicMock + +import pytest + +from command_line_assistant.rendering.base import OutputStreamWritter +from command_line_assistant.rendering.decorators.colors import ColorDecorator +from command_line_assistant.rendering.decorators.text import ( + EmojiDecorator, + TextWrapDecorator, +) +from command_line_assistant.rendering.renders.spinner import Frames, SpinnerRenderer + + +class MockStream(OutputStreamWritter): + """Mock stream class for testing""" + + def __init__(self): + self.written = [] + super().__init__(stream=MagicMock()) + + def write(self, text: str) -> None: + self.written.append(text) + + def flush(self) -> None: + pass + + +@pytest.fixture +def mock_stream(): + return MockStream() + + +@pytest.fixture +def spinner(mock_stream): + return SpinnerRenderer("Loading...", stream=mock_stream) + + +def test_spinner_initialization(): + """Test spinner initialization with default values""" + spinner = SpinnerRenderer("Loading...") + assert spinner._message == "Loading..." + assert spinner._delay == 0.1 + assert spinner._clear_message is False + assert spinner._done.is_set() is False + assert spinner._spinner_thread is None + + +def test_spinner_custom_initialization(): + """Test spinner initialization with custom values""" + spinner = SpinnerRenderer( + message="Custom loading", frames=Frames.dash, delay=0.2, clear_message=True + ) + assert spinner._message == "Custom loading" + assert spinner._delay == 0.2 + assert spinner._clear_message is True + + +@pytest.mark.parametrize( + "decorator,expected_pattern", + [ + (ColorDecorator(foreground="green"), "\x1b[32mTest message\x1b[0m"), + (EmojiDecorator("🚀"), "🚀 Test message"), + (TextWrapDecorator(width=20), "Test message"), + ], +) +def test_spinner_decorator_application(spinner, decorator, expected_pattern): + """Test that different decorators are properly applied""" + spinner.update(decorator) + spinner.render("Test message") + + assert any(expected_pattern in text for text in spinner._stream.written) + + +def test_multiple_decorators(spinner): + """Test applying multiple decorators""" + spinner.update(ColorDecorator(foreground="blue")) + spinner.update(EmojiDecorator("⭐")) + spinner.render("Test message") + + written = spinner._stream.written[-1] + assert "⭐" in written + assert "\x1b[34m" in written # Blue color code + assert "Test message" in written + + +def test_spinner_start_stop(spinner): + """Test starting and stopping the spinner""" + spinner.start() + assert isinstance(spinner._spinner_thread, threading.Thread) + assert spinner._spinner_thread.is_alive() + + spinner.stop() + assert not spinner._spinner_thread.is_alive() + assert spinner._done.is_set() + + +def test_spinner_context_manager(spinner): + """Test spinner as context manager""" + with spinner: + assert spinner._spinner_thread.is_alive() + time.sleep(0.2) # Allow some frames to be written + + assert not spinner._spinner_thread.is_alive() + assert spinner._done.is_set() + assert len(spinner._stream.written) > 0 + + +def test_spinner_with_colored_text(mock_stream): + """Test spinner with colored text""" + spinner = SpinnerRenderer("Loading...", stream=mock_stream) + spinner.update(ColorDecorator(foreground="cyan")) + + with spinner: + time.sleep(0.2) + + # Check that color codes are present in output + assert any("\x1b[36m" in text for text in mock_stream.written) # Cyan color code + assert any("\x1b[0m" in text for text in mock_stream.written) # Reset code + + +def test_spinner_with_emoji_and_color(mock_stream): + """Test spinner with both emoji and color decorators""" + spinner = SpinnerRenderer("Processing...", stream=mock_stream) + spinner.update(ColorDecorator(foreground="yellow")) + spinner.update(EmojiDecorator("⚡")) + + with spinner: + time.sleep(0.2) + + written_text = mock_stream.written + assert any("⚡" in text for text in written_text) + assert any("\x1b[33m" in text for text in written_text) # Yellow color code + + +def test_spinner_with_text_wrap(mock_stream): + """Test spinner with text wrapping""" + long_message = "This is a very long message that should be wrapped" + spinner = SpinnerRenderer(long_message, stream=mock_stream) + spinner.update(TextWrapDecorator(width=20)) + + with spinner: + time.sleep(0.2) + + # Verify that the text was wrapped + written_text = mock_stream.written + assert any(len(line.strip()) <= 20 for line in written_text) + + +@pytest.mark.parametrize( + "frames", + [ + Frames.default, + Frames.dash, + Frames.circular, + Frames.dots, + Frames.arrows, + Frames.moving, + ], +) +def test_different_frame_styles(mock_stream, frames): + """Test that all frame styles work correctly""" + spinner = SpinnerRenderer( + "Testing frames", stream=mock_stream, frames=frames, delay=0.1 + ) + + with spinner: + time.sleep(0.2) # Allow some frames to be written + + assert len(mock_stream.written) > 0 + + +def test_spinner_clear_message(mock_stream): + """Test that clear_message properly clears the spinner message""" + spinner = SpinnerRenderer("Clear me", stream=mock_stream, clear_message=True) + + with spinner: + time.sleep(0.2) + + # Verify any written message contains clear spaces + written_text = mock_stream.written + assert any(" " * (len("Clear me") + 2) + "\r" in text for text in written_text) diff --git a/tests/rendering/test_render.py b/tests/rendering/renders/test_text.py similarity index 95% rename from tests/rendering/test_render.py rename to tests/rendering/renders/test_text.py index 236d9b7..af13c20 100644 --- a/tests/rendering/test_render.py +++ b/tests/rendering/renders/test_text.py @@ -1,7 +1,7 @@ from command_line_assistant.rendering.decorators.colors import ColorDecorator from command_line_assistant.rendering.decorators.style import StyleDecorator from command_line_assistant.rendering.decorators.text import TextWrapDecorator -from command_line_assistant.rendering.render import TextRenderer +from command_line_assistant.rendering.renders.text import TextRenderer def test_text_renderer_multiple_decorators(): @@ -61,7 +61,7 @@ def test_text_renderer_render_empty_text(capsys): captured = capsys.readouterr() # TODO(r0x0d): right now, we are still applying the color and everything else. # Maybe in the future we want to get rid of the formatting if we don't have text... - assert captured.out.strip() == "\x1b[32m\x1b[0m" + assert captured.out.strip() == "" def test_text_renderer_render_multiline(capsys): diff --git a/tests/rendering/test_spinner.py b/tests/rendering/test_spinner.py deleted file mode 100644 index f5fd004..0000000 --- a/tests/rendering/test_spinner.py +++ /dev/null @@ -1,106 +0,0 @@ -import sys -import threading -import time -from contextlib import contextmanager -from io import StringIO - -import pytest - -from command_line_assistant.rendering.spinner import Frames, ascii_spinner - - -@contextmanager -def capture_stdout(): - """Helper context manager to capture stdout for testing""" - stdout = StringIO() - old_stdout = sys.stdout - sys.stdout = stdout - try: - yield stdout - finally: - sys.stdout = old_stdout - - -def test_frames_default_values(): - """Test that Frames class has all expected default values""" - frames = Frames() - - # Test that all frame sequences exist - assert hasattr(frames, "default") - assert hasattr(frames, "braille") - assert hasattr(frames, "circular") - assert hasattr(frames, "dots") - assert hasattr(frames, "arrows") - assert hasattr(frames, "moving") - - -def test_frames_iteration(): - """Test that frame sequences can be iterated""" - frames = Frames() - - # Test default frames iteration - default_iterator = frames.default - first_frame = next(default_iterator) - assert first_frame in ["-", "\\", "|", "/"] - - # Test that it cycles - for _ in range(5): # More than number of frames - frame = next(default_iterator) - assert frame in ["-", "\\", "|", "/"] - - -def test_ascii_spinner_basic(): - """Test basic spinner functionality""" - with capture_stdout() as output: - with ascii_spinner("Loading", delay=0.1): - time.sleep(0.2) # Allow spinner to make at least one iteration - - captured = output.getvalue() - assert "Loading" in captured - assert "\r" in captured # Should use carriage return - - -def test_ascii_spinner_clear_message(): - """Test spinner with clear_message option""" - with capture_stdout() as output: - with ascii_spinner("Loading", clear_message=True, delay=0.1): - time.sleep(0.2) - - final_output = output.getvalue().split("\r")[-1] - assert len(final_output.strip()) == 0 # Should end with empty line - - -def test_ascii_spinner_custom_frames(): - """Test spinner with custom frames""" - custom_frames = iter(["A", "B", "C"]) - with capture_stdout() as output: - with ascii_spinner("Loading", frames=custom_frames, delay=0.1): - time.sleep(0.2) - - captured = output.getvalue() - assert any(frame in captured for frame in ["A", "B", "C"]) - - -def test_spinner_thread_cleanup(): - """Test that spinner properly cleans up its thread""" - initial_threads = threading.active_count() - - with ascii_spinner("Loading", delay=0.1): - time.sleep(0.2) - during_threads = threading.active_count() - assert during_threads > initial_threads # Should have one more thread - - time.sleep(0.2) # Give time for cleanup - after_threads = threading.active_count() - assert after_threads == initial_threads # Thread should be cleaned up - - -@pytest.mark.parametrize("delay", [0.1, 0.2, 0.5]) -def test_spinner_different_delays(delay): - """Test spinner with different delay values""" - start_time = time.time() - with ascii_spinner("Loading", delay=delay): - time.sleep(delay * 2) # Wait for at least 2 iterations - duration = time.time() - start_time - - assert duration >= delay * 2 diff --git a/tests/rendering/test_stream.py b/tests/rendering/test_stream.py new file mode 100644 index 0000000..097da3c --- /dev/null +++ b/tests/rendering/test_stream.py @@ -0,0 +1,226 @@ +import sys + +import pytest + +from command_line_assistant.rendering.stream import StderrStream, StdoutStream + + +class TestStdoutStream: + def test_initialization_default(self): + """Test StdoutStream initialization with default end character""" + stream = StdoutStream() + assert stream._stream == sys.stdout + assert stream._end == "\n" + + def test_initialization_custom_end(self): + """Test StdoutStream initialization with custom end character""" + stream = StdoutStream(end=">>>") + assert stream._stream == sys.stdout + assert stream._end == ">>>" + + def test_write(self, capsys): + """Test writing to stdout""" + stream = StdoutStream() + test_message = "Hello, World!" + stream.write(test_message) + + captured = capsys.readouterr() + assert captured.out == f"{test_message}\n" + + def test_write_custom_end(self, capsys): + """Test writing to stdout with custom end character""" + stream = StdoutStream(end=">>>") + test_message = "Hello, World!" + stream.write(test_message) + + captured = capsys.readouterr() + assert captured.out == f"{test_message}>>>" + + def test_write_empty_string(self, capsys): + """Test writing empty string to stdout""" + stream = StdoutStream() + stream.write("") + + captured = capsys.readouterr() + assert captured.out == "\n" + + def test_write_multiple_lines(self, capsys): + """Test writing multiple lines to stdout""" + stream = StdoutStream() + messages = ["Line 1", "Line 2", "Line 3"] + + for message in messages: + stream.write(message) + + captured = capsys.readouterr() + expected = "".join(f"{msg}\n" for msg in messages) + assert captured.out == expected + + def test_flush(self, capsys): + """Test flushing stdout""" + stream = StdoutStream() + stream.write("Test message") + stream.flush() + + captured = capsys.readouterr() + assert captured.out == "Test message\n" + + def test_execute(self, capsys): + """Test execute method with stdout""" + stream = StdoutStream() + test_message = "Execute test" + stream.execute(test_message) + + captured = capsys.readouterr() + assert captured.out == f"{test_message}\n" + + def test_execute_empty_string(self, capsys): + """Test execute method with empty string""" + stream = StdoutStream() + stream.execute("") + + captured = capsys.readouterr() + assert captured.out == "" + + +class TestStderrStream: + def test_initialization_default(self): + """Test StderrStream initialization with default end character""" + stream = StderrStream() + assert stream._stream == sys.stderr + assert stream._end == "\n" + + def test_initialization_custom_end(self): + """Test StderrStream initialization with custom end character""" + stream = StderrStream(end="!!!") + assert stream._stream == sys.stderr + assert stream._end == "!!!" + + def test_write(self, capsys): + """Test writing to stderr""" + stream = StderrStream() + test_message = "Error message" + stream.write(test_message) + + captured = capsys.readouterr() + assert captured.err == f"{test_message}\n" + + def test_write_custom_end(self, capsys): + """Test writing to stderr with custom end character""" + stream = StderrStream(end="!!!") + test_message = "Error message" + stream.write(test_message) + + captured = capsys.readouterr() + assert captured.err == f"{test_message}!!!" + + def test_write_empty_string(self, capsys): + """Test writing empty string to stderr""" + stream = StderrStream() + stream.write("") + + captured = capsys.readouterr() + assert captured.err == "\n" + + def test_write_multiple_lines(self, capsys): + """Test writing multiple lines to stderr""" + stream = StderrStream() + messages = ["Error 1", "Error 2", "Error 3"] + + for message in messages: + stream.write(message) + + captured = capsys.readouterr() + expected = "".join(f"{msg}\n" for msg in messages) + assert captured.err == expected + + def test_flush(self, capsys): + """Test flushing stderr""" + stream = StderrStream() + stream.write("Error message") + stream.flush() + + captured = capsys.readouterr() + assert captured.err == "Error message\n" + + def test_execute(self, capsys): + """Test execute method with stderr""" + stream = StderrStream() + test_message = "Execute error test" + stream.execute(test_message) + + captured = capsys.readouterr() + assert captured.err == f"{test_message}\n" + + def test_execute_empty_string(self, capsys): + """Test execute method with empty string""" + stream = StderrStream() + stream.execute("") + + captured = capsys.readouterr() + assert captured.err == "" + + +@pytest.mark.parametrize( + "StreamClass,expected_stream", + [ + (StdoutStream, sys.stdout), + (StderrStream, sys.stderr), + ], +) +def test_stream_initialization(StreamClass, expected_stream): + """Test initialization of both stream classes""" + stream = StreamClass() + assert stream._stream == expected_stream + + +@pytest.mark.parametrize("StreamClass", [StdoutStream, StderrStream]) +def test_stream_unicode(StreamClass, capsys): + """Test handling of Unicode characters in both streams""" + stream = StreamClass() + + unicode_messages = [ + "Hello, 世界!", # Japanese + "¡Hola, món!", # Catalan with Spanish punctuation + "🌟 Stars ✨", # Emojis + "θ, π, φ", # Greek letters + ] + + for message in unicode_messages: + stream.write(message) + captured = capsys.readouterr() + output = captured.out if isinstance(stream, StdoutStream) else captured.err + assert message in output + + +@pytest.mark.parametrize("StreamClass", [StdoutStream, StderrStream]) +def test_stream_long_messages(StreamClass, capsys): + """Test handling of long messages in both streams""" + stream = StreamClass() + + long_message = "x" * 10000 # 10K characters + stream.write(long_message) + + captured = capsys.readouterr() + output = captured.out if isinstance(stream, StdoutStream) else captured.err + assert output.strip() == long_message + + +@pytest.mark.parametrize("StreamClass", [StdoutStream, StderrStream]) +def test_stream_special_characters(StreamClass, capsys): + """Test handling of special characters in both streams""" + stream = StreamClass() + + special_messages = [ + "Tab\there", + "New\nline", + "Return\rchar", + "Back\\slash", + '"Quotes"', + ] + + for message in special_messages: + stream.write(message) + captured = capsys.readouterr() + output = captured.out if isinstance(stream, StdoutStream) else captured.err + assert message in output diff --git a/tests/test_initialize.py b/tests/test_initialize.py index 6b51ecf..3500e19 100644 --- a/tests/test_initialize.py +++ b/tests/test_initialize.py @@ -7,7 +7,7 @@ class MockCommand(BaseCLICommand): - def run(self): + def run(self): # type: ignore return True diff --git a/tests/utils/test_environment.py b/tests/utils/test_environment.py index 044d413..9c9b965 100644 --- a/tests/utils/test_environment.py +++ b/tests/utils/test_environment.py @@ -1,3 +1,5 @@ +from pathlib import Path + import pytest from command_line_assistant.utils import environment @@ -6,13 +8,27 @@ @pytest.mark.parametrize( ("xdg_path_env", "expected"), ( - ("", "/etc/xdg"), - ("/etc/xdg", "/etc/xdg"), - ("/etc/xdg:some/other/path", "/etc/xdg"), - ("no-path-xdg:what-iam-doing", "/etc/xdg"), - ("/my-special-one-path", "/my-special-one-path"), + ("", Path("/etc/xdg")), + ("/etc/xdg", Path("/etc/xdg")), + ("/etc/xdg:some/other/path", Path("/etc/xdg")), + ("no-path-xdg:what-iam-doing", Path("/etc/xdg")), + ("/my-special-one-path", Path("/my-special-one-path")), ), ) -def test_get_xdg_path(xdg_path_env, expected, monkeypatch): +def test_get_xdg_config_path(xdg_path_env, expected, monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("XDG_CONFIG_DIRS", xdg_path_env) - assert environment.get_xdg_path() == expected + assert environment.get_xdg_config_path() == expected + + +@pytest.mark.parametrize( + ("xdg_path_env", "expected"), + ( + ("", Path("some/dir")), + ("/etc/xdg", Path("/etc/xdg")), + ("/my-special-one-path", Path("/my-special-one-path")), + ), +) +def test_get_xdg_state_path(xdg_path_env, expected, monkeypatch: pytest.MonkeyPatch): + monkeypatch.setattr(environment, "WANTED_XDG_STATE_PATH", Path("some/dir")) + monkeypatch.setenv("XDG_STATE_HOME", xdg_path_env) + assert environment.get_xdg_state_path() == expected