diff --git a/java/src/org/openqa/selenium/remote/http/ClientConfig.java b/java/src/org/openqa/selenium/remote/http/ClientConfig.java index fb4f69b191012..26552eeb24327 100644 --- a/java/src/org/openqa/selenium/remote/http/ClientConfig.java +++ b/java/src/org/openqa/selenium/remote/http/ClientConfig.java @@ -24,6 +24,7 @@ import java.net.URISyntaxException; import java.net.URL; import java.time.Duration; +import javax.net.ssl.SSLContext; import org.openqa.selenium.Credentials; import org.openqa.selenium.internal.Require; @@ -38,24 +39,28 @@ public class ClientConfig { private final Proxy proxy; private final Credentials credentials; + private final SSLContext sslContext; + protected ClientConfig( URI baseUri, Duration connectionTimeout, Duration readTimeout, Filter filters, Proxy proxy, - Credentials credentials) { + Credentials credentials, + SSLContext sslContext) { this.baseUri = baseUri; this.connectionTimeout = Require.nonNegative("Connection timeout", connectionTimeout); this.readTimeout = Require.nonNegative("Read timeout", readTimeout); this.filters = Require.nonNull("Filters", filters); this.proxy = proxy; this.credentials = credentials; + this.sslContext = sslContext; } public static ClientConfig defaultConfig() { return new ClientConfig( - null, Duration.ofSeconds(10), Duration.ofMinutes(3), DEFAULT_FILTER, null, null); + null, Duration.ofSeconds(10), Duration.ofMinutes(3), DEFAULT_FILTER, null, null, null); } public ClientConfig baseUri(URI baseUri) { @@ -65,7 +70,8 @@ public ClientConfig baseUri(URI baseUri) { readTimeout, filters, proxy, - credentials); + credentials, + sslContext); } public ClientConfig baseUrl(URL baseUrl) { @@ -95,7 +101,8 @@ public ClientConfig connectionTimeout(Duration timeout) { readTimeout, filters, proxy, - credentials); + credentials, + sslContext); } public Duration connectionTimeout() { @@ -109,7 +116,8 @@ public ClientConfig readTimeout(Duration timeout) { Require.nonNull("Read timeout", timeout), filters, proxy, - credentials); + credentials, + sslContext); } public Duration readTimeout() { @@ -124,12 +132,19 @@ public ClientConfig withFilter(Filter filter) { readTimeout, filter.andThen(DEFAULT_FILTER), proxy, - credentials); + credentials, + sslContext); } public ClientConfig withRetries() { return new ClientConfig( - baseUri, connectionTimeout, readTimeout, filters.andThen(RETRY_FILTER), proxy, credentials); + baseUri, + connectionTimeout, + readTimeout, + filters.andThen(RETRY_FILTER), + proxy, + credentials, + sslContext); } public Filter filter() { @@ -143,7 +158,8 @@ public ClientConfig proxy(Proxy proxy) { readTimeout, filters, Require.nonNull("Proxy", proxy), - credentials); + credentials, + sslContext); } public Proxy proxy() { @@ -157,13 +173,29 @@ public ClientConfig authenticateAs(Credentials credentials) { readTimeout, filters, proxy, - Require.nonNull("Credentials", credentials)); + Require.nonNull("Credentials", credentials), + sslContext); } public Credentials credentials() { return credentials; } + public ClientConfig sslContext(SSLContext sslContext) { + return new ClientConfig( + baseUri, + connectionTimeout, + readTimeout, + filters, + proxy, + credentials, + Require.nonNull("SSL Context", sslContext)); + } + + public SSLContext sslContext() { + return sslContext; + } + @Override public String toString() { return "ClientConfig{" @@ -179,6 +211,8 @@ public String toString() { + proxy + ", credentials=" + credentials + + ", sslcontext=" + + sslContext + '}'; } } diff --git a/java/src/org/openqa/selenium/remote/http/jdk/JdkHttpClient.java b/java/src/org/openqa/selenium/remote/http/jdk/JdkHttpClient.java index e8de858963311..019dd8ad06b70 100644 --- a/java/src/org/openqa/selenium/remote/http/jdk/JdkHttpClient.java +++ b/java/src/org/openqa/selenium/remote/http/jdk/JdkHttpClient.java @@ -47,6 +47,7 @@ import java.util.function.Supplier; import java.util.logging.Level; import java.util.logging.Logger; +import javax.net.ssl.SSLContext; import org.openqa.selenium.Credentials; import org.openqa.selenium.TimeoutException; import org.openqa.selenium.UsernameAndPassword; @@ -144,6 +145,11 @@ public void connectFailed(URI uri, SocketAddress sa, IOException ioe) { builder = builder.proxy(proxySelector); } + SSLContext sslContext = config.sslContext(); + if (sslContext != null) { + builder.sslContext(sslContext); + } + this.client = builder.build(); } diff --git a/java/test/org/openqa/selenium/grid/router/ProxyWebsocketTest.java b/java/test/org/openqa/selenium/grid/router/ProxyWebsocketTest.java index d723d008420ac..54ffd00c75873 100644 --- a/java/test/org/openqa/selenium/grid/router/ProxyWebsocketTest.java +++ b/java/test/org/openqa/selenium/grid/router/ProxyWebsocketTest.java @@ -23,7 +23,12 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; +import java.net.Socket; import java.net.URISyntaxException; +import java.security.KeyManagementException; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.security.cert.X509Certificate; import java.time.Instant; import java.util.Collections; import java.util.Optional; @@ -32,6 +37,10 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Supplier; import java.util.stream.Stream; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.TrustManager; +import javax.net.ssl.X509ExtendedTrustManager; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.params.ParameterizedTest; @@ -49,6 +58,7 @@ import org.openqa.selenium.grid.sessionmap.local.LocalSessionMap; import org.openqa.selenium.netty.server.NettyServer; import org.openqa.selenium.remote.SessionId; +import org.openqa.selenium.remote.http.ClientConfig; import org.openqa.selenium.remote.http.HttpClient; import org.openqa.selenium.remote.http.HttpHandler; import org.openqa.selenium.remote.http.HttpRequest; @@ -181,7 +191,10 @@ public void onText(CharSequence data) { @ParameterizedTest @MethodSource("data") void shouldBeAbleToSendMessagesOverSecureWebSocket(Supplier values) - throws URISyntaxException, InterruptedException { + throws URISyntaxException, + InterruptedException, + NoSuchAlgorithmException, + KeyManagementException { setFields(values); Config secureConfig = new MapConfig(ImmutableMap.of("server", ImmutableMap.of("https-self-signed", true))); @@ -207,11 +220,46 @@ void shouldBeAbleToSendMessagesOverSecureWebSocket(Supplier values) new ImmutableCapabilities(), Instant.now())); + final TrustManager trustManager = + new X509ExtendedTrustManager() { + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType, Socket socket) {} + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket) {} + + @Override + public void checkClientTrusted( + X509Certificate[] chain, String authType, SSLEngine engine) {} + + @Override + public void checkServerTrusted( + X509Certificate[] chain, String authType, SSLEngine engine) {} + + @Override + public java.security.cert.X509Certificate[] getAcceptedIssuers() { + return new java.security.cert.X509Certificate[0]; + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) {} + + @Override + public void checkServerTrusted( + java.security.cert.X509Certificate[] chain, String authType) {} + }; + + SSLContext sslContext = SSLContext.getInstance("SSL"); + sslContext.init(null, new TrustManager[] {trustManager}, new SecureRandom()); + CountDownLatch latch = new CountDownLatch(1); AtomicReference text = new AtomicReference<>(); try (WebSocket socket = clientFactory - .createClient(secureProxyServer.getUrl()) + .createClient( + ClientConfig.defaultConfig() + .baseUrl(secureProxyServer.getUrl()) + .sslContext(sslContext)) .openSocket( new HttpRequest(GET, String.format("/session/%s/" + protocol, id)), new WebSocket.Listener() {