Skip to content

Commit

Permalink
Merge pull request #391 from praveenag/1.4
Browse files Browse the repository at this point in the history
better exception messages when database is down.
  • Loading branch information
praveenag authored Jul 18, 2017
2 parents 8935542 + dd710ec commit 45c2930
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ public class TLSSocketChannel implements ByteChannel
{
private final ByteChannel channel; // The real channel the data is sent to and read from
private final Logger logger;
private final BoltServerAddress address;

private SSLEngine sslEngine;

Expand All @@ -65,17 +66,20 @@ public class TLSSocketChannel implements ByteChannel

private static final ByteBuffer DUMMY_BUFFER = ByteBuffer.allocate( 0 );

public static TLSSocketChannel create( BoltServerAddress address, SecurityPlan securityPlan, ByteChannel channel, Logger logger )
public static TLSSocketChannel create( BoltServerAddress address, SecurityPlan securityPlan,
ByteChannel channel, Logger logger )
throws IOException
{
SSLEngine sslEngine = securityPlan.sslContext().createSSLEngine( address.host(), address.port() );
sslEngine.setUseClientMode( true );
return create( channel, logger, sslEngine );
return create( channel, logger, sslEngine, address );
}

public static TLSSocketChannel create( ByteChannel channel, Logger logger, SSLEngine sslEngine ) throws IOException
public static TLSSocketChannel create( ByteChannel channel, Logger logger, SSLEngine sslEngine,
BoltServerAddress address ) throws IOException
{
TLSSocketChannel tlsChannel = new TLSSocketChannel( channel, logger, sslEngine );

TLSSocketChannel tlsChannel = new TLSSocketChannel( channel, logger, sslEngine, address );
try
{
tlsChannel.runHandshake();
Expand All @@ -87,9 +91,10 @@ public static TLSSocketChannel create( ByteChannel channel, Logger logger, SSLEn
return tlsChannel;
}

TLSSocketChannel( ByteChannel channel, Logger logger, SSLEngine sslEngine )
TLSSocketChannel( ByteChannel channel, Logger logger, SSLEngine sslEngine, BoltServerAddress address )
throws IOException
{
this.address = address;
this.logger = logger;
this.channel = channel;
this.sslEngine = sslEngine;
Expand Down Expand Up @@ -167,9 +172,10 @@ int channelRead( ByteBuffer toBuffer ) throws IOException
{
// best effort
}

throw new ServiceUnavailableException(
"SSL Connection terminated while receiving data. " +
"This can happen due to network instabilities, or due to restarts of the database." );
"Failed to receive any data from the connected address " + address + ". " +
"Please ensure a working connection to the database." );
}
return read;
}
Expand All @@ -193,8 +199,8 @@ int channelWrite( ByteBuffer fromBuffer ) throws IOException
// best effort
}
throw new ServiceUnavailableException(
"SSL Connection terminated while writing data. " +
"This can happen due to network instabilities, or due to restarts of the database." );
"Failed to send any data to the connected address " + address + ". " +
"Please ensure a working connection to the database." );
}
return written;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import javax.net.ssl.SSLHandshakeException;
import javax.net.ssl.SSLSession;

import org.neo4j.driver.internal.net.BoltServerAddress;
import org.neo4j.driver.v1.exceptions.ServiceUnavailableException;
import org.neo4j.driver.v1.exceptions.SecurityException;

Expand All @@ -40,6 +41,7 @@
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.neo4j.driver.internal.logging.DevNullLogger.DEV_NULL_LOGGER;
import static org.neo4j.driver.internal.net.BoltServerAddress.LOCAL_DEFAULT;
import static org.neo4j.driver.internal.security.TLSSocketChannel.create;

public class TLSSocketChannelTest
Expand All @@ -59,7 +61,7 @@ public void shouldCloseConnectionIfFailedToRead() throws Throwable
when( mockedSslSession.getPacketBufferSize() ).thenReturn( 10 );

// When
TLSSocketChannel channel = new TLSSocketChannel( mockedChannel, DEV_NULL_LOGGER, mockedSslEngine );
TLSSocketChannel channel = new TLSSocketChannel( mockedChannel, DEV_NULL_LOGGER, mockedSslEngine, LOCAL_DEFAULT );

try
{
Expand All @@ -69,7 +71,8 @@ public void shouldCloseConnectionIfFailedToRead() throws Throwable
catch( Exception e )
{
assertThat( e, instanceOf( ServiceUnavailableException.class ) );
assertThat( e.getMessage(), startsWith( "SSL Connection terminated while receiving data. " ) );
assertThat( e.getMessage(), startsWith( "Failed to receive any data from the connected address " +
"localhost:7687. Please ensure a working connection to the database." ) );
}
// Then
verify( mockedChannel ).close();
Expand All @@ -89,7 +92,7 @@ public void shouldCloseConnectionIfFailedToWrite() throws Throwable
when( mockedSslSession.getPacketBufferSize() ).thenReturn( 10 );

// When
TLSSocketChannel channel = new TLSSocketChannel( mockedChannel, DEV_NULL_LOGGER, mockedSslEngine );
TLSSocketChannel channel = new TLSSocketChannel( mockedChannel, DEV_NULL_LOGGER, mockedSslEngine, LOCAL_DEFAULT );

try
{
Expand All @@ -99,7 +102,8 @@ public void shouldCloseConnectionIfFailedToWrite() throws Throwable
catch( Exception e )
{
assertThat( e, instanceOf( ServiceUnavailableException.class ) );
assertThat( e.getMessage(), startsWith( "SSL Connection terminated while writing data. " ) );
assertThat( e.getMessage(), startsWith( "Failed to send any data to the connected address localhost:7687. " +
"Please ensure a working connection to the database." ) );
}

// Then
Expand All @@ -123,7 +127,7 @@ public void shouldThrowUnauthorizedIfFailedToHandshake() throws Throwable
// When & Then
try
{
create( mockedChannel, DEV_NULL_LOGGER, mockedSslEngine );
create( mockedChannel, DEV_NULL_LOGGER, mockedSslEngine, LOCAL_DEFAULT );
fail( "Should fail to run handshake" );
}
catch( Exception e )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;

import org.neo4j.driver.internal.net.BoltServerAddress;
import org.neo4j.driver.internal.security.TLSSocketChannel;

import static org.hamcrest.core.IsEqual.equalTo;
import static org.junit.Assert.assertThat;
import static org.neo4j.driver.internal.logging.DevNullLogger.DEV_NULL_LOGGER;
import static org.neo4j.driver.internal.net.BoltServerAddress.LOCAL_DEFAULT;

/**
* This tests that the TLSSocketChannel handles every combination of network buffer sizes that we
Expand All @@ -58,7 +60,7 @@ protected void testForBufferSizes( byte[] blobOfData, int networkFrameSize, int
SocketAddress address = new InetSocketAddress( serverSocket.getInetAddress(), serverSocket.getLocalPort() );
ByteChannel ch = new LittleAtATimeChannel( SocketChannel.open( address ), networkFrameSize );

try ( TLSSocketChannel channel = TLSSocketChannel.create( ch, DEV_NULL_LOGGER, engine ) )
try ( TLSSocketChannel channel = TLSSocketChannel.create( ch, DEV_NULL_LOGGER, engine, LOCAL_DEFAULT ) )
{
ByteBuffer readBuffer = ByteBuffer.allocate( blobOfData.length );
while ( readBuffer.position() < readBuffer.capacity() )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,14 @@
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;

import org.neo4j.driver.internal.net.BoltServerAddress;
import org.neo4j.driver.internal.security.TLSSocketChannel;

import static org.hamcrest.core.IsEqual.equalTo;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import static org.neo4j.driver.internal.logging.DevNullLogger.DEV_NULL_LOGGER;
import static org.neo4j.driver.internal.net.BoltServerAddress.LOCAL_DEFAULT;

/**
* This tests that the TLSSocketChannel handles every combination of network buffer sizes that we
Expand All @@ -60,7 +62,7 @@ protected void testForBufferSizes( byte[] blobOfData, int networkFrameSize, int
SocketAddress address = new InetSocketAddress( serverSocket.getInetAddress(), serverSocket.getLocalPort() );
ByteChannel ch = new LittleAtATimeChannel( SocketChannel.open( address ), networkFrameSize );

try ( TLSSocketChannel channel = TLSSocketChannel.create( ch, DEV_NULL_LOGGER, engine ) )
try ( TLSSocketChannel channel = TLSSocketChannel.create( ch, DEV_NULL_LOGGER, engine, LOCAL_DEFAULT ) )
{
ByteBuffer writeBuffer = ByteBuffer.wrap( blobOfData );
while ( writeBuffer.position() < writeBuffer.capacity() )
Expand Down

0 comments on commit 45c2930

Please sign in to comment.