Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

restore thread context before running action listener #1418

Merged
merged 1 commit into from
Oct 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,14 @@ void registerModel(ActionListener<MLRegisterModelResponse> listener) throws Inte
XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent());
modelGroup.toXContent(builder, ToXContent.EMPTY_PARAMS);
createModelGroupRequest.source(builder);
client.index(createModelGroupRequest, ActionListener.wrap(r -> {
client.index(createModelGroupRequest, ActionListener.runBefore(ActionListener.wrap(r -> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be we should add a similar comment in the code also what you have in the PR?

client.execute(MLRegisterModelAction.INSTANCE, registerRequest, ActionListener.wrap(listener::onResponse, e -> {
log.error("Failed to Register Model", e);
listener.onFailure(e);
}));
}, e-> {
listener.onFailure(e);
}));
}), () -> context.restore()));
} catch (IOException e) {
throw new MLException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ private void initMasterKey() {
if (clusterService.state().metadata().hasIndex(ML_CONFIG_INDEX)) {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
GetRequest getRequest = new GetRequest(ML_CONFIG_INDEX).id(MASTER_KEY);
client.get(getRequest, new LatchedActionListener(ActionListener.<GetResponse>wrap(r -> {
client.get(getRequest, ActionListener.runBefore(new LatchedActionListener(ActionListener.<GetResponse>wrap(r -> {
if (r.isExists()) {
String masterKey = (String) r.getSourceAsMap().get(MASTER_KEY);
this.masterKey = masterKey;
Expand All @@ -120,7 +120,7 @@ private void initMasterKey() {
}, e -> {
log.error("Failed to get ML encryption master key", e);
exceptionRef.set(e);
}), latch));
}), latch), () -> context.restore()));
}
} else {
exceptionRef.set(new ResourceNotFoundException(MASTER_KEY_NOT_READY_ERROR));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
sourceBuilder.query(QueryBuilders.matchQuery(MLModel.CONNECTOR_ID_FIELD, connectorId));
searchRequest.source(sourceBuilder);
client.search(searchRequest, ActionListener.wrap(searchResponse -> {
client.search(searchRequest, ActionListener.runBefore(ActionListener.wrap(searchResponse -> {
SearchHit[] searchHits = searchResponse.getHits().getHits();
if (searchHits.length == 0) {
deleteConnector(deleteRequest, connectorId, actionListener);
Expand All @@ -92,7 +92,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete
}
log.error("Failed to delete ML connector: " + connectorId, e);
actionListener.onFailure(e);
}));
}), () -> context.restore()));
} catch (Exception e) {
log.error(e.getMessage(), e);
actionListener.onFailure(e);
Expand All @@ -108,7 +108,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<Delete

private void deleteConnector(DeleteRequest deleteRequest, String connectorId, ActionListener<DeleteResponse> actionListener) {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
client.delete(deleteRequest, new ActionListener<>() {
client.delete(deleteRequest, ActionListener.runBefore(new ActionListener<>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
if (deleteResponse.getResult() == DocWriteResponse.Result.NOT_FOUND) {
Expand All @@ -125,7 +125,7 @@ public void onFailure(Exception e) {
log.error("Failed to delete ML connector: " + connectorId, e);
actionListener.onFailure(e);
}
});
}, () -> context.restore()));
} catch (Exception e) {
log.error("Failed to delete ML connector: " + connectorId, e);
actionListener.onFailure(e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ protected void doExecute(Task task, SearchRequest request, ActionListener<Search
private void search(SearchRequest request, ActionListener<SearchResponse> actionListener) {
User user = RestActionUtils.getUserContext(client);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<SearchResponse> wrappedListener = ActionListener.runBefore(actionListener, () -> context.restore());
List<String> excludes = Optional
.ofNullable(request.source())
.map(SearchSourceBuilder::fetchSource)
Expand All @@ -78,11 +79,11 @@ private void search(SearchRequest request, ActionListener<SearchResponse> action
);
request.source().fetchSource(rebuiltFetchSourceContext);
if (connectorAccessControlHelper.skipConnectorAccessControl(user)) {
client.search(request, actionListener);
client.search(request, wrappedListener);
} else {
SearchSourceBuilder sourceBuilder = connectorAccessControlHelper.addUserBackendRolesFilter(user, request.source());
request.source(sourceBuilder);
client.search(request, actionListener);
client.search(request, wrappedListener);
}
} catch (Exception e) {
log.error(e.getMessage(), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,15 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl
String[] excludes = new String[] { MLModel.MODEL_CONTENT_FIELD, MLModel.OLD_MODEL_CONTENT_FIELD };

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<MLDeployModelResponse> wrappedListener = ActionListener.runBefore(listener, () -> context.restore());
mlModelManager.getModel(modelId, null, excludes, ActionListener.wrap(mlModel -> {
FunctionName functionName = mlModel.getAlgorithm();
if (functionName == FunctionName.REMOTE && !mlFeatureEnabledSetting.isRemoteInferenceEnabled()) {
throw new IllegalStateException(REMOTE_INFERENCE_DISABLED_ERR_MSG);
}
modelAccessControlHelper.validateModelGroupAccess(user, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> {
if (!access) {
listener
wrappedListener
.onFailure(new MLValidationException("User Doesn't have privilege to perform this operation on this model"));
} else {
String[] targetNodeIds = deployModelRequest.getModelNodeIds();
Expand Down Expand Up @@ -172,7 +173,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl
Set<String> difference = new HashSet<String>(Arrays.asList(workerNodes));
difference.removeAll(Arrays.asList(targetNodeIds));
if (difference.size() > 0) {
listener
wrappedListener
.onFailure(
new IllegalArgumentException(
"Model already deployed to these nodes: "
Expand All @@ -188,7 +189,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl
eligibleNodes.addAll(Arrays.asList(allEligibleNodes));
}
if (nodeIds.size() == 0) {
listener.onFailure(new IllegalArgumentException("no eligible node found"));
wrappedListener.onFailure(new IllegalArgumentException("no eligible node found"));
return;
}

Expand All @@ -215,7 +216,7 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl
mlTask.setTaskId(taskId);
try {
mlTaskManager.add(mlTask, nodeIds);
listener.onResponse(new MLDeployModelResponse(taskId, MLTaskState.CREATED.name()));
wrappedListener.onResponse(new MLDeployModelResponse(taskId, MLTaskState.CREATED.name()));
threadPool
.executor(DEPLOY_THREAD_POOL)
.execute(
Expand All @@ -238,20 +239,20 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLDepl
TASK_SEMAPHORE_TIMEOUT,
true
);
listener.onFailure(ex);
wrappedListener.onFailure(ex);
}
}, exception -> {
log.error("Failed to create deploy model task for " + modelId, exception);
listener.onFailure(exception);
wrappedListener.onFailure(exception);
}));
}
}, e -> {
log.error("Failed to Validate Access for ModelId " + modelId, e);
listener.onFailure(e);
wrappedListener.onFailure(e);
}));
}, e -> {
log.error("Failed to deploy model " + modelId, e);
listener.onFailure(e);
wrappedListener.onFailure(e);
}));
} catch (Exception e) {
log.error("Failed to get ML model " + modelId, e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,14 @@ protected void doExecute(Task task, SearchRequest request, ActionListener<Search

private void preProcessRoleAndPerformSearch(SearchRequest request, User user, ActionListener<SearchResponse> listener) {
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<SearchResponse> wrappedListener = ActionListener.runBefore(listener, () -> context.restore());
if (modelAccessControlHelper.skipModelAccessControl(user)) {
client.search(request, listener);
client.search(request, wrappedListener);
} else {
// Security is enabled, filter is enabled and user isn't admin
modelAccessControlHelper.addUserBackendRolesFilter(user, request.source());
log.debug("Filtering result by " + user.getBackendRoles());
client.search(request, listener);
client.search(request, wrappedListener);
}
} catch (Exception e) {
log.error("Failed to search", e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,13 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLTask
final User userInfo = user;

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<MLTaskResponse> wrappedListener = ActionListener.runBefore(listener, () -> context.restore());
mlModelManager.getModel(modelId, ActionListener.wrap(mlModel -> {
FunctionName functionName = mlModel.getAlgorithm();
modelAccessControlHelper
.validateModelGroupAccess(userInfo, mlModel.getModelGroupId(), client, ActionListener.wrap(access -> {
if (!access) {
listener
wrappedListener
.onFailure(
new MLValidationException("User Doesn't have privilege to perform this operation on this model")
);
Expand All @@ -100,20 +101,25 @@ protected void doExecute(Task task, ActionRequest request, ActionListener<MLTask
log.debug("receive predict request " + requestId + " for model " + mlPredictionTaskRequest.getModelId());
long startTime = System.nanoTime();
mlPredictTaskRunner
.run(functionName, mlPredictionTaskRequest, transportService, ActionListener.runAfter(listener, () -> {
long endTime = System.nanoTime();
double durationInMs = (endTime - startTime) / 1e6;
modelCacheHelper.addPredictRequestDuration(modelId, durationInMs);
log.debug("completed predict request " + requestId + " for model " + modelId);
}));
.run(
functionName,
mlPredictionTaskRequest,
transportService,
ActionListener.runAfter(wrappedListener, () -> {
long endTime = System.nanoTime();
double durationInMs = (endTime - startTime) / 1e6;
modelCacheHelper.addPredictRequestDuration(modelId, durationInMs);
log.debug("completed predict request " + requestId + " for model " + modelId);
})
);
}
}, e -> {
log.error("Failed to Validate Access for ModelId " + modelId, e);
listener.onFailure(e);
wrappedListener.onFailure(e);
}));
}, e -> {
log.error("Failed to find model " + modelId, e);
listener.onFailure(e);
wrappedListener.onFailure(e);
}));

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,15 @@ public void createModelGroup(MLRegisterModelGroupInput input, ActionListener<Str
String modelName = input.getName();
User user = RestActionUtils.getUserContext(client);
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
ActionListener<String> wrappedListener = ActionListener.runBefore(listener, () -> context.restore());
validateUniqueModelGroupName(input.getName(), ActionListener.wrap(modelGroups -> {
if (modelGroups != null
&& modelGroups.getHits().getTotalHits() != null
&& modelGroups.getHits().getTotalHits().value != 0) {
Iterator<SearchHit> iterator = modelGroups.getHits().iterator();
while (iterator.hasNext()) {
String id = iterator.next().getId();
listener
wrappedListener
.onFailure(
new IllegalArgumentException(
"The name you provided is already being used by another model with ID: "
Expand Down Expand Up @@ -121,19 +122,19 @@ public void createModelGroup(MLRegisterModelGroupInput input, ActionListener<Str

client.index(indexRequest, ActionListener.wrap(r -> {
log.debug("Indexed model group doc successfully {}", modelName);
listener.onResponse(r.getId());
wrappedListener.onResponse(r.getId());
}, e -> {
log.error("Failed to index model group doc", e);
listener.onFailure(e);
wrappedListener.onFailure(e);
}));
}, ex -> {
log.error("Failed to init model group index", ex);
listener.onFailure(ex);
wrappedListener.onFailure(ex);
}));
}
}, e -> {
log.error("Failed to search model group index", e);
listener.onFailure(e);
wrappedListener.onFailure(e);
}));
} catch (Exception e) {
log.error("Failed to create model group doc", e);
Expand Down
Loading