Skip to content

Commit

Permalink
Refactors ClientHelper to combine header logic (elastic#30620)
Browse files Browse the repository at this point in the history
* Refactors ClientHelper to combine header logic

This change removes all the `*ClientHelper` classes which were
repeating logic between plugins and instead adds
`ClientHelper.executeWithHeaders()` and
`ClientHelper.executeWithHeadersAsync()` methods to centralise the
logic for executing requests with stored security headers.

* Removes Watcher headers constant
  • Loading branch information
colings86 authored and ywelsch committed May 23, 2018
1 parent 56ba373 commit 579deef
Show file tree
Hide file tree
Showing 22 changed files with 324 additions and 593 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,28 @@
import org.elasticsearch.client.Client;
import org.elasticsearch.client.FilterClient;
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.xpack.core.security.authc.AuthenticationField;
import org.elasticsearch.xpack.core.security.authc.AuthenticationServiceField;

import java.util.Map;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;

/**
* Utility class to help with the execution of requests made using a {@link Client} such that they
* have the origin as a transient and listeners have the appropriate context upon invocation
*/
public final class ClientHelper {

/**
* List of headers that are related to security
*/
public static final Set<String> SECURITY_HEADER_FILTERS = Sets.newHashSet(AuthenticationServiceField.RUN_AS_USER_HEADER,
AuthenticationField.AUTHENTICATION_KEY);

public static final String ACTION_ORIGIN_TRANSIENT_NAME = "action.origin";
public static final String SECURITY_ORIGIN = "security";
public static final String WATCHER_ORIGIN = "watcher";
Expand Down Expand Up @@ -78,6 +90,82 @@ RequestBuilder extends ActionRequestBuilder<Request, Response, RequestBuilder>>
}
}

/**
* Execute a client operation and return the response, try to run an action
* with least privileges, when headers exist
*
* @param headers
* Request headers, ideally including security headers
* @param origin
* The origin to fall back to if there are no security headers
* @param client
* The client used to query
* @param supplier
* The action to run
* @return An instance of the response class
*/
public static <T extends ActionResponse> T executeWithHeaders(Map<String, String> headers, String origin, Client client,
Supplier<T> supplier) {
Map<String, String> filteredHeaders = headers.entrySet().stream().filter(e -> SECURITY_HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));

// no security headers, we will have to use the xpack internal user for
// our execution by specifying the origin
if (filteredHeaders.isEmpty()) {
try (ThreadContext.StoredContext ignore = stashWithOrigin(client.threadPool().getThreadContext(), origin)) {
return supplier.get();
}
} else {
try (ThreadContext.StoredContext ignore = client.threadPool().getThreadContext().stashContext()) {
client.threadPool().getThreadContext().copyHeaders(filteredHeaders.entrySet());
return supplier.get();
}
}
}

/**
* Execute a client operation asynchronously, try to run an action with
* least privileges, when headers exist
*
* @param headers
* Request headers, ideally including security headers
* @param origin
* The origin to fall back to if there are no security headers
* @param action
* The action to execute
* @param request
* The request object for the action
* @param listener
* The listener to call when the action is complete
*/
public static <Request extends ActionRequest, Response extends ActionResponse,
RequestBuilder extends ActionRequestBuilder<Request, Response, RequestBuilder>> void executeWithHeadersAsync(
Map<String, String> headers, String origin, Client client, Action<Request, Response, RequestBuilder> action, Request request,
ActionListener<Response> listener) {

Map<String, String> filteredHeaders = headers.entrySet().stream().filter(e -> SECURITY_HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));

final ThreadContext threadContext = client.threadPool().getThreadContext();

// No headers (e.g. security not installed/in use) so execute as origin
if (filteredHeaders.isEmpty()) {
ClientHelper.executeAsyncWithOrigin(client, origin, action, request, listener);
} else {
// Otherwise stash the context and copy in the saved headers before executing
final Supplier<ThreadContext.StoredContext> supplier = threadContext.newRestorableContext(false);
try (ThreadContext.StoredContext ignore = stashWithHeaders(threadContext, filteredHeaders)) {
client.execute(action, request, new ContextPreservingActionListener<>(supplier, listener));
}
}
}

private static ThreadContext.StoredContext stashWithHeaders(ThreadContext threadContext, Map<String, String> headers) {
final ThreadContext.StoredContext storedContext = threadContext.stashContext();
threadContext.copyHeaders(headers.entrySet());
return storedContext;
}

private static final class ClientWithOrigin extends FilterClient {

private final String origin;
Expand All @@ -98,5 +186,4 @@ RequestBuilder extends ActionRequestBuilder<Request, Response, RequestBuilder>>
}
}
}

}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
import org.elasticsearch.common.xcontent.ObjectParser;
import org.elasticsearch.common.xcontent.ToXContent;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.persistent.PersistentTasksCustomMetaData;
import org.elasticsearch.persistent.PersistentTasksCustomMetaData.PersistentTask;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedConfig;
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedJobValidator;
import org.elasticsearch.xpack.core.ml.datafeed.DatafeedState;
Expand All @@ -35,8 +38,6 @@
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
import org.elasticsearch.xpack.core.ml.utils.NameResolver;
import org.elasticsearch.xpack.core.ml.utils.ToXContentParams;
import org.elasticsearch.persistent.PersistentTasksCustomMetaData;
import org.elasticsearch.persistent.PersistentTasksCustomMetaData.PersistentTask;

import java.io.IOException;
import java.util.Collection;
Expand Down Expand Up @@ -303,7 +304,7 @@ public Builder putDatafeed(DatafeedConfig datafeedConfig, ThreadContext threadCo
// Adjust the request, adding security headers from the current thread context
DatafeedConfig.Builder builder = new DatafeedConfig.Builder(datafeedConfig);
Map<String, String> headers = threadContext.getHeaders().entrySet().stream()
.filter(e -> MlClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey()))
.filter(e -> ClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
builder.setHeaders(headers);
datafeedConfig = builder.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.xpack.core.ml.MlClientHelper;
import org.elasticsearch.xpack.core.ClientHelper;
import org.elasticsearch.xpack.core.ml.datafeed.extractor.ExtractorUtils;
import org.elasticsearch.xpack.core.ml.job.config.Job;
import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
Expand Down Expand Up @@ -304,7 +304,7 @@ public DatafeedConfig apply(DatafeedConfig datafeedConfig, ThreadContext threadC
if (threadContext != null) {
// Adjust the request, adding security headers from the current thread context
Map<String, String> headers = threadContext.getHeaders().entrySet().stream()
.filter(e -> MlClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey()))
.filter(e -> ClientHelper.SECURITY_HEADER_FILTERS.contains(e.getKey()))
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
builder.setHeaders(headers);
}
Expand Down
Loading

0 comments on commit 579deef

Please sign in to comment.