Skip to content

Commit

Permalink
Add user_id to history tables. (#112)
Browse files Browse the repository at this point in the history
The user_id is a special way for us to control which record is tied to
which user. This will allow us to be prepared for a multi-user
environment.
  • Loading branch information
r0x0d authored Jan 15, 2025
1 parent 7ad8540 commit 012e642
Show file tree
Hide file tree
Showing 24 changed files with 329 additions and 436 deletions.
13 changes: 8 additions & 5 deletions command_line_assistant/commands/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def run(self) -> int:
Returns:
int: Status code of the execution.
"""

try:
if self._clear:
self._clear_history()
Expand All @@ -89,7 +90,7 @@ def run(self) -> int:
def _retrieve_all_conversations(self) -> None:
"""Retrieve and display all conversations from history."""
self._text_renderer.render("Getting all conversations from history.")
response = self._proxy.GetHistory()
response = self._proxy.GetHistory(self._context.effective_user_id)
history = HistoryEntry.from_structure(response)

# Display the conversation
Expand All @@ -98,7 +99,7 @@ def _retrieve_all_conversations(self) -> None:
def _retrieve_first_conversation(self) -> None:
"""Retrieve the first conversation in the conversation cache."""
self._text_renderer.render("Getting first conversation from history.")
response = self._proxy.GetFirstConversation()
response = self._proxy.GetFirstConversation(self._context.effective_user_id)
history = HistoryEntry.from_structure(response)

# Display the conversation
Expand All @@ -111,7 +112,9 @@ def _retrieve_conversation_filtered(self, filter: str) -> None:
filter (str): Keyword to filter in the user history
"""
self._text_renderer.render("Filtering conversation history.")
response = self._proxy.GetFilteredConversation(filter)
response = self._proxy.GetFilteredConversation(
self._context.effective_user_id, filter
)

# Handle and display the response
history = HistoryEntry.from_structure(response)
Expand All @@ -122,7 +125,7 @@ def _retrieve_conversation_filtered(self, filter: str) -> None:
def _retrieve_last_conversation(self) -> None:
"""Retrieve the last conversation in the conversation cache."""
self._text_renderer.render("Getting last conversation from history.")
response = self._proxy.GetLastConversation()
response = self._proxy.GetLastConversation(self._context.effective_user_id)

# Handle and display the response
history = HistoryEntry.from_structure(response)
Expand All @@ -133,7 +136,7 @@ def _retrieve_last_conversation(self) -> None:
def _clear_history(self) -> None:
"""Clear the user history"""
self._text_renderer.render("Cleaning the history.")
self._proxy.ClearHistory()
self._proxy.ClearHistory(self._context.effective_user_id)

def _show_history(self, entries: list[HistoryItem]) -> None:
"""Internal method to show the history in a standarized way
Expand Down
17 changes: 8 additions & 9 deletions command_line_assistant/commands/query.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Module to handle the query command."""

import argparse
import getpass
from argparse import Namespace
from io import TextIOWrapper
from typing import Optional
Expand Down Expand Up @@ -77,6 +76,9 @@ def __init__(
)
self._error_renderer: TextRenderer = create_error_renderer()
self._warning_renderer: TextRenderer = create_warning_renderer()

self._proxy = QUERY_IDENTIFIER.get_proxy()

super().__init__()

def _get_input_source(self) -> str:
Expand Down Expand Up @@ -142,24 +144,21 @@ def run(self) -> int:
Returns:
int: Status code of the execution
"""
proxy = QUERY_IDENTIFIER.get_proxy()

try:
query = self._get_input_source()
question = self._get_input_source()
except ValueError as e:
self._error_renderer.render(str(e))
return 1

input_query = Message()
input_query.message = query
# Get the current user
input_query.user = getpass.getuser()
output = "Nothing to see here..."

try:
with self._spinner_renderer:
proxy.ProcessQuery(input_query.to_structure(input_query))
output = Message.from_structure(proxy.RetrieveAnswer()).message
response = self._proxy.AskQuestion(
self._context.effective_user_id, question
)
output = Message.from_structure(response).message
except (
RequestFailedError,
MissingHistoryFileError,
Expand Down
29 changes: 13 additions & 16 deletions command_line_assistant/config/schemas.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Module to hold the config schema and it's sub schemas."""

import copy
import dataclasses
import pwd
from pathlib import Path
from typing import Optional, Union

Expand Down Expand Up @@ -96,22 +98,17 @@ def __post_init__(self) -> None:

self.level = self.level.upper()

def should_log_for_user(self, username: str, log_type: str) -> bool:
"""Check if logging should be enabled for a specific user and log type.
Args:
username (str): The username to check
log_type (str): The type of log ('responses' or 'question')
Returns:
bool: Whether logging should be enabled for this user and log type
"""
# If user has specific settings, use those
if username in self.users:
return self.users[username].get(log_type, False)

# Otherwise fall back to global settings
return getattr(self, log_type, False)
if self.users:
# Turn any username to their effective_user_id
defined_users = copy.deepcopy(self.users)
for user in defined_users.keys():
try:
effective_user_id = str(pwd.getpwnam(user).pw_uid)
self.users[effective_user_id] = self.users.pop(user)
except KeyError as e:
raise ValueError(
f"{user} is not present on the system. Remove it from the configuration."
) from e


@dataclasses.dataclass
Expand Down
2 changes: 1 addition & 1 deletion command_line_assistant/daemon/database/models/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class HistoryModel(BaseModel):
__tablename__ = "history"

id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(UUID(as_uuid=True), nullable=False)
timestamp = Column(DateTime, default=datetime.utcnow())
deleted_at = Column(DateTime, nullable=True)

Expand All @@ -37,7 +38,6 @@ class InteractionModel(BaseModel):
response_text = Column(String)
response_role = Column(String, default="assistant")
response_tokens = Column(Integer, default=0)
session_id = Column(UUID(as_uuid=True), nullable=False, default=uuid.uuid4)
os_distribution = Column(String, default="RHEL")
os_version = Column(String, nullable=False)
os_arch = Column(String, nullable=False)
20 changes: 10 additions & 10 deletions command_line_assistant/daemon/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import uuid
from pathlib import Path
from typing import Optional
from typing import Optional, Union

#: Path to the machine ID file
MACHINE_ID_PATH: Path = Path("/etc/machine-id")
Expand All @@ -14,15 +14,15 @@
class UserSessionManager:
"""Manage user session information."""

def __init__(self, effective_user: str) -> None:
def __init__(self, effective_user_id: Union[int, str]) -> None:
"""Initialize the session manager.
Args:
effective_user (str): The effective user id
effective_user_id (Union[int, str]): The effective user id
"""
self._effective_user_id: Union[int, str] = effective_user_id
self._machine_uuid: Optional[uuid.UUID] = None
self._effective_user: str = effective_user
self._session_uuid: Optional[uuid.UUID] = None
self._user_id: Optional[uuid.UUID] = None

@property
def machine_id(self) -> uuid.UUID:
Expand Down Expand Up @@ -55,17 +55,17 @@ def machine_id(self) -> uuid.UUID:
return self._machine_uuid

@property
def session_id(self) -> uuid.UUID:
"""Property that generates a unique session ID combining machine and user effective id.
def user_id(self) -> uuid.UUID:
"""Property that generates a unique user ID combining machine and user effective id.
Returns:
uuid.UUID: A unique session identifier
"""
if not self._session_uuid:
if not self._user_id:
# Combine machine ID and effective user to create a unique namespace
namespace = self.machine_id

# Generate a UUID using the effective username as name in the namespace
self._session_uuid = uuid.uuid5(namespace, self._effective_user)
self._user_id = uuid.uuid5(namespace, str(self._effective_user_id))

return self._session_uuid
return self._user_id
53 changes: 2 additions & 51 deletions command_line_assistant/dbus/context.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
"""D-Bus context classes for managing the commands"""

from typing import Optional

from dasbus.signal import Signal

from command_line_assistant.config import Config
from command_line_assistant.dbus.structures import Message


class BaseContext:
"""Base class for context that defines the structure of it."""
class DaemonContext:
"""Context class for context that defines the structure of it."""

def __init__(self, config: Config) -> None:
"""Constructor of the class.
Expand All @@ -27,47 +22,3 @@ def config(self) -> Config:
Config: Instance of the configuration class
"""
return self._config


class QueryContext(BaseContext):
"""This is the process context that will handle anything query related"""

def __init__(self, config: Config) -> None:
"""Constructor of the class.
Args:
config (Config): Instance of the configuration class
"""
self._input_query: Optional[Message] = None
self._query_changed = Signal()
super().__init__(config)

@property
def query(self) -> Optional[Message]:
"""Property for the internal query attribute.
Returns:
Optional[Message]: The user query wrapped in a `py:Message` dbus structure.
"""
return self._input_query

def process_query(self, input_query: Message) -> None:
"""Emit the signal that the query has changed.
Args:
input_query (Message): The user query
"""
self._input_query = input_query
self._query_changed.emit()


class HistoryContext(BaseContext):
"""This is the process context that will handle anything query related"""

def __init__(self, config: Config) -> None:
"""Constructor of the class.
Args:
config (Config): Instance of the configuration class.
"""
super().__init__(config)
Loading

0 comments on commit 012e642

Please sign in to comment.