Skip to content

Commit

Permalink
add feature flag for offline batch ingestion (#2982)
Browse files Browse the repository at this point in the history
* add feature flag for offline batch ingestion

Signed-off-by: Xun Zhang <xunzh@amazon.com>

* add feature flag for offline batch inference

Signed-off-by: Xun Zhang <xunzh@amazon.com>

---------

Signed-off-by: Xun Zhang <xunzh@amazon.com>
(cherry picked from commit 107b916)
  • Loading branch information
Zhangxunmt authored and github-actions[bot] committed Sep 25, 2024
1 parent 18a15e8 commit 73cc356
Show file tree
Hide file tree
Showing 13 changed files with 171 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import static org.opensearch.ml.common.MLTaskState.FAILED;
import static org.opensearch.ml.plugin.MachineLearningPlugin.INGEST_THREAD_POOL;
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;
import static org.opensearch.ml.utils.MLExceptionUtils.OFFLINE_BATCH_INGESTION_DISABLED_ERR_MSG;

import java.time.Instant;
import java.util.List;
Expand All @@ -35,6 +36,7 @@
import org.opensearch.ml.common.transport.batch.MLBatchIngestionResponse;
import org.opensearch.ml.engine.MLEngineClassLoader;
import org.opensearch.ml.engine.ingest.Ingestable;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.ml.utils.MLExceptionUtils;
import org.opensearch.tasks.Task;
Expand All @@ -55,27 +57,33 @@ public class TransportBatchIngestionAction extends HandledTransportAction<Action
MLTaskManager mlTaskManager;
private final Client client;
private ThreadPool threadPool;
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Inject
public TransportBatchIngestionAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
MLTaskManager mlTaskManager,
ThreadPool threadPool
ThreadPool threadPool,
MLFeatureEnabledSetting mlFeatureEnabledSetting
) {
super(MLBatchIngestionAction.NAME, transportService, actionFilters, MLBatchIngestionRequest::new);
this.transportService = transportService;
this.client = client;
this.mlTaskManager = mlTaskManager;
this.threadPool = threadPool;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<MLBatchIngestionResponse> listener) {
MLBatchIngestionRequest mlBatchIngestionRequest = MLBatchIngestionRequest.fromActionRequest(request);
MLBatchIngestionInput mlBatchIngestionInput = mlBatchIngestionRequest.getMlBatchIngestionInput();
try {
if (!mlFeatureEnabledSetting.isOfflineBatchIngestionEnabled()) {
throw new IllegalStateException(OFFLINE_BATCH_INGESTION_DISABLED_ERR_MSG);
}
validateBatchIngestInput(mlBatchIngestionInput);
MLTask mlTask = MLTask
.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import static org.opensearch.ml.common.CommonValue.ML_CONNECTOR_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX;
import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.CANCEL_BATCH_PREDICT;
import static org.opensearch.ml.utils.MLExceptionUtils.BATCH_INFERENCE_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;

import java.util.HashMap;
Expand Down Expand Up @@ -51,8 +52,8 @@
import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.script.ScriptService;
import org.opensearch.tasks.Task;
Expand All @@ -74,7 +75,7 @@ public class CancelBatchJobTransportAction extends HandledTransportAction<Action
MLModelManager mlModelManager;

MLTaskManager mlTaskManager;
MLModelCacheHelper modelCacheHelper;
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

@Inject
public CancelBatchJobTransportAction(
Expand All @@ -87,7 +88,8 @@ public CancelBatchJobTransportAction(
ConnectorAccessControlHelper connectorAccessControlHelper,
EncryptorImpl encryptor,
MLTaskManager mlTaskManager,
MLModelManager mlModelManager
MLModelManager mlModelManager,
MLFeatureEnabledSetting mlFeatureEnabledSetting
) {
super(MLCancelBatchJobAction.NAME, transportService, actionFilters, MLCancelBatchJobRequest::new);
this.client = client;
Expand All @@ -98,6 +100,7 @@ public CancelBatchJobTransportAction(
this.encryptor = encryptor;
this.mlTaskManager = mlTaskManager;
this.mlModelManager = mlModelManager;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;
}

@Override
Expand All @@ -116,6 +119,10 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLCanc
MLTask mlTask = MLTask.parse(parser);

// check if function is remote and task is of type batch prediction
if (mlTask.getTaskType() == MLTaskType.BATCH_PREDICTION
&& !mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()) {
throw new IllegalStateException(BATCH_INFERENCE_DISABLED_ERR_MSG);
}
if (mlTask.getTaskType() == MLTaskType.BATCH_PREDICTION && mlTask.getFunctionName() == FunctionName.REMOTE) {
processRemoteBatchPrediction(mlTask, actionListener);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_COMPLETED_REGEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_FIELD;
import static org.opensearch.ml.utils.MLExceptionUtils.BATCH_INFERENCE_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.MLExceptionUtils.logException;
import static org.opensearch.ml.utils.MLNodeUtils.createXContentParserFromRegistry;

Expand Down Expand Up @@ -68,8 +69,8 @@
import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorExecutor;
import org.opensearch.ml.engine.encryptor.EncryptorImpl;
import org.opensearch.ml.helper.ConnectorAccessControlHelper;
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.script.ScriptService;
import org.opensearch.tasks.Task;
Expand All @@ -91,7 +92,7 @@ public class GetTaskTransportAction extends HandledTransportAction<ActionRequest
MLModelManager mlModelManager;

MLTaskManager mlTaskManager;
MLModelCacheHelper modelCacheHelper;
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

volatile List<String> remoteJobStatusFields;
volatile Pattern remoteJobCompletedStatusRegexPattern;
Expand All @@ -111,6 +112,7 @@ public GetTaskTransportAction(
EncryptorImpl encryptor,
MLTaskManager mlTaskManager,
MLModelManager mlModelManager,
MLFeatureEnabledSetting mlFeatureEnabledSetting,
Settings settings
) {
super(MLTaskGetAction.NAME, transportService, actionFilters, MLTaskGetRequest::new);
Expand All @@ -122,6 +124,7 @@ public GetTaskTransportAction(
this.encryptor = encryptor;
this.mlTaskManager = mlTaskManager;
this.mlModelManager = mlModelManager;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;

remoteJobStatusFields = ML_COMMONS_REMOTE_JOB_STATUS_FIELD.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_REMOTE_JOB_STATUS_FIELD, it -> remoteJobStatusFields = it);
Expand Down Expand Up @@ -178,6 +181,10 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLTask
MLTask mlTask = MLTask.parse(parser);

// check if function is remote and task is of type batch prediction
if (mlTask.getTaskType() == MLTaskType.BATCH_PREDICTION
&& !mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()) {
throw new IllegalStateException(BATCH_INFERENCE_DISABLED_ERR_MSG);
}
if (mlTask.getTaskType() == MLTaskType.BATCH_PREDICTION && mlTask.getFunctionName() == FunctionName.REMOTE) {
processRemoteBatchPrediction(mlTask, taskId, actionListener);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,9 @@ public List<Setting<?>> getSettings() {
MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLED_REGEX,
MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_CANCELLING_REGEX,
MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX,
MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED
MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED,
MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED,
MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED
);
return settings;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import com.google.common.collect.ImmutableList;

public class RestMLGetTaskAction extends BaseRestHandler {
private static final String ML_GET_Task_ACTION = "ml_get_task_action";
private static final String ML_GET_TASK_ACTION = "ml_get_task_action";

/**
* Constructor
Expand All @@ -33,7 +33,7 @@ public RestMLGetTaskAction() {}

@Override
public String getName() {
return ML_GET_Task_ACTION;
return ML_GET_TASK_ACTION;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.MLExceptionUtils.BATCH_INFERENCE_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.MLExceptionUtils.LOCAL_MODEL_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.MLExceptionUtils.REMOTE_INFERENCE_DISABLED_ERR_MSG;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_ALGORITHM;
Expand Down Expand Up @@ -131,6 +132,8 @@ MLPredictionTaskRequest getRequest(String modelId, String algorithm, RestRequest
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
} else if (FunctionName.isDLModel(FunctionName.from(algorithm.toUpperCase())) && !mlFeatureEnabledSetting.isLocalModelEnabled()) {
throw new IllegalStateException(LOCAL_MODEL_DISABLED_ERR_MSG);
} else if (ActionType.BATCH_PREDICT == actionType && !mlFeatureEnabledSetting.isOfflineBatchInferenceEnabled()) {
throw new IllegalStateException(BATCH_INFERENCE_DISABLED_ERR_MSG);
} else if (!ActionType.isValidActionInModelPrediction(actionType)) {
throw new IllegalArgumentException("Wrong action type in the rest request path!");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ private MLCommonsSettings() {}
public static final Setting<Boolean> ML_COMMONS_CONNECTOR_ACCESS_CONTROL_ENABLED = Setting
.boolSetting("plugins.ml_commons.connector_access_control_enabled", false, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<Boolean> ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED = Setting
.boolSetting("plugins.ml_commons.offline_batch_ingestion_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<Boolean> ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED = Setting
.boolSetting("plugins.ml_commons.offline_batch_inference_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<List<String>> ML_COMMONS_TRUSTED_CONNECTOR_ENDPOINTS_REGEX = Setting
.listSetting(
"plugins.ml_commons.trusted_connector_endpoints_regex",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_LOCAL_MODEL_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_REMOTE_INFERENCE_ENABLED;

import java.util.concurrent.atomic.AtomicBoolean;
Expand All @@ -27,13 +29,17 @@ public class MLFeatureEnabledSetting {
private volatile AtomicBoolean isConnectorPrivateIpEnabled;

private volatile Boolean isControllerEnabled;
private volatile Boolean isBatchIngestionEnabled;
private volatile Boolean isBatchInferenceEnabled;

public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings) {
isRemoteInferenceEnabled = ML_COMMONS_REMOTE_INFERENCE_ENABLED.get(settings);
isAgentFrameworkEnabled = ML_COMMONS_AGENT_FRAMEWORK_ENABLED.get(settings);
isLocalModelEnabled = ML_COMMONS_LOCAL_MODEL_ENABLED.get(settings);
isConnectorPrivateIpEnabled = new AtomicBoolean(ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED.get(settings));
isControllerEnabled = ML_COMMONS_CONTROLLER_ENABLED.get(settings);
isBatchIngestionEnabled = ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED.get(settings);
isBatchInferenceEnabled = ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED.get(settings);

clusterService
.getClusterSettings()
Expand All @@ -46,6 +52,12 @@ public MLFeatureEnabledSetting(ClusterService clusterService, Settings settings)
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_CONNECTOR_PRIVATE_IP_ENABLED, it -> isConnectorPrivateIpEnabled.set(it));
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_CONTROLLER_ENABLED, it -> isControllerEnabled = it);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED, it -> isBatchIngestionEnabled = it);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED, it -> isBatchInferenceEnabled = it);
}

/**
Expand Down Expand Up @@ -84,4 +96,19 @@ public Boolean isControllerEnabled() {
return isControllerEnabled;
}

/**
* Whether the offline batch ingestion is enabled. If disabled, APIs in ml-commons will block offline batch ingestion.
* @return whether the feature is enabled.
*/
public Boolean isOfflineBatchIngestionEnabled() {
return isBatchIngestionEnabled;
}

/**
* Whether the offline batch inference is enabled. If disabled, APIs in ml-commons will block offline batch inference.
* @return whether the feature is enabled.
*/
public Boolean isOfflineBatchInferenceEnabled() {
return isBatchInferenceEnabled;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@ public class MLExceptionUtils {
"Remote Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.remote_inference_enabled\" to true.";
public static final String LOCAL_MODEL_DISABLED_ERR_MSG =
"Local Model is currently disabled. To enable it, update the setting \"plugins.ml_commons.local_model.enabled\" to true.";
public static final String BATCH_INFERENCE_DISABLED_ERR_MSG =
"Offline Batch Inference is currently disabled. To enable it, update the setting \"plugins.ml_commons.offline_batch_inference_enabled\" to true.";
public static final String AGENT_FRAMEWORK_DISABLED_ERR_MSG =
"Agent Framework is currently disabled. To enable it, update the setting \"plugins.ml_commons.agent_framework_enabled\" to true.";
public static final String CONTROLLER_DISABLED_ERR_MSG =
"Controller is currently disabled. To enable it, update the setting \"plugins.ml_commons.controller_enabled\" to true.";
public static final String OFFLINE_BATCH_INGESTION_DISABLED_ERR_MSG =
"Offline batch ingestion is currently disabled. To enable it, update the setting \"plugins.ml_commons.offline_batch_ingestion_enabled\" to true.";

public static String getRootCauseMessage(final Throwable throwable) {
String message = ExceptionUtils.getRootCauseMessage(throwable);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.opensearch.ml.common.transport.batch.MLBatchIngestionInput;
import org.opensearch.ml.common.transport.batch.MLBatchIngestionRequest;
import org.opensearch.ml.common.transport.batch.MLBatchIngestionResponse;
import org.opensearch.ml.settings.MLFeatureEnabledSetting;
import org.opensearch.ml.task.MLTaskManager;
import org.opensearch.tasks.Task;
import org.opensearch.test.OpenSearchTestCase;
Expand Down Expand Up @@ -73,6 +74,8 @@ public class TransportBatchIngestionActionTests extends OpenSearchTestCase {
ThreadPool threadPool;
@Mock
ExecutorService executorService;
@Mock
private MLFeatureEnabledSetting mlFeatureEnabledSetting;

private TransportBatchIngestionAction batchAction;
private MLBatchIngestionInput batchInput;
Expand All @@ -81,7 +84,14 @@ public class TransportBatchIngestionActionTests extends OpenSearchTestCase {
@Before
public void setup() {
MockitoAnnotations.openMocks(this);
batchAction = new TransportBatchIngestionAction(transportService, actionFilters, client, mlTaskManager, threadPool);
batchAction = new TransportBatchIngestionAction(
transportService,
actionFilters,
client,
mlTaskManager,
threadPool,
mlFeatureEnabledSetting
);

Map<String, Object> fieldMap = new HashMap<>();
fieldMap.put("chapter", "$.content[0]");
Expand All @@ -106,6 +116,8 @@ public void setup() {
.dataSources(dataSource)
.build();
when(mlBatchIngestionRequest.getMlBatchIngestionInput()).thenReturn(batchInput);

when(mlFeatureEnabledSetting.isOfflineBatchIngestionEnabled()).thenReturn(true);
}

public void test_doExecute_success() {
Expand Down Expand Up @@ -181,6 +193,18 @@ public void test_doExecute_handleSuccessRate0() {
);
}

public void test_doExecute_batchIngestionDisabled() {
when(mlFeatureEnabledSetting.isOfflineBatchIngestionEnabled()).thenReturn(false);
batchAction.doExecute(task, mlBatchIngestionRequest, actionListener);

ArgumentCaptor<IllegalStateException> argumentCaptor = ArgumentCaptor.forClass(IllegalStateException.class);
verify(actionListener).onFailure(argumentCaptor.capture());
assertEquals(
"Offline batch ingestion is currently disabled. To enable it, update the setting \"plugins.ml_commons.offline_batch_ingestion_enabled\" to true.",
argumentCaptor.getValue().getMessage()
);
}

public void test_doExecute_noDataSource() {
MLBatchIngestionInput batchInput = MLBatchIngestionInput
.builder()
Expand Down
Loading

0 comments on commit 73cc356

Please sign in to comment.