Skip to content

Commit d4017a6

Browse files
authored
Fix memory leak in TCP connection (#765)
We instantiated a new `SelectorManager` at each connection attempt, leaking resources even if each socket was properly cleaned up on disconnection. The leak affects both jvm and native. Also removes the `startTls` function that is not used anymore since we moved Tor outside phoenix (ACINQ/phoenix#662).
1 parent d067b1f commit d4017a6

File tree

5 files changed

+50
-81
lines changed

5 files changed

+50
-81
lines changed

modules/core/src/commonMain/kotlin/fr/acinq/lightning/io/TcpSocket.kt

+5-5
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,25 @@ interface TcpSocket {
1616
suspend fun receiveFully(buffer: ByteArray, offset: Int, length: Int)
1717
suspend fun receiveAvailable(buffer: ByteArray, offset: Int, length: Int): Int
1818

19-
suspend fun startTls(tls: TLS): TcpSocket
20-
2119
fun close()
2220

2321
sealed class TLS {
2422
/** Used for Lightning connections */
2523
data object DISABLED : TLS()
2624

25+
sealed class ENABLED : TLS()
26+
2727
/** Used for Electrum servers when expecting a valid certificate */
2828
data class TRUSTED_CERTIFICATES(
2929
/**
3030
* Specify an expectedHostName when it's different than the `host` value you used
3131
* within TcpSocket.Builder.connect(). This may be the case when using Tor.
3232
*/
3333
val expectedHostName: String? = null
34-
) : TLS()
34+
) : ENABLED()
3535

3636
/** Only used in unit tests */
37-
data object UNSAFE_CERTIFICATES : TLS()
37+
data object UNSAFE_CERTIFICATES : ENABLED()
3838

3939
/**
4040
* Used for Electrum servers when expecting a specific public key
@@ -46,7 +46,7 @@ interface TcpSocket {
4646
* (I.e. same as PEM format, without BEGIN/END header/footer)
4747
*/
4848
val pubKey: String
49-
) : TLS() {
49+
) : ENABLED() {
5050
override fun toString(): String {
5151
return "PINNED_PUBLIC_KEY(pubKey=${pubKey.take(64)}...}"
5252
}

modules/core/src/commonTest/kotlin/fr/acinq/lightning/blockchain/electrum/ElectrumClientTest.kt

-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,6 @@ class ElectrumClientTest : LightningTestSuite() {
246246

247247
override suspend fun receiveFully(buffer: ByteArray, offset: Int, length: Int) = TODO("Not yet implemented")
248248
override suspend fun send(bytes: ByteArray?, offset: Int, length: Int, flush: Boolean) = TODO("Not yet implemented")
249-
override suspend fun startTls(tls: TcpSocket.TLS): TcpSocket = TODO("Not yet implemented")
250249
override fun close() = TODO("Not yet implemented")
251250
}
252251

modules/core/src/iosMain/kotlin/fr/acinq/lightning/io/IosTcpSocket.kt

-23
Original file line numberDiff line numberDiff line change
@@ -94,29 +94,6 @@ class IosTcpSocket @OptIn(ExperimentalForeignApi::class) constructor(private val
9494
)
9595
}
9696

97-
@OptIn(ExperimentalForeignApi::class)
98-
override suspend fun startTls(
99-
tls: TcpSocket.TLS
100-
): TcpSocket = suspendCancellableCoroutine { continuation ->
101-
102-
// @kotlinx.cinterop.ObjCMethod
103-
// public open external fun startTLSWithTls(
104-
// tls: swift.phoenix_crypto.NativeSocketTLS,
105-
// success: (swift.phoenix_crypto.NativeSocket?) -> kotlin.Unit,
106-
// failure: (swift.phoenix_crypto.NativeSocketError?) -> kotlin.Unit
107-
// ): kotlin.Unit { /* compiled code */ }
108-
109-
socket.startTLSWithTls(
110-
tls = tls.toNativeSocketTLS(),
111-
success = { newSocket ->
112-
continuation.resume(IosTcpSocket(newSocket!!))
113-
},
114-
failure = { error ->
115-
continuation.resumeWithException(error!!.toIOException())
116-
}
117-
)
118-
}
119-
12097
@OptIn(ExperimentalForeignApi::class)
12198
override fun close() {
12299

modules/core/src/jvmMain/kotlin/fr/acinq/lightning/io/JvmTcpSocket.kt

+34-43
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@ import java.util.*
2121
import javax.net.ssl.TrustManagerFactory
2222
import javax.net.ssl.X509TrustManager
2323

24-
class JvmTcpSocket(val socket: Socket, val loggerFactory: LoggerFactory) : TcpSocket {
25-
26-
private val logger = loggerFactory.newLogger(this::class)
24+
class JvmTcpSocket(val socket: Socket) : TcpSocket {
2725

2826
private val connection = socket.connection()
2927

@@ -73,39 +71,13 @@ class JvmTcpSocket(val socket: Socket, val loggerFactory: LoggerFactory) : TcpSo
7371
.takeUnless { it == -1 } ?: throw TcpSocket.IOException.ConnectionClosed()
7472
}
7573

76-
override suspend fun startTls(tls: TcpSocket.TLS): TcpSocket = try {
77-
when (tls) {
78-
is TcpSocket.TLS.TRUSTED_CERTIFICATES -> JvmTcpSocket(connection.tls(tlsContext(logger)), loggerFactory)
79-
TcpSocket.TLS.UNSAFE_CERTIFICATES -> JvmTcpSocket(connection.tls(tlsContext(logger)) {
80-
logger.warning { "using unsafe TLS!" }
81-
trustManager = unsafeX509TrustManager()
82-
}, loggerFactory)
83-
is TcpSocket.TLS.PINNED_PUBLIC_KEY -> {
84-
JvmTcpSocket(connection.tls(tlsContext(logger), tlsConfigForPinnedCert(tls.pubKey, logger)), loggerFactory)
85-
}
86-
TcpSocket.TLS.DISABLED -> this
87-
}
88-
} catch (e: Exception) {
89-
throw when (e) {
90-
is ConnectException -> TcpSocket.IOException.ConnectionRefused(e)
91-
is SocketException -> TcpSocket.IOException.Unknown(e.message, e)
92-
else -> e
93-
}
94-
}
95-
9674
override fun close() {
9775
// NB: this safely calls close(), wrapping it into a try/catch.
9876
socket.dispose()
9977
}
10078

10179
companion object {
10280

103-
fun unsafeX509TrustManager() = object : X509TrustManager {
104-
override fun checkClientTrusted(p0: Array<out X509Certificate>?, p1: String?) {}
105-
override fun checkServerTrusted(p0: Array<out X509Certificate>?, p1: String?) {}
106-
override fun getAcceptedIssuers(): Array<X509Certificate>? = null
107-
}
108-
10981
fun buildPublicKey(encodedKey: ByteArray, logger: Logger): java.security.PublicKey {
11082
val spec = X509EncodedKeySpec(encodedKey)
11183
val algorithms = listOf("RSA", "EC", "DiffieHellman", "DSA", "RSASSA-PSS", "XDH", "X25519", "X448")
@@ -119,7 +91,7 @@ class JvmTcpSocket(val socket: Socket, val loggerFactory: LoggerFactory) : TcpSo
11991
throw IllegalArgumentException("unsupported key's algorithm, only $algorithms")
12092
}
12193

122-
fun tlsConfigForPinnedCert(pinnedPubkey: String, logger: Logger): TLSConfig = TLSConfigBuilder().apply {
94+
private fun tlsConfigForPinnedCert(pinnedPubkey: String, logger: Logger): TLSConfig = TLSConfigBuilder().apply {
12395
// build a default X509 trust manager.
12496
val factory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())!!
12597
factory.init(null as KeyStore?)
@@ -157,30 +129,49 @@ class JvmTcpSocket(val socket: Socket, val loggerFactory: LoggerFactory) : TcpSo
157129
override fun getAcceptedIssuers(): Array<X509Certificate> = defaultX509TrustManager.acceptedIssuers
158130
}
159131
}.build()
132+
133+
private fun tlsConfigForUnsafeCertificates(): TLSConfig = TLSConfigBuilder().apply {
134+
trustManager = object : X509TrustManager {
135+
override fun checkClientTrusted(p0: Array<out X509Certificate>?, p1: String?) {}
136+
override fun checkServerTrusted(p0: Array<out X509Certificate>?, p1: String?) {}
137+
override fun getAcceptedIssuers(): Array<X509Certificate>? = null
138+
}
139+
}.build()
140+
141+
fun buildTlsConfigFor(host: String, tls: TcpSocket.TLS.ENABLED, logger: Logger) = when (tls) {
142+
is TcpSocket.TLS.TRUSTED_CERTIFICATES -> TLSConfigBuilder().build()
143+
is TcpSocket.TLS.PINNED_PUBLIC_KEY -> {
144+
logger.warning { "using unsafe TLS for connection with $host" }
145+
tlsConfigForPinnedCert(tls.pubKey, logger)
146+
}
147+
is TcpSocket.TLS.UNSAFE_CERTIFICATES -> {
148+
logger.info { "using certificate pinning for connections with $host" }
149+
tlsConfigForUnsafeCertificates()
150+
}
151+
}
152+
160153
}
161154
}
162155

163156
internal actual object PlatformSocketBuilder : TcpSocket.Builder {
157+
158+
private val Selector = SelectorManager(Dispatchers.IO)
159+
164160
actual override suspend fun connect(host: String, port: Int, tls: TcpSocket.TLS, loggerFactory: LoggerFactory): TcpSocket {
165161
val logger = loggerFactory.newLogger(this::class)
166162
return withContext(Dispatchers.IO) {
163+
var socket: Socket? = null
167164
try {
168-
val socket = aSocket(SelectorManager(Dispatchers.IO)).tcp().connect(host, port).let { socket ->
169-
when (tls) {
170-
is TcpSocket.TLS.TRUSTED_CERTIFICATES -> socket.tls(tlsContext(logger))
171-
TcpSocket.TLS.UNSAFE_CERTIFICATES -> socket.tls(tlsContext(logger)) {
172-
logger.warning { "using unsafe TLS!" }
173-
trustManager = JvmTcpSocket.unsafeX509TrustManager()
165+
socket = aSocket(Selector).tcp().connect(host, port)
166+
.let {
167+
when (tls) {
168+
is TcpSocket.TLS.DISABLED -> it
169+
is TcpSocket.TLS.ENABLED -> it.tls(tlsContext(logger), JvmTcpSocket.buildTlsConfigFor(host, tls, logger))
174170
}
175-
is TcpSocket.TLS.PINNED_PUBLIC_KEY -> {
176-
logger.info { "using certificate pinning for connections with $host" }
177-
socket.tls(tlsContext(logger), JvmTcpSocket.tlsConfigForPinnedCert(tls.pubKey, logger))
178-
}
179-
else -> socket
180171
}
181-
}
182-
JvmTcpSocket(socket, loggerFactory)
172+
JvmTcpSocket(socket)
183173
} catch (e: Exception) {
174+
socket?.dispose()
184175
throw when (e) {
185176
is ConnectException -> TcpSocket.IOException.ConnectionRefused(e)
186177
is SocketException -> TcpSocket.IOException.Unknown(e.message, e)

modules/core/src/nativeMain/kotlin/fr/acinq/lightning/io/KtorNoTlsTcpSocket.kt

+11-9
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,6 @@ class KtorNoTlsTcpSocket(private val socket: Socket) : TcpSocket {
6262
.takeUnless { it == -1 } ?: throw TcpSocket.IOException.ConnectionClosed()
6363
}
6464

65-
override suspend fun startTls(tls: TcpSocket.TLS): TcpSocket = TODO("TLS not supported")
66-
6765
override fun close() {
6866
// NB: this safely calls close(), wrapping it into a try/catch.
6967
socket.dispose()
@@ -72,22 +70,26 @@ class KtorNoTlsTcpSocket(private val socket: Socket) : TcpSocket {
7270
}
7371

7472
internal object KtorSocketBuilder : TcpSocket.Builder {
73+
74+
private val Selector = SelectorManager(Dispatchers.IO)
75+
7576
override suspend fun connect(host: String, port: Int, tls: TcpSocket.TLS, loggerFactory: LoggerFactory): TcpSocket {
7677
return withContext(Dispatchers.IO) {
78+
var socket: Socket? = null
7779
try {
78-
val socket = aSocket(SelectorManager(Dispatchers.IO)).tcp().connect(host, port,
80+
socket = aSocket(Selector).tcp().connect(
81+
host, port,
7982
configure = {
8083
keepAlive = true
8184
socketTimeout = 15.seconds.inWholeMilliseconds
8285
noDelay = true
83-
}).let { socket ->
84-
when (tls) {
85-
is TcpSocket.TLS.DISABLED -> socket
86-
else -> TODO("TLS not supported")
87-
}
86+
})
87+
when (tls) {
88+
is TcpSocket.TLS.DISABLED -> KtorNoTlsTcpSocket(socket)
89+
is TcpSocket.TLS.ENABLED -> TODO("TLS not supported")
8890
}
89-
KtorNoTlsTcpSocket(socket)
9091
} catch (e: Exception) {
92+
socket?.dispose()
9193
throw e
9294
}
9395
}

0 commit comments

Comments
 (0)