Skip to content

Commit

Permalink
Handle early disconnects before SSL handshake.
Browse files Browse the repository at this point in the history
[resolves #595][#596]
  • Loading branch information
pvlugter authored Jul 11, 2023
1 parent 4fa302c commit 70afe5f
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 44 deletions.
39 changes: 19 additions & 20 deletions src/main/java/io/r2dbc/postgresql/client/ReactorNettyClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ private ReactorNettyClient(Connection connection, ConnectionSettings settings) {
Assert.requireNonNull(connection, "Connection must not be null");
this.settings = Assert.requireNonNull(settings, "ConnectionSettings must not be null");

connection.addHandlerFirst(new EnsureSubscribersCompleteChannelHandler(this.requestSink));
connection.addHandlerLast(new EnsureSubscribersCompleteChannelHandler(this.requestSink));
connection.addHandlerLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE - 5, 1, 4, -4, 0));
this.connection = connection;
this.byteBufAllocator = connection.outbound().alloc();
Expand Down Expand Up @@ -392,43 +392,42 @@ public static Mono<ReactorNettyClient> connect(SocketAddress socketAddress, Conn
tcpClient = tcpClient.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, settings.getConnectTimeoutMs());
}

return tcpClient.connect().flatMap(it -> {

ChannelPipeline pipeline = it.channel().pipeline();
return tcpClient.doOnChannelInit((observer, channel, remoteAddress) -> {
ChannelPipeline pipeline = channel.pipeline();

InternalLogger logger = InternalLoggerFactory.getInstance(ReactorNettyClient.class);
if (logger.isTraceEnabled()) {
pipeline.addFirst(LoggingHandler.class.getSimpleName(),
new LoggingHandler(ReactorNettyClient.class, LogLevel.TRACE));
}

return registerSslHandler(settings.getSslConfig(), it).thenReturn(new ReactorNettyClient(it, settings));
});
registerSslHandler(settings.getSslConfig(), channel);
}).connect().flatMap(it ->
getSslHandshake(it.channel()).thenReturn(new ReactorNettyClient(it, settings))
);
}

private static Mono<? extends Void> registerSslHandler(SSLConfig sslConfig, Connection it) {

private static void registerSslHandler(SSLConfig sslConfig, Channel channel) {
try {
if (sslConfig.getSslMode().startSsl()) {

return Mono.defer(() -> {
AbstractPostgresSSLHandlerAdapter sslAdapter;
if (sslConfig.getSslMode() == SSLMode.TUNNEL) {
sslAdapter = new SSLTunnelHandlerAdapter(it.outbound().alloc(), sslConfig);
} else {
sslAdapter = new SSLSessionHandlerAdapter(it.outbound().alloc(), sslConfig);
}

it.addHandlerFirst(sslAdapter);
return sslAdapter.getHandshake();
AbstractPostgresSSLHandlerAdapter sslAdapter;
if (sslConfig.getSslMode() == SSLMode.TUNNEL) {
sslAdapter = new SSLTunnelHandlerAdapter(channel.alloc(), sslConfig);
} else {
sslAdapter = new SSLSessionHandlerAdapter(channel.alloc(), sslConfig);
}

}).subscribeOn(Schedulers.boundedElastic());
channel.pipeline().addFirst(sslAdapter);
}
} catch (Throwable e) {
throw new RuntimeException(e);
}
}

return Mono.empty();
private static Mono<Void> getSslHandshake(Channel channel) {
AbstractPostgresSSLHandlerAdapter sslAdapter = channel.pipeline().get(AbstractPostgresSSLHandlerAdapter.class);
return (sslAdapter != null) ? sslAdapter.getHandshake() : Mono.empty();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,35 +33,54 @@ final class SSLSessionHandlerAdapter extends AbstractPostgresSSLHandlerAdapter {

private final SSLConfig sslConfig;

private boolean negotiating = true;

SSLSessionHandlerAdapter(ByteBufAllocator alloc, SSLConfig sslConfig) {
super(alloc, sslConfig);
this.alloc = alloc;
this.sslConfig = sslConfig;
}

@Override
public void handlerAdded(ChannelHandlerContext ctx) {
Mono.from(SSLRequest.INSTANCE.encode(this.alloc)).subscribe(ctx::writeAndFlush);
public void channelActive(ChannelHandlerContext ctx) throws Exception {
if (negotiating) {
Mono.from(SSLRequest.INSTANCE.encode(this.alloc)).subscribe(ctx::writeAndFlush);
}
super.channelActive(ctx);
}

@Override
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
if (negotiating) {
// If we receive channel inactive before negotiated, then the inbound has closed early.
PostgresqlSslException e = new PostgresqlSslException("Connection closed during SSL negotiation");
completeHandshakeExceptionally(e);
}
super.channelInactive(ctx);
}

@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
ByteBuf buf = (ByteBuf) msg;
char response = (char) buf.readByte();
try {
switch (response) {
case 'S':
processSslEnabled(ctx, buf);
break;
case 'N':
processSslDisabled();
break;
default:
buf.release();
throw new IllegalStateException("Unknown SSLResponse from server: '" + response + "'");
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if (negotiating) {
ByteBuf buf = (ByteBuf) msg;
char response = (char) buf.readByte();
try {
switch (response) {
case 'S':
processSslEnabled(ctx, buf);
break;
case 'N':
processSslDisabled();
break;
default:
throw new IllegalStateException("Unknown SSLResponse from server: '" + response + "'");
}
} finally {
buf.release();
negotiating = false;
}
} finally {
buf.release();
} else {
super.channelRead(ctx, msg);
}
}

Expand All @@ -82,9 +101,7 @@ private void processSslEnabled(ChannelHandlerContext ctx, ByteBuf msg) {
completeHandshakeExceptionally(e);
return;
}
ctx.channel().pipeline()
.addFirst(this.getSslHandler())
.remove(this);
ctx.channel().pipeline().addFirst(this.getSslHandler());
ctx.fireChannelRead(msg.retain());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@ public void handlerAdded(ChannelHandlerContext ctx) {
completeHandshakeExceptionally(e);
return;
}
ctx.channel().pipeline()
.addFirst(this.getSslHandler())
.remove(this);
ctx.channel().pipeline().addFirst(this.getSslHandler());
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
/*
* Copyright 2022 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.r2dbc.postgresql.client;

import io.r2dbc.postgresql.PostgresqlConnectionConfiguration;
import io.r2dbc.postgresql.PostgresqlConnectionFactory;
import io.r2dbc.postgresql.api.PostgresqlException;
import org.junit.jupiter.api.Test;
import reactor.netty.DisposableChannel;
import reactor.netty.DisposableServer;
import reactor.netty.tcp.TcpServer;
import reactor.test.StepVerifier;

import java.nio.channels.ClosedChannelException;
import java.util.function.Consumer;

import static org.assertj.core.api.Assertions.assertThat;

public class DowntimeIntegrationTests {

// Simulate server downtime, where connections are accepted and then closed immediately
static DisposableServer newServer() {
return TcpServer.create()
.doOnConnection(DisposableChannel::dispose)
.bindNow();
}

static PostgresqlConnectionFactory newConnectionFactory(DisposableServer server, SSLMode sslMode) {
return new PostgresqlConnectionFactory(
PostgresqlConnectionConfiguration.builder()
.host(server.host())
.port(server.port())
.username("test")
.sslMode(sslMode)
.build());
}

static void verifyError(SSLMode sslMode, Consumer<Throwable> assertions) {
DisposableServer server = newServer();
PostgresqlConnectionFactory connectionFactory = newConnectionFactory(server, sslMode);
connectionFactory.create().as(StepVerifier::create).verifyErrorSatisfies(assertions);
server.disposeNow();
}

@Test
void failSslHandshakeIfInboundClosed() {
verifyError(SSLMode.REQUIRE, error ->
assertThat(error)
.isInstanceOf(AbstractPostgresSSLHandlerAdapter.PostgresqlSslException.class)
.hasMessage("Connection closed during SSL negotiation"));
}

@Test
void failSslTunnelIfInboundClosed() {
verifyError(SSLMode.TUNNEL, error -> {
assertThat(error)
.isInstanceOf(PostgresqlException.class)
.cause()
.isInstanceOf(ClosedChannelException.class);

assertThat(error.getCause().getSuppressed().length).isOne();

assertThat(error.getCause().getSuppressed()[0])
.hasMessage("Connection closed while SSL/TLS handshake was in progress");
});
}

}

0 comments on commit 70afe5f

Please sign in to comment.