@@ -21,9 +21,7 @@ import java.util.*
21
21
import javax.net.ssl.TrustManagerFactory
22
22
import javax.net.ssl.X509TrustManager
23
23
24
- class JvmTcpSocket (val socket : Socket , val loggerFactory : LoggerFactory ) : TcpSocket {
25
-
26
- private val logger = loggerFactory.newLogger(this ::class )
24
+ class JvmTcpSocket (val socket : Socket ) : TcpSocket {
27
25
28
26
private val connection = socket.connection()
29
27
@@ -73,39 +71,13 @@ class JvmTcpSocket(val socket: Socket, val loggerFactory: LoggerFactory) : TcpSo
73
71
.takeUnless { it == - 1 } ? : throw TcpSocket .IOException .ConnectionClosed ()
74
72
}
75
73
76
- override suspend fun startTls (tls : TcpSocket .TLS ): TcpSocket = try {
77
- when (tls) {
78
- is TcpSocket .TLS .TRUSTED_CERTIFICATES -> JvmTcpSocket (connection.tls(tlsContext(logger)), loggerFactory)
79
- TcpSocket .TLS .UNSAFE_CERTIFICATES -> JvmTcpSocket (connection.tls(tlsContext(logger)) {
80
- logger.warning { " using unsafe TLS!" }
81
- trustManager = unsafeX509TrustManager()
82
- }, loggerFactory)
83
- is TcpSocket .TLS .PINNED_PUBLIC_KEY -> {
84
- JvmTcpSocket (connection.tls(tlsContext(logger), tlsConfigForPinnedCert(tls.pubKey, logger)), loggerFactory)
85
- }
86
- TcpSocket .TLS .DISABLED -> this
87
- }
88
- } catch (e: Exception ) {
89
- throw when (e) {
90
- is ConnectException -> TcpSocket .IOException .ConnectionRefused (e)
91
- is SocketException -> TcpSocket .IOException .Unknown (e.message, e)
92
- else -> e
93
- }
94
- }
95
-
96
74
override fun close () {
97
75
// NB: this safely calls close(), wrapping it into a try/catch.
98
76
socket.dispose()
99
77
}
100
78
101
79
companion object {
102
80
103
- fun unsafeX509TrustManager () = object : X509TrustManager {
104
- override fun checkClientTrusted (p0 : Array <out X509Certificate >? , p1 : String? ) {}
105
- override fun checkServerTrusted (p0 : Array <out X509Certificate >? , p1 : String? ) {}
106
- override fun getAcceptedIssuers (): Array <X509Certificate >? = null
107
- }
108
-
109
81
fun buildPublicKey (encodedKey : ByteArray , logger : Logger ): java.security.PublicKey {
110
82
val spec = X509EncodedKeySpec (encodedKey)
111
83
val algorithms = listOf (" RSA" , " EC" , " DiffieHellman" , " DSA" , " RSASSA-PSS" , " XDH" , " X25519" , " X448" )
@@ -119,7 +91,7 @@ class JvmTcpSocket(val socket: Socket, val loggerFactory: LoggerFactory) : TcpSo
119
91
throw IllegalArgumentException (" unsupported key's algorithm, only $algorithms " )
120
92
}
121
93
122
- fun tlsConfigForPinnedCert (pinnedPubkey : String , logger : Logger ): TLSConfig = TLSConfigBuilder ().apply {
94
+ private fun tlsConfigForPinnedCert (pinnedPubkey : String , logger : Logger ): TLSConfig = TLSConfigBuilder ().apply {
123
95
// build a default X509 trust manager.
124
96
val factory = TrustManagerFactory .getInstance(TrustManagerFactory .getDefaultAlgorithm())!!
125
97
factory.init (null as KeyStore ? )
@@ -157,30 +129,49 @@ class JvmTcpSocket(val socket: Socket, val loggerFactory: LoggerFactory) : TcpSo
157
129
override fun getAcceptedIssuers (): Array <X509Certificate > = defaultX509TrustManager.acceptedIssuers
158
130
}
159
131
}.build()
132
+
133
+ private fun tlsConfigForUnsafeCertificates (): TLSConfig = TLSConfigBuilder ().apply {
134
+ trustManager = object : X509TrustManager {
135
+ override fun checkClientTrusted (p0 : Array <out X509Certificate >? , p1 : String? ) {}
136
+ override fun checkServerTrusted (p0 : Array <out X509Certificate >? , p1 : String? ) {}
137
+ override fun getAcceptedIssuers (): Array <X509Certificate >? = null
138
+ }
139
+ }.build()
140
+
141
+ fun buildTlsConfigFor (host : String , tls : TcpSocket .TLS .ENABLED , logger : Logger ) = when (tls) {
142
+ is TcpSocket .TLS .TRUSTED_CERTIFICATES -> TLSConfigBuilder ().build()
143
+ is TcpSocket .TLS .PINNED_PUBLIC_KEY -> {
144
+ logger.warning { " using unsafe TLS for connection with $host " }
145
+ tlsConfigForPinnedCert(tls.pubKey, logger)
146
+ }
147
+ is TcpSocket .TLS .UNSAFE_CERTIFICATES -> {
148
+ logger.info { " using certificate pinning for connections with $host " }
149
+ tlsConfigForUnsafeCertificates()
150
+ }
151
+ }
152
+
160
153
}
161
154
}
162
155
163
156
internal actual object PlatformSocketBuilder : TcpSocket.Builder {
157
+
158
+ private val Selector = SelectorManager (Dispatchers .IO )
159
+
164
160
actual override suspend fun connect (host : String , port : Int , tls : TcpSocket .TLS , loggerFactory : LoggerFactory ): TcpSocket {
165
161
val logger = loggerFactory.newLogger(this ::class )
166
162
return withContext(Dispatchers .IO ) {
163
+ var socket: Socket ? = null
167
164
try {
168
- val socket = aSocket(SelectorManager (Dispatchers .IO )).tcp().connect(host, port).let { socket ->
169
- when (tls) {
170
- is TcpSocket .TLS .TRUSTED_CERTIFICATES -> socket.tls(tlsContext(logger))
171
- TcpSocket .TLS .UNSAFE_CERTIFICATES -> socket.tls(tlsContext(logger)) {
172
- logger.warning { " using unsafe TLS!" }
173
- trustManager = JvmTcpSocket .unsafeX509TrustManager()
165
+ socket = aSocket(Selector ).tcp().connect(host, port)
166
+ .let {
167
+ when (tls) {
168
+ is TcpSocket .TLS .DISABLED -> it
169
+ is TcpSocket .TLS .ENABLED -> it.tls(tlsContext(logger), JvmTcpSocket .buildTlsConfigFor(host, tls, logger))
174
170
}
175
- is TcpSocket .TLS .PINNED_PUBLIC_KEY -> {
176
- logger.info { " using certificate pinning for connections with $host " }
177
- socket.tls(tlsContext(logger), JvmTcpSocket .tlsConfigForPinnedCert(tls.pubKey, logger))
178
- }
179
- else -> socket
180
171
}
181
- }
182
- JvmTcpSocket (socket, loggerFactory)
172
+ JvmTcpSocket (socket)
183
173
} catch (e: Exception ) {
174
+ socket?.dispose()
184
175
throw when (e) {
185
176
is ConnectException -> TcpSocket .IOException .ConnectionRefused (e)
186
177
is SocketException -> TcpSocket .IOException .Unknown (e.message, e)
0 commit comments