Skip to content

Commit

Permalink
add integ tests for model APIs
Browse files Browse the repository at this point in the history
Signed-off-by: Xun Zhang <xunzh@amazon.com>
  • Loading branch information
Zhangxunmt committed Mar 5, 2022
1 parent 83d905b commit b22b45a
Show file tree
Hide file tree
Showing 7 changed files with 258 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_PEMCERT_FILEPATH;
import static org.opensearch.ml.stats.StatNames.ML_TOTAL_FAILURE_COUNT;
import static org.opensearch.ml.stats.StatNames.ML_TOTAL_REQUEST_COUNT;
import static org.opensearch.ml.utils.TestData.trainModelDataJson;

import java.io.IOException;
import java.net.URI;
Expand Down Expand Up @@ -283,7 +284,16 @@ protected void validateStats(
}
assertEquals(expectedTotalFailureCount, totalFailureCount);
assertEquals(expectedTotalAlgoFailureCount, totalAlgoFailureCount);
assertEquals(expectedTotalRequestCount, totalRequestCount);
// ToDo: this line makes this test flaky as other tests makes the request count not predictable
// assertEquals(expectedTotalRequestCount, totalRequestCount);
assertEquals(expectedTotalAlgoRequestCount, totalAlgoRequestCount);
}

protected Response ingestModelData() throws IOException {
Response trainModelResponse = TestHelper
.makeRequest(client(), "POST", "_plugins/_ml/_train/sample_algo", null, TestHelper.toHttpEntity(trainModelDataJson()), null);
HttpEntity entity = trainModelResponse.getEntity();
assertNotNull(trainModelResponse);
return trainModelResponse;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.rest;

import java.io.IOException;
import java.util.Map;

import org.apache.http.HttpEntity;
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
import org.opensearch.ml.utils.TestHelper;
import org.opensearch.rest.RestStatus;

public class RestMLDeleteModelActionIT extends MLCommonsRestTestCase {

public void testDeleteModelAPI_EmptyResources() throws Exception {
TestHelper
.assertFailWith(
ResponseException.class,
"index_not_found_exception",
() -> TestHelper.makeRequest(client(), "DELETE", "/_plugins/_ml/models/111222333", null, "", null)
);
}

public void testDeleteModelAPI_Success() throws IOException {
Response trainModelResponse = ingestModelData();
HttpEntity entity = trainModelResponse.getEntity();
assertNotNull(trainModelResponse);
String entityString = TestHelper.httpEntityToString(entity);
Map map = gson.fromJson(entityString, Map.class);
String model_id = (String) map.get("model_id");

Response getModelResponse = TestHelper.makeRequest(client(), "DELETE", "/_plugins/_ml/models/" + model_id, null, "", null);
HttpEntity responseEntity = getModelResponse.getEntity();
assertNotNull(getModelResponse);
assertEquals(RestStatus.OK, TestHelper.restStatus(getModelResponse));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.rest;

import java.io.IOException;
import java.util.Map;

import org.apache.http.HttpEntity;
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
import org.opensearch.ml.utils.TestHelper;
import org.opensearch.rest.RestStatus;

public class RestMLGetModelActionIT extends MLCommonsRestTestCase {

public void testGetModelAPI_EmptyResources() throws IOException {
try {
Response getModelResponse = TestHelper.makeRequest(client(), "GET", "/_plugins/_ml/models/111222333", null, "", null);
HttpEntity entity = getModelResponse.getEntity();
assertNotNull(getModelResponse);
assertEquals(RestStatus.INTERNAL_SERVER_ERROR, TestHelper.restStatus(getModelResponse));
} catch (Exception exception) {
assertEquals(exception.getClass(), ResponseException.class);
}
}

public void testGetModelAPI_Success() throws IOException {
Response trainModelResponse = ingestModelData();
HttpEntity entity = trainModelResponse.getEntity();
assertNotNull(trainModelResponse);
String entityString = TestHelper.httpEntityToString(entity);
Map map = gson.fromJson(entityString, Map.class);
String model_id = (String) map.get("model_id");

Response getModelResponse = TestHelper.makeRequest(client(), "GET", "/_plugins/_ml/models/" + model_id, null, "", null);
HttpEntity responseEntity = getModelResponse.getEntity();
assertNotNull(getModelResponse);
assertEquals(RestStatus.OK, TestHelper.restStatus(getModelResponse));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.rest;

import java.util.List;

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.common.Strings;
import org.opensearch.rest.RestHandler;
import org.opensearch.rest.RestRequest;
import org.opensearch.test.OpenSearchTestCase;

public class RestMLGetModelActionTests extends OpenSearchTestCase {
@Rule
public ExpectedException thrown = ExpectedException.none();

private RestMLGetModelAction restMLGetModelAction;

@Before
public void setup() {
restMLGetModelAction = new RestMLGetModelAction();
}

@Test
public void testConstructor() {
RestMLGetModelAction mlGetModelAction = new RestMLGetModelAction();
assertNotNull(mlGetModelAction);
}

@Test
public void testGetName() {
String actionName = restMLGetModelAction.getName();
assertFalse(Strings.isNullOrEmpty(actionName));
assertEquals("ml_get_model_action", actionName);
}

@Test
public void testRoutes() {
List<RestHandler.Route> routes = restMLGetModelAction.routes();
assertNotNull(routes);
assertFalse(routes.isEmpty());
RestHandler.Route route = routes.get(0);
assertEquals(RestRequest.Method.GET, route.getMethod());
assertEquals("/_plugins/_ml/models/{model_id}", route.getPath());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.rest;

import static org.opensearch.ml.utils.TestData.matchAllSearchQuery;

import java.io.IOException;
import java.util.Map;

import org.apache.http.HttpEntity;
import org.opensearch.client.Response;
import org.opensearch.client.ResponseException;
import org.opensearch.ml.utils.TestHelper;
import org.opensearch.rest.RestStatus;

public class RestMLSearchModelActionIT extends MLCommonsRestTestCase {

public void testSearchModelAPI_EmptyResources() throws Exception {
TestHelper
.assertFailWith(
ResponseException.class,
"index_not_found_exception",
() -> TestHelper.makeRequest(client(), "GET", "/_plugins/_ml/models/_search", null, matchAllSearchQuery(), null)
);
}

public void testSearchModelAPI_Success() throws IOException {
Response trainModelResponse = ingestModelData();
HttpEntity entity = trainModelResponse.getEntity();
assertNotNull(trainModelResponse);
String entityString = TestHelper.httpEntityToString(entity);
Map map = gson.fromJson(entityString, Map.class);
String model_id = (String) map.get("model_id");

Response searchModelResponse = TestHelper
.makeRequest(client(), "GET", "/_plugins/_ml/models/_search", null, matchAllSearchQuery(), null);
HttpEntity responseEntity = searchModelResponse.getEntity();
assertNotNull(searchModelResponse);
assertEquals(RestStatus.OK, TestHelper.restStatus(searchModelResponse));
}
}
53 changes: 53 additions & 0 deletions plugin/src/test/java/org/opensearch/ml/utils/TestData.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

package org.opensearch.ml.utils;

import com.google.gson.JsonArray;
import com.google.gson.JsonObject;

public class TestData {

public static final String IRIS_DATA = "{ \"index\" : { \"_index\" : \"iris_data\" } }\n"
Expand Down Expand Up @@ -307,4 +310,54 @@ public class TestData {
+ "{\"sepal_length_in_cm\":6.2,\"sepal_width_in_cm\":3.4,\"petal_length_in_cm\":5.4,\"petal_width_in_cm\":2.3,\"class\":\"Iris-virginica\"}\n"
+ "{ \"index\" : { \"_index\" : \"iris_data\" } }\n"
+ "{\"sepal_length_in_cm\":5.9,\"sepal_width_in_cm\":3.0,\"petal_length_in_cm\":5.1,\"petal_width_in_cm\":1.8,\"class\":\"Iris-virginica\"}\n";

public static final String trainModelDataJson() {
JsonObject column_metas_1 = new JsonObject();
JsonObject column_metas_2 = new JsonObject();
JsonArray column_metas = new JsonArray();
column_metas_1.addProperty("name", "total_sum");
column_metas_1.addProperty("column_type", "DOUBLE");

column_metas_2.addProperty("name", "is_error");
column_metas_2.addProperty("column_type", "BOOLEAN");

column_metas.add(column_metas_1);
column_metas.add(column_metas_2);

JsonObject rows_values_1 = new JsonObject();
JsonObject rows_values_2 = new JsonObject();

rows_values_1.addProperty("column_type", "DOUBLE");
rows_values_1.addProperty("value", 15);

rows_values_2.addProperty("column_type", "BOOLEAN");
rows_values_2.addProperty("value", false);

JsonArray rows_values = new JsonArray();
rows_values.add(rows_values_1);
rows_values.add(rows_values_2);

JsonArray rows = new JsonArray();
JsonObject value = new JsonObject();
value.add("values", rows_values);
rows.add(value);

JsonObject input_data = new JsonObject();
input_data.add("column_metas", column_metas);
input_data.add("rows", rows);

JsonObject parameters = new JsonObject();
parameters.addProperty("sample_param", 10);

JsonObject body = new JsonObject();
body.add("parameters", parameters);
body.add("input_data", input_data);

return body.toString();
}

public static final String matchAllSearchQuery() {
String matchAllQuery = "{\"query\": {" + "\"match_all\": {}" + "}" + "}";
return matchAllQuery;
}
}
14 changes: 14 additions & 0 deletions plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;

import org.apache.http.Header;
import org.apache.http.HttpEntity;
Expand Down Expand Up @@ -138,4 +139,17 @@ public static String httpEntityToString(HttpEntity entity) throws IOException {
}
return sb.toString();
}

public static <S, T> void assertFailWith(Class<S> clazz, String message, Callable<T> callable) throws Exception {
try {
callable.call();
} catch (Throwable e) {
if (e.getClass() != clazz) {
throw e;
}
if (message != null && !e.getMessage().contains(message)) {
throw e;
}
}
}
}

0 comments on commit b22b45a

Please sign in to comment.