From 363ed0646c221d0b7283143510fc33218c177a52 Mon Sep 17 00:00:00 2001 From: Bastien Teinturier <31281497+t-bast@users.noreply.github.com> Date: Tue, 8 Feb 2022 15:03:58 +0100 Subject: [PATCH] Add support for `option_payment_metadata` (#313) * Filter init, node and invoice features We should explicitly filter features based on where they can be included (`init`, `node_announcement` or `invoice`) as specified in Bolt 9. We also introduce the option_payment_metadata feature which helps our test cases since it's only allowed in invoices. * Refactor onion to dedicated namespace This commit doesn't contain any logic, it simply prefixes some classes to make it obvious that they are payment-related, rename files and moves some classes. We will update the payment onion, so it was a good time to do this small refactoring which will also be necessary for onion messages. * Add support for option_payment_metadata Add support for https://github.com/lightning/bolts/pull/912 Whenever we find a payment metadata field in an invoice, we send it in the onion payload for the final recipient. We include a payment metadata in every invoice we generate. This lets us see whether our payers support it or not, which is important data to have before we make it mandatory and use it for storage-less invoices. --- .../fr/acinq/lightning/tests/TestConstants.kt | 2 + .../acinq/lightning/tests/io/peer/builders.kt | 6 +- .../kotlin/fr/acinq/lightning/Features.kt | 48 ++- .../fr/acinq/lightning/channel/Commitments.kt | 4 +- .../kotlin/fr/acinq/lightning/io/Peer.kt | 2 +- .../payment/IncomingPaymentHandler.kt | 31 +- ...mingPacket.kt => IncomingPaymentPacket.kt} | 16 +- .../payment/OutgoingPaymentHandler.kt | 10 +- ...oingPacket.kt => OutgoingPaymentPacket.kt} | 38 +- .../acinq/lightning/payment/PaymentRequest.kt | 40 +-- .../serialization/v1/Serialization.kt | 17 +- .../serialization/v2/Serialization.kt | 17 +- .../serialization/v3/Serialization.kt | 17 +- .../kotlin/fr/acinq/lightning/wire/Onion.kt | 333 ----------------- .../fr/acinq/lightning/wire/OnionRouting.kt | 55 +++ .../fr/acinq/lightning/wire/PaymentOnion.kt | 338 ++++++++++++++++++ .../fr/acinq/lightning/FeaturesTestsCommon.kt | 31 ++ .../fr/acinq/lightning/channel/TestsHelper.kt | 4 +- .../IncomingPaymentHandlerTestsCommon.kt | 37 +- .../OutgoingPaymentHandlerTestsCommon.kt | 8 +- .../payment/PaymentPacketTestsCommon.kt | 156 ++++---- .../payment/PaymentRequestTestsCommon.kt | 56 +-- ...tsCommon.kt => PaymentOnionTestsCommon.kt} | 105 +++--- src/jvmTest/kotlin/fr/acinq/lightning/Node.kt | 1 + 24 files changed, 785 insertions(+), 587 deletions(-) rename src/commonMain/kotlin/fr/acinq/lightning/payment/{IncomingPacket.kt => IncomingPaymentPacket.kt} (81%) rename src/commonMain/kotlin/fr/acinq/lightning/payment/{OutgoingPacket.kt => OutgoingPaymentPacket.kt} (75%) delete mode 100644 src/commonMain/kotlin/fr/acinq/lightning/wire/Onion.kt create mode 100644 src/commonMain/kotlin/fr/acinq/lightning/wire/OnionRouting.kt create mode 100644 src/commonMain/kotlin/fr/acinq/lightning/wire/PaymentOnion.kt rename src/commonTest/kotlin/fr/acinq/lightning/wire/{OnionTestsCommon.kt => PaymentOnionTestsCommon.kt} (76%) diff --git a/lightning-kmp-test-fixtures/src/commonMain/kotlin/fr/acinq/lightning/tests/TestConstants.kt b/lightning-kmp-test-fixtures/src/commonMain/kotlin/fr/acinq/lightning/tests/TestConstants.kt index 62324f21a..97c93ed73 100644 --- a/lightning-kmp-test-fixtures/src/commonMain/kotlin/fr/acinq/lightning/tests/TestConstants.kt +++ b/lightning-kmp-test-fixtures/src/commonMain/kotlin/fr/acinq/lightning/tests/TestConstants.kt @@ -54,6 +54,7 @@ object TestConstants { Feature.StaticRemoteKey to FeatureSupport.Mandatory, Feature.AnchorOutputs to FeatureSupport.Mandatory, Feature.ChannelType to FeatureSupport.Mandatory, + Feature.PaymentMetadata to FeatureSupport.Optional, Feature.TrampolinePayment to FeatureSupport.Optional, Feature.WakeUpNotificationProvider to FeatureSupport.Optional, Feature.PayToOpenProvider to FeatureSupport.Optional, @@ -129,6 +130,7 @@ object TestConstants { Feature.StaticRemoteKey to FeatureSupport.Mandatory, Feature.AnchorOutputs to FeatureSupport.Mandatory, Feature.ChannelType to FeatureSupport.Mandatory, + Feature.PaymentMetadata to FeatureSupport.Optional, Feature.TrampolinePayment to FeatureSupport.Optional, Feature.WakeUpNotificationClient to FeatureSupport.Optional, Feature.PayToOpenClient to FeatureSupport.Optional, diff --git a/lightning-kmp-test-fixtures/src/commonMain/kotlin/fr/acinq/lightning/tests/io/peer/builders.kt b/lightning-kmp-test-fixtures/src/commonMain/kotlin/fr/acinq/lightning/tests/io/peer/builders.kt index a548eb8e4..01ca286a3 100644 --- a/lightning-kmp-test-fixtures/src/commonMain/kotlin/fr/acinq/lightning/tests/io/peer/builders.kt +++ b/lightning-kmp-test-fixtures/src/commonMain/kotlin/fr/acinq/lightning/tests/io/peer/builders.kt @@ -65,9 +65,9 @@ public suspend fun newPeers( } // Initialize Bob with Alice's features - bob.send(BytesReceived(LightningMessage.encode(Init(features = nodeParams.first.features.toByteArray().toByteVector())))) + bob.send(BytesReceived(LightningMessage.encode(Init(features = nodeParams.first.features.initFeatures().toByteArray().toByteVector())))) // Initialize Alice with Bob's features - alice.send(BytesReceived(LightningMessage.encode(Init(features = nodeParams.second.features.toByteArray().toByteVector())))) + alice.send(BytesReceived(LightningMessage.encode(Init(features = nodeParams.second.features.initFeatures().toByteArray().toByteVector())))) // TODO update to depend on the initChannels size if (initChannels.isNotEmpty()) { @@ -124,7 +124,7 @@ public suspend fun CoroutineScope.newPeer( remotedNodeChannelState?.let { state -> // send Init from remote node - val theirInit = Init(features = state.staticParams.nodeParams.features.toByteArray().toByteVector()) + val theirInit = Init(features = state.staticParams.nodeParams.features.initFeatures().toByteArray().toByteVector()) val initMsg = LightningMessage.encode(theirInit) peer.send(BytesReceived(initMsg)) diff --git a/src/commonMain/kotlin/fr/acinq/lightning/Features.kt b/src/commonMain/kotlin/fr/acinq/lightning/Features.kt index 213edc184..1a13ce546 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/Features.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/Features.kt @@ -6,6 +6,9 @@ import fr.acinq.lightning.utils.leftPaddedCopyOf import fr.acinq.lightning.utils.or import kotlinx.serialization.Serializable +/** Feature scope as defined in Bolt 9. */ +enum class FeatureScope { Init, Node, Invoice } + enum class FeatureSupport { Mandatory { override fun toString() = "mandatory" @@ -20,6 +23,7 @@ sealed class Feature { abstract val rfcName: String abstract val mandatory: Int + abstract val scopes: Set val optional: Int get() = mandatory + 1 fun supportBit(support: FeatureSupport): Int = when (support) { @@ -33,6 +37,7 @@ sealed class Feature { object OptionDataLossProtect : Feature() { override val rfcName get() = "option_data_loss_protect" override val mandatory get() = 0 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node) } @Serializable @@ -41,66 +46,84 @@ sealed class Feature { // reserved but not used as per lightningnetwork/lightning-rfc/pull/178 override val mandatory get() = 2 + override val scopes: Set get() = setOf(FeatureScope.Init) } @Serializable object ChannelRangeQueries : Feature() { override val rfcName get() = "gossip_queries" override val mandatory get() = 6 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node) } @Serializable object VariableLengthOnion : Feature() { override val rfcName get() = "var_onion_optin" override val mandatory get() = 8 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node, FeatureScope.Invoice) } @Serializable object ChannelRangeQueriesExtended : Feature() { override val rfcName get() = "gossip_queries_ex" override val mandatory get() = 10 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node) } @Serializable object StaticRemoteKey : Feature() { override val rfcName get() = "option_static_remotekey" override val mandatory get() = 12 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node) } @Serializable object PaymentSecret : Feature() { override val rfcName get() = "payment_secret" override val mandatory get() = 14 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node, FeatureScope.Invoice) } @Serializable object BasicMultiPartPayment : Feature() { override val rfcName get() = "basic_mpp" override val mandatory get() = 16 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node, FeatureScope.Invoice) } @Serializable object Wumbo : Feature() { override val rfcName get() = "option_support_large_channel" override val mandatory get() = 18 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node) } @Serializable object AnchorOutputs : Feature() { override val rfcName get() = "option_anchor_outputs" override val mandatory get() = 20 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node) } @Serializable object ShutdownAnySegwit : Feature() { override val rfcName get() = "option_shutdown_anysegwit" override val mandatory get() = 26 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node) } @Serializable object ChannelType : Feature() { override val rfcName get() = "option_channel_type" override val mandatory get() = 44 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node) + } + + @Serializable + object PaymentMetadata : Feature() { + override val rfcName get() = "option_payment_metadata" + override val mandatory get() = 48 + override val scopes: Set get() = setOf(FeatureScope.Invoice) } // The following features have not been standardised, hence the high feature bits to avoid conflicts. @@ -109,6 +132,7 @@ sealed class Feature { object TrampolinePayment : Feature() { override val rfcName get() = "trampoline_payment" override val mandatory get() = 50 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node, FeatureScope.Invoice) } /** This feature bit should be activated when a node accepts having their channel reserve set to 0. */ @@ -116,6 +140,7 @@ sealed class Feature { object ZeroReserveChannels : Feature() { override val rfcName get() = "zero_reserve_channels" override val mandatory get() = 128 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node) } /** This feature bit should be activated when a node accepts unconfirmed channels (will set min_depth to 0 in accept_channel). */ @@ -123,6 +148,7 @@ sealed class Feature { object ZeroConfChannels : Feature() { override val rfcName get() = "zero_conf_channels" override val mandatory get() = 130 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node) } /** This feature bit should be activated when a mobile node supports waking up via push notifications. */ @@ -130,6 +156,7 @@ sealed class Feature { object WakeUpNotificationClient : Feature() { override val rfcName get() = "wake_up_notification_client" override val mandatory get() = 132 + override val scopes: Set get() = setOf(FeatureScope.Init) } /** This feature bit should be activated when a node supports waking up their peers via push notifications. */ @@ -137,6 +164,7 @@ sealed class Feature { object WakeUpNotificationProvider : Feature() { override val rfcName get() = "wake_up_notification_provider" override val mandatory get() = 134 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node) } /** This feature bit should be activated when a node accepts on-the-fly channel creation. */ @@ -144,6 +172,7 @@ sealed class Feature { object PayToOpenClient : Feature() { override val rfcName get() = "pay_to_open_client" override val mandatory get() = 136 + override val scopes: Set get() = setOf(FeatureScope.Init) } /** This feature bit should be activated when a node supports opening channels on-the-fly when liquidity is missing to receive a payment. */ @@ -151,6 +180,7 @@ sealed class Feature { object PayToOpenProvider : Feature() { override val rfcName get() = "pay_to_open_provider" override val mandatory get() = 138 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node) } /** This feature bit should be activated when a node accepts channel creation via trusted swaps-in. */ @@ -158,6 +188,7 @@ sealed class Feature { object TrustedSwapInClient : Feature() { override val rfcName get() = "trusted_swap_in_client" override val mandatory get() = 140 + override val scopes: Set get() = setOf(FeatureScope.Init) } /** This feature bit should be activated when a node supports opening channels in exchange for on-chain funds (swap-in). */ @@ -165,6 +196,7 @@ sealed class Feature { object TrustedSwapInProvider : Feature() { override val rfcName get() = "trusted_swap_in_provider" override val mandatory get() = 142 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node) } /** This feature bit should be activated when a node wants to send channel backups to their peers. */ @@ -172,6 +204,7 @@ sealed class Feature { object ChannelBackupClient : Feature() { override val rfcName get() = "channel_backup_client" override val mandatory get() = 144 + override val scopes: Set get() = setOf(FeatureScope.Init) } /** This feature bit should be activated when a node stores channel backups for their peers. */ @@ -179,6 +212,7 @@ sealed class Feature { object ChannelBackupProvider : Feature() { override val rfcName get() = "channel_backup_provider" override val mandatory get() = 146 + override val scopes: Set get() = setOf(FeatureScope.Init, FeatureScope.Node) } } @@ -188,9 +222,16 @@ data class UnknownFeature(val bitIndex: Int) @Serializable data class Features(val activated: Map, val unknown: Set = emptySet()) { - fun hasFeature(feature: Feature, support: FeatureSupport? = null): Boolean = - if (support != null) activated[feature] == support - else activated.containsKey(feature) + fun hasFeature(feature: Feature, support: FeatureSupport? = null): Boolean = when (support) { + null -> activated.containsKey(feature) + else -> activated[feature] == support + } + + fun initFeatures(): Features = Features(activated.filter { it.key.scopes.contains(FeatureScope.Init) }, unknown) + + fun nodeAnnouncementFeatures(): Features = Features(activated.filter { it.key.scopes.contains(FeatureScope.Node) }, unknown) + + fun invoiceFeatures(): Features = Features(activated.filter { it.key.scopes.contains(FeatureScope.Invoice) }, unknown) /** NB: this method is not reflexive, see [[Features.areCompatible]] if you want symmetric validation. */ fun areSupported(remoteFeatures: Features): Boolean { @@ -236,6 +277,7 @@ data class Features(val activated: Map, val unknown: Se Feature.AnchorOutputs, Feature.ShutdownAnySegwit, Feature.ChannelType, + Feature.PaymentMetadata, Feature.TrampolinePayment, Feature.ZeroReserveChannels, Feature.ZeroConfChannels, diff --git a/src/commonMain/kotlin/fr/acinq/lightning/channel/Commitments.kt b/src/commonMain/kotlin/fr/acinq/lightning/channel/Commitments.kt index 703d25b81..e2605756e 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/channel/Commitments.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/channel/Commitments.kt @@ -11,7 +11,7 @@ import fr.acinq.lightning.blockchain.fee.FeerateTolerance import fr.acinq.lightning.crypto.Generators import fr.acinq.lightning.crypto.KeyManager import fr.acinq.lightning.crypto.ShaChain -import fr.acinq.lightning.payment.OutgoingPacket +import fr.acinq.lightning.payment.OutgoingPaymentPacket import fr.acinq.lightning.transactions.CommitmentSpec import fr.acinq.lightning.transactions.Transactions import fr.acinq.lightning.transactions.Transactions.TransactionWithInputInfo.CommitTx @@ -370,7 +370,7 @@ data class Commitments( // we have already sent a fail/fulfill for this htlc alreadyProposed(localChanges.proposed, htlc.id) -> Either.Left(UnknownHtlcId(channelId, cmd.id)) else -> { - when (val result = OutgoingPacket.buildHtlcFailure(nodeSecret, htlc.paymentHash, htlc.onionRoutingPacket, cmd.reason)) { + when (val result = OutgoingPaymentPacket.buildHtlcFailure(nodeSecret, htlc.paymentHash, htlc.onionRoutingPacket, cmd.reason)) { is Either.Right -> { val fail = UpdateFailHtlc(channelId, cmd.id, result.value) val commitments1 = addLocalProposal(fail) diff --git a/src/commonMain/kotlin/fr/acinq/lightning/io/Peer.kt b/src/commonMain/kotlin/fr/acinq/lightning/io/Peer.kt index 76f76817d..ff6df1c2a 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/io/Peer.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/io/Peer.kt @@ -115,7 +115,7 @@ class Peer( private val features = nodeParams.features - private val ourInit = Init(features.toByteArray().toByteVector()) + private val ourInit = Init(features.initFeatures().toByteArray().toByteVector()) private var theirInit: Init? = null public val currentTipFlow = MutableStateFlow?>(null) diff --git a/src/commonMain/kotlin/fr/acinq/lightning/payment/IncomingPaymentHandler.kt b/src/commonMain/kotlin/fr/acinq/lightning/payment/IncomingPaymentHandler.kt index 4e51d710d..1689a12c0 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/payment/IncomingPaymentHandler.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/payment/IncomingPaymentHandler.kt @@ -1,10 +1,15 @@ package fr.acinq.lightning.payment +import fr.acinq.bitcoin.ByteVector import fr.acinq.bitcoin.ByteVector32 import fr.acinq.bitcoin.Crypto import fr.acinq.bitcoin.PrivateKey -import fr.acinq.lightning.* +import fr.acinq.lightning.CltvExpiry +import fr.acinq.lightning.Lightning.randomBytes import fr.acinq.lightning.Lightning.randomBytes32 +import fr.acinq.lightning.MilliSatoshi +import fr.acinq.lightning.NodeParams +import fr.acinq.lightning.WalletParams import fr.acinq.lightning.channel.* import fr.acinq.lightning.db.IncomingPayment import fr.acinq.lightning.db.IncomingPaymentsDb @@ -18,16 +23,16 @@ sealed class PaymentPart { abstract val amount: MilliSatoshi abstract val totalAmount: MilliSatoshi abstract val paymentHash: ByteVector32 - abstract val finalPayload: FinalPayload + abstract val finalPayload: PaymentOnion.FinalPayload } -data class HtlcPart(val htlc: UpdateAddHtlc, override val finalPayload: FinalPayload) : PaymentPart() { +data class HtlcPart(val htlc: UpdateAddHtlc, override val finalPayload: PaymentOnion.FinalPayload) : PaymentPart() { override val amount: MilliSatoshi = htlc.amountMsat override val totalAmount: MilliSatoshi = finalPayload.totalAmount override val paymentHash: ByteVector32 = htlc.paymentHash } -data class PayToOpenPart(val payToOpenRequest: PayToOpenRequest, override val finalPayload: FinalPayload) : PaymentPart() { +data class PayToOpenPart(val payToOpenRequest: PayToOpenRequest, override val finalPayload: PaymentOnion.FinalPayload) : PaymentPart() { override val amount: MilliSatoshi = payToOpenRequest.amountMsat override val totalAmount: MilliSatoshi = finalPayload.totalAmount override val paymentHash: ByteVector32 = payToOpenRequest.paymentHash @@ -75,7 +80,6 @@ class IncomingPaymentHandler(val nodeParams: NodeParams, val walletParams: Walle timestampSeconds: Long = currentTimestampSeconds() ): PaymentRequest { val paymentHash = Crypto.sha256(paymentPreimage).toByteVector32() - val invoiceFeatures = PaymentRequest.invoiceFeatures(nodeParams.features) logger.debug { "h:$paymentHash using routing hints $extraHops" } val pr = PaymentRequest.create( nodeParams.chainHash, @@ -84,8 +88,10 @@ class IncomingPaymentHandler(val nodeParams: NodeParams, val walletParams: Walle nodeParams.nodePrivateKey, description, PaymentRequest.DEFAULT_MIN_FINAL_EXPIRY_DELTA, - invoiceFeatures, + nodeParams.features.invoiceFeatures(), randomBytes32(), + // We always include a payment metadata in our invoices, which lets us test whether senders support it + ByteVector("2a"), expirySeconds, extraHops, timestampSeconds @@ -246,7 +252,10 @@ class IncomingPaymentHandler(val nodeParams: NodeParams, val walletParams: Walle } else -> { // We have received all the payment parts. - logger.info { "h:${paymentPart.paymentHash} payment received (${payment.amountReceived})" } + when (val paymentMetadata = paymentPart.finalPayload.paymentMetadata) { + null -> logger.info { "h:${paymentPart.paymentHash} payment received (${payment.amountReceived}) without payment metadata" } + else -> logger.info { "h:${paymentPart.paymentHash} payment received (${payment.amountReceived}) with payment metadata ($paymentMetadata)" } + } val (actions, receivedWith) = payment.parts.map { part -> when (part) { is HtlcPart -> { @@ -352,7 +361,7 @@ class IncomingPaymentHandler(val nodeParams: NodeParams, val walletParams: Walle fun checkPaymentsTimeout(currentTimestampSeconds: Long): List { val actions = mutableListOf() - val keysToRemove = mutableListOf() + val keysToRemove = mutableSetOf() // BOLT 04: // - MUST fail all HTLCs in the HTLC set after some reasonable timeout. @@ -379,7 +388,7 @@ class IncomingPaymentHandler(val nodeParams: NodeParams, val walletParams: Walle /** Convert an incoming htlc to a payment part abstraction. Payment parts are then summed together to reach the full payment amount. */ private fun toPaymentPart(privateKey: PrivateKey, htlc: UpdateAddHtlc): Either { // NB: IncomingPacket.decrypt does additional validation on top of IncomingPacket.decryptOnion - return when (val decrypted = IncomingPacket.decrypt(htlc, privateKey)) { + return when (val decrypted = IncomingPaymentPacket.decrypt(htlc, privateKey)) { is Either.Left -> { // Unable to decrypt onion val failureMsg = decrypted.value val action = actionForFailureMessage(failureMsg, htlc) @@ -394,7 +403,7 @@ class IncomingPaymentHandler(val nodeParams: NodeParams, val walletParams: Walle * This is very similar to the processing of a htlc, except that we only have a packet, to decrypt into a final payload. */ private fun toPaymentPart(privateKey: PrivateKey, payToOpenRequest: PayToOpenRequest): Either { - return when (val decrypted = IncomingPacket.decryptOnion(payToOpenRequest.paymentHash, payToOpenRequest.finalPacket, payToOpenRequest.finalPacket.payload.size(), privateKey)) { + return when (val decrypted = IncomingPaymentPacket.decryptOnion(payToOpenRequest.paymentHash, payToOpenRequest.finalPacket, payToOpenRequest.finalPacket.payload.size(), privateKey)) { is Either.Left -> { val failureMsg = decrypted.value val action = actionForPayToOpenFailure(privateKey, failureMsg, payToOpenRequest) @@ -424,7 +433,7 @@ class IncomingPaymentHandler(val nodeParams: NodeParams, val walletParams: Walle private fun actionForPayToOpenFailure(privateKey: PrivateKey, failure: FailureMessage, payToOpenRequest: PayToOpenRequest): PayToOpenResponseEvent { val reason = CMD_FAIL_HTLC.Reason.Failure(failure) - val encryptedReason = when (val result = OutgoingPacket.buildHtlcFailure(privateKey, payToOpenRequest.paymentHash, payToOpenRequest.finalPacket, reason)) { + val encryptedReason = when (val result = OutgoingPaymentPacket.buildHtlcFailure(privateKey, payToOpenRequest.paymentHash, payToOpenRequest.finalPacket, reason)) { is Either.Right -> result.value is Either.Left -> null } diff --git a/src/commonMain/kotlin/fr/acinq/lightning/payment/IncomingPacket.kt b/src/commonMain/kotlin/fr/acinq/lightning/payment/IncomingPaymentPacket.kt similarity index 81% rename from src/commonMain/kotlin/fr/acinq/lightning/payment/IncomingPacket.kt rename to src/commonMain/kotlin/fr/acinq/lightning/payment/IncomingPaymentPacket.kt index 6232cfc4b..662e8490b 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/payment/IncomingPacket.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/payment/IncomingPaymentPacket.kt @@ -6,7 +6,7 @@ import fr.acinq.lightning.crypto.sphinx.Sphinx import fr.acinq.lightning.utils.Either import fr.acinq.lightning.wire.* -object IncomingPacket { +object IncomingPaymentPacket { /** * Decrypt the onion packet of a received htlc. We expect to be the final recipient, and we validate that the HTLC @@ -18,12 +18,12 @@ object IncomingPacket { * - a decrypted and valid onion final payload * - or a Bolt4 failure message that can be returned to the sender if the HTLC is invalid */ - fun decrypt(add: UpdateAddHtlc, privateKey: PrivateKey): Either { + fun decrypt(add: UpdateAddHtlc, privateKey: PrivateKey): Either { return when (val decrypted = decryptOnion(add.paymentHash, add.onionRoutingPacket, OnionRoutingPacket.PaymentPacketLength, privateKey)) { is Either.Left -> Either.Left(decrypted.value) is Either.Right -> { val outer = decrypted.value - when (val trampolineOnion = outer.records.get()) { + when (val trampolineOnion = outer.records.get()) { null -> validate(add, outer) else -> { when (val inner = decryptOnion(add.paymentHash, trampolineOnion.packet, OnionRoutingPacket.TrampolinePacketLength, privateKey)) { @@ -37,7 +37,7 @@ object IncomingPacket { } @OptIn(ExperimentalUnsignedTypes::class) - fun decryptOnion(paymentHash: ByteVector32, packet: OnionRoutingPacket, packetLength: Int, privateKey: PrivateKey): Either { + fun decryptOnion(paymentHash: ByteVector32, packet: OnionRoutingPacket, packetLength: Int, privateKey: PrivateKey): Either { return when (val decrypted = Sphinx.peel(privateKey, paymentHash, packet, packetLength)) { is Either.Left -> Either.Left(decrypted.value) is Either.Right -> run { @@ -45,7 +45,7 @@ object IncomingPacket { Either.Left(UnknownNextPeer) } else { try { - Either.Right(FinalPayload.read(decrypted.value.payload.toByteArray())) + Either.Right(PaymentOnion.FinalPayload.read(decrypted.value.payload.toByteArray())) } catch (_: Throwable) { Either.Left(InvalidOnionPayload(0U, 0)) } @@ -54,7 +54,7 @@ object IncomingPacket { } } - private fun validate(add: UpdateAddHtlc, payload: FinalPayload): Either { + private fun validate(add: UpdateAddHtlc, payload: PaymentOnion.FinalPayload): Either { return when { add.amountMsat != payload.amount -> Either.Left(FinalIncorrectHtlcAmount(add.amountMsat)) add.cltvExpiry != payload.expiry -> Either.Left(FinalIncorrectCltvExpiry(add.cltvExpiry)) @@ -63,7 +63,7 @@ object IncomingPacket { } @OptIn(ExperimentalUnsignedTypes::class) - private fun validate(add: UpdateAddHtlc, outerPayload: FinalPayload, innerPayload: FinalPayload): Either { + private fun validate(add: UpdateAddHtlc, outerPayload: PaymentOnion.FinalPayload, innerPayload: PaymentOnion.FinalPayload): Either { return when { add.amountMsat != outerPayload.amount -> Either.Left(FinalIncorrectHtlcAmount(add.amountMsat)) add.cltvExpiry != outerPayload.expiry -> Either.Left(FinalIncorrectCltvExpiry(add.cltvExpiry)) @@ -74,7 +74,7 @@ object IncomingPacket { else -> { // We merge contents from the outer and inner payloads. // We must use the inner payload's total amount and payment secret because the payment may be split between multiple trampoline payments (#reckless). - Either.Right(FinalPayload.createMultiPartPayload(outerPayload.amount, innerPayload.totalAmount, outerPayload.expiry, innerPayload.paymentSecret)) + Either.Right(PaymentOnion.FinalPayload.createMultiPartPayload(outerPayload.amount, innerPayload.totalAmount, outerPayload.expiry, innerPayload.paymentSecret, innerPayload.paymentMetadata)) } } } diff --git a/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandler.kt b/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandler.kt index f3a84af4f..f40265c3d 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandler.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandler.kt @@ -294,7 +294,7 @@ class OutgoingPaymentHandler(val nodeId: PublicKey, val walletParams: WalletPara OutgoingPayment.Part.Status.Pending ) val channelHops: List = listOf(ChannelHop(nodeId, route.channel.staticParams.remoteNodeId, route.channel.channelUpdate)) - val (add, secrets) = OutgoingPacket.buildCommand(childId, request.paymentHash, channelHops, trampolinePayload.createFinalPayload(route.amount)) + val (add, secrets) = OutgoingPaymentPacket.buildCommand(childId, request.paymentHash, channelHops, trampolinePayload.createFinalPayload(route.amount)) return Triple(outgoingPayment, secrets, WrappedChannelEvent(route.channel.channelId, ChannelEvent.ExecuteCommand(add))) } @@ -312,13 +312,13 @@ class OutgoingPaymentHandler(val nodeId: PublicKey, val walletParams: WalletPara val finalExpiryDelta = request.details.paymentRequest.minFinalExpiryDelta ?: Channel.MIN_CLTV_EXPIRY_DELTA val finalExpiry = finalExpiryDelta.toCltvExpiry(currentBlockHeight.toLong()) - val finalPayload = FinalPayload.createSinglePartPayload(request.amount, finalExpiry, request.details.paymentRequest.paymentSecret) + val finalPayload = PaymentOnion.FinalPayload.createSinglePartPayload(request.amount, finalExpiry, request.details.paymentRequest.paymentSecret, request.details.paymentRequest.paymentMetadata) val invoiceFeatures = Features(request.details.paymentRequest.features) val (trampolineAmount, trampolineExpiry, trampolineOnion) = if (invoiceFeatures.hasFeature(Feature.TrampolinePayment)) { - OutgoingPacket.buildPacket(request.paymentHash, trampolineRoute, finalPayload, OnionRoutingPacket.TrampolinePacketLength) + OutgoingPaymentPacket.buildPacket(request.paymentHash, trampolineRoute, finalPayload, OnionRoutingPacket.TrampolinePacketLength) } else { - OutgoingPacket.buildTrampolineToLegacyPacket(request.details.paymentRequest, trampolineRoute, finalPayload) + OutgoingPaymentPacket.buildTrampolineToLegacyPacket(request.details.paymentRequest, trampolineRoute, finalPayload) } return Triple(trampolineAmount, trampolineExpiry, trampolineOnion.packet) } @@ -337,7 +337,7 @@ class OutgoingPaymentHandler(val nodeId: PublicKey, val walletParams: WalletPara * @param packet trampoline onion packet. */ data class TrampolinePayload(val totalAmount: MilliSatoshi, val expiry: CltvExpiry, val paymentSecret: ByteVector32, val packet: OnionRoutingPacket) { - fun createFinalPayload(partialAmount: MilliSatoshi): FinalPayload = FinalPayload.createTrampolinePayload(partialAmount, totalAmount, expiry, paymentSecret, packet) + fun createFinalPayload(partialAmount: MilliSatoshi): PaymentOnion.FinalPayload = PaymentOnion.FinalPayload.createTrampolinePayload(partialAmount, totalAmount, expiry, paymentSecret, packet) } /** diff --git a/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPacket.kt b/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentPacket.kt similarity index 75% rename from src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPacket.kt rename to src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentPacket.kt index bed2a8120..6a1e2f392 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPacket.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/payment/OutgoingPaymentPacket.kt @@ -9,7 +9,6 @@ import fr.acinq.lightning.Lightning import fr.acinq.lightning.MilliSatoshi import fr.acinq.lightning.channel.CMD_ADD_HTLC import fr.acinq.lightning.channel.CMD_FAIL_HTLC -import fr.acinq.lightning.channel.Channel import fr.acinq.lightning.crypto.sphinx.FailurePacket import fr.acinq.lightning.crypto.sphinx.PacketAndSecrets import fr.acinq.lightning.crypto.sphinx.SharedSecrets @@ -19,22 +18,24 @@ import fr.acinq.lightning.router.Hop import fr.acinq.lightning.router.NodeHop import fr.acinq.lightning.utils.Either import fr.acinq.lightning.utils.UUID -import fr.acinq.lightning.wire.* +import fr.acinq.lightning.wire.FailureMessage +import fr.acinq.lightning.wire.OnionRoutingPacket +import fr.acinq.lightning.wire.PaymentOnion -object OutgoingPacket { +object OutgoingPaymentPacket { /** * Build an encrypted onion packet from onion payloads and node public keys. */ - private fun buildOnion(nodes: List, payloads: List, associatedData: ByteVector32, payloadLength: Int): PacketAndSecrets { + private fun buildOnion(nodes: List, payloads: List, associatedData: ByteVector32, payloadLength: Int): PacketAndSecrets { require(nodes.size == payloads.size) val sessionKey = Lightning.randomKey() val payloadsBin = payloads .map { when (it) { - is ChannelRelayPayload -> it.write() - is NodeRelayPayload -> it.write() - is FinalPayload -> it.write() + is PaymentOnion.ChannelRelayPayload -> it.write() + is PaymentOnion.NodeRelayPayload -> it.write() + is PaymentOnion.FinalPayload -> it.write() } } return Sphinx.create(sessionKey, nodes, payloadsBin, associatedData, payloadLength) @@ -50,13 +51,13 @@ object OutgoingPacket { * - firstExpiry is the cltv expiry for the first htlc in the route * - a sequence of payloads that will be used to build the onion */ - private fun buildPayloads(hops: List, finalPayload: FinalPayload): Triple> { - return hops.reversed().fold(Triple(finalPayload.amount, finalPayload.expiry, listOf(finalPayload))) { triple, hop -> + private fun buildPayloads(hops: List, finalPayload: PaymentOnion.FinalPayload): Triple> { + return hops.reversed().fold(Triple(finalPayload.amount, finalPayload.expiry, listOf(finalPayload))) { triple, hop -> val (amount, expiry, payloads) = triple val payload = when (hop) { // Since we don't have any scenario where we add tlv data for intermediate hops, we use legacy payloads. - is ChannelHop -> ChannelRelayPayload.create(hop.lastUpdate.shortChannelId, amount, expiry) - is NodeHop -> NodeRelayPayload.create(amount, expiry, hop.nextNodeId) + is ChannelHop -> PaymentOnion.ChannelRelayPayload.create(hop.lastUpdate.shortChannelId, amount, expiry) + is NodeHop -> PaymentOnion.NodeRelayPayload.create(amount, expiry, hop.nextNodeId) } Triple(amount + hop.fee(amount), expiry + hop.cltvExpiryDelta, listOf(payload) + payloads) } @@ -74,13 +75,16 @@ object OutgoingPacket { * - firstExpiry is the cltv expiry for the first trampoline node in the route * - the trampoline onion to include in final payload of a normal onion */ - fun buildTrampolineToLegacyPacket(invoice: PaymentRequest, hops: List, finalPayload: FinalPayload): Triple { - val (firstAmount, firstExpiry, payloads) = hops.drop(1).reversed().fold(Triple(finalPayload.amount, finalPayload.expiry, listOf(finalPayload))) { triple, hop -> + fun buildTrampolineToLegacyPacket(invoice: PaymentRequest, hops: List, finalPayload: PaymentOnion.FinalPayload): Triple { + // NB: the final payload will never reach the recipient, since the next-to-last trampoline hop will convert that to a legacy payment + // We use the smallest final payload possible, otherwise we may overflow the trampoline onion size. + val dummyFinalPayload = PaymentOnion.FinalPayload.createSinglePartPayload(finalPayload.amount, finalPayload.expiry, finalPayload.paymentSecret, null) + val (firstAmount, firstExpiry, payloads) = hops.drop(1).reversed().fold(Triple(finalPayload.amount, finalPayload.expiry, listOf(dummyFinalPayload))) { triple, hop -> val (amount, expiry, payloads) = triple val payload = when (payloads.size) { // The next-to-last trampoline hop must include invoice data to indicate the conversion to a legacy payment. - 1 -> NodeRelayPayload.createNodeRelayToNonTrampolinePayload(finalPayload.amount, finalPayload.totalAmount, finalPayload.expiry, hop.nextNodeId, invoice) - else -> NodeRelayPayload.create(amount, expiry, hop.nextNodeId) + 1 -> PaymentOnion.NodeRelayPayload.createNodeRelayToNonTrampolinePayload(finalPayload.amount, finalPayload.totalAmount, finalPayload.expiry, hop.nextNodeId, invoice) + else -> PaymentOnion.NodeRelayPayload.create(amount, expiry, hop.nextNodeId) } Triple(amount + hop.fee(amount), expiry + hop.cltvExpiryDelta, listOf(payload) + payloads) } @@ -99,7 +103,7 @@ object OutgoingPacket { * - firstExpiry is the cltv expiry for the first htlc in the route * - the onion to include in the HTLC */ - fun buildPacket(paymentHash: ByteVector32, hops: List, finalPayload: FinalPayload, payloadLength: Int): Triple { + fun buildPacket(paymentHash: ByteVector32, hops: List, finalPayload: PaymentOnion.FinalPayload, payloadLength: Int): Triple { val (firstAmount, firstExpiry, payloads) = buildPayloads(hops.drop(1), finalPayload) val nodes = hops.map { it.nextNodeId } // BOLT 2 requires that associatedData == paymentHash @@ -112,7 +116,7 @@ object OutgoingPacket { * * @return the command and the onion shared secrets (used to decrypt the error in case of payment failure) */ - fun buildCommand(paymentId: UUID, paymentHash: ByteVector32, hops: List, finalPayload: FinalPayload): Pair { + fun buildCommand(paymentId: UUID, paymentHash: ByteVector32, hops: List, finalPayload: PaymentOnion.FinalPayload): Pair { val (firstAmount, firstExpiry, onion) = buildPacket(paymentHash, hops, finalPayload, OnionRoutingPacket.PaymentPacketLength) return Pair(CMD_ADD_HTLC(firstAmount, paymentHash, firstExpiry, onion.packet, paymentId, commit = true), onion.sharedSecrets) } diff --git a/src/commonMain/kotlin/fr/acinq/lightning/payment/PaymentRequest.kt b/src/commonMain/kotlin/fr/acinq/lightning/payment/PaymentRequest.kt index c32beadee..88062076f 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/payment/PaymentRequest.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/payment/PaymentRequest.kt @@ -28,6 +28,9 @@ data class PaymentRequest( @Transient val paymentSecret: ByteVector32 = tags.find { it is TaggedField.PaymentSecret }!!.run { (this as TaggedField.PaymentSecret).secret } + @Transient + val paymentMetadata: ByteVector? = tags.find { it is TaggedField.PaymentMetadata }?.run { (this as TaggedField.PaymentMetadata).data } + @Transient val description: String? = tags.find { it is TaggedField.Description }?.run { (this as TaggedField.Description).description } @@ -138,22 +141,6 @@ data class PaymentRequest( Block.LivenetGenesisBlock.hash to "lnbc" ) - // only some features are valid in invoices - // see 'Context' column in https://github.com/lightningnetwork/lightning-rfc/blob/master/09-features.md - private val bolt11Features = setOf( - Feature.VariableLengthOnion, - Feature.PaymentSecret, - Feature.BasicMultiPartPayment, - Feature.TrampolinePayment - ) - - /** - * This filters out all features unrelated to BOLT 11 - */ - fun invoiceFeatures(features: Features): Features { - return Features(activated = features.activated.filter { (f, _) -> bolt11Features.contains(f) }) - } - fun create( chainHash: ByteVector32, amount: MilliSatoshi?, @@ -163,6 +150,7 @@ data class PaymentRequest( minFinalCltvExpiryDelta: CltvExpiryDelta, features: Features, paymentSecret: ByteVector32 = randomBytes32(), + paymentMetadata: ByteVector? = null, expirySeconds: Long? = null, extraHops: List> = listOf(), timestampSeconds: Long = currentTimestampSeconds() @@ -173,11 +161,11 @@ data class PaymentRequest( TaggedField.Description(description), TaggedField.MinFinalCltvExpiry(minFinalCltvExpiryDelta.toLong()), TaggedField.PaymentSecret(paymentSecret), - TaggedField.Features(features.toByteArray().toByteVector()) + // We remove unknown features which could make the invoice too big. + TaggedField.Features(features.invoiceFeatures().copy(unknown = setOf()).toByteArray().toByteVector()) ) - if (expirySeconds != null) { - tags.add(TaggedField.Expiry(expirySeconds)) - } + paymentMetadata?.let { tags.add(TaggedField.PaymentMetadata(it)) } + expirySeconds?.let { tags.add(TaggedField.Expiry(it)) } if (extraHops.isNotEmpty()) { extraHops.forEach { tags.add(TaggedField.RoutingInfo(it)) } } @@ -226,6 +214,7 @@ data class PaymentRequest( when (tag) { TaggedField.PaymentHash.tag -> tags.add(kotlin.runCatching { TaggedField.PaymentHash.decode(value) }.getOrDefault(TaggedField.InvalidTag(tag, value))) TaggedField.PaymentSecret.tag -> tags.add(kotlin.runCatching { TaggedField.PaymentSecret.decode(value) }.getOrDefault(TaggedField.InvalidTag(tag, value))) + TaggedField.PaymentMetadata.tag -> tags.add(kotlin.runCatching { TaggedField.PaymentMetadata.decode(value) }.getOrDefault(TaggedField.InvalidTag(tag, value))) TaggedField.Description.tag -> tags.add(kotlin.runCatching { TaggedField.Description.decode(value) }.getOrDefault(TaggedField.InvalidTag(tag, value))) TaggedField.DescriptionHash.tag -> tags.add(kotlin.runCatching { TaggedField.DescriptionHash.decode(value) }.getOrDefault(TaggedField.InvalidTag(tag, value))) TaggedField.Expiry.tag -> tags.add(kotlin.runCatching { TaggedField.Expiry.decode(value) }.getOrDefault(TaggedField.InvalidTag(tag, value))) @@ -355,6 +344,17 @@ data class PaymentRequest( } } + @Serializable + data class PaymentMetadata(@Contextual val data: ByteVector) : TaggedField() { + override val tag: Int5 = PaymentMetadata.tag + override fun encode(): List = Bech32.eight2five(data.toByteArray()).toList() + + companion object { + const val tag: Int5 = 27 + fun decode(input: List): PaymentMetadata = PaymentMetadata(Bech32.five2eight(input.toTypedArray(), 0).toByteVector()) + } + } + /** @param expirySeconds payment expiry (in seconds) */ @Serializable data class Expiry(val expirySeconds: Long) : TaggedField() { diff --git a/src/commonMain/kotlin/fr/acinq/lightning/serialization/v1/Serialization.kt b/src/commonMain/kotlin/fr/acinq/lightning/serialization/v1/Serialization.kt index 410d95204..c6fe64a4e 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/serialization/v1/Serialization.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/serialization/v1/Serialization.kt @@ -45,14 +45,15 @@ object Serialization { subclass(ChannelTlv.ChannelVersionTlv.serializer()) subclass(ChannelTlv.ChannelOriginTlv.serializer()) subclass(InitTlv.Networks.serializer()) - subclass(OnionTlv.AmountToForward.serializer()) - subclass(OnionTlv.OutgoingCltv.serializer()) - subclass(OnionTlv.OutgoingChannelId.serializer()) - subclass(OnionTlv.PaymentData.serializer()) - subclass(OnionTlv.InvoiceFeatures.serializer()) - subclass(OnionTlv.OutgoingNodeId.serializer()) - subclass(OnionTlv.InvoiceRoutingInfo.serializer()) - subclass(OnionTlv.TrampolineOnion.serializer()) + subclass(OnionPaymentPayloadTlv.AmountToForward.serializer()) + subclass(OnionPaymentPayloadTlv.OutgoingCltv.serializer()) + subclass(OnionPaymentPayloadTlv.OutgoingChannelId.serializer()) + subclass(OnionPaymentPayloadTlv.PaymentData.serializer()) + subclass(OnionPaymentPayloadTlv.PaymentMetadata.serializer()) + subclass(OnionPaymentPayloadTlv.InvoiceFeatures.serializer()) + subclass(OnionPaymentPayloadTlv.OutgoingNodeId.serializer()) + subclass(OnionPaymentPayloadTlv.InvoiceRoutingInfo.serializer()) + subclass(OnionPaymentPayloadTlv.TrampolineOnion.serializer()) subclass(GenericTlv.serializer()) } } diff --git a/src/commonMain/kotlin/fr/acinq/lightning/serialization/v2/Serialization.kt b/src/commonMain/kotlin/fr/acinq/lightning/serialization/v2/Serialization.kt index 71451e357..37f88ca70 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/serialization/v2/Serialization.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/serialization/v2/Serialization.kt @@ -55,14 +55,15 @@ object Serialization { subclass(ChannelTlv.ChannelVersionTlv.serializer()) subclass(ChannelTlv.ChannelOriginTlv.serializer()) subclass(InitTlv.Networks.serializer()) - subclass(OnionTlv.AmountToForward.serializer()) - subclass(OnionTlv.OutgoingCltv.serializer()) - subclass(OnionTlv.OutgoingChannelId.serializer()) - subclass(OnionTlv.PaymentData.serializer()) - subclass(OnionTlv.InvoiceFeatures.serializer()) - subclass(OnionTlv.OutgoingNodeId.serializer()) - subclass(OnionTlv.InvoiceRoutingInfo.serializer()) - subclass(OnionTlv.TrampolineOnion.serializer()) + subclass(OnionPaymentPayloadTlv.AmountToForward.serializer()) + subclass(OnionPaymentPayloadTlv.OutgoingCltv.serializer()) + subclass(OnionPaymentPayloadTlv.OutgoingChannelId.serializer()) + subclass(OnionPaymentPayloadTlv.PaymentData.serializer()) + subclass(OnionPaymentPayloadTlv.PaymentMetadata.serializer()) + subclass(OnionPaymentPayloadTlv.InvoiceFeatures.serializer()) + subclass(OnionPaymentPayloadTlv.OutgoingNodeId.serializer()) + subclass(OnionPaymentPayloadTlv.InvoiceRoutingInfo.serializer()) + subclass(OnionPaymentPayloadTlv.TrampolineOnion.serializer()) subclass(GenericTlv.serializer()) } } diff --git a/src/commonMain/kotlin/fr/acinq/lightning/serialization/v3/Serialization.kt b/src/commonMain/kotlin/fr/acinq/lightning/serialization/v3/Serialization.kt index 119eb008d..1e3632fed 100644 --- a/src/commonMain/kotlin/fr/acinq/lightning/serialization/v3/Serialization.kt +++ b/src/commonMain/kotlin/fr/acinq/lightning/serialization/v3/Serialization.kt @@ -55,14 +55,15 @@ object Serialization { subclass(ChannelTlv.ChannelVersionTlv.serializer()) subclass(ChannelTlv.ChannelOriginTlv.serializer()) subclass(InitTlv.Networks.serializer()) - subclass(OnionTlv.AmountToForward.serializer()) - subclass(OnionTlv.OutgoingCltv.serializer()) - subclass(OnionTlv.OutgoingChannelId.serializer()) - subclass(OnionTlv.PaymentData.serializer()) - subclass(OnionTlv.InvoiceFeatures.serializer()) - subclass(OnionTlv.OutgoingNodeId.serializer()) - subclass(OnionTlv.InvoiceRoutingInfo.serializer()) - subclass(OnionTlv.TrampolineOnion.serializer()) + subclass(OnionPaymentPayloadTlv.AmountToForward.serializer()) + subclass(OnionPaymentPayloadTlv.OutgoingCltv.serializer()) + subclass(OnionPaymentPayloadTlv.OutgoingChannelId.serializer()) + subclass(OnionPaymentPayloadTlv.PaymentData.serializer()) + subclass(OnionPaymentPayloadTlv.PaymentMetadata.serializer()) + subclass(OnionPaymentPayloadTlv.InvoiceFeatures.serializer()) + subclass(OnionPaymentPayloadTlv.OutgoingNodeId.serializer()) + subclass(OnionPaymentPayloadTlv.InvoiceRoutingInfo.serializer()) + subclass(OnionPaymentPayloadTlv.TrampolineOnion.serializer()) subclass(GenericTlv.serializer()) } } diff --git a/src/commonMain/kotlin/fr/acinq/lightning/wire/Onion.kt b/src/commonMain/kotlin/fr/acinq/lightning/wire/Onion.kt deleted file mode 100644 index cd7933e12..000000000 --- a/src/commonMain/kotlin/fr/acinq/lightning/wire/Onion.kt +++ /dev/null @@ -1,333 +0,0 @@ -package fr.acinq.lightning.wire - -import fr.acinq.bitcoin.ByteVector -import fr.acinq.bitcoin.ByteVector32 -import fr.acinq.bitcoin.PublicKey -import fr.acinq.bitcoin.io.ByteArrayInput -import fr.acinq.bitcoin.io.ByteArrayOutput -import fr.acinq.bitcoin.io.Input -import fr.acinq.bitcoin.io.Output -import fr.acinq.lightning.CltvExpiry -import fr.acinq.lightning.CltvExpiryDelta -import fr.acinq.lightning.MilliSatoshi -import fr.acinq.lightning.ShortChannelId -import fr.acinq.lightning.payment.PaymentRequest -import fr.acinq.lightning.utils.msat -import fr.acinq.lightning.utils.toByteVector -import fr.acinq.lightning.utils.toByteVector32 -import kotlinx.serialization.Contextual -import kotlinx.serialization.Serializable - -@Serializable -data class OnionRoutingPacket( - val version: Int, - @Contextual val publicKey: ByteVector, - @Contextual val payload: ByteVector, - @Contextual val hmac: ByteVector32 -) { - companion object { - const val PaymentPacketLength = 1300 - const val TrampolinePacketLength = 400 - } -} - -/** - * @param payloadLength length of the onion-encrypted payload. - */ -@OptIn(ExperimentalUnsignedTypes::class) -class OnionRoutingPacketSerializer(private val payloadLength: Int) { - fun read(input: Input): OnionRoutingPacket { - return OnionRoutingPacket( - LightningCodecs.byte(input), - LightningCodecs.bytes(input, 33).toByteVector(), - LightningCodecs.bytes(input, payloadLength).toByteVector(), - LightningCodecs.bytes(input, 32).toByteVector32() - ) - } - - fun read(bytes: ByteArray): OnionRoutingPacket = read(ByteArrayInput(bytes)) - - fun write(message: OnionRoutingPacket, out: Output) { - LightningCodecs.writeByte(message.version, out) - LightningCodecs.writeBytes(message.publicKey, out) - LightningCodecs.writeBytes(message.payload, out) - LightningCodecs.writeBytes(message.hmac, out) - } - - fun write(message: OnionRoutingPacket): ByteArray { - val out = ByteArrayOutput() - write(message, out) - return out.toByteArray() - } -} - -@OptIn(ExperimentalUnsignedTypes::class) -@Serializable -sealed class OnionTlv : Tlv { - /** Amount to forward to the next node. */ - @Serializable - data class AmountToForward(val amount: MilliSatoshi) : OnionTlv() { - override val tag: Long get() = AmountToForward.tag - override fun write(out: Output) = LightningCodecs.writeTU64(amount.toLong(), out) - - companion object : TlvValueReader { - const val tag: Long = 2 - override fun read(input: Input): AmountToForward = AmountToForward(MilliSatoshi(LightningCodecs.tu64(input))) - } - } - - /** CLTV value to use for the HTLC offered to the next node. */ - @Serializable - data class OutgoingCltv(val cltv: CltvExpiry) : OnionTlv() { - override val tag: Long get() = OutgoingCltv.tag - override fun write(out: Output) = LightningCodecs.writeTU32(cltv.toLong().toInt(), out) - - companion object : TlvValueReader { - const val tag: Long = 4 - override fun read(input: Input): OutgoingCltv = OutgoingCltv(CltvExpiry(LightningCodecs.tu32(input).toLong())) - } - } - - /** Id of the channel to use to forward a payment to the next node. */ - @Serializable - data class OutgoingChannelId(val shortChannelId: ShortChannelId) : OnionTlv() { - override val tag: Long get() = OutgoingChannelId.tag - override fun write(out: Output) = LightningCodecs.writeU64(shortChannelId.toLong(), out) - - companion object : TlvValueReader { - const val tag: Long = 6 - override fun read(input: Input): OutgoingChannelId = OutgoingChannelId(ShortChannelId(LightningCodecs.u64(input))) - } - } - - /** - * Bolt 11 payment details (only included for the last node). - * - * @param secret payment secret specified in the Bolt 11 invoice. - * @param totalAmount total amount in multi-part payments. When missing, assumed to be equal to AmountToForward. - */ - @Serializable - data class PaymentData(@Contextual val secret: ByteVector32, val totalAmount: MilliSatoshi) : OnionTlv() { - override val tag: Long get() = PaymentData.tag - override fun write(out: Output) { - LightningCodecs.writeBytes(secret, out) - LightningCodecs.writeTU64(totalAmount.toLong(), out) - } - - companion object : TlvValueReader { - const val tag: Long = 8 - override fun read(input: Input): PaymentData = PaymentData(ByteVector32(LightningCodecs.bytes(input, 32)), MilliSatoshi(LightningCodecs.tu64(input))) - } - } - - /** - * Invoice feature bits. Only included for intermediate trampoline nodes when they should convert to a legacy payment - * because the final recipient doesn't support trampoline. - */ - @Serializable - data class InvoiceFeatures(@Contextual val features: ByteVector) : OnionTlv() { - override val tag: Long get() = InvoiceFeatures.tag - override fun write(out: Output) = LightningCodecs.writeBytes(features, out) - - companion object : TlvValueReader { - const val tag: Long = 66097 - override fun read(input: Input): InvoiceFeatures = InvoiceFeatures(ByteVector(LightningCodecs.bytes(input, input.availableBytes))) - } - } - - /** Id of the next node. */ - @Serializable - data class OutgoingNodeId(@Contextual val nodeId: PublicKey) : OnionTlv() { - override val tag: Long get() = OutgoingNodeId.tag - override fun write(out: Output) = LightningCodecs.writeBytes(nodeId.value, out) - - companion object : TlvValueReader { - const val tag: Long = 66098 - override fun read(input: Input): OutgoingNodeId = OutgoingNodeId(PublicKey(LightningCodecs.bytes(input, 33))) - } - } - - /** - * Invoice routing hints. Only included for intermediate trampoline nodes when they should convert to a legacy payment - * because the final recipient doesn't support trampoline. - */ - @Serializable - data class InvoiceRoutingInfo(val extraHops: List>) : OnionTlv() { - override val tag: Long get() = InvoiceRoutingInfo.tag - override fun write(out: Output) { - for (routeHint in extraHops) { - LightningCodecs.writeByte(routeHint.size, out) - routeHint.map { - LightningCodecs.writeBytes(it.nodeId.value, out) - LightningCodecs.writeU64(it.shortChannelId.toLong(), out) - LightningCodecs.writeU32(it.feeBase.toLong().toInt(), out) - LightningCodecs.writeU32(it.feeProportionalMillionths.toInt(), out) - LightningCodecs.writeU16(it.cltvExpiryDelta.toInt(), out) - } - } - } - - companion object : TlvValueReader { - const val tag: Long = 66099 - override fun read(input: Input): InvoiceRoutingInfo { - val extraHops = mutableListOf>() - while (input.availableBytes > 0) { - val hopCount = LightningCodecs.byte(input) - val extraHop = (0 until hopCount).map { - PaymentRequest.TaggedField.ExtraHop( - PublicKey(LightningCodecs.bytes(input, 33)), - ShortChannelId(LightningCodecs.u64(input)), - MilliSatoshi(LightningCodecs.u32(input).toLong()), - LightningCodecs.u32(input).toLong(), - CltvExpiryDelta(LightningCodecs.u16(input)) - ) - } - extraHops.add(extraHop) - } - return InvoiceRoutingInfo(extraHops) - } - } - } - - /** An encrypted trampoline onion packet. */ - @Serializable - data class TrampolineOnion(val packet: OnionRoutingPacket) : OnionTlv() { - override val tag: Long get() = TrampolineOnion.tag - override fun write(out: Output) = OnionRoutingPacketSerializer(OnionRoutingPacket.TrampolinePacketLength).write(packet, out) - - companion object : TlvValueReader { - const val tag: Long = 66100 - override fun read(input: Input): TrampolineOnion = TrampolineOnion(OnionRoutingPacketSerializer(OnionRoutingPacket.TrampolinePacketLength).read(input)) - } - } -} - -sealed class PerHopPayload { - - abstract fun write(out: Output) - - fun write(): ByteArray { - val out = ByteArrayOutput() - write(out) - return out.toByteArray() - } - - companion object { - val tlvSerializer = TlvStreamSerializer( - true, - @Suppress("UNCHECKED_CAST") - mapOf( - OnionTlv.AmountToForward.tag to OnionTlv.AmountToForward.Companion as TlvValueReader, - OnionTlv.OutgoingCltv.tag to OnionTlv.OutgoingCltv.Companion as TlvValueReader, - OnionTlv.OutgoingChannelId.tag to OnionTlv.OutgoingChannelId.Companion as TlvValueReader, - OnionTlv.PaymentData.tag to OnionTlv.PaymentData.Companion as TlvValueReader, - OnionTlv.InvoiceFeatures.tag to OnionTlv.InvoiceFeatures.Companion as TlvValueReader, - OnionTlv.OutgoingNodeId.tag to OnionTlv.OutgoingNodeId.Companion as TlvValueReader, - OnionTlv.InvoiceRoutingInfo.tag to OnionTlv.InvoiceRoutingInfo.Companion as TlvValueReader, - OnionTlv.TrampolineOnion.tag to OnionTlv.TrampolineOnion.Companion as TlvValueReader, - ) - ) - } -} - -interface PerHopPayloadReader { - fun read(input: Input): T - fun read(bytes: ByteArray): T = read(ByteArrayInput(bytes)) -} - -data class FinalPayload(val records: TlvStream) : PerHopPayload() { - val amount = records.get()!!.amount - val expiry = records.get()!!.cltv - val paymentSecret = records.get()!!.secret - val totalAmount = run { - val total = records.get()!!.totalAmount - if (total > 0.msat) total else amount - } - - override fun write(out: Output) = tlvSerializer.write(records, out) - - companion object : PerHopPayloadReader { - override fun read(input: Input): FinalPayload = FinalPayload(tlvSerializer.read(input)) - - /** Create a single-part payment (total amount sent at once). */ - fun createSinglePartPayload(amount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32, userCustomTlvs: List = listOf()): FinalPayload = - FinalPayload(TlvStream(listOf(OnionTlv.AmountToForward(amount), OnionTlv.OutgoingCltv(expiry), OnionTlv.PaymentData(paymentSecret, amount)), userCustomTlvs)) - - /** Create a partial payment (total amount split between multiple payments). */ - fun createMultiPartPayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32, additionalTlvs: List = listOf(), userCustomTlvs: List = listOf()): FinalPayload = - FinalPayload(TlvStream(listOf(OnionTlv.AmountToForward(amount), OnionTlv.OutgoingCltv(expiry), OnionTlv.PaymentData(paymentSecret, totalAmount)) + additionalTlvs, userCustomTlvs)) - - /** Create a trampoline outer payload. */ - fun createTrampolinePayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32, trampolinePacket: OnionRoutingPacket): FinalPayload = - FinalPayload(TlvStream(listOf(OnionTlv.AmountToForward(amount), OnionTlv.OutgoingCltv(expiry), OnionTlv.PaymentData(paymentSecret, totalAmount), OnionTlv.TrampolineOnion(trampolinePacket)))) - } -} - -data class ChannelRelayPayload(val records: TlvStream) : PerHopPayload() { - val amountToForward = records.get()!!.amount - val outgoingCltv = records.get()!!.cltv - val outgoingChannelId = records.get()!!.shortChannelId - - override fun write(out: Output) = tlvSerializer.write(records, out) - - companion object : PerHopPayloadReader { - override fun read(input: Input): ChannelRelayPayload = ChannelRelayPayload(tlvSerializer.read(input)) - - fun create(outgoingChannelId: ShortChannelId, amountToForward: MilliSatoshi, outgoingCltv: CltvExpiry): ChannelRelayPayload = - ChannelRelayPayload(TlvStream(listOf(OnionTlv.AmountToForward(amountToForward), OnionTlv.OutgoingCltv(outgoingCltv), OnionTlv.OutgoingChannelId(outgoingChannelId)))) - } -} - -data class NodeRelayPayload(val records: TlvStream) : PerHopPayload() { - val amountToForward = records.get()!!.amount - val outgoingCltv = records.get()!!.cltv - val outgoingNodeId = records.get()!!.nodeId - val totalAmount = run { - val paymentData = records.get() - when { - paymentData == null -> amountToForward - paymentData.totalAmount == MilliSatoshi(0) -> amountToForward - else -> paymentData.totalAmount - } - } - - // NB: the following fields are only included in the trampoline-to-legacy case. - val paymentSecret = records.get()?.secret - val invoiceFeatures = records.get()?.features - val invoiceRoutingInfo = records.get()?.extraHops - - override fun write(out: Output) = tlvSerializer.write(records, out) - - companion object : PerHopPayloadReader { - override fun read(input: Input): NodeRelayPayload = NodeRelayPayload(tlvSerializer.read(input)) - - fun create(amount: MilliSatoshi, expiry: CltvExpiry, nextNodeId: PublicKey) = NodeRelayPayload(TlvStream(listOf(OnionTlv.AmountToForward(amount), OnionTlv.OutgoingCltv(expiry), OnionTlv.OutgoingNodeId(nextNodeId)))) - - /** Create a trampoline inner payload instructing the trampoline node to relay via a non-trampoline payment. */ - fun createNodeRelayToNonTrampolinePayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, targetNodeId: PublicKey, invoice: PaymentRequest): NodeRelayPayload { - // NB: we limit the number of routing hints to ensure we don't overflow the onion. - // A better solution is to provide the routing hints outside the onion (in the `update_add_htlc` tlv stream). - val prunedRoutingHints = invoice.routingInfo.shuffled().fold(listOf()) { previous, current -> - if (previous.flatMap { it.hints }.size + current.hints.size <= 4) { - previous + current - } else { - previous - } - }.map { it.hints } - return NodeRelayPayload( - TlvStream( - listOf( - OnionTlv.AmountToForward(amount), - OnionTlv.OutgoingCltv(expiry), - OnionTlv.OutgoingNodeId(targetNodeId), - OnionTlv.PaymentData(invoice.paymentSecret, totalAmount), - OnionTlv.InvoiceFeatures(invoice.features), - OnionTlv.InvoiceRoutingInfo(prunedRoutingHints) - ) - ) - ) - } - - } - -} diff --git a/src/commonMain/kotlin/fr/acinq/lightning/wire/OnionRouting.kt b/src/commonMain/kotlin/fr/acinq/lightning/wire/OnionRouting.kt new file mode 100644 index 000000000..8e3b15005 --- /dev/null +++ b/src/commonMain/kotlin/fr/acinq/lightning/wire/OnionRouting.kt @@ -0,0 +1,55 @@ +package fr.acinq.lightning.wire + +import fr.acinq.bitcoin.ByteVector +import fr.acinq.bitcoin.ByteVector32 +import fr.acinq.bitcoin.io.ByteArrayInput +import fr.acinq.bitcoin.io.ByteArrayOutput +import fr.acinq.bitcoin.io.Input +import fr.acinq.bitcoin.io.Output +import fr.acinq.lightning.utils.toByteVector +import fr.acinq.lightning.utils.toByteVector32 +import kotlinx.serialization.Contextual +import kotlinx.serialization.Serializable + +@Serializable +data class OnionRoutingPacket( + val version: Int, + @Contextual val publicKey: ByteVector, + @Contextual val payload: ByteVector, + @Contextual val hmac: ByteVector32 +) { + companion object { + const val PaymentPacketLength = 1300 + const val TrampolinePacketLength = 400 + } +} + +/** + * @param payloadLength length of the onion-encrypted payload. + */ +@OptIn(ExperimentalUnsignedTypes::class) +class OnionRoutingPacketSerializer(private val payloadLength: Int) { + fun read(input: Input): OnionRoutingPacket { + return OnionRoutingPacket( + LightningCodecs.byte(input), + LightningCodecs.bytes(input, 33).toByteVector(), + LightningCodecs.bytes(input, payloadLength).toByteVector(), + LightningCodecs.bytes(input, 32).toByteVector32() + ) + } + + fun read(bytes: ByteArray): OnionRoutingPacket = read(ByteArrayInput(bytes)) + + fun write(message: OnionRoutingPacket, out: Output) { + LightningCodecs.writeByte(message.version, out) + LightningCodecs.writeBytes(message.publicKey, out) + LightningCodecs.writeBytes(message.payload, out) + LightningCodecs.writeBytes(message.hmac, out) + } + + fun write(message: OnionRoutingPacket): ByteArray { + val out = ByteArrayOutput() + write(message, out) + return out.toByteArray() + } +} \ No newline at end of file diff --git a/src/commonMain/kotlin/fr/acinq/lightning/wire/PaymentOnion.kt b/src/commonMain/kotlin/fr/acinq/lightning/wire/PaymentOnion.kt new file mode 100644 index 000000000..2fe59723a --- /dev/null +++ b/src/commonMain/kotlin/fr/acinq/lightning/wire/PaymentOnion.kt @@ -0,0 +1,338 @@ +package fr.acinq.lightning.wire + +import fr.acinq.bitcoin.ByteVector +import fr.acinq.bitcoin.ByteVector32 +import fr.acinq.bitcoin.PublicKey +import fr.acinq.bitcoin.io.ByteArrayInput +import fr.acinq.bitcoin.io.ByteArrayOutput +import fr.acinq.bitcoin.io.Input +import fr.acinq.bitcoin.io.Output +import fr.acinq.lightning.CltvExpiry +import fr.acinq.lightning.CltvExpiryDelta +import fr.acinq.lightning.MilliSatoshi +import fr.acinq.lightning.ShortChannelId +import fr.acinq.lightning.payment.PaymentRequest +import fr.acinq.lightning.utils.msat +import kotlinx.serialization.Contextual +import kotlinx.serialization.Serializable + +@OptIn(ExperimentalUnsignedTypes::class) +@Serializable +sealed class OnionPaymentPayloadTlv : Tlv { + /** Amount to forward to the next node. */ + @Serializable + data class AmountToForward(val amount: MilliSatoshi) : OnionPaymentPayloadTlv() { + override val tag: Long get() = AmountToForward.tag + override fun write(out: Output) = LightningCodecs.writeTU64(amount.toLong(), out) + + companion object : TlvValueReader { + const val tag: Long = 2 + override fun read(input: Input): AmountToForward = AmountToForward(MilliSatoshi(LightningCodecs.tu64(input))) + } + } + + /** CLTV value to use for the HTLC offered to the next node. */ + @Serializable + data class OutgoingCltv(val cltv: CltvExpiry) : OnionPaymentPayloadTlv() { + override val tag: Long get() = OutgoingCltv.tag + override fun write(out: Output) = LightningCodecs.writeTU32(cltv.toLong().toInt(), out) + + companion object : TlvValueReader { + const val tag: Long = 4 + override fun read(input: Input): OutgoingCltv = OutgoingCltv(CltvExpiry(LightningCodecs.tu32(input).toLong())) + } + } + + /** Id of the channel to use to forward a payment to the next node. */ + @Serializable + data class OutgoingChannelId(val shortChannelId: ShortChannelId) : OnionPaymentPayloadTlv() { + override val tag: Long get() = OutgoingChannelId.tag + override fun write(out: Output) = LightningCodecs.writeU64(shortChannelId.toLong(), out) + + companion object : TlvValueReader { + const val tag: Long = 6 + override fun read(input: Input): OutgoingChannelId = OutgoingChannelId(ShortChannelId(LightningCodecs.u64(input))) + } + } + + /** + * Bolt 11 payment details (only included for the last node). + * + * @param secret payment secret specified in the Bolt 11 invoice. + * @param totalAmount total amount in multi-part payments. When missing, assumed to be equal to AmountToForward. + */ + @Serializable + data class PaymentData(@Contextual val secret: ByteVector32, val totalAmount: MilliSatoshi) : OnionPaymentPayloadTlv() { + override val tag: Long get() = PaymentData.tag + override fun write(out: Output) { + LightningCodecs.writeBytes(secret, out) + LightningCodecs.writeTU64(totalAmount.toLong(), out) + } + + companion object : TlvValueReader { + const val tag: Long = 8 + override fun read(input: Input): PaymentData = PaymentData(ByteVector32(LightningCodecs.bytes(input, 32)), MilliSatoshi(LightningCodecs.tu64(input))) + } + } + + /** + * When payment metadata is included in a Bolt 9 invoice, we should send it as-is to the recipient. + * This lets recipients generate invoices without having to store anything on their side until the invoice is paid. + */ + @Serializable + data class PaymentMetadata(@Contextual val data: ByteVector) : OnionPaymentPayloadTlv() { + override val tag: Long get() = PaymentMetadata.tag + override fun write(out: Output) = LightningCodecs.writeBytes(data, out) + + companion object : TlvValueReader { + const val tag: Long = 16 + override fun read(input: Input): PaymentMetadata = PaymentMetadata(ByteVector(LightningCodecs.bytes(input, input.availableBytes))) + } + } + + /** + * Invoice feature bits. Only included for intermediate trampoline nodes when they should convert to a legacy payment + * because the final recipient doesn't support trampoline. + */ + @Serializable + data class InvoiceFeatures(@Contextual val features: ByteVector) : OnionPaymentPayloadTlv() { + override val tag: Long get() = InvoiceFeatures.tag + override fun write(out: Output) = LightningCodecs.writeBytes(features, out) + + companion object : TlvValueReader { + const val tag: Long = 66097 + override fun read(input: Input): InvoiceFeatures = InvoiceFeatures(ByteVector(LightningCodecs.bytes(input, input.availableBytes))) + } + } + + /** Id of the next node. */ + @Serializable + data class OutgoingNodeId(@Contextual val nodeId: PublicKey) : OnionPaymentPayloadTlv() { + override val tag: Long get() = OutgoingNodeId.tag + override fun write(out: Output) = LightningCodecs.writeBytes(nodeId.value, out) + + companion object : TlvValueReader { + const val tag: Long = 66098 + override fun read(input: Input): OutgoingNodeId = OutgoingNodeId(PublicKey(LightningCodecs.bytes(input, 33))) + } + } + + /** + * Invoice routing hints. Only included for intermediate trampoline nodes when they should convert to a legacy payment + * because the final recipient doesn't support trampoline. + */ + @Serializable + data class InvoiceRoutingInfo(val extraHops: List>) : OnionPaymentPayloadTlv() { + override val tag: Long get() = InvoiceRoutingInfo.tag + override fun write(out: Output) { + for (routeHint in extraHops) { + LightningCodecs.writeByte(routeHint.size, out) + routeHint.map { + LightningCodecs.writeBytes(it.nodeId.value, out) + LightningCodecs.writeU64(it.shortChannelId.toLong(), out) + LightningCodecs.writeU32(it.feeBase.toLong().toInt(), out) + LightningCodecs.writeU32(it.feeProportionalMillionths.toInt(), out) + LightningCodecs.writeU16(it.cltvExpiryDelta.toInt(), out) + } + } + } + + companion object : TlvValueReader { + const val tag: Long = 66099 + override fun read(input: Input): InvoiceRoutingInfo { + val extraHops = mutableListOf>() + while (input.availableBytes > 0) { + val hopCount = LightningCodecs.byte(input) + val extraHop = (0 until hopCount).map { + PaymentRequest.TaggedField.ExtraHop( + PublicKey(LightningCodecs.bytes(input, 33)), + ShortChannelId(LightningCodecs.u64(input)), + MilliSatoshi(LightningCodecs.u32(input).toLong()), + LightningCodecs.u32(input).toLong(), + CltvExpiryDelta(LightningCodecs.u16(input)) + ) + } + extraHops.add(extraHop) + } + return InvoiceRoutingInfo(extraHops) + } + } + } + + /** An encrypted trampoline onion packet. */ + @Serializable + data class TrampolineOnion(val packet: OnionRoutingPacket) : OnionPaymentPayloadTlv() { + override val tag: Long get() = TrampolineOnion.tag + override fun write(out: Output) = OnionRoutingPacketSerializer(OnionRoutingPacket.TrampolinePacketLength).write(packet, out) + + companion object : TlvValueReader { + const val tag: Long = 66100 + override fun read(input: Input): TrampolineOnion = TrampolineOnion(OnionRoutingPacketSerializer(OnionRoutingPacket.TrampolinePacketLength).read(input)) + } + } +} + +object PaymentOnion { + + sealed class PerHopPayload { + + abstract fun write(out: Output) + + fun write(): ByteArray { + val out = ByteArrayOutput() + write(out) + return out.toByteArray() + } + + companion object { + val tlvSerializer = TlvStreamSerializer( + true, @Suppress("UNCHECKED_CAST") mapOf( + OnionPaymentPayloadTlv.AmountToForward.tag to OnionPaymentPayloadTlv.AmountToForward.Companion as TlvValueReader, + OnionPaymentPayloadTlv.OutgoingCltv.tag to OnionPaymentPayloadTlv.OutgoingCltv.Companion as TlvValueReader, + OnionPaymentPayloadTlv.OutgoingChannelId.tag to OnionPaymentPayloadTlv.OutgoingChannelId.Companion as TlvValueReader, + OnionPaymentPayloadTlv.PaymentData.tag to OnionPaymentPayloadTlv.PaymentData.Companion as TlvValueReader, + OnionPaymentPayloadTlv.PaymentMetadata.tag to OnionPaymentPayloadTlv.PaymentMetadata.Companion as TlvValueReader, + OnionPaymentPayloadTlv.InvoiceFeatures.tag to OnionPaymentPayloadTlv.InvoiceFeatures.Companion as TlvValueReader, + OnionPaymentPayloadTlv.OutgoingNodeId.tag to OnionPaymentPayloadTlv.OutgoingNodeId.Companion as TlvValueReader, + OnionPaymentPayloadTlv.InvoiceRoutingInfo.tag to OnionPaymentPayloadTlv.InvoiceRoutingInfo.Companion as TlvValueReader, + OnionPaymentPayloadTlv.TrampolineOnion.tag to OnionPaymentPayloadTlv.TrampolineOnion.Companion as TlvValueReader, + ) + ) + } + } + + interface PerHopPayloadReader { + fun read(input: Input): T + fun read(bytes: ByteArray): T = read(ByteArrayInput(bytes)) + } + + data class FinalPayload(val records: TlvStream) : PerHopPayload() { + val amount = records.get()!!.amount + val expiry = records.get()!!.cltv + val paymentSecret = records.get()!!.secret + val totalAmount = run { + val total = records.get()!!.totalAmount + if (total > 0.msat) total else amount + } + val paymentMetadata = records.get()?.data + + override fun write(out: Output) = tlvSerializer.write(records, out) + + companion object : PerHopPayloadReader { + override fun read(input: Input): FinalPayload = FinalPayload(tlvSerializer.read(input)) + + /** Create a single-part payment (total amount sent at once). */ + fun createSinglePartPayload(amount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32, paymentMetadata: ByteVector?, userCustomTlvs: List = listOf()): FinalPayload { + val tlvs = buildList { + add(OnionPaymentPayloadTlv.AmountToForward(amount)) + add(OnionPaymentPayloadTlv.OutgoingCltv(expiry)) + add(OnionPaymentPayloadTlv.PaymentData(paymentSecret, amount)) + paymentMetadata?.let { add(OnionPaymentPayloadTlv.PaymentMetadata(it)) } + } + return FinalPayload(TlvStream(tlvs, userCustomTlvs)) + } + + /** Create a partial payment (total amount split between multiple payments). */ + fun createMultiPartPayload( + amount: MilliSatoshi, + totalAmount: MilliSatoshi, + expiry: CltvExpiry, + paymentSecret: ByteVector32, + paymentMetadata: ByteVector?, + additionalTlvs: List = listOf(), + userCustomTlvs: List = listOf() + ): FinalPayload { + val tlvs = buildList { + add(OnionPaymentPayloadTlv.AmountToForward(amount)) + add(OnionPaymentPayloadTlv.OutgoingCltv(expiry)) + add(OnionPaymentPayloadTlv.PaymentData(paymentSecret, totalAmount)) + paymentMetadata?.let { add(OnionPaymentPayloadTlv.PaymentMetadata(it)) } + addAll(additionalTlvs) + } + return FinalPayload(TlvStream(tlvs, userCustomTlvs)) + } + + /** Create a trampoline outer payload. */ + fun createTrampolinePayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, paymentSecret: ByteVector32, trampolinePacket: OnionRoutingPacket): FinalPayload { + val tlvs = buildList { + add(OnionPaymentPayloadTlv.AmountToForward(amount)) + add(OnionPaymentPayloadTlv.OutgoingCltv(expiry)) + add(OnionPaymentPayloadTlv.PaymentData(paymentSecret, totalAmount)) + add(OnionPaymentPayloadTlv.TrampolineOnion(trampolinePacket)) + } + return FinalPayload(TlvStream(tlvs)) + } + } + } + + data class ChannelRelayPayload(val records: TlvStream) : PerHopPayload() { + val amountToForward = records.get()!!.amount + val outgoingCltv = records.get()!!.cltv + val outgoingChannelId = records.get()!!.shortChannelId + + override fun write(out: Output) = tlvSerializer.write(records, out) + + companion object : PerHopPayloadReader { + override fun read(input: Input): ChannelRelayPayload = ChannelRelayPayload(tlvSerializer.read(input)) + + fun create(outgoingChannelId: ShortChannelId, amountToForward: MilliSatoshi, outgoingCltv: CltvExpiry): ChannelRelayPayload = + ChannelRelayPayload(TlvStream(listOf(OnionPaymentPayloadTlv.AmountToForward(amountToForward), OnionPaymentPayloadTlv.OutgoingCltv(outgoingCltv), OnionPaymentPayloadTlv.OutgoingChannelId(outgoingChannelId)))) + } + } + + data class NodeRelayPayload(val records: TlvStream) : PerHopPayload() { + val amountToForward = records.get()!!.amount + val outgoingCltv = records.get()!!.cltv + val outgoingNodeId = records.get()!!.nodeId + val totalAmount = run { + val paymentData = records.get() + when { + paymentData == null -> amountToForward + paymentData.totalAmount == MilliSatoshi(0) -> amountToForward + else -> paymentData.totalAmount + } + } + + // NB: the following fields are only included in the trampoline-to-legacy case. + val paymentSecret = records.get()?.secret + val paymentMetadata = records.get()?.data + val invoiceFeatures = records.get()?.features + val invoiceRoutingInfo = records.get()?.extraHops + + override fun write(out: Output) = tlvSerializer.write(records, out) + + companion object : PerHopPayloadReader { + override fun read(input: Input): NodeRelayPayload = NodeRelayPayload(tlvSerializer.read(input)) + + fun create(amount: MilliSatoshi, expiry: CltvExpiry, nextNodeId: PublicKey) = + NodeRelayPayload(TlvStream(listOf(OnionPaymentPayloadTlv.AmountToForward(amount), OnionPaymentPayloadTlv.OutgoingCltv(expiry), OnionPaymentPayloadTlv.OutgoingNodeId(nextNodeId)))) + + /** Create a trampoline inner payload instructing the trampoline node to relay via a non-trampoline payment. */ + fun createNodeRelayToNonTrampolinePayload(amount: MilliSatoshi, totalAmount: MilliSatoshi, expiry: CltvExpiry, targetNodeId: PublicKey, invoice: PaymentRequest): NodeRelayPayload { + // NB: we limit the number of routing hints to ensure we don't overflow the onion. + // A better solution is to provide the routing hints outside the onion (in the `update_add_htlc` tlv stream). + val prunedRoutingHints = invoice.routingInfo.shuffled().fold(listOf()) { previous, current -> + if (previous.flatMap { it.hints }.size + current.hints.size <= 4) { + previous + current + } else { + previous + } + }.map { it.hints } + return NodeRelayPayload( + TlvStream( + buildList { + add(OnionPaymentPayloadTlv.AmountToForward(amount)) + add(OnionPaymentPayloadTlv.OutgoingCltv(expiry)) + add(OnionPaymentPayloadTlv.OutgoingNodeId(targetNodeId)) + add(OnionPaymentPayloadTlv.PaymentData(invoice.paymentSecret, totalAmount)) + invoice.paymentMetadata?.let { add(OnionPaymentPayloadTlv.PaymentMetadata(it)) } + add(OnionPaymentPayloadTlv.InvoiceFeatures(invoice.features)) + add(OnionPaymentPayloadTlv.InvoiceRoutingInfo(prunedRoutingHints)) + } + ) + ) + } + } + } + +} diff --git a/src/commonTest/kotlin/fr/acinq/lightning/FeaturesTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/FeaturesTestsCommon.kt index 919027da2..4fcc960aa 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/FeaturesTestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/FeaturesTestsCommon.kt @@ -191,6 +191,37 @@ class FeaturesTestsCommon : LightningTestSuite() { } } + @Test + fun `filter features based on their usage`() { + val features = Features( + mapOf( + OptionDataLossProtect to FeatureSupport.Optional, + InitialRoutingSync to FeatureSupport.Optional, + VariableLengthOnion to FeatureSupport.Mandatory, + PaymentMetadata to FeatureSupport.Optional, + ), + setOf(UnknownFeature(753), UnknownFeature(852)), + ) + assertEquals( + Features( + mapOf(OptionDataLossProtect to FeatureSupport.Optional, InitialRoutingSync to FeatureSupport.Optional, VariableLengthOnion to FeatureSupport.Mandatory), + setOf(UnknownFeature(753), UnknownFeature(852)), + ), features.initFeatures() + ) + assertEquals( + Features( + mapOf(OptionDataLossProtect to FeatureSupport.Optional, VariableLengthOnion to FeatureSupport.Mandatory), + setOf(UnknownFeature(753), UnknownFeature(852)), + ), features.nodeAnnouncementFeatures() + ) + assertEquals( + Features( + mapOf(VariableLengthOnion to FeatureSupport.Mandatory, PaymentMetadata to FeatureSupport.Optional), + setOf(UnknownFeature(753), UnknownFeature(852)), + ), features.invoiceFeatures() + ) + } + @Test fun `features to bytes`() { val testCases = mapOf( diff --git a/src/commonTest/kotlin/fr/acinq/lightning/channel/TestsHelper.kt b/src/commonTest/kotlin/fr/acinq/lightning/channel/TestsHelper.kt index 4c61d1060..fe891d18e 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/channel/TestsHelper.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/channel/TestsHelper.kt @@ -6,7 +6,7 @@ import fr.acinq.lightning.Lightning.randomBytes32 import fr.acinq.lightning.blockchain.* import fr.acinq.lightning.blockchain.fee.FeeratePerKw import fr.acinq.lightning.blockchain.fee.OnChainFeerates -import fr.acinq.lightning.payment.OutgoingPacket +import fr.acinq.lightning.payment.OutgoingPaymentPacket import fr.acinq.lightning.router.ChannelHop import fr.acinq.lightning.serialization.Serialization import fr.acinq.lightning.tests.TestConstants @@ -320,7 +320,7 @@ object TestsHelper { val expiry = CltvExpiryDelta(144).toCltvExpiry(currentBlockHeight) val dummyKey = PrivateKey(ByteVector32("0101010101010101010101010101010101010101010101010101010101010101")).publicKey() val dummyUpdate = ChannelUpdate(ByteVector64.Zeroes, ByteVector32.Zeroes, ShortChannelId(144, 0, 0), 0, 0, 0, CltvExpiryDelta(1), 0.msat, 0.msat, 0, null) - val cmd = OutgoingPacket.buildCommand(paymentId, paymentHash, listOf(ChannelHop(dummyKey, destination, dummyUpdate)), FinalPayload.createSinglePartPayload(amount, expiry, randomBytes32())).first.copy(commit = false) + val cmd = OutgoingPaymentPacket.buildCommand(paymentId, paymentHash, listOf(ChannelHop(dummyKey, destination, dummyUpdate)), PaymentOnion.FinalPayload.createSinglePartPayload(amount, expiry, randomBytes32(), null)).first.copy(commit = false) return Pair(paymentPreimage, cmd) } diff --git a/src/commonTest/kotlin/fr/acinq/lightning/payment/IncomingPaymentHandlerTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/payment/IncomingPaymentHandlerTestsCommon.kt index 6ea4321d4..eaf92d283 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/payment/IncomingPaymentHandlerTestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/payment/IncomingPaymentHandlerTestsCommon.kt @@ -179,7 +179,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { payToOpenFeeSatoshis = 100.sat, paymentHash = ByteVector32.One, // <-- not associated to a pending invoice expireAt = Long.MAX_VALUE, - finalPacket = OutgoingPacket.buildPacket( + finalPacket = OutgoingPaymentPacket.buildPacket( paymentHash = ByteVector32.One, // <-- has to be the same as the one above otherwise encryption fails hops = channelHops(paymentHandler.nodeParams.nodeId), finalPayload = makeMppPayload(defaultAmount, defaultAmount, randomBytes32()), @@ -196,7 +196,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { payToOpenRequest.chainHash, payToOpenRequest.paymentHash, PayToOpenResponse.Result.Failure( - OutgoingPacket.buildHtlcFailure( + OutgoingPaymentPacket.buildHtlcFailure( paymentHandler.nodeParams.nodePrivateKey, payToOpenRequest.paymentHash, payToOpenRequest.finalPacket, @@ -222,7 +222,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { payToOpenRequest.chainHash, payToOpenRequest.paymentHash, PayToOpenResponse.Result.Failure( - OutgoingPacket.buildHtlcFailure( + OutgoingPaymentPacket.buildHtlcFailure( paymentHandler.nodeParams.nodePrivateKey, payToOpenRequest.paymentHash, payToOpenRequest.finalPacket, @@ -248,7 +248,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { payToOpenRequest.chainHash, payToOpenRequest.paymentHash, PayToOpenResponse.Result.Failure( - OutgoingPacket.buildHtlcFailure( + OutgoingPaymentPacket.buildHtlcFailure( paymentHandler.nodeParams.nodePrivateKey, payToOpenRequest.paymentHash, payToOpenRequest.finalPacket, @@ -281,7 +281,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { payToOpenRequest.chainHash, payToOpenRequest.paymentHash, PayToOpenResponse.Result.Failure( - OutgoingPacket.buildHtlcFailure( + OutgoingPaymentPacket.buildHtlcFailure( paymentHandler.nodeParams.nodePrivateKey, payToOpenRequest.paymentHash, payToOpenRequest.finalPacket, @@ -307,7 +307,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { payToOpenFeeSatoshis = 100.sat, paymentHash = incomingPayment.paymentHash, expireAt = Long.MAX_VALUE, - finalPacket = OutgoingPacket.buildPacket( + finalPacket = OutgoingPaymentPacket.buildPacket( paymentHash = incomingPayment.paymentHash, hops = trampolineHops, finalPayload = makeMppPayload(defaultAmount, defaultAmount, paymentSecret.reversed()), // <-- wrong secret @@ -324,7 +324,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { payToOpenRequest.chainHash, payToOpenRequest.paymentHash, PayToOpenResponse.Result.Failure( - OutgoingPacket.buildHtlcFailure( + OutgoingPaymentPacket.buildHtlcFailure( paymentHandler.nodeParams.nodePrivateKey, payToOpenRequest.paymentHash, payToOpenRequest.finalPacket, @@ -581,7 +581,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { payToOpenRequest.chainHash, payToOpenRequest.paymentHash, PayToOpenResponse.Result.Failure( - OutgoingPacket.buildHtlcFailure( + OutgoingPaymentPacket.buildHtlcFailure( paymentHandler.nodeParams.nodePrivateKey, payToOpenRequest.paymentHash, payToOpenRequest.finalPacket, @@ -640,7 +640,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { payToOpenRequest1.chainHash, payToOpenRequest1.paymentHash, PayToOpenResponse.Result.Failure( - OutgoingPacket.buildHtlcFailure( + OutgoingPaymentPacket.buildHtlcFailure( paymentHandler.nodeParams.nodePrivateKey, payToOpenRequest1.paymentHash, payToOpenRequest1.finalPacket, @@ -654,7 +654,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { payToOpenRequest2.chainHash, payToOpenRequest2.paymentHash, PayToOpenResponse.Result.Failure( - OutgoingPacket.buildHtlcFailure( + OutgoingPaymentPacket.buildHtlcFailure( paymentHandler.nodeParams.nodePrivateKey, payToOpenRequest2.paymentHash, payToOpenRequest2.finalPacket, @@ -1173,12 +1173,12 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { return listOf(channelHop) } - private fun makeCmdAddHtlc(destination: PublicKey, paymentHash: ByteVector32, finalPayload: FinalPayload): CMD_ADD_HTLC { - return OutgoingPacket.buildCommand(UUID.randomUUID(), paymentHash, channelHops(destination), finalPayload).first.copy(commit = true) + private fun makeCmdAddHtlc(destination: PublicKey, paymentHash: ByteVector32, finalPayload: PaymentOnion.FinalPayload): CMD_ADD_HTLC { + return OutgoingPaymentPacket.buildCommand(UUID.randomUUID(), paymentHash, channelHops(destination), finalPayload).first.copy(commit = true) } - private fun makeUpdateAddHtlc(id: Long, channelId: ByteVector32, destination: IncomingPaymentHandler, paymentHash: ByteVector32, finalPayload: FinalPayload): UpdateAddHtlc { - val (_, _, packetAndSecrets) = OutgoingPacket.buildPacket(paymentHash, channelHops(destination.nodeParams.nodeId), finalPayload, OnionRoutingPacket.PaymentPacketLength) + private fun makeUpdateAddHtlc(id: Long, channelId: ByteVector32, destination: IncomingPaymentHandler, paymentHash: ByteVector32, finalPayload: PaymentOnion.FinalPayload): UpdateAddHtlc { + val (_, _, packetAndSecrets) = OutgoingPaymentPacket.buildPacket(paymentHash, channelHops(destination.nodeParams.nodeId), finalPayload, OnionRoutingPacket.PaymentPacketLength) return UpdateAddHtlc(channelId, id, finalPayload.amount, paymentHash, finalPayload.expiry, packetAndSecrets.packet) } @@ -1188,12 +1188,12 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { paymentSecret: ByteVector32, cltvExpiryDelta: CltvExpiryDelta = CltvExpiryDelta(144), currentBlockHeight: Int = TestConstants.defaultBlockHeight - ): FinalPayload { + ): PaymentOnion.FinalPayload { val expiry = cltvExpiryDelta.toCltvExpiry(currentBlockHeight.toLong()) - return FinalPayload.createMultiPartPayload(amount, totalAmount, expiry, paymentSecret) + return PaymentOnion.FinalPayload.createMultiPartPayload(amount, totalAmount, expiry, paymentSecret, null) } - private fun makePayToOpenRequest(incomingPayment: IncomingPayment, finalPayload: FinalPayload, payToOpenMinAmount: MilliSatoshi = 10_000.msat): PayToOpenRequest { + private fun makePayToOpenRequest(incomingPayment: IncomingPayment, finalPayload: PaymentOnion.FinalPayload, payToOpenMinAmount: MilliSatoshi = 10_000.msat): PayToOpenRequest { return PayToOpenRequest( chainHash = ByteVector32.Zeroes, fundingSatoshis = 100_000.sat, @@ -1202,7 +1202,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { payToOpenFeeSatoshis = finalPayload.amount.truncateToSatoshi() * 0.1, // 10% paymentHash = incomingPayment.paymentHash, expireAt = Long.MAX_VALUE, - finalPacket = OutgoingPacket.buildPacket( + finalPacket = OutgoingPaymentPacket.buildPacket( paymentHash = incomingPayment.paymentHash, hops = channelHops(TestConstants.Bob.nodeParams.nodeId), finalPayload = finalPayload, @@ -1213,6 +1213,7 @@ class IncomingPaymentHandlerTestsCommon : LightningTestSuite() { private suspend fun makeIncomingPayment(payee: IncomingPaymentHandler, amount: MilliSatoshi?, expirySeconds: Long? = null, timestamp: Long = currentTimestampSeconds()): Pair { val paymentRequest = payee.createInvoice(defaultPreimage, amount, "unit test", listOf(), expirySeconds, timestamp) + assertNotNull(paymentRequest.paymentMetadata) return Pair(payee.db.getIncomingPayment(paymentRequest.paymentHash)!!, paymentRequest.paymentSecret) } diff --git a/src/commonTest/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandlerTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandlerTestsCommon.kt index 27c09c24d..560738a8c 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandlerTestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/payment/OutgoingPaymentHandlerTestsCommon.kt @@ -231,7 +231,7 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { // The recipient should receive the right amount and expiry. val payloadBytesC = Sphinx.peel(recipientKey, payment.paymentHash, packetC, OnionRoutingPacket.TrampolinePacketLength).right!! - val payloadC = FinalPayload.read(payloadBytesC.payload.toByteArray()) + val payloadC = PaymentOnion.FinalPayload.read(payloadBytesC.payload.toByteArray()) assertEquals(200_000.msat, payloadC.amount) assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + PaymentRequest.DEFAULT_MIN_FINAL_EXPIRY_DELTA, payloadC.expiry) assertEquals(payloadC.amount, payloadC.totalAmount) @@ -286,7 +286,7 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { // The recipient should receive the right amount and expiry. val payloadBytesC = Sphinx.peel(recipientKey, payment.paymentHash, packetC, OnionRoutingPacket.TrampolinePacketLength).right!! - val payloadC = FinalPayload.read(payloadBytesC.payload.toByteArray()) + val payloadC = PaymentOnion.FinalPayload.read(payloadBytesC.payload.toByteArray()) assertEquals(300_000.msat, payloadC.amount) assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + PaymentRequest.DEFAULT_MIN_FINAL_EXPIRY_DELTA, payloadC.expiry) assertEquals(payloadC.amount, payloadC.totalAmount) @@ -423,7 +423,7 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { adds.forEach { assertEquals(payment.paymentHash, it.second.paymentHash) } adds.forEach { (channelId, add) -> // Bob should receive the right final information. - val payloadB = IncomingPacket.decrypt(makeUpdateAddHtlc(channelId, add), TestConstants.Bob.nodeParams.nodePrivateKey).right!! + val payloadB = IncomingPaymentPacket.decrypt(makeUpdateAddHtlc(channelId, add), TestConstants.Bob.nodeParams.nodePrivateKey).right!! assertEquals(add.amount, payloadB.amount) assertEquals(300_000.msat, payloadB.totalAmount) assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + PaymentRequest.DEFAULT_MIN_FINAL_EXPIRY_DELTA, payloadB.expiry) @@ -498,7 +498,7 @@ class OutgoingPaymentHandlerTestsCommon : LightningTestSuite() { // The recipient should receive the right amount and expiry. val payloadBytesC = Sphinx.peel(recipientKey, payment.paymentHash, packetC, OnionRoutingPacket.TrampolinePacketLength).right!! - val payloadC = FinalPayload.read(payloadBytesC.payload.toByteArray()) + val payloadC = PaymentOnion.FinalPayload.read(payloadBytesC.payload.toByteArray()) assertEquals(300_000.msat, payloadC.amount) assertEquals(CltvExpiry(TestConstants.defaultBlockHeight.toLong()) + PaymentRequest.DEFAULT_MIN_FINAL_EXPIRY_DELTA, payloadC.expiry) assertEquals(payloadC.amount, payloadC.totalAmount) diff --git a/src/commonTest/kotlin/fr/acinq/lightning/payment/PaymentPacketTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/payment/PaymentPacketTestsCommon.kt index 3baca2e40..0073857bc 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/payment/PaymentPacketTestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/payment/PaymentPacketTestsCommon.kt @@ -1,9 +1,13 @@ package fr.acinq.lightning.payment -import fr.acinq.bitcoin.* +import fr.acinq.bitcoin.Block +import fr.acinq.bitcoin.ByteVector +import fr.acinq.bitcoin.Crypto +import fr.acinq.bitcoin.PrivateKey import fr.acinq.bitcoin.io.ByteArrayInput import fr.acinq.lightning.* import fr.acinq.lightning.Lightning.nodeFee +import fr.acinq.lightning.Lightning.randomBytes import fr.acinq.lightning.Lightning.randomBytes32 import fr.acinq.lightning.Lightning.randomBytes64 import fr.acinq.lightning.Lightning.randomKey @@ -50,6 +54,7 @@ class PaymentPacketTestsCommon : LightningTestSuite() { private val paymentPreimage = randomBytes32() private val paymentHash = Crypto.sha256(paymentPreimage).toByteVector32() private val paymentSecret = randomBytes32() + private val paymentMetadata = randomBytes(64).toByteVector() private val expiryDE = finalExpiry private val amountDE = finalAmount @@ -83,8 +88,9 @@ class PaymentPacketTestsCommon : LightningTestSuite() { ) private fun testBuildOnion() { - val finalPayload = FinalPayload(TlvStream(listOf(OnionTlv.AmountToForward(finalAmount), OnionTlv.OutgoingCltv(finalExpiry), OnionTlv.PaymentData(paymentSecret, finalAmount)))) - val (firstAmount, firstExpiry, onion) = OutgoingPacket.buildPacket(paymentHash, hops, finalPayload, OnionRoutingPacket.PaymentPacketLength) + val finalPayload = + PaymentOnion.FinalPayload(TlvStream(listOf(OnionPaymentPayloadTlv.AmountToForward(finalAmount), OnionPaymentPayloadTlv.OutgoingCltv(finalExpiry), OnionPaymentPayloadTlv.PaymentData(paymentSecret, finalAmount)))) + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket(paymentHash, hops, finalPayload, OnionRoutingPacket.PaymentPacketLength) assertEquals(amountAB, firstAmount) assertEquals(expiryAB, firstExpiry) assertEquals(OnionRoutingPacket.PaymentPacketLength, onion.packet.payload.size()) @@ -115,7 +121,7 @@ class PaymentPacketTestsCommon : LightningTestSuite() { assertEquals(channelUpdateDE.shortChannelId, payloadD.outgoingChannelId) val addE = UpdateAddHtlc(randomBytes32(), 2, amountDE, paymentHash, expiryDE, packetE) - val payloadE = IncomingPacket.decrypt(addE, privE).right!! + val payloadE = IncomingPaymentPacket.decrypt(addE, privE).right!! assertEquals(finalAmount, payloadE.amount) assertEquals(finalAmount, payloadE.totalAmount) assertEquals(finalExpiry, payloadE.expiry) @@ -123,22 +129,22 @@ class PaymentPacketTestsCommon : LightningTestSuite() { } // Wallets don't need to decrypt onions for intermediate nodes, but it's useful to test that encryption works correctly. - fun decryptChannelRelay(add: UpdateAddHtlc, privateKey: PrivateKey): Pair { + fun decryptChannelRelay(add: UpdateAddHtlc, privateKey: PrivateKey): Pair { val decrypted = Sphinx.peel(privateKey, add.paymentHash, add.onionRoutingPacket, OnionRoutingPacket.PaymentPacketLength).right!! assertFalse(decrypted.isLastPacket) - val decoded = ChannelRelayPayload.read(ByteArrayInput(decrypted.payload.toByteArray())) + val decoded = PaymentOnion.ChannelRelayPayload.read(ByteArrayInput(decrypted.payload.toByteArray())) return Pair(decoded, decrypted.nextPacket) } // Wallets don't need to decrypt onions for intermediate nodes, but it's useful to test that encryption works correctly. - fun decryptNodeRelay(add: UpdateAddHtlc, privateKey: PrivateKey): Triple { + fun decryptNodeRelay(add: UpdateAddHtlc, privateKey: PrivateKey): Triple { val decrypted = Sphinx.peel(privateKey, add.paymentHash, add.onionRoutingPacket, OnionRoutingPacket.PaymentPacketLength).right!! assertTrue(decrypted.isLastPacket) - val outerPayload = FinalPayload.read(ByteArrayInput(decrypted.payload.toByteArray())) - val trampolineOnion = outerPayload.records.get() + val outerPayload = PaymentOnion.FinalPayload.read(ByteArrayInput(decrypted.payload.toByteArray())) + val trampolineOnion = outerPayload.records.get() assertNotNull(trampolineOnion) val decryptedInner = Sphinx.peel(privateKey, add.paymentHash, trampolineOnion.packet, OnionRoutingPacket.TrampolinePacketLength).right!! - val innerPayload = NodeRelayPayload.read(ByteArrayInput(decryptedInner.payload.toByteArray())) + val innerPayload = PaymentOnion.NodeRelayPayload.read(ByteArrayInput(decryptedInner.payload.toByteArray())) return Triple(outerPayload, innerPayload, decryptedInner.nextPacket) } @@ -151,7 +157,7 @@ class PaymentPacketTestsCommon : LightningTestSuite() { @Test fun `build a command including the onion`() { - val (add, _) = OutgoingPacket.buildCommand(UUID.randomUUID(), paymentHash, hops, FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret)) + val (add, _) = OutgoingPaymentPacket.buildCommand(UUID.randomUUID(), paymentHash, hops, PaymentOnion.FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, null)) assertTrue(add.amount > finalAmount) assertEquals(add.cltvExpiry, finalExpiry + channelUpdateDE.cltvExpiryDelta + channelUpdateCD.cltvExpiryDelta + channelUpdateBC.cltvExpiryDelta) assertEquals(add.paymentHash, paymentHash) @@ -164,7 +170,7 @@ class PaymentPacketTestsCommon : LightningTestSuite() { @Test fun `build a command with no hops`() { val paymentSecret = randomBytes32() - val (add, _) = OutgoingPacket.buildCommand(UUID.randomUUID(), paymentHash, hops.take(1), FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret)) + val (add, _) = OutgoingPaymentPacket.buildCommand(UUID.randomUUID(), paymentHash, hops.take(1), PaymentOnion.FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, paymentSecret, paymentMetadata)) assertEquals(add.amount, finalAmount) assertEquals(add.cltvExpiry, finalExpiry) assertEquals(add.paymentHash, paymentHash) @@ -172,11 +178,12 @@ class PaymentPacketTestsCommon : LightningTestSuite() { // let's peel the onion val addB = UpdateAddHtlc(randomBytes32(), 0, finalAmount, paymentHash, finalExpiry, add.onion) - val finalPayload = IncomingPacket.decrypt(addB, privB).right!! + val finalPayload = IncomingPaymentPacket.decrypt(addB, privB).right!! assertEquals(finalPayload.amount, finalAmount) assertEquals(finalPayload.totalAmount, finalAmount) assertEquals(finalPayload.expiry, finalExpiry) assertEquals(paymentSecret, finalPayload.paymentSecret) + assertEquals(paymentMetadata, finalPayload.paymentMetadata) } @Test @@ -186,19 +193,19 @@ class PaymentPacketTestsCommon : LightningTestSuite() { // / \ / \ // a -> b -> c d e - val (amountAC, expiryAC, trampolineOnion) = OutgoingPacket.buildPacket( + val (amountAC, expiryAC, trampolineOnion) = OutgoingPaymentPacket.buildPacket( paymentHash, trampolineHops, - FinalPayload.createMultiPartPayload(finalAmount, finalAmount * 3, finalExpiry, paymentSecret), + PaymentOnion.FinalPayload.createMultiPartPayload(finalAmount, finalAmount * 3, finalExpiry, paymentSecret, paymentMetadata), OnionRoutingPacket.TrampolinePacketLength ) assertEquals(amountBC, amountAC) assertEquals(expiryBC, expiryAC) - val (firstAmount, firstExpiry, onion) = OutgoingPacket.buildPacket( + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( paymentHash, trampolineChannelHops, - FinalPayload.createTrampolinePayload(amountAC, amountAC, expiryAC, randomBytes32(), trampolineOnion.packet), + PaymentOnion.FinalPayload.createTrampolinePayload(amountAC, amountAC, expiryAC, randomBytes32(), trampolineOnion.packet), OnionRoutingPacket.PaymentPacketLength ) assertEquals(amountAB, firstAmount) @@ -206,7 +213,7 @@ class PaymentPacketTestsCommon : LightningTestSuite() { val addB = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) val (payloadB, packetC) = decryptChannelRelay(addB, privB) - assertEquals(ChannelRelayPayload.create(channelUpdateBC.shortChannelId, amountBC, expiryBC), payloadB) + assertEquals(PaymentOnion.ChannelRelayPayload.create(channelUpdateBC.shortChannelId, amountBC, expiryBC), payloadB) val addC = UpdateAddHtlc(randomBytes32(), 2, amountBC, paymentHash, expiryBC, packetC) val (outerC, innerC, packetD) = decryptNodeRelay(addC, privC) @@ -221,10 +228,10 @@ class PaymentPacketTestsCommon : LightningTestSuite() { assertNull(innerC.paymentSecret) // c forwards the trampoline payment to d. - val (amountD, expiryD, onionD) = OutgoingPacket.buildPacket( + val (amountD, expiryD, onionD) = OutgoingPaymentPacket.buildPacket( paymentHash, listOf(ChannelHop(c, d, channelUpdateCD)), - FinalPayload.createTrampolinePayload(amountCD, amountCD, expiryCD, randomBytes32(), packetD), + PaymentOnion.FinalPayload.createTrampolinePayload(amountCD, amountCD, expiryCD, randomBytes32(), packetD), OnionRoutingPacket.PaymentPacketLength ) assertEquals(amountCD, amountD) @@ -242,17 +249,27 @@ class PaymentPacketTestsCommon : LightningTestSuite() { assertNull(innerD.paymentSecret) // d forwards the trampoline payment to e. - val (amountE, expiryE, onionE) = OutgoingPacket.buildPacket( + val (amountE, expiryE, onionE) = OutgoingPaymentPacket.buildPacket( paymentHash, listOf(ChannelHop(d, e, channelUpdateDE)), - FinalPayload.createTrampolinePayload(amountDE, amountDE, expiryDE, randomBytes32(), packetE), + PaymentOnion.FinalPayload.createTrampolinePayload(amountDE, amountDE, expiryDE, randomBytes32(), packetE), OnionRoutingPacket.PaymentPacketLength ) assertEquals(amountDE, amountE) assertEquals(expiryDE, expiryE) val addE = UpdateAddHtlc(randomBytes32(), 4, amountE, paymentHash, expiryE, onionE.packet) - val payloadE = IncomingPacket.decrypt(addE, privE).right!! - assertEquals(payloadE, FinalPayload(TlvStream(listOf(OnionTlv.AmountToForward(finalAmount), OnionTlv.OutgoingCltv(finalExpiry), OnionTlv.PaymentData(paymentSecret, finalAmount * 3))))) + val payloadE = IncomingPaymentPacket.decrypt(addE, privE).right!! + val expectedFinalPayload = PaymentOnion.FinalPayload( + TlvStream( + listOf( + OnionPaymentPayloadTlv.AmountToForward(finalAmount), + OnionPaymentPayloadTlv.OutgoingCltv(finalExpiry), + OnionPaymentPayloadTlv.PaymentData(paymentSecret, finalAmount * 3), + OnionPaymentPayloadTlv.PaymentMetadata(paymentMetadata) + ) + ) + ) + assertEquals(payloadE, expectedFinalPayload) } @Test @@ -268,19 +285,20 @@ class PaymentPacketTestsCommon : LightningTestSuite() { "lnbcrt", finalAmount, currentTimestampSeconds(), e, listOf( PaymentRequest.TaggedField.PaymentHash(paymentHash), PaymentRequest.TaggedField.PaymentSecret(paymentSecret), + PaymentRequest.TaggedField.PaymentMetadata(paymentMetadata), PaymentRequest.TaggedField.DescriptionHash(randomBytes32()), PaymentRequest.TaggedField.Features(invoiceFeatures.toByteArray().toByteVector()), PaymentRequest.TaggedField.RoutingInfo(routingHints) ), ByteVector.empty ) - val (amountAC, expiryAC, trampolineOnion) = OutgoingPacket.buildTrampolineToLegacyPacket(invoice, trampolineHops, FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32())) + val (amountAC, expiryAC, trampolineOnion) = OutgoingPaymentPacket.buildTrampolineToLegacyPacket(invoice, trampolineHops, PaymentOnion.FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32(), null)) assertEquals(amountBC, amountAC) assertEquals(expiryBC, expiryAC) - val (firstAmount, firstExpiry, onion) = OutgoingPacket.buildPacket( + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( paymentHash, trampolineChannelHops, - FinalPayload.createTrampolinePayload(amountAC, amountAC, expiryAC, randomBytes32(), trampolineOnion.packet), + PaymentOnion.FinalPayload.createTrampolinePayload(amountAC, amountAC, expiryAC, randomBytes32(), trampolineOnion.packet), OnionRoutingPacket.PaymentPacketLength ) assertEquals(amountAB, firstAmount) @@ -303,10 +321,10 @@ class PaymentPacketTestsCommon : LightningTestSuite() { assertNull(innerC.paymentSecret) // c forwards the trampoline payment to d. - val (amountD, expiryD, onionD) = OutgoingPacket.buildPacket( + val (amountD, expiryD, onionD) = OutgoingPaymentPacket.buildPacket( paymentHash, listOf(ChannelHop(c, d, channelUpdateCD)), - FinalPayload.createTrampolinePayload(amountCD, amountCD, expiryCD, randomBytes32(), packetD), + PaymentOnion.FinalPayload.createTrampolinePayload(amountCD, amountCD, expiryCD, randomBytes32(), packetD), OnionRoutingPacket.PaymentPacketLength ) assertEquals(amountCD, amountD) @@ -322,6 +340,7 @@ class PaymentPacketTestsCommon : LightningTestSuite() { assertEquals(e, innerD.outgoingNodeId) assertEquals(finalAmount, innerD.totalAmount) assertEquals(invoice.paymentSecret, innerD.paymentSecret) + assertEquals(invoice.paymentMetadata, innerD.paymentMetadata) assertEquals(ByteVector("024100"), innerD.invoiceFeatures) // var_onion_optin, payment_secret, basic_mpp assertEquals(listOf(routingHints), innerD.invoiceRoutingInfo) } @@ -340,134 +359,149 @@ class PaymentPacketTestsCommon : LightningTestSuite() { PaymentRequest.TaggedField.RoutingInfo(routingHintOverflow) ), ByteVector.empty ) - assertFails { OutgoingPacket.buildTrampolineToLegacyPacket(invoice, trampolineHops, FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32())) } + assertFails { OutgoingPaymentPacket.buildTrampolineToLegacyPacket(invoice, trampolineHops, PaymentOnion.FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32(), null)) } } @Test fun `fail to decrypt when the onion is invalid`() { - val (firstAmount, firstExpiry, onion) = OutgoingPacket.buildPacket(paymentHash, hops, FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32()), OnionRoutingPacket.PaymentPacketLength) + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket(paymentHash, hops, PaymentOnion.FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32(), null), OnionRoutingPacket.PaymentPacketLength) val add = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet.copy(payload = onion.packet.payload.reversed())) - val failure = IncomingPacket.decrypt(add, privB) + val failure = IncomingPaymentPacket.decrypt(add, privB) assertTrue(failure.isLeft) assertEquals(InvalidOnionHmac.code, failure.left!!.code) } @Test fun `fail to decrypt when the trampoline onion is invalid`() { - val (amountAC, expiryAC, trampolineOnion) = OutgoingPacket.buildPacket( + val (amountAC, expiryAC, trampolineOnion) = OutgoingPaymentPacket.buildPacket( paymentHash, trampolineHops, - FinalPayload.createMultiPartPayload(finalAmount, finalAmount * 2, finalExpiry, paymentSecret), + PaymentOnion.FinalPayload.createMultiPartPayload(finalAmount, finalAmount * 2, finalExpiry, paymentSecret, null), OnionRoutingPacket.TrampolinePacketLength ) - val (firstAmount, firstExpiry, onion) = OutgoingPacket.buildPacket( + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( paymentHash, trampolineChannelHops, - FinalPayload.createTrampolinePayload(amountAC, amountAC, expiryAC, randomBytes32(), trampolineOnion.packet.copy(payload = trampolineOnion.packet.payload.reversed())), + PaymentOnion.FinalPayload.createTrampolinePayload(amountAC, amountAC, expiryAC, randomBytes32(), trampolineOnion.packet.copy(payload = trampolineOnion.packet.payload.reversed())), OnionRoutingPacket.PaymentPacketLength ) val addB = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) val (_, packetC) = decryptChannelRelay(addB, privB) val addC = UpdateAddHtlc(randomBytes32(), 2, amountBC, paymentHash, expiryBC, packetC) - val failure = IncomingPacket.decrypt(addC, privC) + val failure = IncomingPaymentPacket.decrypt(addC, privC) assertTrue(failure.isLeft) assertEquals(InvalidOnionHmac.code, failure.left!!.code) } @Test fun `fail to decrypt when payment hash doesn't match associated data`() { - val (firstAmount, firstExpiry, onion) = OutgoingPacket.buildPacket(paymentHash.reversed(), hops, FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32()), OnionRoutingPacket.PaymentPacketLength) + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( + paymentHash.reversed(), + hops, + PaymentOnion.FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32(), paymentMetadata), + OnionRoutingPacket.PaymentPacketLength + ) val add = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet) - val failure = IncomingPacket.decrypt(add, privB) + val failure = IncomingPaymentPacket.decrypt(add, privB) assertTrue(failure.isLeft) assertEquals(InvalidOnionHmac.code, failure.left!!.code) } @Test fun `fail to decrypt at the final node when amount has been modified by next-to-last node`() { - val (firstAmount, firstExpiry, onion) = OutgoingPacket.buildPacket(paymentHash, hops.take(1), FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32()), OnionRoutingPacket.PaymentPacketLength) + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( + paymentHash, + hops.take(1), + PaymentOnion.FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32(), paymentMetadata), + OnionRoutingPacket.PaymentPacketLength + ) val add = UpdateAddHtlc(randomBytes32(), 1, firstAmount - 100.msat, paymentHash, firstExpiry, onion.packet) - val failure = IncomingPacket.decrypt(add, privB) + val failure = IncomingPaymentPacket.decrypt(add, privB) assertEquals(Either.Left(FinalIncorrectHtlcAmount(firstAmount - 100.msat)), failure) } @Test fun `fail to decrypt at the final node when expiry has been modified by next-to-last node`() { - val (firstAmount, firstExpiry, onion) = OutgoingPacket.buildPacket(paymentHash, hops.take(1), FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32()), OnionRoutingPacket.PaymentPacketLength) + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( + paymentHash, + hops.take(1), + PaymentOnion.FinalPayload.createSinglePartPayload(finalAmount, finalExpiry, randomBytes32(), paymentMetadata), + OnionRoutingPacket.PaymentPacketLength + ) val add = UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry - CltvExpiryDelta(12), onion.packet) - val failure = IncomingPacket.decrypt(add, privB) + val failure = IncomingPaymentPacket.decrypt(add, privB) assertEquals(Either.Left(FinalIncorrectCltvExpiry(firstExpiry - CltvExpiryDelta(12))), failure) } @Test fun `fail to decrypt at the final trampoline node when amount has been modified by next-to-last trampoline`() { - val (amountAC, expiryAC, trampolineOnion) = OutgoingPacket.buildPacket( + val (amountAC, expiryAC, trampolineOnion) = OutgoingPaymentPacket.buildPacket( paymentHash, trampolineHops, - FinalPayload.createMultiPartPayload(finalAmount, finalAmount, finalExpiry, paymentSecret), + PaymentOnion.FinalPayload.createMultiPartPayload(finalAmount, finalAmount, finalExpiry, paymentSecret, null), OnionRoutingPacket.TrampolinePacketLength ) - val (firstAmount, firstExpiry, onion) = OutgoingPacket.buildPacket( + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( paymentHash, trampolineChannelHops, - FinalPayload.createTrampolinePayload(amountAC, amountAC, expiryAC, randomBytes32(), trampolineOnion.packet), + PaymentOnion.FinalPayload.createTrampolinePayload(amountAC, amountAC, expiryAC, randomBytes32(), trampolineOnion.packet), OnionRoutingPacket.PaymentPacketLength ) val (_, packetC) = decryptChannelRelay(UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet), privB) val (_, _, packetD) = decryptNodeRelay(UpdateAddHtlc(randomBytes32(), 2, amountBC, paymentHash, expiryBC, packetC), privC) // c forwards the trampoline payment to d. - val (amountD, expiryD, onionD) = OutgoingPacket.buildPacket( + val (amountD, expiryD, onionD) = OutgoingPaymentPacket.buildPacket( paymentHash, listOf(ChannelHop(c, d, channelUpdateCD)), - FinalPayload.createTrampolinePayload(amountCD, amountCD, expiryCD, randomBytes32(), packetD), + PaymentOnion.FinalPayload.createTrampolinePayload(amountCD, amountCD, expiryCD, randomBytes32(), packetD), OnionRoutingPacket.PaymentPacketLength ) val (_, _, packetE) = decryptNodeRelay(UpdateAddHtlc(randomBytes32(), 3, amountD, paymentHash, expiryD, onionD.packet), privD) // d forwards an invalid amount to e (the outer total amount doesn't match the inner amount). val invalidTotalAmount = amountDE + 100.msat - val (amountE, expiryE, onionE) = OutgoingPacket.buildPacket( + val (amountE, expiryE, onionE) = OutgoingPaymentPacket.buildPacket( paymentHash, listOf(ChannelHop(d, e, channelUpdateDE)), - FinalPayload.createTrampolinePayload(amountDE, invalidTotalAmount, expiryDE, randomBytes32(), packetE), + PaymentOnion.FinalPayload.createTrampolinePayload(amountDE, invalidTotalAmount, expiryDE, randomBytes32(), packetE), OnionRoutingPacket.PaymentPacketLength ) - val failure = IncomingPacket.decrypt(UpdateAddHtlc(randomBytes32(), 4, amountE, paymentHash, expiryE, onionE.packet), privE) + val failure = IncomingPaymentPacket.decrypt(UpdateAddHtlc(randomBytes32(), 4, amountE, paymentHash, expiryE, onionE.packet), privE) assertEquals(Either.Left(FinalIncorrectHtlcAmount(invalidTotalAmount)), failure) } @Test fun `fail to decrypt at the final trampoline node when expiry has been modified by next-to-last trampoline`() { - val (amountAC, expiryAC, trampolineOnion) = OutgoingPacket.buildPacket( + val (amountAC, expiryAC, trampolineOnion) = OutgoingPaymentPacket.buildPacket( paymentHash, trampolineHops, - FinalPayload.createMultiPartPayload(finalAmount, finalAmount, finalExpiry, paymentSecret), + PaymentOnion.FinalPayload.createMultiPartPayload(finalAmount, finalAmount, finalExpiry, paymentSecret, null), OnionRoutingPacket.TrampolinePacketLength ) - val (firstAmount, firstExpiry, onion) = OutgoingPacket.buildPacket( + val (firstAmount, firstExpiry, onion) = OutgoingPaymentPacket.buildPacket( paymentHash, trampolineChannelHops, - FinalPayload.createTrampolinePayload(amountAC, amountAC, expiryAC, randomBytes32(), trampolineOnion.packet), + PaymentOnion.FinalPayload.createTrampolinePayload(amountAC, amountAC, expiryAC, randomBytes32(), trampolineOnion.packet), OnionRoutingPacket.PaymentPacketLength ) val (_, packetC) = decryptChannelRelay(UpdateAddHtlc(randomBytes32(), 1, firstAmount, paymentHash, firstExpiry, onion.packet), privB) val (_, _, packetD) = decryptNodeRelay(UpdateAddHtlc(randomBytes32(), 2, amountBC, paymentHash, expiryBC, packetC), privC) // c forwards the trampoline payment to d. - val (amountD, expiryD, onionD) = OutgoingPacket.buildPacket( + val (amountD, expiryD, onionD) = OutgoingPaymentPacket.buildPacket( paymentHash, listOf(ChannelHop(c, d, channelUpdateCD)), - FinalPayload.createTrampolinePayload(amountCD, amountCD, expiryCD, randomBytes32(), packetD), + PaymentOnion.FinalPayload.createTrampolinePayload(amountCD, amountCD, expiryCD, randomBytes32(), packetD), OnionRoutingPacket.PaymentPacketLength ) val (_, _, packetE) = decryptNodeRelay(UpdateAddHtlc(randomBytes32(), 3, amountD, paymentHash, expiryD, onionD.packet), privD) // d forwards an invalid expiry to e (the outer expiry doesn't match the inner expiry). val invalidExpiry = expiryDE - CltvExpiryDelta(12) - val (amountE, expiryE, onionE) = OutgoingPacket.buildPacket( + val (amountE, expiryE, onionE) = OutgoingPaymentPacket.buildPacket( paymentHash, listOf(ChannelHop(d, e, channelUpdateDE)), - FinalPayload.createTrampolinePayload(amountDE, amountDE, invalidExpiry, randomBytes32(), packetE), + PaymentOnion.FinalPayload.createTrampolinePayload(amountDE, amountDE, invalidExpiry, randomBytes32(), packetE), OnionRoutingPacket.PaymentPacketLength ) - val failure = IncomingPacket.decrypt(UpdateAddHtlc(randomBytes32(), 4, amountE, paymentHash, expiryE, onionE.packet), privE) + val failure = IncomingPaymentPacket.decrypt(UpdateAddHtlc(randomBytes32(), 4, amountE, paymentHash, expiryE, onionE.packet), privE) assertEquals(Either.Left(FinalIncorrectCltvExpiry(invalidExpiry)), failure) } } diff --git a/src/commonTest/kotlin/fr/acinq/lightning/payment/PaymentRequestTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/payment/PaymentRequestTestsCommon.kt index a300b8ff1..3c69d70d7 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/payment/PaymentRequestTestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/payment/PaymentRequestTestsCommon.kt @@ -3,6 +3,7 @@ package fr.acinq.lightning.payment import fr.acinq.bitcoin.* import fr.acinq.lightning.* import fr.acinq.lightning.Lightning.randomBytes32 +import fr.acinq.lightning.Lightning.randomKey import fr.acinq.lightning.tests.utils.LightningTestSuite import fr.acinq.lightning.utils.* import fr.acinq.secp256k1.Hex @@ -362,6 +363,24 @@ class PaymentRequestTestsCommon : LightningTestSuite() { assertEquals(ref, check) } + @Test + fun `On mainnet, please send 0,01 BTC with payment metadata 0x01fafaf0`() { + val ref = "lnbc10m1pvjluezpp5qqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqqqsyqcyq5rqwzqfqypqdp9wpshjmt9de6zqmt9w3skgct5vysxjmnnd9jx2mq8q8a04uqsp5zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zyg3zygs9q2gqqqqqqsgq7hf8he7ecf7n4ffphs6awl9t6676rrclv9ckg3d3ncn7fct63p6s365duk5wrk202cfy3aj5xnnp5gs3vrdvruverwwq7yzhkf5a3xqpd05wjc" + val pr = PaymentRequest.read(ref) + assertEquals(pr.prefix, "lnbc") + assertEquals(pr.amount, MilliSatoshi(1000000000)) + assertEquals(pr.paymentHash, ByteVector32("0001020304050607080900010203040506070809000102030405060708090102")) + assertEquals(pr.features, Features(Feature.VariableLengthOnion to FeatureSupport.Mandatory, Feature.PaymentSecret to FeatureSupport.Mandatory, Feature.PaymentMetadata to FeatureSupport.Mandatory).toByteArray().toByteVector()) + assertEquals(pr.timestampSeconds, 1496314658L) + assertEquals(pr.nodeId, PublicKey.fromHex("03e7156ae33b0a208d0744199163177e909e80176e55d97a2f221ede0f934dd9ad")) + assertEquals(pr.paymentSecret, ByteVector32("1111111111111111111111111111111111111111111111111111111111111111")) + assertEquals(pr.description, "payment metadata inside") + assertEquals(pr.paymentMetadata, ByteVector("01fafaf0")) + assertEquals(pr.tags.size, 5) + val check = pr.sign(priv).write() + assertEquals(ref, check) + } + @Test fun `reject invalid invoices`() { val refs = listOf( @@ -408,6 +427,16 @@ class PaymentRequestTestsCommon : LightningTestSuite() { assertEquals(21, unknownTag!!.tag) } + @Test + fun `filter non-invoice features`() { + val nodeFeatures = Features( + mapOf(Feature.VariableLengthOnion to FeatureSupport.Mandatory, Feature.PaymentSecret to FeatureSupport.Mandatory, Feature.ShutdownAnySegwit to FeatureSupport.Optional), + setOf(UnknownFeature(103), UnknownFeature(256)) + ) + val pr = PaymentRequest.create(Block.LivenetGenesisBlock.hash, 500.msat, randomBytes32(), randomKey(), "non-invoice features", CltvExpiryDelta(6), nodeFeatures) + assertEquals(Features(Feature.VariableLengthOnion to FeatureSupport.Mandatory, Feature.PaymentSecret to FeatureSupport.Mandatory), Features(pr.features)) + } + @Test fun `feature bits to minimally-encoded feature bytes`() { val testCases = listOf( @@ -436,8 +465,7 @@ class PaymentRequestTestsCommon : LightningTestSuite() { @Test fun `payment secret`() { - val features = - Features(Feature.VariableLengthOnion to FeatureSupport.Mandatory, Feature.PaymentSecret to FeatureSupport.Mandatory, Feature.BasicMultiPartPayment to FeatureSupport.Optional) + val features = Features(Feature.VariableLengthOnion to FeatureSupport.Mandatory, Feature.PaymentSecret to FeatureSupport.Mandatory, Feature.BasicMultiPartPayment to FeatureSupport.Optional) val pr = PaymentRequest.create(Block.LivenetGenesisBlock.hash, 123.msat, ByteVector32.One, priv, "Some invoice", CltvExpiryDelta(18), features) assertNotNull(pr.paymentSecret) assertEquals(ByteVector("024100"), pr.features) @@ -465,28 +493,4 @@ class PaymentRequestTestsCommon : LightningTestSuite() { } } - @Test - fun filterFeatures() { - assertEquals( - expected = PaymentRequest.invoiceFeatures( - Features( - activated = mapOf( - Feature.InitialRoutingSync to FeatureSupport.Optional, - Feature.StaticRemoteKey to FeatureSupport.Mandatory, - Feature.PaymentSecret to FeatureSupport.Mandatory, - Feature.TrampolinePayment to FeatureSupport.Optional, - ), - unknown = setOf( - UnknownFeature(47) - ) - ) - ), - actual = Features( - activated = mapOf( - Feature.PaymentSecret to FeatureSupport.Mandatory, - Feature.TrampolinePayment to FeatureSupport.Optional - ) - ) - ) - } } \ No newline at end of file diff --git a/src/commonTest/kotlin/fr/acinq/lightning/wire/OnionTestsCommon.kt b/src/commonTest/kotlin/fr/acinq/lightning/wire/PaymentOnionTestsCommon.kt similarity index 76% rename from src/commonTest/kotlin/fr/acinq/lightning/wire/OnionTestsCommon.kt rename to src/commonTest/kotlin/fr/acinq/lightning/wire/PaymentOnionTestsCommon.kt index 5e95bf964..d73185a1e 100644 --- a/src/commonTest/kotlin/fr/acinq/lightning/wire/OnionTestsCommon.kt +++ b/src/commonTest/kotlin/fr/acinq/lightning/wire/PaymentOnionTestsCommon.kt @@ -17,7 +17,7 @@ import kotlin.test.assertEquals import kotlin.test.assertFails import kotlin.test.assertNull -class OnionTestsCommon : LightningTestSuite() { +class PaymentOnionTestsCommon : LightningTestSuite() { @Test fun `encode - decode onion packet`() { val bin = Hex.decode( @@ -40,14 +40,14 @@ class OnionTestsCommon : LightningTestSuite() { @Test fun `encode - decode channel relay per-hop payload`() { val testCases = mapOf( - ChannelRelayPayload.create(ShortChannelId(0), MilliSatoshi(0), CltvExpiry(0)) to Hex.decode("0e 0200 0400 06080000000000000000"), - ChannelRelayPayload.create(ShortChannelId(42), MilliSatoshi(142000), CltvExpiry(500000)) to Hex.decode("14 0203022ab0 040307a120 0608000000000000002a"), - ChannelRelayPayload.create(ShortChannelId(561), MilliSatoshi(1105), CltvExpiry(1729)) to Hex.decode("12 02020451 040206c1 06080000000000000231") + PaymentOnion.ChannelRelayPayload.create(ShortChannelId(0), MilliSatoshi(0), CltvExpiry(0)) to Hex.decode("0e 0200 0400 06080000000000000000"), + PaymentOnion.ChannelRelayPayload.create(ShortChannelId(42), MilliSatoshi(142000), CltvExpiry(500000)) to Hex.decode("14 0203022ab0 040307a120 0608000000000000002a"), + PaymentOnion.ChannelRelayPayload.create(ShortChannelId(561), MilliSatoshi(1105), CltvExpiry(1729)) to Hex.decode("12 02020451 040206c1 06080000000000000231") ) testCases.forEach { val expected = it.key - val decoded = ChannelRelayPayload.read(it.value) + val decoded = PaymentOnion.ChannelRelayPayload.read(it.value) assertEquals(expected, decoded) val encoded = decoded.write() assertArrayEquals(it.value, encoded) @@ -57,10 +57,10 @@ class OnionTestsCommon : LightningTestSuite() { @Test fun `encode - decode variable-length (tlv) node relay per-hop payload`() { val nodeId = PublicKey(Hex.decode("02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619")) - val expected = NodeRelayPayload(TlvStream(listOf(OnionTlv.AmountToForward(561.msat), OnionTlv.OutgoingCltv(CltvExpiry(42)), OnionTlv.OutgoingNodeId(nodeId)))) + val expected = PaymentOnion.NodeRelayPayload(TlvStream(listOf(OnionPaymentPayloadTlv.AmountToForward(561.msat), OnionPaymentPayloadTlv.OutgoingCltv(CltvExpiry(42)), OnionPaymentPayloadTlv.OutgoingNodeId(nodeId)))) val bin = Hex.decode("2e 02020231 04012a fe000102322102eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619") - val decoded = NodeRelayPayload.read(bin) + val decoded = PaymentOnion.NodeRelayPayload.read(bin) assertEquals(expected, decoded) assertEquals(decoded.amountToForward, 561.msat) assertEquals(decoded.totalAmount, 561.msat) @@ -85,22 +85,22 @@ class OnionTestsCommon : LightningTestSuite() { listOf(PaymentRequest.TaggedField.ExtraHop(node1, ShortChannelId(1), 10.msat, 100, CltvExpiryDelta(144))), listOf(PaymentRequest.TaggedField.ExtraHop(node2, ShortChannelId(2), 20.msat, 150, CltvExpiryDelta(12)), PaymentRequest.TaggedField.ExtraHop(node3, ShortChannelId(3), 30.msat, 200, CltvExpiryDelta(24))) ) - val expected = NodeRelayPayload( + val expected = PaymentOnion.NodeRelayPayload( TlvStream( listOf( - OnionTlv.AmountToForward(561.msat), - OnionTlv.OutgoingCltv(CltvExpiry(42)), - OnionTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 1105.msat), - OnionTlv.InvoiceFeatures(features), - OnionTlv.OutgoingNodeId(nodeId), - OnionTlv.InvoiceRoutingInfo(routingHints) + OnionPaymentPayloadTlv.AmountToForward(561.msat), + OnionPaymentPayloadTlv.OutgoingCltv(CltvExpiry(42)), + OnionPaymentPayloadTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 1105.msat), + OnionPaymentPayloadTlv.InvoiceFeatures(features), + OnionPaymentPayloadTlv.OutgoingNodeId(nodeId), + OnionPaymentPayloadTlv.InvoiceRoutingInfo(routingHints) ) ) ) val bin = Hex.decode("fa 02020231 04012a 0822eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f2836866190451 fe00010231010a fe000102322102eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619 fe000102339b01036d6caac248af96f6afa7f904f550253a0f3ef3f5aa2fe6838a95b216691468e200000000000000010000000a00000064009002025f7117a78150fe2ef97db7cfc83bd57b2e2c0d0dd25eaf467a4a1c2a45ce148600000000000000020000001400000096000c02a051267759c3a149e3e72372f4e0c4054ba597ebfd0eda78a2273023667205ee00000000000000030000001e000000c80018") - val decoded = NodeRelayPayload.read(bin) + val decoded = PaymentOnion.NodeRelayPayload.read(bin) assertEquals(decoded, expected) assertEquals(decoded.amountToForward, 561.msat) assertEquals(decoded.totalAmount, 1105.msat) @@ -119,54 +119,61 @@ class OnionTestsCommon : LightningTestSuite() { val testCases = mapOf( TlvStream( listOf( - OnionTlv.AmountToForward(561.msat), - OnionTlv.OutgoingCltv(CltvExpiry(42)), - OnionTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 0.msat) + OnionPaymentPayloadTlv.AmountToForward(561.msat), + OnionPaymentPayloadTlv.OutgoingCltv(CltvExpiry(42)), + OnionPaymentPayloadTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 0.msat) ) ) to Hex.decode("29 02020231 04012a 0820eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), TlvStream( listOf( - OnionTlv.AmountToForward(561.msat), - OnionTlv.OutgoingCltv(CltvExpiry(42)), - OnionTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 1105.msat) + OnionPaymentPayloadTlv.AmountToForward(561.msat), + OnionPaymentPayloadTlv.OutgoingCltv(CltvExpiry(42)), + OnionPaymentPayloadTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 1105.msat) ) ) to Hex.decode("2b 02020231 04012a 0822eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f2836866190451"), TlvStream( listOf( - OnionTlv.AmountToForward(561.msat), - OnionTlv.OutgoingCltv(CltvExpiry(42)), - OnionTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 4294967295L.msat) + OnionPaymentPayloadTlv.AmountToForward(561.msat), + OnionPaymentPayloadTlv.OutgoingCltv(CltvExpiry(42)), + OnionPaymentPayloadTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 4294967295L.msat) ) ) to Hex.decode("2d 02020231 04012a 0824eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619ffffffff"), TlvStream( listOf( - OnionTlv.AmountToForward(561.msat), - OnionTlv.OutgoingCltv(CltvExpiry(42)), - OnionTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 4294967296L.msat) + OnionPaymentPayloadTlv.AmountToForward(561.msat), + OnionPaymentPayloadTlv.OutgoingCltv(CltvExpiry(42)), + OnionPaymentPayloadTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 4294967296L.msat) ) ) to Hex.decode("2e 02020231 04012a 0825eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f2836866190100000000"), TlvStream( listOf( - OnionTlv.AmountToForward(561.msat), - OnionTlv.OutgoingCltv(CltvExpiry(42)), - OnionTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 1099511627775L.msat) + OnionPaymentPayloadTlv.AmountToForward(561.msat), + OnionPaymentPayloadTlv.OutgoingCltv(CltvExpiry(42)), + OnionPaymentPayloadTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 1099511627775L.msat) ) ) to Hex.decode("2e 02020231 04012a 0825eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619ffffffffff"), TlvStream( listOf( - OnionTlv.AmountToForward(561.msat), - OnionTlv.OutgoingCltv(CltvExpiry(42)), - OnionTlv.OutgoingChannelId(ShortChannelId(1105)), - OnionTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 0.msat) + OnionPaymentPayloadTlv.AmountToForward(561.msat), + OnionPaymentPayloadTlv.OutgoingCltv(CltvExpiry(42)), + OnionPaymentPayloadTlv.OutgoingChannelId(ShortChannelId(1105)), + OnionPaymentPayloadTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 0.msat) ) ) to Hex.decode("33 02020231 04012a 06080000000000000451 0820eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), TlvStream( - listOf(OnionTlv.AmountToForward(561.msat), OnionTlv.OutgoingCltv(CltvExpiry(42)), OnionTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 0.msat)), + listOf( + OnionPaymentPayloadTlv.AmountToForward(561.msat), + OnionPaymentPayloadTlv.OutgoingCltv(CltvExpiry(42)), + OnionPaymentPayloadTlv.PaymentData(ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), 0.msat) + ), listOf(GenericTlv(65535, ByteVector("06c1"))) ) to Hex.decode("2f 02020231 04012a 0820eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619 fdffff0206c1"), TlvStream( listOf( - OnionTlv.AmountToForward(561.msat), OnionTlv.OutgoingCltv(CltvExpiry(42)), OnionTlv.PaymentData(ByteVector32("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), 0.msat), OnionTlv.TrampolineOnion( + OnionPaymentPayloadTlv.AmountToForward(561.msat), + OnionPaymentPayloadTlv.OutgoingCltv(CltvExpiry(42)), + OnionPaymentPayloadTlv.PaymentData(ByteVector32("ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"), 0.msat), + OnionPaymentPayloadTlv.TrampolineOnion( OnionRoutingPacket( 0, ByteVector("02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619"), @@ -182,12 +189,12 @@ class OnionTestsCommon : LightningTestSuite() { testCases.forEach { val expected = it.key - val decoded = FinalPayload.read(it.value) - assertEquals(decoded, FinalPayload(expected)) + val decoded = PaymentOnion.FinalPayload.read(it.value) + assertEquals(decoded, PaymentOnion.FinalPayload(expected)) assertEquals(decoded.amount, 561.msat) assertEquals(decoded.expiry, CltvExpiry(42)) - val encoded = FinalPayload(expected).write() + val encoded = PaymentOnion.FinalPayload(expected).write() assertArrayEquals(it.value, encoded) } } @@ -195,29 +202,29 @@ class OnionTestsCommon : LightningTestSuite() { @Test fun `encode - decode variable-length (tlv) final per-hop payload with custom user records`() { val tlvs = TlvStream( - listOf(OnionTlv.AmountToForward(561.msat), OnionTlv.OutgoingCltv(CltvExpiry(42)), OnionTlv.PaymentData(ByteVector32.Zeroes, 0.msat)), + listOf(OnionPaymentPayloadTlv.AmountToForward(561.msat), OnionPaymentPayloadTlv.OutgoingCltv(CltvExpiry(42)), OnionPaymentPayloadTlv.PaymentData(ByteVector32.Zeroes, 0.msat)), listOf(GenericTlv(5432123457L, ByteVector("16c7ec71663784ff100b6eface1e60a97b92ea9d18b8ece5e558586bc7453828"))) ) val bin = Hex.decode("53 02020231 04012a 08200000000000000000000000000000000000000000000000000000000000000000 ff0000000143c7a0412016c7ec71663784ff100b6eface1e60a97b92ea9d18b8ece5e558586bc7453828") - val encoded = FinalPayload(tlvs).write() + val encoded = PaymentOnion.FinalPayload(tlvs).write() assertArrayEquals(bin, encoded) - assertEquals(FinalPayload(tlvs), FinalPayload.read(bin)) + assertEquals(PaymentOnion.FinalPayload(tlvs), PaymentOnion.FinalPayload.read(bin)) } @Test fun `decode multi-part final per-hop payload`() { - val notMultiPart = FinalPayload.read(Hex.decode("29 02020231 04012a 08200000000000000000000000000000000000000000000000000000000000000000")) + val notMultiPart = PaymentOnion.FinalPayload.read(Hex.decode("29 02020231 04012a 08200000000000000000000000000000000000000000000000000000000000000000")) assertEquals(notMultiPart.totalAmount, 561.msat) assertEquals(notMultiPart.paymentSecret, ByteVector32.Zeroes) - val multiPart = FinalPayload.read(Hex.decode("2b 02020231 04012a 0822eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f2836866190451")) + val multiPart = PaymentOnion.FinalPayload.read(Hex.decode("2b 02020231 04012a 0822eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f2836866190451")) assertEquals(multiPart.amount, 561.msat) assertEquals(multiPart.expiry, CltvExpiry(42)) assertEquals(multiPart.totalAmount, 1105.msat) assertEquals(multiPart.paymentSecret, ByteVector32("eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619")) - val multiPartNoTotalAmount = FinalPayload.read(Hex.decode("29 02020231 04012a 0820eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619")) + val multiPartNoTotalAmount = PaymentOnion.FinalPayload.read(Hex.decode("29 02020231 04012a 0820eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619")) assertEquals(multiPartNoTotalAmount.amount, 561.msat) assertEquals(multiPartNoTotalAmount.expiry, CltvExpiry(42)) assertEquals(multiPartNoTotalAmount.totalAmount, 561.msat) @@ -233,7 +240,7 @@ class OnionTestsCommon : LightningTestSuite() { ) testCases.forEach { - assertFails { FinalPayload.read(it) } + assertFails { PaymentOnion.FinalPayload.read(it) } } } @@ -254,9 +261,9 @@ class OnionTestsCommon : LightningTestSuite() { ) testCases.forEach { - assertFails { ChannelRelayPayload.read(it) } - assertFails { NodeRelayPayload.read(it) } - assertFails { FinalPayload.read(it) } + assertFails { PaymentOnion.ChannelRelayPayload.read(it) } + assertFails { PaymentOnion.NodeRelayPayload.read(it) } + assertFails { PaymentOnion.FinalPayload.read(it) } } } } \ No newline at end of file diff --git a/src/jvmTest/kotlin/fr/acinq/lightning/Node.kt b/src/jvmTest/kotlin/fr/acinq/lightning/Node.kt index ba9e255c8..7ed7dbf98 100644 --- a/src/jvmTest/kotlin/fr/acinq/lightning/Node.kt +++ b/src/jvmTest/kotlin/fr/acinq/lightning/Node.kt @@ -148,6 +148,7 @@ object Node { Feature.StaticRemoteKey to FeatureSupport.Optional, Feature.AnchorOutputs to FeatureSupport.Optional, Feature.ChannelType to FeatureSupport.Mandatory, + Feature.PaymentMetadata to FeatureSupport.Optional, Feature.TrampolinePayment to FeatureSupport.Optional, Feature.ZeroReserveChannels to FeatureSupport.Optional, Feature.ZeroConfChannels to FeatureSupport.Optional,