From 70afe5fbd7d1b5cf57c6b83a947b771b13f6f798 Mon Sep 17 00:00:00 2001 From: Peter Vlugter <59895+pvlugter@users.noreply.github.com> Date: Tue, 11 Jul 2023 23:44:36 +1200 Subject: [PATCH] Handle early disconnects before SSL handshake. [resolves #595][#596] --- .../postgresql/client/ReactorNettyClient.java | 39 +++++---- .../client/SSLSessionHandlerAdapter.java | 59 ++++++++----- .../client/SSLTunnelHandlerAdapter.java | 4 +- .../client/DowntimeIntegrationTests.java | 82 +++++++++++++++++++ 4 files changed, 140 insertions(+), 44 deletions(-) create mode 100644 src/test/java/io/r2dbc/postgresql/client/DowntimeIntegrationTests.java diff --git a/src/main/java/io/r2dbc/postgresql/client/ReactorNettyClient.java b/src/main/java/io/r2dbc/postgresql/client/ReactorNettyClient.java index c7abfb40a..cb40ede43 100644 --- a/src/main/java/io/r2dbc/postgresql/client/ReactorNettyClient.java +++ b/src/main/java/io/r2dbc/postgresql/client/ReactorNettyClient.java @@ -144,7 +144,7 @@ private ReactorNettyClient(Connection connection, ConnectionSettings settings) { Assert.requireNonNull(connection, "Connection must not be null"); this.settings = Assert.requireNonNull(settings, "ConnectionSettings must not be null"); - connection.addHandlerFirst(new EnsureSubscribersCompleteChannelHandler(this.requestSink)); + connection.addHandlerLast(new EnsureSubscribersCompleteChannelHandler(this.requestSink)); connection.addHandlerLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE - 5, 1, 4, -4, 0)); this.connection = connection; this.byteBufAllocator = connection.outbound().alloc(); @@ -392,9 +392,8 @@ public static Mono connect(SocketAddress socketAddress, Conn tcpClient = tcpClient.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, settings.getConnectTimeoutMs()); } - return tcpClient.connect().flatMap(it -> { - - ChannelPipeline pipeline = it.channel().pipeline(); + return tcpClient.doOnChannelInit((observer, channel, remoteAddress) -> { + ChannelPipeline pipeline = channel.pipeline(); InternalLogger logger = InternalLoggerFactory.getInstance(ReactorNettyClient.class); if (logger.isTraceEnabled()) { @@ -402,33 +401,33 @@ public static Mono connect(SocketAddress socketAddress, Conn new LoggingHandler(ReactorNettyClient.class, LogLevel.TRACE)); } - return registerSslHandler(settings.getSslConfig(), it).thenReturn(new ReactorNettyClient(it, settings)); - }); + registerSslHandler(settings.getSslConfig(), channel); + }).connect().flatMap(it -> + getSslHandshake(it.channel()).thenReturn(new ReactorNettyClient(it, settings)) + ); } - private static Mono registerSslHandler(SSLConfig sslConfig, Connection it) { - + private static void registerSslHandler(SSLConfig sslConfig, Channel channel) { try { if (sslConfig.getSslMode().startSsl()) { - return Mono.defer(() -> { - AbstractPostgresSSLHandlerAdapter sslAdapter; - if (sslConfig.getSslMode() == SSLMode.TUNNEL) { - sslAdapter = new SSLTunnelHandlerAdapter(it.outbound().alloc(), sslConfig); - } else { - sslAdapter = new SSLSessionHandlerAdapter(it.outbound().alloc(), sslConfig); - } - - it.addHandlerFirst(sslAdapter); - return sslAdapter.getHandshake(); + AbstractPostgresSSLHandlerAdapter sslAdapter; + if (sslConfig.getSslMode() == SSLMode.TUNNEL) { + sslAdapter = new SSLTunnelHandlerAdapter(channel.alloc(), sslConfig); + } else { + sslAdapter = new SSLSessionHandlerAdapter(channel.alloc(), sslConfig); + } - }).subscribeOn(Schedulers.boundedElastic()); + channel.pipeline().addFirst(sslAdapter); } } catch (Throwable e) { throw new RuntimeException(e); } + } - return Mono.empty(); + private static Mono getSslHandshake(Channel channel) { + AbstractPostgresSSLHandlerAdapter sslAdapter = channel.pipeline().get(AbstractPostgresSSLHandlerAdapter.class); + return (sslAdapter != null) ? sslAdapter.getHandshake() : Mono.empty(); } @Override diff --git a/src/main/java/io/r2dbc/postgresql/client/SSLSessionHandlerAdapter.java b/src/main/java/io/r2dbc/postgresql/client/SSLSessionHandlerAdapter.java index da2422262..4e64df8fa 100644 --- a/src/main/java/io/r2dbc/postgresql/client/SSLSessionHandlerAdapter.java +++ b/src/main/java/io/r2dbc/postgresql/client/SSLSessionHandlerAdapter.java @@ -33,6 +33,8 @@ final class SSLSessionHandlerAdapter extends AbstractPostgresSSLHandlerAdapter { private final SSLConfig sslConfig; + private boolean negotiating = true; + SSLSessionHandlerAdapter(ByteBufAllocator alloc, SSLConfig sslConfig) { super(alloc, sslConfig); this.alloc = alloc; @@ -40,28 +42,45 @@ final class SSLSessionHandlerAdapter extends AbstractPostgresSSLHandlerAdapter { } @Override - public void handlerAdded(ChannelHandlerContext ctx) { - Mono.from(SSLRequest.INSTANCE.encode(this.alloc)).subscribe(ctx::writeAndFlush); + public void channelActive(ChannelHandlerContext ctx) throws Exception { + if (negotiating) { + Mono.from(SSLRequest.INSTANCE.encode(this.alloc)).subscribe(ctx::writeAndFlush); + } + super.channelActive(ctx); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + if (negotiating) { + // If we receive channel inactive before negotiated, then the inbound has closed early. + PostgresqlSslException e = new PostgresqlSslException("Connection closed during SSL negotiation"); + completeHandshakeExceptionally(e); + } + super.channelInactive(ctx); } @Override - public void channelRead(ChannelHandlerContext ctx, Object msg) { - ByteBuf buf = (ByteBuf) msg; - char response = (char) buf.readByte(); - try { - switch (response) { - case 'S': - processSslEnabled(ctx, buf); - break; - case 'N': - processSslDisabled(); - break; - default: - buf.release(); - throw new IllegalStateException("Unknown SSLResponse from server: '" + response + "'"); + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (negotiating) { + ByteBuf buf = (ByteBuf) msg; + char response = (char) buf.readByte(); + try { + switch (response) { + case 'S': + processSslEnabled(ctx, buf); + break; + case 'N': + processSslDisabled(); + break; + default: + throw new IllegalStateException("Unknown SSLResponse from server: '" + response + "'"); + } + } finally { + buf.release(); + negotiating = false; } - } finally { - buf.release(); + } else { + super.channelRead(ctx, msg); } } @@ -82,9 +101,7 @@ private void processSslEnabled(ChannelHandlerContext ctx, ByteBuf msg) { completeHandshakeExceptionally(e); return; } - ctx.channel().pipeline() - .addFirst(this.getSslHandler()) - .remove(this); + ctx.channel().pipeline().addFirst(this.getSslHandler()); ctx.fireChannelRead(msg.retain()); } diff --git a/src/main/java/io/r2dbc/postgresql/client/SSLTunnelHandlerAdapter.java b/src/main/java/io/r2dbc/postgresql/client/SSLTunnelHandlerAdapter.java index 3301c7630..059d66b35 100644 --- a/src/main/java/io/r2dbc/postgresql/client/SSLTunnelHandlerAdapter.java +++ b/src/main/java/io/r2dbc/postgresql/client/SSLTunnelHandlerAdapter.java @@ -40,9 +40,7 @@ public void handlerAdded(ChannelHandlerContext ctx) { completeHandshakeExceptionally(e); return; } - ctx.channel().pipeline() - .addFirst(this.getSslHandler()) - .remove(this); + ctx.channel().pipeline().addFirst(this.getSslHandler()); } } diff --git a/src/test/java/io/r2dbc/postgresql/client/DowntimeIntegrationTests.java b/src/test/java/io/r2dbc/postgresql/client/DowntimeIntegrationTests.java new file mode 100644 index 000000000..d00b52a1d --- /dev/null +++ b/src/test/java/io/r2dbc/postgresql/client/DowntimeIntegrationTests.java @@ -0,0 +1,82 @@ +/* + * Copyright 2022 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.r2dbc.postgresql.client; + +import io.r2dbc.postgresql.PostgresqlConnectionConfiguration; +import io.r2dbc.postgresql.PostgresqlConnectionFactory; +import io.r2dbc.postgresql.api.PostgresqlException; +import org.junit.jupiter.api.Test; +import reactor.netty.DisposableChannel; +import reactor.netty.DisposableServer; +import reactor.netty.tcp.TcpServer; +import reactor.test.StepVerifier; + +import java.nio.channels.ClosedChannelException; +import java.util.function.Consumer; + +import static org.assertj.core.api.Assertions.assertThat; + +public class DowntimeIntegrationTests { + + // Simulate server downtime, where connections are accepted and then closed immediately + static DisposableServer newServer() { + return TcpServer.create() + .doOnConnection(DisposableChannel::dispose) + .bindNow(); + } + + static PostgresqlConnectionFactory newConnectionFactory(DisposableServer server, SSLMode sslMode) { + return new PostgresqlConnectionFactory( + PostgresqlConnectionConfiguration.builder() + .host(server.host()) + .port(server.port()) + .username("test") + .sslMode(sslMode) + .build()); + } + + static void verifyError(SSLMode sslMode, Consumer assertions) { + DisposableServer server = newServer(); + PostgresqlConnectionFactory connectionFactory = newConnectionFactory(server, sslMode); + connectionFactory.create().as(StepVerifier::create).verifyErrorSatisfies(assertions); + server.disposeNow(); + } + + @Test + void failSslHandshakeIfInboundClosed() { + verifyError(SSLMode.REQUIRE, error -> + assertThat(error) + .isInstanceOf(AbstractPostgresSSLHandlerAdapter.PostgresqlSslException.class) + .hasMessage("Connection closed during SSL negotiation")); + } + + @Test + void failSslTunnelIfInboundClosed() { + verifyError(SSLMode.TUNNEL, error -> { + assertThat(error) + .isInstanceOf(PostgresqlException.class) + .cause() + .isInstanceOf(ClosedChannelException.class); + + assertThat(error.getCause().getSuppressed().length).isOne(); + + assertThat(error.getCause().getSuppressed()[0]) + .hasMessage("Connection closed while SSL/TLS handshake was in progress"); + }); + } + +}