Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix agent responses for conversation #1668

Merged
merged 3 commits into from
Dec 6, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions py/core/main/api/templates/log_viewer.html
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
<head>
<title>R2R Log Viewer</title>
<style>
body {
margin: 20px;
body {
margin: 20px;
font-family: monospace;
background: #f8f9fa;
}
#logs {
white-space: pre-wrap;
#logs {
white-space: pre-wrap;
background: white;
padding: 20px;
border-radius: 4px;
Expand All @@ -31,38 +31,38 @@
<body>
<h2>R2R Log Viewer</h2>
<div id="logs"><span class="status">Connecting to log stream...</span></div>

<!-- Include ansi_up via a CDN -->
<script src="https://cdn.jsdelivr.net/npm/ansi_up@5.0.0/ansi_up.min.js"></script>
<script>
let ws = null;
let ansi_up = new AnsiUp();

function connect() {
if (ws) {
ws.close();
}

ws = new WebSocket(`ws://${window.location.host}/v3/logs/stream`);

ws.onmessage = function(event) {
const logsDiv = document.getElementById("logs");
const newEntry = document.createElement('div');
newEntry.className = 'log-entry';

// Convert ANSI to HTML
const htmlContent = ansi_up.ansi_to_html(event.data);
newEntry.innerHTML = htmlContent;
logsDiv.appendChild(newEntry);

// Keep only the last 1000 entries
while (logsDiv.children.length > 1000) {
logsDiv.removeChild(logsDiv.firstChild);
}

logsDiv.scrollTop = logsDiv.scrollHeight;
};

ws.onclose = function() {
const logsDiv = document.getElementById("logs");
const msg = document.createElement('div');
Expand All @@ -71,14 +71,14 @@ <h2>R2R Log Viewer</h2>
logsDiv.appendChild(msg);
setTimeout(connect, 1000);
};

ws.onerror = function(err) {
console.error('WebSocket error:', err);
};
}

connect();

window.onbeforeunload = function() {
if (ws) {
ws.close();
Expand Down
53 changes: 25 additions & 28 deletions py/core/main/services/retrieval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,32 +302,31 @@ async def agent(
ids = []

if conversation_id:
# Fetch existing conversation
conversation = (
await self.logging_connection.get_conversation(
conversation_id, branch_id
)
)
if not conversation:
logger.error(
f"No conversation found for ID: {conversation_id}"
)
raise R2RException(
status_code=404,
message=f"Conversation not found: {conversation_id}",
try:
# Fetch existing conversation
conversation = (
await self.logging_connection.get_conversation(
conversation_id=conversation_id,
branch_id=branch_id,
)
)
except Exception as e:
logger.error(f"Error logging conversation: {str(e)}")
# Assuming 'conversation' is a list of dicts with 'id' and 'message' keys
messages_from_conversation = []
for resp in conversation:
if isinstance(resp, dict):
msg = Message.from_dict(resp["message"])
messages_from_conversation.append(msg)
ids.append(resp["id"])
else:
logger.error(
f"Unexpected type in conversation: {type(resp)}"
)
messages = messages_from_conversation + messages

if conversation is not None:
print("Gets into messages_from_conversation")
messages_from_conversation: list[Message] = []
for resp in conversation:
if isinstance(resp, dict):
msg = Message.from_dict(resp["message"])
messages_from_conversation.append(msg)
ids.append(resp["id"])
else:
logger.warning(
f"Unexpected type in conversation: {type(resp)}\n{resp}"
)
messages = messages_from_conversation + messages
else:
# Create new conversation
conversation_id = (
Expand Down Expand Up @@ -399,8 +398,6 @@ async def stream_response():
search_settings=search_settings,
rag_generation_config=rag_generation_config,
include_title_if_available=include_title_if_available,
*args,
**kwargs,
)

# Save the assistant's reply to the conversation
Expand Down Expand Up @@ -428,7 +425,7 @@ async def stream_response():
value=latency,
)
return {
"messages": [msg.to_dict() for msg in results],
"messages": results,
"conversation_id": str(
conversation_id
), # Ensure it's a string
Expand All @@ -443,7 +440,7 @@ async def stream_response():
)
raise HTTPException(
status_code=500,
detail="Internal Server Error",
detail=f"Internal Server Error - {str(e)}",
)


Expand Down
14 changes: 10 additions & 4 deletions py/core/providers/logger/r2r_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,7 @@ async def get_conversation(
)
conversation_created_at = row[0]

print(f"Getting a branch_id: {branch_id}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid using print statements for debugging. Use logging instead for better control and to avoid cluttering the output in production environments.

if branch_id is None:
# Get the most recent branch by created_at timestamp
async with self.conn.execute(
Expand All @@ -691,14 +692,19 @@ async def get_conversation(
(conversation_id,),
) as cursor:
row = await cursor.fetchone()
print(f"Row: {row}")
branch_id = row[0] if row else None

# If no branch exists, return empty results but with required fields
if branch_id is None:
return {
"id": conversation_id,
"created_at": conversation_created_at,
}
logger.warning(
f"No branches found for conversation ID {conversation_id}"
)
return None
# return {
# "id": conversation_id,
# "created_at": conversation_created_at,
# }

# Get all messages for this branch
async with self.conn.execute(
Expand Down
Loading