diff --git a/container-tests/src/test/java/okhttp3/containers/BasicMockServerTest.kt b/container-tests/src/test/java/okhttp3/containers/BasicMockServerTest.kt index 659b9714443b..14c0249d0904 100644 --- a/container-tests/src/test/java/okhttp3/containers/BasicMockServerTest.kt +++ b/container-tests/src/test/java/okhttp3/containers/BasicMockServerTest.kt @@ -17,6 +17,7 @@ package okhttp3.containers import assertk.assertThat import assertk.assertions.contains +import javax.net.ssl.SSLSocketFactory import javax.net.ssl.TrustManagerFactory import javax.net.ssl.X509TrustManager import okhttp3.HttpUrl.Companion.toHttpUrl @@ -84,15 +85,20 @@ class BasicMockServerTest { fun OkHttpClient.Builder.trustMockServer(): OkHttpClient.Builder = apply { - val keyStoreFactory = KeyStoreFactory(Configuration.configuration(), MockServerLogger()) - - val socketFactory = keyStoreFactory.sslContext().socketFactory - - val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) - trustManagerFactory.init(keyStoreFactory.loadOrCreateKeyStore()) - val trustManager = trustManagerFactory.trustManagers.first() as X509TrustManager + val (socketFactory, trustManager) = trustManagerPair() sslSocketFactory(socketFactory, trustManager) } + + fun trustManagerPair(): Pair { + val keyStoreFactory = KeyStoreFactory(Configuration.configuration(), MockServerLogger()) + + val socketFactory = keyStoreFactory.sslContext().socketFactory + + val trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()) + trustManagerFactory.init(keyStoreFactory.loadOrCreateKeyStore()) + val trustManager = trustManagerFactory.trustManagers.first() as X509TrustManager + return Pair(socketFactory, trustManager) + } } } diff --git a/container-tests/src/test/java/okhttp3/containers/BasicProxyTest.kt b/container-tests/src/test/java/okhttp3/containers/BasicProxyTest.kt index aab26a7ca184..e3521c2361f2 100644 --- a/container-tests/src/test/java/okhttp3/containers/BasicProxyTest.kt +++ b/container-tests/src/test/java/okhttp3/containers/BasicProxyTest.kt @@ -27,6 +27,7 @@ import okhttp3.OkHttpClient import okhttp3.Protocol import okhttp3.Request import okhttp3.containers.BasicMockServerTest.Companion.MOCKSERVER_IMAGE +import okhttp3.containers.BasicMockServerTest.Companion.trustManagerPair import okhttp3.containers.BasicMockServerTest.Companion.trustMockServer import okio.buffer import okio.source @@ -104,6 +105,13 @@ class BasicProxyTest { @Test fun testOkHttpSecureProxiedHttp1() { testRequest { + it.withProxyConfiguration( + ProxyConfiguration.proxyConfiguration( + ProxyConfiguration.Type.HTTPS, + it.remoteAddress(), + ), + ) + val client = OkHttpClient.Builder() .trustMockServer() @@ -121,6 +129,36 @@ class BasicProxyTest { } } + @Test + fun testOkHttpSecureProxiedHttp2() { + testRequest { + it.withProxyConfiguration( + ProxyConfiguration.proxyConfiguration( + ProxyConfiguration.Type.HTTPS, + it.remoteAddress(), + ), + ) + + val (socketFactory, trustManager) = trustManagerPair() + + val client = + OkHttpClient.Builder() + .sslSocketFactory(socketFactory, trustManager) + .proxy(Proxy(Proxy.Type.HTTP, it.remoteAddress())) + .protocols(listOf(Protocol.HTTP_2, Protocol.HTTP_1_1)) + .socketFactory(socketFactory) + .build() + + val response = + client.newCall( + Request((mockServer.secureEndpoint + "/person?name=peter").toHttpUrl()), + ).execute() + + assertThat(response.body.string()).contains("Peter the person") + assertThat(response.protocol).isEqualTo(Protocol.HTTP_2) + } + } + @Test fun testUrlConnectionDirect() { testRequest { @@ -169,6 +207,13 @@ class BasicProxyTest { HttpsURLConnection.setDefaultSSLSocketFactory(keyStoreFactory.sslContext().socketFactory) testRequest { + it.withProxyConfiguration( + ProxyConfiguration.proxyConfiguration( + ProxyConfiguration.Type.HTTPS, + it.remoteAddress(), + ), + ) + val proxy = Proxy( Proxy.Type.HTTP, diff --git a/okcurl/src/main/kotlin/okhttp3/curl/Main.kt b/okcurl/src/main/kotlin/okhttp3/curl/Main.kt index 126af4c31dbb..bb60d741e558 100644 --- a/okcurl/src/main/kotlin/okhttp3/curl/Main.kt +++ b/okcurl/src/main/kotlin/okhttp3/curl/Main.kt @@ -16,12 +16,15 @@ package okhttp3.curl import com.github.ajalt.clikt.core.CliktCommand +import com.github.ajalt.clikt.core.UsageError import com.github.ajalt.clikt.parameters.arguments.argument import com.github.ajalt.clikt.parameters.options.default import com.github.ajalt.clikt.parameters.options.flag import com.github.ajalt.clikt.parameters.options.multiple import com.github.ajalt.clikt.parameters.options.option import com.github.ajalt.clikt.parameters.types.int +import java.net.InetSocketAddress +import java.net.Proxy import java.security.cert.X509Certificate import java.util.Properties import java.util.concurrent.TimeUnit.SECONDS @@ -74,6 +77,8 @@ class Main : CliktCommand(name = NAME, help = "A curl for the next-generation we val sslDebug: Boolean by option(help = "Output SSL Debug").flag() + val proxy: String? by option(help = "Proxy config") + val url: String? by argument(name = "url", help = "Remote resource URL") var client: Call.Factory? = null @@ -98,9 +103,10 @@ class Main : CliktCommand(name = NAME, help = "A curl for the next-generation we if (callTimeout != DEFAULT_TIMEOUT) { builder.callTimeout(callTimeout.toLong(), SECONDS) } + var sslSocketFactory: SSLSocketFactory? = null if (allowInsecure) { val trustManager = createInsecureTrustManager() - val sslSocketFactory = createInsecureSslSocketFactory(trustManager) + sslSocketFactory = createInsecureSslSocketFactory(trustManager) builder.sslSocketFactory(sslSocketFactory, trustManager) builder.hostnameVerifier(createInsecureHostnameVerifier()) } @@ -108,6 +114,26 @@ class Main : CliktCommand(name = NAME, help = "A curl for the next-generation we val logger = HttpLoggingInterceptor.Logger(::println) builder.eventListenerFactory(LoggingEventListener.Factory(logger)) } + proxy?.let { + val (type, host, port) = it.split(':', limit = 3) + val address = InetSocketAddress.createUnresolved(host, port.toInt()) + when (type) { + "http" -> { + builder.proxy(Proxy(Proxy.Type.HTTP, address)) + } + + "https" -> { + builder.proxy(Proxy(Proxy.Type.HTTP, address)) + .socketFactory(sslSocketFactory ?: Platform.get().newSslSocketFactory(Platform.get().platformTrustManager())) + } + + "socks4" -> { + builder.proxy(Proxy(Proxy.Type.SOCKS, address)) + } + + else -> throw UsageError("Unknown proxy '$it'") + } + } return builder.build() } @@ -129,6 +155,7 @@ class Main : CliktCommand(name = NAME, help = "A curl for the next-generation we return prop.getProperty("version", "dev") } + @Suppress("TrustAllX509TrustManager", "CustomX509TrustManager") private fun createInsecureTrustManager(): X509TrustManager = object : X509TrustManager { override fun checkClientTrusted( diff --git a/okhttp/src/main/kotlin/okhttp3/OkHttpClient.kt b/okhttp/src/main/kotlin/okhttp3/OkHttpClient.kt index 7f499c765c73..357366a475c7 100644 --- a/okhttp/src/main/kotlin/okhttp3/OkHttpClient.kt +++ b/okhttp/src/main/kotlin/okhttp3/OkHttpClient.kt @@ -356,6 +356,10 @@ open class OkHttpClient internal constructor( checkNotNull(certificateChainCleaner) { "certificateChainCleaner == null" } checkNotNull(x509TrustManager) { "x509TrustManager == null" } } + + if ((proxy?.type() ?: Proxy.Type.DIRECT) == Proxy.Type.DIRECT && socketFactory is SSLSocketFactory) { + Platform.get().log("socketFactory is SSLSocketFactory without Proxy", Platform.WARN) + } } /** Prepares the [request] to be executed at some point in the future. */ @@ -890,8 +894,6 @@ open class OkHttpClient internal constructor( */ fun socketFactory(socketFactory: SocketFactory) = apply { - require(socketFactory !is SSLSocketFactory) { "socketFactory instanceof SSLSocketFactory" } - if (socketFactory != this.socketFactory) { this.routeDatabase = null } diff --git a/okhttp/src/test/java/okhttp3/OkHttpClientTest.kt b/okhttp/src/test/java/okhttp3/OkHttpClientTest.kt index c29f58b69197..d77b25d4e896 100644 --- a/okhttp/src/test/java/okhttp3/OkHttpClientTest.kt +++ b/okhttp/src/test/java/okhttp3/OkHttpClientTest.kt @@ -229,11 +229,11 @@ class OkHttpClientTest { assertThat(response.body.string()).isEqualTo("abc") } - @Test fun sslSocketFactorySetAsSocketFactory() { - val builder = OkHttpClient.Builder() - assertFailsWith { - builder.socketFactory(SSLSocketFactory.getDefault()) - } + @Test + fun sslSocketFactorySetAsSocketFactory() { + OkHttpClient.Builder() + .socketFactory(SSLSocketFactory.getDefault()) + .build() } @Test fun noSslSocketFactoryConfigured() {