diff --git a/lightning-kmp-test-fixtures/src/commonMain/kotlin/fr/acinq/lightning/tests/TestConstants.kt b/lightning-kmp-test-fixtures/src/commonMain/kotlin/fr/acinq/lightning/tests/TestConstants.kt index 62324f21a..97c93ed73 100644 --- a/lightning-kmp-test-fixtures/src/commonMain/kotlin/fr/acinq/lightning/tests/TestConstants.kt +++ b/lightning-kmp-test-fixtures/src/commonMain/kotlin/fr/acinq/lightning/tests/TestConstants.kt @@ -54,6 +54,7 @@ object TestConstants { Feature.StaticRemoteKey to FeatureSupport.Mandatory, Feature.AnchorOutputs to FeatureSupport.Mandatory, Feature.ChannelType to FeatureSupport.Mandatory, + Feature.PaymentMetadata to FeatureSupport.Optional, Feature.TrampolinePayment to FeatureSupport.Optional, Feature.WakeUpNotificationProvider to FeatureSupport.Optional, Feature.PayToOpenProvider to FeatureSupport.Optional, @@ -129,6 +130,7 @@ object TestConstants { Feature.StaticRemoteKey to FeatureSupport.Mandatory, Feature.AnchorOutputs to FeatureSupport.Mandatory, Feature.ChannelType to FeatureSupport.Mandatory, + Feature.PaymentMetadata to FeatureSupport.Optional, Feature.TrampolinePayment to FeatureSupport.Optional, Feature.WakeUpNotificationClient to FeatureSupport.Optional, Feature.PayToOpenClient to FeatureSupport.Optional, diff --git a/lightning-kmp-test-fixtures/src/commonMain/kotlin/fr/acinq/lightning/tests/io/peer/builders.kt b/lightning-kmp-test-fixtures/src/commonMain/kotlin/fr/acinq/lightning/tests/io/peer/builders.kt index a548eb8e4..01ca286a3 100644 --- a/lightning-kmp-test-fixtures/src/commonMain/kotlin/fr/acinq/lightning/tests/io/peer/builders.kt +++ b/lightning-kmp-test-fixtures/src/commonMain/kotlin/fr/acinq/lightning/tests/io/peer/builders.kt @@ -65,9 +65,9 @@ public suspend fun newPeers( } // Initialize Bob with Alice's features - bob.send(BytesReceived(LightningMessage.encode(Init(features = nodeParams.first.features.toByteArray().toByteVector())))) + bob.send(BytesReceived(LightningMessage.encode(Init(features = nodeParams.first.features.initFeatures().toByteArray().toByteVector())))) // Initialize Alice with Bob's features - alice.send(BytesReceived(LightningMessage.encode(Init(features = nodeParams.second.features.toByteArray().toByteVector())))) + alice.send(BytesReceived(LightningMessage.encode(Init(features = nodeParams.second.features.initFeatures().toByteArray().toByteVector())))) // TODO update to depend on the initChannels size if (initChannels.isNotEmpty()) { @@ -124,7 +124,7 @@ public suspend fun CoroutineScope.newPeer( remotedNodeChannelState?.let { state -> // send Init from remote node - val theirInit = Init(features = state.staticParams.nodeParams.features.toByteArray().toByteVector()) + val theirInit = Init(features = state.staticParams.nodeParams.features.initFeatures().toByteArray().toByteVector()) val initMsg = LightningMessage.encode(theirInit) peer.send(BytesReceived(initMsg)) diff --git a/src/commonMain/kotlin/fr/acinq/lightning/Features.kt b/src/commonMain/kotlin/fr/acinq/lightning/Features.kt index 213edc184..1a13ce546 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/Features.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/Features.kt @@ -6,6 +6,9 @@ import fr.acinq.lightning.utils.leftPaddedCopyOf import fr.acinq.lightning.utils.or import kotlinx.serialization.Serializable +/** Feature scope as defined in Bolt 9. */ +enum class FeatureScope { Init, Node, Invoice } + enum class FeatureSupport { Mandatory { override fun toString() = "mandatory" @@ -20,6 +23,7 @@ sealed class Feature { abstract val rfcName: String abstract val mandatory: Int + abstract val scopes: Set val optional: Int get() = mandatory + 1 fun supportBit(support: FeatureSupport): Int = when (support) { @@ -33,6 +37,7 @@ sealed class Feature { object OptionDataLossProtect : Feature() { override val rfcName get() = "option_data_loss_protect" override val mandatory get() = 0 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node) } @Serializable @@ -41,66 +46,84 @@ sealed class Feature { // reserved but not used as per lightningnetwork/lightning-rfc/pull/178 override val mandatory get() = 2 + override val scopes: Set get() = setOf(FeatureScope.Init) } @Serializable object ChannelRangeQueries : Feature() { override val rfcName get() = "gossip_queries" override val mandatory get() = 6 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node) } @Serializable object VariableLengthOnion : Feature() { override val rfcName get() = "var_onion_optin" override val mandatory get() = 8 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node, FeatureScope.Invoice) } @Serializable object ChannelRangeQueriesExtended : Feature() { override val rfcName get() = "gossip_queries_ex" override val mandatory get() = 10 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node) } @Serializable object StaticRemoteKey : Feature() { override val rfcName get() = "option_static_remotekey" override val mandatory get() = 12 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node) } @Serializable object PaymentSecret : Feature() { override val rfcName get() = "payment_secret" override val mandatory get() = 14 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node, FeatureScope.Invoice) } @Serializable object BasicMultiPartPayment : Feature() { override val rfcName get() = "basic_mpp" override val mandatory get() = 16 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node, FeatureScope.Invoice) } @Serializable object Wumbo : Feature() { override val rfcName get() = "option_support_large_channel" override val mandatory get() = 18 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node) } @Serializable object AnchorOutputs : Feature() { override val rfcName get() = "option_anchor_outputs" override val mandatory get() = 20 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node) } @Serializable object ShutdownAnySegwit : Feature() { override val rfcName get() = "option_shutdown_anysegwit" override val mandatory get() = 26 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node) } @Serializable object ChannelType : Feature() { override val rfcName get() = "option_channel_type" override val mandatory get() = 44 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node) + } + + @Serializable + object PaymentMetadata : Feature() { + override val rfcName get() = "option_payment_metadata" + override val mandatory get() = 48 + override val scopes: Set get() = setOf(FeatureScope.Invoice) } // The following features have not been standardised, hence the high feature bits to avoid conflicts. @@ -109,6 +132,7 @@ sealed class Feature { object TrampolinePayment : Feature() { override val rfcName get() = "trampoline_payment" override val mandatory get() = 50 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node, FeatureScope.Invoice) } /** This feature bit should be activated when a node accepts having their channel reserve set to 0. */ @@ -116,6 +140,7 @@ sealed class Feature { object ZeroReserveChannels : Feature() { override val rfcName get() = "zero_reserve_channels" override val mandatory get() = 128 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node) } /** This feature bit should be activated when a node accepts unconfirmed channels (will set min_depth to 0 in accept_channel). */ @@ -123,6 +148,7 @@ sealed class Feature { object ZeroConfChannels : Feature() { override val rfcName get() = "zero_conf_channels" override val mandatory get() = 130 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node) } /** This feature bit should be activated when a mobile node supports waking up via push notifications. */ @@ -130,6 +156,7 @@ sealed class Feature { object WakeUpNotificationClient : Feature() { override val rfcName get() = "wake_up_notification_client" override val mandatory get() = 132 + override val scopes: Set get() = setOf(FeatureScope.Init) } /** This feature bit should be activated when a node supports waking up their peers via push notifications. */ @@ -137,6 +164,7 @@ sealed class Feature { object WakeUpNotificationProvider : Feature() { override val rfcName get() = "wake_up_notification_provider" override val mandatory get() = 134 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node) } /** This feature bit should be activated when a node accepts on-the-fly channel creation. */ @@ -144,6 +172,7 @@ sealed class Feature { object PayToOpenClient : Feature() { override val rfcName get() = "pay_to_open_client" override val mandatory get() = 136 + override val scopes: Set get() = setOf(FeatureScope.Init) } /** This feature bit should be activated when a node supports opening channels on-the-fly when liquidity is missing to receive a payment. */ @@ -151,6 +180,7 @@ sealed class Feature { object PayToOpenProvider : Feature() { override val rfcName get() = "pay_to_open_provider" override val mandatory get() = 138 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node) } /** This feature bit should be activated when a node accepts channel creation via trusted swaps-in. */ @@ -158,6 +188,7 @@ sealed class Feature { object TrustedSwapInClient : Feature() { override val rfcName get() = "trusted_swap_in_client" override val mandatory get() = 140 + override val scopes: Set get() = setOf(FeatureScope.Init) } /** This feature bit should be activated when a node supports opening channels in exchange for on-chain funds (swap-in). */ @@ -165,6 +196,7 @@ sealed class Feature { object TrustedSwapInProvider : Feature() { override val rfcName get() = "trusted_swap_in_provider" override val mandatory get() = 142 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node) } /** This feature bit should be activated when a node wants to send channel backups to their peers. */ @@ -172,6 +204,7 @@ sealed class Feature { object ChannelBackupClient : Feature() { override val rfcName get() = "channel_backup_client" override val mandatory get() = 144 + override val scopes: Set get() = setOf(FeatureScope.Init) } /** This feature bit should be activated when a node stores channel backups for their peers. */ @@ -179,6 +212,7 @@ sealed class Feature { object ChannelBackupProvider : Feature() { override val rfcName get() = "channel_backup_provider" override val mandatory get() = 146 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node) } } @@ -188,9 +222,16 @@ data class UnknownFeature(val bitIndex: Int) @Serializable data class Features(val activated: Map, val unknown: Set = emptySet()) { - fun hasFeature(feature: Feature, support: FeatureSupport? = null): Boolean = - if (support != null) activated[feature] == support - else activated.containsKey(feature) + fun hasFeature(feature: Feature, support: FeatureSupport? = null): Boolean = when (support) { + null -> activated.containsKey(feature) + else -> activated[feature] == support + } + + fun initFeatures(): Features = Features(activated.filter { it.key.scopes.contains(FeatureScope.Init) }, unknown) + + fun nodeAnnouncementFeatures(): Features = Features(activated.filter { it.key.scopes.contains(FeatureScope.Node) }, unknown) + + fun invoiceFeatures(): Features = Features(activated.filter { it.key.scopes.contains(FeatureScope.Invoice) }, unknown) /** NB: this method is not reflexive, see [[Features.areCompatible]] if you want symmetric validation. */ fun areSupported(remoteFeatures: Features): Boolean { @@ -236,6 +277,7 @@ data class Features(val activated: Map, val unknown: Se Feature.AnchorOutputs, Feature.ShutdownAnySegwit, Feature.ChannelType, + Feature.PaymentMetadata, Feature.TrampolinePayment, Feature.ZeroReserveChannels, Feature.ZeroConfChannels, diff --git a/src/commonMain/kotlin/fr/acinq/lightning/channel/Commitments.kt b/src/commonMain/kotlin/fr/acinq/lightning/channel/Commitments.kt index 703d25b81..e2605756e 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/channel/Commitments.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/channel/Commitments.kt @@ -11,7 +11,7 @@ import fr.acinq.lightning.blockchain.fee.FeerateTolerance import fr.acinq.lightning.crypto.Generators import fr.acinq.lightning.crypto.KeyManager import fr.acinq.lightning.crypto.ShaChain -import fr.acinq.lightning.payment.OutgoingPacket +import fr.acinq.lightning.payment.OutgoingPaymentPacket import fr.acinq.lightning.transactions.CommitmentSpec import fr.acinq.lightning.transactions.Transactions import fr.acinq.lightning.transactions.Transactions.TransactionWithInputInfo.CommitTx @@ -370,7 +370,7 @@ data class Commitments( // we have already sent a fail/fulfill for this htlc alreadyProposed(localChanges.proposed, htlc.id) -> Either.Left(UnknownHtlcId(channelId, cmd.id)) else -> { - when (val result = OutgoingPacket.buildHtlcFailure(nodeSecret, htlc.paymentHash, htlc.onionRoutingPacket, cmd.reason)) { + when (val result = OutgoingPaymentPacket.buildHtlcFailure(nodeSecret, htlc.paymentHash, htlc.onionRoutingPacket, cmd.reason)) { is Either.Right -> { val fail = UpdateFailHtlc(channelId, cmd.id, result.value) val commitments1 = addLocalProposal(fail) diff --git a/src/commonMain/kotlin/fr/acinq/lightning/io/Peer.kt b/src/commonMain/kotlin/fr/acinq/lightning/io/Peer.kt index 76f76817d..ff6df1c2a 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/io/Peer.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/io/Peer.kt @@ -115,7 +115,7 @@ class Peer( private val features = nodeParams.features - private val ourInit = Init(features.toByteArray().toByteVector()) + private val ourInit = Init(features.initFeatures().toByteArray().toByteVector()) private var theirInit: Init? = null public val currentTipFlow = MutableStateFlow?>(null) diff --git a/src/commonMain/kotlin/fr/acinq/lightning/payment/IncomingPaymentHandler.kt b/src/commonMain/kotlin/fr/acinq/lightning/payment/IncomingPaymentHandler.kt index 4e51d710d..1689a12c0 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/payment/IncomingPaymentHandler.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/payment/IncomingPaymentHandler.kt @@ -1,10 +1,15 @@ package fr.acinq.lightning.payment +import fr.acinq.bitcoin.ByteVector import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.Crypto import fr.acinq.bitcoin.PrivateKey -import fr.acinq.lightning.* +import fr.acinq.lightning.CltvExpiry +import fr.acinq.lightning.Lightning.randomBytes import fr.acinq.lightning.Lightning.randomBytes32 +import fr.acinq.lightning.MilliSatoshi +import fr.acinq.lightning.NodeParams +import fr.acinq.lightning.WalletParams import fr.acinq.lightning.channel.* import fr.acinq.lightning.db.IncomingPayment import fr.acinq.lightning.db.IncomingPaymentsDb @@ -18,16 +23,16 @@ sealed class PaymentPart { abstract val amount: MilliSatoshi abstract val totalAmount: MilliSatoshi abstract val paymentHash: ByteVector32 - abstract val finalPayload: FinalPayload + abstract val finalPayload: PaymentOnion.FinalPayload } -data class HtlcPart(val htlc: UpdateAddHtlc, override val finalPayload: FinalPayload) : PaymentPart() { +data class HtlcPart(val htlc: UpdateAddHtlc, override val finalPayload: PaymentOnion.FinalPayload) : PaymentPart() { override val amount: MilliSatoshi = htlc.amountMsat override val totalAmount: MilliSatoshi = finalPayload.totalAmount override val paymentHash: ByteVector32 = htlc.paymentHash } -data class PayToOpenPart(val payToOpenRequest: PayToOpenRequest, override val finalPayload: FinalPayload) : PaymentPart() { +data class PayToOpenPart(val payToOpenRequest: PayToOpenRequest, override val finalPayload: PaymentOnion.FinalPayload) : PaymentPart() { override val amount: MilliSatoshi = payToOpenRequest.amountMsat override val totalAmount: MilliSatoshi = finalPayload.totalAmount override val paymentHash: ByteVector32 = payToOpenRequest.paymentHash @@ -75,7 +80,6 @@ class IncomingPaymentHandler(val nodeParams: NodeParams, val walletParams: Walle timestampSeconds: Long = currentTimestampSeconds() ): PaymentRequest { val paymentHash = Crypto.sha256(paymentPreimage).toByteVector32() - val invoiceFeatures = PaymentRequest.invoiceFeatures(nodeParams.features) logger.debug { "h:$paymentHash using routing hints $extraHops" } val pr = PaymentRequest.create( nodeParams.chainHash, @@ -84,8 +88,10 @@ class IncomingPaymentHandler(val nodeParams: NodeParams, val walletParams: Walle nodeParams.nodePrivateKey, description, PaymentRequest.DEFAULT_MIN_FINAL_EXPIRY_DELTA, - invoiceFeatures, + nodeParams.features.invoiceFeatures(), randomBytes32(), + // We always include a payment metadata in our invoices, which lets us test whether senders support it + ByteVector("2a"), expirySeconds, extraHops, timestampSeconds @@ -246,7 +252,10 @@ class IncomingPaymentHandler(val nodeParams: NodeParams, val walletParams: Walle } else -> { // We have received all the payment parts. - logger.info { "h:${paymentPart.paymentHash} payment received (${payment.amountReceived})" } + when (val paymentMetadata = paymentPart.finalPayload.paymentMetadata) { + null -> logger.info { "h:${paymentPart.paymentHash} payment received (${payment.amountReceived}) without payment metadata" } + else -> logger.info { "h:${paymentPart.paymentHash} payment received (${payment.amountReceived}) with payment metadata ($paymentMetadata)" } + } val (actions, receivedWith) = payment.parts.map { part -> when (part) { is HtlcPart -> { @@ -352,7 +361,7 @@ class IncomingPaymentHandler(val nodeParams: NodeParams, val walletParams: Walle fun checkPaymentsTimeout(currentTimestampSeconds: Long): List { val actions = mutableListOf() - val keysToRemove = mutableListOf() + val keysToRemove = mutableSetOf() // BOLT 04: // - MUST fail all HTLCs in the HTLC set after some reasonable timeout. @@ -379,7 +388,7 @@ class IncomingPaymentHandler(val nodeParams: NodeParams, val walletParams: Walle /** Convert an incoming htlc to a payment part abstraction. Payment parts are then summed together to reach the full payment amount. */ private fun toPaymentPart(privateKey: PrivateKey, htlc: UpdateAddHtlc): Either { // NB: IncomingPacket.decrypt does additional validation on top of IncomingPacket.decryptOnion - return when (val decrypted = IncomingPacket.decrypt(htlc, privateKey)) { + return when (val decrypted = IncomingPaymentPacket.decrypt(htlc, privateKey)) { is Either.Left -> { // Unable to decrypt onion val failureMsg = decrypted.value val action = actionForFailureMessage(failureMsg, htlc) @@ -394,7 +403,7 @@ class IncomingPaymentHandler(val nodeParams: NodeParams, val walletParams: Walle * This is very similar to the processing of a htlc, except that we only have a packet, to decrypt into a final payload. */ private fun toPaymentPart(privateKey: PrivateKey, payToOpenRequest: PayToOpenRequest): Either { - return when (val decrypted = IncomingPacket.decryptOnion(payToOpenRequest.paymentHash, payToOpenRequest.finalPacket, payToOpenRequest.finalPacket.payload.size(), privateKey)) { + return when (val decrypted = IncomingPaymentPacket.decryptOnion(payToOpenRequest.paymentHash, payToOpenRequest.finalPacket, payToOpenRequest.finalPacket.payload.size(), privateKey)) { is Either.Left -> { val failureMsg = decrypted.value val action = actionForPayToOpenFailure(privateKey, failureMsg, payToOpenRequest) @@ -424,7 +433,7 @@ class IncomingPaymentHandler(val nodeParams: NodeParams, val walletParams: Walle private fun actionForPayToOpenFailure(privateKey: PrivateKey, failure: FailureMessage, payToOpenRequest: PayToOpenRequest): PayToOpenResponseEvent { val reason = CMD_FAIL_HTLC.Reason.Failure(failure) - val encryptedReason = when (val result = OutgoingPacket.buildHtlcFailure(privateKey, payToOpenRequest.paymentHash, payToOpenRequest.finalPacket, reason)) { + val encryptedReason = when (val result = OutgoingPaymentPacket.buildHtlcFailure(privateKey, payToOpenRequest.paymentHash, payToOpenRequest.finalPacket, reason)) { is Either.Right -> result.value is Either.Left -> null } diff --git a/src/commonMain/kotlin/fr/acinq/lightning/payment/IncomingPacket.kt b/src/commonMain/kotlin/fr/acinq/lightning/payment/IncomingPaymentPacket.kt similarity index 81% rename from src/commonMain/kotlin/fr/acinq/lightning/payment/IncomingPacket.kt rename to src/commonMain/kotlin/fr/acinq/lightning/payment/IncomingPaymentPacket.kt index 6232cfc4b..662e8490b 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/payment/IncomingPacket.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/payment/IncomingPaymentPacket.kt @@ -6,7 +6,7 @@ import fr.acinq.lightning.crypto.sphinx.Sphinx import fr.acinq.lightning.utils.Either import fr.acinq.lightning.wire.* -object IncomingPacket { +object IncomingPaymentPacket { /** * Decrypt the onion packet of a received htlc. We expect to be the final recipient, and we validate that the HTLC @@ -18,12 +18,12 @@ object IncomingPacket { * - a decrypted and valid onion final payload * - or a Bolt4 failure message that can be returned to the sender if the HTLC is invalid */ - fun decrypt(add: UpdateAddHtlc, privateKey: PrivateKey): Either { + fun decrypt(add: UpdateAddHtlc, privateKey: PrivateKey): Either { return when (val decrypted = decryptOnion(add.paymentHash, add.onionRoutingPacket, OnionRoutingPacket.PaymentPacketLength, privateKey)) { is Either.Left -> Either.Left(decrypted.value) is Either.Right -> { val outer = decrypted.value - when (val trampolineOnion = outer.records.get()) { + when (val trampolineOnion = outer.records.get()) { null -> validate(add, outer) else -> { when (val inner = decryptOnion(add.paymentHash, trampolineOnion.packet, OnionRoutingPacket.TrampolinePacketLength, privateKey)) { @@ -37,7 +37,7 @@ object IncomingPacket { } @OptIn(ExperimentalUnsignedTypes::class) - fun decryptOnion(paymentHash: ByteVector32, packet: OnionRoutingPacket, packetLength: Int, privateKey: PrivateKey): Either { + fun decryptOnion(paymentHash: ByteVector32, packet: OnionRoutingPacket, packetLength: Int, privateKey: PrivateKey): Either { return when (val decrypted = Sphinx.peel(privateKey, paymentHash, packet, packetLength)) { is Either.Left -> Either.Left(decrypted.value) is Either.Right -> run { @@ -45,7 +45,7 @@ object IncomingPacket { Either.Left(UnknownNextPeer) } else { try { - Either.Right(FinalPayload.read(decrypted.value.payload.toByteArray())) + Either.Right(PaymentOnion.FinalPayload.read(decrypted.value.payload.toByteArray())) } catch (_: Throwable) { Either.Left(InvalidOnionPayload(0U, 0)) } @@ -54,7 +54,7 @@ object IncomingPacket { } } - private fun validate(add: UpdateAddHtlc, payload: FinalPayload): Either { + private fun validate(add: UpdateAddHtlc, payload: PaymentOnion.FinalPayload): Either { return when { add.amountMsat != payload.amount -> Either.Left(FinalIncorrectHtlcAmount(add.amountMsat)) add.cltvExpiry != payload.expiry -> Either.Left(FinalIncorrectCltvExpiry(add.cltvExpiry)) @@ -63,7 +63,7 @@ object IncomingPacket { } @OptIn(ExperimentalUnsignedTypes::class) - private fun validate(add: UpdateAddHtlc, outerPayload: FinalPayload, innerPayload: FinalPayload): Either { + private fun validate(add: UpdateAddHtlc, outerPayload: PaymentOnion.FinalPayload, innerPayload: PaymentOnion.FinalPayload): Either { return when { add.amountMsat != outerPayload.amount -> Either.Left(FinalIncorrectHtlcAmount(add.amountMsat)) add.cltvExpiry != outerPayload.expiry -> Either.Left(FinalIncorrectCltvExpiry(add.cltvExpiry)) @@ -74,7 +74,7 @@ object IncomingPacket { else -> { // We merge contents from the outer and inner payloads. // We must use the inner payload's total amount and payment secret because the payment may be split between multiple trampoline payments (#reckless). - Either.Right(FinalPayload.createMultiPartPayload(outerPayload.amount, innerPayload.totalAmount, outerPayload.expiry, innerPayload.paymentSecret)) + Either.Right(PaymentOnion.FinalPayload.createMultiPartPayload(outerPayload.amount, innerPayload.totalAmount, outerPayload.expiry, innerPayload.paymentSecret, innerPayload.paymentMetadata)) } } } diff --git a/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandler.kt b/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandler.kt index f3a84af4f..f40265c3d 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandler.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandler.kt @@ -294,7 +294,7 @@ class OutgoingPaymentHandler(val nodeId: PublicKey, val walletParams: WalletPara OutgoingPayment.Part.Status.Pending ) val channelHops: List = listOf(ChannelHop(nodeId, route.channel.staticParams.remoteNodeId, route.channel.channelUpdate)) - val (add, secrets) = OutgoingPacket.buildCommand(childId, request.paymentHash, channelHops, trampolinePayload.createFinalPayload(route.amount)) + val (add, secrets) = OutgoingPaymentPacket.buildCommand(childId, request.paymentHash, channelHops, trampolinePayload.createFinalPayload(route.amount)) return Triple(outgoingPayment, secrets, WrappedChannelEvent(route.channel.channelId, ChannelEvent.ExecuteCommand(add))) } @@ -312,13 +312,13 @@ class OutgoingPaymentHandler(val nodeId: PublicKey, val walletParams: WalletPara val finalExpiryDelta = request.details.paymentRequest.minFinalExpiryDelta ?: Channel.MIN_CLTV_EXPIRY_DELTA val finalExpiry = finalExpiryDelta.toCltvExpiry(currentBlockHeight.toLong()) - val finalPayload = FinalPayload.createSinglePartPayload(request.amount, finalExpiry, request.details.paymentRequest.paymentSecret) + val finalPayload = PaymentOnion.FinalPayload.createSinglePartPayload(request.amount, finalExpiry, request.details.paymentRequest.paymentSecret, request.details.paymentRequest.paymentMetadata) val invoiceFeatures = Features(request.details.paymentRequest.features) val (trampolineAmount, trampolineExpiry, trampolineOnion) = if (invoiceFeatures.hasFeature(Feature.TrampolinePayment)) { - OutgoingPacket.buildPacket(request.paymentHash, trampolineRoute, finalPayload, OnionRoutingPacket.TrampolinePacketLength) + OutgoingPaymentPacket.buildPacket(request.paymentHash, trampolineRoute, finalPayload, OnionRoutingPacket.TrampolinePacketLength) } else { - OutgoingPacket.buildTrampolineToLegacyPacket(request.details.paymentRequest, trampolineRoute, finalPayload) + OutgoingPaymentPacket.buildTrampolineToLegacyPacket(request.details.paymentRequest, trampolineRoute, finalPayload) } return Triple(trampolineAmount, trampolineExpiry, trampolineOnion.packet) } @@ -337,7 +337,7 @@ class OutgoingPaymentHandler(val nodeId: PublicKey, val walletParams: WalletPara * @param packet trampoline onion packet. */ data class TrampolinePayload(val totalAmount: MilliSatoshi, val expiry: CltvExpiry, val paymentSecret: ByteVector32, val packet: OnionRoutingPacket) { - fun createFinalPayload(partialAmount: MilliSatoshi): FinalPayload = FinalPayload.createTrampolinePayload(partialAmount, totalAmount, expiry, paymentSecret, packet) + fun createFinalPayload(partialAmount: MilliSatoshi): PaymentOnion.FinalPayload = PaymentOnion.FinalPayload.createTrampolinePayload(partialAmount, totalAmount, expiry, paymentSecret, packet) } /** diff --git a/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPacket.kt b/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentPacket.kt similarity index 75% rename from src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPacket.kt rename to src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentPacket.kt index bed2a8120..6a1e2f392 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPacket.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentPacket.kt @@ -9,7 +9,6 @@ import fr.acinq.lightning.Lightning import fr.acinq.lightning.MilliSatoshi import fr.acinq.lightning.channel.CMD_ADD_HTLC import fr.acinq.lightning.channel.CMD_FAIL_HTLC -import fr.acinq.lightning.channel.Channel import fr.acinq.lightning.crypto.sphinx.FailurePacket import fr.acinq.lightning.crypto.sphinx.PacketAndSecrets import fr.acinq.lightning.crypto.sphinx.SharedSecrets @@ -19,22 +18,24 @@ import fr.acinq.lightning.router.Hop import fr.acinq.lightning.router.NodeHop import fr.acinq.lightning.utils.Either import fr.acinq.lightning.utils.UUID -import fr.acinq.lightning.wire.* +import fr.acinq.lightning.wire.FailureMessage +import fr.acinq.lightning.wire.OnionRoutingPacket +import fr.acinq.lightning.wire.PaymentOnion -object OutgoingPacket { +object OutgoingPaymentPacket { /** * Build an encrypted onion packet from onion payloads and node public keys. */ - private fun buildOnion(nodes: List, payloads: List, associatedData: ByteVector32, payloadLength: Int): PacketAndSecrets { + private fun buildOnion(nodes: List, payloads: List, associatedData: ByteVector32, payloadLength: Int): PacketAndSecrets { require(nodes.size == payloads.size) val sessionKey = Lightning.randomKey() val payloadsBin = payloads .map { when (it) { - is ChannelRelayPayload -> it.write() - is NodeRelayPayload -> it.write() - is FinalPayload -> it.write() + is PaymentOnion.ChannelRelayPayload -> it.write() + is PaymentOnion.NodeRelayPayload -> it.write() + is PaymentOnion.FinalPayload -> it.write() } } return Sphinx.create(sessionKey, nodes, payloadsBin, associatedData, payloadLength) @@ -50,13 +51,13 @@ object OutgoingPacket { * - firstExpiry is the cltv expiry for the first htlc in the route * - a sequence of payloads that will be used to build the onion */ - private fun buildPayloads(hops: List, finalPayload: FinalPayload): Triple> { - return hops.reversed().fold(Triple(finalPayload.amount, finalPayload.expiry, listOf(finalPayload))) { triple, hop -> + private fun buildPayloads(hops: List, finalPayload: PaymentOnion.FinalPayload): Triple> { + return hops.reversed().fold(Triple(finalPayload.amount, finalPayload.expiry, listOf(finalPayload))) { triple, hop -> val (amount, expiry, payloads) = triple val payload = when (hop) { // Since we don't have any scenario where we add tlv data for intermediate hops, we use legacy payloads. - is ChannelHop -> ChannelRelayPayload.create(hop.lastUpdate.shortChannelId, amount, expiry) - is NodeHop -> NodeRelayPayload.create(amount, expiry, hop.nextNodeId) + is ChannelHop -> PaymentOnion.ChannelRelayPayload.create(hop.lastUpdate.shortChannelId, amount, expiry) + is NodeHop -> PaymentOnion.NodeRelayPayload.create(amount, expiry, hop.nextNodeId) } Triple(amount + hop.fee(amount), expiry + hop.cltvExpiryDelta, listOf(payload) + payloads) } @@ -74,13 +75,16 @@ object OutgoingPacket { * - firstExpiry is the cltv expiry for the first trampoline node in the route * - the trampoline onion to include in final payload of a normal onion */ - fun buildTrampolineToLegacyPacket(invoice: PaymentRequest, hops: List, finalPayload: FinalPayload): Triple { - val (firstAmount, firstExpiry, payloads) = hops.drop(1).reversed().fold(Triple(finalPayload.amount, finalPayload.expiry, listOf(finalPayload))) { triple, hop -> + fun buildTrampolineToLegacyPacket(invoice: PaymentRequest, hops: List, finalPayload: PaymentOnion.FinalPayload): Triple { + // NB: the final payload will never reach the recipient, since the next-to-last trampoline hop will convert that to a legacy payment + // We use the smallest final payload possible, otherwise we may overflow the trampoline onion size. + val dummyFinalPayload = PaymentOnion.FinalPayload.createSinglePartPayload(finalPayload.amount, finalPayload.expiry, finalPayload.paymentSecret, null) + val (firstAmount, firstExpiry, payloads) = hops.drop(1).reversed().fold(Triple(finalPayload.amount, finalPayload.expiry, listOf(dummyFinalPayload))) { triple, hop -> val (amount, expiry, payloads) = triple val payload = when (payloads.size) { // The next-to-last trampoline hop must include invoice data to indicate the conversion to a legacy payment. - 1 -> NodeRelayPayload.createNodeRelayToNonTrampolinePayload(finalPayload.amount, finalPayload.totalAmount, finalPayload.expiry, hop.nextNodeId, invoice) - else -> NodeRelayPayload.create(amount, expiry, hop.nextNodeId) + 1 -> PaymentOnion.NodeRelayPayload.createNodeRelayToNonTrampolinePayload(finalPayload.amount, finalPayload.totalAmount, finalPayload.expiry, hop.nextNodeId, invoice) + else -> PaymentOnion.NodeRelayPayload.create(amount, expiry, hop.nextNodeId) } Triple(amount + hop.fee(amount), expiry + hop.cltvExpiryDelta, listOf(payload) + payloads) } @@ -99,7 +103,7 @@ object OutgoingPacket { * - firstExpiry is the cltv expiry for the first htlc in the route * - the onion to include in the HTLC */ - fun buildPacket(paymentHash: ByteVector32, hops: List, finalPayload: FinalPayload, payloadLength: Int): Triple { + fun buildPacket(paymentHash: ByteVector32, hops: List, finalPayload: PaymentOnion.FinalPayload, payloadLength: Int): Triple { val (firstAmount, firstExpiry, payloads) = buildPayloads(hops.drop(1), finalPayload) val nodes = hops.map { it.nextNodeId } // BOLT 2 requires that associatedData == paymentHash @@ -112,7 +116,7 @@ object OutgoingPacket { * * @return the command and the onion shared secrets (used to decrypt the error in case of payment failure) */ - fun buildCommand(paymentId: UUID, paymentHash: ByteVector32, hops: List, finalPayload: FinalPayload): Pair { + fun buildCommand(paymentId: UUID, paymentHash: ByteVector32, hops: List, finalPayload: PaymentOnion.FinalPayload): Pair { val (firstAmount, firstExpiry, onion) = buildPacket(paymentHash, hops, finalPayload, OnionRoutingPacket.PaymentPacketLength) return Pair(CMD_ADD_HTLC(firstAmount, paymentHash, firstExpiry, onion.packet, paymentId, commit = true), onion.sharedSecrets) } diff --git a/src/commonMain/kotlin/fr/acinq/lightning/payment/PaymentRequest.kt b/src/commonMain/kotlin/fr/acinq/lightning/payment/PaymentRequest.kt index c32beadee..88062076f 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/payment/PaymentRequest.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/payment/PaymentRequest.kt @@ -28,6 +28,9 @@ data class PaymentRequest( @Transient val paymentSecret: ByteVector32 = tags.find { it is TaggedField.PaymentSecret }!!.run { (this as TaggedField.PaymentSecret).secret } + @Transient + val paymentMetadata: ByteVector? = tags.find { it is TaggedField.PaymentMetadata }?.run { (this as TaggedField.PaymentMetadata).data } + @Transient val description: String? = tags.find { it is TaggedField.Description }?.run { (this as TaggedField.Description).description } @@ -138,22 +141,6 @@ data class PaymentRequest( Block.LivenetGenesisBlock.hash to "lnbc" ) - // only some features are valid in invoices - // see 'Context' column in https://github.com/lightningnetwork/lightning-rfc/blob/master/09-features.md - private val bolt11Features = setOf( - Feature.VariableLengthOnion, - Feature.PaymentSecret, - Feature.BasicMultiPartPayment, - Feature.TrampolinePayment - ) - - /** - * This filters out all features unrelated to BOLT 11 - */ - fun invoiceFeatures(features: Features): Features { - return Features(activated = features.activated.filter { (f, _) -> bolt11Features.contains(f) }) - } - fun create( chainHash: ByteVector32, amount: MilliSatoshi?, @@ -163,6 +150,7 @@ data class PaymentRequest( minFinalCltvExpiryDelta: CltvExpiryDelta, features: Features, paymentSecret: ByteVector32 = randomBytes32(), + paymentMetadata: ByteVector? = null, expirySeconds: Long? = null, extraHops: List> = listOf(), timestampSeconds: Long = currentTimestampSeconds() @@ -173,11 +161,11 @@ data class PaymentRequest( TaggedField.Description(description), TaggedField.MinFinalCltvExpiry(minFinalCltvExpiryDelta.toLong()), TaggedField.PaymentSecret(paymentSecret), - TaggedField.Features(features.toByteArray().toByteVector()) + // We remove unknown features which could make the invoice too big. + TaggedField.Features(features.invoiceFeatures().copy(unknown = setOf()).toByteArray().toByteVector()) ) - if (expirySeconds != null) { - tags.add(TaggedField.Expiry(expirySeconds)) - } + paymentMetadata?.let { tags.add(TaggedField.PaymentMetadata(it)) } + expirySeconds?.let { tags.add(TaggedField.Expiry(it)) } if (extraHops.isNotEmpty()) { extraHops.forEach { tags.add(TaggedField.RoutingInfo(it)) } } @@ -226,6 +214,7 @@ data class PaymentRequest( when (tag) { TaggedField.PaymentHash.tag -> tags.add(kotlin.runCatching { TaggedField.PaymentHash.decode(value) }.getOrDefault(TaggedField.InvalidTag(tag, value))) TaggedField.PaymentSecret.tag -> tags.add(kotlin.runCatching { TaggedField.PaymentSecret.decode(value) }.getOrDefault(TaggedField.InvalidTag(tag, value))) + TaggedField.PaymentMetadata.tag -> tags.add(kotlin.runCatching { TaggedField.PaymentMetadata.decode(value) }.getOrDefault(TaggedField.InvalidTag(tag, value))) TaggedField.Description.tag -> tags.add(kotlin.runCatching { TaggedField.Description.decode(value) }.getOrDefault(TaggedField.InvalidTag(tag, value))) TaggedField.DescriptionHash.tag -> tags.add(kotlin.runCatching { TaggedField.DescriptionHash.decode(value) }.getOrDefault(TaggedField.InvalidTag(tag, value))) TaggedField.Expiry.tag -> tags.add(kotlin.runCatching { TaggedField.Expiry.decode(value) }.getOrDefault(TaggedField.InvalidTag(tag, value))) @@ -355,6 +344,17 @@ data class PaymentRequest( } } + @Serializable + data class PaymentMetadata(@Contextual val data: ByteVector) : TaggedField() { + override val tag: Int5 = PaymentMetadata.tag + override fun encode(): List = Bech32.eight2five(data.toByteArray()).toList() + + companion object { + const val tag: Int5 = 27 + fun decode(input: List): PaymentMetadata = PaymentMetadata(Bech32.five2eight(input.toTypedArray(), 0).toByteVector()) + } + } + /** @param expirySeconds payment expiry (in seconds) */ @Serializable data class Expiry(val expirySeconds: Long) : TaggedField() { diff --git a/src/commonMain/kotlin/fr/acinq/lightning/serialization/v1/Serialization.kt b/src/commonMain/kotlin/fr/acinq/lightning/serialization/v1/Serialization.kt index 410d95204..c6fe64a4e 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/serialization/v1/Serialization.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/serialization/v1/Serialization.kt @@ -45,14 +45,15 @@ object Serialization { subclass(ChannelTlv.ChannelVersionTlv.serializer()) subclass(ChannelTlv.ChannelOriginTlv.serializer()) subclass(InitTlv.Networks.serializer()) - subclass(OnionTlv.AmountToForward.serializer()) - subclass(OnionTlv.OutgoingCltv.serializer()) - subclass(OnionTlv.OutgoingChannelId.serializer()) - subclass(OnionTlv.PaymentData.serializer()) - subclass(OnionTlv.InvoiceFeatures.serializer()) - subclass(OnionTlv.OutgoingNodeId.serializer()) - subclass(OnionTlv.InvoiceRoutingInfo.serializer()) - subclass(OnionTlv.TrampolineOnion.serializer()) + subclass(OnionPaymentPayloadTlv.AmountToForward.serializer()) + subclass(OnionPaymentPayloadTlv.OutgoingCltv.serializer()) + subclass(OnionPaymentPayloadTlv.OutgoingChannelId.serializer()) + subclass(OnionPaymentPayloadTlv.PaymentData.serializer()) + subclass(OnionPaymentPayloadTlv.PaymentMetadata.serializer()) + subclass(OnionPaymentPayloadTlv.InvoiceFeatures.serializer()) + subclass(OnionPaymentPayloadTlv.OutgoingNodeId.serializer()) + subclass(OnionPaymentPayloadTlv.InvoiceRoutingInfo.serializer()) + subclass(OnionPaymentPayloadTlv.TrampolineOnion.serializer()) subclass(GenericTlv.serializer()) } } diff --git a/src/commonMain/kotlin/fr/acinq/lightning/serialization/v2/Serialization.kt b/src/commonMain/kotlin/fr/acinq/lightning/serialization/v2/Serialization.kt index 71451e357..37f88ca70 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/serialization/v2/Serialization.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/serialization/v2/Serialization.kt @@ -55,14 +55,15 @@ object Serialization { subclass(ChannelTlv.ChannelVersionTlv.serializer()) subclass(ChannelTlv.ChannelOriginTlv.serializer()) subclass(InitTlv.Networks.serializer()) - subclass(OnionTlv.AmountToForward.serializer()) - subclass(OnionTlv.OutgoingCltv.serializer()) - subclass(OnionTlv.OutgoingChannelId.serializer()) - subclass(OnionTlv.PaymentData.serializer()) - subclass(OnionTlv.InvoiceFeatures.serializer()) - subclass(OnionTlv.OutgoingNodeId.serializer()) - subclass(OnionTlv.InvoiceRoutingInfo.serializer()) - subclass(OnionTlv.TrampolineOnion.serializer()) + subclass(OnionPaymentPayloadTlv.AmountToForward.serializer()) + subclass(OnionPaymentPayloadTlv.OutgoingCltv.serializer()) + subclass(OnionPaymentPayloadTlv.OutgoingChannelId.serializer()) + subclass(OnionPaymentPayloadTlv.PaymentData.serializer()) + subclass(OnionPaymentPayloadTlv.PaymentMetadata.serializer()) + subclass(OnionPaymentPayloadTlv.InvoiceFeatures.serializer()) + subclass(OnionPaymentPayloadTlv.OutgoingNodeId.serializer()) + subclass(OnionPaymentPayloadTlv.InvoiceRoutingInfo.serializer()) + subclass(OnionPaymentPayloadTlv.TrampolineOnion.serializer()) subclass(GenericTlv.serializer()) } } diff --git a/src/commonMain/kotlin/fr/acinq/lightning/serialization/v3/Serialization.kt b/src/commonMain/kotlin/fr/acinq/lightning/serialization/v3/Serialization.kt index 119eb008d..1e3632fed 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/serialization/v3/Serialization.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/serialization/v3/Serialization.kt @@ -55,14 +55,15 @@ object Serialization { subclass(ChannelTlv.ChannelVersionTlv.serializer()) subclass(ChannelTlv.ChannelOriginTlv.serializer()) subclass(InitTlv.Networks.serializer()) - subclass(OnionTlv.AmountToForward.serializer()) - subclass(OnionTlv.OutgoingCltv.serializer()) - subclass(OnionTlv.OutgoingChannelId.serializer()) - subclass(OnionTlv.PaymentData.serializer()) - subclass(OnionTlv.InvoiceFeatures.serializer()) - subclass(OnionTlv.OutgoingNodeId.serializer()) - subclass(OnionTlv.InvoiceRoutingInfo.serializer()) - subclass(OnionTlv.TrampolineOnion.serializer()) + subclass(OnionPaymentPayloadTlv.AmountToForward.serializer()) + subclass(OnionPaymentPayloadTlv.OutgoingCltv.serializer()) + subclass(OnionPaymentPayloadTlv.OutgoingChannelId.serializer()) + subclass(OnionPaymentPayloadTlv.PaymentData.serializer()) + subclass(OnionPaymentPayloadTlv.PaymentMetadata.serializer()) + subclass(OnionPaymentPayloadTlv.InvoiceFeatures.serializer()) + subclass(OnionPaymentPayloadTlv.OutgoingNodeId.serializer()) + subclass(OnionPaymentPayloadTlv.InvoiceRoutingInfo.serializer()) + subclass(OnionPaymentPayloadTlv.TrampolineOnion.serializer()) subclass(GenericTlv.serializer()) } } diff --git a/src/commonMain/kotlin/fr/acinq/lightning/wire/Onion.kt b/src/commonMain/kotlin/fr/acinq/lightning/wire/Onion.kt deleted file mode 100644 index cd7933e12..000000000 --- a/src/commonMain/kotlin/fr/acinq/lightning/wire/Onion.kt +++ /dev/null @@ -1,333 +0,0 @@ -package fr.acinq.lightning.wire - -import fr.acinq.bitcoin.ByteVector -import fr.acinq.bitcoin.ByteVector32 -import fr.acinq.bitcoin.PublicKey -import fr.acinq.bitcoin.io.ByteArrayInput -import fr.acinq.bitcoin.io.ByteArrayOutput -import fr.acinq.bitcoin.io.Input -import fr.acinq.bitcoin.io.Output -import fr.acinq.lightning.CltvExpiry -import fr.acinq.lightning.CltvExpiryDelta -import fr.acinq.lightning.MilliSatoshi -import fr.acinq.lightning.ShortChannelId -import fr.acinq.lightning.payment.PaymentRequest -import fr.acinq.lightning.utils.msat -import fr.acinq.lightning.utils.toByteVector -import fr.acinq.lightning.utils.toByteVector32 -import kotlinx.serialization.Contextual -import kotlinx.serialization.Serializable - -@Serializable -data class OnionRoutingPacket( - val version: Int, - @Contextual val publicKey: ByteVector, - @Contextual val payload: ByteVector, - @Contextual val hmac: ByteVector32 -) { - companion object { - const val PaymentPacketLength = 1300 - const val TrampolinePacketLength = 400 - } -} - -/** - * @param payloadLength length of the onion-encrypted payload. - */ -@OptIn(ExperimentalUnsignedTypes::class) -class OnionRoutingPacketSerializer(private val payloadLength: Int) { - fun read(input: Input): OnionRoutingPacket { - return OnionRoutingPacket( - LightningCodecs.byte(input), - LightningCodecs.bytes(input, 33).toByteVector(), - LightningCodecs.bytes(input, payloadLength).toByteVector(), - LightningCodecs.bytes(input, 32).toByteVector32() - ) - } - - fun read(bytes: ByteArray): OnionRoutingPacket = read(ByteArrayInput(bytes)) - - fun write(message: OnionRoutingPacket, out: Output) { - LightningCodecs.writeByte(message.version, out) - LightningCodecs.writeBytes(message.publicKey, out) - LightningCodecs.writeBytes(message.payload, out) - LightningCodecs.writeBytes(message.hmac, out) - } - - fun write(message: OnionRoutingPacket): ByteArray { - val out = ByteArrayOutput() - write(message, out) - return out.toByteArray() - } -} - -@OptIn(ExperimentalUnsignedTypes::class) -@Serializable -sealed class OnionTlv : Tlv { - /** Amount to forward to the next node. */ - @Serializable - data class AmountToForward(val amount: MilliSatoshi) : OnionTlv() { - override val tag: Long get() = AmountToForward.tag - override fun write(out: Output) = LightningCodecs.writeTU64(amount.toLong(), out) - - companion object : TlvValueReader { - const val tag: Long = 2 - override fun read(input: Input): AmountToForward = AmountToForward(MilliSatoshi(LightningCodecs.tu64(input))) - } - } - - /** CLTV value to use for the HTLC offered to the next node. */ - @Serializable - data class OutgoingCltv(val cltv: CltvExpiry) : OnionTlv() { - override val tag: Long get() = OutgoingCltv.tag - override fun write(out: Output) = LightningCodecs.writeTU32(cltv.toLong().toInt(), out) - - companion object : TlvValueReader { - const val tag: Long = 4 - override fun read(input: Input): OutgoingCltv = OutgoingCltv(CltvExpiry(LightningCodecs.tu32(input).toLong())) - } - } - - /** Id of the channel to use to forward a payment to the next node. */ - @Serializable - data class OutgoingChannelId(val shortChannelId: ShortChannelId) : OnionTlv() { - override val tag: Long get() = OutgoingChannelId.tag - override fun write(out: Output) = LightningCodecs.writeU64(shortChannelId.toLong(), out) - - companion object : TlvValueReader { - const val tag: Long = 6 - override fun read(input: Input): OutgoingChannelId = OutgoingChannelId(ShortChannelId(LightningCodecs.u64(input))) - } - } - - /** - * Bolt 11 payment details (only included for the last node). - * - * @param secret payment secret specified in the Bolt 11 invoice. - * @param totalAmount total amount in multi-part payments. When missing, assumed to be equal to AmountToForward. - */ - @Serializable - data class PaymentData(@Contextual val secret: ByteVector32, val totalAmount: MilliSatoshi) : OnionTlv() { - override val tag: Long get() = PaymentData.tag - override fun write(out: Output) { - LightningCodecs.writeBytes(secret, out) - LightningCodecs.writeTU64(totalAmount.toLong(), out) - } - - companion object : TlvValueReader { - const val tag: Long = 8 - override fun read(input: Input): PaymentData = PaymentData(ByteVector32(LightningCodecs.bytes(input, 32)), MilliSatoshi(LightningCodecs.tu64(input))) - } - } - - /** - * Invoice feature bits. Only included for intermediate trampoline nodes when they should convert to a legacy payment - * because the final recipient doesn't support trampoline. - */ - @Serializable - data class InvoiceFeatures(@Contextual val features: ByteVector) : OnionTlv() { - override val tag: Long get() = InvoiceFeatures.tag - override fun write(out: Output) = LightningCodecs.writeBytes(features, out) - - companion object : TlvValueReader { - const val tag: Long = 66097 - override fun read(input: Input): InvoiceFeatures = InvoiceFeatures(ByteVector(LightningCodecs.bytes(input, input.availableBytes))) - } - } - - /** Id of the next node. */ - @Serializable - data class OutgoingNodeId(@Contextual val nodeId: PublicKey) : OnionTlv() { - override val tag: Long get() = OutgoingNodeId.tag - override fun write(out: Output) = LightningCodecs.writeBytes(nodeId.value, out) - - companion object : TlvValueReader { - const val tag: Long = 66098 - override fun read(input: Input): OutgoingNodeId = OutgoingNodeId(PublicKey(LightningCodecs.bytes(input, 33))) - } - } - - /** - * Invoice routing hints. Only included for intermediate trampoline nodes when they should convert to a legacy payment - * because the final recipient doesn't support trampoline. - */ - @Serializable - data class InvoiceRoutingInfo(val extraHops: List>) : OnionTlv() { - override val tag: Long get() = InvoiceRoutingInfo.tag - override fun write(out: Output) { - for (routeHint in extraHops) { - LightningCodecs.writeByte(routeHint.size, out) - routeHint.map { - LightningCodecs.writeBytes(it.nodeId.value, out) - LightningCodecs.writeU64(it.shortChannelId.toLong(), out) - LightningCodecs.writeU32(it.feeBase.toLong().toInt(), out) - LightningCodecs.writeU32(it.feeProportionalMillionths.toInt(), out) - LightningCodecs.writeU16(it.cltvExpiryDelta.toInt(), out) - } - } - } - - companion object : TlvValueReader { - const val tag: Long = 66099 - override fun read(input: Input): InvoiceRoutingInfo { - val extraHops = mutableListOf>() - while (input.availableBytes > 0) { - val hopCount = LightningCodecs.byte(input) - val extraHop = (0 until hopCount).map { - PaymentRequest.TaggedField.ExtraHop( - PublicKey(LightningCodecs.bytes(input, 33)), - ShortChannelId(LightningCodecs.u64(input)), - MilliSatoshi(LightningCodecs.u32(input).toLong()), - LightningCodecs.u32(input).toLong(), - CltvExpiryDelta(LightningCodecs.u16(input)) - ) - } - extraHops.add(extraHop) - } - return InvoiceRoutingInfo(extraHops) - } - } - } - - /** An encrypted trampoline onion packet. */ - @Serializable - data class TrampolineOnion(val packet: OnionRoutingPacket) : OnionTlv() { - override val tag: Long get() = TrampolineOnion.tag - override fun write(out: Output) = OnionRoutingPacketSerializer(OnionRoutingPacket.TrampolinePacketLength).write(packet, out) - - companion object : TlvValueReader { - const val tag: Long = 66100 - override fun read(input: Input): TrampolineOnion = TrampolineOnion(OnionRoutingPacketSerializer(OnionRoutingPacket.TrampolinePacketLength).read(input)) - } - } -} - -sealed class PerHopPayload { - - abstract fun write(out: Output) - - fun write(): ByteArray { - val out = ByteArrayOutput() - write(out) - return out.toByteArray() - } - - companion object { - val tlvSerializer = TlvStreamSerializer( - true, - @Suppress("UNCHECKED_CAST") - mapOf( - OnionTlv.AmountToForward.tag to OnionTlv.AmountToForward.Companion as TlvValueReader, - OnionTlv.OutgoingCltv.tag to OnionTlv.OutgoingCltv.Companion as TlvValueReader, - OnionTlv.OutgoingChannelId.tag to OnionTlv.OutgoingChannelId.Companion as TlvValueReader, - OnionTlv.PaymentData.tag to OnionTlv.PaymentData.Companion as TlvValueReader, - OnionTlv.InvoiceFeatures.tag to OnionTlv.InvoiceFeatures.Companion as TlvValueReader, - OnionTlv.OutgoingNodeId.tag to OnionTlv.OutgoingNodeId.Companion as TlvValueReader, - OnionTlv.InvoiceRoutingInfo.tag to OnionTlv.InvoiceRoutingInfo.Companion as TlvValueReader, - OnionTlv.TrampolineOnion.tag to OnionTlv.TrampolineOnion.Companion as TlvValueReader, - ) - ) - } -} - -interface PerHopPayloadReader { - fun read(input: Input): T - fun read(bytes: ByteArray): T = read(ByteArrayInput(bytes)) -} - -data class FinalPayload(val records: TlvStream) : PerHopPayload() { - val amount = records.get()!!.amount - val expiry = records.get()!!.cltv - val paymentSecret = records.get()!!.secret - val totalAmount = run { - val total = records.get()!!.totalAmount - if (total > 0.msat) total else amount - } - - override fun write(out: Output) = tlvSerializer.write(records, out) - - companion object : PerHopPayloadReader { - override fun read(input: Input): FinalPayload = FinalPayload(tlvSerializer.read(input)) - - /** Create a single-part payment (total amount sent at once). */ - fun createSinglePartPayload(amount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32, userCustomTlvs: List = listOf()): FinalPayload = - FinalPayload(TlvStream(listOf(OnionTlv.AmountToForward(amount), OnionTlv.OutgoingCltv(expiry), OnionTlv.PaymentData(paymentSecret, amount)), userCustomTlvs)) - - /** Create a partial payment (total amount split between multiple payments). */ - fun createMultiPartPayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32, additionalTlvs: List = listOf(), userCustomTlvs: List = listOf()): FinalPayload = - FinalPayload(TlvStream(listOf(OnionTlv.AmountToForward(amount), OnionTlv.OutgoingCltv(expiry), OnionTlv.PaymentData(paymentSecret, totalAmount)) + additionalTlvs, userCustomTlvs)) - - /** Create a trampoline outer payload. */ - fun createTrampolinePayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32, trampolinePacket: OnionRoutingPacket): FinalPayload = - FinalPayload(TlvStream(listOf(OnionTlv.AmountToForward(amount), OnionTlv.OutgoingCltv(expiry), OnionTlv.PaymentData(paymentSecret, totalAmount), OnionTlv.TrampolineOnion(trampolinePacket)))) - } -} - -data class ChannelRelayPayload(val records: TlvStream) : PerHopPayload() { - val amountToForward = records.get()!!.amount - val outgoingCltv = records.get()!!.cltv - val outgoingChannelId = records.get()!!.shortChannelId - - override fun write(out: Output) = tlvSerializer.write(records, out) - - companion object : PerHopPayloadReader { - override fun read(input: Input): ChannelRelayPayload = ChannelRelayPayload(tlvSerializer.read(input)) - - fun create(outgoingChannelId: ShortChannelId, amountToForward: MilliSatoshi, outgoingCltv: CltvExpiry): ChannelRelayPayload = - ChannelRelayPayload(TlvStream(listOf(OnionTlv.AmountToForward(amountToForward), OnionTlv.OutgoingCltv(outgoingCltv), OnionTlv.OutgoingChannelId(outgoingChannelId)))) - } -} - -data class NodeRelayPayload(val records: TlvStream) : PerHopPayload() { - val amountToForward = records.get()!!.amount - val outgoingCltv = records.get()!!.cltv - val outgoingNodeId = records.get()!!.nodeId - val totalAmount = run { - val paymentData = records.get() - when { - paymentData == null -> amountToForward - paymentData.totalAmount == MilliSatoshi(0) -> amountToForward - else -> paymentData.totalAmount - } - } - - // NB: the following fields are only included in the trampoline-to-legacy case. - val paymentSecret = records.get()?.secret - val invoiceFeatures = records.get()?.features - val invoiceRoutingInfo = records.get()?.extraHops - - override fun write(out: Output) = tlvSerializer.write(records, out) - - companion object : PerHopPayloadReader { - override fun read(input: Input): NodeRelayPayload = NodeRelayPayload(tlvSerializer.read(input)) - - fun create(amount: MilliSatoshi, expiry: CltvExpiry, nextNodeId: PublicKey) = NodeRelayPayload(TlvStream(listOf(OnionTlv.AmountToForward(amount), OnionTlv.OutgoingCltv(expiry), OnionTlv.OutgoingNodeId(nextNodeId)))) - - /** Create a trampoline inner payload instructing the trampoline node to relay via a non-trampoline payment. */ - fun createNodeRelayToNonTrampolinePayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, targetNodeId: PublicKey, invoice: PaymentRequest): NodeRelayPayload { - // NB: we limit the number of routing hints to ensure we don't overflow the onion. - // A better solution is to provide the routing hints outside the onion (in the `update_add_htlc` tlv stream). - val prunedRoutingHints = invoice.routingInfo.shuffled().fold(listOf()) { previous, current -> - if (previous.flatMap { it.hints }.size + current.hints.size <= 4) { - previous + current - } else { - previous - } - }.map { it.hints } - return NodeRelayPayload( - TlvStream( - listOf( - OnionTlv.AmountToForward(amount), - OnionTlv.OutgoingCltv(expiry), - OnionTlv.OutgoingNodeId(targetNodeId), - OnionTlv.PaymentData(invoice.paymentSecret, totalAmount), - OnionTlv.InvoiceFeatures(invoice.features), - OnionTlv.InvoiceRoutingInfo(prunedRoutingHints) - ) - ) - ) - } - - } - -} diff --git a/src/commonMain/kotlin/fr/acinq/lightning/wire/OnionRouting.kt b/src/commonMain/kotlin/fr/acinq/lightning/wire/OnionRouting.kt new file mode 100644 index 000000000..8e3b15005 --- /dev/null +++ b/src/commonMain/kotlin/fr/acinq/lightning/wire/OnionRouting.kt @@ -0,0 +1,55 @@ +package fr.acinq.lightning.wire + +import fr.acinq.bitcoin.ByteVector +import fr.acinq.bitcoin.ByteVector32 +import fr.acinq.bitcoin.io.ByteArrayInput +import fr.acinq.bitcoin.io.ByteArrayOutput +import fr.acinq.bitcoin.io.Input +import fr.acinq.bitcoin.io.Output +import fr.acinq.lightning.utils.toByteVector +import fr.acinq.lightning.utils.toByteVector32 +import kotlinx.serialization.Contextual +import kotlinx.serialization.Serializable + +@Serializable +data class OnionRoutingPacket( + val version: Int, + @Contextual val publicKey: ByteVector, + @Contextual val payload: ByteVector, + @Contextual val hmac: ByteVector32 +) { + companion object { + const val PaymentPacketLength = 1300 + const val TrampolinePacketLength = 400 + } +} + +/** + * @param payloadLength length of the onion-encrypted payload. + */ +@OptIn(ExperimentalUnsignedTypes::class) +class OnionRoutingPacketSerializer(private val payloadLength: Int) { + fun read(input: Input): OnionRoutingPacket { + return OnionRoutingPacket( + LightningCodecs.byte(input), + LightningCodecs.bytes(input, 33).toByteVector(), + LightningCodecs.bytes(input, payloadLength).toByteVector(), + LightningCodecs.bytes(input, 32).toByteVector32() + ) + } + + fun read(bytes: ByteArray): OnionRoutingPacket = read(ByteArrayInput(bytes)) + + fun write(message: OnionRoutingPacket, out: Output) { + LightningCodecs.writeByte(message.version, out) + LightningCodecs.writeBytes(message.publicKey, out) + LightningCodecs.writeBytes(message.payload, out) + LightningCodecs.writeBytes(message.hmac, out) + } + + fun write(message: OnionRoutingPacket): ByteArray { + val out = ByteArrayOutput() + write(message, out) + return out.toByteArray() + } +} \ No newline at end of file diff --git a/src/commonMain/kotlin/fr/acinq/lightning/wire/PaymentOnion.kt b/src/commonMain/kotlin/fr/acinq/lightning/wire/PaymentOnion.kt new file mode 100644 index 000000000..2fe59723a --- /dev/null +++ b/src/commonMain/kotlin/fr/acinq/lightning/wire/PaymentOnion.kt @@ -0,0 +1,338 @@ +package fr.acinq.lightning.wire + +import fr.acinq.bitcoin.ByteVector +import fr.acinq.bitcoin.ByteVector32 +import fr.acinq.bitcoin.PublicKey +import fr.acinq.bitcoin.io.ByteArrayInput +import fr.acinq.bitcoin.io.ByteArrayOutput +import fr.acinq.bitcoin.io.Input +import fr.acinq.bitcoin.io.Output +import fr.acinq.lightning.CltvExpiry +import fr.acinq.lightning.CltvExpiryDelta +import fr.acinq.lightning.MilliSatoshi +import fr.acinq.lightning.ShortChannelId +import fr.acinq.lightning.payment.PaymentRequest +import fr.acinq.lightning.utils.msat +import kotlinx.serialization.Contextual +import kotlinx.serialization.Serializable + +@OptIn(ExperimentalUnsignedTypes::class) +@Serializable +sealed class OnionPaymentPayloadTlv : Tlv { + /** Amount to forward to the next node. */ + @Serializable + data class AmountToForward(val amount: MilliSatoshi) : OnionPaymentPayloadTlv() { + override val tag: Long get() = AmountToForward.tag + override fun write(out: Output) = LightningCodecs.writeTU64(amount.toLong(), out) + + companion object : TlvValueReader { + const val tag: Long = 2 + override fun read(input: Input): AmountToForward = AmountToForward(MilliSatoshi(LightningCodecs.tu64(input))) + } + } + + /** CLTV value to use for the HTLC offered to the next node. */ + @Serializable + data class OutgoingCltv(val cltv: CltvExpiry) : OnionPaymentPayloadTlv() { + override val tag: Long get() = OutgoingCltv.tag + override fun write(out: Output) = LightningCodecs.writeTU32(cltv.toLong().toInt(), out) + + companion object : TlvValueReader { + const val tag: Long = 4 + override fun read(input: Input): OutgoingCltv = OutgoingCltv(CltvExpiry(LightningCodecs.tu32(input).toLong())) + } + } + + /** Id of the channel to use to forward a payment to the next node. */ + @Serializable + data class OutgoingChannelId(val shortChannelId: ShortChannelId) : OnionPaymentPayloadTlv() { + override val tag: Long get() = OutgoingChannelId.tag + override fun write(out: Output) = LightningCodecs.writeU64(shortChannelId.toLong(), out) + + companion object : TlvValueReader { + const val tag: Long = 6 + override fun read(input: Input): OutgoingChannelId = OutgoingChannelId(ShortChannelId(LightningCodecs.u64(input))) + } + } + + /** + * Bolt 11 payment details (only included for the last node). + * + * @param secret payment secret specified in the Bolt 11 invoice. + * @param totalAmount total amount in multi-part payments. When missing, assumed to be equal to AmountToForward. + */ + @Serializable + data class PaymentData(@Contextual val secret: ByteVector32, val totalAmount: MilliSatoshi) : OnionPaymentPayloadTlv() { + override val tag: Long get() = PaymentData.tag + override fun write(out: Output) { + LightningCodecs.writeBytes(secret, out) + LightningCodecs.writeTU64(totalAmount.toLong(), out) + } + + companion object : TlvValueReader { + const val tag: Long = 8 + override fun read(input: Input): PaymentData = PaymentData(ByteVector32(LightningCodecs.bytes(input, 32)), MilliSatoshi(LightningCodecs.tu64(input))) + } + } + + /** + * When payment metadata is included in a Bolt 9 invoice, we should send it as-is to the recipient. + * This lets recipients generate invoices without having to store anything on their side until the invoice is paid. + */ + @Serializable + data class PaymentMetadata(@Contextual val data: ByteVector) : OnionPaymentPayloadTlv() { + override val tag: Long get() = PaymentMetadata.tag + override fun write(out: Output) = LightningCodecs.writeBytes(data, out) + + companion object : TlvValueReader { + const val tag: Long = 16 + override fun read(input: Input): PaymentMetadata = PaymentMetadata(ByteVector(LightningCodecs.bytes(input, input.availableBytes))) + } + } + + /** + * Invoice feature bits. Only included for intermediate trampoline nodes when they should convert to a legacy payment + * because the final recipient doesn't support trampoline. + */ + @Serializable + data class InvoiceFeatures(@Contextual val features: ByteVector) : OnionPaymentPayloadTlv() { + override val tag: Long get() = InvoiceFeatures.tag + override fun write(out: Output) = LightningCodecs.writeBytes(features, out) + + companion object : TlvValueReader { + const val tag: Long = 66097 + override fun read(input: Input): InvoiceFeatures = InvoiceFeatures(ByteVector(LightningCodecs.bytes(input, input.availableBytes))) + } + } + + /** Id of the next node. */ + @Serializable + data class OutgoingNodeId(@Contextual val nodeId: PublicKey) : OnionPaymentPayloadTlv() { + override val tag: Long get() = OutgoingNodeId.tag + override fun write(out: Output) = LightningCodecs.writeBytes(nodeId.value, out) + + companion object : TlvValueReader { + const val tag: Long = 66098 + override fun read(input: Input): OutgoingNodeId = OutgoingNodeId(PublicKey(LightningCodecs.bytes(input, 33))) + } + } + + /** + * Invoice routing hints. Only included for intermediate trampoline nodes when they should convert to a legacy payment + * because the final recipient doesn't support trampoline. + */ + @Serializable + data class InvoiceRoutingInfo(val extraHops: List>) : OnionPaymentPayloadTlv() { + override val tag: Long get() = InvoiceRoutingInfo.tag + override fun write(out: Output) { + for (routeHint in extraHops) { + LightningCodecs.writeByte(routeHint.size, out) + routeHint.map { + LightningCodecs.writeBytes(it.nodeId.value, out) + LightningCodecs.writeU64(it.shortChannelId.toLong(), out) + LightningCodecs.writeU32(it.feeBase.toLong().toInt(), out) + LightningCodecs.writeU32(it.feeProportionalMillionths.toInt(), out) + LightningCodecs.writeU16(it.cltvExpiryDelta.toInt(), out) + } + } + } + + companion object : TlvValueReader { + const val tag: Long = 66099 + override fun read(input: Input): InvoiceRoutingInfo { + val extraHops = mutableListOf>() + while (input.availableBytes > 0) { + val hopCount = LightningCodecs.byte(input) + val extraHop = (0 until hopCount).map { + PaymentRequest.TaggedField.ExtraHop( + PublicKey(LightningCodecs.bytes(input, 33)), + ShortChannelId(LightningCodecs.u64(input)), + MilliSatoshi(LightningCodecs.u32(input).toLong()), + LightningCodecs.u32(input).toLong(), + CltvExpiryDelta(LightningCodecs.u16(input)) + ) + } + extraHops.add(extraHop) + } + return InvoiceRoutingInfo(extraHops) + } + } + } + + /** An encrypted trampoline onion packet. */ + @Serializable + data class TrampolineOnion(val packet: OnionRoutingPacket) : OnionPaymentPayloadTlv() { + override val tag: Long get() = TrampolineOnion.tag + override fun write(out: Output) = OnionRoutingPacketSerializer(OnionRoutingPacket.TrampolinePacketLength).write(packet, out) + + companion object : TlvValueReader { + const val tag: Long = 66100 + override fun read(input: Input): TrampolineOnion = TrampolineOnion(OnionRoutingPacketSerializer(OnionRoutingPacket.TrampolinePacketLength).read(input)) + } + } +} + +object PaymentOnion { + + sealed class PerHopPayload { + + abstract fun write(out: Output) + + fun write(): ByteArray { + val out = ByteArrayOutput() + write(out) + return out.toByteArray() + } + + companion object { + val tlvSerializer = TlvStreamSerializer( + true, @Suppress("UNCHECKED_CAST") mapOf( + OnionPaymentPayloadTlv.AmountToForward.tag to OnionPaymentPayloadTlv.AmountToForward.Companion as TlvValueReader, + OnionPaymentPayloadTlv.OutgoingCltv.tag to OnionPaymentPayloadTlv.OutgoingCltv.Companion as TlvValueReader, + OnionPaymentPayloadTlv.OutgoingChannelId.tag to OnionPaymentPayloadTlv.OutgoingChannelId.Companion as TlvValueReader, + OnionPaymentPayloadTlv.PaymentData.tag to OnionPaymentPayloadTlv.PaymentData.Companion as TlvValueReader, + OnionPaymentPayloadTlv.PaymentMetadata.tag to OnionPaymentPayloadTlv.PaymentMetadata.Companion as TlvValueReader, + OnionPaymentPayloadTlv.InvoiceFeatures.tag to OnionPaymentPayloadTlv.InvoiceFeatures.Companion as TlvValueReader, + OnionPaymentPayloadTlv.OutgoingNodeId.tag to OnionPaymentPayloadTlv.OutgoingNodeId.Companion as TlvValueReader, + OnionPaymentPayloadTlv.InvoiceRoutingInfo.tag to OnionPaymentPayloadTlv.InvoiceRoutingInfo.Companion as TlvValueReader, + OnionPaymentPayloadTlv.TrampolineOnion.tag to OnionPaymentPayloadTlv.TrampolineOnion.Companion as TlvValueReader, + ) + ) + } + } + + interface PerHopPayloadReader { + fun read(input: Input): T + fun read(bytes: ByteArray): T = read(ByteArrayInput(bytes)) + } + + data class FinalPayload(val records: TlvStream) : PerHopPayload() { + val amount = records.get()!!.amount + val expiry = records.get()!!.cltv + val paymentSecret = records.get()!!.secret + val totalAmount = run { + val total = records.get()!!.totalAmount + if (total > 0.msat) total else amount + } + val paymentMetadata = records.get()?.data + + override fun write(out: Output) = tlvSerializer.write(records, out) + + companion object : PerHopPayloadReader { + override fun read(input: Input): FinalPayload = FinalPayload(tlvSerializer.read(input)) + + /** Create a single-part payment (total amount sent at once). */ + fun createSinglePartPayload(amount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32, paymentMetadata: ByteVector?, userCustomTlvs: List = listOf()): FinalPayload { + val tlvs = buildList { + add(OnionPaymentPayloadTlv.AmountToForward(amount)) + add(OnionPaymentPayloadTlv.OutgoingCltv(expiry)) + add(OnionPaymentPayloadTlv.PaymentData(paymentSecret, amount)) + paymentMetadata?.let { add(OnionPaymentPayloadTlv.PaymentMetadata(it)) } + } + return FinalPayload(TlvStream(tlvs, userCustomTlvs)) + } + + /** Create a partial payment (total amount split between multiple payments). */ + fun createMultiPartPayload( + amount: MilliSatoshi, + totalAmount: MilliSatoshi, + expiry: CltvExpiry, + paymentSecret: ByteVector32, + paymentMetadata: ByteVector?, + additionalTlvs: List = listOf(), + userCustomTlvs: List = listOf() + ): FinalPayload { + val tlvs = buildList { + add(OnionPaymentPayloadTlv.AmountToForward(amount)) + add(OnionPaymentPayloadTlv.OutgoingCltv(expiry)) + add(OnionPaymentPayloadTlv.PaymentData(paymentSecret, totalAmount)) + paymentMetadata?.let { add(OnionPaymentPayloadTlv.PaymentMetadata(it)) } + addAll(additionalTlvs) + } + return FinalPayload(TlvStream(tlvs, userCustomTlvs)) + } + + /** Create a trampoline outer payload. */ + fun createTrampolinePayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32, trampolinePacket: OnionRoutingPacket): FinalPayload { + val tlvs = buildList { + add(OnionPaymentPayloadTlv.AmountToForward(amount)) + add(OnionPaymentPayloadTlv.OutgoingCltv(expiry)) + add(OnionPaymentPayloadTlv.PaymentData(paymentSecret, totalAmount)) + add(OnionPaymentPayloadTlv.TrampolineOnion(trampolinePacket)) + } + return FinalPayload(TlvStream(tlvs)) + } + } + } + + data class ChannelRelayPayload(val records: TlvStream) : PerHopPayload() { + val amountToForward = records.get()!!.amount + val outgoingCltv = records.get()!!.cltv + val outgoingChannelId = records.get()!!.shortChannelId + + override fun write(out: Output) = tlvSerializer.write(records, out) + + companion object : PerHopPayloadReader { + override fun read(input: Input): ChannelRelayPayload = ChannelRelayPayload(tlvSerializer.read(input)) + + fun create(outgoingChannelId: ShortChannelId, amountToForward: MilliSatoshi, outgoingCltv: CltvExpiry): ChannelRelayPayload = + ChannelRelayPayload(TlvStream(listOf(OnionPaymentPayloadTlv.AmountToForward(amountToForward), OnionPaymentPayloadTlv.OutgoingCltv(outgoingCltv), OnionPaymentPayloadTlv.OutgoingChannelId(outgoingChannelId)))) + } + } + + data class NodeRelayPayload(val records: TlvStream) : PerHopPayload() { + val amountToForward = records.get()!!.amount + val outgoingCltv = records.get()!!.cltv + val outgoingNodeId = records.get()!!.nodeId + val totalAmount = run { + val paymentData = records.get() + when { + paymentData == null -> amountToForward + paymentData.totalAmount == MilliSatoshi(0) -> amountToForward + else -> paymentData.totalAmount + } + } + + // NB: the following fields are only included in the trampoline-to-legacy case. + val paymentSecret = records.get()?.secret + val paymentMetadata = records.get()?.data + val invoiceFeatures = records.get()?.features + val invoiceRoutingInfo = records.get()?.extraHops + + override fun write(out: Output) = tlvSerializer.write(records, out) + + companion object : PerHopPayloadReader { + override fun read(input: Input): NodeRelayPayload = NodeRelayPayload(tlvSerializer.read(input)) + + fun create(amount: MilliSatoshi, expiry: CltvExpiry, nextNodeId: PublicKey) = + NodeRelayPayload(TlvStream(listOf(OnionPaymentPayloadTlv.AmountToForward(amount), OnionPaymentPayloadTlv.OutgoingCltv(expiry), OnionPaymentPayloadTlv.OutgoingNodeId(nextNodeId)))) + + /** Create a trampoline inner payload instructing the trampoline node to relay via a non-trampoline payment. */ + fun createNodeRelayToNonTrampolinePayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, targetNodeId: PublicKey, invoice: PaymentRequest): NodeRelayPayload { + // NB: we limit the number of routing hints to ensure we don't overflow the onion. + // A better solution is to provide the routing hints outside the onion (in the `update_add_htlc` tlv stream). + val prunedRoutingHints = invoice.routingInfo.shuffled().fold(listOf()) { previous, current -> + if (previous.flatMap { it.hints }.size + current.hints.size <= 4) { + previous + current + } else { + previous + } + }.map { it.hints } + return NodeRelayPayload( + TlvStream( + buildList { + add(OnionPaymentPayloadTlv.AmountToForward(amount)) + add(OnionPaymentPayloadTlv.OutgoingCltv(expiry)) + add(OnionPaymentPayloadTlv.OutgoingNodeId(targetNodeId)) + add(OnionPaymentPayloadTlv.PaymentData(invoice.paymentSecret, totalAmount)) + invoice.paymentMetadata?.let { add(OnionPaymentPayloadTlv.PaymentMetadata(it)) } + add(OnionPaymentPayloadTlv.InvoiceFeatures(invoice.features)) + add(OnionPaymentPayloadTlv.InvoiceRoutingInfo(prunedRoutingHints)) + } + ) + ) + } + } + } + +} diff --git a/src/commonTest/kotlin/fr/acinq/lightning/FeaturesTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/FeaturesTestsCommon.kt index 919027da2..4fcc960aa 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/FeaturesTestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/FeaturesTestsCommon.kt @@ -191,6 +191,37 @@ class FeaturesTestsCommon : LightningTestSuite() { } } + @Test + fun `filter features based on their usage`() { + val features = Features( + mapOf( + OptionDataLossProtect to FeatureSupport.Optional, + InitialRoutingSync to FeatureSupport.Optional, + VariableLengthOnion to FeatureSupport.Mandatory, + PaymentMetadata to FeatureSupport.Optional, + ), + setOf(UnknownFeature(753), UnknownFeature(852)), + ) + assertEquals( + Features( + mapOf(OptionDataLossProtect to FeatureSupport.Optional, InitialRoutingSync to FeatureSupport.Optional, VariableLengthOnion to FeatureSupport.Mandatory), + setOf(UnknownFeature(753), UnknownFeature(852)), + ), features.initFeatures() + ) + assertEquals( + Features( + mapOf(OptionDataLossProtect to FeatureSupport.Optional, VariableLengthOnion to FeatureSupport.Mandatory), + setOf(UnknownFeature(753), UnknownFeature(852)), + ), features.nodeAnnouncementFeatures() + ) + assertEquals( + Features( + mapOf(VariableLengthOnion to FeatureSupport.Mandatory, PaymentMetadata to FeatureSupport.Optional), + setOf(UnknownFeature(753), UnknownFeature(852)), + ), features.invoiceFeatures() + ) + } + @Test fun `features to bytes`() { val testCases = mapOf( diff --git a/src/commonTest/kotlin/fr/acinq/lightning/channel/TestsHelper.kt b/src/commonTest/kotlin/fr/acinq/lightning/channel/TestsHelper.kt index 4c61d1060..fe891d18e 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/channel/TestsHelper.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/channel/TestsHelper.kt @@ -6,7 +6,7 @@ import fr.acinq.lightning.Lightning.randomBytes32 import fr.acinq.lightning.blockchain.* import fr.acinq.lightning.blockchain.fee.FeeratePerKw import fr.acinq.lightning.blockchain.fee.OnChainFeerates -import fr.acinq.lightning.payment.OutgoingPacket +import fr.acinq.lightning.payment.OutgoingPaymentPacket import fr.acinq.lightning.router.ChannelHop import fr.acinq.lightning.serialization.Serialization import fr.acinq.lightning.tests.TestConstants @@ -320,7 +320,7 @@ object TestsHelper { val expiry = CltvExpiryDelta(144).toCltvExpiry(currentBlockHeight) val dummyKey = PrivateKey(ByteVector32("0101010101010101010101010101010101010101010101010101010101010101")).publicKey() val dummyUpdate = ChannelUpdate(ByteVector64.Zeroes, ByteVector32.Zeroes, ShortChannelId(144, 0, 0), 0, 0, 0, CltvExpiryDelta(1), 0.msat, 0.msat, 0, null) - val cmd = OutgoingPacket.buildCommand(paymentId, paymentHash, listOf(ChannelHop(dummyKey, destination, dummyUpdate)), FinalPayload.createSinglePartPayload(amount, expiry, randomBytes32())).first.copy(commit = false) + val cmd = OutgoingPaymentPacket.buildCommand(paymentId, paymentHash, listOf(ChannelHop(dummyKey, destination, dummyUpdate)), PaymentOnion.FinalPayload.createSinglePartPayload(amount, expiry, randomBytes32(), null)).first.copy(commit = false) return Pair(paymentPreimage, cmd) } diff --git a/src/commonTest/kotlin/fr/acinq/lightning/payment/IncomingPaymentHandlerTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/payment/IncomingPaymentHandlerTestsCommon.kt index 6ea4321d4..eaf92d283 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/payment/IncomingPaymentHandlerTestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/payment/IncomingPaymentHandlerTestsCommon.kt @@ -179,7 +179,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { payToOpenFeeSatoshis = 100.sat, paymentHash = ByteVector32.One, // <-- not associated to a pending invoice expireAt = Long.MAX_VALUE, - finalPacket = OutgoingPacket.buildPacket( + finalPacket = OutgoingPaymentPacket.buildPacket( paymentHash = ByteVector32.One, // <-- has to be the same as the one above otherwise encryption fails hops = channelHops(paymentHandler.nodeParams.nodeId), finalPayload = makeMppPayload(defaultAmount, defaultAmount, randomBytes32()), @@ -196,7 +196,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { payToOpenRequest.chainHash, payToOpenRequest.paymentHash, PayToOpenResponse.Result.Failure( - OutgoingPacket.buildHtlcFailure( + OutgoingPaymentPacket.buildHtlcFailure( paymentHandler.nodeParams.nodePrivateKey, payToOpenRequest.paymentHash, payToOpenRequest.finalPacket, @@ -222,7 +222,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { payToOpenRequest.chainHash, payToOpenRequest.paymentHash, PayToOpenResponse.Result.Failure( - OutgoingPacket.buildHtlcFailure( + OutgoingPaymentPacket.buildHtlcFailure( paymentHandler.nodeParams.nodePrivateKey, payToOpenRequest.paymentHash, payToOpenRequest.finalPacket, @@ -248,7 +248,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { payToOpenRequest.chainHash, payToOpenRequest.paymentHash, PayToOpenResponse.Result.Failure( - OutgoingPacket.buildHtlcFailure( + OutgoingPaymentPacket.buildHtlcFailure( paymentHandler.nodeParams.nodePrivateKey, payToOpenRequest.paymentHash, payToOpenRequest.finalPacket, @@ -281,7 +281,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { payToOpenRequest.chainHash, payToOpenRequest.paymentHash, PayToOpenResponse.Result.Failure( - OutgoingPacket.buildHtlcFailure( + OutgoingPaymentPacket.buildHtlcFailure( paymentHandler.nodeParams.nodePrivateKey, payToOpenRequest.paymentHash, payToOpenRequest.finalPacket, @@ -307,7 +307,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { payToOpenFeeSatoshis = 100.sat, paymentHash = incomingPayment.paymentHash, expireAt = Long.MAX_VALUE, - finalPacket = OutgoingPacket.buildPacket( + finalPacket = OutgoingPaymentPacket.buildPacket( paymentHash = incomingPayment.paymentHash, hops = trampolineHops, finalPayload = makeMppPayload(defaultAmount, defaultAmount, paymentSecret.reversed()), // <-- wrong secret @@ -324,7 +324,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { payToOpenRequest.chainHash, payToOpenRequest.paymentHash, PayToOpenResponse.Result.Failure( - OutgoingPacket.buildHtlcFailure( + OutgoingPaymentPacket.buildHtlcFailure( paymentHandler.nodeParams.nodePrivateKey, payToOpenRequest.paymentHash, payToOpenRequest.finalPacket, @@ -581,7 +581,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { payToOpenRequest.chainHash, payToOpenRequest.paymentHash, PayToOpenResponse.Result.Failure( - OutgoingPacket.buildHtlcFailure( + OutgoingPaymentPacket.buildHtlcFailure( paymentHandler.nodeParams.nodePrivateKey, payToOpenRequest.paymentHash, payToOpenRequest.finalPacket, @@ -640,7 +640,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { payToOpenRequest1.chainHash, payToOpenRequest1.paymentHash, PayToOpenResponse.Result.Failure( - OutgoingPacket.buildHtlcFailure( + OutgoingPaymentPacket.buildHtlcFailure( paymentHandler.nodeParams.nodePrivateKey, payToOpenRequest1.paymentHash, payToOpenRequest1.finalPacket, @@ -654,7 +654,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { payToOpenRequest2.chainHash, payToOpenRequest2.paymentHash, PayToOpenResponse.Result.Failure( - OutgoingPacket.buildHtlcFailure( + OutgoingPaymentPacket.buildHtlcFailure( paymentHandler.nodeParams.nodePrivateKey, payToOpenRequest2.paymentHash, payToOpenRequest2.finalPacket, @@ -1173,12 +1173,12 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { return listOf(channelHop) } - private fun makeCmdAddHtlc(destination: PublicKey, paymentHash: ByteVector32, finalPayload: FinalPayload): CMD_ADD_HTLC { - return OutgoingPacket.buildCommand(UUID.randomUUID(), paymentHash, channelHops(destination), finalPayload).first.copy(commit = true) + private fun makeCmdAddHtlc(destination: PublicKey, paymentHash: ByteVector32, finalPayload: PaymentOnion.FinalPayload): CMD_ADD_HTLC { + return OutgoingPaymentPacket.buildCommand(UUID.randomUUID(), paymentHash, channelHops(destination), finalPayload).first.copy(commit = true) } - private fun makeUpdateAddHtlc(id: Long, channelId: ByteVector32, destination: IncomingPaymentHandler, paymentHash: ByteVector32, finalPayload: FinalPayload): UpdateAddHtlc { - val (_, _, packetAndSecrets) = OutgoingPacket.buildPacket(paymentHash, channelHops(destination.nodeParams.nodeId), finalPayload, OnionRoutingPacket.PaymentPacketLength) + private fun makeUpdateAddHtlc(id: Long, channelId: ByteVector32, destination: IncomingPaymentHandler, paymentHash: ByteVector32, finalPayload: PaymentOnion.FinalPayload): UpdateAddHtlc { + val (_, _, packetAndSecrets) = OutgoingPaymentPacket.buildPacket(paymentHash, channelHops(destination.nodeParams.nodeId), finalPayload, OnionRoutingPacket.PaymentPacketLength) return UpdateAddHtlc(channelId, id, finalPayload.amount, paymentHash, finalPayload.expiry, packetAndSecrets.packet) } @@ -1188,12 +1188,12 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { paymentSecret: ByteVector32, cltvExpiryDelta: CltvExpiryDelta = CltvExpiryDelta(144), currentBlockHeight: Int = TestConstants.defaultBlockHeight - ): FinalPayload { + ): PaymentOnion.FinalPayload { val expiry = cltvExpiryDelta.toCltvExpiry(currentBlockHeight.toLong()) - return FinalPayload.createMultiPartPayload(amount, totalAmount, expiry, paymentSecret) + return PaymentOnion.FinalPayload.createMultiPartPayload(amount, totalAmount, expiry, paymentSecret, null) } - private fun makePayToOpenRequest(incomingPayment: IncomingPayment, finalPayload: FinalPayload, payToOpenMinAmount: MilliSatoshi = 10_000.msat): PayToOpenRequest { + private fun makePayToOpenRequest(incomingPayment: IncomingPayment, finalPayload: PaymentOnion.FinalPayload, payToOpenMinAmount: MilliSatoshi = 10_000.msat): PayToOpenRequest { return PayToOpenRequest( chainHash = ByteVector32.Zeroes, fundingSatoshis = 100_000.sat, @@ -1202,7 +1202,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { payToOpenFeeSatoshis = finalPayload.amount.truncateToSatoshi() * 0.1, // 10% paymentHash = incomingPayment.paymentHash, expireAt = Long.MAX_VALUE, - finalPacket = OutgoingPacket.buildPacket( + finalPacket = OutgoingPaymentPacket.buildPacket( paymentHash = incomingPayment.paymentHash, hops = channelHops(TestConstants.Bob.nodeParams.nodeId), finalPayload = finalPayload, @@ -1213,6 +1213,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { private suspend fun makeIncomingPayment(payee: IncomingPaymentHandler, amount: MilliSatoshi?, expirySeconds: Long? = null, timestamp: Long = currentTimestampSeconds()): Pair { val paymentRequest = payee.createInvoice(defaultPreimage, amount, "unit test", listOf(), expirySeconds, timestamp) + assertNotNull(paymentRequest.paymentMetadata) return Pair(payee.db.getIncomingPayment(paymentRequest.paymentHash)!!, paymentRequest.paymentSecret) } diff --git a/src/commonTest/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandlerTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandlerTestsCommon.kt index 27c09c24d..560738a8c 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandlerTestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandlerTestsCommon.kt @@ -231,7 +231,7 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { // The recipient should receive the right amount and expiry. val payloadBytesC = Sphinx.peel(recipientKey, payment.paymentHash, packetC, OnionRoutingPacket.TrampolinePacketLength).right!! - val payloadC = FinalPayload.read(payloadBytesC.payload.toByteArray()) + val payloadC = PaymentOnion.FinalPayload.read(payloadBytesC.payload.toByteArray()) assertEquals(200_000.msat, payloadC.amount) assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + PaymentRequest.DEFAULT_MIN_FINAL_EXPIRY_DELTA, payloadC.expiry) assertEquals(payloadC.amount, payloadC.totalAmount) @@ -286,7 +286,7 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { // The recipient should receive the right amount and expiry. val payloadBytesC = Sphinx.peel(recipientKey, payment.paymentHash, packetC, OnionRoutingPacket.TrampolinePacketLength).right!! - val payloadC = FinalPayload.read(payloadBytesC.payload.toByteArray()) + val payloadC = PaymentOnion.FinalPayload.read(payloadBytesC.payload.toByteArray()) assertEquals(300_000.msat, payloadC.amount) assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + PaymentRequest.DEFAULT_MIN_FINAL_EXPIRY_DELTA, payloadC.expiry) assertEquals(payloadC.amount, payloadC.totalAmount) @@ -423,7 +423,7 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { adds.forEach { assertEquals(payment.paymentHash, it.second.paymentHash) } adds.forEach { (channelId, add) -> // Bob should receive the right final information. - val payloadB = IncomingPacket.decrypt(makeUpdateAddHtlc(channelId, add), TestConstants.Bob.nodeParams.nodePrivateKey).right!! + val payloadB = IncomingPaymentPacket.decrypt(makeUpdateAddHtlc(channelId, add), TestConstants.Bob.nodeParams.nodePrivateKey).right!! assertEquals(add.amount, payloadB.amount) assertEquals(300_000.msat, payloadB.totalAmount) assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + PaymentRequest.DEFAULT_MIN_FINAL_EXPIRY_DELTA, payloadB.expiry) @@ -498,7 +498,7 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { // The recipient should receive the right amount and expiry. val payloadBytesC = Sphinx.peel(recipientKey, payment.paymentHash, packetC, OnionRoutingPacket.TrampolinePacketLength).right!! - val payloadC = FinalPayload.read(payloadBytesC.payload.toByteArray()) + val payloadC = PaymentOnion.FinalPayload.read(payloadBytesC.payload.toByteArray()) assertEquals(300_000.msat, payloadC.amount) assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + PaymentRequest.DEFAULT_MIN_FINAL_EXPIRY_DELTA, payloadC.expiry) assertEquals(payloadC.amount, payloadC.totalAmount) diff --git a/src/commonTest/kotlin/fr/acinq/lightning/payment/PaymentPacketTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/payment/PaymentPacketTestsCommon.kt index 3baca2e40..0073857bc 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/payment/PaymentPacketTestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/payment/PaymentPacketTestsCommon.kt @@ -1,9 +1,13 @@ package fr.acinq.lightning.payment -import fr.acinq.bitcoin.* +import fr.acinq.bitcoin.Block +import fr.acinq.bitcoin.ByteVector +import fr.acinq.bitcoin.Crypto +import fr.acinq.bitcoin.PrivateKey import fr.acinq.bitcoin.io.ByteArrayInput import fr.acinq.lightning.* import fr.acinq.lightning.Lightning.nodeFee +import fr.acinq.lightning.Lightning.randomBytes import fr.acinq.lightning.Lightning.randomBytes32 import fr.acinq.lightning.Lightning.randomBytes64 import fr.acinq.lightning.Lightning.randomKey @@ -50,6 +54,7 @@ class PaymentPacketTestsCommon : LightningTestSuite() { private val paymentPreimage = randomBytes32() private val paymentHash = Crypto.sha256(paymentPreimage).toByteVector32() private val paymentSecret = randomBytes32() + private val paymentMetadata = randomBytes(64).toByteVector() private val expiryDE = finalExpiry private val amountDE = finalAmount @@ -83,8 +88,9 @@ class PaymentPacketTestsCommon : LightningTestSuite() { ) private fun testBuildOnion() { - val finalPayload = FinalPayload(TlvStream(listOf(OnionTlv.AmountToForward(finalAmount), OnionTlv.OutgoingCltv(finalExpiry), OnionTlv.PaymentData(paymentSecret, finalAmount)))) - val (firstAmount, firstExpiry, onion) = OutgoingPacket.buildPacket(paymentHash, hops, finalPayload, OnionRoutingPacket.PaymentPacketLength) + val finalPayload = + PaymentOnion.FinalPayload(TlvStream(listOf(OnionPaymentPayloadTlv.AmountToForward(finalAmount), OnionPaymentPayloadTlv.OutgoingCltv(finalExpiry), OnionPaymentPayloadTlv.PaymentData(paymentSecret, finalAmount)))) + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket(paymentHash, hops, finalPayload, OnionRoutingPacket.PaymentPacketLength) assertEquals(amountAB, firstAmount) assertEquals(expiryAB, firstExpiry) assertEquals(OnionRoutingPacket.PaymentPacketLength, onion.packet.payload.size()) @@ -115,7 +121,7 @@ class PaymentPacketTestsCommon : LightningTestSuite() { assertEquals(channelUpdateDE.shortChannelId, payloadD.outgoingChannelId) val addE = UpdateAddHtlc(randomBytes32(), 2, amountDE, paymentHash, expiryDE, packetE) - val payloadE = IncomingPacket.decrypt(addE, privE).right!! + val payloadE = IncomingPaymentPacket.decrypt(addE, privE).right!! assertEquals(finalAmount, payloadE.amount) assertEquals(finalAmount, payloadE.totalAmount) assertEquals(finalExpiry, payloadE.expiry) @@ -123,22 +129,22 @@ class PaymentPacketTestsCommon : LightningTestSuite() { } // Wallets don't need to decrypt onions for intermediate nodes, but it's useful to test that encryption works correctly. - fun decryptChannelRelay(add: UpdateAddHtlc, privateKey: PrivateKey): Pair { + fun decryptChannelRelay(add: UpdateAddHtlc, privateKey: PrivateKey): Pair { val decrypted = Sphinx.peel(privateKey, add.paymentHash, add.onionRoutingPacket, OnionRoutingPacket.PaymentPacketLength).right!! assertFalse(decrypted.isLastPacket) - val decoded = ChannelRelayPayload.read(ByteArrayInput(decrypted.payload.toByteArray())) + val decoded = PaymentOnion.ChannelRelayPayload.read(ByteArrayInput(decrypted.payload.toByteArray())) return Pair(decoded, decrypted.nextPacket) } // Wallets don't need to decrypt onions for intermediate nodes, but it's useful to test that encryption works correctly. - fun decryptNodeRelay(add: UpdateAddHtlc, privateKey: PrivateKey): Triple { + fun decryptNodeRelay(add: UpdateAddHtlc, privateKey: PrivateKey): Triple { val decrypted = Sphinx.peel(privateKey, add.paymentHash, add.onionRoutingPacket, OnionRoutingPacket.PaymentPacketLength).right!! assertTrue(decrypted.isLastPacket) - val outerPayload = FinalPayload.read(ByteArrayInput(decrypted.payload.toByteArray())) - val trampolineOnion = outerPayload.records.get() + val outerPayload = PaymentOnion.FinalPayload.read(ByteArrayInput(decrypted.payload.toByteArray())) + val trampolineOnion = outerPayload.records.get() assertNotNull(trampolineOnion) val decryptedInner = Sphinx.peel(privateKey, add.paymentHash, trampolineOnion.packet, OnionRoutingPacket.TrampolinePacketLength).right!! - val innerPayload = NodeRelayPayload.read(ByteArrayInput(decryptedInner.payload.toByteArray())) + val innerPayload = PaymentOnion.NodeRelayPayload.read(ByteArrayInput(decryptedInner.payload.toByteArray())) return Triple(outerPayload, innerPayload, decryptedInner.nextPacket) } @@ -151,7 +157,7 @@ class PaymentPacketTestsCommon : LightningTestSuite() { @Test fun `build a command including the onion`() { - val (add, _) = OutgoingPacket.buildCommand(UUID.randomUUID(), paymentHash, hops, FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret)) + val (add, _) = OutgoingPaymentPacket.buildCommand(UUID.randomUUID(), paymentHash, hops, PaymentOnion.FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, null)) assertTrue(add.amount > finalAmount) assertEquals(add.cltvExpiry, finalExpiry + channelUpdateDE.cltvExpiryDelta + channelUpdateCD.cltvExpiryDelta + channelUpdateBC.cltvExpiryDelta) assertEquals(add.paymentHash, paymentHash) @@ -164,7 +170,7 @@ class PaymentPacketTestsCommon : LightningTestSuite() { @Test fun `build a command with no hops`() { val paymentSecret = randomBytes32() - val (add, _) = OutgoingPacket.buildCommand(UUID.randomUUID(), paymentHash, hops.take(1), FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret)) + val (add, _) = OutgoingPaymentPacket.buildCommand(UUID.randomUUID(), paymentHash, hops.take(1), PaymentOnion.FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, paymentMetadata)) assertEquals(add.amount, finalAmount) assertEquals(add.cltvExpiry, finalExpiry) assertEquals(add.paymentHash, paymentHash) @@ -172,11 +178,12 @@ class PaymentPacketTestsCommon : LightningTestSuite() { // let's peel the onion val addB = UpdateAddHtlc(randomBytes32(), 0, finalAmount, paymentHash, finalExpiry, add.onion) - val finalPayload = IncomingPacket.decrypt(addB, privB).right!! + val finalPayload = IncomingPaymentPacket.decrypt(addB, privB).right!! assertEquals(finalPayload.amount, finalAmount) assertEquals(finalPayload.totalAmount, finalAmount) assertEquals(finalPayload.expiry, finalExpiry) assertEquals(paymentSecret, finalPayload.paymentSecret) + assertEquals(paymentMetadata, finalPayload.paymentMetadata) } @Test @@ -186,19 +193,19 @@ class PaymentPacketTestsCommon : LightningTestSuite() { // / \ / \ // a -> b -> c d e - val (amountAC, expiryAC, trampolineOnion) = OutgoingPacket.buildPacket( + val (amountAC, expiryAC, trampolineOnion) = OutgoingPaymentPacket.buildPacket( paymentHash, trampolineHops, - FinalPayload.createMultiPartPayload(finalAmount, finalAmount * 3, finalExpiry, paymentSecret), + PaymentOnion.FinalPayload.createMultiPartPayload(finalAmount, finalAmount * 3, finalExpiry, paymentSecret, paymentMetadata), OnionRoutingPacket.TrampolinePacketLength ) assertEquals(amountBC, amountAC) assertEquals(expiryBC, expiryAC) - val (firstAmount, firstExpiry, onion) = OutgoingPacket.buildPacket( + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( paymentHash, trampolineChannelHops, - FinalPayload.createTrampolinePayload(amountAC, amountAC, expiryAC, randomBytes32(), trampolineOnion.packet), + PaymentOnion.FinalPayload.createTrampolinePayload(amountAC, amountAC, expiryAC, randomBytes32(), trampolineOnion.packet), OnionRoutingPacket.PaymentPacketLength ) assertEquals(amountAB, firstAmount) @@ -206,7 +213,7 @@ class PaymentPacketTestsCommon : LightningTestSuite() { val addB = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) val (payloadB, packetC) = decryptChannelRelay(addB, privB) - assertEquals(ChannelRelayPayload.create(channelUpdateBC.shortChannelId, amountBC, expiryBC), payloadB) + assertEquals(PaymentOnion.ChannelRelayPayload.create(channelUpdateBC.shortChannelId, amountBC, expiryBC), payloadB) val addC = UpdateAddHtlc(randomBytes32(), 2, amountBC, paymentHash, expiryBC, packetC) val (outerC, innerC, packetD) = decryptNodeRelay(addC, privC) @@ -221,10 +228,10 @@ class PaymentPacketTestsCommon : LightningTestSuite() { assertNull(innerC.paymentSecret) // c forwards the trampoline payment to d. - val (amountD, expiryD, onionD) = OutgoingPacket.buildPacket( + val (amountD, expiryD, onionD) = OutgoingPaymentPacket.buildPacket( paymentHash, listOf(ChannelHop(c, d, channelUpdateCD)), - FinalPayload.createTrampolinePayload(amountCD, amountCD, expiryCD, randomBytes32(), packetD), + PaymentOnion.FinalPayload.createTrampolinePayload(amountCD, amountCD, expiryCD, randomBytes32(), packetD), OnionRoutingPacket.PaymentPacketLength ) assertEquals(amountCD, amountD) @@ -242,17 +249,27 @@ class PaymentPacketTestsCommon : LightningTestSuite() { assertNull(innerD.paymentSecret) // d forwards the trampoline payment to e. - val (amountE, expiryE, onionE) = OutgoingPacket.buildPacket( + val (amountE, expiryE, onionE) = OutgoingPaymentPacket.buildPacket( paymentHash, listOf(ChannelHop(d, e, channelUpdateDE)), - FinalPayload.createTrampolinePayload(amountDE, amountDE, expiryDE, randomBytes32(), packetE), + PaymentOnion.FinalPayload.createTrampolinePayload(amountDE, amountDE, expiryDE, randomBytes32(), packetE), OnionRoutingPacket.PaymentPacketLength ) assertEquals(amountDE, amountE) assertEquals(expiryDE, expiryE) val addE = UpdateAddHtlc(randomBytes32(), 4, amountE, paymentHash, expiryE, onionE.packet) - val payloadE = IncomingPacket.decrypt(addE, privE).right!! - assertEquals(payloadE, FinalPayload(TlvStream(listOf(OnionTlv.AmountToForward(finalAmount), OnionTlv.OutgoingCltv(finalExpiry), OnionTlv.PaymentData(paymentSecret, finalAmount * 3))))) + val payloadE = IncomingPaymentPacket.decrypt(addE, privE).right!! + val expectedFinalPayload = PaymentOnion.FinalPayload( + TlvStream( + listOf( + OnionPaymentPayloadTlv.AmountToForward(finalAmount), + OnionPaymentPayloadTlv.OutgoingCltv(finalExpiry), + OnionPaymentPayloadTlv.PaymentData(paymentSecret, finalAmount * 3), + OnionPaymentPayloadTlv.PaymentMetadata(paymentMetadata) + ) + ) + ) + assertEquals(payloadE, expectedFinalPayload) } @Test @@ -268,19 +285,20 @@ class PaymentPacketTestsCommon : LightningTestSuite() { "lnbcrt", finalAmount, currentTimestampSeconds(), e, listOf( PaymentRequest.TaggedField.PaymentHash(paymentHash), PaymentRequest.TaggedField.PaymentSecret(paymentSecret), + PaymentRequest.TaggedField.PaymentMetadata(paymentMetadata), PaymentRequest.TaggedField.DescriptionHash(randomBytes32()), PaymentRequest.TaggedField.Features(invoiceFeatures.toByteArray().toByteVector()), PaymentRequest.TaggedField.RoutingInfo(routingHints) ), ByteVector.empty ) - val (amountAC, expiryAC, trampolineOnion) = OutgoingPacket.buildTrampolineToLegacyPacket(invoice, trampolineHops, FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32())) + val (amountAC, expiryAC, trampolineOnion) = OutgoingPaymentPacket.buildTrampolineToLegacyPacket(invoice, trampolineHops, PaymentOnion.FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32(), null)) assertEquals(amountBC, amountAC) assertEquals(expiryBC, expiryAC) - val (firstAmount, firstExpiry, onion) = OutgoingPacket.buildPacket( + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( paymentHash, trampolineChannelHops, - FinalPayload.createTrampolinePayload(amountAC, amountAC, expiryAC, randomBytes32(), trampolineOnion.packet), + PaymentOnion.FinalPayload.createTrampolinePayload(amountAC, amountAC, expiryAC, randomBytes32(), trampolineOnion.packet), OnionRoutingPacket.PaymentPacketLength ) assertEquals(amountAB, firstAmount) @@ -303,10 +321,10 @@ class PaymentPacketTestsCommon : LightningTestSuite() { assertNull(innerC.paymentSecret) // c forwards the trampoline payment to d. - val (amountD, expiryD, onionD) = OutgoingPacket.buildPacket( + val (amountD, expiryD, onionD) = OutgoingPaymentPacket.buildPacket( paymentHash, listOf(ChannelHop(c, d, channelUpdateCD)), - FinalPayload.createTrampolinePayload(amountCD, amountCD, expiryCD, randomBytes32(), packetD), + PaymentOnion.FinalPayload.createTrampolinePayload(amountCD, amountCD, expiryCD, randomBytes32(), packetD), OnionRoutingPacket.PaymentPacketLength ) assertEquals(amountCD, amountD) @@ -322,6 +340,7 @@ class PaymentPacketTestsCommon : LightningTestSuite() { assertEquals(e, innerD.outgoingNodeId) assertEquals(finalAmount, innerD.totalAmount) assertEquals(invoice.paymentSecret, innerD.paymentSecret) + assertEquals(invoice.paymentMetadata, innerD.paymentMetadata) assertEquals(ByteVector("024100"), innerD.invoiceFeatures) // var_onion_optin, payment_secret, basic_mpp assertEquals(listOf(routingHints), innerD.invoiceRoutingInfo) } @@ -340,134 +359,149 @@ class PaymentPacketTestsCommon : LightningTestSuite() { PaymentRequest.TaggedField.RoutingInfo(routingHintOverflow) ), ByteVector.empty ) - assertFails { OutgoingPacket.buildTrampolineToLegacyPacket(invoice, trampolineHops, FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32())) } + assertFails { OutgoingPaymentPacket.buildTrampolineToLegacyPacket(invoice, trampolineHops, PaymentOnion.FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32(), null)) } } @Test fun `fail to decrypt when the onion is invalid`() { - val (firstAmount, firstExpiry, onion) = OutgoingPacket.buildPacket(paymentHash, hops, FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32()), OnionRoutingPacket.PaymentPacketLength) + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket(paymentHash, hops, PaymentOnion.FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32(), null), OnionRoutingPacket.PaymentPacketLength) val add = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet.copy(payload = onion.packet.payload.reversed())) - val failure = IncomingPacket.decrypt(add, privB) + val failure = IncomingPaymentPacket.decrypt(add, privB) assertTrue(failure.isLeft) assertEquals(InvalidOnionHmac.code, failure.left!!.code) } @Test fun `fail to decrypt when the trampoline onion is invalid`() { - val (amountAC, expiryAC, trampolineOnion) = OutgoingPacket.buildPacket( + val (amountAC, expiryAC, trampolineOnion) = OutgoingPaymentPacket.buildPacket( paymentHash, trampolineHops, - FinalPayload.createMultiPartPayload(finalAmount, finalAmount * 2, finalExpiry, paymentSecret), + PaymentOnion.FinalPayload.createMultiPartPayload(finalAmount, finalAmount * 2, finalExpiry, paymentSecret, null), OnionRoutingPacket.TrampolinePacketLength ) - val (firstAmount, firstExpiry, onion) = OutgoingPacket.buildPacket( + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( paymentHash, trampolineChannelHops, - FinalPayload.createTrampolinePayload(amountAC, amountAC, expiryAC, randomBytes32(), trampolineOnion.packet.copy(payload = trampolineOnion.packet.payload.reversed())), + PaymentOnion.FinalPayload.createTrampolinePayload(amountAC, amountAC, expiryAC, randomBytes32(), trampolineOnion.packet.copy(payload = trampolineOnion.packet.payload.reversed())), OnionRoutingPacket.PaymentPacketLength ) val addB = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) val (_, packetC) = decryptChannelRelay(addB, privB) val addC = UpdateAddHtlc(randomBytes32(), 2, amountBC, paymentHash, expiryBC, packetC) - val failure = IncomingPacket.decrypt(addC, privC) + val failure = IncomingPaymentPacket.decrypt(addC, privC) assertTrue(failure.isLeft) assertEquals(InvalidOnionHmac.code, failure.left!!.code) } @Test fun `fail to decrypt when payment hash doesn't match associated data`() { - val (firstAmount, firstExpiry, onion) = OutgoingPacket.buildPacket(paymentHash.reversed(), hops, FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32()), OnionRoutingPacket.PaymentPacketLength) + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( + paymentHash.reversed(), + hops, + PaymentOnion.FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32(), paymentMetadata), + OnionRoutingPacket.PaymentPacketLength + ) val add = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) - val failure = IncomingPacket.decrypt(add, privB) + val failure = IncomingPaymentPacket.decrypt(add, privB) assertTrue(failure.isLeft) assertEquals(InvalidOnionHmac.code, failure.left!!.code) } @Test fun `fail to decrypt at the final node when amount has been modified by next-to-last node`() { - val (firstAmount, firstExpiry, onion) = OutgoingPacket.buildPacket(paymentHash, hops.take(1), FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32()), OnionRoutingPacket.PaymentPacketLength) + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( + paymentHash, + hops.take(1), + PaymentOnion.FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32(), paymentMetadata), + OnionRoutingPacket.PaymentPacketLength + ) val add = UpdateAddHtlc(randomBytes32(), 1, firstAmount - 100.msat, paymentHash, firstExpiry, onion.packet) - val failure = IncomingPacket.decrypt(add, privB) + val failure = IncomingPaymentPacket.decrypt(add, privB) assertEquals(Either.Left(FinalIncorrectHtlcAmount(firstAmount - 100.msat)), failure) } @Test fun `fail to decrypt at the final node when expiry has been modified by next-to-last node`() { - val (firstAmount, firstExpiry, onion) = OutgoingPacket.buildPacket(paymentHash, hops.take(1), FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32()), OnionRoutingPacket.PaymentPacketLength) + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( + paymentHash, + hops.take(1), + PaymentOnion.FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32(), paymentMetadata), + OnionRoutingPacket.PaymentPacketLength + ) val add = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry - CltvExpiryDelta(12), onion.packet) - val failure = IncomingPacket.decrypt(add, privB) + val failure = IncomingPaymentPacket.decrypt(add, privB) assertEquals(Either.Left(FinalIncorrectCltvExpiry(firstExpiry - CltvExpiryDelta(12))), failure) } @Test fun `fail to decrypt at the final trampoline node when amount has been modified by next-to-last trampoline`() { - val (amountAC, expiryAC, trampolineOnion) = OutgoingPacket.buildPacket( + val (amountAC, expiryAC, trampolineOnion) = OutgoingPaymentPacket.buildPacket( paymentHash, trampolineHops, - FinalPayload.createMultiPartPayload(finalAmount, finalAmount, finalExpiry, paymentSecret), + PaymentOnion.FinalPayload.createMultiPartPayload(finalAmount, finalAmount, finalExpiry, paymentSecret, null), OnionRoutingPacket.TrampolinePacketLength ) - val (firstAmount, firstExpiry, onion) = OutgoingPacket.buildPacket( + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( paymentHash, trampolineChannelHops, - FinalPayload.createTrampolinePayload(amountAC, amountAC, expiryAC, randomBytes32(), trampolineOnion.packet), + PaymentOnion.FinalPayload.createTrampolinePayload(amountAC, amountAC, expiryAC, randomBytes32(), trampolineOnion.packet), OnionRoutingPacket.PaymentPacketLength ) val (_, packetC) = decryptChannelRelay(UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet), privB) val (_, _, packetD) = decryptNodeRelay(UpdateAddHtlc(randomBytes32(), 2, amountBC, paymentHash, expiryBC, packetC), privC) // c forwards the trampoline payment to d. - val (amountD, expiryD, onionD) = OutgoingPacket.buildPacket( + val (amountD, expiryD, onionD) = OutgoingPaymentPacket.buildPacket( paymentHash, listOf(ChannelHop(c, d, channelUpdateCD)), - FinalPayload.createTrampolinePayload(amountCD, amountCD, expiryCD, randomBytes32(), packetD), + PaymentOnion.FinalPayload.createTrampolinePayload(amountCD, amountCD, expiryCD, randomBytes32(), packetD), OnionRoutingPacket.PaymentPacketLength ) val (_, _, packetE) = decryptNodeRelay(UpdateAddHtlc(randomBytes32(), 3, amountD, paymentHash, expiryD, onionD.packet), privD) // d forwards an invalid amount to e (the outer total amount doesn't match the inner amount). val invalidTotalAmount = amountDE + 100.msat - val (amountE, expiryE, onionE) = OutgoingPacket.buildPacket( + val (amountE, expiryE, onionE) = OutgoingPaymentPacket.buildPacket( paymentHash, listOf(ChannelHop(d, e, channelUpdateDE)), - FinalPayload.createTrampolinePayload(amountDE, invalidTotalAmount, expiryDE, randomBytes32(), packetE), + PaymentOnion.FinalPayload.createTrampolinePayload(amountDE, invalidTotalAmount, expiryDE, randomBytes32(), packetE), OnionRoutingPacket.PaymentPacketLength ) - val failure = IncomingPacket.decrypt(UpdateAddHtlc(randomBytes32(), 4, amountE, paymentHash, expiryE, onionE.packet), privE) + val failure = IncomingPaymentPacket.decrypt(UpdateAddHtlc(randomBytes32(), 4, amountE, paymentHash, expiryE, onionE.packet), privE) assertEquals(Either.Left(FinalIncorrectHtlcAmount(invalidTotalAmount)), failure) } @Test fun `fail to decrypt at the final trampoline node when expiry has been modified by next-to-last trampoline`() { - val (amountAC, expiryAC, trampolineOnion) = OutgoingPacket.buildPacket( + val (amountAC, expiryAC, trampolineOnion) = OutgoingPaymentPacket.buildPacket( paymentHash, trampolineHops, - FinalPayload.createMultiPartPayload(finalAmount, finalAmount, finalExpiry, paymentSecret), + PaymentOnion.FinalPayload.createMultiPartPayload(finalAmount, finalAmount, finalExpiry, paymentSecret, null), OnionRoutingPacket.TrampolinePacketLength ) - val (firstAmount, firstExpiry, onion) = OutgoingPacket.buildPacket( + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( paymentHash, trampolineChannelHops, - FinalPayload.createTrampolinePayload(amountAC, amountAC, expiryAC, randomBytes32(), trampolineOnion.packet), + PaymentOnion.FinalPayload.createTrampolinePayload(amountAC, amountAC, expiryAC, randomBytes32(), trampolineOnion.packet), OnionRoutingPacket.PaymentPacketLength ) val (_, packetC) = decryptChannelRelay(UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet), privB) val (_, _, packetD) = decryptNodeRelay(UpdateAddHtlc(randomBytes32(), 2, amountBC, paymentHash, expiryBC, packetC), privC) // c forwards the trampoline payment to d. - val (amountD, expiryD, onionD) = OutgoingPacket.buildPacket( + val (amountD, expiryD, onionD) = OutgoingPaymentPacket.buildPacket( paymentHash, listOf(ChannelHop(c, d, channelUpdateCD)), - FinalPayload.createTrampolinePayload(amountCD, amountCD, expiryCD, randomBytes32(), packetD), + PaymentOnion.FinalPayload.createTrampolinePayload(amountCD, amountCD, expiryCD, randomBytes32(), packetD), OnionRoutingPacket.PaymentPacketLength ) val (_, _, packetE) = decryptNodeRelay(UpdateAddHtlc(randomBytes32(), 3, amountD, paymentHash, expiryD, onionD.packet), privD) // d forwards an invalid expiry to e (the outer expiry doesn't match the inner expiry). val invalidExpiry = expiryDE - CltvExpiryDelta(12) - val (amountE, expiryE, onionE) = OutgoingPacket.buildPacket( + val (amountE, expiryE, onionE) = OutgoingPaymentPacket.buildPacket( paymentHash, listOf(ChannelHop(d, e, channelUpdateDE)), - FinalPayload.createTrampolinePayload(amountDE, amountDE, invalidExpiry, randomBytes32(), packetE), + PaymentOnion.FinalPayload.createTrampolinePayload(amountDE, amountDE, invalidExpiry, randomBytes32(), packetE), OnionRoutingPacket.PaymentPacketLength ) - val failure = IncomingPacket.decrypt(UpdateAddHtlc(randomBytes32(), 4, amountE, paymentHash, expiryE, onionE.packet), privE) + val failure = IncomingPaymentPacket.decrypt(UpdateAddHtlc(randomBytes32(), 4, amountE, paymentHash, expiryE, onionE.packet), privE) assertEquals(Either.Left(FinalIncorrectCltvExpiry(invalidExpiry)), failure) } } diff --git a/src/commonTest/kotlin/fr/acinq/lightning/payment/PaymentRequestTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/payment/PaymentRequestTestsCommon.kt index a300b8ff1..3c69d70d7 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/payment/PaymentRequestTestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/payment/PaymentRequestTestsCommon.kt @@ -3,6 +3,7 @@ package fr.acinq.lightning.payment import fr.acinq.bitcoin.* import fr.acinq.lightning.* import fr.acinq.lightning.Lightning.randomBytes32 +import fr.acinq.lightning.Lightning.randomKey import fr.acinq.lightning.tests.utils.LightningTestSuite import fr.acinq.lightning.utils.* import fr.acinq.secp256k1.Hex @@ -362,6 +363,24 @@ class PaymentRequestTestsCommon : LightningTestSuite() { assertEquals(ref, check) } + @Test + fun `On mainnet, please send 0,01 BTC with payment metadata 0x01fafaf0`() { + val ref = "lnbc10m1pvjluezpp5qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypqdp9wpshjmt9de6zqmt9w3skgct5vysxjmnnd9jx2mq8q8a04uqsp5zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zygs9q2gqqqqqqsgq7hf8he7ecf7n4ffphs6awl9t6676rrclv9ckg3d3ncn7fct63p6s365duk5wrk202cfy3aj5xnnp5gs3vrdvruverwwq7yzhkf5a3xqpd05wjc" + val pr = PaymentRequest.read(ref) + assertEquals(pr.prefix, "lnbc") + assertEquals(pr.amount, MilliSatoshi(1000000000)) + assertEquals(pr.paymentHash, ByteVector32("0001020304050607080900010203040506070809000102030405060708090102")) + assertEquals(pr.features, Features(Feature.VariableLengthOnion to FeatureSupport.Mandatory, Feature.PaymentSecret to FeatureSupport.Mandatory, Feature.PaymentMetadata to FeatureSupport.Mandatory).toByteArray().toByteVector()) + assertEquals(pr.timestampSeconds, 1496314658L) + assertEquals(pr.nodeId, PublicKey.fromHex("03e7156ae33b0a208d0744199163177e909e80176e55d97a2f221ede0f934dd9ad")) + assertEquals(pr.paymentSecret, ByteVector32("1111111111111111111111111111111111111111111111111111111111111111")) + assertEquals(pr.description, "payment metadata inside") + assertEquals(pr.paymentMetadata, ByteVector("01fafaf0")) + assertEquals(pr.tags.size, 5) + val check = pr.sign(priv).write() + assertEquals(ref, check) + } + @Test fun `reject invalid invoices`() { val refs = listOf( @@ -408,6 +427,16 @@ class PaymentRequestTestsCommon : LightningTestSuite() { assertEquals(21, unknownTag!!.tag) } + @Test + fun `filter non-invoice features`() { + val nodeFeatures = Features( + mapOf(Feature.VariableLengthOnion to FeatureSupport.Mandatory, Feature.PaymentSecret to FeatureSupport.Mandatory, Feature.ShutdownAnySegwit to FeatureSupport.Optional), + setOf(UnknownFeature(103), UnknownFeature(256)) + ) + val pr = PaymentRequest.create(Block.LivenetGenesisBlock.hash, 500.msat, randomBytes32(), randomKey(), "non-invoice features", CltvExpiryDelta(6), nodeFeatures) + assertEquals(Features(Feature.VariableLengthOnion to FeatureSupport.Mandatory, Feature.PaymentSecret to FeatureSupport.Mandatory), Features(pr.features)) + } + @Test fun `feature bits to minimally-encoded feature bytes`() { val testCases = listOf( @@ -436,8 +465,7 @@ class PaymentRequestTestsCommon : LightningTestSuite() { @Test fun `payment secret`() { - val features = - Features(Feature.VariableLengthOnion to FeatureSupport.Mandatory, Feature.PaymentSecret to FeatureSupport.Mandatory, Feature.BasicMultiPartPayment to FeatureSupport.Optional) + val features = Features(Feature.VariableLengthOnion to FeatureSupport.Mandatory, Feature.PaymentSecret to FeatureSupport.Mandatory, Feature.BasicMultiPartPayment to FeatureSupport.Optional) val pr = PaymentRequest.create(Block.LivenetGenesisBlock.hash, 123.msat, ByteVector32.One, priv, "Some invoice", CltvExpiryDelta(18), features) assertNotNull(pr.paymentSecret) assertEquals(ByteVector("024100"), pr.features) @@ -465,28 +493,4 @@ class PaymentRequestTestsCommon : LightningTestSuite() { } } - @Test - fun filterFeatures() { - assertEquals( - expected = PaymentRequest.invoiceFeatures( - Features( - activated = mapOf( - Feature.InitialRoutingSync to FeatureSupport.Optional, - Feature.StaticRemoteKey to FeatureSupport.Mandatory, - Feature.PaymentSecret to FeatureSupport.Mandatory, - Feature.TrampolinePayment to FeatureSupport.Optional, - ), - unknown = setOf( - UnknownFeature(47) - ) - ) - ), - actual = Features( - activated = mapOf( - Feature.PaymentSecret to FeatureSupport.Mandatory, - Feature.TrampolinePayment to FeatureSupport.Optional - ) - ) - ) - } } \ No newline at end of file diff --git a/src/commonTest/kotlin/fr/acinq/lightning/wire/OnionTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/wire/PaymentOnionTestsCommon.kt similarity index 76% rename from src/commonTest/kotlin/fr/acinq/lightning/wire/OnionTestsCommon.kt rename to src/commonTest/kotlin/fr/acinq/lightning/wire/PaymentOnionTestsCommon.kt index 5e95bf964..d73185a1e 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/wire/OnionTestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/wire/PaymentOnionTestsCommon.kt @@ -17,7 +17,7 @@ import kotlin.test.assertEquals import kotlin.test.assertFails import kotlin.test.assertNull -class OnionTestsCommon : LightningTestSuite() { +class PaymentOnionTestsCommon : LightningTestSuite() { @Test fun `encode - decode onion packet`() { val bin = Hex.decode( @@ -40,14 +40,14 @@ class OnionTestsCommon : LightningTestSuite() { @Test fun `encode - decode channel relay per-hop payload`() { val testCases = mapOf( - ChannelRelayPayload.create(ShortChannelId(0), MilliSatoshi(0), CltvExpiry(0)) to Hex.decode("0e 0200 0400 06080000000000000000"), - ChannelRelayPayload.create(ShortChannelId(42), MilliSatoshi(142000), CltvExpiry(500000)) to Hex.decode("14 0203022ab0 040307a120 0608000000000000002a"), - ChannelRelayPayload.create(ShortChannelId(561), MilliSatoshi(1105), CltvExpiry(1729)) to Hex.decode("12 02020451 040206c1 06080000000000000231") + PaymentOnion.ChannelRelayPayload.create(ShortChannelId(0), MilliSatoshi(0), CltvExpiry(0)) to Hex.decode("0e 0200 0400 06080000000000000000"), + PaymentOnion.ChannelRelayPayload.create(ShortChannelId(42), MilliSatoshi(142000), CltvExpiry(500000)) to Hex.decode("14 0203022ab0 040307a120 0608000000000000002a"), + PaymentOnion.ChannelRelayPayload.create(ShortChannelId(561), MilliSatoshi(1105), CltvExpiry(1729)) to Hex.decode("12 02020451 040206c1 06080000000000000231") ) testCases.forEach { val expected = it.key - val decoded = ChannelRelayPayload.read(it.value) + val decoded = PaymentOnion.ChannelRelayPayload.read(it.value) assertEquals(expected, decoded) val encoded = decoded.write() assertArrayEquals(it.value, encoded) @@ -57,10 +57,10 @@ class OnionTestsCommon : LightningTestSuite() { @Test fun `encode - decode variable-length (tlv) node relay per-hop payload`() { val nodeId = PublicKey(Hex.decode("02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619")) - val expected = NodeRelayPayload(TlvStream(listOf(OnionTlv.AmountToForward(561.msat), OnionTlv.OutgoingCltv(CltvExpiry(42)), OnionTlv.OutgoingNodeId(nodeId)))) + val expected = PaymentOnion.NodeRelayPayload(TlvStream(listOf(OnionPaymentPayloadTlv.AmountToForward(561.msat), OnionPaymentPayloadTlv.OutgoingCltv(CltvExpiry(42)), OnionPaymentPayloadTlv.OutgoingNodeId(nodeId)))) val bin = Hex.decode("2e 02020231 04012a fe000102322102eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619") - val decoded = NodeRelayPayload.read(bin) + val decoded = PaymentOnion.NodeRelayPayload.read(bin) assertEquals(expected, decoded) assertEquals(decoded.amountToForward, 561.msat) assertEquals(decoded.totalAmount, 561.msat) @@ -85,22 +85,22 @@ class OnionTestsCommon : LightningTestSuite() { listOf(PaymentRequest.TaggedField.ExtraHop(node1, ShortChannelId(1), 10.msat, 100, CltvExpiryDelta(144))), listOf(PaymentRequest.TaggedField.ExtraHop(node2, ShortChannelId(2), 20.msat, 150, CltvExpiryDelta(12)), PaymentRequest.TaggedField.ExtraHop(node3, ShortChannelId(3), 30.msat, 200, CltvExpiryDelta(24))) ) - val expected = NodeRelayPayload( + val expected = PaymentOnion.NodeRelayPayload( TlvStream( listOf( - OnionTlv.AmountToForward(561.msat), - OnionTlv.OutgoingCltv(CltvExpiry(42)), - OnionTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 1105.msat), - OnionTlv.InvoiceFeatures(features), - OnionTlv.OutgoingNodeId(nodeId), - OnionTlv.InvoiceRoutingInfo(routingHints) + OnionPaymentPayloadTlv.AmountToForward(561.msat), + OnionPaymentPayloadTlv.OutgoingCltv(CltvExpiry(42)), + OnionPaymentPayloadTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 1105.msat), + OnionPaymentPayloadTlv.InvoiceFeatures(features), + OnionPaymentPayloadTlv.OutgoingNodeId(nodeId), + OnionPaymentPayloadTlv.InvoiceRoutingInfo(routingHints) ) ) ) val bin = Hex.decode("fa 02020231 04012a 0822eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f2836866190451 fe00010231010a fe000102322102eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619 fe000102339b01036d6caac248af96f6afa7f904f550253a0f3ef3f5aa2fe6838a95b216691468e200000000000000010000000a00000064009002025f7117a78150fe2ef97db7cfc83bd57b2e2c0d0dd25eaf467a4a1c2a45ce148600000000000000020000001400000096000c02a051267759c3a149e3e72372f4e0c4054ba597ebfd0eda78a2273023667205ee00000000000000030000001e000000c80018") - val decoded = NodeRelayPayload.read(bin) + val decoded = PaymentOnion.NodeRelayPayload.read(bin) assertEquals(decoded, expected) assertEquals(decoded.amountToForward, 561.msat) assertEquals(decoded.totalAmount, 1105.msat) @@ -119,54 +119,61 @@ class OnionTestsCommon : LightningTestSuite() { val testCases = mapOf( TlvStream( listOf( - OnionTlv.AmountToForward(561.msat), - OnionTlv.OutgoingCltv(CltvExpiry(42)), - OnionTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 0.msat) + OnionPaymentPayloadTlv.AmountToForward(561.msat), + OnionPaymentPayloadTlv.OutgoingCltv(CltvExpiry(42)), + OnionPaymentPayloadTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 0.msat) ) ) to Hex.decode("29 02020231 04012a 0820eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), TlvStream( listOf( - OnionTlv.AmountToForward(561.msat), - OnionTlv.OutgoingCltv(CltvExpiry(42)), - OnionTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 1105.msat) + OnionPaymentPayloadTlv.AmountToForward(561.msat), + OnionPaymentPayloadTlv.OutgoingCltv(CltvExpiry(42)), + OnionPaymentPayloadTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 1105.msat) ) ) to Hex.decode("2b 02020231 04012a 0822eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f2836866190451"), TlvStream( listOf( - OnionTlv.AmountToForward(561.msat), - OnionTlv.OutgoingCltv(CltvExpiry(42)), - OnionTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 4294967295L.msat) + OnionPaymentPayloadTlv.AmountToForward(561.msat), + OnionPaymentPayloadTlv.OutgoingCltv(CltvExpiry(42)), + OnionPaymentPayloadTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 4294967295L.msat) ) ) to Hex.decode("2d 02020231 04012a 0824eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619ffffffff"), TlvStream( listOf( - OnionTlv.AmountToForward(561.msat), - OnionTlv.OutgoingCltv(CltvExpiry(42)), - OnionTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 4294967296L.msat) + OnionPaymentPayloadTlv.AmountToForward(561.msat), + OnionPaymentPayloadTlv.OutgoingCltv(CltvExpiry(42)), + OnionPaymentPayloadTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 4294967296L.msat) ) ) to Hex.decode("2e 02020231 04012a 0825eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f2836866190100000000"), TlvStream( listOf( - OnionTlv.AmountToForward(561.msat), - OnionTlv.OutgoingCltv(CltvExpiry(42)), - OnionTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 1099511627775L.msat) + OnionPaymentPayloadTlv.AmountToForward(561.msat), + OnionPaymentPayloadTlv.OutgoingCltv(CltvExpiry(42)), + OnionPaymentPayloadTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 1099511627775L.msat) ) ) to Hex.decode("2e 02020231 04012a 0825eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619ffffffffff"), TlvStream( listOf( - OnionTlv.AmountToForward(561.msat), - OnionTlv.OutgoingCltv(CltvExpiry(42)), - OnionTlv.OutgoingChannelId(ShortChannelId(1105)), - OnionTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 0.msat) + OnionPaymentPayloadTlv.AmountToForward(561.msat), + OnionPaymentPayloadTlv.OutgoingCltv(CltvExpiry(42)), + OnionPaymentPayloadTlv.OutgoingChannelId(ShortChannelId(1105)), + OnionPaymentPayloadTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 0.msat) ) ) to Hex.decode("33 02020231 04012a 06080000000000000451 0820eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), TlvStream( - listOf(OnionTlv.AmountToForward(561.msat), OnionTlv.OutgoingCltv(CltvExpiry(42)), OnionTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 0.msat)), + listOf( + OnionPaymentPayloadTlv.AmountToForward(561.msat), + OnionPaymentPayloadTlv.OutgoingCltv(CltvExpiry(42)), + OnionPaymentPayloadTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 0.msat) + ), listOf(GenericTlv(65535, ByteVector("06c1"))) ) to Hex.decode("2f 02020231 04012a 0820eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619 fdffff0206c1"), TlvStream( listOf( - OnionTlv.AmountToForward(561.msat), OnionTlv.OutgoingCltv(CltvExpiry(42)), OnionTlv.PaymentData(ByteVector32("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), 0.msat), OnionTlv.TrampolineOnion( + OnionPaymentPayloadTlv.AmountToForward(561.msat), + OnionPaymentPayloadTlv.OutgoingCltv(CltvExpiry(42)), + OnionPaymentPayloadTlv.PaymentData(ByteVector32("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), 0.msat), + OnionPaymentPayloadTlv.TrampolineOnion( OnionRoutingPacket( 0, ByteVector("02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), @@ -182,12 +189,12 @@ class OnionTestsCommon : LightningTestSuite() { testCases.forEach { val expected = it.key - val decoded = FinalPayload.read(it.value) - assertEquals(decoded, FinalPayload(expected)) + val decoded = PaymentOnion.FinalPayload.read(it.value) + assertEquals(decoded, PaymentOnion.FinalPayload(expected)) assertEquals(decoded.amount, 561.msat) assertEquals(decoded.expiry, CltvExpiry(42)) - val encoded = FinalPayload(expected).write() + val encoded = PaymentOnion.FinalPayload(expected).write() assertArrayEquals(it.value, encoded) } } @@ -195,29 +202,29 @@ class OnionTestsCommon : LightningTestSuite() { @Test fun `encode - decode variable-length (tlv) final per-hop payload with custom user records`() { val tlvs = TlvStream( - listOf(OnionTlv.AmountToForward(561.msat), OnionTlv.OutgoingCltv(CltvExpiry(42)), OnionTlv.PaymentData(ByteVector32.Zeroes, 0.msat)), + listOf(OnionPaymentPayloadTlv.AmountToForward(561.msat), OnionPaymentPayloadTlv.OutgoingCltv(CltvExpiry(42)), OnionPaymentPayloadTlv.PaymentData(ByteVector32.Zeroes, 0.msat)), listOf(GenericTlv(5432123457L, ByteVector("16c7ec71663784ff100b6eface1e60a97b92ea9d18b8ece5e558586bc7453828"))) ) val bin = Hex.decode("53 02020231 04012a 08200000000000000000000000000000000000000000000000000000000000000000 ff0000000143c7a0412016c7ec71663784ff100b6eface1e60a97b92ea9d18b8ece5e558586bc7453828") - val encoded = FinalPayload(tlvs).write() + val encoded = PaymentOnion.FinalPayload(tlvs).write() assertArrayEquals(bin, encoded) - assertEquals(FinalPayload(tlvs), FinalPayload.read(bin)) + assertEquals(PaymentOnion.FinalPayload(tlvs), PaymentOnion.FinalPayload.read(bin)) } @Test fun `decode multi-part final per-hop payload`() { - val notMultiPart = FinalPayload.read(Hex.decode("29 02020231 04012a 08200000000000000000000000000000000000000000000000000000000000000000")) + val notMultiPart = PaymentOnion.FinalPayload.read(Hex.decode("29 02020231 04012a 08200000000000000000000000000000000000000000000000000000000000000000")) assertEquals(notMultiPart.totalAmount, 561.msat) assertEquals(notMultiPart.paymentSecret, ByteVector32.Zeroes) - val multiPart = FinalPayload.read(Hex.decode("2b 02020231 04012a 0822eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f2836866190451")) + val multiPart = PaymentOnion.FinalPayload.read(Hex.decode("2b 02020231 04012a 0822eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f2836866190451")) assertEquals(multiPart.amount, 561.msat) assertEquals(multiPart.expiry, CltvExpiry(42)) assertEquals(multiPart.totalAmount, 1105.msat) assertEquals(multiPart.paymentSecret, ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619")) - val multiPartNoTotalAmount = FinalPayload.read(Hex.decode("29 02020231 04012a 0820eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619")) + val multiPartNoTotalAmount = PaymentOnion.FinalPayload.read(Hex.decode("29 02020231 04012a 0820eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619")) assertEquals(multiPartNoTotalAmount.amount, 561.msat) assertEquals(multiPartNoTotalAmount.expiry, CltvExpiry(42)) assertEquals(multiPartNoTotalAmount.totalAmount, 561.msat) @@ -233,7 +240,7 @@ class OnionTestsCommon : LightningTestSuite() { ) testCases.forEach { - assertFails { FinalPayload.read(it) } + assertFails { PaymentOnion.FinalPayload.read(it) } } } @@ -254,9 +261,9 @@ class OnionTestsCommon : LightningTestSuite() { ) testCases.forEach { - assertFails { ChannelRelayPayload.read(it) } - assertFails { NodeRelayPayload.read(it) } - assertFails { FinalPayload.read(it) } + assertFails { PaymentOnion.ChannelRelayPayload.read(it) } + assertFails { PaymentOnion.NodeRelayPayload.read(it) } + assertFails { PaymentOnion.FinalPayload.read(it) } } } } \ No newline at end of file diff --git a/src/jvmTest/kotlin/fr/acinq/lightning/Node.kt b/src/jvmTest/kotlin/fr/acinq/lightning/Node.kt index ba9e255c8..7ed7dbf98 100644 --- a/src/jvmTest/kotlin/fr/acinq/lightning/Node.kt +++ b/src/jvmTest/kotlin/fr/acinq/lightning/Node.kt @@ -148,6 +148,7 @@ object Node { Feature.StaticRemoteKey to FeatureSupport.Optional, Feature.AnchorOutputs to FeatureSupport.Optional, Feature.ChannelType to FeatureSupport.Mandatory, + Feature.PaymentMetadata to FeatureSupport.Optional, Feature.TrampolinePayment to FeatureSupport.Optional, Feature.ZeroReserveChannels to FeatureSupport.Optional, Feature.ZeroConfChannels to FeatureSupport.Optional,