Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update deprecated openAI mode in integration test and fix flaky tests #1846

Merged
merged 11 commits into from
Jan 12, 2024
11 changes: 8 additions & 3 deletions .github/workflows/CI-workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -141,24 +141,29 @@ jobs:
else
echo "imagePresent=false" >> $GITHUB_ENV
fi
- name: Generate Password For Admin
id: genpass
run: |
PASSWORD=$(openssl rand -base64 20 | tr -dc 'A-Za-z0-9!@#$%^&*()_+=-')
echo "password={$PASSWORD}" >> $GITHUB_OUTPUT
- name: Run Docker Image
if: env.imagePresent == 'true'
run: |
cd ..
docker run -p 9200:9200 -d -p 9600:9600 -e "discovery.type=single-node" opensearch-ml:test
docker run -p 9200:9200 -d -p 9600:9600 -e "discovery.type=single-node" -e OPENSEARCH_INITIAL_ADMIN_PASSWORD=${{ steps.genpass.outputs.password }} opensearch-ml:test
sleep 90
- name: Run MLCommons Test
if: env.imagePresent == 'true'
run: |
security=`curl -XGET https://localhost:9200/_cat/plugins?v -u admin:admin --insecure |grep opensearch-security|wc -l`
security=`curl -XGET https://localhost:9200/_cat/plugins?v -u admin:${{ steps.genpass.outputs.password }} --insecure |grep opensearch-security|wc -l`
export OPENAI_KEY=$(aws secretsmanager get-secret-value --secret-id github_openai_key --query SecretString --output text)
export COHERE_KEY=$(aws secretsmanager get-secret-value --secret-id github_cohere_key --query SecretString --output text)
echo "::add-mask::$OPENAI_KEY"
echo "::add-mask::$COHERE_KEY"
if [ $security -gt 0 ]
then
echo "Security plugin is available"
./gradlew integTest -Dtests.rest.cluster=localhost:9200 -Dtests.cluster=localhost:9200 -Dtests.clustername="docker-cluster" -Dhttps=true -Duser=admin -Dpassword=admin
./gradlew integTest -Dtests.rest.cluster=localhost:9200 -Dtests.cluster=localhost:9200 -Dtests.clustername="docker-cluster" -Dhttps=true -Duser=admin -Dpassword=${{ steps.genpass.outputs.password }}
else
echo "Security plugin is NOT available"
./gradlew integTest -Dtests.rest.cluster=localhost:9200 -Dtests.cluster=localhost:9200 -Dtests.clustername="docker-cluster"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ POST /_plugins/_ml/connectors/_create
"endpoint": "api.openai.com",
"max_tokens": 7,
"temperature": 0,
"model": "text-davinci-003"
"model": "gpt-3.5-turbo-instruct"
},
"credential": {
"openAI_key": "<PLEASE ADD YOUR OPENAI API KEY HERE>"
Expand Down Expand Up @@ -62,7 +62,7 @@ POST /_plugins/_ml/models/<ENTER MODEL ID HERE>/_predict
"id": "cmpl-7g0NPOJd8IvXTdhecdlR0VGfrLMWE",
"object": "text_completion",
"created": 1690245579,
"model": "text-davinci-003",
"model": "gpt-3.5-turbo-instruct",
"choices": [
{
"text": """
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ public void createConversation(String name, ActionListener<String> listener) {
public void getConversations(int from, int maxResults, ActionListener<List<ConversationMeta>> listener) {
if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) {
listener.onResponse(List.of());
return;
}
SearchRequest request = Requests.searchRequest(META_INDEX_NAME);
String userstr = getUserStrFromThreadContext();
Expand Down Expand Up @@ -250,6 +251,7 @@ public void getConversations(int maxResults, ActionListener<List<ConversationMet
public void deleteConversation(String conversationId, ActionListener<Boolean> listener) {
if (!clusterService.state().metadata().hasIndex(META_INDEX_NAME)) {
listener.onResponse(true);
return;
}
DeleteRequest delRequest = Requests.deleteRequest(META_INDEX_NAME).id(conversationId);
String userstr = getUserStrFromThreadContext();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ public String encrypt(String plainText) {
initMasterKey();
final AwsCrypto crypto = AwsCrypto.builder().withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt).build();
byte[] bytes = Base64.getDecoder().decode(masterKey);
JceMasterKey jceMasterKey = JceMasterKey.getInstance(new SecretKeySpec(bytes, "AES"), "Custom", "", "AES/GCM/NoPadding");
// https://github.com/aws/aws-encryption-sdk-java/issues/1879
JceMasterKey jceMasterKey = JceMasterKey.getInstance(new SecretKeySpec(bytes, "AES"), "Custom", "", "AES/GCM/NOPADDING");

final CryptoResult<byte[], JceMasterKey> encryptResult = crypto
.encryptData(jceMasterKey, plainText.getBytes(StandardCharsets.UTF_8));
Expand All @@ -81,7 +82,7 @@ public String decrypt(String encryptedText) {
final AwsCrypto crypto = AwsCrypto.builder().withCommitmentPolicy(CommitmentPolicy.RequireEncryptRequireDecrypt).build();

byte[] bytes = Base64.getDecoder().decode(masterKey);
JceMasterKey jceMasterKey = JceMasterKey.getInstance(new SecretKeySpec(bytes, "AES"), "Custom", "", "AES/GCM/NoPadding");
JceMasterKey jceMasterKey = JceMasterKey.getInstance(new SecretKeySpec(bytes, "AES"), "Custom", "", "AES/GCM/NOPADDING");

final CryptoResult<byte[], JceMasterKey> decryptedResult = crypto
.decryptData(jceMasterKey, Base64.getDecoder().decode(encryptedText));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1222,21 +1222,21 @@ public void test_get_modelGroup() throws IOException {
getModelGroup(
user2Client,
modelGroupId1,
getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); }
getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup1"); }
);

// Admin successfully gets model group
getModelGroup(
client(),
modelGroupId1,
getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); }
getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup1"); }
);
} catch (IOException e) {
assertNull(e);
}
// User2 fails to get model group
try {
getModelGroup(user3Client, modelGroupId, null);
getModelGroup(user3Client, modelGroupId1, null);
} catch (Exception e) {
assertEquals(ResponseException.class, e.getClass());
assertTrue(
Expand All @@ -1256,21 +1256,21 @@ public void test_get_modelGroup() throws IOException {
getModelGroup(
user1Client,
modelGroupId2,
getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); }
getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup2"); }
);

// User3 successfully gets model group
getModelGroup(
user3Client,
modelGroupId2,
getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); }
getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup2"); }
);

// User4 successfully gets model group
getModelGroup(
user4Client,
modelGroupId2,
getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); }
getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup2"); }
);
} catch (IOException e) {
assertNull(e);
Expand All @@ -1286,14 +1286,14 @@ public void test_get_modelGroup() throws IOException {
getModelGroup(
user3Client,
modelGroupId3,
getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); }
getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup3"); }
);

// Admin successfully gets model group
getModelGroup(
client(),
modelGroupId3,
getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); }
getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup3"); }
);
} catch (IOException e) {
assertNull(e);
Expand All @@ -1320,7 +1320,7 @@ public void test_get_modelGroup() throws IOException {
getModelGroup(
client(),
modelGroupId4,
getModelGroupResult -> { assertTrue(getModelGroupResult.containsKey("model_group_id")); }
getModelGroupResult -> { assertEquals(getModelGroupResult.get("name"), "testModelGroup4"); }
);
} catch (IOException e) {
assertNull(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public class RestMLRemoteInferenceIT extends MLCommonsRestTestCase {
+ " \"content_type\": \"application/json\",\n"
+ " \"max_tokens\": 7,\n"
+ " \"temperature\": 0,\n"
+ " \"model\": \"text-davinci-003\"\n"
+ " \"model\": \"gpt-3.5-turbo-instruct\"\n"
+ " },\n"
+ " \"credential\": {\n"
+ " \"openAI_key\": \""
Expand Down Expand Up @@ -265,7 +265,7 @@ public void testOpenAIEditsModel() throws IOException, InterruptedException {
+ " \"endpoint\": \"api.openai.com\",\n"
+ " \"auth\": \"API_Key\",\n"
+ " \"content_type\": \"application/json\",\n"
+ " \"model\": \"text-davinci-edit-001\"\n"
+ " \"model\": \"gpt-4\"\n"
+ " },\n"
+ " \"credential\": {\n"
+ " \"openAI_key\": \""
Expand All @@ -276,18 +276,18 @@ public void testOpenAIEditsModel() throws IOException, InterruptedException {
+ " {\n"
+ " \"action_type\": \"predict\",\n"
+ " \"method\": \"POST\",\n"
+ " \"url\": \"https://api.openai.com/v1/edits\",\n"
+ " \"url\": \"https://api.openai.com/v1/chat/completions\",\n"
+ " \"headers\": { \n"
+ " \"Authorization\": \"Bearer ${credential.openAI_key}\"\n"
+ " },\n"
+ " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"input\\\": \\\"${parameters.input}\\\", \\\"instruction\\\": \\\"${parameters.instruction}\\\" }\"\n"
+ " \"request_body\": \"{ \\\"model\\\": \\\"${parameters.model}\\\", \\\"messages\\\": [{\\\"role\\\": \\\"user\\\", \\\"content\\\": \\\"${parameters.input}\\\"}]}\"\n"
+ " }\n"
+ " ]\n"
+ "}";
Response response = createConnector(entity);
Map responseMap = parseResponseToMap(response);
String connectorId = (String) responseMap.get("connector_id");
response = registerRemoteModel("openAI-GPT-3.5 edit model", connectorId);
response = registerRemoteModel("openAI-GPT-4 edit model", connectorId);
responseMap = parseResponseToMap(response);
String taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
Expand All @@ -298,12 +298,7 @@ public void testOpenAIEditsModel() throws IOException, InterruptedException {
responseMap = parseResponseToMap(response);
taskId = (String) responseMap.get("task_id");
waitForTask(taskId, MLTaskState.COMPLETED);
String predictInput = "{\n"
+ " \"parameters\": {\n"
+ " \"input\": \"What day of the wek is it?\",\n"
+ " \"instruction\": \"Fix the spelling mistakes\"\n"
+ " }\n"
+ "}";
String predictInput = "{\"parameters\":{\"input\":\"What day of the wek is it?\"}}";
response = predictRemoteModel(modelId, predictInput);
responseMap = parseResponseToMap(response);
List responseList = (List) responseMap.get("inference_results");
Expand All @@ -317,7 +312,9 @@ public void testOpenAIEditsModel() throws IOException, InterruptedException {
return;
}
responseMap = (Map) responseList.get(0);
assertFalse(((String) responseMap.get("text")).isEmpty());
responseMap = (Map) responseMap.get("message");

assertFalse(((String) responseMap.get("content")).isEmpty());
}

public void testOpenAIModerationsModel() throws IOException, InterruptedException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Map;
import java.util.concurrent.TimeUnit;

import org.apache.hc.core5.http.HttpEntity;
import org.apache.hc.core5.http.HttpHeaders;
import org.apache.hc.core5.http.message.BasicHeader;
import org.junit.Assert;
import org.junit.Before;
import org.opensearch.client.Response;
import org.opensearch.core.rest.RestStatus;
Expand Down Expand Up @@ -120,16 +122,20 @@ public void testConversations_MorePages() throws IOException {
assert (((Double) map.get("next_token")).intValue() == 1);
}

public void testGetConversations_nextPage() throws IOException {
public void testGetConversations_nextPage() throws IOException, InterruptedException {
Response ccresponse1 = TestHelper.makeRequest(client(), "POST", ActionConstants.CREATE_CONVERSATION_REST_PATH, null, "", null);
assert (ccresponse1 != null);
assert (TestHelper.restStatus(ccresponse1) == RestStatus.OK);
HttpEntity cchttpEntity1 = ccresponse1.getEntity();
String ccentityString1 = TestHelper.httpEntityToString(cchttpEntity1);
Map ccmap1 = gson.fromJson(ccentityString1, Map.class);
assert (ccmap1.containsKey("conversation_id"));
logger.info("ccentityString1={}", ccentityString1);
String id1 = (String) ccmap1.get("conversation_id");

// wait for 0.1s to make sure update time is different between conversation 1 and 2
TimeUnit.MICROSECONDS.sleep(100);

Response ccresponse2 = TestHelper.makeRequest(client(), "POST", ActionConstants.CREATE_CONVERSATION_REST_PATH, null, "", null);
assert (ccresponse2 != null);
assert (TestHelper.restStatus(ccresponse2) == RestStatus.OK);
Expand Down Expand Up @@ -159,7 +165,7 @@ public void testGetConversations_nextPage() throws IOException {
ArrayList<Map> conversations1 = (ArrayList<Map>) map1.get("conversations");
assert (conversations1.size() == 1);
assert (conversations1.get(0).containsKey("conversation_id"));
assert (((String) conversations1.get(0).get("conversation_id")).equals(id2));
Assert.assertEquals(conversations1.get(0).get("conversation_id"), id2);
assert (((Double) map1.get("next_token")).intValue() == 1);

Response response = TestHelper
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ public static RestRequest getCreateConnectorRestRequest() {
+ " \"content_type\": \"application/json\",\n"
+ " \"max_tokens\": 7,\n"
+ " \"temperature\": 0,\n"
+ " \"model\": \"text-davinci-003\"\n"
+ " \"model\": \"gpt-3.5-turbo-instruct\"\n"
+ " },\n"
+ " \"credential\": {\n"
+ " \"openAI_key\": \"xxxxxxxx\"\n"
Expand Down
Loading