Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HostHeaderHttpRequesterFilter does not work for HTTP/2 #944

Merged
merged 3 commits into from
Feb 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright © 2018-2019 Apple Inc. and the ServiceTalk project authors
* Copyright © 2018-2020 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -32,7 +32,7 @@
import static io.netty.util.NetUtil.isValidIpV6Address;
import static io.servicetalk.http.api.CharSequences.newAsciiString;
import static io.servicetalk.http.api.HttpHeaderNames.HOST;
import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_1_1;
import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_1_0;

/**
* A filter which will apply a fallback value for the {@link HttpHeaderNames#HOST} header if one is not present.
Expand Down Expand Up @@ -84,8 +84,9 @@ public HttpExecutionStrategy influenceStrategy(final HttpExecutionStrategy strat
private Single<StreamingHttpResponse> request(final StreamingHttpRequester delegate,
final HttpExecutionStrategy strategy,
final StreamingHttpRequest request) {
if (HTTP_1_1.equals(request.version()) && !request.headers().contains(HOST)) {
request.headers().set(HOST, fallbackHost);
// "Host" header is not required for HTTP/1.0
if (!HTTP_1_0.equals(request.version()) && !request.headers().contains(HOST)) {
request.setHeader(HOST, fallbackHost);
}
return delegate.request(strategy, request);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright © 2019 Apple Inc. and the ServiceTalk project authors
* Copyright © 2019-2020 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -15,25 +15,96 @@
*/
package io.servicetalk.http.netty;

import io.servicetalk.concurrent.internal.ServiceTalkTestTimeout;
import io.servicetalk.http.api.BlockingHttpClient;
import io.servicetalk.http.api.BlockingHttpRequester;
import io.servicetalk.http.api.HttpProtocolConfig;
import io.servicetalk.http.api.HttpProtocolVersion;
import io.servicetalk.http.api.HttpRequest;
import io.servicetalk.http.api.HttpResponse;
import io.servicetalk.http.api.ReservedBlockingHttpConnection;
import io.servicetalk.transport.api.ServerContext;

import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.Timeout;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;

import static io.servicetalk.concurrent.api.Publisher.from;
import static io.servicetalk.concurrent.api.Single.succeeded;
import javax.annotation.Nullable;

import static io.servicetalk.http.api.HttpExecutionStrategies.noOffloadsStrategy;
import static io.servicetalk.http.api.HttpHeaderNames.HOST;
import static io.servicetalk.http.api.HttpResponseStatus.OK;
import static io.servicetalk.http.api.HttpSerializationProviders.textDeserializer;
import static io.servicetalk.http.api.HttpSerializationProviders.textSerializer;
import static io.servicetalk.http.netty.HttpClients.forSingleAddress;
import static io.servicetalk.http.netty.HttpProtocolConfigs.h1Default;
import static io.servicetalk.http.netty.HttpProtocolConfigs.h2Default;
import static io.servicetalk.transport.netty.internal.AddressUtils.localAddress;
import static io.servicetalk.transport.netty.internal.AddressUtils.serverHostAndPort;
import static java.util.Objects.requireNonNull;
import static org.junit.Assert.assertEquals;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;

@RunWith(Parameterized.class)
@SuppressWarnings("PMD.AvoidUsingHardCodedIP")
public class HostHeaderHttpRequesterFilterTest {

private enum HttpVersionConfig {
HTTP_1_0 {
@Override
HttpProtocolVersion version() {
return HttpProtocolVersion.HTTP_1_0;
}

@Override
HttpProtocolConfig config() {
return h1Default();
}
},
HTTP_1_1 {
@Override
HttpProtocolVersion version() {
return HttpProtocolVersion.HTTP_1_1;
}

@Override
HttpProtocolConfig config() {
return h1Default();
}
},
HTTP_2_0 {
@Override
HttpProtocolVersion version() {
return H2ToStH1Utils.HTTP_2_0;
}

@Override
HttpProtocolConfig config() {
return h2Default();
}
};

abstract HttpProtocolVersion version();

abstract HttpProtocolConfig config();
}

@Rule
public final Timeout timeout = new ServiceTalkTestTimeout();

private final HttpVersionConfig httpVersionConfig;

public HostHeaderHttpRequesterFilterTest(HttpVersionConfig httpVersionConfig) {
this.httpVersionConfig = httpVersionConfig;
}

@Parameters(name = "httpVersion={0}")
public static HttpVersionConfig[] data() {
return HttpVersionConfig.values();
}

@Test
public void ipv4NotEscaped() throws Exception {
doHostHeaderTest("1.2.3.4", "1.2.3.4");
Expand All @@ -44,58 +115,88 @@ public void ipv6IsEscaped() throws Exception {
doHostHeaderTest("::1", "[::1]");
}

private static void doHostHeaderTest(String hostHeader, String expectedValue) throws Exception {
private void doHostHeaderTest(String hostHeader, String expectedValue) throws Exception {
try (ServerContext context = buildServer();
BlockingHttpClient client = forSingleAddress(serverHostAndPort(context))
.unresolvedAddressToHost(addr -> hostHeader)
.buildBlocking()) {
assertEquals(expectedValue,
client.request(client.get("/")).payloadBody(textDeserializer()));
.protocols(httpVersionConfig.config())
.unresolvedAddressToHost(addr -> hostHeader)
.buildBlocking()) {
assertResponse(client, null, expectedValue);
}
}

private static ServerContext buildServer() throws Exception {
private ServerContext buildServer() throws Exception {
return HttpServers.forAddress(localAddress(0))
.listenStreamingAndAwait((ctx, request, responseFactory) ->
succeeded(responseFactory.ok().payloadBody(
from(requireNonNull(request.headers().get(HOST)).toString()), textSerializer())));
.protocols(httpVersionConfig.config())
.listenBlockingAndAwait((ctx, request, responseFactory) -> {
assertThat(request.version(), equalTo(httpVersionConfig.version()));
final CharSequence host = request.headers().get(HOST);
return responseFactory.ok()
.version(httpVersionConfig.version())
.payloadBody(host != null ? host.toString() : "null", textSerializer());
});
}

@Test
public void clientBuilderAppendClientFilter() throws Exception {
try (ServerContext context = buildServer();
BlockingHttpClient client = forSingleAddress(serverHostAndPort(context))
.disableHostHeaderFallback() // turn off the default
.protocols(httpVersionConfig.config())
.disableHostHeaderFallback() // turn off the default
.appendClientFilter(new HostHeaderHttpRequesterFilter("foo.bar:-1"))
.buildBlocking()) {
assertEquals("foo.bar:-1",
client.request(client.get("/")).payloadBody(textDeserializer()));
.buildBlocking()) {
assertResponse(client, null, "foo.bar:-1");
}
}

@Test
public void clientBuilderAppendConnectionFilter() throws Exception {
try (ServerContext context = buildServer();
BlockingHttpClient client = forSingleAddress(serverHostAndPort(context))
.disableHostHeaderFallback() // turn off the default
.protocols(httpVersionConfig.config())
.disableHostHeaderFallback() // turn off the default
.appendConnectionFilter(new HostHeaderHttpRequesterFilter("foo.bar:-1"))
.buildBlocking()) {
assertEquals("foo.bar:-1",
client.request(client.get("/")).payloadBody(textDeserializer()));
.buildBlocking()) {
assertResponse(client, null, "foo.bar:-1");
}
}

@Test
public void reserveConnection() throws Exception {
try (ServerContext context = buildServer();
BlockingHttpClient client = HttpClients.forResolvedAddress(serverHostAndPort(context))
.protocols(httpVersionConfig.config())
.disableHostHeaderFallback() // turn off the default
.appendConnectionFilter(new HostHeaderHttpRequesterFilter("foo.bar:-1"))
.buildBlocking()) {
ReservedBlockingHttpConnection conn = client.reserveConnection(client.get("/"));
assertEquals("foo.bar:-1",
conn.request(conn.get("/")).payloadBody(textDeserializer()));
conn.close();
.buildBlocking();
ReservedBlockingHttpConnection conn = client.reserveConnection(client.get("/"))) {
assertResponse(conn, null, "foo.bar:-1");
}
}

@Test
public void clientBuilderAppendClientFilterExplicitHostHeader() throws Exception {
try (ServerContext context = buildServer();
BlockingHttpClient client = forSingleAddress(serverHostAndPort(context))
.protocols(httpVersionConfig.config())
.disableHostHeaderFallback() // turn off the default
.appendClientFilter(new HostHeaderHttpRequesterFilter("foo.bar:-1"))
.buildBlocking()) {
assertResponse(client, "bar.only:-1", "bar.only:-1");
}
}

private void assertResponse(BlockingHttpRequester requester, @Nullable String hostHeader, String expectedValue)
throws Exception {
final HttpRequest request = requester.get("/").version(httpVersionConfig.version());
if (hostHeader != null) {
request.setHeader(HOST, hostHeader);
}
HttpResponse response = requester.request(noOffloadsStrategy(), request);
assertThat(response.status(), equalTo(OK));
assertThat(response.version(), equalTo(httpVersionConfig.version()));
// "Host" header is not required for HTTP/1.0. Therefore, we may expect "null" here.
assertThat(response.payloadBody(textDeserializer()), equalTo(
httpVersionConfig == HttpVersionConfig.HTTP_1_0 && hostHeader == null ? "null" : expectedValue));
}
}