Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add integration test for train and predict API #157

Merged
merged 2 commits into from
Mar 1, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion DEVELOPER_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ This package uses the [Gradle](https://docs.gradle.org/current/userguide/usergui
2. `./gradlew :run` launches a single node cluster with ml-commons plugin installed
3. `./gradlew :integTest` launches a single node cluster with ml-commons plugin installed and runs all integration tests except security
4. ` ./gradlew :integTest --tests="**.test execute foo"` runs a single integration test class or method
5. `./gradlew spotlessApply` formats code. And/or import formatting rules in `.eclipseformat.xml` with IDE.
5. `./gradlew integTest -Dtests.rest.cluster=localhost:9200 -Dtests.cluster=localhost:9200 -Dtests.clustername="docker-cluster" -Dhttps=true -Duser=admin -Dpassword=admin` launches integration tests against a local cluster and run tests with security
6. `./gradlew spotlessApply` formats code. And/or import formatting rules in `.eclipseformat.xml` with IDE.

When launching a cluster using one of the above commands logs are placed in `/build/cluster/run node0/opensearch-<version>/logs`. Though the logs are teed to the console, in practices it's best to check the actual log file.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ public void predict() {
int anomalyCount = 0;
for (int i = 0 ;i<dataSize; i++) {
if (i % 100 == 0) {
System.out.println(predictions.getRow(i).getValue(1).doubleValue());
if (predictions.getRow(i).getValue(1).doubleValue() > 0.01) {
anomalyCount++;
}
Expand Down
50 changes: 34 additions & 16 deletions plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
* SPDX-License-Identifier: Apache-2.0
*/

import java.util.concurrent.Callable
import org.opensearch.gradle.test.RestIntegTestTask
import org.opensearch.gradle.testclusters.StandaloneRestIntegTestTask

plugins {
id 'java'
Expand Down Expand Up @@ -43,24 +45,12 @@ dependencies {
compile "org.opensearch:common-utils:${common_utils_version}"
compile("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}")
compile("com.fasterxml.jackson.core:jackson-databind:${versions.jackson}")
compile group: 'com.google.guava', name: 'guava', version:'29.0-jre'
implementation group: 'com.google.guava', name: 'guava', version: '31.0.1-jre'
implementation group: 'com.google.code.gson', name: 'gson', version: '2.9.0'

checkstyle "com.puppycrawl.tools:checkstyle:${project.checkstyle.toolVersion}"
}

test {
include '**/*Test.class'
systemProperty 'tests.security.manager', 'false'
finalizedBy jacocoTestReport
}

jacocoTestReport {
reports {
xml.enabled true
html.enabled true
csv.enabled true
}
}

compileJava {
options.compilerArgs.addAll(["-processor", 'lombok.launch.AnnotationProcessorHider$AnnotationProcessor'])
Expand All @@ -81,14 +71,16 @@ loggerUsageCheck.enabled = false

def _numNodes = findProperty('numNodes') as Integer ?: 1

def opensearch_tmp_dir = rootProject.file('build/private/opensearch_tmp').absoluteFile
opensearch_tmp_dir.mkdirs()

test {
include '**/*Tests.class'
systemProperty 'tests.security.manager', 'false'
}

def opensearch_tmp_dir = rootProject.file('build/private/opensearch_tmp').absoluteFile
opensearch_tmp_dir.mkdirs()


task integTest(type: RestIntegTestTask) {
description = "Run tests against a cluster"
testClassesDirs = sourceSets.test.output.classesDirs
Expand All @@ -105,6 +97,13 @@ integTest {
systemProperty "user", System.getProperty("user")
systemProperty "password", System.getProperty("password")

// Only rest case can run with remote cluster
if (System.getProperty("tests.rest.cluster") != null) {
filter {
includeTestsMatching "org.opensearch.ml.rest.*IT"
}
}

// The 'doFirst' delays till execution time.
doFirst {
// Tell the test JVM if the cluster JVM is running under a debugger so that tests can
Expand Down Expand Up @@ -150,6 +149,24 @@ testClusters.integTest {
}
}

task integTestRemote(type: RestIntegTestTask) {
testClassesDirs = sourceSets.test.output.classesDirs
classpath = sourceSets.test.runtimeClasspath
systemProperty 'tests.security.manager', 'false'
systemProperty 'java.io.tmpdir', opensearch_tmp_dir.absolutePath

systemProperty "https", System.getProperty("https")
systemProperty "user", System.getProperty("user")
systemProperty "password", System.getProperty("password")

// Only rest case can run with remote cluster
if (System.getProperty("tests.rest.cluster") != null) {
filter {
includeTestsMatching "org.opensearch.ml.rest.*IT"
}
}
}

run {
doFirst {
// There seems to be an issue when running multi node run or integ tasks with unicast_hosts
Expand All @@ -172,6 +189,7 @@ task release(type: Copy, group: 'build') {
jacocoTestReport {
reports {
xml.enabled true
html.enabled true
csv.enabled false
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
package org.opensearch.ml.rest;

import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_ENABLED;
import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_FILEPATH;
import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_KEYPASSWORD;
import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_PASSWORD;
import static org.opensearch.commons.ConfigConstants.OPENSEARCH_SECURITY_SSL_HTTP_PEMCERT_FILEPATH;

import java.io.IOException;
import java.net.URI;
import java.net.URISyntaxException;
import java.nio.file.Path;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

import org.apache.http.Header;
import org.apache.http.HttpHeaders;
import org.apache.http.HttpHost;
import org.apache.http.auth.AuthScope;
import org.apache.http.auth.UsernamePasswordCredentials;
import org.apache.http.client.CredentialsProvider;
import org.apache.http.conn.ssl.NoopHostnameVerifier;
import org.apache.http.impl.client.BasicCredentialsProvider;
import org.apache.http.message.BasicHeader;
import org.apache.http.ssl.SSLContextBuilder;
import org.apache.http.util.EntityUtils;
import org.junit.After;
import org.opensearch.client.Request;
import org.opensearch.client.Response;
import org.opensearch.client.RestClient;
import org.opensearch.client.RestClientBuilder;
import org.opensearch.common.io.PathUtils;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.DeprecationHandler;
import org.opensearch.common.xcontent.NamedXContentRegistry;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.commons.rest.SecureRestClientBuilder;
import org.opensearch.ml.utils.TestData;
import org.opensearch.ml.utils.TestHelper;
import org.opensearch.rest.RestStatus;
import org.opensearch.test.rest.OpenSearchRestTestCase;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

public abstract class MLCommonsRestTestCase extends OpenSearchRestTestCase {

protected boolean isHttps() {
boolean isHttps = Optional.ofNullable(System.getProperty("https")).map("true"::equalsIgnoreCase).orElse(false);
if (isHttps) {
// currently only external cluster is supported for security enabled testing
if (!Optional.ofNullable(System.getProperty("tests.rest.cluster")).isPresent()) {
throw new RuntimeException("cluster url should be provided for security enabled testing");
}
}

return isHttps;
}

@Override
protected String getProtocol() {
return isHttps() ? "https" : "http";
}

@Override
protected Settings restAdminSettings() {
return Settings
.builder()
// disable the warning exception for admin client since it's only used for cleanup.
.put("strictDeprecationMode", false)
.put("http.port", 9200)
.put(OPENSEARCH_SECURITY_SSL_HTTP_ENABLED, isHttps())
.put(OPENSEARCH_SECURITY_SSL_HTTP_PEMCERT_FILEPATH, "sample.pem")
.put(OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_FILEPATH, "test-kirk.jks")
.put(OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_PASSWORD, "changeit")
.put(OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_KEYPASSWORD, "changeit")
.build();
}

// Utility fn for deleting indices. Should only be used when not allowed in a regular context
// (e.g., deleting system indices)
protected static void deleteIndexWithAdminClient(String name) throws IOException {
Request request = new Request("DELETE", "/" + name);
adminClient().performRequest(request);
}

// Utility fn for checking if an index exists. Should only be used when not allowed in a regular context
// (e.g., checking existence of system indices)
protected static boolean indexExistsWithAdminClient(String indexName) throws IOException {
Request request = new Request("HEAD", "/" + indexName);
Response response = adminClient().performRequest(request);
return RestStatus.OK.getStatus() == response.getStatusLine().getStatusCode();
}

@Override
protected RestClient buildClient(Settings settings, HttpHost[] hosts) throws IOException {
boolean strictDeprecationMode = settings.getAsBoolean("strictDeprecationMode", true);
RestClientBuilder builder = RestClient.builder(hosts);
if (isHttps()) {
String keystore = settings.get(OPENSEARCH_SECURITY_SSL_HTTP_KEYSTORE_FILEPATH);
if (Objects.nonNull(keystore)) {
URI uri = null;
try {
uri = this.getClass().getClassLoader().getResource("security/sample.pem").toURI();
} catch (URISyntaxException e) {
throw new RuntimeException(e);
}
Path configPath = PathUtils.get(uri).getParent().toAbsolutePath();
return new SecureRestClientBuilder(settings, configPath).build();
} else {
configureHttpsClient(builder, settings);
builder.setStrictDeprecationMode(strictDeprecationMode);
return builder.build();
}

} else {
configureClient(builder, settings);
builder.setStrictDeprecationMode(strictDeprecationMode);
return builder.build();
}

}

@SuppressWarnings("unchecked")
@After
protected void wipeAllODFEIndices() throws IOException {
Response response = adminClient().performRequest(new Request("GET", "/_cat/indices?format=json&expand_wildcards=all"));
XContentType xContentType = XContentType.fromMediaTypeOrFormat(response.getEntity().getContentType().getValue());
try (
XContentParser parser = xContentType
.xContent()
.createParser(
NamedXContentRegistry.EMPTY,
DeprecationHandler.THROW_UNSUPPORTED_OPERATION,
response.getEntity().getContent()
)
) {
XContentParser.Token token = parser.nextToken();
List<Map<String, Object>> parserList = null;
if (token == XContentParser.Token.START_ARRAY) {
parserList = parser.listOrderedMap().stream().map(obj -> (Map<String, Object>) obj).collect(Collectors.toList());
} else {
parserList = Collections.singletonList(parser.mapOrdered());
}

for (Map<String, Object> index : parserList) {
String indexName = (String) index.get("index");
if (indexName != null && !".opendistro_security".equals(indexName)) {
adminClient().performRequest(new Request("DELETE", "/" + indexName));
}
}
}
}

protected static void configureHttpsClient(RestClientBuilder builder, Settings settings) throws IOException {
Map<String, String> headers = ThreadContext.buildDefaultHeaders(settings);
Header[] defaultHeaders = new Header[headers.size()];
int i = 0;
for (Map.Entry<String, String> entry : headers.entrySet()) {
defaultHeaders[i++] = new BasicHeader(entry.getKey(), entry.getValue());
}
builder.setDefaultHeaders(defaultHeaders);
builder.setHttpClientConfigCallback(httpClientBuilder -> {
String userName = Optional
.ofNullable(System.getProperty("user"))
.orElseThrow(() -> new RuntimeException("user name is missing"));
String password = Optional
.ofNullable(System.getProperty("password"))
.orElseThrow(() -> new RuntimeException("password is missing"));
CredentialsProvider credentialsProvider = new BasicCredentialsProvider();
credentialsProvider.setCredentials(AuthScope.ANY, new UsernamePasswordCredentials(userName, password));
try {
return httpClientBuilder
.setDefaultCredentialsProvider(credentialsProvider)
// disable the certificate since our testing cluster just uses the default security configuration
.setSSLHostnameVerifier(NoopHostnameVerifier.INSTANCE)
.setSSLContext(SSLContextBuilder.create().loadTrustMaterial(null, (chains, authType) -> true).build());
} catch (Exception e) {
throw new RuntimeException(e);
}
});

final String socketTimeoutString = settings.get(CLIENT_SOCKET_TIMEOUT);
final TimeValue socketTimeout = TimeValue
.parseTimeValue(socketTimeoutString == null ? "60s" : socketTimeoutString, CLIENT_SOCKET_TIMEOUT);
builder.setRequestConfigCallback(conf -> conf.setSocketTimeout(Math.toIntExact(socketTimeout.getMillis())));
if (settings.hasValue(CLIENT_PATH_PREFIX)) {
builder.setPathPrefix(settings.get(CLIENT_PATH_PREFIX));
}
}

/**
* wipeAllIndices won't work since it cannot delete security index. Use wipeAllODFEIndices instead.
*/
@Override
protected boolean preserveIndicesUponCompletion() {
return true;
}

protected Response ingestIrisData(String indexName) throws IOException {
String irisDataIndexMapping = "";
TestHelper
.makeRequest(
client(),
"PUT",
indexName,
null,
TestHelper.toHttpEntity(irisDataIndexMapping),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana"))
);

Response statsResponse = TestHelper.makeRequest(client(), "GET", indexName, ImmutableMap.of(), "", null);
assertEquals(RestStatus.OK, TestHelper.restStatus(statsResponse));
String result = EntityUtils.toString(statsResponse.getEntity());
assertTrue(result.contains(indexName));

Response bulkResponse = TestHelper
.makeRequest(
client(),
"POST",
"_bulk?refresh=true",
null,
TestHelper.toHttpEntity(TestData.IRIS_DATA),
ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, ""))
);
assertEquals(RestStatus.OK, TestHelper.restStatus(statsResponse));
return bulkResponse;
}
}
Loading