Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix memory leak in TCP connection #765

Merged
merged 6 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,25 @@ interface TcpSocket {
suspend fun receiveFully(buffer: ByteArray, offset: Int, length: Int)
suspend fun receiveAvailable(buffer: ByteArray, offset: Int, length: Int): Int

suspend fun startTls(tls: TLS): TcpSocket

fun close()

sealed class TLS {
/** Used for Lightning connections */
data object DISABLED : TLS()

sealed class ENABLED : TLS()

/** Used for Electrum servers when expecting a valid certificate */
data class TRUSTED_CERTIFICATES(
/**
* Specify an expectedHostName when it's different than the `host` value you used
* within TcpSocket.Builder.connect(). This may be the case when using Tor.
*/
val expectedHostName: String? = null
) : TLS()
) : ENABLED()

/** Only used in unit tests */
data object UNSAFE_CERTIFICATES : TLS()
data object UNSAFE_CERTIFICATES : ENABLED()

/**
* Used for Electrum servers when expecting a specific public key
Expand All @@ -46,7 +46,7 @@ interface TcpSocket {
* (I.e. same as PEM format, without BEGIN/END header/footer)
*/
val pubKey: String
) : TLS() {
) : ENABLED() {
override fun toString(): String {
return "PINNED_PUBLIC_KEY(pubKey=${pubKey.take(64)}...}"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,6 @@ class ElectrumClientTest : LightningTestSuite() {

override suspend fun receiveFully(buffer: ByteArray, offset: Int, length: Int) = TODO("Not yet implemented")
override suspend fun send(bytes: ByteArray?, offset: Int, length: Int, flush: Boolean) = TODO("Not yet implemented")
override suspend fun startTls(tls: TcpSocket.TLS): TcpSocket = TODO("Not yet implemented")
override fun close() = TODO("Not yet implemented")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,29 +94,6 @@ class IosTcpSocket @OptIn(ExperimentalForeignApi::class) constructor(private val
)
}

@OptIn(ExperimentalForeignApi::class)
override suspend fun startTls(
tls: TcpSocket.TLS
): TcpSocket = suspendCancellableCoroutine { continuation ->

// @kotlinx.cinterop.ObjCMethod
// public open external fun startTLSWithTls(
// tls: swift.phoenix_crypto.NativeSocketTLS,
// success: (swift.phoenix_crypto.NativeSocket?) -> kotlin.Unit,
// failure: (swift.phoenix_crypto.NativeSocketError?) -> kotlin.Unit
// ): kotlin.Unit { /* compiled code */ }

socket.startTLSWithTls(
tls = tls.toNativeSocketTLS(),
success = { newSocket ->
continuation.resume(IosTcpSocket(newSocket!!))
},
failure = { error ->
continuation.resumeWithException(error!!.toIOException())
}
)
}

@OptIn(ExperimentalForeignApi::class)
override fun close() {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,7 @@ import java.util.*
import javax.net.ssl.TrustManagerFactory
import javax.net.ssl.X509TrustManager

class JvmTcpSocket(val socket: Socket, val loggerFactory: LoggerFactory) : TcpSocket {

private val logger = loggerFactory.newLogger(this::class)
class JvmTcpSocket(val socket: Socket) : TcpSocket {

private val connection = socket.connection()

Expand Down Expand Up @@ -73,39 +71,13 @@ class JvmTcpSocket(val socket: Socket, val loggerFactory: LoggerFactory) : TcpSo
.takeUnless { it == -1 } ?: throw TcpSocket.IOException.ConnectionClosed()
}

override suspend fun startTls(tls: TcpSocket.TLS): TcpSocket = try {
when (tls) {
is TcpSocket.TLS.TRUSTED_CERTIFICATES -> JvmTcpSocket(connection.tls(tlsContext(logger)), loggerFactory)
TcpSocket.TLS.UNSAFE_CERTIFICATES -> JvmTcpSocket(connection.tls(tlsContext(logger)) {
logger.warning { "using unsafe TLS!" }
trustManager = unsafeX509TrustManager()
}, loggerFactory)
is TcpSocket.TLS.PINNED_PUBLIC_KEY -> {
JvmTcpSocket(connection.tls(tlsContext(logger), tlsConfigForPinnedCert(tls.pubKey, logger)), loggerFactory)
}
TcpSocket.TLS.DISABLED -> this
}
} catch (e: Exception) {
throw when (e) {
is ConnectException -> TcpSocket.IOException.ConnectionRefused(e)
is SocketException -> TcpSocket.IOException.Unknown(e.message, e)
else -> e
}
}

override fun close() {
// NB: this safely calls close(), wrapping it into a try/catch.
socket.dispose()
}

companion object {

fun unsafeX509TrustManager() = object : X509TrustManager {
override fun checkClientTrusted(p0: Array<out X509Certificate>?, p1: String?) {}
override fun checkServerTrusted(p0: Array<out X509Certificate>?, p1: String?) {}
override fun getAcceptedIssuers(): Array<X509Certificate>? = null
}

fun buildPublicKey(encodedKey: ByteArray, logger: Logger): java.security.PublicKey {
val spec = X509EncodedKeySpec(encodedKey)
val algorithms = listOf("RSA", "EC", "DiffieHellman", "DSA", "RSASSA-PSS", "XDH", "X25519", "X448")
Expand All @@ -119,7 +91,7 @@ class JvmTcpSocket(val socket: Socket, val loggerFactory: LoggerFactory) : TcpSo
throw IllegalArgumentException("unsupported key's algorithm, only $algorithms")
}

fun tlsConfigForPinnedCert(pinnedPubkey: String, logger: Logger): TLSConfig = TLSConfigBuilder().apply {
private fun tlsConfigForPinnedCert(pinnedPubkey: String, logger: Logger): TLSConfig = TLSConfigBuilder().apply {
// build a default X509 trust manager.
val factory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm())!!
factory.init(null as KeyStore?)
Expand Down Expand Up @@ -157,30 +129,49 @@ class JvmTcpSocket(val socket: Socket, val loggerFactory: LoggerFactory) : TcpSo
override fun getAcceptedIssuers(): Array<X509Certificate> = defaultX509TrustManager.acceptedIssuers
}
}.build()

private fun tlsConfigForUnsafeCertificates(): TLSConfig = TLSConfigBuilder().apply {
trustManager = object : X509TrustManager {
override fun checkClientTrusted(p0: Array<out X509Certificate>?, p1: String?) {}
override fun checkServerTrusted(p0: Array<out X509Certificate>?, p1: String?) {}
override fun getAcceptedIssuers(): Array<X509Certificate>? = null
}
}.build()

fun buildTlsConfigFor(host: String, tls: TcpSocket.TLS.ENABLED, logger: Logger) = when (tls) {
is TcpSocket.TLS.TRUSTED_CERTIFICATES -> TLSConfigBuilder().build()
is TcpSocket.TLS.PINNED_PUBLIC_KEY -> {
logger.warning { "using unsafe TLS for connection with $host" }
tlsConfigForPinnedCert(tls.pubKey, logger)
}
is TcpSocket.TLS.UNSAFE_CERTIFICATES -> {
logger.info { "using certificate pinning for connections with $host" }
tlsConfigForUnsafeCertificates()
}
}

}
}

internal actual object PlatformSocketBuilder : TcpSocket.Builder {

private val Selector = SelectorManager(Dispatchers.IO)

actual override suspend fun connect(host: String, port: Int, tls: TcpSocket.TLS, loggerFactory: LoggerFactory): TcpSocket {
val logger = loggerFactory.newLogger(this::class)
return withContext(Dispatchers.IO) {
var socket: Socket? = null
try {
val socket = aSocket(SelectorManager(Dispatchers.IO)).tcp().connect(host, port).let { socket ->
when (tls) {
is TcpSocket.TLS.TRUSTED_CERTIFICATES -> socket.tls(tlsContext(logger))
TcpSocket.TLS.UNSAFE_CERTIFICATES -> socket.tls(tlsContext(logger)) {
logger.warning { "using unsafe TLS!" }
trustManager = JvmTcpSocket.unsafeX509TrustManager()
socket = aSocket(Selector).tcp().connect(host, port)
.let {
when (tls) {
is TcpSocket.TLS.DISABLED -> it
is TcpSocket.TLS.ENABLED -> it.tls(tlsContext(logger), JvmTcpSocket.buildTlsConfigFor(host, tls, logger))
}
is TcpSocket.TLS.PINNED_PUBLIC_KEY -> {
logger.info { "using certificate pinning for connections with $host" }
socket.tls(tlsContext(logger), JvmTcpSocket.tlsConfigForPinnedCert(tls.pubKey, logger))
}
else -> socket
}
}
JvmTcpSocket(socket, loggerFactory)
JvmTcpSocket(socket)
} catch (e: Exception) {
socket?.dispose()
throw when (e) {
is ConnectException -> TcpSocket.IOException.ConnectionRefused(e)
is SocketException -> TcpSocket.IOException.Unknown(e.message, e)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ class KtorNoTlsTcpSocket(private val socket: Socket) : TcpSocket {
.takeUnless { it == -1 } ?: throw TcpSocket.IOException.ConnectionClosed()
}

override suspend fun startTls(tls: TcpSocket.TLS): TcpSocket = TODO("TLS not supported")

override fun close() {
// NB: this safely calls close(), wrapping it into a try/catch.
socket.dispose()
Expand All @@ -72,22 +70,26 @@ class KtorNoTlsTcpSocket(private val socket: Socket) : TcpSocket {
}

internal object KtorSocketBuilder : TcpSocket.Builder {

private val Selector = SelectorManager(Dispatchers.IO)

override suspend fun connect(host: String, port: Int, tls: TcpSocket.TLS, loggerFactory: LoggerFactory): TcpSocket {
return withContext(Dispatchers.IO) {
var socket: Socket? = null
try {
val socket = aSocket(SelectorManager(Dispatchers.IO)).tcp().connect(host, port,
socket = aSocket(Selector).tcp().connect(
host, port,
configure = {
keepAlive = true
socketTimeout = 15.seconds.inWholeMilliseconds
noDelay = true
}).let { socket ->
when (tls) {
is TcpSocket.TLS.DISABLED -> socket
else -> TODO("TLS not supported")
}
})
when (tls) {
is TcpSocket.TLS.DISABLED -> KtorNoTlsTcpSocket(socket)
is TcpSocket.TLS.ENABLED -> TODO("TLS not supported")
}
KtorNoTlsTcpSocket(socket)
} catch (e: Exception) {
socket?.dispose()
throw e
}
}
Expand Down