Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RSPEED-412] Add filter history option #101

Merged
merged 1 commit into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 31 additions & 9 deletions command_line_assistant/commands/history.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Module to handle the history command."""

from argparse import Namespace
from typing import Optional

from command_line_assistant.dbus.constants import HISTORY_IDENTIFIER
from command_line_assistant.dbus.exceptions import (
Expand All @@ -25,7 +26,9 @@
class HistoryCommand(BaseCLICommand):
"""Class that represents the history command."""

def __init__(self, clear: bool, first: bool, last: bool) -> None:
def __init__(
self, clear: bool, first: bool, last: bool, filter: Optional[str] = None
) -> None:
"""Constructor of the class.

Note:
Expand All @@ -36,10 +39,12 @@
clear (bool): If the history should be cleared
first (bool): Retrieve only the first conversation from history
last (bool): Retrieve only last conversation from history
filter (Optional[str], optional): Keyword to filter in the user history
"""
self._clear = clear
self._first = first
self._last = last
self._filter = filter

self._proxy = HISTORY_IDENTIFIER.get_proxy()

Expand Down Expand Up @@ -67,14 +72,13 @@
try:
if self._clear:
self._clear_history()

if self._first:
elif self._first:
self._retrieve_first_conversation()

if self._last:
elif self._last:
self._retrieve_last_conversation()

if not self._last and not self._clear and not self._first:
elif self._filter:
self._retrieve_conversation_filtered(self._filter)

Check warning on line 80 in command_line_assistant/commands/history.py

View check run for this annotation

Codecov / codecov/patch

command_line_assistant/commands/history.py#L80

Added line #L80 was not covered by tests
else:
self._retrieve_all_conversations()

return 0
Expand All @@ -100,7 +104,22 @@
# Display the conversation
self._show_history(history.entries)

def _retrieve_last_conversation(self):
def _retrieve_conversation_filtered(self, filter: str) -> None:
"""Retrieve the user conversation with keyword filtering.

Args:
filter (str): Keyword to filter in the user history
"""
self._text_renderer.render("Filtering conversation history.")
response = self._proxy.GetFilteredConversation(filter)

# Handle and display the response
history = HistoryEntry.from_structure(response)

# Display the conversation
self._show_history(history.entries)

def _retrieve_last_conversation(self) -> None:
"""Retrieve the last conversation in the conversation cache."""
self._text_renderer.render("Getting last conversation from history.")
response = self._proxy.GetLastConversation()
Expand Down Expand Up @@ -165,6 +184,9 @@
action="store_true",
help="Get the last conversation from history.",
)
history_parser.add_argument(
"--filter", help="Search for a specific string of text in the history."
)
history_parser.set_defaults(func=_command_factory)


Expand All @@ -177,4 +199,4 @@
Returns:
HistoryCommand: Return an instance of class
"""
return HistoryCommand(args.clear, args.first, args.last)
return HistoryCommand(args.clear, args.first, args.last, args.filter)

Check warning on line 202 in command_line_assistant/commands/history.py

View check run for this annotation

Codecov / codecov/patch

command_line_assistant/commands/history.py#L202

Added line #L202 was not covered by tests
37 changes: 36 additions & 1 deletion command_line_assistant/dbus/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dasbus.server.interface import dbus_interface
from dasbus.server.property import emits_properties_changed
from dasbus.server.template import InterfaceTemplate
from dasbus.typing import Structure
from dasbus.typing import Str, Structure

from command_line_assistant.daemon.http.query import submit
from command_line_assistant.dbus.constants import HISTORY_IDENTIFIER, QUERY_IDENTIFIER
Expand All @@ -17,6 +17,7 @@
from command_line_assistant.history.plugins.local import LocalHistory

audit_logger = logging.getLogger("audit")
logger = logging.getLogger(__name__)


@dbus_interface(QUERY_IDENTIFIER.interface_name)
Expand Down Expand Up @@ -121,6 +122,40 @@ def GetLastConversation(self) -> Structure:

return HistoryEntry.to_structure(history_entry)

def GetFilteredConversation(self, filter: Str) -> Structure:
"""Get last conversation from history.

Args:
filter (str): The filter

Returns:
Structure: A single history entyr in a dbus structure format.
"""
manager = HistoryManager(self.implementation.config, LocalHistory)
history = manager.read()
history_entry = HistoryEntry()
found_entries = []

if history.history:
logger.info("Filtering the user history with keyword '%s'", filter)
# We ignore the type in the condition as pyright thinks that "Str" is not "str".
# Pyright is correct about this, but "Str" is a special type for dbus. It will be "str" in the end.
found_entries = [
entry
for entry in history.history
if (
filter in entry.interaction.query.text # type: ignore
or filter in entry.interaction.response.text # type: ignore
)
]

logger.info("Found %s entries in the history", len(found_entries))
# Normalize the entries to send over dbus
_ = [
history_entry.set_from_dict(entry.to_dict()) for entry in set(found_entries)
]
return HistoryEntry.to_structure(history_entry)

def ClearHistory(self) -> None:
"""Clear the user history."""
manager = HistoryManager(self.implementation.config, LocalHistory)
Expand Down
3 changes: 3 additions & 0 deletions command_line_assistant/history/plugins/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def read(self) -> History:

filepath = self._config.history.file

logger.info("Reading history at %s", filepath)
try:
data = filepath.read_text()
return History.from_json(data)
Expand Down Expand Up @@ -63,6 +64,7 @@ def write(self, current_history: History, query: str, response: str) -> None:

filepath = self._config.history.file
final_history = self._add_new_entry(current_history, query, response)
logger.info("Writting user history at %s", filepath)
try:
filepath.write_text(final_history.to_json())
except json.JSONDecodeError as e:
Expand All @@ -85,6 +87,7 @@ def clear(self) -> None:
# Write empty history
current_history = History()
filepath = self._config.history.file
logger.info("Clearing history at %s", filepath)
try:
filepath.write_text(current_history.to_json())
logger.info("History cleared successfully")
Expand Down
12 changes: 6 additions & 6 deletions command_line_assistant/history/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from command_line_assistant.constants import VERSION


@dataclass
@dataclass(frozen=True)
class QueryData:
"""Schema to represent a query emited by the user.

Expand All @@ -23,7 +23,7 @@ class QueryData:
role: str = "user"


@dataclass
@dataclass(frozen=True)
class ResponseData:
"""Schema to represent the LLM response.

Expand All @@ -38,7 +38,7 @@ class ResponseData:
role: str = "assistant"


@dataclass
@dataclass(frozen=True)
class InteractionData:
"""Schema to represent the interaction data between user and LLM.

Expand All @@ -51,7 +51,7 @@ class InteractionData:
response: ResponseData = field(default_factory=ResponseData)


@dataclass
@dataclass(frozen=True)
class OSInfo:
"""Schema to represent the system information

Expand All @@ -66,7 +66,7 @@ class OSInfo:
arch: str = platform.architecture()[0]


@dataclass
@dataclass(frozen=True)
class EntryMetadata:
"""Schema to represent the entry metadata information

Expand All @@ -79,7 +79,7 @@ class EntryMetadata:
os_info: OSInfo = field(default_factory=OSInfo)


@dataclass
@dataclass(frozen=True)
class HistoryEntry:
"""Schema to represent an entry of the history

Expand Down
2 changes: 1 addition & 1 deletion command_line_assistant/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
},
"audit_file": {
"class": "logging.FileHandler",
"filename": "/var/log/audit/command-line-assistant.log",
"filename": "/tmp/command-line-assistant.log",
"formatter": "audit",
"mode": "a",
},
Expand Down
33 changes: 33 additions & 0 deletions tests/commands/test_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,39 @@ def test_retrieve_all_conversations_empty(mock_proxy, capsys):
assert "No history found.\n" in captured.out


def test_retrieve_conversation_filtered_empty(mock_proxy, capsys):
"""Test retrieving first conversation when history is empty."""
empty_history = HistoryEntry()
mock_proxy.GetFilteredConversation.return_value = empty_history.to_structure(
empty_history
)

HistoryCommand(
clear=False, first=True, last=False, filter="missing"
)._retrieve_conversation_filtered(filter="missing")
captured = capsys.readouterr()
assert "No history found.\n" in captured.out


def test_retrieve_conversation_filtered_success(
mock_proxy, sample_history_entry, capsys
):
"""Test retrieving last conversation successfully."""
mock_proxy.GetFilteredConversation.return_value = sample_history_entry.to_structure(
sample_history_entry
)

HistoryCommand(
clear=False, first=False, last=True, filter="test"
)._retrieve_conversation_filtered(filter="missing")
captured = capsys.readouterr()
mock_proxy.GetFilteredConversation.assert_called_once()
assert (
"\x1b[92mQuery: test query\x1b[0m\n\x1b[94mAnswer: test response\x1b[0m\n"
in captured.out
)


def test_retrieve_first_conversation_success(mock_proxy, sample_history_entry, capsys):
"""Test retrieving first conversation successfully."""
mock_proxy.GetFirstConversation.return_value = sample_history_entry.to_structure(
Expand Down
35 changes: 35 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,38 @@ def mock_proxy():
@pytest.fixture
def mock_stream():
return MockStream()


@pytest.fixture
def sample_history_data():
"""Create sample history data for testing."""
return {
"history": [
{
"id": "test-id",
"timestamp": "2024-01-01T00:00:00Z",
"interaction": {
"query": {"text": "test query", "role": "user"},
"response": {
"text": "test response",
"tokens": 2,
"role": "assistant",
},
},
"metadata": {
"session_id": "test-session",
"os_info": {
"distribution": "RHEL",
"version": "test",
"arch": "x86_64",
},
},
}
],
"metadata": {
"last_updated": "2024-01-01T00:00:00Z",
"version": "0.1.0",
"entry_count": 1,
"size_bytes": 0,
},
}
Loading
Loading