diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java index bab7303e80..c6cfcdf862 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/metrics_correlation/MetricsCorrelation.java @@ -225,14 +225,14 @@ void registerModel(ActionListener listener) throws Inte XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); modelGroup.toXContent(builder, ToXContent.EMPTY_PARAMS); createModelGroupRequest.source(builder); - client.index(createModelGroupRequest, ActionListener.wrap(r -> { + client.index(createModelGroupRequest, ActionListener.runBefore(ActionListener.wrap(r -> { client.execute(MLRegisterModelAction.INSTANCE, registerRequest, ActionListener.wrap(listener::onResponse, e -> { log.error("Failed to Register Model", e); listener.onFailure(e); })); }, e-> { listener.onFailure(e); - })); + }), () -> context.restore())); } catch (IOException e) { throw new MLException(e); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java index 739447a3f3..0d51ec02e3 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/encryptor/EncryptorImpl.java @@ -110,7 +110,7 @@ private void initMasterKey() { if (clusterService.state().metadata().hasIndex(ML_CONFIG_INDEX)) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY); - client.get(getRequest, new LatchedActionListener(ActionListener.wrap(r -> { + client.get(getRequest, ActionListener.runBefore(new LatchedActionListener(ActionListener.wrap(r -> { if (r.isExists()) { String masterKey = (String) r.getSourceAsMap().get(MASTER_KEY); this.masterKey = masterKey; @@ -120,7 +120,7 @@ private void initMasterKey() { }, e -> { log.error("Failed to get ML encryption master key", e); exceptionRef.set(e); - }), latch)); + }), latch), () -> context.restore())); } } else { exceptionRef.set(new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR)); diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java index b1db55a3db..c4dcce439f 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/DeleteConnectorTransportAction.java @@ -68,7 +68,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { + client.search(searchRequest, ActionListener.runBefore(ActionListener.wrap(searchResponse -> { SearchHit[] searchHits = searchResponse.getHits().getHits(); if (searchHits.length == 0) { deleteConnector(deleteRequest, connectorId, actionListener); @@ -92,7 +92,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener context.restore())); } catch (Exception e) { log.error(e.getMessage(), e); actionListener.onFailure(e); @@ -108,7 +108,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener actionListener) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - client.delete(deleteRequest, new ActionListener<>() { + client.delete(deleteRequest, ActionListener.runBefore(new ActionListener<>() { @Override public void onResponse(DeleteResponse deleteResponse) { if (deleteResponse.getResult() == DocWriteResponse.Result.NOT_FOUND) { @@ -125,7 +125,7 @@ public void onFailure(Exception e) { log.error("Failed to delete ML connector: " + connectorId, e); actionListener.onFailure(e); } - }); + }, () -> context.restore())); } catch (Exception e) { log.error("Failed to delete ML connector: " + connectorId, e); actionListener.onFailure(e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/connector/SearchConnectorTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/connector/SearchConnectorTransportAction.java index 6d4f07a7d8..c655eaea21 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/connector/SearchConnectorTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/connector/SearchConnectorTransportAction.java @@ -60,6 +60,7 @@ protected void doExecute(Task task, SearchRequest request, ActionListener actionListener) { User user = RestActionUtils.getUserContext(client); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore()); List excludes = Optional .ofNullable(request.source()) .map(SearchSourceBuilder::fetchSource) @@ -78,11 +79,11 @@ private void search(SearchRequest request, ActionListener action ); request.source().fetchSource(rebuiltFetchSourceContext); if (connectorAccessControlHelper.skipConnectorAccessControl(user)) { - client.search(request, actionListener); + client.search(request, wrappedListener); } else { SearchSourceBuilder sourceBuilder = connectorAccessControlHelper.addUserBackendRolesFilter(user, request.source()); request.source(sourceBuilder); - client.search(request, actionListener); + client.search(request, wrappedListener); } } catch (Exception e) { log.error(e.getMessage(), e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java index 72a9cfd889..a877ecfa90 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/deploy/TransportDeployModelAction.java @@ -132,6 +132,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> { FunctionName functionName = mlModel.getAlgorithm(); if (functionName == FunctionName.REMOTE && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) { @@ -139,7 +140,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { if (!access) { - listener + wrappedListener .onFailure(new MLValidationException("User Doesn't have privilege to perform this operation on this model")); } else { String[] targetNodeIds = deployModelRequest.getModelNodeIds(); @@ -172,7 +173,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener difference = new HashSet(Arrays.asList(workerNodes)); difference.removeAll(Arrays.asList(targetNodeIds)); if (difference.size() > 0) { - listener + wrappedListener .onFailure( new IllegalArgumentException( "Model already deployed to these nodes: " @@ -188,7 +189,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { log.error("Failed to create deploy model task for " + modelId, exception); - listener.onFailure(exception); + wrappedListener.onFailure(exception); })); } }, e -> { log.error("Failed to Validate Access for ModelId " + modelId, e); - listener.onFailure(e); + wrappedListener.onFailure(e); })); }, e -> { log.error("Failed to deploy model " + modelId, e); - listener.onFailure(e); + wrappedListener.onFailure(e); })); } catch (Exception e) { log.error("Failed to get ML model " + modelId, e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java index ae81c2097a..ab9836308e 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/model_group/SearchModelGroupTransportAction.java @@ -57,13 +57,14 @@ protected void doExecute(Task task, SearchRequest request, ActionListener listener) { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); if (modelAccessControlHelper.skipModelAccessControl(user)) { - client.search(request, listener); + client.search(request, wrappedListener); } else { // Security is enabled, filter is enabled and user isn't admin modelAccessControlHelper.addUserBackendRolesFilter(user, request.source()); log.debug("Filtering result by " + user.getBackendRoles()); - client.search(request, listener); + client.search(request, wrappedListener); } } catch (Exception e) { log.error("Failed to search", e); diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java index e21d18fd91..ac506ac296 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java @@ -86,12 +86,13 @@ protected void doExecute(Task task, ActionRequest request, ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); mlModelManager.getModel(modelId, ActionListener.wrap(mlModel -> { FunctionName functionName = mlModel.getAlgorithm(); modelAccessControlHelper .validateModelGroupAccess(userInfo, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> { if (!access) { - listener + wrappedListener .onFailure( new MLValidationException("User Doesn't have privilege to perform this operation on this model") ); @@ -100,20 +101,25 @@ protected void doExecute(Task task, ActionRequest request, ActionListener { - long endTime = System.nanoTime(); - double durationInMs = (endTime - startTime) / 1e6; - modelCacheHelper.addPredictRequestDuration(modelId, durationInMs); - log.debug("completed predict request " + requestId + " for model " + modelId); - })); + .run( + functionName, + mlPredictionTaskRequest, + transportService, + ActionListener.runAfter(wrappedListener, () -> { + long endTime = System.nanoTime(); + double durationInMs = (endTime - startTime) / 1e6; + modelCacheHelper.addPredictRequestDuration(modelId, durationInMs); + log.debug("completed predict request " + requestId + " for model " + modelId); + }) + ); } }, e -> { log.error("Failed to Validate Access for ModelId " + modelId, e); - listener.onFailure(e); + wrappedListener.onFailure(e); })); }, e -> { log.error("Failed to find model " + modelId, e); - listener.onFailure(e); + wrappedListener.onFailure(e); })); } diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java index ee08608abd..97e5eb2389 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelGroupManager.java @@ -65,6 +65,7 @@ public void createModelGroup(MLRegisterModelGroupInput input, ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); validateUniqueModelGroupName(input.getName(), ActionListener.wrap(modelGroups -> { if (modelGroups != null && modelGroups.getHits().getTotalHits() != null @@ -72,7 +73,7 @@ public void createModelGroup(MLRegisterModelGroupInput input, ActionListener iterator = modelGroups.getHits().iterator(); while (iterator.hasNext()) { String id = iterator.next().getId(); - listener + wrappedListener .onFailure( new IllegalArgumentException( "The name you provided is already being used by another model with ID: " @@ -121,19 +122,19 @@ public void createModelGroup(MLRegisterModelGroupInput input, ActionListener { log.debug("Indexed model group doc successfully {}", modelName); - listener.onResponse(r.getId()); + wrappedListener.onResponse(r.getId()); }, e -> { log.error("Failed to index model group doc", e); - listener.onFailure(e); + wrappedListener.onFailure(e); })); }, ex -> { log.error("Failed to init model group index", ex); - listener.onFailure(ex); + wrappedListener.onFailure(ex); })); } }, e -> { log.error("Failed to search model group index", e); - listener.onFailure(e); + wrappedListener.onFailure(e); })); } catch (Exception e) { log.error("Failed to create model group doc", e); diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index bc7edde75b..99cfd18bc8 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -219,6 +219,7 @@ public void registerModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput, uploadMLModelMeta(mlRegisterModelMetaInput, "1", listener); } else { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); GetRequest getModelGroupRequest = new GetRequest(ML_MODEL_GROUP_INDEX).id(modelGroupId); client.get(getModelGroupRequest, ActionListener.wrap(modelGroup -> { if (modelGroup.isExists()) { @@ -237,25 +238,22 @@ public void registerModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput, .setIfPrimaryTerm(primaryTerm) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) .doc(source); - client - .update( - updateModelGroupRequest, - ActionListener - .wrap(r -> { uploadMLModelMeta(mlRegisterModelMetaInput, newVersion + "", listener); }, e -> { - log.error("Failed to update model group", e); - listener.onFailure(e); - }) - ); + client.update(updateModelGroupRequest, ActionListener.wrap(r -> { + uploadMLModelMeta(mlRegisterModelMetaInput, newVersion + "", wrappedListener); + }, e -> { + log.error("Failed to update model group", e); + wrappedListener.onFailure(e); + })); } else { log.error("Model group not found"); - listener.onFailure(new MLResourceNotFoundException("Fail to find model group")); + wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group")); } }, e -> { if (e instanceof IndexNotFoundException) { - listener.onFailure(new MLResourceNotFoundException("Fail to find model group")); + wrappedListener.onFailure(new MLResourceNotFoundException("Fail to find model group")); } else { log.error("Failed to get model group", e); - listener.onFailure(new MLValidationException("Failed to get model group")); + wrappedListener.onFailure(new MLValidationException("Failed to get model group")); } })); } catch (Exception e) { @@ -272,6 +270,7 @@ public void registerModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput, private void uploadMLModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput, String version, ActionListener listener) { FunctionName functionName = mlRegisterModelMetaInput.getFunctionName(); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); String modelName = mlRegisterModelMetaInput.getName(); mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(res -> { Instant now = Instant.now(); @@ -297,14 +296,14 @@ private void uploadMLModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput client.index(indexRequest, ActionListener.wrap(response -> { log.debug("Index model meta doc successfully {}", modelName); - listener.onResponse(response.getId()); + wrappedListener.onResponse(response.getId()); }, e -> { log.error("Failed to index model meta doc", e); - listener.onFailure(e); + wrappedListener.onFailure(e); })); }, ex -> { log.error("Failed to init model index", ex); - listener.onFailure(ex); + wrappedListener.onFailure(ex); })); } catch (Exception e) { log.error("Failed to register model", e); @@ -332,7 +331,7 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa uploadModel(registerModelInput, mlTask, "1"); } try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { - client.get(getModelGroupRequest, ActionListener.wrap(modelGroup -> { + client.get(getModelGroupRequest, ActionListener.runBefore(ActionListener.wrap(modelGroup -> { if (modelGroup.isExists()) { Map source = modelGroup.getSourceAsMap(); int latestVersion = (int) source.get(MLModelGroup.LATEST_VERSION_FIELD); @@ -349,14 +348,16 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa .setIfPrimaryTerm(primaryTerm) .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE) .doc(source); - client - .update( - updateModelGroupRequest, - ActionListener.wrap(r -> { uploadModel(registerModelInput, mlTask, newVersion + ""); }, e -> { - log.error("Failed to update model group", e); - handleException(registerModelInput.getFunctionName(), mlTask.getTaskId(), e); - }) - ); + try (ThreadContext.StoredContext threadContext = client.threadPool().getThreadContext().stashContext()) { + client + .update( + updateModelGroupRequest, + ActionListener.wrap(r -> { uploadModel(registerModelInput, mlTask, newVersion + ""); }, e -> { + log.error("Failed to update model group", e); + handleException(registerModelInput.getFunctionName(), mlTask.getTaskId(), e); + }) + ); + } } else { log.error("Model group not found"); handleException( @@ -376,7 +377,7 @@ public void registerMLModel(MLRegisterModelInput registerModelInput, MLTask mlTa log.error("Failed to get model group", e); handleException(registerModelInput.getFunctionName(), mlTask.getTaskId(), e); } - })); + }), () -> context.restore())); } catch (Exception e) { log.error("Failed to register model", e); handleException(registerModelInput.getFunctionName(), mlTask.getTaskId(), e); @@ -399,7 +400,7 @@ private void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask ml if (registerModelInput.getConnector() != null) { registerModelInput.getConnector().encrypt(mlEngine::encrypt); } - mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(res -> { + mlIndicesHandler.initModelIndexIfAbsent(ActionListener.runBefore(ActionListener.wrap(res -> { MLModel mlModelMeta = MLModel .builder() .name(modelName) @@ -436,7 +437,7 @@ private void indexRemoteModel(MLRegisterModelInput registerModelInput, MLTask ml }, e -> { log.error("Failed to init model index", e); handleException(functionName, taskId, e); - })); + }), () -> context.restore())); } catch (Exception e) { logException("Failed to upload model", e, log); handleException(functionName, taskId, e); @@ -461,7 +462,7 @@ private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTas String version = modelVersion == null ? registerModelInput.getVersion() : modelVersion; String modelGroupId = registerModelInput.getModelGroupId(); Instant now = Instant.now(); - mlIndicesHandler.initModelIndexIfAbsent(ActionListener.wrap(res -> { + mlIndicesHandler.initModelIndexIfAbsent(ActionListener.runBefore(ActionListener.wrap(res -> { MLModel mlModelMeta = MLModel .builder() .name(modelName) @@ -497,7 +498,7 @@ private void registerModelFromUrl(MLRegisterModelInput registerModelInput, MLTas }, e -> { log.error("Failed to init model index", e); handleException(functionName, taskId, e); - })); + }), () -> context.restore())); } catch (Exception e) { logException("Failed to register model", e, log); handleException(functionName, taskId, e); @@ -726,6 +727,7 @@ public void deployModel( } modelCacheHelper.initModelState(modelId, MLModelState.DEPLOYING, functionName, workerNodes, deployToAllNodes); try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { + ActionListener wrappedListener = ActionListener.runBefore(listener, () -> context.restore()); checkAndAddRunningTask(mlTask, maxDeployTasksPerNode); this.getModel(modelId, threadedActionListener(DEPLOY_THREAD_POOL, ActionListener.wrap(mlModel -> { if (FunctionName.REMOTE == mlModel.getAlgorithm() @@ -747,7 +749,7 @@ public void deployModel( // deploy remote model or model trained by built-in algorithm like kmeans if (mlModel.getConnector() != null) { setupPredictable(modelId, mlModel, params); - listener.onResponse("successful"); + wrappedListener.onResponse("successful"); return; } log.info("Set connector {} for the model: {}", mlModel.getConnectorId(), modelId); @@ -766,11 +768,11 @@ public void deployModel( Connector connector = Connector.createConnector(parser); mlModel.setConnector(connector); setupPredictable(modelId, mlModel, params); - listener.onResponse("successful"); + wrappedListener.onResponse("successful"); log.info("Completed setting connector {} in the model {}", mlModel.getConnectorId(), modelId); } } - }, e -> { listener.onFailure(e); })); + }, e -> { wrappedListener.onFailure(e); })); return; } @@ -781,7 +783,7 @@ public void deployModel( if (modelContentHash != null && !modelContentHash.equals(hash)) { log.error("Model content hash can't match original hash value"); removeModel(modelId); - listener.onFailure(new IllegalArgumentException("model content changed")); + wrappedListener.onFailure(new IllegalArgumentException("model content changed")); return; } log.debug("Model content matches original hash value, continue deploying"); @@ -793,11 +795,11 @@ public void deployModel( modelCacheHelper.setMLExecutor(modelId, mlExecutable); mlStats.getStat(MLNodeLevelStat.ML_DEPLOYED_MODEL_COUNT).increment(); modelCacheHelper.setModelState(modelId, MLModelState.DEPLOYED); - listener.onResponse("successful"); + wrappedListener.onResponse("successful"); } catch (Exception e) { log.error("Failed to add predictor to cache", e); mlExecutable.close(); - listener.onFailure(e); + wrappedListener.onFailure(e); } } else { @@ -811,20 +813,20 @@ public void deployModel( ? mlModel.getTotalChunks() * CHUNK_SIZE : modelContentSizeInBytes; modelCacheHelper.setMemSizeEstimation(modelId, mlModel.getModelFormat(), contentSize); - listener.onResponse("successful"); + wrappedListener.onResponse("successful"); } catch (Exception e) { log.error("Failed to add predictor to cache", e); predictable.close(); - listener.onFailure(e); + wrappedListener.onFailure(e); } } }, e -> { log.error("Failed to retrieve model " + modelId, e); - handleDeployModelException(modelId, functionName, listener, e); + handleDeployModelException(modelId, functionName, wrappedListener, e); })); }, e -> { log.error("Failed to deploy model " + modelId, e); - handleDeployModelException(modelId, functionName, listener, e); + handleDeployModelException(modelId, functionName, wrappedListener, e); }))); } catch (Exception e) { handleDeployModelException(modelId, functionName, listener, e);