Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix agent responses for conversation #1668

Merged
merged 3 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions py/core/main/api/templates/log_viewer.html
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
<head>
<title>R2R Log Viewer</title>
<style>
body {
margin: 20px;
body {
margin: 20px;
font-family: monospace;
background: #f8f9fa;
}
#logs {
white-space: pre-wrap;
#logs {
white-space: pre-wrap;
background: white;
padding: 20px;
border-radius: 4px;
Expand All @@ -31,38 +31,38 @@
<body>
<h2>R2R Log Viewer</h2>
<div id="logs"><span class="status">Connecting to log stream...</span></div>

<!-- Include ansi_up via a CDN -->
<script src="https://cdn.jsdelivr.net/npm/ansi_up@5.0.0/ansi_up.min.js"></script>
<script>
let ws = null;
let ansi_up = new AnsiUp();

function connect() {
if (ws) {
ws.close();
}

ws = new WebSocket(`ws://${window.location.host}/v3/logs/stream`);

ws.onmessage = function(event) {
const logsDiv = document.getElementById("logs");
const newEntry = document.createElement('div');
newEntry.className = 'log-entry';

// Convert ANSI to HTML
const htmlContent = ansi_up.ansi_to_html(event.data);
newEntry.innerHTML = htmlContent;
logsDiv.appendChild(newEntry);

// Keep only the last 1000 entries
while (logsDiv.children.length > 1000) {
logsDiv.removeChild(logsDiv.firstChild);
}

logsDiv.scrollTop = logsDiv.scrollHeight;
};

ws.onclose = function() {
const logsDiv = document.getElementById("logs");
const msg = document.createElement('div');
Expand All @@ -71,14 +71,14 @@ <h2>R2R Log Viewer</h2>
logsDiv.appendChild(msg);
setTimeout(connect, 1000);
};

ws.onerror = function(err) {
console.error('WebSocket error:', err);
};
}

connect();

window.onbeforeunload = function() {
if (ws) {
ws.close();
Expand Down
142 changes: 65 additions & 77 deletions py/core/main/services/retrieval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from core.providers.logger.r2r_logger import SqlitePersistentLoggingProvider
from core.telemetry.telemetry_decorator import telemetry_event

from shared.api.models.management.responses import MessageResponse

from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders
from ..config import R2RConfig
from .base import Service
Expand Down Expand Up @@ -248,17 +250,13 @@ async def agent(
search_settings: SearchSettings = SearchSettings(),
task_prompt_override: Optional[str] = None,
include_title_if_available: Optional[bool] = False,
conversation_id: Optional[str] = None,
branch_id: Optional[str] = None,
conversation_id: Optional[UUID] = None,
branch_id: Optional[UUID] = None,
message: Optional[Message] = None,
messages: Optional[list[Message]] = None,
*args,
**kwargs,
):
async with manage_run(self.run_manager, RunType.RETRIEVAL) as run_id:
try:
t0 = time.time()

if message and messages:
raise R2RException(
status_code=400,
Expand All @@ -278,19 +276,35 @@ async def agent(
else:
raise R2RException(
status_code=400,
message="Invalid message format",
message="""
Invalid message format. The expected format contains:
role: MessageType | 'system' | 'user' | 'assistant' | 'function'
content: Optional[str]
name: Optional[str]
function_call: Optional[dict[str, Any]]
tool_calls: Optional[list[dict[str, Any]]]
""",
)

# Ensure 'messages' is a list of Message instances
if messages:
messages = [
(
msg
if isinstance(msg, Message)
else Message.from_dict(msg)
)
for msg in messages
]
processed_messages = []
for message in messages:
if isinstance(message, Message):
processed_messages.append(message)
elif hasattr(message, "dict"):
processed_messages.append(
Message.from_dict(message.dict())
)
elif isinstance(message, dict):
processed_messages.append(
Message.from_dict(message)
)
else:
processed_messages.append(
Message.from_dict(str(message))
)
messages = processed_messages
else:
messages = []

Expand All @@ -301,42 +315,36 @@ async def agent(

ids = []

if conversation_id:
# Fetch existing conversation
conversation = (
await self.logging_connection.get_conversation(
conversation_id, branch_id
)
)
if not conversation:
logger.error(
f"No conversation found for ID: {conversation_id}"
)
raise R2RException(
status_code=404,
message=f"Conversation not found: {conversation_id}",
)
# Assuming 'conversation' is a list of dicts with 'id' and 'message' keys
messages_from_conversation = []
for resp in conversation:
if isinstance(resp, dict):
msg = Message.from_dict(resp["message"])
messages_from_conversation.append(msg)
ids.append(resp["id"])
else:
logger.error(
f"Unexpected type in conversation: {type(resp)}"
if conversation_id: # Fetch the existing conversation
try:
conversation = (
await self.logging_connection.get_conversation(
conversation_id=conversation_id,
branch_id=branch_id,
)
messages = messages_from_conversation + messages
else:
# Create new conversation
conversation_id = (
)
except Exception as e:
logger.error(f"Error fetching conversation: {str(e)}")

if conversation is not None:
messages_from_conversation: list[Message] = []
for message_response in conversation:
if isinstance(message_response, MessageResponse):
messages_from_conversation.append(
message_response.message
)
ids.append(message_response.id)
else:
logger.warning(
f"Unexpected type in conversation found: {type(message_response)}\n{message_response}"
)
messages = messages_from_conversation + messages
else: # Create new conversation
conversation_response = (
await self.logging_connection.create_conversation()
)
ids = []
# messages already initialized earlier
conversation_id = conversation_response.id

# Append 'message' to 'messages' if provided
if message:
messages.append(message)

Expand All @@ -350,27 +358,19 @@ async def agent(

# Save the new message to the conversation
parent_id = ids[-1] if ids else None

message_response = await self.logging_connection.add_message(
conversation_id,
current_message,
conversation_id=conversation_id,
content=current_message,
parent_id=parent_id,
)

if message_response is not None:
message_id = message_response["id"]
else:
message_id = None
message_id = (
message_response.id
if message_response is not None
else None
)

if rag_generation_config.stream:
t1 = time.time()
latency = f"{t1 - t0:.2f}"

await self.logging_connection.log(
run_id=run_id,
key="rag_agent_generation_latency",
value=latency,
)

async def stream_response():
async with manage_run(self.run_manager, "rag_agent"):
Expand All @@ -386,8 +386,6 @@ async def stream_response():
search_settings=search_settings,
rag_generation_config=rag_generation_config,
include_title_if_available=include_title_if_available,
*args,
**kwargs,
):
yield chunk

Expand All @@ -399,8 +397,6 @@ async def stream_response():
search_settings=search_settings,
rag_generation_config=rag_generation_config,
include_title_if_available=include_title_if_available,
*args,
**kwargs,
)

# Save the assistant's reply to the conversation
Expand All @@ -419,31 +415,23 @@ async def stream_response():
parent_id=message_id,
)

t1 = time.time()
latency = f"{t1 - t0:.2f}"

await self.logging_connection.log(
run_id=run_id,
key="rag_agent_generation_latency",
value=latency,
)
return {
"messages": [msg.to_dict() for msg in results],
"messages": results,
"conversation_id": str(
conversation_id
), # Ensure it's a string
}

except Exception as e:
logger.error(f"Pipeline error: {str(e)}")
logger.error(f"Error in agent response: {str(e)}")
if "NoneType" in str(e):
raise HTTPException(
status_code=502,
detail="Server not reachable or returned an invalid response",
)
raise HTTPException(
status_code=500,
detail="Internal Server Error",
detail=f"Internal Server Error - {str(e)}",
)


Expand Down
Loading
Loading