Skip to content

Commit

Permalink
Update deprecated openAI mode in integration test and fix flaky tests (
Browse files Browse the repository at this point in the history
…#1846)

* Fix flaky test

Signed-off-by: Hailong Cui <ihailong@amazon.com>

* remove unused import

Signed-off-by: Hailong Cui <ihailong@amazon.com>

* replace deprecated openAI model in test

Signed-off-by: Hailong Cui <ihailong@amazon.com>

* ignore testOpenAIEditsModel as it's deprecated

Signed-off-by: Hailong Cui <ihailong@amazon.com>

* typo fix

Signed-off-by: Hailong Cui <ihailong@amazon.com>

* update admin password

Signed-off-by: Hailong Cui <ihailong@amazon.com>

* fix getConversations when index not exists

Signed-off-by: Hailong Cui <ihailong@amazon.com>

* update blueprint doc

Signed-off-by: Hailong Cui <ihailong@amazon.com>

* update edit model to gpt-4

Signed-off-by: Hailong Cui <ihailong@amazon.com>

* remove hardcode password

Signed-off-by: Hailong Cui <ihailong@amazon.com>

* replace deprecating set-output

Signed-off-by: Hailong Cui <ihailong@amazon.com>

---------

Signed-off-by: Hailong Cui <ihailong@amazon.com>
  • Loading branch information
Hailong-am authored Jan 12, 2024
1 parent 19c93b1 commit 6584529
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 31 deletions.
11 changes: 8 additions & 3 deletions .github/workflows/CI-workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -141,24 +141,29 @@ jobs:
else
echo "imagePresent=false" >> $GITHUB_ENV
fi
- name: Generate Password For Admin
id: genpass
run: |
PASSWORD=$(openssl rand -base64 20 | tr -dc 'A-Za-z0-9!@#$%^&*()_+=-')
echo "password={$PASSWORD}" >> $GITHUB_OUTPUT
- name: Run Docker Image
if: env.imagePresent == 'true'
run: |
cd ..
docker run -p 9200:9200 -d -p 9600:9600 -e "discovery.type=single-node" opensearch-ml:test
docker run -p 9200:9200 -d -p 9600:9600 -e "discovery.type=single-node" -e OPENSEARCH_INITIAL_ADMIN_PASSWORD=${{ steps.genpass.outputs.password }} opensearch-ml:test
sleep 90
- name: Run MLCommons Test
if: env.imagePresent == 'true'
run: |
security=`curl -XGET https://localhost:9200/_cat/plugins?v -u admin:admin --insecure |grep opensearch-security|wc -l`
security=`curl -XGET https://localhost:9200/_cat/plugins?v -u admin:${{ steps.genpass.outputs.password }} --insecure |grep opensearch-security|wc -l`
export OPENAI_KEY=$(aws secretsmanager get-secret-value --secret-id github_openai_key --query SecretString --output text)
export COHERE_KEY=$(aws secretsmanager get-secret-value --secret-id github_cohere_key --query SecretString --output text)
echo "::add-mask::$OPENAI_KEY"
echo "::add-mask::$COHERE_KEY"
if [ $security -gt 0 ]
then
echo "Security plugin is available"
./gradlew integTest -Dtests.rest.cluster=localhost:9200 -Dtests.cluster=localhost:9200 -Dtests.clustername="docker-cluster" -Dhttps=true -Duser=admin -Dpassword=admin
./gradlew integTest -Dtests.rest.cluster=localhost:9200 -Dtests.cluster=localhost:9200 -Dtests.clustername="docker-cluster" -Dhttps=true -Duser=admin -Dpassword=${{ steps.genpass.outputs.password }}
else
echo "Security plugin is NOT available"
./gradlew integTest -Dtests.rest.cluster=localhost:9200 -Dtests.cluster=localhost:9200 -Dtests.clustername="docker-cluster"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ POST /_plugins/_ml/connectors/_create
"endpoint": "api.openai.com",
"max_tokens": 7,
"temperature": 0,
"model": "text-davinci-003"
"model": "gpt-3.5-turbo-instruct"
},
"credential": {
"openAI_key": "<PLEASE ADD YOUR OPENAI API KEY HERE>"
Expand Down Expand Up @@ -62,7 +62,7 @@ POST /_plugins/_ml/models/<ENTER MODEL ID HERE>/_predict
"id": "cmpl-7g0NPOJd8IvXTdhecdlR0VGfrLMWE",
"object": "text_completion",
"created": 1690245579,
"model": "text-davinci-003",
"model": "gpt-3.5-turbo-instruct",
"choices": [
{
"text": """
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ public void createConversation(String name, ActionListener<String> listener) {
public void getConversations(int from, int maxResults, ActionListener<List<ConversationMeta>> listener) {
if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) {
listener.onResponse(List.of());
return;
}
SearchRequest request = Requests.searchRequest(META_INDEX_NAME);
String userstr = getUserStrFromThreadContext();
Expand Down Expand Up @@ -250,6 +251,7 @@ public void getConversations(int maxResults, ActionListener<List<ConversationMet
public void deleteConversation(String conversationId, ActionListener<Boolean> listener) {
if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) {
listener.onResponse(true);
return;
}
DeleteRequest delRequest = Requests.deleteRequest(META_INDEX_NAME).id(conversationId);
String userstr = getUserStrFromThreadContext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ public String encrypt(String plainText) {
initMasterKey();
final AwsCrypto crypto = AwsCrypto.builder().withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt).build();
byte[] bytes = Base64.getDecoder().decode(masterKey);
JceMasterKey jceMasterKey = JceMasterKey.getInstance(new SecretKeySpec(bytes, "AES"), "Custom", "", "AES/GCM/NoPadding");
// https://github.com/aws/aws-encryption-sdk-java/issues/1879
JceMasterKey jceMasterKey = JceMasterKey.getInstance(new SecretKeySpec(bytes, "AES"), "Custom", "", "AES/GCM/NOPADDING");

final CryptoResult<byte[], JceMasterKey> encryptResult = crypto
.encryptData(jceMasterKey, plainText.getBytes(StandardCharsets.UTF_8));
Expand All @@ -81,7 +82,7 @@ public String decrypt(String encryptedText) {
final AwsCrypto crypto = AwsCrypto.builder().withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt).build();

byte[] bytes = Base64.getDecoder().decode(masterKey);
JceMasterKey jceMasterKey = JceMasterKey.getInstance(new SecretKeySpec(bytes, "AES"), "Custom", "", "AES/GCM/NoPadding");
JceMasterKey jceMasterKey = JceMasterKey.getInstance(new SecretKeySpec(bytes, "AES"), "Custom", "", "AES/GCM/NOPADDING");

final CryptoResult<byte[], JceMasterKey> decryptedResult = crypto
.decryptData(jceMasterKey, Base64.getDecoder().decode(encryptedText));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1239,21 +1239,21 @@ public void test_get_modelGroup() throws IOException {
getModelGroup(
user2Client,
modelGroupId1,
getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); }
getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup1"); }
);

// Admin successfully gets model group
getModelGroup(
client(),
modelGroupId1,
getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); }
getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup1"); }
);
} catch (IOException e) {
assertNull(e);
}
// User2 fails to get model group
try {
getModelGroup(user3Client, modelGroupId, null);
getModelGroup(user3Client, modelGroupId1, null);
} catch (Exception e) {
assertEquals(ResponseException.class, e.getClass());
assertTrue(
Expand All @@ -1273,21 +1273,21 @@ public void test_get_modelGroup() throws IOException {
getModelGroup(
user1Client,
modelGroupId2,
getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); }
getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup2"); }
);

// User3 successfully gets model group
getModelGroup(
user3Client,
modelGroupId2,
getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); }
getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup2"); }
);

// User4 successfully gets model group
getModelGroup(
user4Client,
modelGroupId2,
getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); }
getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup2"); }
);
} catch (IOException e) {
assertNull(e);
Expand All @@ -1303,14 +1303,14 @@ public void test_get_modelGroup() throws IOException {
getModelGroup(
user3Client,
modelGroupId3,
getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); }
getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup3"); }
);

// Admin successfully gets model group
getModelGroup(
client(),
modelGroupId3,
getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); }
getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup3"); }
);
} catch (IOException e) {
assertNull(e);
Expand All @@ -1337,7 +1337,7 @@ public void test_get_modelGroup() throws IOException {
getModelGroup(
client(),
modelGroupId4,
getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); }
getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup4"); }
);
} catch (IOException e) {
assertNull(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class RestMLRemoteInferenceIT extends MLCommonsRestTestCase {
+ " \"content_type\": \"application/json\",\n"
+ " \"max_tokens\": 7,\n"
+ " \"temperature\": 0,\n"
+ " \"model\": \"text-davinci-003\"\n"
+ " \"model\": \"gpt-3.5-turbo-instruct\"\n"
+ " },\n"
+ " \"credential\": {\n"
+ " \"openAI_key\": \""
Expand Down Expand Up @@ -265,7 +265,7 @@ public void testOpenAIEditsModel() throws IOException, InterruptedException {
+ " \"endpoint\": \"api.openai.com\",\n"
+ " \"auth\": \"API_Key\",\n"
+ " \"content_type\": \"application/json\",\n"
+ " \"model\": \"text-davinci-edit-001\"\n"
+ " \"model\": \"gpt-4\"\n"
+ " },\n"
+ " \"credential\": {\n"
+ " \"openAI_key\": \""
Expand All @@ -276,18 +276,18 @@ public void testOpenAIEditsModel() throws IOException, InterruptedException {
+ " {\n"
+ " \"action_type\": \"predict\",\n"
+ " \"method\": \"POST\",\n"
+ " \"url\": \"https://api.openai.com/v1/edits\",\n"
+ " \"url\": \"https://api.openai.com/v1/chat/completions\",\n"
+ " \"headers\": { \n"
+ " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n"
+ " },\n"
+ " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"input\\\": \\\"${parameters.input}\\\", \\\"instruction\\\": \\\"${parameters.instruction}\\\" }\"\n"
+ " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": [{\\\"role\\\": \\\"user\\\", \\\"content\\\": \\\"${parameters.input}\\\"}]}\"\n"
+ " }\n"
+ " ]\n"
+ "}";
Response response = createConnector(entity);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
response = registerRemoteModel("openAI-GPT-3.5 edit model", connectorId);
response = registerRemoteModel("openAI-GPT-4 edit model", connectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
Expand All @@ -298,12 +298,7 @@ public void testOpenAIEditsModel() throws IOException, InterruptedException {
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
String predictInput = "{\n"
+ " \"parameters\": {\n"
+ " \"input\": \"What day of the wek is it?\",\n"
+ " \"instruction\": \"Fix the spelling mistakes\"\n"
+ " }\n"
+ "}";
String predictInput = "{\"parameters\":{\"input\":\"What day of the wek is it?\"}}";
response = predictRemoteModel(modelId, predictInput);
responseMap = parseResponseToMap(response);
List responseList = (List) responseMap.get("inference_results");
Expand All @@ -317,7 +312,9 @@ public void testOpenAIEditsModel() throws IOException, InterruptedException {
return;
}
responseMap = (Map) responseList.get(0);
assertFalse(((String) responseMap.get("text")).isEmpty());
responseMap = (Map) responseMap.get("message");

assertFalse(((String) responseMap.get("content")).isEmpty());
}

public void testOpenAIModerationsModel() throws IOException, InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Map;
import java.util.concurrent.TimeUnit;

import org.apache.hc.core5.http.HttpEntity;
import org.apache.hc.core5.http.HttpHeaders;
import org.apache.hc.core5.http.message.BasicHeader;
import org.junit.Assert;
import org.junit.Before;
import org.opensearch.client.Response;
import org.opensearch.core.rest.RestStatus;
Expand Down Expand Up @@ -120,16 +122,20 @@ public void testConversations_MorePages() throws IOException {
assert (((Double) map.get("next_token")).intValue() == 1);
}

public void testGetConversations_nextPage() throws IOException {
public void testGetConversations_nextPage() throws IOException, InterruptedException {
Response ccresponse1 = TestHelper.makeRequest(client(), "POST", ActionConstants.CREATE_CONVERSATION_REST_PATH, null, "", null);
assert (ccresponse1 != null);
assert (TestHelper.restStatus(ccresponse1) == RestStatus.OK);
HttpEntity cchttpEntity1 = ccresponse1.getEntity();
String ccentityString1 = TestHelper.httpEntityToString(cchttpEntity1);
Map ccmap1 = gson.fromJson(ccentityString1, Map.class);
assert (ccmap1.containsKey("conversation_id"));
logger.info("ccentityString1={}", ccentityString1);
String id1 = (String) ccmap1.get("conversation_id");

// wait for 0.1s to make sure update time is different between conversation 1 and 2
TimeUnit.MICROSECONDS.sleep(100);

Response ccresponse2 = TestHelper.makeRequest(client(), "POST", ActionConstants.CREATE_CONVERSATION_REST_PATH, null, "", null);
assert (ccresponse2 != null);
assert (TestHelper.restStatus(ccresponse2) == RestStatus.OK);
Expand Down Expand Up @@ -159,7 +165,7 @@ public void testGetConversations_nextPage() throws IOException {
ArrayList<Map> conversations1 = (ArrayList<Map>) map1.get("conversations");
assert (conversations1.size() == 1);
assert (conversations1.get(0).containsKey("conversation_id"));
assert (((String) conversations1.get(0).get("conversation_id")).equals(id2));
Assert.assertEquals(conversations1.get(0).get("conversation_id"), id2);
assert (((Double) map1.get("next_token")).intValue() == 1);

Response response = TestHelper
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ public static RestRequest getCreateConnectorRestRequest() {
+ " \"content_type\": \"application/json\",\n"
+ " \"max_tokens\": 7,\n"
+ " \"temperature\": 0,\n"
+ " \"model\": \"text-davinci-003\"\n"
+ " \"model\": \"gpt-3.5-turbo-instruct\"\n"
+ " },\n"
+ " \"credential\": {\n"
+ " \"openAI_key\": \"xxxxxxxx\"\n"
Expand Down

0 comments on commit 6584529

Please sign in to comment.