From 499d011529ac46421c97adc6c8eace34f95eb7ac Mon Sep 17 00:00:00 2001 From: Idel Pivnitskiy Date: Sun, 17 Sep 2023 23:16:24 -0700 Subject: [PATCH] Implement HTTP proxy CONNECT with ALPN --- .../api/SingleAddressHttpClientBuilder.java | 12 +++ .../http/netty/AlpnChannelSingle.java | 28 ++++--- .../netty/AlpnLBHttpConnectionFactory.java | 8 +- ...DefaultSingleAddressHttpClientBuilder.java | 10 +-- .../netty/DeferredServerChannelBinder.java | 8 +- .../PipelinedLBHttpConnectionFactory.java | 3 +- .../ProxyConnectLBHttpConnectionFactory.java | 80 +++++++++++------- .../netty/StreamingConnectionFactory.java | 9 +- .../http/netty/HttpsProxyTest.java | 83 +++++++++++++------ ...oxyConnectLBHttpConnectionFactoryTest.java | 2 - .../CopyByteBufHandlerChannelInitializer.java | 11 +++ .../internal/DefaultNettyConnection.java | 12 +++ 12 files changed, 179 insertions(+), 87 deletions(-) diff --git a/servicetalk-http-api/src/main/java/io/servicetalk/http/api/SingleAddressHttpClientBuilder.java b/servicetalk-http-api/src/main/java/io/servicetalk/http/api/SingleAddressHttpClientBuilder.java index 0906349d6d..1e85f0c6e5 100644 --- a/servicetalk-http-api/src/main/java/io/servicetalk/http/api/SingleAddressHttpClientBuilder.java +++ b/servicetalk-http-api/src/main/java/io/servicetalk/http/api/SingleAddressHttpClientBuilder.java @@ -54,6 +54,12 @@ public interface SingleAddressHttpClientBuilder extends HttpClientBuilder< * If the client talks to a proxy over http (not https, {@link #sslConfig(ClientSslConfig) ClientSslConfig} is NOT * configured), it will rewrite the request-target to * absolute-form, as specified by the RFC. + *

+ * For secure proxy tunnels (when {@link #sslConfig(ClientSslConfig) ClientSslConfig} is configured) the tunnel is + * always initialized using + * HTTP/1.1 CONNECT request. The actual + * protocol will be negotiated via ALPN extension of TLS protocol, + * taking into account HTTP protocols configured via {@link #protocols(HttpProtocolConfig...)} method. * * @param proxyAddress Unresolved address of the proxy. When used with a builder created for a resolved address, * {@code proxyAddress} should also be already resolved – otherwise runtime exceptions may occur. @@ -70,6 +76,12 @@ default SingleAddressHttpClientBuilder proxyAddress(U proxyAddress) { // F * If the client talks to a proxy over http (not https, {@link #sslConfig(ClientSslConfig) ClientSslConfig} is NOT * configured), it will rewrite the request-target to * absolute-form, as specified by the RFC. + *

+ * For secure proxy tunnels (when {@link #sslConfig(ClientSslConfig) ClientSslConfig} is configured) the tunnel is + * always initialized using + * HTTP/1.1 CONNECT request. The actual + * protocol will be negotiated via ALPN extension of TLS protocol, + * taking into account HTTP protocols configured via {@link #protocols(HttpProtocolConfig...)} method. * * @param proxyAddress Unresolved address of the proxy. When used with a builder created for a resolved address, * {@code proxyAddress} should also be already resolved – otherwise runtime exceptions may occur. diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/AlpnChannelSingle.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/AlpnChannelSingle.java index 533492f16f..ceeb1a0de0 100644 --- a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/AlpnChannelSingle.java +++ b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/AlpnChannelSingle.java @@ -28,27 +28,34 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.util.function.Consumer; import javax.annotation.Nullable; import static io.servicetalk.http.netty.AlpnIds.HTTP_1_1; import static io.servicetalk.transport.netty.internal.ChannelCloseUtils.assignConnectionError; +import static java.util.Objects.requireNonNull; /** * A {@link Single} that initializes ALPN handler and completes after protocol negotiation. */ final class AlpnChannelSingle extends ChannelInitSingle { - private final boolean forceChannelRead; + private final Consumer onHandlerAdded; + + AlpnChannelSingle(final Channel channel, + final ChannelInitializer channelInitializer) { + this(channel, channelInitializer, __ -> { }); + } AlpnChannelSingle(final Channel channel, final ChannelInitializer channelInitializer, - final boolean forceChannelRead) { + final Consumer onHandlerAdded) { super(channel, channelInitializer); - this.forceChannelRead = forceChannelRead; + this.onHandlerAdded = requireNonNull(onHandlerAdded); } @Override protected ChannelHandler newChannelHandler(final Subscriber subscriber) { - return new AlpnChannelHandler(subscriber, forceChannelRead); + return new AlpnChannelHandler(subscriber, onHandlerAdded); } /** @@ -61,24 +68,19 @@ private static final class AlpnChannelHandler extends ApplicationProtocolNegotia @Nullable private SingleSource.Subscriber subscriber; - private final boolean forceRead; + private final Consumer onHandlerAdded; AlpnChannelHandler(final SingleSource.Subscriber subscriber, - final boolean forceRead) { + final Consumer onHandlerAdded) { super(HTTP_1_1); this.subscriber = subscriber; - this.forceRead = forceRead; + this.onHandlerAdded = onHandlerAdded; } @Override public void handlerAdded(final ChannelHandlerContext ctx) throws Exception { super.handlerAdded(ctx); - if (forceRead) { - // Force a read to get the SSL handshake started. We initialize pipeline before - // SslHandshakeCompletionEvent will complete, therefore, no data will be propagated before we finish - // initialization. - ctx.read(); - } + onHandlerAdded.accept(ctx); } @Override diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/AlpnLBHttpConnectionFactory.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/AlpnLBHttpConnectionFactory.java index e9c4dc6930..a9e5afa244 100644 --- a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/AlpnLBHttpConnectionFactory.java +++ b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/AlpnLBHttpConnectionFactory.java @@ -70,7 +70,7 @@ private Single createConnection( final Channel channel, final ConnectionObserver connectionObserver, final ReadOnlyTcpClientConfig tcpConfig) { return new AlpnChannelSingle(channel, - new TcpClientChannelInitializer(tcpConfig, connectionObserver), false).flatMap(protocol -> { + new TcpClientChannelInitializer(tcpConfig, connectionObserver)).flatMap(protocol -> { switch (protocol) { case HTTP_1_1: final H1ProtocolConfig h1Config = config.h1Config(); @@ -89,8 +89,12 @@ private Single createConnection( new H2ClientParentChannelInitializer(h2Config), connectionObserver, config.allowDropTrailersReadFromTransport()); default: - return failed(new IllegalStateException("Unknown ALPN protocol negotiated: " + protocol)); + return unknownAlpnProtocol(protocol); } }); } + + static Single unknownAlpnProtocol(final String protocol) { + return failed(new IllegalStateException("Unknown ALPN protocol negotiated: " + protocol)); + } } diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/DefaultSingleAddressHttpClientBuilder.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/DefaultSingleAddressHttpClientBuilder.java index a9917b8f5d..6cd68d3472 100644 --- a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/DefaultSingleAddressHttpClientBuilder.java +++ b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/DefaultSingleAddressHttpClientBuilder.java @@ -235,8 +235,9 @@ public HttpExecutionStrategy executionStrategy() { return computedStrategy; } }; - if (roConfig.h2Config() != null && roConfig.hasProxy()) { - throw new IllegalStateException("Proxying is not yet supported with HTTP/2"); + final SslContext sslContext = roConfig.tcpConfig().sslContext(); + if (roConfig.hasProxy() && sslContext == null && roConfig.h2Config() != null) { + throw new IllegalStateException("Proxying is not yet supported with plaintext HTTP/2"); } // Track resources that potentially need to be closed when an exception is thrown during buildStreaming @@ -250,7 +251,6 @@ public HttpExecutionStrategy executionStrategy() { final ExecutionStrategy connectionFactoryStrategy = ctx.builder.strategyComputation.buildForConnectionFactory(); - final SslContext sslContext = roConfig.tcpConfig().sslContext(); if (roConfig.hasProxy() && sslContext != null) { assert roConfig.connectAddress() != null; final ConnectionFactoryFilter proxy = @@ -269,14 +269,14 @@ public HttpExecutionStrategy executionStrategy() { ctx.builder.addIdleTimeoutConnectionFilter ? appendConnectionFilter(ctx.builder.connectionFilterFactory, DEFAULT_IDLE_TIMEOUT_FILTER) : ctx.builder.connectionFilterFactory; - if (roConfig.isH2PriorKnowledge()) { + if (!roConfig.hasProxy() && roConfig.isH2PriorKnowledge()) { H2ProtocolConfig h2Config = roConfig.h2Config(); assert h2Config != null; connectionFactory = new H2LBHttpConnectionFactory<>(roConfig, executionContext, connectionFilterFactory, reqRespFactory, connectionFactoryStrategy, connectionFactoryFilter, ctx.builder.loadBalancerFactory::toLoadBalancedConnection); - } else if (roConfig.tcpConfig().preferredAlpnProtocol() != null) { + } else if (!roConfig.hasProxy() && roConfig.tcpConfig().preferredAlpnProtocol() != null) { H1ProtocolConfig h1Config = roConfig.h1Config(); H2ProtocolConfig h2Config = roConfig.h2Config(); connectionFactory = new AlpnLBHttpConnectionFactory<>(roConfig, executionContext, diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/DeferredServerChannelBinder.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/DeferredServerChannelBinder.java index 2e27ced0ac..4a04b91974 100644 --- a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/DeferredServerChannelBinder.java +++ b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/DeferredServerChannelBinder.java @@ -31,6 +31,7 @@ import io.servicetalk.transport.netty.internal.NettyConnectionContext; import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -90,8 +91,11 @@ private static Single alpnInitChannel(final SocketAddres final StreamingHttpService service, final boolean drainRequestPayloadBody, final ConnectionObserver observer) { - return new AlpnChannelSingle(channel, - new TcpServerChannelInitializer(config.tcpConfig(), observer), true).flatMap(protocol -> { + return new AlpnChannelSingle(channel, new TcpServerChannelInitializer(config.tcpConfig(), observer), + // Force a read to get the SSL handshake started. We initialize pipeline before + // SslHandshakeCompletionEvent will complete, therefore, no data will be propagated before we finish + // initialization. + ChannelHandlerContext::read).flatMap(protocol -> { switch (protocol) { case HTTP_1_1: return NettyHttpServer.initChannel(channel, httpExecutionContext, config, diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/PipelinedLBHttpConnectionFactory.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/PipelinedLBHttpConnectionFactory.java index b9b4d20edf..f3951d1e76 100644 --- a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/PipelinedLBHttpConnectionFactory.java +++ b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/PipelinedLBHttpConnectionFactory.java @@ -47,7 +47,8 @@ final class PipelinedLBHttpConnectionFactory extends AbstractLB Single newFilterableConnection(final ResolvedAddress resolvedAddress, final TransportObserver observer) { assert config.h1Config() != null; - return buildStreaming(executionContext, resolvedAddress, config, observer) + return buildStreaming(executionContext, resolvedAddress, config.tcpConfig(), config.h1Config(), + config.hasProxy(), observer) .map(conn -> new PipelinedStreamingHttpConnection(conn, config.h1Config(), reqRespFactoryFunc.apply(HTTP_1_1), config.allowDropTrailersReadFromTransport())); } diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/ProxyConnectLBHttpConnectionFactory.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/ProxyConnectLBHttpConnectionFactory.java index 393c13b8ec..07e8d7c69e 100644 --- a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/ProxyConnectLBHttpConnectionFactory.java +++ b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/ProxyConnectLBHttpConnectionFactory.java @@ -16,7 +16,6 @@ package io.servicetalk.http.netty; import io.servicetalk.client.api.ConnectionFactoryFilter; -import io.servicetalk.concurrent.SingleSource; import io.servicetalk.concurrent.api.Completable; import io.servicetalk.concurrent.api.Publisher; import io.servicetalk.concurrent.api.Single; @@ -27,31 +26,33 @@ import io.servicetalk.http.api.StreamingHttpRequest; import io.servicetalk.http.api.StreamingHttpRequestResponseFactory; import io.servicetalk.http.api.StreamingHttpResponse; +import io.servicetalk.http.netty.AlpnChannelSingle.NoopChannelInitializer; +import io.servicetalk.tcp.netty.internal.ReadOnlyTcpClientConfig; import io.servicetalk.transport.api.ExecutionStrategy; import io.servicetalk.transport.api.TransportObserver; +import io.servicetalk.transport.netty.internal.CopyByteBufHandlerChannelInitializer; +import io.servicetalk.transport.netty.internal.DefaultNettyConnection; import io.servicetalk.transport.netty.internal.DeferSslHandler; +import io.servicetalk.transport.netty.internal.NoopTransportObserver.NoopConnectionObserver; import io.servicetalk.transport.netty.internal.StacklessClosedChannelException; import io.netty.channel.Channel; -import io.netty.channel.ChannelHandlerContext; -import io.netty.channel.ChannelInboundHandlerAdapter; -import io.netty.handler.ssl.SslHandshakeCompletionEvent; import java.util.function.Consumer; import javax.annotation.Nullable; -import static io.servicetalk.concurrent.api.Processors.newSingleProcessor; -import static io.servicetalk.concurrent.api.SourceAdapters.fromSource; import static io.servicetalk.http.api.HttpApiConversions.isPayloadEmpty; import static io.servicetalk.http.api.HttpContextKeys.HTTP_EXECUTION_STRATEGY_KEY; import static io.servicetalk.http.api.HttpExecutionStrategies.customStrategyBuilder; import static io.servicetalk.http.api.HttpExecutionStrategies.offloadNone; import static io.servicetalk.http.api.HttpHeaderNames.HOST; import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_1_1; +import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_2_0; import static io.servicetalk.http.api.HttpResponseStatus.StatusClass.SUCCESSFUL_2XX; +import static io.servicetalk.http.netty.AlpnLBHttpConnectionFactory.unknownAlpnProtocol; +import static io.servicetalk.http.netty.HttpProtocolConfigs.h1Default; import static io.servicetalk.http.netty.StreamingConnectionFactory.buildStreaming; import static io.servicetalk.utils.internal.ThrowableUtils.addSuppressed; -import static java.util.Objects.requireNonNull; final class ProxyConnectLBHttpConnectionFactory extends AbstractLBHttpConnectionFactory { @@ -71,7 +72,6 @@ final class ProxyConnectLBHttpConnectionFactory final Consumer connectRequestInitializer) { super(config, executionContext, version -> reqRespFactory, connectStrategy, connectionFactoryFilter, connectionFilterFunction, protocolBinding); - requireNonNull(config.h1Config(), "H1ProtocolConfig is required"); assert config.connectAddress() != null; this.connectAddress = config.connectAddress().toString(); this.connectRequestInitializer = connectRequestInitializer; @@ -80,9 +80,13 @@ final class ProxyConnectLBHttpConnectionFactory @Override Single newFilterableConnection(final ResolvedAddress resolvedAddress, final TransportObserver observer) { - assert config.h1Config() != null; - return buildStreaming(executionContext, resolvedAddress, config, observer) - .map(c -> new PipelinedStreamingHttpConnection(c, config.h1Config(), + final H1ProtocolConfig h1Config = config.h1Config() != null ? config.h1Config() : h1Default(); + return buildStreaming(executionContext, resolvedAddress, config.tcpConfig(), h1Config, config.hasProxy(), + observer) + // Always create PipelinedStreamingHttpConnection because: + // 1. buildStreaming creates a CloseHandler for pipelined request-response + // 2. in case ALPN negotiates HTTP/1.x we won't need to change the connection + .map(c -> new PipelinedStreamingHttpConnection(c, h1Config, reqRespFactoryFunc.apply(HTTP_1_1), config.allowDropTrailersReadFromTransport())) .flatMap(this::processConnect); } @@ -137,27 +141,10 @@ private static void configureOffloading(final StreamingHttpRequest request) { private Single handshake( final NettyFilterableStreamingHttpConnection connection) { return Single.defer(() -> { - final SingleSource.Processor - processor = newSingleProcessor(); final Channel channel = connection.nettyChannel(); assert channel.eventLoop().inEventLoop(); - channel.pipeline().addLast(new ChannelInboundHandlerAdapter() { - @Override - public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) { - if (evt instanceof SslHandshakeCompletionEvent) { - SslHandshakeCompletionEvent event = (SslHandshakeCompletionEvent) evt; - if (event.isSuccess()) { - processor.onSuccess(connection); - } else { - processor.onError(event.cause()); - } - channel.pipeline().remove(this); - } - ctx.fireUserEventTriggered(evt); - } - }); - final Single result; + final Single result; final DeferSslHandler deferSslHandler = channel.pipeline().get(DeferSslHandler.class); if (deferSslHandler == null) { if (!channel.isActive()) { @@ -171,8 +158,39 @@ public void userEventTriggered(final ChannelHandlerContext ctx, final Object evt DeferSslHandler.class + " in the channel pipeline.")); } } else { - deferSslHandler.ready(); - result = fromSource(processor); + result = new AlpnChannelSingle(channel, NoopChannelInitializer.INSTANCE, __ -> deferSslHandler.ready()); + } + return result.shareContextOnSubscribe(); + }).flatMap(protocol -> { + final Single result; + switch (protocol) { + case AlpnIds.HTTP_1_1: + // Nothing to do, HTTP/1.1 pipeline is already initialized + result = Single.succeeded(connection); + break; + case AlpnIds.HTTP_2: + final Channel channel = connection.nettyChannel(); + assert channel.eventLoop().inEventLoop(); + // Remove HTTP/1.1 handlers: + channel.pipeline().remove(HttpRequestEncoder.class); + channel.pipeline().remove(HttpResponseDecoder.class); + channel.pipeline().remove(CopyByteBufHandlerChannelInitializer.handlerClass()); + channel.pipeline().remove(DefaultNettyConnection.handlerClass()); + // Initialize HTTP/2: + final H2ProtocolConfig h2Config = config.h2Config(); + assert h2Config != null; + final ReadOnlyTcpClientConfig tcpConfig = config.tcpConfig(); + result = H2ClientParentConnectionContext.initChannel(channel, executionContext, + h2Config, reqRespFactoryFunc.apply(HTTP_2_0), tcpConfig.flushStrategy(), + tcpConfig.idleTimeoutMs(), tcpConfig.sslConfig(), + new H2ClientParentChannelInitializer(h2Config), + // FIXME: propagate real observer + NoopConnectionObserver.INSTANCE, config.allowDropTrailersReadFromTransport()) + .cast(FilterableStreamingHttpConnection.class); + break; + default: + result = unknownAlpnProtocol(protocol); + break; } return result.shareContextOnSubscribe(); }); diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/StreamingConnectionFactory.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/StreamingConnectionFactory.java index 6a7da1f3d3..b4dd8df929 100644 --- a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/StreamingConnectionFactory.java +++ b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/StreamingConnectionFactory.java @@ -50,14 +50,13 @@ private StreamingConnectionFactory() { static Single> buildStreaming( final HttpExecutionContext executionContext, final ResolvedAddress resolvedAddress, - final ReadOnlyHttpClientConfig roConfig, final TransportObserver observer) { - final ReadOnlyTcpClientConfig tcpConfig = withSslConfigPeerHost(resolvedAddress, roConfig.tcpConfig()); - final H1ProtocolConfig h1Config = roConfig.h1Config(); - assert h1Config != null; + final ReadOnlyTcpClientConfig originalTcpConfig, final H1ProtocolConfig h1Config, boolean hasProxy, + final TransportObserver observer) { + final ReadOnlyTcpClientConfig tcpConfig = withSslConfigPeerHost(resolvedAddress, originalTcpConfig); // We disable auto read so we can handle stuff in the ConnectionFilter before we accept any content. return TcpConnector.connect(null, resolvedAddress, tcpConfig, false, executionContext, (channel, connectionObserver) -> createConnection(channel, executionContext, h1Config, tcpConfig, - new TcpClientChannelInitializer(tcpConfig, connectionObserver, roConfig.hasProxy()), + new TcpClientChannelInitializer(tcpConfig, connectionObserver, hasProxy), connectionObserver), observer); } diff --git a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/HttpsProxyTest.java b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/HttpsProxyTest.java index 12f1b240d8..45d9d57eeb 100644 --- a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/HttpsProxyTest.java +++ b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/HttpsProxyTest.java @@ -22,6 +22,8 @@ import io.servicetalk.context.api.ContextMap; import io.servicetalk.http.api.BlockingHttpClient; import io.servicetalk.http.api.FilterableStreamingHttpConnection; +import io.servicetalk.http.api.HttpProtocolVersion; +import io.servicetalk.http.api.HttpRequest; import io.servicetalk.http.api.HttpResponse; import io.servicetalk.http.api.ReservedBlockingHttpConnection; import io.servicetalk.test.resources.DefaultTestCerts; @@ -33,14 +35,17 @@ import io.servicetalk.transport.netty.internal.ExecutionContextExtension; import org.junit.jupiter.api.AfterEach; -import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; import java.util.concurrent.atomic.AtomicReference; import javax.annotation.Nullable; @@ -48,14 +53,17 @@ import static io.servicetalk.http.api.HttpContextKeys.HTTP_TARGET_ADDRESS_BEHIND_PROXY; import static io.servicetalk.http.api.HttpHeaderNames.HOST; import static io.servicetalk.http.api.HttpHeaderNames.PROXY_AUTHORIZATION; -import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_1_1; import static io.servicetalk.http.api.HttpResponseStatus.INTERNAL_SERVER_ERROR; import static io.servicetalk.http.api.HttpResponseStatus.OK; import static io.servicetalk.http.api.HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED; import static io.servicetalk.http.api.HttpSerializers.textSerializerUtf8; +import static io.servicetalk.http.netty.HttpProtocol.HTTP_1; +import static io.servicetalk.http.netty.HttpProtocol.HTTP_2; +import static io.servicetalk.http.netty.HttpProtocol.toConfigs; import static io.servicetalk.test.resources.DefaultTestCerts.serverPemHostname; import static io.servicetalk.transport.netty.internal.AddressUtils.serverHostAndPort; import static java.nio.charset.StandardCharsets.US_ASCII; +import static java.util.Arrays.asList; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; @@ -66,6 +74,7 @@ class HttpsProxyTest { private static final Logger LOGGER = LoggerFactory.getLogger(HttpsProxyTest.class); private static final String AUTH_TOKEN = "aGVsbG86d29ybGQ="; + private static final Collection TRUE_FALSE = asList(true, false); @RegisterExtension static final ExecutionContextExtension SERVER_CTX = @@ -88,13 +97,27 @@ class HttpsProxyTest { @Nullable private BlockingHttpClient client; - void setUp(boolean withAuth) throws Exception { + private static List> protocols() { + return asList(asList(HTTP_1), asList(HTTP_2), asList(HTTP_2, HTTP_1), asList(HTTP_1, HTTP_2)); + } + + private static List protocolsWithAuth() { + List arguments = new ArrayList<>(); + for (List protocols : protocols()) { + for (boolean withAuth: TRUE_FALSE) { + arguments.add(Arguments.of(protocols, withAuth)); + } + } + return arguments; + } + + void setUp(List protocols, boolean withAuth) throws Exception { if (withAuth) { proxyTunnel.basicAuthToken(AUTH_TOKEN); } proxyAddress = proxyTunnel.startProxy(); - startServer(); - createClient(withAuth); + startServer(protocols); + createClient(protocols, withAuth); } @AfterEach @@ -114,16 +137,17 @@ static void safeClose(@Nullable AutoCloseable closeable) { } } - void startServer() throws Exception { + void startServer(List protocols) throws Exception { serverContext = BuilderUtils.newServerBuilder(SERVER_CTX) .sslConfig(new ServerSslConfigBuilder(DefaultTestCerts::loadServerPem, DefaultTestCerts::loadServerKey).build()) + .protocols(toConfigs(protocols)) .listenAndAwait((ctx, request, responseFactory) -> succeeded(responseFactory.ok() .payloadBody("host: " + request.headers().get(HOST), textSerializerUtf8()))); serverAddress = serverHostAndPort(serverContext); } - void createClient(boolean withAuth) { + void createClient(List protocols, boolean withAuth) { assert serverContext != null && proxyAddress != null; client = BuilderUtils.newClientBuilder(serverContext, CLIENT_CTX) .proxyAddress(proxyAddress, withAuth ? @@ -131,42 +155,48 @@ void createClient(boolean withAuth) { __ -> { }) .sslConfig(new ClientSslConfigBuilder(DefaultTestCerts::loadServerCAPem) .peerHost(serverPemHostname()).build()) + .protocols(toConfigs(protocols)) .appendConnectionFactoryFilter(new TargetAddressCheckConnectionFactoryFilter(targetAddress, true)) .buildBlocking(); } - @ParameterizedTest(name = "{displayName} [{index}] withAuth={0}") - @ValueSource(booleans = {false, true}) - void testClientRequest(boolean withAuth) throws Exception { - setUp(withAuth); + @ParameterizedTest(name = "{displayName} [{index}] protocols={0} withAuth={1}") + @MethodSource("protocolsWithAuth") + void testClientRequest(List protocols, boolean withAuth) throws Exception { + setUp(protocols, withAuth); assert client != null; - assertResponse(client.request(client.get("/path"))); + assertResponse(client.request(client.get("/path")), protocols.get(0).version); } - @ParameterizedTest(name = "{displayName} [{index}] withAuth={0}") - @ValueSource(booleans = {false, true}) - void testConnectionRequest(boolean withAuth) throws Exception { - setUp(withAuth); + @ParameterizedTest(name = "{displayName} [{index}] protocols={0} withAuth={1}") + @MethodSource("protocolsWithAuth") + void testConnectionRequest(List protocols, boolean withAuth) throws Exception { + setUp(protocols, withAuth); assert client != null; + HttpProtocolVersion expectedVersion = protocols.get(0).version; try (ReservedBlockingHttpConnection connection = client.reserveConnection(client.get("/"))) { - assertThat(connection.connectionContext().protocol(), is(HTTP_1_1)); + assertThat(connection.connectionContext().protocol(), is(expectedVersion)); assertThat(connection.connectionContext().sslConfig(), is(notNullValue())); assertThat(connection.connectionContext().sslSession(), is(notNullValue())); - assertResponse(connection.request(connection.get("/path"))); + HttpRequest request = connection.get("/path"); + assertThat(request.version(), is(expectedVersion)); + assertResponse(connection.request(request), expectedVersion); } } - private void assertResponse(HttpResponse httpResponse) { + private void assertResponse(HttpResponse httpResponse, HttpProtocolVersion expectedVersion) { assertThat(httpResponse.status(), is(OK)); + assertThat(httpResponse.version(), is(expectedVersion)); assertThat(proxyTunnel.connectCount(), is(1)); assertThat(httpResponse.payloadBody().toString(US_ASCII), is("host: " + serverAddress)); assertThat(targetAddress.get(), is(equalTo(serverAddress.toString()))); } - @Test - void testProxyAuthRequired() throws Exception { - setUp(false); + @ParameterizedTest(name = "{displayName} [{index}] protocols={0}") + @MethodSource("protocols") + void testProxyAuthRequired(List protocols) throws Exception { + setUp(protocols, false); proxyTunnel.basicAuthToken(AUTH_TOKEN); assert client != null; ProxyResponseException e = assertThrows(ProxyResponseException.class, @@ -175,9 +205,10 @@ void testProxyAuthRequired() throws Exception { assertThat(targetAddress.get(), is(equalTo(serverAddress.toString()))); } - @Test - void testBadProxyResponse() throws Exception { - setUp(false); + @ParameterizedTest(name = "{displayName} [{index}] protocols={0}") + @MethodSource("protocols") + void testBadProxyResponse(List protocols) throws Exception { + setUp(protocols, false); proxyTunnel.badResponseProxy(); assert client != null; ProxyResponseException e = assertThrows(ProxyResponseException.class, diff --git a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ProxyConnectLBHttpConnectionFactoryTest.java b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ProxyConnectLBHttpConnectionFactoryTest.java index 74db286684..fdf0bbd9c4 100644 --- a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ProxyConnectLBHttpConnectionFactoryTest.java +++ b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ProxyConnectLBHttpConnectionFactoryTest.java @@ -64,7 +64,6 @@ import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_1_1; import static io.servicetalk.http.api.HttpResponseStatus.INTERNAL_SERVER_ERROR; import static io.servicetalk.http.api.HttpResponseStatus.OK; -import static io.servicetalk.http.netty.HttpProtocolConfigs.h1Default; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsString; @@ -129,7 +128,6 @@ public void cancel() { connectRequestInitializer = mock(Consumer.class); HttpClientConfig config = new HttpClientConfig(); config.connectAddress(CONNECT_ADDRESS); - config.protocolConfigs().protocols(h1Default()); connectionFactory = new ProxyConnectLBHttpConnectionFactory<>(config.asReadOnly(), executionContext, null, REQ_RES_FACTORY, ConnectExecutionStrategy.offloadNone(), ConnectionFactoryFilter.identity(), mock(ProtocolBinding.class), connectRequestInitializer); diff --git a/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/CopyByteBufHandlerChannelInitializer.java b/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/CopyByteBufHandlerChannelInitializer.java index 24692d1b2d..2eb183a35e 100644 --- a/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/CopyByteBufHandlerChannelInitializer.java +++ b/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/CopyByteBufHandlerChannelInitializer.java @@ -55,6 +55,17 @@ public void init(final Channel channel) { channel.pipeline().addLast(copyHandler); } + /** + * Return {@link Class} of the {@link ChannelInboundHandler} in case there is a need to remove the handler from the + * {@link ChannelPipeline}. + * + * @return {@link Class} of the {@link ChannelInboundHandler} in case there is a need to remove the handler from the + * {@link ChannelPipeline}. + */ + public static Class handlerClass() { + return CopyByteBufHandler.class; + } + /** * This handler has to be added to the {@link ChannelPipeline} when {@link PooledByteBufAllocator} is used for * reading data from the socket. The allocated {@link ByteBuf}s must be copied and released before handed over to diff --git a/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/DefaultNettyConnection.java b/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/DefaultNettyConnection.java index d26d37e3fc..89b900de32 100644 --- a/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/DefaultNettyConnection.java +++ b/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/DefaultNettyConnection.java @@ -53,6 +53,7 @@ import io.netty.channel.ChannelConfig; import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandler; import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPromise; import io.netty.channel.EventLoop; @@ -524,6 +525,17 @@ protected void handleSubscribe( }; } + /** + * Return {@link Class} of the {@link ChannelInboundHandler} in case there is a need to remove the handler from the + * {@link ChannelPipeline}. + * + * @return {@link Class} of the {@link ChannelInboundHandler} in case there is a need to remove the handler from the + * {@link ChannelPipeline}. + */ + public static Class handlerClass() { + return NettyToStChannelHandler.class; + } + private static boolean shouldWaitForSslHandshake(@Nullable final SSLSession sslSession, @Nullable final SslConfig sslConfig, final ChannelPipeline pipeline) {