Skip to content

Commit

Permalink
fix no permission to create model/task index bug;add security IT for …
Browse files Browse the repository at this point in the history
…train/predict API (#177)

* fix no permission to create model/task index bug;add security IT for train/predict API

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

* add more security IT for readonly user

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

* throw exception if delete model/task successfully for readonly user

Signed-off-by: Yaliang Wu <ylwu@amazon.com>
(cherry picked from commit f12ca76)
  • Loading branch information
ylwu-amzn authored and github-actions[bot] committed Mar 9, 2022
1 parent 75c94a2 commit df01175
Show file tree
Hide file tree
Showing 4 changed files with 285 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@

import org.opensearch.action.ActionListener;
import org.opensearch.action.admin.indices.create.CreateIndexRequest;
import org.opensearch.action.admin.indices.create.CreateIndexResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.XContentType;

@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
Expand Down Expand Up @@ -87,19 +89,24 @@ public void initMLTaskIndex(ActionListener<Boolean> listener) {

public void initMLIndexIfAbsent(String indexName, String mapping, ActionListener<Boolean> listener) {
if (!clusterService.state().metadata().hasIndex(indexName)) {
CreateIndexRequest request = new CreateIndexRequest(indexName).mapping("_doc", mapping, XContentType.JSON);

client.admin().indices().create(request, ActionListener.wrap(r -> {
if (r.isAcknowledged()) {
log.info("create index:{}", indexName);
listener.onResponse(true);
} else {
listener.onResponse(false);
}
}, e -> {
log.error("Failed to create index " + indexName, e);
try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) {
ActionListener<CreateIndexResponse> actionListener = ActionListener.wrap(r -> {
if (r.isAcknowledged()) {
log.info("create index:{}", indexName);
listener.onResponse(true);
} else {
listener.onResponse(false);
}
}, e -> {
log.error("Failed to create index " + indexName, e);
listener.onFailure(e);
});
CreateIndexRequest request = new CreateIndexRequest(indexName).mapping("_doc", mapping, XContentType.JSON);
client.admin().indices().create(request, ActionListener.runBefore(actionListener, () -> threadContext.restore()));
} catch (Exception e) {
log.error("Failed to init index " + indexName, e);
listener.onFailure(e);
}));
}
} else {
log.info("index:{} is already created", indexName);
listener.onResponse(true);
Expand Down
109 changes: 101 additions & 8 deletions plugin/src/test/java/org/opensearch/ml/rest/MLCommonsRestTestCase.java
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,10 @@ protected Response ingestIrisData(String indexName) throws IOException {
protected void validateStats(
FunctionName functionName,
ActionName actionName,
int expectedTotalFailureCount,
int expectedTotalAlgoFailureCount,
int expectedMinumnTotalRequestCount,
int expectedTotalAlgoRequestCount
int expectedMinimumTotalFailureCount,
int expectedMinimumTotalAlgoFailureCount,
int expectedMinimumTotalRequestCount,
int expectedMinimumTotalAlgoRequestCount
) throws IOException {
Response statsResponse = TestHelper.makeRequest(client(), "GET", "_plugins/_ml/stats", null, "", null);
HttpEntity entity = statsResponse.getEntity();
Expand Down Expand Up @@ -291,10 +291,10 @@ protected void validateStats(
totalAlgoRequestCount += (Double) nodeStatsMap.get(requestCountStat);
}
}
assertEquals(expectedTotalFailureCount, totalFailureCount);
assertEquals(expectedTotalAlgoFailureCount, totalAlgoFailureCount);
assertTrue(totalRequestCount >= expectedMinumnTotalRequestCount);
assertEquals(expectedTotalAlgoRequestCount, totalAlgoRequestCount);
assertTrue(totalFailureCount >= expectedMinimumTotalFailureCount);
assertTrue(totalAlgoFailureCount >= expectedMinimumTotalAlgoFailureCount);
assertTrue(totalRequestCount >= expectedMinimumTotalRequestCount);
assertTrue(totalAlgoRequestCount >= expectedMinimumTotalAlgoRequestCount);
}

protected Response ingestModelData() throws IOException {
Expand Down Expand Up @@ -464,4 +464,97 @@ public void trainAndPredict(
function.accept(predictionResult);
}
}

public void train(
RestClient client,
FunctionName functionName,
String indexName,
MLAlgoParams params,
SearchSourceBuilder searchSourceBuilder,
Consumer<Map<String, Object>> function,
boolean async
) throws IOException {
MLInputDataset inputData = SearchQueryInputDataset
.builder()
.indices(ImmutableList.of(indexName))
.searchSourceBuilder(searchSourceBuilder)
.build();
MLInput kmeansInput = MLInput.builder().algorithm(functionName).parameters(params).inputDataset(inputData).build();
String endpoint = "/_plugins/_ml/_train/" + functionName.name().toLowerCase(Locale.ROOT);
if (async) {
endpoint += "?async=true";
}
Response response = TestHelper.makeRequest(client, "POST", endpoint, ImmutableMap.of(), TestHelper.toHttpEntity(kmeansInput), null);
verifyResponse(function, response);
}

public void predict(
RestClient client,
FunctionName functionName,
String modelId,
String indexName,
MLAlgoParams params,
SearchSourceBuilder searchSourceBuilder,
Consumer<Map<String, Object>> function
) throws IOException {
MLInputDataset inputData = SearchQueryInputDataset
.builder()
.indices(ImmutableList.of(indexName))
.searchSourceBuilder(searchSourceBuilder)
.build();
MLInput kmeansInput = MLInput.builder().algorithm(functionName).parameters(params).inputDataset(inputData).build();
String endpoint = "/_plugins/_ml/_predict/" + functionName.name().toLowerCase(Locale.ROOT) + "/" + modelId;
Response response = TestHelper.makeRequest(client, "POST", endpoint, ImmutableMap.of(), TestHelper.toHttpEntity(kmeansInput), null);
verifyResponse(function, response);
}

public void getModel(RestClient client, String modelId, Consumer<Map<String, Object>> function) throws IOException {
Response response = TestHelper.makeRequest(client, "GET", "/_plugins/_ml/models/" + modelId, null, "", null);
verifyResponse(function, response);
}

public void getTask(RestClient client, String taskId, Consumer<Map<String, Object>> function) throws IOException {
Response response = TestHelper.makeRequest(client, "GET", "/_plugins/_ml/tasks/" + taskId, null, "", null);
verifyResponse(function, response);
}

public void deleteModel(RestClient client, String modelId, Consumer<Map<String, Object>> function) throws IOException {
Response response = TestHelper.makeRequest(client, "DELETE", "/_plugins/_ml/models/" + modelId, null, "", null);
verifyResponse(function, response);
}

public void deleteTask(RestClient client, String taskId, Consumer<Map<String, Object>> function) throws IOException {
Response response = TestHelper.makeRequest(client, "DELETE", "/_plugins/_ml/tasks/" + taskId, null, "", null);
verifyResponse(function, response);
}

public void searchModelsWithAlgoName(RestClient client, String algoName, Consumer<Map<String, Object>> function) throws IOException {
String query = String.format(Locale.ROOT, "{\"query\":{\"bool\":{\"filter\":[{\"term\":{\"algorithm\":\"%s\"}}]}}}", algoName);
searchModels(client, query, function);
}

public void searchModels(RestClient client, String query, Consumer<Map<String, Object>> function) throws IOException {
Response response = TestHelper.makeRequest(client, "GET", "/_plugins/_ml/models/_search", null, query, null);
verifyResponse(function, response);
}

public void searchTasksWithAlgoName(RestClient client, String algoName, Consumer<Map<String, Object>> function) throws IOException {
String query = String.format(Locale.ROOT, "{\"query\":{\"bool\":{\"filter\":[{\"term\":{\"function_name\":\"%s\"}}]}}}", algoName);
searchTasks(client, query, function);
}

public void searchTasks(RestClient client, String query, Consumer<Map<String, Object>> function) throws IOException {
Response response = TestHelper.makeRequest(client, "GET", "/_plugins/_ml/tasks/_search", null, query, null);
verifyResponse(function, response);
}

private void verifyResponse(Consumer<Map<String, Object>> function, Response response) throws IOException {
HttpEntity entity = response.getEntity();
assertNotNull(response);
String entityString = TestHelper.httpEntityToString(entity);
Map<String, Object> map = gson.fromJson(entityString, Map.class);
if (function != null) {
function.accept(map);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.common.Strings;
import org.opensearch.rest.RestHandler;
Expand All @@ -27,20 +26,17 @@ public void setup() {
restMLGetModelAction = new RestMLGetModelAction();
}

@Test
public void testConstructor() {
RestMLGetModelAction mlGetModelAction = new RestMLGetModelAction();
assertNotNull(mlGetModelAction);
}

@Test
public void testGetName() {
String actionName = restMLGetModelAction.getName();
assertFalse(Strings.isNullOrEmpty(actionName));
assertEquals("ml_get_model_action", actionName);
}

@Test
public void testRoutes() {
List<RestHandler.Route> routes = restMLGetModelAction.routes();
assertNotNull(routes);
Expand Down
165 changes: 165 additions & 0 deletions plugin/src/test/java/org/opensearch/ml/rest/SecureMLRestIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Map;

import org.apache.http.HttpHost;
import org.junit.After;
Expand All @@ -20,8 +21,11 @@
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.ml.common.parameter.FunctionName;
import org.opensearch.ml.common.parameter.KMeansParams;
import org.opensearch.ml.common.parameter.MLTaskState;
import org.opensearch.search.builder.SearchSourceBuilder;

import com.google.common.base.Throwables;

public class SecureMLRestIT extends MLCommonsRestTestCase {
private String irisIndex = "iris_data_secure_ml_it";

Expand Down Expand Up @@ -129,6 +133,20 @@ public void testTrainAndPredictWithFullMLAccessNoIndexAccess() throws IOExceptio
);
}

public void testTrainWithReadOnlyMLAccess() throws IOException {
exceptionRule.expect(ResponseException.class);
exceptionRule.expectMessage("no permissions for [cluster:admin/opensearch/ml/train]");
KMeansParams kMeansParams = KMeansParams.builder().build();
train(mlReadOnlyClient, FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, null, false);
}

public void testPredictWithReadOnlyMLAccess() throws IOException {
exceptionRule.expect(ResponseException.class);
exceptionRule.expectMessage("no permissions for [cluster:admin/opensearch/ml/predict]");
KMeansParams kMeansParams = KMeansParams.builder().build();
predict(mlReadOnlyClient, FunctionName.KMEANS, "modelId", irisIndex, kMeansParams, searchSourceBuilder, null);
}

public void testTrainAndPredictWithFullAccess() throws IOException {
trainAndPredict(
mlFullAccessClient,
Expand All @@ -142,4 +160,151 @@ public void testTrainAndPredictWithFullAccess() throws IOException {
}
);
}

public void testTrainModelWithFullAccessThenPredict() throws IOException {
KMeansParams kMeansParams = KMeansParams.builder().build();
// train model
train(mlFullAccessClient, FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, trainResult -> {
String modelId = (String) trainResult.get("model_id");
assertNotNull(modelId);
String status = (String) trainResult.get("status");
assertEquals(MLTaskState.COMPLETED.name(), status);
try {
getModel(mlFullAccessClient, modelId, model -> {
String algorithm = (String) model.get("algorithm");
assertEquals(FunctionName.KMEANS.name(), algorithm);
});
} catch (IOException e) {
assertNull(e);
}
try {
// predict with trained model
predict(mlFullAccessClient, FunctionName.KMEANS, modelId, irisIndex, kMeansParams, searchSourceBuilder, predictResult -> {
String predictStatus = (String) predictResult.get("status");
assertEquals(MLTaskState.COMPLETED.name(), predictStatus);
Map<String, Object> predictionResult = (Map<String, Object>) predictResult.get("prediction_result");
ArrayList rows = (ArrayList) predictionResult.get("rows");
assertTrue(rows.size() > 1);
});
} catch (IOException e) {
assertNull(e);
}
}, false);
}

public void testTrainModelInAsyncWayWithFullAccess() throws IOException {
train(mlFullAccessClient, FunctionName.KMEANS, irisIndex, KMeansParams.builder().build(), searchSourceBuilder, trainResult -> {
assertFalse(trainResult.containsKey("model_id"));
String taskId = (String) trainResult.get("task_id");
assertNotNull(taskId);
String status = (String) trainResult.get("status");
assertEquals(MLTaskState.CREATED.name(), status);
try {
getTask(mlFullAccessClient, taskId, task -> {
String algorithm = (String) task.get("function_name");
assertEquals(FunctionName.KMEANS.name(), algorithm);
});
} catch (IOException e) {
assertNull(e);
}
}, true);
}

public void testReadOnlyUser_CanGetModel_CanNotDeleteModel() throws IOException {
KMeansParams kMeansParams = KMeansParams.builder().build();
// train model with full access client
train(mlFullAccessClient, FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, trainResult -> {
String modelId = (String) trainResult.get("model_id");
assertNotNull(modelId);
String status = (String) trainResult.get("status");
assertEquals(MLTaskState.COMPLETED.name(), status);
try {
// get model with readonly client
getModel(mlReadOnlyClient, modelId, model -> {
String algorithm = (String) model.get("algorithm");
assertEquals(FunctionName.KMEANS.name(), algorithm);
});
} catch (IOException e) {
assertNull(e);
}
try {
// Failed to delete model with read only client
deleteModel(mlReadOnlyClient, modelId, null);
throw new RuntimeException("Delete model for readonly user does not fail");
} catch (Exception e) {
assertEquals(ResponseException.class, e.getClass());
assertTrue(Throwables.getStackTraceAsString(e).contains("no permissions for [cluster:admin/opensearch/ml/models/delete]"));
}
}, false);
}

public void testReadOnlyUser_CanGetTask_CanNotDeleteTask() throws IOException {
KMeansParams kMeansParams = KMeansParams.builder().build();
// train model with full access client
train(mlFullAccessClient, FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, trainResult -> {
assertFalse(trainResult.containsKey("model_id"));
String taskId = (String) trainResult.get("task_id");
assertNotNull(taskId);
String status = (String) trainResult.get("status");
assertEquals(MLTaskState.CREATED.name(), status);
try {
// get task with readonly client
getTask(mlReadOnlyClient, taskId, task -> {
String algorithm = (String) task.get("function_name");
assertEquals(FunctionName.KMEANS.name(), algorithm);
});
} catch (IOException e) {
assertNull(e);
}
try {
// Failed to delete task with read only client
deleteTask(mlReadOnlyClient, taskId, null);
throw new RuntimeException("Delete task for readonly user does not fail");
} catch (Exception e) {
assertEquals(ResponseException.class, e.getClass());
assertTrue(Throwables.getStackTraceAsString(e).contains("no permissions for [cluster:admin/opensearch/ml/tasks/delete]"));
}
}, true);
}

public void testReadOnlyUser_CanSearchModels() throws IOException {
KMeansParams kMeansParams = KMeansParams.builder().build();
// train model with full access client
train(mlFullAccessClient, FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, trainResult -> {
String modelId = (String) trainResult.get("model_id");
assertNotNull(modelId);
String status = (String) trainResult.get("status");
assertEquals(MLTaskState.COMPLETED.name(), status);
try {
// search model with readonly client
searchModelsWithAlgoName(mlReadOnlyClient, FunctionName.KMEANS.name(), models -> {
ArrayList<Object> hits = (ArrayList) ((Map<String, Object>) models.get("hits")).get("hits");
assertTrue(hits.size() > 0);
});
} catch (IOException e) {
assertNull(e);
}
}, false);
}

public void testReadOnlyUser_CanSearchTasks() throws IOException {
KMeansParams kMeansParams = KMeansParams.builder().build();
// train model with full access client
train(mlFullAccessClient, FunctionName.KMEANS, irisIndex, kMeansParams, searchSourceBuilder, trainResult -> {
assertFalse(trainResult.containsKey("model_id"));
String taskId = (String) trainResult.get("task_id");
assertNotNull(taskId);
String status = (String) trainResult.get("status");
assertEquals(MLTaskState.CREATED.name(), status);
try {
// search tasks with readonly client
searchTasksWithAlgoName(mlReadOnlyClient, FunctionName.KMEANS.name(), tasks -> {
ArrayList<Object> hits = (ArrayList) ((Map<String, Object>) tasks.get("hits")).get("hits");
assertTrue(hits.size() > 0);
});
} catch (IOException e) {
assertNull(e);
}
}, true);
}
}

0 comments on commit df01175

Please sign in to comment.