Skip to content

Commit

Permalink
rename memory field names in responses (#2017)
Browse files Browse the repository at this point in the history
* rename memory field names in responses

Signed-off-by: Xun Zhang <xunzh@amazon.com>

* fix the IT test

Signed-off-by: Xun Zhang <xunzh@amazon.com>

---------

Signed-off-by: Xun Zhang <xunzh@amazon.com>
  • Loading branch information
Zhangxunmt authored Feb 6, 2024
1 parent b62b0de commit a07dd14
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ public class ActionConstants {
public final static String CONVERSATION_ID_FIELD = "memory_id";

/** name of list of conversations in all responses */
public final static String RESPONSE_CONVERSATION_LIST_FIELD = "conversations";
public final static String RESPONSE_CONVERSATION_LIST_FIELD = "memories";
/** name of list on interactions in all responses */
public final static String RESPONSE_INTERACTION_LIST_FIELD = "interactions";
public final static String RESPONSE_INTERACTION_LIST_FIELD = "messages";
/** name of list on traces in all responses */
public final static String RESPONSE_TRACES_LIST_FIELD = "traces";
/** name of interaction Id field in all responses */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public void testToXContent_MoreTokens() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
response.toXContent(builder, ToXContent.EMPTY_PARAMS);
String result = BytesReference.bytes(builder).utf8ToString();
String expected = "{\"conversations\":[{\"memory_id\":\"0\",\"create_time\":\""
String expected = "{\"memories\":[{\"memory_id\":\"0\",\"create_time\":\""
+ conversation.getCreatedTime()
+ "\"updated_time\":\""
+ conversation.getUpdatedTime()
Expand All @@ -93,7 +93,7 @@ public void testToXContent_NoMoreTokens() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
response.toXContent(builder, ToXContent.EMPTY_PARAMS);
String result = BytesReference.bytes(builder).utf8ToString();
String expected = "{\"conversations\":[{\"memory_id\":\"0\",\"create_time\":\""
String expected = "{\"memories\":[{\"memory_id\":\"0\",\"create_time\":\""
+ conversation.getCreatedTime()
+ "\"updated_time\":\""
+ conversation.getUpdatedTime()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,11 @@ public void testToXContent_MoreTokens() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
response.toXContent(builder, ToXContent.EMPTY_PARAMS);
String result = BytesReference.bytes(builder).utf8ToString();
String expected = "{\"interactions\":[{\"memory_id\":\"cid\",\"message_id\":\"id0\",\"create_time\":\""
String expected = "{\"messages\":[{\"memory_id\":\"cid\",\"message_id\":\"id0\",\"create_time\":\""
+ interaction.getCreateTime()
+ "\",\"input\":\"input\",\"prompt_template\":\"pt\",\"response\":\"response\",\"origin\":\"origin\",\"additional_info\":{\"metadata\":\"some meta\"}}],\"next_token\":2}";
log.info(result);
log.info(expected);
// Sometimes there's an extra trailing 0 in the time stringification, so just assert closeness
LevenshteinDistance ld = new LevenshteinDistance();
log.info(ld.getDistance(result, expected));
assert (ld.getDistance(result, expected) > 0.95);
}

Expand All @@ -117,14 +114,11 @@ public void testToXContent_NoMoreTokens() throws IOException {
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
response.toXContent(builder, ToXContent.EMPTY_PARAMS);
String result = BytesReference.bytes(builder).utf8ToString();
String expected = "{\"interactions\":[{\"memory_id\":\"cid\",\"message_id\":\"id0\",\"create_time\":\""
String expected = "{\"messages\":[{\"memory_id\":\"cid\",\"message_id\":\"id0\",\"create_time\":\""
+ interaction.getCreateTime()
+ "\",\"input\":\"input\",\"prompt_template\":\"pt\",\"response\":\"response\",\"origin\":\"origin\",\"additional_info\":{\"metadata\":\"some meta\"}}]}";
log.info(result);
log.info(expected);
// Sometimes there's an extra trailing 0 in the time stringification, so just assert closeness
LevenshteinDistance ld = new LevenshteinDistance();
log.info(ld.getDistance(result, expected));
assert (ld.getDistance(result, expected) > 0.95);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@ public void testNoConversations_EmptyList() throws IOException {
HttpEntity httpEntity = response.getEntity();
String entityString = TestHelper.httpEntityToString(httpEntity);
Map map = gson.fromJson(entityString, Map.class);
assert (map.containsKey("conversations"));
assert (map.containsKey(RESPONSE_CONVERSATION_LIST_FIELD));
assert (!map.containsKey("next_token"));
assert (((ArrayList) map.get("conversations")).size() == 0);
assert (((ArrayList) map.get(RESPONSE_CONVERSATION_LIST_FIELD)).size() == 0);
}

public void testGetConversations_LastPage() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ public void testGetInteractions_NoConversation() throws IOException {
HttpEntity httpEntity = response.getEntity();
String entityString = TestHelper.httpEntityToString(httpEntity);
Map map = gson.fromJson(entityString, Map.class);
assert (map.containsKey("interactions"));
assert (map.containsKey(RESPONSE_INTERACTION_LIST_FIELD));
assert (!map.containsKey("next_token"));
assert (((ArrayList) map.get("interactions")).size() == 0);
assert (((ArrayList) map.get(RESPONSE_INTERACTION_LIST_FIELD)).size() == 0);
}

public void testGetInteractions_NoInteractions() throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {

// Optional parameter; if provided, conversational memory will be used for RAG
// and the current interaction will be saved in the conversation referenced by this id.
private static final ParseField CONVERSATION_ID = new ParseField("conversation_id");
private static final ParseField CONVERSATION_ID = new ParseField("memory_id");

// Optional parameter; if an LLM model is not set at the search pipeline level, one must be
// provided at the search request level.
Expand All @@ -64,7 +64,7 @@ public class GenerativeQAParameters implements Writeable, ToXContentObject {

// Optional parameter; this parameter controls the number of the interactions to include
// in the user prompt.
private static final ParseField INTERACTION_SIZE = new ParseField("interaction_size");
private static final ParseField INTERACTION_SIZE = new ParseField("message_size");

// Optional parameter; this parameter controls how long the search pipeline waits for a response
// from a remote inference endpoint before timing out the request.
Expand Down

0 comments on commit a07dd14

Please sign in to comment.