diff --git a/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java b/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java index 2280e6d2c2..9965734a31 100644 --- a/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java +++ b/plugin/src/main/java/org/opensearch/ml/indices/MLIndicesHandler.java @@ -12,8 +12,10 @@ import org.opensearch.action.ActionListener; import org.opensearch.action.admin.indices.create.CreateIndexRequest; +import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.XContentType; @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @@ -87,19 +89,24 @@ public void initMLTaskIndex(ActionListener listener) { public void initMLIndexIfAbsent(String indexName, String mapping, ActionListener listener) { if (!clusterService.state().metadata().hasIndex(indexName)) { - CreateIndexRequest request = new CreateIndexRequest(indexName).mapping("_doc", mapping, XContentType.JSON); - - client.admin().indices().create(request, ActionListener.wrap(r -> { - if (r.isAcknowledged()) { - log.info("create index:{}", indexName); - listener.onResponse(true); - } else { - listener.onResponse(false); - } - }, e -> { - log.error("Failed to create index " + indexName, e); + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + ActionListener actionListener = ActionListener.wrap(r -> { + if (r.isAcknowledged()) { + log.info("create index:{}", indexName); + listener.onResponse(true); + } else { + listener.onResponse(false); + } + }, e -> { + log.error("Failed to create index " + indexName, e); + listener.onFailure(e); + }); + CreateIndexRequest request = new CreateIndexRequest(indexName).mapping("_doc", mapping, XContentType.JSON); + client.admin().indices().create(request, ActionListener.runBefore(actionListener, () -> threadContext.restore())); + } catch (Exception e) { + log.error("Failed to init index " + indexName, e); listener.onFailure(e); - })); + } } else { log.info("index:{} is already created", indexName); listener.onResponse(true); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java index 3cbea23167..5927c9fcef 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java @@ -260,10 +260,10 @@ protected Response ingestIrisData(String indexName) throws IOException { protected void validateStats( FunctionName functionName, ActionName actionName, - int expectedTotalFailureCount, - int expectedTotalAlgoFailureCount, - int expectedMinumnTotalRequestCount, - int expectedTotalAlgoRequestCount + int expectedMinimumTotalFailureCount, + int expectedMinimumTotalAlgoFailureCount, + int expectedMinimumTotalRequestCount, + int expectedMinimumTotalAlgoRequestCount ) throws IOException { Response statsResponse = TestHelper.makeRequest(client(), "GET", "_plugins/_ml/stats", null, "", null); HttpEntity entity = statsResponse.getEntity(); @@ -291,10 +291,10 @@ protected void validateStats( totalAlgoRequestCount += (Double) nodeStatsMap.get(requestCountStat); } } - assertEquals(expectedTotalFailureCount, totalFailureCount); - assertEquals(expectedTotalAlgoFailureCount, totalAlgoFailureCount); - assertTrue(totalRequestCount >= expectedMinumnTotalRequestCount); - assertEquals(expectedTotalAlgoRequestCount, totalAlgoRequestCount); + assertTrue(totalFailureCount >= expectedMinimumTotalFailureCount); + assertTrue(totalAlgoFailureCount >= expectedMinimumTotalAlgoFailureCount); + assertTrue(totalRequestCount >= expectedMinimumTotalRequestCount); + assertTrue(totalAlgoRequestCount >= expectedMinimumTotalAlgoRequestCount); } protected Response ingestModelData() throws IOException { @@ -464,4 +464,97 @@ public void trainAndPredict( function.accept(predictionResult); } } + + public void train( + RestClient client, + FunctionName functionName, + String indexName, + MLAlgoParams params, + SearchSourceBuilder searchSourceBuilder, + Consumer> function, + boolean async + ) throws IOException { + MLInputDataset inputData = SearchQueryInputDataset + .builder() + .indices(ImmutableList.of(indexName)) + .searchSourceBuilder(searchSourceBuilder) + .build(); + MLInput kmeansInput = MLInput.builder().algorithm(functionName).parameters(params).inputDataset(inputData).build(); + String endpoint = "/_plugins/_ml/_train/" + functionName.name().toLowerCase(Locale.ROOT); + if (async) { + endpoint += "?async=true"; + } + Response response = TestHelper.makeRequest(client, "POST", endpoint, ImmutableMap.of(), TestHelper.toHttpEntity(kmeansInput), null); + verifyResponse(function, response); + } + + public void predict( + RestClient client, + FunctionName functionName, + String modelId, + String indexName, + MLAlgoParams params, + SearchSourceBuilder searchSourceBuilder, + Consumer> function + ) throws IOException { + MLInputDataset inputData = SearchQueryInputDataset + .builder() + .indices(ImmutableList.of(indexName)) + .searchSourceBuilder(searchSourceBuilder) + .build(); + MLInput kmeansInput = MLInput.builder().algorithm(functionName).parameters(params).inputDataset(inputData).build(); + String endpoint = "/_plugins/_ml/_predict/" + functionName.name().toLowerCase(Locale.ROOT) + "/" + modelId; + Response response = TestHelper.makeRequest(client, "POST", endpoint, ImmutableMap.of(), TestHelper.toHttpEntity(kmeansInput), null); + verifyResponse(function, response); + } + + public void getModel(RestClient client, String modelId, Consumer> function) throws IOException { + Response response = TestHelper.makeRequest(client, "GET", "/_plugins/_ml/models/" + modelId, null, "", null); + verifyResponse(function, response); + } + + public void getTask(RestClient client, String taskId, Consumer> function) throws IOException { + Response response = TestHelper.makeRequest(client, "GET", "/_plugins/_ml/tasks/" + taskId, null, "", null); + verifyResponse(function, response); + } + + public void deleteModel(RestClient client, String modelId, Consumer> function) throws IOException { + Response response = TestHelper.makeRequest(client, "DELETE", "/_plugins/_ml/models/" + modelId, null, "", null); + verifyResponse(function, response); + } + + public void deleteTask(RestClient client, String taskId, Consumer> function) throws IOException { + Response response = TestHelper.makeRequest(client, "DELETE", "/_plugins/_ml/tasks/" + taskId, null, "", null); + verifyResponse(function, response); + } + + public void searchModelsWithAlgoName(RestClient client, String algoName, Consumer> function) throws IOException { + String query = String.format(Locale.ROOT, "{\"query\":{\"bool\":{\"filter\":[{\"term\":{\"algorithm\":\"%s\"}}]}}}", algoName); + searchModels(client, query, function); + } + + public void searchModels(RestClient client, String query, Consumer> function) throws IOException { + Response response = TestHelper.makeRequest(client, "GET", "/_plugins/_ml/models/_search", null, query, null); + verifyResponse(function, response); + } + + public void searchTasksWithAlgoName(RestClient client, String algoName, Consumer> function) throws IOException { + String query = String.format(Locale.ROOT, "{\"query\":{\"bool\":{\"filter\":[{\"term\":{\"function_name\":\"%s\"}}]}}}", algoName); + searchTasks(client, query, function); + } + + public void searchTasks(RestClient client, String query, Consumer> function) throws IOException { + Response response = TestHelper.makeRequest(client, "GET", "/_plugins/_ml/tasks/_search", null, query, null); + verifyResponse(function, response); + } + + private void verifyResponse(Consumer> function, Response response) throws IOException { + HttpEntity entity = response.getEntity(); + assertNotNull(response); + String entityString = TestHelper.httpEntityToString(entity); + Map map = gson.fromJson(entityString, Map.class); + if (function != null) { + function.accept(map); + } + } } diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelActionTests.java index e6d9fb4398..a784e814c6 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLGetModelActionTests.java @@ -9,7 +9,6 @@ import org.junit.Before; import org.junit.Rule; -import org.junit.Test; import org.junit.rules.ExpectedException; import org.opensearch.common.Strings; import org.opensearch.rest.RestHandler; @@ -27,20 +26,17 @@ public void setup() { restMLGetModelAction = new RestMLGetModelAction(); } - @Test public void testConstructor() { RestMLGetModelAction mlGetModelAction = new RestMLGetModelAction(); assertNotNull(mlGetModelAction); } - @Test public void testGetName() { String actionName = restMLGetModelAction.getName(); assertFalse(Strings.isNullOrEmpty(actionName)); assertEquals("ml_get_model_action", actionName); } - @Test public void testRoutes() { List routes = restMLGetModelAction.routes(); assertNotNull(routes); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/SecureMLRestIT.java b/plugin/src/test/java/org/opensearch/ml/rest/SecureMLRestIT.java index 0a566bbc7a..f2ee77f25c 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/SecureMLRestIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/SecureMLRestIT.java @@ -8,6 +8,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Map; import org.apache.http.HttpHost; import org.junit.After; @@ -20,8 +21,11 @@ import org.opensearch.index.query.MatchAllQueryBuilder; import org.opensearch.ml.common.parameter.FunctionName; import org.opensearch.ml.common.parameter.KMeansParams; +import org.opensearch.ml.common.parameter.MLTaskState; import org.opensearch.search.builder.SearchSourceBuilder; +import com.google.common.base.Throwables; + public class SecureMLRestIT extends MLCommonsRestTestCase { private String irisIndex = "iris_data_secure_ml_it"; @@ -129,6 +133,20 @@ public void testTrainAndPredictWithFullMLAccessNoIndexAccess() throws IOExceptio ); } + public void testTrainWithReadOnlyMLAccess() throws IOException { + exceptionRule.expect(ResponseException.class); + exceptionRule.expectMessage("no permissions for [cluster:admin/opensearch/ml/train]"); + KMeansParams kMeansParams = KMeansParams.builder().build(); + train(mlReadOnlyClient, FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, null, false); + } + + public void testPredictWithReadOnlyMLAccess() throws IOException { + exceptionRule.expect(ResponseException.class); + exceptionRule.expectMessage("no permissions for [cluster:admin/opensearch/ml/predict]"); + KMeansParams kMeansParams = KMeansParams.builder().build(); + predict(mlReadOnlyClient, FunctionName.KMEANS, "modelId", irisIndex, kMeansParams, searchSourceBuilder, null); + } + public void testTrainAndPredictWithFullAccess() throws IOException { trainAndPredict( mlFullAccessClient, @@ -142,4 +160,151 @@ public void testTrainAndPredictWithFullAccess() throws IOException { } ); } + + public void testTrainModelWithFullAccessThenPredict() throws IOException { + KMeansParams kMeansParams = KMeansParams.builder().build(); + // train model + train(mlFullAccessClient, FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, trainResult -> { + String modelId = (String) trainResult.get("model_id"); + assertNotNull(modelId); + String status = (String) trainResult.get("status"); + assertEquals(MLTaskState.COMPLETED.name(), status); + try { + getModel(mlFullAccessClient, modelId, model -> { + String algorithm = (String) model.get("algorithm"); + assertEquals(FunctionName.KMEANS.name(), algorithm); + }); + } catch (IOException e) { + assertNull(e); + } + try { + // predict with trained model + predict(mlFullAccessClient, FunctionName.KMEANS, modelId, irisIndex, kMeansParams, searchSourceBuilder, predictResult -> { + String predictStatus = (String) predictResult.get("status"); + assertEquals(MLTaskState.COMPLETED.name(), predictStatus); + Map predictionResult = (Map) predictResult.get("prediction_result"); + ArrayList rows = (ArrayList) predictionResult.get("rows"); + assertTrue(rows.size() > 1); + }); + } catch (IOException e) { + assertNull(e); + } + }, false); + } + + public void testTrainModelInAsyncWayWithFullAccess() throws IOException { + train(mlFullAccessClient, FunctionName.KMEANS, irisIndex, KMeansParams.builder().build(), searchSourceBuilder, trainResult -> { + assertFalse(trainResult.containsKey("model_id")); + String taskId = (String) trainResult.get("task_id"); + assertNotNull(taskId); + String status = (String) trainResult.get("status"); + assertEquals(MLTaskState.CREATED.name(), status); + try { + getTask(mlFullAccessClient, taskId, task -> { + String algorithm = (String) task.get("function_name"); + assertEquals(FunctionName.KMEANS.name(), algorithm); + }); + } catch (IOException e) { + assertNull(e); + } + }, true); + } + + public void testReadOnlyUser_CanGetModel_CanNotDeleteModel() throws IOException { + KMeansParams kMeansParams = KMeansParams.builder().build(); + // train model with full access client + train(mlFullAccessClient, FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, trainResult -> { + String modelId = (String) trainResult.get("model_id"); + assertNotNull(modelId); + String status = (String) trainResult.get("status"); + assertEquals(MLTaskState.COMPLETED.name(), status); + try { + // get model with readonly client + getModel(mlReadOnlyClient, modelId, model -> { + String algorithm = (String) model.get("algorithm"); + assertEquals(FunctionName.KMEANS.name(), algorithm); + }); + } catch (IOException e) { + assertNull(e); + } + try { + // Failed to delete model with read only client + deleteModel(mlReadOnlyClient, modelId, null); + throw new RuntimeException("Delete model for readonly user does not fail"); + } catch (Exception e) { + assertEquals(ResponseException.class, e.getClass()); + assertTrue(Throwables.getStackTraceAsString(e).contains("no permissions for [cluster:admin/opensearch/ml/models/delete]")); + } + }, false); + } + + public void testReadOnlyUser_CanGetTask_CanNotDeleteTask() throws IOException { + KMeansParams kMeansParams = KMeansParams.builder().build(); + // train model with full access client + train(mlFullAccessClient, FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, trainResult -> { + assertFalse(trainResult.containsKey("model_id")); + String taskId = (String) trainResult.get("task_id"); + assertNotNull(taskId); + String status = (String) trainResult.get("status"); + assertEquals(MLTaskState.CREATED.name(), status); + try { + // get task with readonly client + getTask(mlReadOnlyClient, taskId, task -> { + String algorithm = (String) task.get("function_name"); + assertEquals(FunctionName.KMEANS.name(), algorithm); + }); + } catch (IOException e) { + assertNull(e); + } + try { + // Failed to delete task with read only client + deleteTask(mlReadOnlyClient, taskId, null); + throw new RuntimeException("Delete task for readonly user does not fail"); + } catch (Exception e) { + assertEquals(ResponseException.class, e.getClass()); + assertTrue(Throwables.getStackTraceAsString(e).contains("no permissions for [cluster:admin/opensearch/ml/tasks/delete]")); + } + }, true); + } + + public void testReadOnlyUser_CanSearchModels() throws IOException { + KMeansParams kMeansParams = KMeansParams.builder().build(); + // train model with full access client + train(mlFullAccessClient, FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, trainResult -> { + String modelId = (String) trainResult.get("model_id"); + assertNotNull(modelId); + String status = (String) trainResult.get("status"); + assertEquals(MLTaskState.COMPLETED.name(), status); + try { + // search model with readonly client + searchModelsWithAlgoName(mlReadOnlyClient, FunctionName.KMEANS.name(), models -> { + ArrayList hits = (ArrayList) ((Map) models.get("hits")).get("hits"); + assertTrue(hits.size() > 0); + }); + } catch (IOException e) { + assertNull(e); + } + }, false); + } + + public void testReadOnlyUser_CanSearchTasks() throws IOException { + KMeansParams kMeansParams = KMeansParams.builder().build(); + // train model with full access client + train(mlFullAccessClient, FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, trainResult -> { + assertFalse(trainResult.containsKey("model_id")); + String taskId = (String) trainResult.get("task_id"); + assertNotNull(taskId); + String status = (String) trainResult.get("status"); + assertEquals(MLTaskState.CREATED.name(), status); + try { + // search tasks with readonly client + searchTasksWithAlgoName(mlReadOnlyClient, FunctionName.KMEANS.name(), tasks -> { + ArrayList hits = (ArrayList) ((Map) tasks.get("hits")).get("hits"); + assertTrue(hits.size() > 0); + }); + } catch (IOException e) { + assertNull(e); + } + }, true); + } }