From 012e6427ac65f69830d1be176a7eb0fb1d545e26 Mon Sep 17 00:00:00 2001 From: Rodolfo Olivieri Date: Wed, 15 Jan 2025 09:42:49 -0400 Subject: [PATCH] Add user_id to history tables. (#112) The user_id is a special way for us to control which record is tied to which user. This will allow us to be prepared for a multi-user environment. --- command_line_assistant/commands/history.py | 13 +- command_line_assistant/commands/query.py | 17 ++- command_line_assistant/config/schemas.py | 29 ++-- .../daemon/database/models/history.py | 2 +- command_line_assistant/daemon/session.py | 20 +-- command_line_assistant/dbus/context.py | 53 +------ command_line_assistant/dbus/interfaces.py | 108 +++++++------- command_line_assistant/dbus/server.py | 17 +-- command_line_assistant/history/base.py | 7 +- command_line_assistant/history/manager.py | 17 ++- .../history/plugins/local.py | 18 ++- command_line_assistant/logger.py | 39 +++++- command_line_assistant/utils/cli.py | 25 ++++ .../config/command-line-assistant/config.toml | 7 +- tests/commands/test_query.py | 25 ++-- tests/conftest.py | 51 +++++-- tests/daemon/database/test_manager.py | 5 +- tests/daemon/test_session.py | 26 ++-- tests/dbus/test_context.py | 44 +----- tests/dbus/test_interfaces.py | 60 +++----- tests/dbus/test_server.py | 6 - tests/history/plugins/test_local.py | 132 ++++-------------- tests/history/test_manager.py | 20 +-- tests/test_logger.py | 24 +++- 24 files changed, 329 insertions(+), 436 deletions(-) diff --git a/command_line_assistant/commands/history.py b/command_line_assistant/commands/history.py index 447cfa1..b1340c1 100644 --- a/command_line_assistant/commands/history.py +++ b/command_line_assistant/commands/history.py @@ -69,6 +69,7 @@ def run(self) -> int: Returns: int: Status code of the execution. """ + try: if self._clear: self._clear_history() @@ -89,7 +90,7 @@ def run(self) -> int: def _retrieve_all_conversations(self) -> None: """Retrieve and display all conversations from history.""" self._text_renderer.render("Getting all conversations from history.") - response = self._proxy.GetHistory() + response = self._proxy.GetHistory(self._context.effective_user_id) history = HistoryEntry.from_structure(response) # Display the conversation @@ -98,7 +99,7 @@ def _retrieve_all_conversations(self) -> None: def _retrieve_first_conversation(self) -> None: """Retrieve the first conversation in the conversation cache.""" self._text_renderer.render("Getting first conversation from history.") - response = self._proxy.GetFirstConversation() + response = self._proxy.GetFirstConversation(self._context.effective_user_id) history = HistoryEntry.from_structure(response) # Display the conversation @@ -111,7 +112,9 @@ def _retrieve_conversation_filtered(self, filter: str) -> None: filter (str): Keyword to filter in the user history """ self._text_renderer.render("Filtering conversation history.") - response = self._proxy.GetFilteredConversation(filter) + response = self._proxy.GetFilteredConversation( + self._context.effective_user_id, filter + ) # Handle and display the response history = HistoryEntry.from_structure(response) @@ -122,7 +125,7 @@ def _retrieve_conversation_filtered(self, filter: str) -> None: def _retrieve_last_conversation(self) -> None: """Retrieve the last conversation in the conversation cache.""" self._text_renderer.render("Getting last conversation from history.") - response = self._proxy.GetLastConversation() + response = self._proxy.GetLastConversation(self._context.effective_user_id) # Handle and display the response history = HistoryEntry.from_structure(response) @@ -133,7 +136,7 @@ def _retrieve_last_conversation(self) -> None: def _clear_history(self) -> None: """Clear the user history""" self._text_renderer.render("Cleaning the history.") - self._proxy.ClearHistory() + self._proxy.ClearHistory(self._context.effective_user_id) def _show_history(self, entries: list[HistoryItem]) -> None: """Internal method to show the history in a standarized way diff --git a/command_line_assistant/commands/query.py b/command_line_assistant/commands/query.py index e5a1b45..44ea9d8 100644 --- a/command_line_assistant/commands/query.py +++ b/command_line_assistant/commands/query.py @@ -1,7 +1,6 @@ """Module to handle the query command.""" import argparse -import getpass from argparse import Namespace from io import TextIOWrapper from typing import Optional @@ -77,6 +76,9 @@ def __init__( ) self._error_renderer: TextRenderer = create_error_renderer() self._warning_renderer: TextRenderer = create_warning_renderer() + + self._proxy = QUERY_IDENTIFIER.get_proxy() + super().__init__() def _get_input_source(self) -> str: @@ -142,24 +144,21 @@ def run(self) -> int: Returns: int: Status code of the execution """ - proxy = QUERY_IDENTIFIER.get_proxy() try: - query = self._get_input_source() + question = self._get_input_source() except ValueError as e: self._error_renderer.render(str(e)) return 1 - input_query = Message() - input_query.message = query - # Get the current user - input_query.user = getpass.getuser() output = "Nothing to see here..." try: with self._spinner_renderer: - proxy.ProcessQuery(input_query.to_structure(input_query)) - output = Message.from_structure(proxy.RetrieveAnswer()).message + response = self._proxy.AskQuestion( + self._context.effective_user_id, question + ) + output = Message.from_structure(response).message except ( RequestFailedError, MissingHistoryFileError, diff --git a/command_line_assistant/config/schemas.py b/command_line_assistant/config/schemas.py index b95ad5c..410ffc3 100644 --- a/command_line_assistant/config/schemas.py +++ b/command_line_assistant/config/schemas.py @@ -1,6 +1,8 @@ """Module to hold the config schema and it's sub schemas.""" +import copy import dataclasses +import pwd from pathlib import Path from typing import Optional, Union @@ -96,22 +98,17 @@ def __post_init__(self) -> None: self.level = self.level.upper() - def should_log_for_user(self, username: str, log_type: str) -> bool: - """Check if logging should be enabled for a specific user and log type. - - Args: - username (str): The username to check - log_type (str): The type of log ('responses' or 'question') - - Returns: - bool: Whether logging should be enabled for this user and log type - """ - # If user has specific settings, use those - if username in self.users: - return self.users[username].get(log_type, False) - - # Otherwise fall back to global settings - return getattr(self, log_type, False) + if self.users: + # Turn any username to their effective_user_id + defined_users = copy.deepcopy(self.users) + for user in defined_users.keys(): + try: + effective_user_id = str(pwd.getpwnam(user).pw_uid) + self.users[effective_user_id] = self.users.pop(user) + except KeyError as e: + raise ValueError( + f"{user} is not present on the system. Remove it from the configuration." + ) from e @dataclasses.dataclass diff --git a/command_line_assistant/daemon/database/models/history.py b/command_line_assistant/daemon/database/models/history.py index a933eac..5f5859d 100644 --- a/command_line_assistant/daemon/database/models/history.py +++ b/command_line_assistant/daemon/database/models/history.py @@ -16,6 +16,7 @@ class HistoryModel(BaseModel): __tablename__ = "history" id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + user_id = Column(UUID(as_uuid=True), nullable=False) timestamp = Column(DateTime, default=datetime.utcnow()) deleted_at = Column(DateTime, nullable=True) @@ -37,7 +38,6 @@ class InteractionModel(BaseModel): response_text = Column(String) response_role = Column(String, default="assistant") response_tokens = Column(Integer, default=0) - session_id = Column(UUID(as_uuid=True), nullable=False, default=uuid.uuid4) os_distribution = Column(String, default="RHEL") os_version = Column(String, nullable=False) os_arch = Column(String, nullable=False) diff --git a/command_line_assistant/daemon/session.py b/command_line_assistant/daemon/session.py index 71b73dc..3a2b6d8 100644 --- a/command_line_assistant/daemon/session.py +++ b/command_line_assistant/daemon/session.py @@ -3,7 +3,7 @@ import logging import uuid from pathlib import Path -from typing import Optional +from typing import Optional, Union #: Path to the machine ID file MACHINE_ID_PATH: Path = Path("/etc/machine-id") @@ -14,15 +14,15 @@ class UserSessionManager: """Manage user session information.""" - def __init__(self, effective_user: str) -> None: + def __init__(self, effective_user_id: Union[int, str]) -> None: """Initialize the session manager. Args: - effective_user (str): The effective user id + effective_user_id (Union[int, str]): The effective user id """ + self._effective_user_id: Union[int, str] = effective_user_id self._machine_uuid: Optional[uuid.UUID] = None - self._effective_user: str = effective_user - self._session_uuid: Optional[uuid.UUID] = None + self._user_id: Optional[uuid.UUID] = None @property def machine_id(self) -> uuid.UUID: @@ -55,17 +55,17 @@ def machine_id(self) -> uuid.UUID: return self._machine_uuid @property - def session_id(self) -> uuid.UUID: - """Property that generates a unique session ID combining machine and user effective id. + def user_id(self) -> uuid.UUID: + """Property that generates a unique user ID combining machine and user effective id. Returns: uuid.UUID: A unique session identifier """ - if not self._session_uuid: + if not self._user_id: # Combine machine ID and effective user to create a unique namespace namespace = self.machine_id # Generate a UUID using the effective username as name in the namespace - self._session_uuid = uuid.uuid5(namespace, self._effective_user) + self._user_id = uuid.uuid5(namespace, str(self._effective_user_id)) - return self._session_uuid + return self._user_id diff --git a/command_line_assistant/dbus/context.py b/command_line_assistant/dbus/context.py index 1d332bc..bd77128 100644 --- a/command_line_assistant/dbus/context.py +++ b/command_line_assistant/dbus/context.py @@ -1,15 +1,10 @@ """D-Bus context classes for managing the commands""" -from typing import Optional - -from dasbus.signal import Signal - from command_line_assistant.config import Config -from command_line_assistant.dbus.structures import Message -class BaseContext: - """Base class for context that defines the structure of it.""" +class DaemonContext: + """Context class for context that defines the structure of it.""" def __init__(self, config: Config) -> None: """Constructor of the class. @@ -27,47 +22,3 @@ def config(self) -> Config: Config: Instance of the configuration class """ return self._config - - -class QueryContext(BaseContext): - """This is the process context that will handle anything query related""" - - def __init__(self, config: Config) -> None: - """Constructor of the class. - - Args: - config (Config): Instance of the configuration class - """ - self._input_query: Optional[Message] = None - self._query_changed = Signal() - super().__init__(config) - - @property - def query(self) -> Optional[Message]: - """Property for the internal query attribute. - - Returns: - Optional[Message]: The user query wrapped in a `py:Message` dbus structure. - """ - return self._input_query - - def process_query(self, input_query: Message) -> None: - """Emit the signal that the query has changed. - - Args: - input_query (Message): The user query - """ - self._input_query = input_query - self._query_changed.emit() - - -class HistoryContext(BaseContext): - """This is the process context that will handle anything query related""" - - def __init__(self, config: Config) -> None: - """Constructor of the class. - - Args: - config (Config): Instance of the configuration class. - """ - super().__init__(config) diff --git a/command_line_assistant/dbus/interfaces.py b/command_line_assistant/dbus/interfaces.py index 80b4257..89a5971 100644 --- a/command_line_assistant/dbus/interfaces.py +++ b/command_line_assistant/dbus/interfaces.py @@ -3,9 +3,8 @@ import logging from dasbus.server.interface import dbus_interface -from dasbus.server.property import emits_properties_changed from dasbus.server.template import InterfaceTemplate -from dasbus.typing import Str, Structure +from dasbus.typing import Int, Str, Structure from command_line_assistant.daemon.http.query import submit from command_line_assistant.dbus.constants import HISTORY_IDENTIFIER, QUERY_IDENTIFIER @@ -25,117 +24,94 @@ class QueryInterface(InterfaceTemplate): """The DBus interface of a query.""" - def RetrieveAnswer(self) -> Structure: + def AskQuestion(self, effective_user_id: Int, question: Str) -> Structure: """This method is mainly called by the client to retrieve it's answer. Returns: Structure: The message output in format of a d-bus structure. """ - query = self.implementation.query.message - user = self.implementation.query.user - # Submit query to backend - llm_response = submit(query, self.implementation.config) + llm_response = submit(question, self.implementation.config) # Create message object message = Message() message.message = llm_response # Deal with history management - manager = HistoryManager(self.implementation.config, LocalHistory) - manager.write(query, llm_response) + manager = HistoryManager( + self.implementation.config, effective_user_id, LocalHistory + ) + manager.write(question, llm_response) audit_logger.info( "Query executed successfully.", extra={ - "user": user, - "query": query, + "user": effective_user_id, + "query": question, "response": llm_response, }, ) # Return the data return Message.to_structure(message) - @emits_properties_changed - def ProcessQuery(self, query: Structure) -> None: - """Process a given query by the user - - Args: - query (Structure): The user query - """ - self.implementation.process_query(Message.from_structure(query)) - @dbus_interface(HISTORY_IDENTIFIER.interface_name) class HistoryInterface(InterfaceTemplate): """The DBus interface of a history""" - def _parse_history_entries(self, entries: list[dict[str, str]]) -> HistoryEntry: - """Parse the history entries in a common format for all methods - - Args: - entries (list[dict[str, str]]): List of entries in a dictionary format with only the necessary information. - - Returns: - HistoryEntry: An instance of HistoryEntry with all necessary information. - """ - history_entry = HistoryEntry() - for entry in entries: - history_item = HistoryItem() - history_item.query = entry["query"] - history_item.response = entry["response"] - history_item.timestamp = entry["timestamp"] - history_entry.entries.append(history_item) - - return history_entry - - def GetHistory(self) -> Structure: + def GetHistory(self, effective_user_id: Int) -> Structure: """Get all conversations from history. Returns: Structure: The history entries in a dbus structure format. """ - manager = HistoryManager(self.implementation.config, LocalHistory) + manager = HistoryManager( + self.implementation.config, effective_user_id, LocalHistory + ) history_entries = manager.read() history_entry = HistoryEntry() if history_entries: - history_entry = self._parse_history_entries(history_entries) + history_entry = _parse_history_entries(history_entries) return HistoryEntry.to_structure(history_entry) # Add new methods with parameters - def GetFirstConversation(self) -> Structure: + def GetFirstConversation(self, effective_user_id: Int) -> Structure: """Get first conversation from history. Returns: Structure: A single history entry in a dbus structure format. """ - manager = HistoryManager(self.implementation.config, LocalHistory) + manager = HistoryManager( + self.implementation.config, effective_user_id, LocalHistory + ) history_entries = manager.read() history_entry = HistoryEntry() if history_entries: - history_entry = self._parse_history_entries(history_entries[:1]) + history_entry = _parse_history_entries(history_entries[:1]) return HistoryEntry.to_structure(history_entry) - def GetLastConversation(self) -> Structure: + def GetLastConversation(self, effective_user_id: Int) -> Structure: """Get last conversation from history. Returns: Structure: A single history entyr in a dbus structure format. """ - manager = HistoryManager(self.implementation.config, LocalHistory) + manager = HistoryManager( + self.implementation.config, effective_user_id, LocalHistory + ) history_entries = manager.read() history_entry = HistoryEntry() if history_entries: - history_entry = self._parse_history_entries(history_entries[-1:]) + history_entry = _parse_history_entries(history_entries[-1:]) return HistoryEntry.to_structure(history_entry) - def GetFilteredConversation(self, filter: Str) -> Structure: + def GetFilteredConversation(self, effective_user_id: Int, filter: Str) -> Structure: """Get last conversation from history. Args: @@ -144,7 +120,9 @@ def GetFilteredConversation(self, filter: Str) -> Structure: Returns: Structure: A single history entyr in a dbus structure format. """ - manager = HistoryManager(self.implementation.config, LocalHistory) + manager = HistoryManager( + self.implementation.config, effective_user_id, LocalHistory + ) history_entries = manager.read() history_entry = HistoryEntry() @@ -157,11 +135,35 @@ def GetFilteredConversation(self, filter: Str) -> Structure: if (filter in entry["query"] or filter in entry["response"]) ] - history_entry = self._parse_history_entries(filtered_entries) + history_entry = _parse_history_entries(filtered_entries) return HistoryEntry.to_structure(history_entry) - def ClearHistory(self) -> None: + def ClearHistory(self, effective_user_id: Int) -> None: """Clear the user history.""" - manager = HistoryManager(self.implementation.config, LocalHistory) + manager = HistoryManager( + self.implementation.config, effective_user_id, LocalHistory + ) manager.clear() + + +def _parse_history_entries(entries: list[dict[str, str]]) -> HistoryEntry: + """Parse the history entries in a common format for all methods + + Args: + entries (list[dict[str, str]]): List of entries in a dictionary format + with only the necessary information. + + Returns: + HistoryEntry: An instance of HistoryEntry with all necessary + information. + """ + history_entry = HistoryEntry() + for entry in entries: + history_item = HistoryItem() + history_item.query = entry["query"] + history_item.response = entry["response"] + history_item.timestamp = entry["timestamp"] + history_entry.entries.append(history_item) + + return history_entry diff --git a/command_line_assistant/dbus/server.py b/command_line_assistant/dbus/server.py index c96a5e1..3be1a7f 100644 --- a/command_line_assistant/dbus/server.py +++ b/command_line_assistant/dbus/server.py @@ -2,7 +2,6 @@ import logging -from dasbus.constants import DBUS_NAME_FLAG_REPLACE_EXISTING from dasbus.loop import EventLoop from command_line_assistant.config import Config @@ -11,7 +10,9 @@ QUERY_IDENTIFIER, SYSTEM_BUS, ) -from command_line_assistant.dbus.context import HistoryContext, QueryContext +from command_line_assistant.dbus.context import ( + DaemonContext, +) from command_line_assistant.dbus.interfaces import HistoryInterface, QueryInterface logger = logging.getLogger(__name__) @@ -26,19 +27,15 @@ def serve(config: Config): logger.info("Starting clad!") try: SYSTEM_BUS.publish_object( - QUERY_IDENTIFIER.object_path, QueryInterface(QueryContext(config)) + QUERY_IDENTIFIER.object_path, QueryInterface(DaemonContext(config)) ) SYSTEM_BUS.publish_object( - HISTORY_IDENTIFIER.object_path, HistoryInterface(HistoryContext(config)) + HISTORY_IDENTIFIER.object_path, HistoryInterface(DaemonContext(config)) ) - # The flag DBUS_NAME_FLAG_REPLACE_EXISTING is needed during development - # so ew can replace the existing bus. - # TODO(r0x0d): See what to do with it later. SYSTEM_BUS.register_service(QUERY_IDENTIFIER.service_name) - SYSTEM_BUS.register_service( - HISTORY_IDENTIFIER.service_name, flags=DBUS_NAME_FLAG_REPLACE_EXISTING - ) + SYSTEM_BUS.register_service(HISTORY_IDENTIFIER.service_name) + loop = EventLoop() loop.run() finally: diff --git a/command_line_assistant/history/base.py b/command_line_assistant/history/base.py index 38df194..b6483c8 100644 --- a/command_line_assistant/history/base.py +++ b/command_line_assistant/history/base.py @@ -1,6 +1,7 @@ """Base module to track all the abstract classes for the history module.""" import logging +import uuid from abc import ABC, abstractmethod from command_line_assistant.config import Config @@ -20,7 +21,7 @@ def __init__(self, config: Config) -> None: self._config = config @abstractmethod - def read(self) -> list[dict[str, str]]: + def read(self, user_id: uuid.UUID) -> list[dict[str, str]]: """Abstract method to represent a read operation Returns: @@ -28,7 +29,7 @@ def read(self) -> list[dict[str, str]]: """ @abstractmethod - def write(self, query: str, response: str) -> None: + def write(self, user_id: uuid.UUID, query: str, response: str) -> None: """Abstract method to represent a write operation Args: @@ -37,7 +38,7 @@ def write(self, query: str, response: str) -> None: """ @abstractmethod - def clear(self) -> None: + def clear(self, user_id: uuid.UUID) -> None: """Abstract method to represent a clear operation""" def _check_if_history_is_enabled(self) -> bool: diff --git a/command_line_assistant/history/manager.py b/command_line_assistant/history/manager.py index 9c900c0..6385fa1 100644 --- a/command_line_assistant/history/manager.py +++ b/command_line_assistant/history/manager.py @@ -3,6 +3,7 @@ from typing import Optional, Type from command_line_assistant.config import Config +from command_line_assistant.daemon.session import UserSessionManager from command_line_assistant.history.base import BaseHistoryPlugin @@ -10,24 +11,30 @@ class HistoryManager: """Manages history operations by delegating to a specific history implementation. Example: - >>> manager = HistoryManager(config, plugin=LocalHistory) + >>> effective_user_id = 1000 + >>> manager = HistoryManager(config, effective_user_id, plugin=LocalHistory) >>> entries = manager.read() >>> manager.write("How do I check disk space?", "Use df -h command...") >>> manager.clear() """ def __init__( - self, config: Config, plugin: Optional[Type[BaseHistoryPlugin]] = None + self, + config: Config, + effective_user_id: int, + plugin: Optional[Type[BaseHistoryPlugin]] = None, ) -> None: """Initialize the history manager. Args: config (Config): Instance of configuration class + effective_user_id (int): The effective user id who asked for the history. plugin (Optional[Type[BaseHistory]], optional): Optional history implementation class """ self._config = config self._plugin: Optional[Type[BaseHistoryPlugin]] = None self._instance: Optional[BaseHistoryPlugin] = None + self._session_manager = UserSessionManager(effective_user_id) # Set initial plugin if provided if plugin: @@ -72,7 +79,7 @@ def read(self) -> list[dict[str, str]]: if not self._instance: raise RuntimeError("No history plugin set. Set plugin before operations.") - return self._instance.read() + return self._instance.read(self._session_manager.user_id) def write(self, query: str, response: str) -> None: """Write a new history entry using the current plugin. @@ -87,7 +94,7 @@ def write(self, query: str, response: str) -> None: if not self._instance: raise RuntimeError("No history plugin set. Set plugin before operations.") - self._instance.write(query, response) + self._instance.write(self._session_manager.user_id, query, response) def clear(self) -> None: """Clear all history entries. @@ -98,4 +105,4 @@ def clear(self) -> None: if not self._instance: raise RuntimeError("No history plugin set. Set plugin before operations.") - self._instance.clear() + self._instance.clear(self._session_manager.user_id) diff --git a/command_line_assistant/history/plugins/local.py b/command_line_assistant/history/plugins/local.py index d13338a..b2752fd 100644 --- a/command_line_assistant/history/plugins/local.py +++ b/command_line_assistant/history/plugins/local.py @@ -51,7 +51,7 @@ def _initialize_database(self) -> DatabaseManager: logger.error("Failed to initialize database: %s", e) raise MissingHistoryFileError(f"Could not initialize database: {e}") from e - def read(self) -> list[dict[str, str]]: + def read(self, user_id: uuid.UUID) -> list[dict[str, str]]: """Reads the history from the database. Returns: @@ -72,6 +72,7 @@ def read(self) -> list[dict[str, str]]: .join(InteractionModel) .filter(HistoryModel.deleted_at.is_(None)) .order_by(asc(HistoryModel.timestamp)) + .where(HistoryModel.user_id == user_id) .all() ) @@ -87,7 +88,7 @@ def read(self) -> list[dict[str, str]]: logger.error("Failed to read from database: %s", e) raise CorruptedHistoryError(f"Failed to read from database: {e}") from e - def write(self, query: str, response: str) -> None: + def write(self, user_id: uuid.UUID, query: str, response: str) -> None: """Write history to the database. Args: @@ -108,7 +109,6 @@ def write(self, query: str, response: str) -> None: query_text=query, response_text=response, response_tokens=len(response), - session_id=uuid.uuid4(), os_distribution="RHEL", # Default to RHEL for now os_version=platform.release(), os_arch=platform.machine(), @@ -116,15 +116,13 @@ def write(self, query: str, response: str) -> None: session.add(interaction) # Create History record - history = HistoryModel( - interaction=interaction, - ) + history = HistoryModel(interaction=interaction, user_id=user_id) session.add(history) except Exception as e: logger.error("Failed to write to database: %s", e) raise CorruptedHistoryError(f"Failed to write to database: {e}") from e - def clear(self) -> None: + def clear(self, user_id: uuid.UUID) -> None: """Clear the database by dropping and recreating tables. Raises: @@ -133,9 +131,9 @@ def clear(self) -> None: try: with self._db.session() as session: # Soft delete by setting deleted_at - session.query(HistoryModel).update( - {"deleted_at": datetime.utcnow()}, synchronize_session=False - ) + session.query(HistoryModel).where( + HistoryModel.user_id == user_id + ).update({"deleted_at": datetime.utcnow()}) logger.info("Database cleared successfully") except Exception as e: logger.error("Failed to clear database: %s", e) diff --git a/command_line_assistant/logger.py b/command_line_assistant/logger.py index ed8fc0a..2a6de3c 100644 --- a/command_line_assistant/logger.py +++ b/command_line_assistant/logger.py @@ -6,6 +6,7 @@ from typing import Optional from command_line_assistant.config import Config +from command_line_assistant.daemon.session import UserSessionManager #: Define the dictionary configuration for the logger instance LOGGING_CONFIG_DICTIONARY = { @@ -51,6 +52,30 @@ } +def _should_log_for_user(effective_user_id: int, config: Config, log_type: str) -> bool: + """Check if logging should be enabled for a specific user and log type. + + Args: + effective_user_id (int): The effective user id to check if logging is enabled. + log_type (str): The type of log ('responses' or 'question') + + Returns: + bool: Whether logging should be enabled for this user and log type + """ + logging_users = copy.deepcopy(config.logging.users) + for user in config.logging.users.keys(): + user_id = str(UserSessionManager(user).user_id) + logging_users[user_id] = logging_users.pop(user) + + user_id = str(UserSessionManager(effective_user_id).user_id) + # If user has specific settings, use those + if user_id in logging_users: + return logging_users[user_id].get(log_type, False) + + # Otherwise fall back to global settings + return getattr(config.logging, log_type, False) + + class AuditFormatter(logging.Formatter): """Custom formatter that handles user-specific logging configuration.""" @@ -94,16 +119,16 @@ def format(self, record: logging.LogRecord) -> str: "user": getattr(record, "user", "unknown"), "message": record.getMessage(), } - - is_query_enabled = hasattr( - record, "query" - ) and self._config.logging.should_log_for_user(data["user"], "question") + effective_user_id = data["user"] + is_query_enabled = hasattr(record, "query") and _should_log_for_user( + effective_user_id, self._config, "question" + ) # Add query if enabled for user data["query"] = record.query if is_query_enabled else None # type: ignore - is_response_enabled = hasattr( - record, "response" - ) and self._config.logging.should_log_for_user(data["user"], "responses") + is_response_enabled = hasattr(record, "response") and _should_log_for_user( + effective_user_id, self._config, "responses" + ) # Add response if enabled for user data["response"] = record.response if is_response_enabled else None # type: ignore diff --git a/command_line_assistant/utils/cli.py b/command_line_assistant/utils/cli.py index 8dbfbbd..394e928 100644 --- a/command_line_assistant/utils/cli.py +++ b/command_line_assistant/utils/cli.py @@ -3,6 +3,9 @@ that is reused across commands and other interactions. """ +import dataclasses +import getpass +import os import select import sys from abc import ABC, abstractmethod @@ -18,9 +21,31 @@ ARGS_WITH_VALUES: list[str] = ["--clear"] +@dataclasses.dataclass +class CommandContext: + """A context for all commands with useful information. + + Note: + This is meant to be initialized exclusively by the client. + + Attributes: + username (str): The username of the current user. + effective_user_id (int): The effective user id. + """ + + username: str = getpass.getuser() + effective_user_id: int = os.getegid() + + class BaseCLICommand(ABC): """Absctract class to define a CLI Command.""" + def __init__(self) -> None: + """Constructor for the base class.""" + self._context: CommandContext = CommandContext() + + super().__init__() + @abstractmethod def run(self) -> int: """Entrypoint method for all CLI commands.""" diff --git a/data/development/config/command-line-assistant/config.toml b/data/development/config/command-line-assistant/config.toml index e566b97..917231c 100644 --- a/data/development/config/command-line-assistant/config.toml +++ b/data/development/config/command-line-assistant/config.toml @@ -14,7 +14,8 @@ type = "sqlite" connection_string = "~/.local/share/command-line-assistant/history.db" [backend] -endpoint = "http://localhost:8080" +#endpoint = "http://localhost:8080" +endpoint = "https://rlsapi-rhel-lightspeed--runtime-int.apps.int.spoke.preprod.us-east-1.aws.paas.redhat.com" [backend.auth] cert_file = "data/development/certificate/fake-certificate.pem" @@ -24,7 +25,7 @@ verify_ssl = false [logging] level = "DEBUG" responses = false # Global setting - don't log responses by default -question = false # Global setting - don't log questions by default +question = false # Global setting - don't log questions by default # User-specific settings -# users.admin = { responses = true, question = true } +#users.rolivier = { responses = true, question = true } diff --git a/tests/commands/test_query.py b/tests/commands/test_query.py index 94a44c9..556ebf0 100644 --- a/tests/commands/test_query.py +++ b/tests/commands/test_query.py @@ -61,22 +61,13 @@ def test_query_command_run(mock_dbus_service, test_input, expected_output, capsy mock_output = Message() mock_output.message = expected_output mock_output.user = "mock" - mock_dbus_service.RetrieveAnswer = lambda: Message.to_structure(mock_output) - - with patch("command_line_assistant.commands.query.getpass.getuser") as mock_getuser: - mock_getuser.return_value = "mock" - # Create and run command - command = QueryCommand(test_input, None) - command.run() - - # Verify ProcessQuery was called with correct input - expected_input = Message() - expected_input.message = test_input - expected_input.user = "mock" - mock_dbus_service.ProcessQuery.assert_called_once_with( - Message.to_structure(expected_input) + mock_dbus_service.AskQuestion = lambda user_id, question: Message.to_structure( + mock_output ) + command = QueryCommand(test_input, None) + command.run() + # Verify output was printed captured = capsys.readouterr() assert expected_output in captured.out.strip() @@ -88,7 +79,9 @@ def test_query_command_empty_response(mock_dbus_service, capsys): mock_output = Message() mock_output.message = "" mock_output.user = "mock" - mock_dbus_service.RetrieveAnswer = lambda: Message.to_structure(mock_output) + mock_dbus_service.AskQuestion = lambda user_id, question: Message.to_structure( + mock_output + ) command = QueryCommand("test query", None) command.run() @@ -241,7 +234,7 @@ def test_get_input_source_value_error(): def test_dbus_error_handling(exception, expected, mock_dbus_service, capsys): """Test handling of DBus errors""" # Make ProcessQuery raise a DBus error - mock_dbus_service.ProcessQuery.side_effect = exception + mock_dbus_service.AskQuestion.side_effect = exception command = QueryCommand("test query", None) command.run() diff --git a/tests/conftest.py b/tests/conftest.py index 460b6ec..140a724 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,6 +14,7 @@ OutputSchema, ) from command_line_assistant.config.schemas import AuthSchema, DatabaseSchema +from command_line_assistant.dbus.context import DaemonContext from command_line_assistant.logger import LOGGING_CONFIG_DICTIONARY from tests.helpers import MockStream @@ -37,6 +38,15 @@ def setup_logger(tmp_path, request): root_logger.removeHandler(handler) +class MockPwnam: + def __init__(self, pw_uid="1000"): + self._pw_uid = pw_uid + + @property + def pw_uid(self): + return self._pw_uid + + @pytest.fixture def mock_config(tmp_path): """Fixture to create a mock configuration""" @@ -46,21 +56,27 @@ def mock_config(tmp_path): cert_file.write_text("cert") key_file.write_text("key") - return Config( - output=OutputSchema( - enforce_script=False, - file=Path("/tmp/test_output.txt"), - prompt_separator="$", - ), - backend=BackendSchema( - endpoint="http://test.endpoint/v1/query", - auth=AuthSchema(cert_file=cert_file, key_file=key_file, verify_ssl=False), - ), - history=HistorySchema( - enabled=True, database=DatabaseSchema(connection_string=history_db) - ), - logging=LoggingSchema(level="debug"), - ) + with patch("pwd.getpwnam", return_value=MockPwnam()): + return Config( + output=OutputSchema( + enforce_script=False, + file=Path("/tmp/test_output.txt"), + prompt_separator="$", + ), + backend=BackendSchema( + endpoint="http://test.endpoint/v1/query", + auth=AuthSchema( + cert_file=cert_file, key_file=key_file, verify_ssl=False + ), + ), + history=HistorySchema( + enabled=True, database=DatabaseSchema(connection_string=history_db) + ), + logging=LoggingSchema( + level="debug", + users={"testuser": {"question": True, "responses": False}}, + ), + ) @pytest.fixture @@ -79,3 +95,8 @@ def mock_proxy(): @pytest.fixture def mock_stream(): return MockStream() + + +@pytest.fixture +def mock_context(mock_config): + return DaemonContext(mock_config) diff --git a/tests/daemon/database/test_manager.py b/tests/daemon/database/test_manager.py index cea43c7..f92143c 100644 --- a/tests/daemon/database/test_manager.py +++ b/tests/daemon/database/test_manager.py @@ -94,7 +94,7 @@ def test_add_success(database_manager): os_arch="x86_64", ) - history = HistoryModel(interaction=interaction) + history = HistoryModel(interaction=interaction, user_id=uuid.uuid4()) database_manager.add(history) result = database_manager.get(InteractionModel, uid) @@ -122,7 +122,7 @@ def test_query_success(database_manager): os_arch="x86_64", ) uid = uuid.uuid4() - history = HistoryModel(id=uid, interaction=interaction) + history = HistoryModel(id=uid, user_id=uuid.uuid4(), interaction=interaction) database_manager.add(history) @@ -157,6 +157,7 @@ def test_get_success(database_manager): uid = uuid.uuid4() history = HistoryModel( id=uid, + user_id=uuid.uuid4(), interaction=interaction, ) database_manager.add(history) diff --git a/tests/daemon/test_session.py b/tests/daemon/test_session.py index 38787b4..abfa467 100644 --- a/tests/daemon/test_session.py +++ b/tests/daemon/test_session.py @@ -7,17 +7,17 @@ def test_initialize_user_session_manager(): - session = UserSessionManager("1000") - assert session._effective_user == "1000" + session = UserSessionManager(1000) + assert session._effective_user_id == 1000 assert not session._machine_uuid - assert not session._session_uuid + assert not session._user_id def test_read_machine_id(tmp_path): machine_id = tmp_path / "machine-id" machine_id.write_text("09e28913cb074ed995a239c93b07fd8a") with patch("command_line_assistant.daemon.session.MACHINE_ID_PATH", machine_id): - session = UserSessionManager("1000") + session = UserSessionManager(1000) assert session.machine_id == uuid.UUID("09e28913cb074ed995a239c93b07fd8a") @@ -25,8 +25,8 @@ def test_generate_session_id(tmp_path): machine_id = tmp_path / "machine-id" machine_id.write_text("09e28913cb074ed995a239c93b07fd8a") with patch("command_line_assistant.daemon.session.MACHINE_ID_PATH", machine_id): - session = UserSessionManager("1000") - assert session.session_id == uuid.UUID("4d465f1c-0507-5dfa-9ea0-e2de1a9e90a5") + session = UserSessionManager(1000) + assert session.user_id == uuid.UUID("4d465f1c-0507-5dfa-9ea0-e2de1a9e90a5") def test_generate_session_id_twice(tmp_path): @@ -34,11 +34,11 @@ def test_generate_session_id_twice(tmp_path): machine_id = tmp_path / "machine-id" machine_id.write_text("09e28913cb074ed995a239c93b07fd8a") with patch("command_line_assistant.daemon.session.MACHINE_ID_PATH", machine_id): - session = UserSessionManager("1000") - assert session.session_id == uuid.UUID("4d465f1c-0507-5dfa-9ea0-e2de1a9e90a5") + session = UserSessionManager(1000) + assert session.user_id == uuid.UUID("4d465f1c-0507-5dfa-9ea0-e2de1a9e90a5") - session = UserSessionManager("1000") - assert session.session_id == uuid.UUID("4d465f1c-0507-5dfa-9ea0-e2de1a9e90a5") + session = UserSessionManager(1000) + assert session.user_id == uuid.UUID("4d465f1c-0507-5dfa-9ea0-e2de1a9e90a5") @pytest.mark.parametrize( @@ -72,7 +72,7 @@ def test_generate_session_id_different_users( "command_line_assistant.daemon.session.MACHINE_ID_PATH", machine_id_file ): session = UserSessionManager(effective_user_id) - assert session.session_id == uuid.UUID(expected) + assert session.user_id == uuid.UUID(expected) def test_empty_machine_id_file(tmp_path): @@ -81,7 +81,7 @@ def test_empty_machine_id_file(tmp_path): with patch( "command_line_assistant.daemon.session.MACHINE_ID_PATH", machine_id_file ): - session = UserSessionManager("1000") + session = UserSessionManager(1000) with pytest.raises(ValueError, match="Machine ID at .* is empty"): assert session.machine_id @@ -91,6 +91,6 @@ def test_machine_id_file_not_found(tmp_path): with patch( "command_line_assistant.daemon.session.MACHINE_ID_PATH", machine_id_file ): - session = UserSessionManager("1000") + session = UserSessionManager(1000) with pytest.raises(FileNotFoundError, match="Machine ID file not found at .*"): assert session.machine_id diff --git a/tests/dbus/test_context.py b/tests/dbus/test_context.py index 007b4b1..e4e486a 100644 --- a/tests/dbus/test_context.py +++ b/tests/dbus/test_context.py @@ -1,46 +1,8 @@ -import pytest - from command_line_assistant.dbus.context import ( - BaseContext, - HistoryContext, - QueryContext, + DaemonContext, ) -from command_line_assistant.dbus.structures import Message - - -@pytest.fixture -def base_context(mock_config): - return BaseContext(mock_config) - -@pytest.fixture -def query_context(mock_config): - return QueryContext(mock_config) - -@pytest.fixture -def history_context(mock_config): - return HistoryContext(mock_config) - - -def test_base_context_config_property(mock_config): - context = BaseContext(mock_config) +def test_daemon_context_config_property(mock_config): + context = DaemonContext(mock_config) assert context.config == mock_config - - -def test_query_context_initial_state(query_context): - assert query_context.query is None - - -def test_query_context_process_query(query_context): - message_obj = Message() - message_obj.message = "test query" - - query_context.process_query(message_obj) - - assert query_context.query == message_obj - - -def test_history_context_inheritance(history_context, mock_config): - assert isinstance(history_context, BaseContext) - assert history_context.config == mock_config diff --git a/tests/dbus/test_interfaces.py b/tests/dbus/test_interfaces.py index 1290bd3..b636822 100644 --- a/tests/dbus/test_interfaces.py +++ b/tests/dbus/test_interfaces.py @@ -1,4 +1,4 @@ -from unittest.mock import Mock, patch +from unittest.mock import patch import pytest from dasbus.server.template import InterfaceTemplate @@ -14,73 +14,49 @@ @pytest.fixture def mock_history_entry(mock_config): - manager = HistoryManager(mock_config, LocalHistory) + manager = HistoryManager(mock_config, 1000, LocalHistory) return manager @pytest.fixture -def mock_implementation(mock_config): - """Create a mock implementation with configuration.""" - impl = Mock() - impl.config = mock_config - mock_query = Message() - mock_query.message = "test query" - impl.query = mock_query - return impl - - -@pytest.fixture -def query_interface(mock_implementation): +def query_interface(mock_context): """Create a QueryInterface instance with mock implementation.""" - interface = QueryInterface(mock_implementation) + interface = QueryInterface(mock_context) assert isinstance(interface, InterfaceTemplate) return interface @pytest.fixture -def history_interface(mock_implementation): +def history_interface(mock_context): """Create a HistoryInterface instance with mock implementation.""" - interface = HistoryInterface(mock_implementation) + interface = HistoryInterface(mock_context) assert isinstance(interface, InterfaceTemplate) return interface -def test_query_interface_retrieve_answer(query_interface, mock_implementation): +def test_query_interface_ask_question(query_interface, mock_config): """Test retrieving answer from query interface.""" expected_response = "test response" with patch( "command_line_assistant.dbus.interfaces.submit", return_value=expected_response ) as mock_submit: - response = query_interface.RetrieveAnswer() - - mock_submit.assert_called_once_with( - mock_implementation.query.message, mock_implementation.config + response = query_interface.AskQuestion( + "b7e95c2c-d2a8-11ef-a6bf-52b437312584", "test?" ) + mock_submit.assert_called_once_with("test?", mock_config) + reconstructed = Message.from_structure(response) assert reconstructed.message == expected_response -def test_query_interface_process_query(query_interface, mock_implementation): - """Test processing query through query interface.""" - test_query = Message() - test_query.message = "test query" - - query_interface.ProcessQuery(Message.to_structure(test_query)) - - mock_implementation.process_query.assert_called_once() - processed_query = mock_implementation.process_query.call_args[0][0] - assert isinstance(processed_query, Message) - assert processed_query.message == test_query.message - - def test_history_interface_get_history(history_interface, mock_history_entry): """Test getting all history through history interface.""" with patch( "command_line_assistant.history.manager.HistoryManager", mock_history_entry ) as manager: manager.write("test query", "test response") - response = history_interface.GetHistory() + response = history_interface.GetHistory(1000) reconstructed = HistoryEntry.from_structure(response) assert len(reconstructed.entries) == 1 @@ -99,7 +75,7 @@ def test_history_interface_get_first_conversation( manager.write("test query", "test response") manager.write("test query2", "test response2") manager.write("test query3", "test response3") - response = history_interface.GetFirstConversation() + response = history_interface.GetFirstConversation(1000) reconstructed = HistoryEntry.from_structure(response) assert len(reconstructed.entries) == 1 @@ -115,7 +91,7 @@ def test_history_interface_get_last_conversation(history_interface, mock_history manager.write("test query", "test response") manager.write("test query2", "test response2") manager.write("test query3", "test response3") - response = history_interface.GetLastConversation() + response = history_interface.GetLastConversation(1000) reconstructed = HistoryEntry.from_structure(response) assert len(reconstructed.entries) == 1 @@ -132,7 +108,7 @@ def test_history_interface_get_filtered_conversation( ) as manager: manager.write("test query", "test response") manager.write("not a query", "not a response") - response = history_interface.GetFilteredConversation(filter="test") + response = history_interface.GetFilteredConversation(1000, filter="test") reconstructed = HistoryEntry.from_structure(response) assert len(reconstructed.entries) == 1 @@ -152,7 +128,7 @@ def test_history_interface_get_filtered_conversation_duplicate_entries_not_match ) as manager: manager.write("test query", "test response") manager.write("test query", "test response") - response = history_interface.GetFilteredConversation(filter="test") + response = history_interface.GetFilteredConversation(1000, filter="test") reconstructed = HistoryEntry.from_structure(response) assert len(reconstructed.entries) == 2 @@ -163,7 +139,7 @@ def test_history_interface_get_filtered_conversation_duplicate_entries_not_match def test_history_interface_clear_history(history_interface): """Test clearing history through history interface.""" with patch("command_line_assistant.dbus.interfaces.HistoryManager") as mock_manager: - history_interface.ClearHistory() + history_interface.ClearHistory(1000) mock_manager.return_value.clear.assert_called_once() @@ -178,6 +154,6 @@ def test_history_interface_empty_history(history_interface): history_interface.GetFirstConversation, history_interface.GetLastConversation, ]: - response = method() + response = method(1000) reconstructed = HistoryEntry.from_structure(response) assert len(reconstructed.entries) == 0 diff --git a/tests/dbus/test_server.py b/tests/dbus/test_server.py index fab76dc..84a2de6 100644 --- a/tests/dbus/test_server.py +++ b/tests/dbus/test_server.py @@ -1,7 +1,5 @@ from unittest import mock -from dasbus.constants import DBUS_NAME_FLAG_REPLACE_EXISTING - from command_line_assistant.config import Config from command_line_assistant.dbus import server @@ -29,10 +27,6 @@ def test_serve_registers_services(monkeypatch): assert system_bus_mock.publish_object.call_count == 2 assert system_bus_mock.register_service.call_count == 2 - assert ( - system_bus_mock.register_service.call_args_list[1][1]["flags"] - == DBUS_NAME_FLAG_REPLACE_EXISTING - ) def test_serve_cleanup_on_exception(monkeypatch): diff --git a/tests/history/plugins/test_local.py b/tests/history/plugins/test_local.py index 24d915f..55bef1f 100644 --- a/tests/history/plugins/test_local.py +++ b/tests/history/plugins/test_local.py @@ -1,15 +1,9 @@ import uuid -from datetime import datetime from unittest.mock import Mock, create_autospec, patch import pytest -from sqlalchemy.orm import Session from command_line_assistant.daemon.database.manager import DatabaseManager -from command_line_assistant.daemon.database.models.history import ( - HistoryModel, - InteractionModel, -) from command_line_assistant.dbus.exceptions import ( CorruptedHistoryError, MissingHistoryFileError, @@ -18,29 +12,10 @@ @pytest.fixture -def mock_db_session() -> Mock: - """Fixture for database session.""" - return create_autospec(Session, instance=True) - - -@pytest.fixture -def mock_db_manager(mock_db_session: Mock) -> Mock: - """Fixture for DatabaseManager with mocked session.""" - db_manager = create_autospec(DatabaseManager, instance=True) - db_manager.session.return_value.__enter__.return_value = mock_db_session - db_manager.session.return_value.__exit__.return_value = None - return db_manager - - -@pytest.fixture -def local_history(mock_config: Mock, mock_db_manager: Mock) -> LocalHistory: +def local_history(mock_config: Mock) -> LocalHistory: """Fixture for LocalHistory instance with mocked dependencies.""" - with patch( - "command_line_assistant.history.plugins.local.DatabaseManager", - return_value=mock_db_manager, - ): - history = LocalHistory(mock_config) - return history + history = LocalHistory(mock_config) + return history class TestLocalHistoryInitialization: @@ -75,36 +50,24 @@ def test_read_disabled_history( ): """Should return empty list when history is disabled.""" mock_config.history.enabled = False - assert local_history.read() == [] + assert local_history.read(uuid.uuid4()) == [] - def test_read_success(self, local_history: LocalHistory, mock_db_session: Mock): + def test_read_success(self, local_history: LocalHistory): """Should successfully read and format history entries.""" # Create mock history entries - mock_interaction = Mock(spec=InteractionModel) - mock_interaction.query_text = "test query" - mock_interaction.response_text = "test response" - - mock_history = Mock(spec=HistoryModel) - mock_history.interaction = mock_interaction - mock_history.timestamp = datetime.utcnow() - - mock_db_session.query.return_value.join.return_value.filter.return_value.order_by.return_value.all.return_value = [ - mock_history - ] - - result = local_history.read() + uid = uuid.uuid4() + local_history.write(uid, "test query", "test response") + result = local_history.read(uid) assert len(result) == 1 assert result[0]["query"] == "test query" assert result[0]["response"] == "test response" assert "timestamp" in result[0] - def test_read_failure(self, local_history: LocalHistory, mock_db_session: Mock): + def test_read_failure(self, local_history: LocalHistory): """Should raise CorruptedHistoryError on read failure.""" - mock_db_session.query.side_effect = Exception("DB Read Error") - with pytest.raises(CorruptedHistoryError, match="Failed to read from database"): - local_history.read() + local_history.read("1") # type: ignore class TestLocalHistoryWrite: @@ -115,8 +78,9 @@ def test_write_disabled_history( ): """Should not write when history is disabled.""" mock_config.history.enabled = False - local_history.write("query", "response") - assert local_history._db.session.call_count == 0 # type: ignore + uid = uuid.uuid4() + local_history.write(uid, "query", "response") + assert not local_history.read(uid) @pytest.mark.parametrize( "query,response", @@ -129,77 +93,33 @@ def test_write_disabled_history( def test_write_success( self, local_history: LocalHistory, - mock_db_session: Mock, query: str, response: str, ): """Should successfully write history entries.""" - with patch( - "uuid.uuid4", return_value=uuid.UUID("12345678-1234-5678-1234-567812345678") - ): - local_history.write(query, response) - - # Verify interaction was created with correct attributes - mock_db_session.add.assert_called() - calls = mock_db_session.add.call_args_list - - # First call should be InteractionModel - interaction = calls[0][0][0] - assert isinstance(interaction, InteractionModel) - assert interaction.query_text == query # type: ignore - assert interaction.response_text == response # type: ignore - assert interaction.session_id is not None - - # Second call should be HistoryModel - history = calls[1][0][0] - assert isinstance(history, HistoryModel) - assert history.interaction == interaction - - def test_write_failure(self, local_history: LocalHistory, mock_db_session: Mock): - """Should raise CorruptedHistoryError on write failure.""" - mock_db_session.add.side_effect = Exception("DB Write Error") + uid = uuid.uuid4() + local_history.write(uid, query, response) + assert len(local_history.read(uid)) == 1 + def test_write_failure(self, local_history: LocalHistory): + """Should raise CorruptedHistoryError on write failure.""" with pytest.raises(CorruptedHistoryError, match="Failed to write to database"): - local_history.write("query", "response") + local_history.write("1", "query", "response") # type: ignore class TestLocalHistoryClear: """Test cases for clearing history.""" - def test_clear_success(self, local_history: LocalHistory, mock_db_session: Mock): + def test_clear_success(self, local_history: LocalHistory): """Should successfully clear history.""" - local_history.clear() + uid = uuid.uuid4() + local_history.write(uid, "test", "test") + local_history.clear(uid) # Verify soft delete was performed - mock_db_session.query.return_value.update.assert_called_once() - update_args = mock_db_session.query.return_value.update.call_args[0][0] - assert "deleted_at" in update_args - assert isinstance(update_args["deleted_at"], datetime) + assert not local_history.read(uid) - def test_clear_failure(self, local_history: LocalHistory, mock_db_session: Mock): + def test_clear_failure(self, local_history: LocalHistory): """Should raise MissingHistoryFileError on clear failure.""" - mock_db_session.query.return_value.update.side_effect = Exception( - "DB Clear Error" - ) - with pytest.raises(MissingHistoryFileError, match="Failed to clear database"): - local_history.clear() - - -def test_integration_workflow(local_history: LocalHistory, mock_db_session: Mock): - """Integration test for full local history workflow.""" - # Setup mock responses - mock_db_session.query.return_value.join.return_value.filter.return_value.order_by.return_value.all.return_value = [] - - # Test read (empty) - assert local_history.read() == [] - - # Test write - local_history.write("test query", "test response") - assert ( - mock_db_session.add.call_count == 2 - ) # One for InteractionModel, one for HistoryModel - - # Test clear - local_history.clear() - mock_db_session.query.return_value.update.assert_called_once() + local_history.clear("bb5b3a3e-d2a7-11ef-a682-52b437312584i9090") # type: ignore diff --git a/tests/history/test_manager.py b/tests/history/test_manager.py index 26598c8..a81bb26 100644 --- a/tests/history/test_manager.py +++ b/tests/history/test_manager.py @@ -12,25 +12,25 @@ def __init__(self, config): self.write_called = False self.clear_called = False - def read(self): + def read(self, user_id): self.read_called = True return [] - def write(self, query: str, response: str) -> None: + def write(self, user_id, query: str, response: str) -> None: self.write_called = True - def clear(self) -> None: + def clear(self, user_id) -> None: self.clear_called = True @pytest.fixture def history_manager(mock_config): - return HistoryManager(mock_config, plugin=LocalHistory) + return HistoryManager(mock_config, 1000, plugin=LocalHistory) def test_history_manager_initialization(mock_config): """Test that HistoryManager initializes correctly""" - manager = HistoryManager(mock_config) + manager = HistoryManager(mock_config, 1000) assert manager._config == mock_config assert manager._plugin is None assert manager._instance is None @@ -38,7 +38,7 @@ def test_history_manager_initialization(mock_config): def test_history_manager_plugin_setter(mock_config): """Test setting a valid plugin""" - manager = HistoryManager(mock_config) + manager = HistoryManager(mock_config, 1000) manager.plugin = MockHistoryPlugin assert manager._plugin == MockHistoryPlugin assert isinstance(manager._instance, MockHistoryPlugin) @@ -46,7 +46,7 @@ def test_history_manager_plugin_setter(mock_config): def test_history_manager_invalid_plugin(mock_config): """Test setting an invalid plugin""" - manager = HistoryManager(mock_config) + manager = HistoryManager(mock_config, 1000) class InvalidPlugin(BaseHistoryPlugin): pass @@ -57,21 +57,21 @@ class InvalidPlugin(BaseHistoryPlugin): def test_history_manager_read_without_plugin(mock_config): """Test reading history without setting a plugin first""" - manager = HistoryManager(mock_config) + manager = HistoryManager(mock_config, 1000) with pytest.raises(RuntimeError): manager.read() def test_history_manager_write_without_plugin(mock_config): """Test writing history without setting a plugin first""" - manager = HistoryManager(mock_config) + manager = HistoryManager(mock_config, 1000) with pytest.raises(RuntimeError): manager.write("test query", "test response") def test_history_manager_clear_without_plugin(mock_config): """Test clearing history without setting a plugin first""" - manager = HistoryManager(mock_config) + manager = HistoryManager(mock_config, 1000) with pytest.raises(RuntimeError): manager.clear() diff --git a/tests/test_logger.py b/tests/test_logger.py index 621db74..66aa663 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -7,6 +7,7 @@ from command_line_assistant.logger import ( AuditFormatter, _create_audit_formatter, + _should_log_for_user, setup_logging, ) @@ -79,7 +80,6 @@ def test_setup_logging(mock_dict_config, mock_config): def test_audit_formatter_user_specific_logging(mock_config): """Test user-specific logging configuration.""" # Configure mock for user-specific settings - mock_config.logging.users = {"testuser": {"question": True, "responses": False}} formatter = AuditFormatter(config=mock_config) record = logging.LogRecord( @@ -91,7 +91,7 @@ def test_audit_formatter_user_specific_logging(mock_config): args=(), exc_info=None, ) - record.user = "testuser" + record.user = "1000" record.query = "test query" record.response = "test response" @@ -108,3 +108,23 @@ def test_setup_logging_invalid_level(mock_config): mock_config.logging.level = "INVALID_LEVEL" with pytest.raises(ValueError): setup_logging(mock_config) + + +@pytest.mark.parametrize( + ("users", "effective_user_id", "log_type", "expected"), + ( + ({"1000": {"response": True, "question": False}}, 1000, "response", True), + ({"1000": {"response": False, "question": False}}, 1000, "response", False), + ({"1000": {"response": False, "question": True}}, 1000, "question", True), + ({"1000": {"response": False, "question": False}}, 1000, "question", False), + # User is defined in the config, but nothing is specified + ({"1000": {}}, 1000, "question", False), + ({"1000": {}}, 1000, "response", False), + ({"1000": {}, "1001": {"response": True}}, 1001, "response", True), + ({"1000": {}, "1001": {"question": True}}, 1001, "question", True), + ({"1000": {}, "1001": {}}, 1001, "question", False), + ), +) +def test_should_log_for_user(users, effective_user_id, log_type, expected, mock_config): + mock_config.logging.users = users + assert _should_log_for_user(effective_user_id, mock_config, log_type) == expected