Skip to content

Commit

Permalink
Fix NullPointerException being thrown by outbound encoder. (#2800)
Browse files Browse the repository at this point in the history
The NPE was due to null remote host which can happen with TLS 1.3 client
certification auth failure. Client cert auth failure can happen after a
successful handshake, causing an inbound SSLException. Changed the inbound
handler to detect and store the SSLException in the channel attributes and
changed the outbound encoder to detect if remote address is not available
and there was an SSLException. Changed the integration test to log the
stack trace of the caught exception for easier troubleshooting as these
SSL tests have been flaky.
  • Loading branch information
andreachild authored Oct 8, 2024
1 parent 8d9cc28 commit 2baa71d
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.util.AttributeKey;
import javax.net.ssl.SSLException;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.apache.tinkerpop.gremlin.driver.Result;
import org.apache.tinkerpop.gremlin.driver.ResultQueue;
import org.apache.tinkerpop.gremlin.driver.exception.ResponseException;
import org.apache.tinkerpop.gremlin.util.ExceptionHelper;
import org.apache.tinkerpop.gremlin.util.iterator.IteratorUtils;
import org.apache.tinkerpop.gremlin.util.message.ResponseMessage;
import org.apache.tinkerpop.gremlin.util.ser.SerializationException;
Expand All @@ -42,6 +44,7 @@
* as the {@link ResponseMessage} objects are deserialized.
*/
public class GremlinResponseHandler extends SimpleChannelInboundHandler<ResponseMessage> {
public static final AttributeKey<Throwable> INBOUND_SSL_EXCEPTION = AttributeKey.valueOf("inboundSslException");
private static final Logger logger = LoggerFactory.getLogger(GremlinResponseHandler.class);
private static final AttributeKey<ResponseException> CAUGHT_EXCEPTION = AttributeKey.valueOf("caughtException");
private final AtomicReference<ResultQueue> pending;
Expand Down Expand Up @@ -106,6 +109,12 @@ public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cau
final ResultQueue pendingQueue = pending.getAndSet(null);
if (pendingQueue != null) pendingQueue.markError(cause);

if (ExceptionHelper.getRootCause(cause) instanceof SSLException) {
// inbound ssl error can happen with tls 1.3 because client certification auth can fail after the handshake completes
// store the inbound ssl error so that outbound can retrieve it
ctx.channel().attr(INBOUND_SSL_EXCEPTION).set(cause);
}

// serialization exceptions should not close the channel - that's worth a retry
if (!IteratorUtils.anyMatch(ExceptionUtils.getThrowableList(cause).iterator(), t -> t instanceof SerializationException))
if (ctx.channel().isActive()) ctx.close();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
package org.apache.tinkerpop.gremlin.driver.handler;

import io.netty.buffer.ByteBuf;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.MessageToMessageEncoder;
Expand Down Expand Up @@ -68,14 +69,15 @@ protected void encode(final ChannelHandlerContext channelHandlerContext, final R
requestMessage));
}

final InetSocketAddress remoteAddress = getRemoteAddress(channelHandlerContext.channel());
try {
final ByteBuf buffer = serializer.serializeRequestAsBinary(requestMessage, channelHandlerContext.alloc());
FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/", buffer);
request.headers().add(HttpHeaderNames.CONTENT_TYPE, mimeType);
request.headers().add(HttpHeaderNames.CONTENT_LENGTH, buffer.readableBytes());
request.headers().add(HttpHeaderNames.ACCEPT, mimeType);
request.headers().add(HttpHeaderNames.ACCEPT_ENCODING, HttpHeaderValues.DEFLATE);
request.headers().add(HttpHeaderNames.HOST, ((InetSocketAddress) channelHandlerContext.channel().remoteAddress()).getAddress().getHostAddress());
request.headers().add(HttpHeaderNames.HOST, remoteAddress.getAddress().getHostAddress());
if (userAgentEnabled) {
request.headers().add(HttpHeaderNames.USER_AGENT, UserAgent.USER_AGENT);
}
Expand All @@ -95,4 +97,16 @@ protected void encode(final ChannelHandlerContext channelHandlerContext, final R
requestMessage, ex));
}
}

private static InetSocketAddress getRemoteAddress(Channel channel) {
final InetSocketAddress remoteAddress = (InetSocketAddress) channel.remoteAddress();
if (remoteAddress == null) {
final Throwable sslException = channel.attr(GremlinResponseHandler.INBOUND_SSL_EXCEPTION).get();
if (sslException != null) {
throw new RuntimeException("Request cannot be serialized because the channel is not connected due to an ssl error.", sslException);
}
throw new RuntimeException("Request cannot be serialized because the channel is not connected");
}
return remoteAddress;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import static org.hamcrest.CoreMatchers.containsString;
import static org.hamcrest.CoreMatchers.is;
Expand All @@ -50,6 +52,7 @@
import static org.junit.Assert.fail;

public class GremlinServerSslIntegrateTest extends AbstractGremlinServerIntegrationTest {
private static final Logger logger = LoggerFactory.getLogger(GremlinServerSslIntegrateTest.class);

/**
* Configure specific Gremlin Server settings for specific tests.
Expand Down Expand Up @@ -257,9 +260,7 @@ public void shouldEnableSslAndClientCertificateAuthAndFailWithoutCert() {
client.submit("'test'").one();
fail("Should throw exception because ssl client auth is enabled on the server but client does not have a cert");
} catch (Exception x) {
final Throwable root = ExceptionHelper.getRootCause(x);
assertThat(root, instanceOf(SSLException.class));
assertThat(root.getMessage(), containsString("bad_certificate"));
assertSslException(x, "bad_certificate");
} finally {
cluster.close();
}
Expand All @@ -275,9 +276,7 @@ public void shouldEnableSslAndClientCertificateAuthAndFailWithoutTrustedClientCe
client.submit("'test'").one();
fail("Should throw exception because ssl client auth is enabled on the server but does not trust client's cert");
} catch (Exception x) {
final Throwable root = ExceptionHelper.getRootCause(x);
assertThat(root, instanceOf(SSLException.class));
assertThat(root.getMessage(), containsString("bad_certificate"));
assertSslException(x, "bad_certificate");
} finally {
cluster.close();
}
Expand All @@ -293,9 +292,7 @@ public void shouldEnableSslAndFailIfProtocolsDontMatch() {
client.submit("'test'").one();
fail("Should throw exception because ssl client requires TLSv1.2 whereas server supports only TLSv1.1");
} catch (Exception x) {
final Throwable root = ExceptionHelper.getRootCause(x);
assertThat(root, instanceOf(SSLException.class));
assertThat(root.getMessage(), containsString("protocol_version"));
assertSslException(x ,"protocol_version");
} finally {
cluster.close();
}
Expand Down Expand Up @@ -344,4 +341,11 @@ public void shouldEnableSslAndClientCertificateAuthWithDifferentStoreType() {
cluster2.close();
}
}

private static void assertSslException(Exception x, String expectedSubstring) {
logger.warn("Exception caught: {}", x.getMessage(), x);
final Throwable root = ExceptionHelper.getRootCause(x);
assertThat(root, instanceOf(SSLException.class));
assertThat(root.getMessage(), containsString(expectedSubstring));
}
}

0 comments on commit 2baa71d

Please sign in to comment.