diff --git a/src/main/java/io/vertx/core/http/impl/HttpChannelConnector.java b/src/main/java/io/vertx/core/http/impl/HttpChannelConnector.java index 244f14cd89f..060749d2d8b 100644 --- a/src/main/java/io/vertx/core/http/impl/HttpChannelConnector.java +++ b/src/main/java/io/vertx/core/http/impl/HttpChannelConnector.java @@ -85,7 +85,7 @@ public SocketAddress server() { } private void connect(EventLoopContext context, Promise promise) { - netClient.connectInternal(proxyOptions, server, peerAddress, this.options.isForceSni() ? peerAddress.host() : null, ssl, useAlpn, false, promise, context, 0); + netClient.connect(proxyOptions, server, peerAddress, this.options.isForceSni() ? peerAddress.host() : null, ssl, useAlpn, false, promise, context, 0); } public Future wrap(EventLoopContext context, NetSocket so_) { diff --git a/src/main/java/io/vertx/core/net/impl/NetClientImpl.java b/src/main/java/io/vertx/core/net/impl/NetClientImpl.java index 9c1b07cf77c..1c18b0c9c67 100644 --- a/src/main/java/io/vertx/core/net/impl/NetClientImpl.java +++ b/src/main/java/io/vertx/core/net/impl/NetClientImpl.java @@ -31,8 +31,8 @@ import io.vertx.core.buffer.impl.PartialPooledByteBufAllocator; import io.vertx.core.impl.CloseFuture; import io.vertx.core.impl.ContextInternal; -import io.vertx.core.impl.future.PromiseInternal; import io.vertx.core.impl.VertxInternal; +import io.vertx.core.impl.future.PromiseInternal; import io.vertx.core.impl.logging.Logger; import io.vertx.core.impl.logging.LoggerFactory; import io.vertx.core.net.NetClient; @@ -91,8 +91,6 @@ public NetClientImpl(VertxInternal vertx, TCPMetrics metrics, NetClientOptions o this.idleTimeoutUnit = options.getIdleTimeoutUnit(); this.closeFuture = closeFuture; this.proxyFilter = options.getNonProxyHosts() != null ? ProxyFilter.nonProxyHosts(options.getNonProxyHosts()) : ProxyFilter.DEFAULT_PROXY_FILTER; - - sslHelper.validate(vertx); } protected void initChannel(ChannelPipeline pipeline) { @@ -222,19 +220,39 @@ private void connect(SocketAddress remoteAddress, String serverName, Promise connectHandler, + ContextInternal context, + int remainingAttempts) { + sslHelper.validate(vertx) + .onComplete(validateResult -> { + if (validateResult.succeeded()) { + connectInternal(proxyOptions, remoteAddress, peerAddress, serverName, ssl, useAlpn, true, connectHandler, context, remainingAttempts); + } else { + failed(context, null, validateResult.cause(), connectHandler); + } + }); } - public void connectInternal(ProxyOptions proxyOptions, - SocketAddress remoteAddress, - SocketAddress peerAddress, - String serverName, - boolean ssl, - boolean useAlpn, - boolean registerWriteHandlers, - Promise connectHandler, - ContextInternal context, - int remainingAttempts) { + private void connectInternal(ProxyOptions proxyOptions, + SocketAddress remoteAddress, + SocketAddress peerAddress, + String serverName, + boolean ssl, + boolean useAlpn, + boolean registerWriteHandlers, + Promise connectHandler, + ContextInternal context, + int remainingAttempts) { checkClosed(); EventLoop eventLoop = context.nettyEventLoop(); diff --git a/src/main/java/io/vertx/core/net/impl/SSLHelper.java b/src/main/java/io/vertx/core/net/impl/SSLHelper.java index bdb8f6cd433..d605251075b 100755 --- a/src/main/java/io/vertx/core/net/impl/SSLHelper.java +++ b/src/main/java/io/vertx/core/net/impl/SSLHelper.java @@ -14,10 +14,13 @@ import io.netty.buffer.ByteBufAllocator; import io.netty.handler.ssl.*; import io.netty.util.Mapping; +import io.vertx.core.Future; +import io.vertx.core.Promise; import io.vertx.core.VertxException; import io.vertx.core.buffer.Buffer; import io.vertx.core.http.ClientAuth; import io.vertx.core.http.HttpClientOptions; +import io.vertx.core.impl.ContextInternal; import io.vertx.core.impl.VertxInternal; import io.vertx.core.impl.logging.Logger; import io.vertx.core.impl.logging.LoggerFactory; @@ -99,6 +102,8 @@ public static SSLEngineOptions resolveEngineOptions(TCPSSLOptions options) { private static final Logger log = LoggerFactory.getLogger(SSLHelper.class); private boolean ssl; + private volatile boolean validated = false; + private volatile Throwable validationError = null; private boolean sni; private long sslHandshakeTimeout; private TimeUnit sslHandshakeTimeoutUnit; @@ -502,10 +507,37 @@ public SslContext getContext(VertxInternal vertx, String serverName, boolean use } // This is called to validate some of the SSL params as that only happens when the context is created - public synchronized void validate(VertxInternal vertx) { + public synchronized Future validate(VertxInternal vertx) { + if (validated) { + if (validationError != null) { + return Future.failedFuture(validationError); + } + return Future.succeededFuture(); + } + + validated = true; + if (ssl) { - getContext(vertx, null); + ContextInternal validateContext = vertx.getOrCreateContext(); + Promise promise = validateContext.promise(); + validateContext.executeBlockingInternal(future -> { + try { + getContext(vertx, null); + future.complete(); + } catch (Exception e) { + future.fail(e); + } + }) + .onSuccess(nothing -> promise.complete()) + .onFailure(error -> { + validationError = error; + promise.fail(error); + }); + + return promise.future(); } + + return Future.succeededFuture(); } public SSLEngine createEngine(SslContext sslContext) { diff --git a/src/main/java/io/vertx/core/net/impl/TCPServerBase.java b/src/main/java/io/vertx/core/net/impl/TCPServerBase.java index 44dac3704dc..d5b41272636 100644 --- a/src/main/java/io/vertx/core/net/impl/TCPServerBase.java +++ b/src/main/java/io/vertx/core/net/impl/TCPServerBase.java @@ -18,17 +18,15 @@ import io.netty.channel.ChannelOption; import io.netty.channel.EventLoop; import io.netty.util.concurrent.GenericFutureListener; -import io.vertx.core.AsyncResult; import io.vertx.core.Closeable; -import io.vertx.core.CompositeFuture; import io.vertx.core.Context; import io.vertx.core.Future; import io.vertx.core.Handler; import io.vertx.core.Promise; import io.vertx.core.buffer.impl.PartialPooledByteBufAllocator; import io.vertx.core.impl.ContextInternal; -import io.vertx.core.impl.future.PromiseInternal; import io.vertx.core.impl.VertxInternal; +import io.vertx.core.impl.future.PromiseInternal; import io.vertx.core.impl.logging.Logger; import io.vertx.core.impl.logging.LoggerFactory; import io.vertx.core.net.NetServerOptions; @@ -37,12 +35,10 @@ import io.vertx.core.spi.metrics.TCPMetrics; import java.net.InetSocketAddress; -import java.util.ArrayList; import java.util.HashSet; -import java.util.List; import java.util.Map; import java.util.Set; -import java.util.stream.Collectors; +import java.util.concurrent.CountDownLatch; /** * Base class for TCP servers @@ -61,6 +57,7 @@ public abstract class TCPServerBase implements Closeable, MetricsProvider { // Per server private EventLoop eventLoop; private Handler worker; + private volatile CountDownLatch initialization = new CountDownLatch(1); private volatile boolean listening; private ContextInternal listenContext; private TCPServerBase actualServer; @@ -120,102 +117,159 @@ private synchronized io.netty.util.concurrent.Future listen(SocketAddre this.eventLoop = context.nettyEventLoop(); SocketAddress bindAddress; + actualPort = localAddress.port(); + String hostOrPath = localAddress.isInetSocket() ? localAddress.host() : localAddress.path(); + boolean shared; + ServerID id; + + if (actualPort > 0 || localAddress.isDomainSocket()) { + id = new ServerID(actualPort, hostOrPath); + shared = true; + bindAddress = localAddress; + } else { + if (actualPort < 0) { + id = new ServerID(actualPort, hostOrPath + "/" + -actualPort); + shared = true; + bindAddress = SocketAddress.inetSocketAddress(0, localAddress.host()); + } else { + id = new ServerID(actualPort, hostOrPath); + shared = false; + bindAddress = localAddress; + } + } + Map sharedNetServers = vertx.sharedTCPServers((Class) getClass()); - synchronized (sharedNetServers) { - actualPort = localAddress.port(); - String hostOrPath = localAddress.isInetSocket() ? localAddress.host() : localAddress.path(); + + if (shared) { TCPServerBase main; - boolean shared; - ServerID id; - if (actualPort > 0 || localAddress.isDomainSocket()) { - id = new ServerID(actualPort, hostOrPath); + + synchronized (sharedNetServers) { main = sharedNetServers.get(id); - shared = true; - bindAddress = localAddress; - } else { - if (actualPort < 0) { - id = new ServerID(actualPort, hostOrPath + "/" + -actualPort); - main = sharedNetServers.get(id); - shared = true; - bindAddress = SocketAddress.inetSocketAddress(0, localAddress.host()); - } else { - id = new ServerID(actualPort, hostOrPath); - main = null; - shared = false; - bindAddress = localAddress; + + if (main != null) { + try { + main.initialization.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return vertx.getAcceptorEventLoopGroup().next().newFailedFuture(e); + } + + if (main.isListening()) { + // Server already exists with that host/port - we will use that + actualServer = main; + metrics = main.metrics; + sslHelper = main.sslHelper; + worker = childHandler(listenContext, localAddress, sslHelper); + actualServer.servers.add(this); + actualServer.channelBalancer.addWorker(eventLoop, worker); + listenContext.addCloseHook(this); + return actualServer.bindFuture; + } } + + sharedNetServers.put(id, this); } - if (main == null) { - try { - sslHelper = createSSLHelper(); - sslHelper.validate(vertx); - worker = childHandler(listenContext, localAddress, sslHelper); - servers = new HashSet<>(); - servers.add(this); - channelBalancer = new ServerChannelLoadBalancer(vertx.getAcceptorEventLoopGroup().next()); - channelBalancer.addWorker(eventLoop, worker); - - ServerBootstrap bootstrap = new ServerBootstrap(); - bootstrap.group(vertx.getAcceptorEventLoopGroup(), channelBalancer.workers()); - if (sslHelper.isSSL()) { - bootstrap.childOption(ChannelOption.ALLOCATOR, PartialPooledByteBufAllocator.INSTANCE); - } else { - bootstrap.childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT); - } + } - bootstrap.childHandler(channelBalancer); - applyConnectionOptions(localAddress.isDomainSocket(), bootstrap); - - bindFuture = AsyncResolveConnectHelper.doBind(vertx, bindAddress, bootstrap); - bindFuture.addListener((GenericFutureListener>) res -> { - if (res.isSuccess()) { - Channel ch = res.getNow(); - log.trace("Net server listening on " + hostOrPath + ":" + ch.localAddress()); - if (shared) { - ch.closeFuture().addListener((ChannelFutureListener) channelFuture -> { - synchronized (sharedNetServers) { - sharedNetServers.remove(id); - } - }); - } - // Update port to actual port when it is not a domain socket as wildcard port 0 might have been used - if (bindAddress.isInetSocket()) { - actualPort = ((InetSocketAddress)ch.localAddress()).getPort(); - } - listenContext.addCloseHook(this); - metrics = createMetrics(localAddress); - } else { - if (shared) { + try { + sslHelper = createSSLHelper(); + } catch (Throwable t) { + cancelInitialization(); + return vertx.getAcceptorEventLoopGroup().next().newFailedFuture(t); + } + + io.netty.util.concurrent.Promise promise = vertx.getAcceptorEventLoopGroup().next().newPromise(); + + sslHelper.validate(vertx) + .onComplete(validateResult -> { + if (validateResult.succeeded()) { + listenValidated(localAddress, bindAddress, shared, hostOrPath, id) + .addListener((io.netty.util.concurrent.Future listenResult) -> { + if (listenResult.isSuccess()) { + promise.setSuccess(listenResult.get()); + initialization.countDown(); + } else { synchronized (sharedNetServers) { sharedNetServers.remove(id); } + cancelInitialization(); + promise.setFailure(listenResult.cause()); } - listening = false; - } - }); - } catch (Throwable t) { - listening = false; - return vertx.getAcceptorEventLoopGroup().next().newFailedFuture(t); - } - if (shared) { - sharedNetServers.put(id, this); + }); + } else { + synchronized (sharedNetServers) { + sharedNetServers.remove(id); + } + cancelInitialization(); + promise.setFailure(validateResult.cause()); } - actualServer = this; + }); + + return promise; + } + + private synchronized io.netty.util.concurrent.Future listenValidated(SocketAddress localAddress, SocketAddress bindAddress, boolean shared, String hostOrPath, ServerID id) { + Map sharedNetServers = vertx.sharedTCPServers((Class) getClass()); + + try { + worker = childHandler(listenContext, localAddress, sslHelper); + servers = new HashSet<>(); + servers.add(this); + channelBalancer = new ServerChannelLoadBalancer(vertx.getAcceptorEventLoopGroup().next()); + channelBalancer.addWorker(eventLoop, worker); + + ServerBootstrap bootstrap = new ServerBootstrap(); + bootstrap.group(vertx.getAcceptorEventLoopGroup(), channelBalancer.workers()); + if (sslHelper.isSSL()) { + bootstrap.childOption(ChannelOption.ALLOCATOR, PartialPooledByteBufAllocator.INSTANCE); } else { - // Server already exists with that host/port - we will use that - actualServer = main; - metrics = main.metrics; - sslHelper = main.sslHelper; - worker = childHandler(listenContext, localAddress, sslHelper); - actualServer.servers.add(this); - actualServer.channelBalancer.addWorker(eventLoop, worker); - listenContext.addCloseHook(this); + bootstrap.childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT); } + + bootstrap.childHandler(channelBalancer); + applyConnectionOptions(localAddress.isDomainSocket(), bootstrap); + + bindFuture = AsyncResolveConnectHelper.doBind(vertx, bindAddress, bootstrap); + bindFuture.addListener((GenericFutureListener>) res -> { + if (res.isSuccess()) { + Channel ch = res.getNow(); + log.trace("Net server listening on " + hostOrPath + ":" + ch.localAddress()); + if (shared) { + ch.closeFuture().addListener((ChannelFutureListener) channelFuture -> { + synchronized (sharedNetServers) { + sharedNetServers.remove(id); + } + }); + } + // Update port to actual port when it is not a domain socket as wildcard port 0 might have been used + if (bindAddress.isInetSocket()) { + actualPort = ((InetSocketAddress)ch.localAddress()).getPort(); + } + listenContext.addCloseHook(this); + metrics = createMetrics(localAddress); + } else { + if (shared) { + synchronized (sharedNetServers) { + sharedNetServers.remove(id); + } + } + listening = false; + } + }); + } catch (Throwable t) { + return vertx.getAcceptorEventLoopGroup().next().newFailedFuture(t); } + actualServer = this; return actualServer.bindFuture; } + private void cancelInitialization() { + listening = false; + initialization.countDown(); + initialization = new CountDownLatch(1); + } + public boolean isListening() { return listening; } diff --git a/src/test/java/io/vertx/core/http/HttpTLSTest.java b/src/test/java/io/vertx/core/http/HttpTLSTest.java index c524477ae2f..387e778dd91 100755 --- a/src/test/java/io/vertx/core/http/HttpTLSTest.java +++ b/src/test/java/io/vertx/core/http/HttpTLSTest.java @@ -1515,13 +1515,19 @@ public void testCrlInvalidPath() throws Exception { clientOptions.setTrustOptions(Trust.SERVER_PEM_ROOT_CA.get()); clientOptions.setSsl(true); clientOptions.addCrlPath("/invalid.pem"); - try { - vertx.createHttpClient(clientOptions); - fail("Was expecting a failure"); - } catch (VertxException e) { - assertNotNull(e.getCause()); - assertEquals(NoSuchFileException.class, e.getCause().getCause().getClass()); - } + vertx.createHttpClient(clientOptions) + .request(HttpMethod.GET, "/index") + .onComplete(result -> { + if (result.failed()) { + Throwable e = result.cause(); + assertNotNull(e.getCause()); + assertEquals(NoSuchFileException.class, e.getCause().getCause().getClass()); + testComplete(); + } else { + fail("Was expecting a failure"); + } + }); + await(); } // Proxy tests