Skip to content

Commit

Permalink
Conversations Updates (#1716)
Browse files Browse the repository at this point in the history
* Conversations Updates

* Bump js package
  • Loading branch information
NolanTrem authored Dec 21, 2024
1 parent 6faec50 commit 8b38f08
Show file tree
Hide file tree
Showing 9 changed files with 135 additions and 98 deletions.
78 changes: 69 additions & 9 deletions js/sdk/__tests__/ConversationsIntegrationSuperUser.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,75 @@ describe("r2rClient V3 Collections Integration Tests", () => {
expect(response.results).toBeDefined();
});

// TODO: This is throwing a 405? Why?
// test("Update a message in a conversation", async () => {
// const response = await client.conversations.updateMessage({
// id: conversationId,
// message_id: messageId,
// content: "Hello, world! How are you?",
// });
// expect(response.results).toBeDefined();
// });
test("Update message content only", async () => {
const newContent = "Updated content";
const response = await client.conversations.updateMessage({
id: conversationId,
messageID: messageId,
content: newContent,
});
expect(response.results).toBeDefined();
expect(response.results.message.content).toBe(newContent);
expect(response.results.metadata.edited).toBe(true);
});

test("Update metadata only", async () => {
const newMetadata = { test: "value" };
const response = await client.conversations.updateMessage({
id: conversationId,
messageID: messageId,
metadata: newMetadata,
});
expect(response.results).toBeDefined();
expect(response.results.metadata.test).toBe("value");
expect(response.results.metadata.edited).toBe(true);
expect(response.results.message.content).toBe("Updated content");
});

test("Update both content and metadata", async () => {
const newContent = "Both updated";
const newMetadata = { key: "value" };
const response = await client.conversations.updateMessage({
id: conversationId,
messageID: messageId,
content: newContent,
metadata: newMetadata,
});
expect(response.results).toBeDefined();
expect(response.results.message.content).toBe(newContent);
expect(response.results.metadata.key).toBe("value");
expect(response.results.metadata.edited).toBe(true);
});

test("Handle empty message update", async () => {
const response = await client.conversations.updateMessage({
id: conversationId,
messageID: messageId,
});
expect(response.results).toBeDefined();
expect(response.results.message.content).toBe("Both updated");
expect(response.results.metadata.edited).toBe(true);
});

test("Reject update with invalid conversation ID", async () => {
await expect(
client.conversations.updateMessage({
id: "invalid-id",
messageID: messageId,
content: "test",
}),
).rejects.toThrow();
});

test("Reject update with invalid message ID", async () => {
await expect(
client.conversations.updateMessage({
id: conversationId,
messageID: "invalid-message-id",
content: "test",
}),
).rejects.toThrow();
});

test("Delete a conversation", async () => {
const response = await client.conversations.delete({ id: conversationId });
Expand Down
2 changes: 1 addition & 1 deletion js/sdk/package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion js/sdk/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "r2r-js",
"version": "0.4.6",
"version": "0.4.7",
"description": "",
"main": "dist/index.js",
"browser": "dist/index.browser.js",
Expand Down
11 changes: 3 additions & 8 deletions js/sdk/src/r2rClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1442,14 +1442,9 @@ export class r2rClient extends BaseClient {
* @returns A promise that resolves to the response from the server.
*/
@feature("getConversation")
async getConversation(
conversationId: string,
): Promise<Record<string, any>> {
async getConversation(conversationId: string): Promise<Record<string, any>> {
this._ensureAuthenticated();
return this._makeRequest(
"GET",
`get_conversation/${conversationId}`,
);
return this._makeRequest("GET", `get_conversation/${conversationId}`);
}

/**
Expand Down Expand Up @@ -1994,7 +1989,7 @@ export class r2rClient extends BaseClient {
rag_generation_config,
task_prompt_override,
include_title_if_available,
conversation_id
conversation_id,
};

Object.keys(json_data).forEach(
Expand Down
7 changes: 5 additions & 2 deletions js/sdk/src/v3/clients/conversations.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,19 @@ export class ConversationsClient {
* @param id The ID of the conversation containing the message
* @param messageID The ID of the message to update
* @param content The new content of the message
* @param metadata Additional metadata to attach to the message
* @returns
*/
@feature("conversations.updateMessage")
async updateMessage(options: {
id: string;
messageID: string;
content: string;
content?: string;
metadata?: Record<string, any>;
}): Promise<any> {
const data: Record<string, any> = {
content: options.content,
...(options.content && { content: options.content }),
...(options.metadata && { metadata: options.metadata }),
};

return this.client.makeRequest(
Expand Down
87 changes: 39 additions & 48 deletions py/core/database/conversations.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,12 @@ async def add_message(
async def edit_message(
self,
message_id: UUID,
new_content: str,
additional_metadata: dict = {},
new_content: str | None = None,
additional_metadata: dict = None,
) -> dict[str, Any]:
# Get the original message
query = f"""
SELECT conversation_id, parent_id, content, metadata
SELECT conversation_id, parent_id, content, metadata, created_at
FROM {self._get_table_name("messages")}
WHERE id = $1
"""
Expand All @@ -247,33 +247,46 @@ async def edit_message(
old_content = json.loads(row["content"])
old_metadata = json.loads(row["metadata"])

# Update the content
old_message = Message(**old_content)
edited_message = Message(
role=old_message.role,
content=new_content,
name=old_message.name,
function_call=old_message.function_call,
tool_calls=old_message.tool_calls,
)

# Merge metadata and mark edited
new_metadata = {**old_metadata, **additional_metadata, "edited": True}
if new_content is not None:
old_message = Message(**old_content)
edited_message = Message(
role=old_message.role,
content=new_content,
name=old_message.name,
function_call=old_message.function_call,
tool_calls=old_message.tool_calls,
)
content_to_save = edited_message.model_dump()
else:
content_to_save = old_content

additional_metadata = additional_metadata or {}

new_metadata = {
**old_metadata,
**additional_metadata,
"edited": (
True
if new_content is not None
else old_metadata.get("edited", False)
),
}

# Instead of branching, we'll simply replace the message content and metadata:
# NOTE: If you prefer versioning or forking behavior, you'd add a new message.
# For simplicity, we just edit the existing message.
# Update message without changing the timestamp
update_query = f"""
UPDATE {self._get_table_name("messages")}
SET content = $1::jsonb, metadata = $2::jsonb, created_at = NOW()
WHERE id = $3
SET content = $1::jsonb,
metadata = $2::jsonb,
created_at = $3
WHERE id = $4
RETURNING id
"""
updated = await self.connection_manager.fetchrow_query(
update_query,
[
json.dumps(edited_message.model_dump()),
json.dumps(content_to_save),
json.dumps(new_metadata),
row["created_at"],
message_id,
],
)
Expand All @@ -284,36 +297,14 @@ async def edit_message(

return {
"id": str(message_id),
"message": edited_message,
"message": (
Message(**content_to_save)
if isinstance(content_to_save, dict)
else content_to_save
),
"metadata": new_metadata,
}

async def update_message_metadata(
self, message_id: UUID, metadata: dict
) -> None:
# Fetch current metadata
query = f"""
SELECT metadata FROM {self._get_table_name("messages")}
WHERE id = $1
"""
row = await self.connection_manager.fetchrow_query(query, [message_id])
if not row:
raise R2RException(
status_code=404, message=f"Message {message_id} not found."
)

current_metadata = row["metadata"] or {}
updated_metadata = {**current_metadata, **metadata}

update_query = f"""
UPDATE {self._get_table_name("messages")}
SET metadata = $1::jsonb
WHERE id = $2
"""
await self.connection_manager.execute_query(
update_query, [updated_metadata, message_id]
)

async def get_conversation(
self, conversation_id: UUID
) -> list[MessageResponse]:
Expand Down
17 changes: 7 additions & 10 deletions py/core/main/api/v3/conversations_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@
WrappedConversationsResponse,
WrappedMessageResponse,
)
from core.providers import (
HatchetOrchestrationProvider,
SimpleOrchestrationProvider,
)

from ...abstractions import R2RProviders, R2RServices
from .base_router import BaseRouterV3
Expand Down Expand Up @@ -504,11 +500,11 @@ async def update_message(
id: UUID = Path(
..., description="The unique identifier of the conversation"
),
message_id: str = Path(
message_id: UUID = Path(
..., description="The ID of the message to update"
),
content: str = Body(
..., description="The new content for the message"
content: Optional[str] = Body(
None, description="The new content for the message"
),
metadata: Optional[dict[str, str]] = Body(
None, description="Additional metadata for the message"
Expand All @@ -520,7 +516,8 @@ async def update_message(
This endpoint updates the content of an existing message in a conversation.
"""
messge_response = await self.services.management.edit_message(
message_id, content, metadata
return await self.services.management.edit_message(
message_id=message_id,
new_content=content,
additional_metadata=metadata,
)
return messge_response
23 changes: 7 additions & 16 deletions py/core/main/services/management_service.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import logging
import os
from collections import defaultdict
from copy import copy
from typing import Any, BinaryIO, Optional, Tuple
from uuid import UUID

import toml
from fastapi.responses import StreamingResponse

from core.base import (
CollectionResponse,
Expand All @@ -19,7 +17,6 @@
RunManager,
User,
)
from core.base.utils import validate_uuid
from core.telemetry.telemetry_decorator import telemetry_event

from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders
Expand Down Expand Up @@ -862,25 +859,19 @@ async def add_message(
@telemetry_event("EditMessage")
async def edit_message(
self,
message_id: str,
new_content: str,
additional_metadata: dict,
message_id: UUID,
new_content: Optional[str] = None,
additional_metadata: Optional[dict] = None,
auth_user=None,
) -> Tuple[str, str]:
) -> dict[str, Any]:
return (
await self.providers.database.conversations_handler.edit_message(
message_id, new_content, additional_metadata
message_id=message_id,
new_content=new_content,
additional_metadata=additional_metadata or {},
)
)

@telemetry_event("updateMessageMetadata")
async def update_message_metadata(
self, message_id: str, metadata: dict, auth_user=None
):
await self.providers.database.conversations_handler.update_message_metadata(
message_id, metadata
)

@telemetry_event("DeleteConversation")
async def delete_conversation(self, conversation_id: str, auth_user=None):
await self.providers.database.conversations_handler.delete_conversation(
Expand Down
6 changes: 3 additions & 3 deletions py/sdk/v3/conversations.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,24 +137,24 @@ async def update_message(
self,
id: str | UUID,
message_id: str,
content: str,
content: Optional[str] = None,
metadata: Optional[dict] = None,
) -> dict:
"""
Update an existing message in a conversation.
Args:
id (Union[str, UUID]): The ID of the conversation containing the message
id (str | UUID): The ID of the conversation containing the message
message_id (str): The ID of the message to update
content (str): The new content of the message
metadata (dict): Additional metadata to attach to the message
Returns:
dict: Result of the operation, including the new message ID and branch ID
"""
data = {"content": content}
if metadata:
data["metadata"] = metadata
# data = {"content": content}
return await self.client._make_request(
"POST",
f"conversations/{str(id)}/messages/{message_id}",
Expand Down

0 comments on commit 8b38f08

Please sign in to comment.