-
Notifications
You must be signed in to change notification settings - Fork 168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Mito AI: Restore chat on refresh #1488
base: dev
Are you sure you want to change the base?
Changes from all commits
841234d
ea1e976
20bf421
887f6bc
2075fed
d1424e6
6e80426
4d3dbcc
0d4ffb6
303f276
371a217
06ac4f2
d9fd663
c12f112
fc0fd9d
1c50027
cc9e1f9
bfe35c7
8400c46
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,11 @@ | ||
import os | ||
import json | ||
import logging | ||
import time | ||
from dataclasses import asdict | ||
from http import HTTPStatus | ||
from typing import Any, Awaitable, Dict, Optional, Literal | ||
from typing import Any, Awaitable, Dict, Optional, List | ||
from threading import Lock | ||
|
||
import tornado | ||
import tornado.ioloop | ||
|
@@ -26,12 +28,77 @@ | |
SmartDebugMessageMetadata, | ||
CodeExplainMessageMetadata, | ||
InlineCompletionMessageMetadata, | ||
HistoryReply | ||
) | ||
from .prompt_builders import remove_inner_thoughts_from_message | ||
from .providers import OpenAIProvider | ||
from .utils.create import initialize_user | ||
from .utils.schema import MITO_FOLDER | ||
|
||
__all__ = ["CompletionHandler"] | ||
|
||
# Global message history with thread-safe access | ||
class GlobalMessageHistory: | ||
def __init__(self, save_file: str = os.path.join(MITO_FOLDER, "message_history.json")): | ||
self._lock = Lock() | ||
self._llm_history: List[Dict[str, str]] = [] | ||
self._display_history: List[Dict[str, str]] = [] | ||
Comment on lines
+44
to
+45
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to keep our language the same across the code base, can we call these |
||
self._save_file = save_file | ||
|
||
# Load from disk on startup | ||
self._load_from_disk() | ||
|
||
def _load_from_disk(self): | ||
"""Load existing history from disk, if it exists.""" | ||
if os.path.exists(self._save_file): | ||
try: | ||
with open(self._save_file, "r", encoding="utf-8") as f: | ||
data = json.load(f) | ||
self._llm_history = data.get("llm_history", []) | ||
self._display_history = data.get("display_history", []) | ||
except Exception as e: | ||
print(f"Error loading history file: {e}") | ||
|
||
def _save_to_disk(self): | ||
"""Save current history to disk.""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to make our lives easier when we want to update the format of this file, can we create a field in the json called something like |
||
data = { | ||
"llm_history": self._llm_history, | ||
"display_history": self._display_history, | ||
} | ||
Comment on lines
+64
to
+67
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Another option here is to just store the metadata that we send from the frontend to the backend in order to create the prompt in the .json file. Then, we can easily recreate both the What do you think about refactoring this? Is there an advantage to storing everything in the .json file itself? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason for the backend to care about I chose the format of what is stored in JSON to match what is stored in the memory, so it is easy to serialize/deserialize. That being said, we don't have to strore full text representations of either and can recalculate on the fly from context objects. As I see it, there is no compeling reason to store that metadata, becuase the only thing we do with it is to calculate LLM prompts and display messages. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Awesome. Let's just dump all of this into a comment explaining the chat history json we're creating so we all have this context going forward. |
||
# Using a temporary file and rename for safer "atomic" writes | ||
tmp_file = f"{self._save_file}.tmp" | ||
try: | ||
with open(tmp_file, "w", encoding="utf-8") as f: | ||
json.dump(data, f, indent=2) | ||
os.replace(tmp_file, self._save_file) | ||
except Exception as e: | ||
# log or handle error | ||
print(f"Error saving history file: {e}") | ||
|
||
def get_histories(self) -> tuple[List[Dict[str, str]], List[Dict[str, str]]]: | ||
with self._lock: | ||
return self._llm_history[:], self._display_history[:] | ||
|
||
def clear_histories(self) -> None: | ||
with self._lock: | ||
self._llm_history = [] | ||
self._display_history = [] | ||
self._save_to_disk() | ||
|
||
def append_message(self, llm_message: Dict[str, str], display_message: Dict[str, str]) -> None: | ||
with self._lock: | ||
self._llm_history.append(llm_message) | ||
self._display_history.append(display_message) | ||
self._save_to_disk() | ||
|
||
def truncate_histories(self, index: int) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Two quick things here:
Screen.Recording.2025-01-28.at.5.11.05.PM.mov |
||
with self._lock: | ||
self._llm_history = self._llm_history[:index] | ||
self._display_history = self._display_history[:index] | ||
self._save_to_disk() | ||
|
||
# Global history instance | ||
message_history = GlobalMessageHistory() | ||
|
||
# This handler is responsible for the mito-ai/completions endpoint. | ||
# It takes a message from the user, sends it to the OpenAI API, and returns the response. | ||
|
@@ -44,7 +111,6 @@ def initialize(self, llm: OpenAIProvider) -> None: | |
super().initialize() | ||
self.log.debug("Initializing websocket connection %s", self.request.path) | ||
self._llm = llm | ||
self.full_message_history = [] | ||
|
||
@property | ||
def log(self) -> logging.Logger: | ||
|
@@ -92,9 +158,6 @@ def on_close(self) -> None: | |
# Stop observing the provider error | ||
self._llm.unobserve(self._send_error, "last_error") | ||
|
||
# Clear the message history | ||
self.full_message_history = [] | ||
|
||
async def on_message(self, message: str) -> None: | ||
"""Handle incoming messages on the WebSocket. | ||
|
||
|
@@ -114,43 +177,57 @@ async def on_message(self, message: str) -> None: | |
|
||
# Clear history if the type is "clear_history" | ||
if type == "clear_history": | ||
self.full_message_history = [] | ||
message_history.clear_histories() | ||
return | ||
|
||
if type == "fetch_history": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we need to update the |
||
_, display_history = message_history.get_histories() | ||
reply = HistoryReply( | ||
parent_id=parsed_message.get('message_id'), | ||
items=display_history | ||
) | ||
|
||
self.reply(reply) | ||
return | ||
|
||
messages = [] | ||
|
||
# Generate new message based on message type | ||
if type == "inline_completion": | ||
prompt = InlineCompletionMessageMetadata(**metadata_dict).prompt | ||
elif type == "chat": | ||
metadata = ChatMessageMetadata(**metadata_dict) | ||
prompt = metadata.prompt | ||
|
||
if metadata.index is not None: | ||
# Clear the chat history after the specified index (inclusive) | ||
self.full_message_history = self.full_message_history[:metadata.index] | ||
|
||
elif type == "codeExplain": | ||
prompt = CodeExplainMessageMetadata(**metadata_dict).prompt | ||
elif type == "smartDebug": | ||
prompt = SmartDebugMessageMetadata(**metadata_dict).prompt | ||
|
||
new_message = { | ||
"role": "user", | ||
"content": prompt | ||
} | ||
|
||
# Inline completion uses its own websocket | ||
# so we can reuse the full_message_history variable | ||
if type == "inline_completion": | ||
self.full_message_history = [new_message] | ||
new_llm_message = {"role": "user", "content": prompt} | ||
llm_history = [new_llm_message] # Inline completion uses its own history | ||
else: | ||
self.full_message_history.append(new_message) | ||
llm_history, display_history = message_history.get_histories() | ||
|
||
if type == "chat": | ||
metadata = ChatMessageMetadata(**metadata_dict) | ||
if metadata.index is not None: | ||
message_history.truncate_histories(metadata.index) | ||
llm_history, display_history = message_history.get_histories() | ||
prompt = metadata.prompt | ||
display_message = metadata.display_message | ||
elif type == "codeExplain": | ||
metadata = CodeExplainMessageMetadata(**metadata_dict) | ||
prompt = metadata.prompt | ||
display_message = metadata.display_message | ||
elif type == "smartDebug": | ||
metadata = SmartDebugMessageMetadata(**metadata_dict) | ||
prompt = metadata.prompt | ||
display_message = metadata.display_message | ||
|
||
new_llm_message = {"role": "user", "content": prompt} | ||
new_display_message = {"role": "user", "content": display_message} | ||
message_history.append_message(new_llm_message, new_display_message) | ||
llm_history, display_history = message_history.get_histories() | ||
|
||
self.log.info(f"LLM message history: {json.dumps(llm_history, indent=2)}") | ||
self.log.info(f"Display message history: {json.dumps(display_history, indent=2)}") | ||
|
||
request = CompletionRequest( | ||
type=type, | ||
message_id=parsed_message.get('message_id'), | ||
messages=self.full_message_history, | ||
messages=llm_history, | ||
stream=parsed_message.get('stream', False) | ||
) | ||
|
||
|
@@ -220,27 +297,45 @@ async def _handle_request(self, request: CompletionRequest, prompt_type: str) -> | |
""" | ||
start = time.time() | ||
reply = await self._llm.request_completions(request, prompt_type) | ||
self.reply(reply) | ||
|
||
# Save to the message history | ||
# Inline completion is ephemeral and does not need to be saved | ||
if request.type != "inline_completion": | ||
self.full_message_history.append( | ||
{ | ||
"role": "assistant", | ||
"content": reply.items[0].content | ||
} | ||
) | ||
response = reply.items[0].content if reply.items else "" | ||
|
||
llm_message = { | ||
"role": "assistant", | ||
"content": response | ||
} | ||
display_message = { | ||
"role": "assistant", | ||
"content": response | ||
} | ||
|
||
if request.type == "smartDebug": | ||
# Remove inner thoughts from the response | ||
response = remove_inner_thoughts_from_message(response) | ||
display_message["content"] = response | ||
|
||
# Modify reply so the display message in the frontend is also have inner thoughts removed | ||
reply.items[0] = CompletionItem(content=response, isIncomplete=reply.items[0].isIncomplete) | ||
|
||
message_history.append_message(llm_message, display_message) | ||
|
||
|
||
self.reply(reply) | ||
|
||
latency_ms = round((time.time() - start) * 1000) | ||
self.log.info(f"Completion handler resolved in {latency_ms} ms.") | ||
|
||
self.log.info(f"LLM message history: {json.dumps(message_history.get_histories()[0], indent=2)}") | ||
self.log.info(f"Display message history: {json.dumps(message_history.get_histories()[1], indent=2)}") | ||
async def _handle_stream_request(self, request: CompletionRequest, prompt_type: str) -> None: | ||
"""Handle stream completion request.""" | ||
start = time.time() | ||
|
||
# Use a string buffer to accumulate the full response from streaming chunks. | ||
# We need to accumulate the response on the backend so that we can save it to | ||
# the full_message_history | ||
# the message history after the streaming is complete. | ||
accumulated_response = "" | ||
async for reply in self._llm.stream_completions(request, prompt_type): | ||
if isinstance(reply, CompletionStreamChunk): | ||
|
@@ -249,12 +344,11 @@ async def _handle_stream_request(self, request: CompletionRequest, prompt_type: | |
self.reply(reply) | ||
|
||
if request.type != "inline_completion": | ||
self.full_message_history.append( | ||
{ | ||
"role": "assistant", | ||
"content": reply.items[0].content | ||
} | ||
) | ||
message = { | ||
"role": "assistant", | ||
"content": accumulated_response | ||
} | ||
message_history.append_message(message, message) | ||
latency_ms = round((time.time() - start) * 1000) | ||
self.log.info(f"Completion streaming completed in {latency_ms} ms.") | ||
|
||
|
@@ -266,6 +360,7 @@ def reply(self, reply: Any) -> None: | |
It must be a dataclass instance. | ||
""" | ||
message = asdict(reply) | ||
self.log.info(f"Replying with: {json.dumps(message)}") | ||
super().write_message(message) | ||
|
||
def _send_error(self, change: Dict[str, Optional[CompletionError]]) -> None: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,6 +37,16 @@ class ChatMessageMetadata: | |
@property | ||
def prompt(self) -> str: | ||
return create_chat_prompt(self.variables or [], self.activeCellCode or '', self.input or '') | ||
|
||
@property | ||
def display_message(self) -> str: | ||
cell_code_block = f"""```python | ||
{self.activeCellCode} | ||
``` | ||
|
||
""" | ||
|
||
return f"{cell_code_block if self.activeCellCode else ''}{self.input}" | ||
|
||
@dataclass(frozen=True) | ||
class SmartDebugMessageMetadata: | ||
|
@@ -47,6 +57,15 @@ class SmartDebugMessageMetadata: | |
@property | ||
def prompt(self) -> str: | ||
return create_error_prompt(self.errorMessage or '', self.activeCellCode or '', self.variables or []) | ||
|
||
@property | ||
def display_message(self) -> str: | ||
cell_code_block = f"""```python | ||
{self.activeCellCode} | ||
``` | ||
|
||
""" | ||
return f"{cell_code_block if self.activeCellCode else ''}{self.errorMessage}" | ||
|
||
@dataclass(frozen=True) | ||
class CodeExplainMessageMetadata: | ||
|
@@ -56,6 +75,16 @@ class CodeExplainMessageMetadata: | |
@property | ||
def prompt(self) -> str: | ||
return create_explain_code_prompt(self.activeCellCode or '') | ||
|
||
@property | ||
def display_message(self) -> str: | ||
cell_code_block = f"""```python | ||
{self.activeCellCode} | ||
``` | ||
|
||
""" | ||
|
||
return f"{cell_code_block if self.activeCellCode else ''}Explain this code" | ||
|
||
@dataclass(frozen=True) | ||
class InlineCompletionMessageMetadata: | ||
|
@@ -169,3 +198,14 @@ class CompletionStreamChunk: | |
"""Message type.""" | ||
error: Optional[CompletionError] = None | ||
"""Completion error.""" | ||
|
||
@dataclass(frozen=True) | ||
class HistoryReply: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we just update this to |
||
"""Message sent from model to client with the chat history.""" | ||
|
||
parent_id: str | ||
"""Message UID.""" | ||
items: List[ChatCompletionMessageParam] | ||
"""List of chat messages.""" | ||
type: Literal["reply"] = "reply" | ||
"""Message type.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just add some more documentation here about the thread-safe piece. Just explain why we need this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's also move the
GlobalMessageHistory
to its own file, just to keep things more isolated