Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Added Setting to Toggle Data Source Management Code Paths #2811

Merged
merged 2 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import java.io.IOException;
import java.util.List;
import java.util.Locale;
import lombok.RequiredArgsConstructor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchException;
Expand All @@ -26,11 +27,14 @@
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestRequest;
import org.opensearch.sql.common.setting.Settings;
import org.opensearch.sql.datasources.exceptions.DataSourceClientException;
import org.opensearch.sql.datasources.exceptions.ErrorMessage;
import org.opensearch.sql.datasources.utils.Scheduler;
import org.opensearch.sql.legacy.metrics.MetricName;
import org.opensearch.sql.legacy.utils.MetricUtils;
import org.opensearch.sql.opensearch.setting.OpenSearchSettings;
import org.opensearch.sql.opensearch.util.RestRequestUtil;
import org.opensearch.sql.spark.asyncquery.exceptions.AsyncQueryNotFoundException;
import org.opensearch.sql.spark.leasemanager.ConcurrencyLimitExceededException;
import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest;
Expand All @@ -44,13 +48,16 @@
import org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionRequest;
import org.opensearch.sql.spark.transport.model.GetAsyncQueryResultActionResponse;

@RequiredArgsConstructor
public class RestAsyncQueryManagementAction extends BaseRestHandler {

public static final String ASYNC_QUERY_ACTIONS = "async_query_actions";
public static final String BASE_ASYNC_QUERY_ACTION_URL = "/_plugins/_async_query";

private static final Logger LOG = LogManager.getLogger(RestAsyncQueryManagementAction.class);

private final OpenSearchSettings settings;

@Override
public String getName() {
return ASYNC_QUERY_ACTIONS;
Expand Down Expand Up @@ -99,6 +106,9 @@ public List<Route> routes() {
@Override
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient nodeClient)
throws IOException {
if (!dataSourcesEnabled()) {
return dataSourcesDisabledError(restRequest);
}
switch (restRequest.method()) {
case POST:
return executePostRequest(restRequest, nodeClient);
Expand Down Expand Up @@ -271,4 +281,21 @@ private void addCustomerErrorMetric(RestRequest.Method requestMethod) {
break;
}
}

private boolean dataSourcesEnabled() {
return settings.getSettingValue(Settings.Key.DATASOURCES_ENABLED);
}

private RestChannelConsumer dataSourcesDisabledError(RestRequest request) {

RestRequestUtil.consumeAllRequestParameters(request);

return channel -> {
reportError(
channel,
new IllegalAccessException(
String.format("%s setting is false", Settings.Key.DATASOURCES_ENABLED.getKeyValue())),
BAD_REQUEST);
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,10 @@ private DataSourceServiceImpl createDataSourceService() {
String masterKey = "a57d991d9b573f75b9bba1df";
DataSourceMetadataStorage dataSourceMetadataStorage =
new OpenSearchDataSourceMetadataStorage(
client, clusterService, new EncryptorImpl(masterKey));
client,
clusterService,
new EncryptorImpl(masterKey),
(OpenSearchSettings) pluginSettings);
return new DataSourceServiceImpl(
new ImmutableSet.Builder<DataSourceFactory>()
.add(new GlueDataSourceFactory(pluginSettings))
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package org.opensearch.sql.spark.rest;

import com.google.gson.Gson;
import com.google.gson.JsonObject;
import lombok.SneakyThrows;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;
import org.opensearch.client.node.NodeClient;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestRequest;
import org.opensearch.rest.RestResponse;
import org.opensearch.sql.common.setting.Settings;
import org.opensearch.sql.opensearch.setting.OpenSearchSettings;
import org.opensearch.threadpool.ThreadPool;

public class RestAsyncQueryManagementActionTest {

private OpenSearchSettings settings;
private RestRequest request;
private RestChannel channel;
private NodeClient nodeClient;
private ThreadPool threadPool;
private RestAsyncQueryManagementAction unit;

@BeforeEach
public void setup() {
settings = Mockito.mock(OpenSearchSettings.class);
request = Mockito.mock(RestRequest.class);
channel = Mockito.mock(RestChannel.class);
nodeClient = Mockito.mock(NodeClient.class);
threadPool = Mockito.mock(ThreadPool.class);

Mockito.when(nodeClient.threadPool()).thenReturn(threadPool);

unit = new RestAsyncQueryManagementAction(settings);
}

@Test
@SneakyThrows
public void testWhenDataSourcesAreDisabled() {
setDataSourcesEnabled(false);
unit.handleRequest(request, channel, nodeClient);
Mockito.verifyNoInteractions(nodeClient);
ArgumentCaptor<RestResponse> response = ArgumentCaptor.forClass(RestResponse.class);
Mockito.verify(channel, Mockito.times(1)).sendResponse(response.capture());
Assertions.assertEquals(400, response.getValue().status().getStatus());
JsonObject actualResponseJson =
new Gson().fromJson(response.getValue().content().utf8ToString(), JsonObject.class);
JsonObject expectedResponseJson = new JsonObject();
expectedResponseJson.addProperty("status", 400);
expectedResponseJson.add("error", new JsonObject());
expectedResponseJson.getAsJsonObject("error").addProperty("type", "IllegalAccessException");
expectedResponseJson.getAsJsonObject("error").addProperty("reason", "Invalid Request");
expectedResponseJson
.getAsJsonObject("error")
.addProperty("details", "plugins.query.datasources.enabled setting is false");
Assertions.assertEquals(expectedResponseJson, actualResponseJson);
}

@Test
@SneakyThrows
public void testWhenDataSourcesAreEnabled() {
setDataSourcesEnabled(true);
Mockito.when(request.method()).thenReturn(RestRequest.Method.GET);
unit.handleRequest(request, channel, nodeClient);
Mockito.verify(threadPool, Mockito.times(1))
.schedule(ArgumentMatchers.any(), ArgumentMatchers.any(), ArgumentMatchers.any());
Mockito.verifyNoInteractions(channel);
}

@Test
public void testGetName() {
Assertions.assertEquals("async_query_actions", unit.getName());
}

private void setDataSourcesEnabled(boolean value) {
Mockito.when(settings.getSettingValue(Settings.Key.DATASOURCES_ENABLED)).thenReturn(value);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public enum Key {
ENCYRPTION_MASTER_KEY("plugins.query.datasources.encryption.masterkey"),
DATASOURCES_URI_HOSTS_DENY_LIST("plugins.query.datasources.uri.hosts.denylist"),
DATASOURCES_LIMIT("plugins.query.datasources.limit"),
DATASOURCES_ENABLED("plugins.query.datasources.enabled"),

METRICS_ROLLING_WINDOW("plugins.query.metrics.rolling_window"),
METRICS_ROLLING_INTERVAL("plugins.query.metrics.rolling_interval"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,20 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import lombok.RequiredArgsConstructor;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchException;
import org.opensearch.OpenSearchSecurityException;
import org.opensearch.OpenSearchStatusException;
import org.opensearch.client.node.NodeClient;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.rest.BaseRestHandler;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestChannel;
import org.opensearch.rest.RestRequest;
import org.opensearch.sql.common.setting.Settings;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.datasources.exceptions.DataSourceNotFoundException;
import org.opensearch.sql.datasources.exceptions.ErrorMessage;
Expand All @@ -37,14 +40,19 @@
import org.opensearch.sql.datasources.utils.XContentParserUtils;
import org.opensearch.sql.legacy.metrics.MetricName;
import org.opensearch.sql.legacy.utils.MetricUtils;
import org.opensearch.sql.opensearch.setting.OpenSearchSettings;
import org.opensearch.sql.opensearch.util.RestRequestUtil;

@RequiredArgsConstructor
public class RestDataSourceQueryAction extends BaseRestHandler {

public static final String DATASOURCE_ACTIONS = "datasource_actions";
public static final String BASE_DATASOURCE_ACTION_URL = "/_plugins/_query/_datasources";

private static final Logger LOG = LogManager.getLogger(RestDataSourceQueryAction.class);

private final OpenSearchSettings settings;

@Override
public String getName() {
return DATASOURCE_ACTIONS;
Expand Down Expand Up @@ -115,6 +123,9 @@ public List<Route> routes() {
@Override
protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient nodeClient)
throws IOException {
if (!enabled()) {
return disabledError(restRequest);
}
switch (restRequest.method()) {
case POST:
return executePostRequest(restRequest, nodeClient);
Expand Down Expand Up @@ -314,4 +325,22 @@ private static boolean isClientError(Exception e) {
|| e instanceof IllegalArgumentException
|| e instanceof IllegalStateException;
}

private boolean enabled() {
return settings.getSettingValue(Settings.Key.DATASOURCES_ENABLED);
}

private RestChannelConsumer disabledError(RestRequest request) {

RestRequestUtil.consumeAllRequestParameters(request);

return channel -> {
reportError(
channel,
new OpenSearchStatusException(
String.format("%s setting is false", Settings.Key.DATASOURCES_ENABLED.getKeyValue()),
BAD_REQUEST),
BAD_REQUEST);
};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,13 @@
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.search.SearchHit;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.sql.common.setting.Settings;
import org.opensearch.sql.datasource.model.DataSourceMetadata;
import org.opensearch.sql.datasources.encryptor.Encryptor;
import org.opensearch.sql.datasources.exceptions.DataSourceNotFoundException;
import org.opensearch.sql.datasources.service.DataSourceMetadataStorage;
import org.opensearch.sql.datasources.utils.XContentParserUtils;
import org.opensearch.sql.opensearch.setting.OpenSearchSettings;

public class OpenSearchDataSourceMetadataStorage implements DataSourceMetadataStorage {

Expand All @@ -61,6 +63,7 @@ public class OpenSearchDataSourceMetadataStorage implements DataSourceMetadataSt
private final ClusterService clusterService;

private final Encryptor encryptor;
private final OpenSearchSettings settings;

/**
* This class implements DataSourceMetadataStorage interface using OpenSearch as underlying
Expand All @@ -71,14 +74,21 @@ public class OpenSearchDataSourceMetadataStorage implements DataSourceMetadataSt
* @param encryptor Encryptor.
*/
public OpenSearchDataSourceMetadataStorage(
Client client, ClusterService clusterService, Encryptor encryptor) {
Client client,
ClusterService clusterService,
Encryptor encryptor,
OpenSearchSettings settings) {
this.client = client;
this.clusterService = clusterService;
this.encryptor = encryptor;
this.settings = settings;
}

@Override
public List<DataSourceMetadata> getDataSourceMetadata() {
if (!isEnabled()) {
return Collections.emptyList();
}
if (!this.clusterService.state().routingTable().hasIndex(DATASOURCE_INDEX_NAME)) {
createDataSourcesIndex();
return Collections.emptyList();
Expand All @@ -88,6 +98,9 @@ public List<DataSourceMetadata> getDataSourceMetadata() {

@Override
public Optional<DataSourceMetadata> getDataSourceMetadata(String datasourceName) {
if (!isEnabled()) {
return Optional.empty();
}
if (!this.clusterService.state().routingTable().hasIndex(DATASOURCE_INDEX_NAME)) {
createDataSourcesIndex();
return Optional.empty();
Expand All @@ -101,6 +114,9 @@ public Optional<DataSourceMetadata> getDataSourceMetadata(String datasourceName)

@Override
public void createDataSourceMetadata(DataSourceMetadata dataSourceMetadata) {
if (!isEnabled()) {
throw new IllegalStateException("Data source management is disabled");
}
encryptDecryptAuthenticationData(dataSourceMetadata, true);
if (!this.clusterService.state().routingTable().hasIndex(DATASOURCE_INDEX_NAME)) {
createDataSourcesIndex();
Expand Down Expand Up @@ -134,6 +150,9 @@ public void createDataSourceMetadata(DataSourceMetadata dataSourceMetadata) {

@Override
public void updateDataSourceMetadata(DataSourceMetadata dataSourceMetadata) {
if (!isEnabled()) {
throw new IllegalStateException("Data source management is disabled");
}
encryptDecryptAuthenticationData(dataSourceMetadata, true);
UpdateRequest updateRequest =
new UpdateRequest(DATASOURCE_INDEX_NAME, dataSourceMetadata.getName());
Expand Down Expand Up @@ -163,6 +182,9 @@ public void updateDataSourceMetadata(DataSourceMetadata dataSourceMetadata) {

@Override
public void deleteDataSourceMetadata(String datasourceName) {
if (!isEnabled()) {
throw new IllegalStateException("Data source management is disabled");
}
DeleteRequest deleteRequest = new DeleteRequest(DATASOURCE_INDEX_NAME);
deleteRequest.id(datasourceName);
deleteRequest.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);
Expand Down Expand Up @@ -302,4 +324,8 @@ private void handleSigV4PropertiesEncryptionDecryption(
.ifPresent(list::add);
encryptOrDecrypt(propertiesMap, isEncryption, list);
}

private boolean isEnabled() {
return settings.getSettingValue(Settings.Key.DATASOURCES_ENABLED);
}
}
Loading
Loading