Skip to content

Commit

Permalink
Make RestHighLevelClient's Request class public (#26627)
Browse files Browse the repository at this point in the history
Request class is currently package protected, making it difficult for
the users to extend the RestHighLevelClient and to use its protected
methods to execute requests. This commit makes the Request class public
and changes few methods of RestHighLevelClient to be protected.
  • Loading branch information
tlrx committed Sep 20, 2017
1 parent e498141 commit 943cbd7
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -63,30 +63,47 @@
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.StringJoiner;

final class Request {
public final class Request {

static final XContentType REQUEST_BODY_CONTENT_TYPE = XContentType.JSON;

final String method;
final String endpoint;
final Map<String, String> params;
final HttpEntity entity;
private final String method;
private final String endpoint;
private final Map<String, String> parameters;
private final HttpEntity entity;

Request(String method, String endpoint, Map<String, String> params, HttpEntity entity) {
this.method = method;
this.endpoint = endpoint;
this.params = params;
public Request(String method, String endpoint, Map<String, String> parameters, HttpEntity entity) {
this.method = Objects.requireNonNull(method, "method cannot be null");
this.endpoint = Objects.requireNonNull(endpoint, "endpoint cannot be null");
this.parameters = Objects.requireNonNull(parameters, "parameters cannot be null");
this.entity = entity;
}

public String getMethod() {
return method;
}

public String getEndpoint() {
return endpoint;
}

public Map<String, String> getParameters() {
return parameters;
}

public HttpEntity getEntity() {
return entity;
}

@Override
public String toString() {
return "Request{" +
"method='" + method + '\'' +
", endpoint='" + endpoint + '\'' +
", params=" + params +
", params=" + parameters +
", hasBody=" + (entity != null) +
'}';
}
Expand Down Expand Up @@ -233,7 +250,7 @@ static Request bulk(BulkRequest bulkRequest) throws IOException {

static Request exists(GetRequest getRequest) {
Request request = get(getRequest);
return new Request(HttpHead.METHOD_NAME, request.endpoint, request.params, null);
return new Request(HttpHead.METHOD_NAME, request.endpoint, request.parameters, null);
}

static Request get(GetRequest getRequest) {
Expand Down Expand Up @@ -381,7 +398,7 @@ static String endpoint(String... parts) {
* @return the {@link ContentType}
*/
@SuppressForbidden(reason = "Only allowed place to convert a XContentType to a ContentType")
static ContentType createContentType(final XContentType xContentType) {
public static ContentType createContentType(final XContentType xContentType) {
return ContentType.create(xContentType.mediaTypeWithoutParameters(), (Charset) null);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ protected <Req extends ActionRequest, Resp> Resp performRequest(Req request,
Request req = requestConverter.apply(request);
Response response;
try {
response = client.performRequest(req.method, req.endpoint, req.params, req.entity, headers);
response = client.performRequest(req.getMethod(), req.getEndpoint(), req.getParameters(), req.getEntity(), headers);
} catch (ResponseException e) {
if (ignores.contains(e.getResponse().getStatusLine().getStatusCode())) {
try {
Expand Down Expand Up @@ -474,7 +474,7 @@ protected <Req extends ActionRequest, Resp> void performRequestAsync(Req request
}

ResponseListener responseListener = wrapResponseListener(responseConverter, listener, ignores);
client.performRequestAsync(req.method, req.endpoint, req.params, req.entity, responseListener, headers);
client.performRequestAsync(req.getMethod(), req.getEndpoint(), req.getParameters(), req.getEntity(), responseListener, headers);
}

<Resp> ResponseListener wrapResponseListener(CheckedFunction<Response, Resp, IOException> responseConverter,
Expand Down Expand Up @@ -522,7 +522,7 @@ public void onFailure(Exception exception) {
* that wraps the original {@link ResponseException}. The potential exception obtained while parsing is added to the returned
* exception as a suppressed exception. This method is guaranteed to not throw any exception eventually thrown while parsing.
*/
ElasticsearchStatusException parseResponseException(ResponseException responseException) {
protected ElasticsearchStatusException parseResponseException(ResponseException responseException) {
Response response = responseException.getResponse();
HttpEntity entity = response.getEntity();
ElasticsearchStatusException elasticsearchException;
Expand All @@ -542,8 +542,8 @@ ElasticsearchStatusException parseResponseException(ResponseException responseEx
return elasticsearchException;
}

<Resp> Resp parseEntity(
HttpEntity entity, CheckedFunction<XContentParser, Resp, IOException> entityParser) throws IOException {
protected <Resp> Resp parseEntity(final HttpEntity entity,
final CheckedFunction<XContentParser, Resp, IOException> entityParser) throws IOException {
if (entity == null) {
throw new IllegalStateException("Response body expected but not returned");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,29 @@

package org.elasticsearch.client;

import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Build;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.main.MainRequest;
import org.elasticsearch.action.main.MainResponse;
import org.apache.http.Header;
import org.apache.http.HttpEntity;
import org.apache.http.HttpHost;
import org.apache.http.HttpResponse;
import org.apache.http.ProtocolVersion;
import org.apache.http.RequestLine;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.entity.ByteArrayEntity;
import org.apache.http.entity.ContentType;
import org.apache.http.message.BasicHeader;
import org.apache.http.message.BasicHttpResponse;
import org.apache.http.message.BasicRequestLine;
import org.apache.http.message.BasicStatusLine;
import org.apache.lucene.util.BytesRef;
import org.elasticsearch.Build;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.main.MainRequest;
import org.elasticsearch.action.main.MainResponse;
import org.elasticsearch.action.support.PlainActionFuture;
import org.elasticsearch.client.Request;
import org.elasticsearch.client.Response;
import org.elasticsearch.client.ResponseListener;
import org.elasticsearch.client.RestClient;
import org.elasticsearch.client.RestHighLevelClient;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.common.SuppressForbidden;
import org.elasticsearch.common.xcontent.XContentHelper;
Expand All @@ -48,18 +52,22 @@
import java.io.IOException;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

import static java.util.Collections.emptyMap;
import static java.util.Collections.emptySet;
import static org.elasticsearch.client.ESRestHighLevelClientTestCase.execute;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.anyMapOf;
import static org.mockito.Matchers.anyObject;
import static org.mockito.Matchers.anyVararg;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

/**
* Test and demonstrates how {@link RestHighLevelClient} can be extended to support custom endpoints.
Expand Down Expand Up @@ -92,31 +100,45 @@ public void testCustomEndpoint() throws IOException {
final MainRequest request = new MainRequest();
final Header header = new BasicHeader("node_name", randomAlphaOfLengthBetween(1, 10));

MainResponse response = execute(request, restHighLevelClient::custom, restHighLevelClient::customAsync, header);
MainResponse response = restHighLevelClient.custom(request, header);
assertEquals(header.getValue(), response.getNodeName());

response = execute(request, restHighLevelClient::customAndParse, restHighLevelClient::customAndParseAsync, header);
response = restHighLevelClient.customAndParse(request, header);
assertEquals(header.getValue(), response.getNodeName());
}

public void testCustomEndpointAsync() throws Exception {
final MainRequest request = new MainRequest();
final Header header = new BasicHeader("node_name", randomAlphaOfLengthBetween(1, 10));

PlainActionFuture<MainResponse> future = PlainActionFuture.newFuture();
restHighLevelClient.customAsync(request, future, header);
assertEquals(header.getValue(), future.get().getNodeName());

future = PlainActionFuture.newFuture();
restHighLevelClient.customAndParseAsync(request, future, header);
assertEquals(header.getValue(), future.get().getNodeName());
}

/**
* The {@link RestHighLevelClient} must declare the following execution methods using the <code>protected</code> modifier
* so that they can be used by subclasses to implement custom logic.
*/
@SuppressForbidden(reason = "We're forced to uses Class#getDeclaredMethods() here because this test checks protected methods")
public void testMethodsVisibility() throws ClassNotFoundException {
String[] methodNames = new String[]{"performRequest", "performRequestAndParseEntity", "performRequestAsync",
"performRequestAsyncAndParseEntity"};
for (String methodName : methodNames) {
boolean found = false;
for (Method method : RestHighLevelClient.class.getDeclaredMethods()) {
if (method.getName().equals(methodName)) {
assertTrue("Method " + methodName + " must be protected", Modifier.isProtected(method.getModifiers()));
found = true;
}
}
assertTrue("Failed to find method " + methodName, found);
}
final String[] methodNames = new String[]{"performRequest",
"performRequestAsync",
"performRequestAndParseEntity",
"performRequestAsyncAndParseEntity",
"parseEntity",
"parseResponseException"};

final List<String> protectedMethods = Arrays.stream(RestHighLevelClient.class.getDeclaredMethods())
.filter(method -> Modifier.isProtected(method.getModifiers()))
.map(Method::getName)
.collect(Collectors.toList());

assertThat(protectedMethods, containsInAnyOrder(methodNames));
}

/**
Expand All @@ -135,15 +157,20 @@ private Void mockPerformRequestAsync(Header httpHeader, ResponseListener respons
* Mocks the synchronous request execution like if it was executed by Elasticsearch.
*/
private Response mockPerformRequest(Header httpHeader) throws IOException {
final Response mockResponse = mock(Response.class);
when(mockResponse.getHost()).thenReturn(new HttpHost("localhost", 9200));

ProtocolVersion protocol = new ProtocolVersion("HTTP", 1, 1);
HttpResponse httpResponse = new BasicHttpResponse(new BasicStatusLine(protocol, 200, "OK"));
when(mockResponse.getStatusLine()).thenReturn(new BasicStatusLine(protocol, 200, "OK"));

MainResponse response = new MainResponse(httpHeader.getValue(), Version.CURRENT, ClusterName.DEFAULT, "_na", Build.CURRENT, true);
BytesRef bytesRef = XContentHelper.toXContent(response, XContentType.JSON, false).toBytesRef();
httpResponse.setEntity(new ByteArrayEntity(bytesRef.bytes, ContentType.APPLICATION_JSON));
when(mockResponse.getEntity()).thenReturn(new ByteArrayEntity(bytesRef.bytes, ContentType.APPLICATION_JSON));

RequestLine requestLine = new BasicRequestLine(HttpGet.METHOD_NAME, ENDPOINT, protocol);
return new Response(requestLine, new HttpHost("localhost", 9200), httpResponse);
when(mockResponse.getRequestLine()).thenReturn(requestLine);

return mockResponse;
}

/**
Expand Down
Loading

0 comments on commit 943cbd7

Please sign in to comment.