Skip to content

Commit

Permalink
Add --input for query subcommand
Browse files Browse the repository at this point in the history
This is an alternative to include a file in your query alongside the
positional or stdin query.
  • Loading branch information
r0x0d committed Jan 9, 2025
1 parent c8e6eb3 commit 4b36142
Show file tree
Hide file tree
Showing 22 changed files with 316 additions and 87 deletions.
112 changes: 94 additions & 18 deletions command_line_assistant/commands/query.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Module to handle the query command."""

import argparse
import getpass
from argparse import Namespace
from io import TextIOWrapper
from typing import Optional

from command_line_assistant.dbus.constants import QUERY_IDENTIFIER
Expand All @@ -19,34 +21,43 @@
from command_line_assistant.rendering.renders.spinner import SpinnerRenderer
from command_line_assistant.rendering.renders.text import TextRenderer
from command_line_assistant.utils.cli import BaseCLICommand, SubParsersAction
from command_line_assistant.utils.files import is_content_in_binary_format
from command_line_assistant.utils.renderers import (
create_error_renderer,
create_spinner_renderer,
create_text_renderer,
create_warning_renderer,
)

#: Legal notice that we need to output once per user
LEGAL_NOTICE = (
"This feature uses AI technology. Do not include personal information or "
"other sensitive information in your input. Interactions may be used to "
"improve Red Hat's products or services."
)
#: Legal notice that we need to output once per user
ALWAYS_LEGAL_MESSAGE = "Always review AI generated content prior to use."
#: Always good to have legal message.
ALWAYS_LEGAL_MESSAGE = "Always review AI generated content prior to use."


class QueryCommand(BaseCLICommand):
"""Class that represents the query command."""

def __init__(self, query_string: str, stdin: Optional[str]) -> None:
def __init__(
self,
query_string: Optional[str] = None,
stdin: Optional[str] = None,
input: Optional[TextIOWrapper] = None,
) -> None:
"""Constructor of the class.
Args:
query_string (str): The query provided by the user.
stdin (Optional[str]): The user redirect input from stdin
query_string (Optional[str], optional): The query provided by the user.
stdin (Optional[str], optional): The user redirect input from stdin
input (Optional[TextIOWrapper], optional): The file input from the user
"""
self._query = query_string
self._stdin = stdin
self._query = query_string.strip() if query_string else None
self._stdin = stdin.strip() if stdin else None
self._input = input

self._spinner_renderer: SpinnerRenderer = create_spinner_renderer(
message="Requesting knowledge from AI",
Expand All @@ -61,13 +72,70 @@ def __init__(self, query_string: str, stdin: Optional[str]) -> None:
WriteOnceDecorator(state_filename="legal"),
]
)
self._warning_renderer: TextRenderer = create_text_renderer(
self._notice_renderer: TextRenderer = create_text_renderer(
decorators=[ColorDecorator(foreground="lightyellow")]
)
self._error_renderer: TextRenderer = create_error_renderer()

self._warning_renderer: TextRenderer = create_warning_renderer()
super().__init__()

def _get_input_source(self) -> str:
"""Determine and return the appropriate input source based on combination rules.
Rules:
1. Positional query only -> use positional query
2. Stdin query only -> use stdin query
3. File query only -> use file query
4. Stdin + positional query -> combine as "{positional_query} {stdin}"
5. Stdin + file query -> combine as "{stdin} {file_query}"
6. Positional + file query -> combine as "{positional_query} {file_query}"
7. All three sources -> use only positional and file as "{positional_query} {file_query}"
Raises:
ValueError: If no input source is provided
Returns:
str: The query string from the selected input source(s)
"""
file_content = None
if self._input:
file_content = self._input.read().strip()
if is_content_in_binary_format(file_content):
raise ValueError("File appears to be binary")

file_content = file_content.strip()

# Rule 7: All three present - positional and file take precedence
if all([self._query, self._stdin, file_content]):
self._warning_renderer.render(
"Using positional query and file input. Stdin will be ignored."
)
return f"{self._query} {file_content}"

# Rule 6: Positional + file
if self._query and file_content:
return f"{self._query} {file_content}"

# Rule 5: Stdin + file
if self._stdin and file_content:
return f"{self._stdin} {file_content}"

# Rule 4: Stdin + positional
if self._stdin and self._query:
return f"{self._query} {self._stdin}"

# Rules 1-3: Single source - return first non-empty source
source = next(
(src for src in [self._query, self._stdin, file_content] if src),
None,
)
if source:
return source

raise ValueError(
"No input provided. Please provide input via file, stdin, or direct query."
)

def run(self) -> int:
"""Main entrypoint for the command to run.
Expand All @@ -76,12 +144,11 @@ def run(self) -> int:
"""
proxy = QUERY_IDENTIFIER.get_proxy()

query = self._query
if self._stdin:
# If query is provided, the message becomes "{query} {stdin}",
# otherwise, to avoid submitting `None` as part of the query, let's
# default to submit only the stidn.
query = f"{query} {self._stdin}" if query else self._stdin
try:
query = self._get_input_source()
except ValueError as e:
self._error_renderer.render(str(e))
return 1

input_query = Message()
input_query.message = query
Expand All @@ -103,7 +170,7 @@ def run(self) -> int:

self._legal_renderer.render(LEGAL_NOTICE)
self._text_renderer.render(output)
self._warning_renderer.render(ALWAYS_LEGAL_MESSAGE)
self._notice_renderer.render(ALWAYS_LEGAL_MESSAGE)
return 0


Expand All @@ -122,6 +189,13 @@ def register_subcommand(parser: SubParsersAction) -> None:
query_parser.add_argument(
"query_string", nargs="?", help="Query string to be processed."
)
query_parser.add_argument(
"-i",
"--input",
nargs="?",
type=argparse.FileType("r"),
help="Read file from user system.",
)

query_parser.set_defaults(func=_command_factory)

Expand All @@ -135,8 +209,10 @@ def _command_factory(args: Namespace) -> QueryCommand:
Returns:
QueryCommand: Return an instance of class
"""
options = {"query_string": args.query_string, "stdin": None, "input": args.input}

# We may not always have the stdin argument in the namespace.
if "stdin" in args:
return QueryCommand(args.query_string, args.stdin)
options["stdin"] = args.stdin

return QueryCommand(args.query_string, None)
return QueryCommand(**options)
2 changes: 1 addition & 1 deletion command_line_assistant/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@
import tomli as tomllib # pyright: ignore[reportMissingImports]


#: Define the config file path.
CONFIG_FILE_DEFINITION: tuple[str, str] = (
"command-line-assistant",
"config.toml",
)
#: Define the config file path.

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion command_line_assistant/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Module to track constants for the project."""

VERSION = "0.1.0"
#: Define the version for the program
VERSION = "0.1.0"
2 changes: 1 addition & 1 deletion command_line_assistant/daemon/http/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from command_line_assistant.daemon.http.adapters import RetryAdapter, SSLAdapter
from command_line_assistant.dbus.exceptions import RequestFailedError

USER_AGENT = f"clad/{VERSION}"
#: Define the custom user agent for clad
USER_AGENT = f"clad/{VERSION}"

logger = logging.getLogger(__name__)

Expand Down
14 changes: 7 additions & 7 deletions command_line_assistant/dbus/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,27 @@
from dasbus.error import ErrorMapper
from dasbus.identifier import DBusServiceIdentifier

ERROR_MAPPER: ErrorMapper = ErrorMapper()
#: Instance of error mapping to allow d-bus to serialize exceptions.
ERROR_MAPPER: ErrorMapper = ErrorMapper()

SYSTEM_BUS: SystemMessageBus = SystemMessageBus(error_mapper=ERROR_MAPPER)
#: System bus with error mapping to serialize exceptions.
SYSTEM_BUS: SystemMessageBus = SystemMessageBus(error_mapper=ERROR_MAPPER)

SERVICE_NAMESPACE = ("com", "redhat", "lightspeed")
#: The base-level service namespace
SERVICE_NAMESPACE = ("com", "redhat", "lightspeed")

QUERY_NAMESAPCE = (*SERVICE_NAMESPACE, "query")
#: The query namespace
QUERY_NAMESAPCE = (*SERVICE_NAMESPACE, "query")

HISTORY_NAMESPACE = (*SERVICE_NAMESPACE, "history")
#: The history namespace
HISTORY_NAMESPACE = (*SERVICE_NAMESPACE, "history")

#: The query identifier that represents a dbus service
QUERY_IDENTIFIER = DBusServiceIdentifier(
namespace=QUERY_NAMESAPCE, message_bus=SYSTEM_BUS
)
#: The query identifier that represents a dbus service

#: The history identifier that represents a dbus service
HISTORY_IDENTIFIER = DBusServiceIdentifier(
namespace=HISTORY_NAMESPACE, message_bus=SYSTEM_BUS
)
#: The history identifier that represents a dbus service
2 changes: 1 addition & 1 deletion command_line_assistant/dbus/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
SERVICE_NAMESPACE,
)

dbus_error = get_error_decorator(ERROR_MAPPER)
#: Special decorator for mapping exceptions to dbus style exceptions
dbus_error = get_error_decorator(ERROR_MAPPER)


@dbus_error("NotAuthorizedUser", namespace=SERVICE_NAMESPACE)
Expand Down
2 changes: 1 addition & 1 deletion command_line_assistant/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from command_line_assistant.config import Config

#: Define the dictionary configuration for the logger instance
LOGGING_CONFIG_DICTIONARY = {
"version": 1,
"disable_existing_loggers": False,
Expand Down Expand Up @@ -48,7 +49,6 @@
},
},
}
#: Define the dictionary configuration for the logger instance


class AuditFormatter(logging.Formatter):
Expand Down
43 changes: 43 additions & 0 deletions command_line_assistant/utils/files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Utilitary module to handle file operations"""

from typing import Union

#: Common binary signatures
BINARY_SIGNATURES = [
b"\x7fELF", # ELF files
b"%PDF", # PDF files
b"PK\x03\x04", # ZIP files
]


def is_content_in_binary_format(content: Union[str, bytes]) -> bool:
"""Check if a given content is in binary format.
Args:
content (str): The content to be checked for binary presence.
Raises:
ValueError: If the content is binary or contains invalid text encoding.
Returns:
bool: True if the content is binary, False otherwise.
"""
try:
# Try to decode as utf-8
if isinstance(content, bytes):
content = content.decode("utf-8")

# Check for null bytes which often indicate binary data
if "\0" in content:
return True

# Additional check for common binary file signatures
content_bytes = content.encode("utf-8") if isinstance(content, str) else content
if any(content_bytes.startswith(sig) for sig in BINARY_SIGNATURES):
return True
except UnicodeDecodeError as e:
raise ValueError(
"File appears to be binary or contains invalid text encoding"
) from e

return False
21 changes: 20 additions & 1 deletion command_line_assistant/utils/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)
from command_line_assistant.rendering.renders.spinner import SpinnerRenderer
from command_line_assistant.rendering.renders.text import TextRenderer
from command_line_assistant.rendering.stream import StdoutStream
from command_line_assistant.rendering.stream import StderrStream, StdoutStream


def create_error_renderer() -> TextRenderer:
Expand All @@ -32,6 +32,25 @@ def create_error_renderer() -> TextRenderer:
return renderer


def create_warning_renderer() -> TextRenderer:
"""Create a standarized instance of text rendering for error output
Returns:
TextRenderer: Instance of a TextRenderer with correct decorators for
error output.
"""
renderer = TextRenderer(StderrStream())
renderer.update(
[
EmojiDecorator(emoji="0x1f914"),
ColorDecorator(foreground="yellow"),
TextWrapDecorator(),
]
)

return renderer


def create_spinner_renderer(
message: str, decorators: list[BaseDecorator]
) -> SpinnerRenderer:
Expand Down
Loading

0 comments on commit 4b36142

Please sign in to comment.