Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mito AI: Restore chat on refresh #1488

Open
wants to merge 19 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 140 additions & 45 deletions mito-ai/mito_ai/handlers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
import json
import logging
import time
from dataclasses import asdict
from http import HTTPStatus
from typing import Any, Awaitable, Dict, Optional, Literal
from typing import Any, Awaitable, Dict, Optional, List
from threading import Lock

import tornado
import tornado.ioloop
Expand All @@ -26,12 +28,77 @@
SmartDebugMessageMetadata,
CodeExplainMessageMetadata,
InlineCompletionMessageMetadata,
HistoryReply
)
from .prompt_builders import remove_inner_thoughts_from_message
from .providers import OpenAIProvider
from .utils.create import initialize_user
from .utils.schema import MITO_FOLDER

__all__ = ["CompletionHandler"]

# Global message history with thread-safe access
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just add some more documentation here about the thread-safe piece. Just explain why we need this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's also move the GlobalMessageHistory to its own file, just to keep things more isolated

class GlobalMessageHistory:
def __init__(self, save_file: str = os.path.join(MITO_FOLDER, "message_history.json")):
self._lock = Lock()
self._llm_history: List[Dict[str, str]] = []
self._display_history: List[Dict[str, str]] = []
Comment on lines +44 to +45
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to keep our language the same across the code base, can we call these _ai_optimized_history and _display_optimized_history

self._save_file = save_file

# Load from disk on startup
self._load_from_disk()

def _load_from_disk(self):
"""Load existing history from disk, if it exists."""
if os.path.exists(self._save_file):
try:
with open(self._save_file, "r", encoding="utf-8") as f:
data = json.load(f)
self._llm_history = data.get("llm_history", [])
self._display_history = data.get("display_history", [])
except Exception as e:
print(f"Error loading history file: {e}")

def _save_to_disk(self):
"""Save current history to disk."""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to make our lives easier when we want to update the format of this file, can we create a field in the json called something like chat_history_version: 1

data = {
"llm_history": self._llm_history,
"display_history": self._display_history,
}
Comment on lines +64 to +67
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option here is to just store the metadata that we send from the frontend to the backend in order to create the prompt in the .json file. Then, we can easily recreate both the llm_history and the display_history from them. There's no reason for the backend to care about the display_messages for example.

What do you think about refactoring this? Is there an advantage to storing everything in the .json file itself?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason for the backend to care about display_messages is to be able to restore them in the frontend when the extension loads.

I chose the format of what is stored in JSON to match what is stored in the memory, so it is easy to serialize/deserialize. That being said, we don't have to strore full text representations of either and can recalculate on the fly from context objects.

As I see it, there is no compeling reason to store that metadata, becuase the only thing we do with it is to calculate LLM prompts and display messages.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome. Let's just dump all of this into a comment explaining the chat history json we're creating so we all have this context going forward.

# Using a temporary file and rename for safer "atomic" writes
tmp_file = f"{self._save_file}.tmp"
try:
with open(tmp_file, "w", encoding="utf-8") as f:
json.dump(data, f, indent=2)
os.replace(tmp_file, self._save_file)
except Exception as e:
# log or handle error
print(f"Error saving history file: {e}")

def get_histories(self) -> tuple[List[Dict[str, str]], List[Dict[str, str]]]:
with self._lock:
return self._llm_history[:], self._display_history[:]

def clear_histories(self) -> None:
with self._lock:
self._llm_history = []
self._display_history = []
self._save_to_disk()

def append_message(self, llm_message: Dict[str, str], display_message: Dict[str, str]) -> None:
with self._lock:
self._llm_history.append(llm_message)
self._display_history.append(display_message)
self._save_to_disk()

def truncate_histories(self, index: int) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two quick things here:

  1. It looks like we are not removing the original message that the user edits. Maybe just an off by one bug here?
  2. In the UI, when the user edits the message and presses submit, we are not clearing all future messages until the AI responds. That looks a bit whacky. Can we clear the messages right away. If this isn't a regression introduced by this PR, then feel free to just open an issue and we can fix independently.
Screen.Recording.2025-01-28.at.5.11.05.PM.mov

with self._lock:
self._llm_history = self._llm_history[:index]
self._display_history = self._display_history[:index]
self._save_to_disk()

# Global history instance
message_history = GlobalMessageHistory()

# This handler is responsible for the mito-ai/completions endpoint.
# It takes a message from the user, sends it to the OpenAI API, and returns the response.
Expand All @@ -44,7 +111,6 @@ def initialize(self, llm: OpenAIProvider) -> None:
super().initialize()
self.log.debug("Initializing websocket connection %s", self.request.path)
self._llm = llm
self.full_message_history = []

@property
def log(self) -> logging.Logger:
Expand Down Expand Up @@ -92,9 +158,6 @@ def on_close(self) -> None:
# Stop observing the provider error
self._llm.unobserve(self._send_error, "last_error")

# Clear the message history
self.full_message_history = []

async def on_message(self, message: str) -> None:
"""Handle incoming messages on the WebSocket.

Expand All @@ -114,43 +177,57 @@ async def on_message(self, message: str) -> None:

# Clear history if the type is "clear_history"
if type == "clear_history":
self.full_message_history = []
message_history.clear_histories()
return

if type == "fetch_history":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to update the type to include fetch_history. Eventually we should do this #1461 to make sure we're catching these things

_, display_history = message_history.get_histories()
reply = HistoryReply(
parent_id=parsed_message.get('message_id'),
items=display_history
)

self.reply(reply)
return

messages = []

# Generate new message based on message type
if type == "inline_completion":
prompt = InlineCompletionMessageMetadata(**metadata_dict).prompt
elif type == "chat":
metadata = ChatMessageMetadata(**metadata_dict)
prompt = metadata.prompt

if metadata.index is not None:
# Clear the chat history after the specified index (inclusive)
self.full_message_history = self.full_message_history[:metadata.index]

elif type == "codeExplain":
prompt = CodeExplainMessageMetadata(**metadata_dict).prompt
elif type == "smartDebug":
prompt = SmartDebugMessageMetadata(**metadata_dict).prompt

new_message = {
"role": "user",
"content": prompt
}

# Inline completion uses its own websocket
# so we can reuse the full_message_history variable
if type == "inline_completion":
self.full_message_history = [new_message]
new_llm_message = {"role": "user", "content": prompt}
llm_history = [new_llm_message] # Inline completion uses its own history
else:
self.full_message_history.append(new_message)
llm_history, display_history = message_history.get_histories()

if type == "chat":
metadata = ChatMessageMetadata(**metadata_dict)
if metadata.index is not None:
message_history.truncate_histories(metadata.index)
llm_history, display_history = message_history.get_histories()
prompt = metadata.prompt
display_message = metadata.display_message
elif type == "codeExplain":
metadata = CodeExplainMessageMetadata(**metadata_dict)
prompt = metadata.prompt
display_message = metadata.display_message
elif type == "smartDebug":
metadata = SmartDebugMessageMetadata(**metadata_dict)
prompt = metadata.prompt
display_message = metadata.display_message

new_llm_message = {"role": "user", "content": prompt}
new_display_message = {"role": "user", "content": display_message}
message_history.append_message(new_llm_message, new_display_message)
llm_history, display_history = message_history.get_histories()

self.log.info(f"LLM message history: {json.dumps(llm_history, indent=2)}")
self.log.info(f"Display message history: {json.dumps(display_history, indent=2)}")

request = CompletionRequest(
type=type,
message_id=parsed_message.get('message_id'),
messages=self.full_message_history,
messages=llm_history,
stream=parsed_message.get('stream', False)
)

Expand Down Expand Up @@ -220,27 +297,45 @@ async def _handle_request(self, request: CompletionRequest, prompt_type: str) ->
"""
start = time.time()
reply = await self._llm.request_completions(request, prompt_type)
self.reply(reply)

# Save to the message history
# Inline completion is ephemeral and does not need to be saved
if request.type != "inline_completion":
self.full_message_history.append(
{
"role": "assistant",
"content": reply.items[0].content
}
)
response = reply.items[0].content if reply.items else ""

llm_message = {
"role": "assistant",
"content": response
}
display_message = {
"role": "assistant",
"content": response
}

if request.type == "smartDebug":
# Remove inner thoughts from the response
response = remove_inner_thoughts_from_message(response)
display_message["content"] = response

# Modify reply so the display message in the frontend is also have inner thoughts removed
reply.items[0] = CompletionItem(content=response, isIncomplete=reply.items[0].isIncomplete)

message_history.append_message(llm_message, display_message)


self.reply(reply)

latency_ms = round((time.time() - start) * 1000)
self.log.info(f"Completion handler resolved in {latency_ms} ms.")

self.log.info(f"LLM message history: {json.dumps(message_history.get_histories()[0], indent=2)}")
self.log.info(f"Display message history: {json.dumps(message_history.get_histories()[1], indent=2)}")
async def _handle_stream_request(self, request: CompletionRequest, prompt_type: str) -> None:
"""Handle stream completion request."""
start = time.time()

# Use a string buffer to accumulate the full response from streaming chunks.
# We need to accumulate the response on the backend so that we can save it to
# the full_message_history
# the message history after the streaming is complete.
accumulated_response = ""
async for reply in self._llm.stream_completions(request, prompt_type):
if isinstance(reply, CompletionStreamChunk):
Expand All @@ -249,12 +344,11 @@ async def _handle_stream_request(self, request: CompletionRequest, prompt_type:
self.reply(reply)

if request.type != "inline_completion":
self.full_message_history.append(
{
"role": "assistant",
"content": reply.items[0].content
}
)
message = {
"role": "assistant",
"content": accumulated_response
}
message_history.append_message(message, message)
latency_ms = round((time.time() - start) * 1000)
self.log.info(f"Completion streaming completed in {latency_ms} ms.")

Expand All @@ -266,6 +360,7 @@ def reply(self, reply: Any) -> None:
It must be a dataclass instance.
"""
message = asdict(reply)
self.log.info(f"Replying with: {json.dumps(message)}")
super().write_message(message)

def _send_error(self, change: Dict[str, Optional[CompletionError]]) -> None:
Expand Down
40 changes: 40 additions & 0 deletions mito-ai/mito_ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,16 @@ class ChatMessageMetadata:
@property
def prompt(self) -> str:
return create_chat_prompt(self.variables or [], self.activeCellCode or '', self.input or '')

@property
def display_message(self) -> str:
cell_code_block = f"""```python
{self.activeCellCode}
```

"""

return f"{cell_code_block if self.activeCellCode else ''}{self.input}"

@dataclass(frozen=True)
class SmartDebugMessageMetadata:
Expand All @@ -47,6 +57,15 @@ class SmartDebugMessageMetadata:
@property
def prompt(self) -> str:
return create_error_prompt(self.errorMessage or '', self.activeCellCode or '', self.variables or [])

@property
def display_message(self) -> str:
cell_code_block = f"""```python
{self.activeCellCode}
```

"""
return f"{cell_code_block if self.activeCellCode else ''}{self.errorMessage}"

@dataclass(frozen=True)
class CodeExplainMessageMetadata:
Expand All @@ -56,6 +75,16 @@ class CodeExplainMessageMetadata:
@property
def prompt(self) -> str:
return create_explain_code_prompt(self.activeCellCode or '')

@property
def display_message(self) -> str:
cell_code_block = f"""```python
{self.activeCellCode}
```

"""

return f"{cell_code_block if self.activeCellCode else ''}Explain this code"

@dataclass(frozen=True)
class InlineCompletionMessageMetadata:
Expand Down Expand Up @@ -169,3 +198,14 @@ class CompletionStreamChunk:
"""Message type."""
error: Optional[CompletionError] = None
"""Completion error."""

@dataclass(frozen=True)
class HistoryReply:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just update this to FetchHistoryReply to make it clear that this matches specifically with that type of request and not all events that effect the history

"""Message sent from model to client with the chat history."""

parent_id: str
"""Message UID."""
items: List[ChatCompletionMessageParam]
"""List of chat messages."""
type: Literal["reply"] = "reply"
"""Message type."""
15 changes: 14 additions & 1 deletion mito-ai/mito_ai/prompt_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,4 +378,17 @@ def parse_date(date_str):
SOLUTION:
"""
return prompt
return prompt

def remove_inner_thoughts_from_message(message: str) -> str:
# The smart debug prompt thinks to itself before returning the solution. We don't need to save the inner thoughts.
# We remove them before saving the message in the chat history
if message == "":
return message

SOLUTION_STRING = "SOLUTION:"

if SOLUTION_STRING in message:
message = message.split(SOLUTION_STRING)[1].strip()

return message
Loading
Loading