diff --git a/eclair-core/src/main/resources/reference.conf b/eclair-core/src/main/resources/reference.conf index f63859ff35..97ea1eb1a4 100644 --- a/eclair-core/src/main/resources/reference.conf +++ b/eclair-core/src/main/resources/reference.conf @@ -54,6 +54,7 @@ eclair { option_anchor_outputs = disabled option_anchors_zero_fee_htlc_tx = disabled option_shutdown_anysegwit = optional + option_onion_messages = optional trampoline_payment = disabled keysend = disabled } @@ -333,6 +334,10 @@ eclair { "mempool.space" ] } + + onion-messages { + rate-limit-per-second = 10 + } } akka { diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala b/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala index 5dba8a4a7e..4c392d5642 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/Features.scala @@ -203,6 +203,11 @@ object Features { val mandatory = 26 } + case object OnionMessages extends Feature { + val rfcName = "option_onion_messages" + val mandatory = 38 + } + // TODO: @t-bast: update feature bits once spec-ed (currently reserved here: https://github.com/lightningnetwork/lightning-rfc/issues/605) // We're not advertising these bits yet in our announcements, clients have to assume support. // This is why we haven't added them yet to `areSupported`. @@ -231,6 +236,7 @@ object Features { AnchorOutputs, AnchorOutputsZeroFeeHtlcTx, ShutdownAnySegwit, + OnionMessages, KeySend ) diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala index 37fff15cc1..233b47a65b 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala @@ -94,7 +94,8 @@ case class NodeParams(nodeKeyManager: NodeKeyManager, maxPaymentAttempts: Int, enableTrampolinePayment: Boolean, balanceCheckInterval: FiniteDuration, - blockchainWatchdogSources: Seq[String]) { + blockchainWatchdogSources: Seq[String], + onionMessageRateLimitPerSecond: Double) { val privateKey: Crypto.PrivateKey = nodeKeyManager.nodeKey.privateKey val nodeId: PublicKey = nodeKeyManager.nodeId @@ -457,7 +458,8 @@ object NodeParams extends Logging { maxPaymentAttempts = config.getInt("max-payment-attempts"), enableTrampolinePayment = config.getBoolean("trampoline-payments-enable"), balanceCheckInterval = FiniteDuration(config.getDuration("balance-check-interval").getSeconds, TimeUnit.SECONDS), - blockchainWatchdogSources = config.getStringList("blockchain-watchdog.sources").asScala.toSeq + blockchainWatchdogSources = config.getStringList("blockchain-watchdog.sources").asScala.toSeq, + onionMessageRateLimitPerSecond = config.getDouble("onion-messages.rate-limit-per-second") ) } } diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala index 73e8c88a81..79ce4d0564 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala @@ -24,6 +24,7 @@ import fr.acinq.eclair.wire.protocol._ import grizzled.slf4j.Logging import scodec.Attempt import scodec.bits.ByteVector +import scodec.codecs.provide import scala.annotation.tailrec import scala.util.{Failure, Success, Try} @@ -272,7 +273,7 @@ object Sphinx extends Logging { * When an invalid onion is received, its hash should be included in the failure message. */ def hash(onion: protocol.OnionRoutingPacket): ByteVector32 = - Crypto.sha256(OnionCodecs.onionRoutingPacketCodec(onion.payload.length.toInt).encode(onion).require.toByteVector) + Crypto.sha256(OnionCodecs.onionRoutingPacketCodec(provide(onion.payload.length.toInt)).encode(onion).require.toByteVector) } @@ -291,6 +292,12 @@ object Sphinx extends Logging { override val PayloadLength = 400 } + /** + * A message onion packet is used when requesting/sending an invoice from/to a remote node when using offers (BOLT12). + * @param PayloadLength SHOULD be 1300 or 32768 + */ + case class MessagePacket(PayloadLength: Int) extends OnionRoutingPacket[Onion.MessagePacket] + /** * A properly decrypted failure from a node in the route. * @@ -378,15 +385,6 @@ object Sphinx extends Logging { */ object RouteBlinding { - /** - * @param publicKey introduction node's public key (which cannot be blinded since the sender need to find a route to it). - * @param blindingEphemeralKey blinding tweak that can be used by the introduction node to derive the private key that - * lets it decrypt the encrypted payload. - * @param encryptedPayload encrypted payload that can be decrypted with the introduction node's private key and the - * blinding ephemeral key. - */ - case class IntroductionNode(publicKey: PublicKey, blindingEphemeralKey: PublicKey, encryptedPayload: ByteVector) - /** * @param blindedPublicKey blinded public key, which hides the real public key. * @param blindingEphemeralKey blinding tweak that can be used by the receiving node to derive the private key that @@ -397,13 +395,13 @@ object Sphinx extends Logging { case class BlindedNode(blindedPublicKey: PublicKey, blindingEphemeralKey: PublicKey, encryptedPayload: ByteVector) /** - * @param introductionNode the first node should not be blinded, otherwise the sender cannot locate it. - * @param blindedNodes blinded nodes (not including the introduction node). + * @param introductionNodeId the first node, not be blinded so that the sender can locate it. + * @param blindedNodes blinded nodes (including the introduction node). */ - case class BlindedRoute(introductionNode: IntroductionNode, blindedNodes: Seq[BlindedNode]) { - val nodeIds: Seq[PublicKey] = introductionNode.publicKey +: blindedNodes.map(_.blindedPublicKey) - val blindingEphemeralKeys: Seq[PublicKey] = introductionNode.blindingEphemeralKey +: blindedNodes.map(_.blindingEphemeralKey) - val encryptedPayloads: Seq[ByteVector] = introductionNode.encryptedPayload +: blindedNodes.map(_.encryptedPayload) + case class BlindedRoute(introductionNodeId: PublicKey, blindedNodes: Seq[BlindedNode]) { + val nodeIds: Seq[PublicKey] = introductionNodeId +: blindedNodes.tail.map(_.blindedPublicKey) + val blindingEphemeralKeys: Seq[PublicKey] = blindedNodes.map(_.blindingEphemeralKey) + val encryptedPayloads: Seq[ByteVector] = blindedNodes.map(_.encryptedPayload) } /** @@ -426,8 +424,7 @@ object Sphinx extends Logging { e = e.multiply(PrivateKey(Crypto.sha256(blindingKey.value ++ sharedSecret.bytes))) BlindedNode(blindedPublicKey, blindingKey, encryptedPayload ++ mac) } - val introductionNode = IntroductionNode(publicKeys.head, blindedHops.head.blindingEphemeralKey, blindedHops.head.encryptedPayload) - BlindedRoute(introductionNode, blindedHops.tail) + BlindedRoute(publicKeys.head, blindedHops) } /** diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala index a8e0866d16..cf549a9ebc 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala @@ -21,6 +21,7 @@ import akka.event.Logging.MDC import akka.event.{BusLogging, DiagnosticLoggingAdapter} import akka.util.Timeout import com.google.common.net.HostAndPort +import com.google.common.util.concurrent.RateLimiter import fr.acinq.bitcoin.Crypto.PublicKey import fr.acinq.bitcoin.{ByteVector32, Satoshi, SatoshiLong, Script} import fr.acinq.eclair.Features.Wumbo @@ -34,7 +35,7 @@ import fr.acinq.eclair.io.Monitoring.Metrics import fr.acinq.eclair.io.PeerConnection.KillReason import fr.acinq.eclair.remote.EclairInternalsSerializer.RemoteTypes import fr.acinq.eclair.wire.protocol -import fr.acinq.eclair.wire.protocol.{Error, HasChannelId, HasTemporaryChannelId, LightningMessage, NodeAddress, RoutingMessage, UnknownMessage, Warning} +import fr.acinq.eclair.wire.protocol.{Error, HasChannelId, HasTemporaryChannelId, LightningMessage, NodeAddress, OnionMessage, OnionMessages, RoutingMessage, UnknownMessage, Warning} import scodec.bits.ByteVector import java.net.InetSocketAddress @@ -54,6 +55,8 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: OnChainA import Peer._ + private val messageRelayRateLimiter = RateLimiter.create(nodeParams.onionMessageRateLimitPerSecond) + startWith(INSTANTIATING, Nothing) when(INSTANTIATING) { @@ -241,6 +244,16 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: OnChainA d.channels.values.toSet[ActorRef].foreach(_ ! INPUT_DISCONNECTED) // we deduplicate with toSet because there might be two entries per channel (tmp id and final id) gotoConnected(connectionReady, d.channels) + case Event(msg: OnionMessage, d: ConnectedData) => + if (nodeParams.features.hasFeature(Features.OnionMessages) && messageRelayRateLimiter.tryAcquire()) { + relayOnionMessage(msg) + } + stay() + + case Event(Peer.SendOnionMessage(toNodeId, msg), d: ConnectedData) if toNodeId == remoteNodeId => + d.peerConnection ! msg + stay() + case Event(unknownMsg: UnknownMessage, d: ConnectedData) if nodeParams.pluginMessageTags.contains(unknownMsg.tag) => context.system.eventStream.publish(UnknownMessageReceived(self, remoteNodeId, unknownMsg, d.connectionInfo)) stay() @@ -251,6 +264,14 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: OnChainA } } + private def relayOnionMessage(msg: OnionMessage): Unit = { + OnionMessages.process(nodeParams.privateKey, msg) match { + case OnionMessages.DropMessage(_) => () // We ignore bad messages + case OnionMessages.RelayMessage(nextNodeId, dataToRelay) => context.parent ! Peer.SendOnionMessage(nextNodeId, dataToRelay) + case OnionMessages.ReceiveMessage(_) => () // We only relay messages + } + } + whenUnhandled { case Event(_: Peer.OpenChannel, _) => sender() ! Status.Failure(new RuntimeException("not connected")) @@ -425,6 +446,8 @@ object Peer { def apply(uri: NodeURI): Connect = new Connect(uri.nodeId, Some(uri.address)) } + case class SendOnionMessage(nodeId: PublicKey, message: OnionMessage) extends PossiblyHarmful + case class Disconnect(nodeId: PublicKey) extends PossiblyHarmful case class OpenChannel(remoteNodeId: PublicKey, fundingSatoshis: Satoshi, pushMsat: MilliSatoshi, channelType_opt: Option[SupportedChannelType], fundingTxFeeratePerKw_opt: Option[FeeratePerKw], channelFlags: Option[Byte], timeout_opt: Option[Timeout]) extends PossiblyHarmful { require(pushMsat <= fundingSatoshis, s"pushMsat must be less or equal to fundingSatoshis") diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala b/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala index d1f13b7f25..e06b537aed 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/io/Switchboard.scala @@ -71,6 +71,10 @@ class Switchboard(nodeParams: NodeParams, peerFactory: Switchboard.PeerFactory) case None => sender() ! Status.Failure(new RuntimeException("peer not found")) } + case s@Peer.SendOnionMessage(nodeId, _) => + val peer = createOrGetPeer(nodeId, offlineChannels = Set.empty) + peer forward s + case o: Peer.OpenChannel => getPeer(o.remoteNodeId) match { case Some(peer) => peer forward o diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageCodecs.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageCodecs.scala index 36b9287b99..beeb6d6868 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageCodecs.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageCodecs.scala @@ -18,6 +18,7 @@ package fr.acinq.eclair.wire.protocol import fr.acinq.eclair.wire.Monitoring.{Metrics, Tags} import fr.acinq.eclair.wire.protocol.CommonCodecs._ +import fr.acinq.eclair.wire.protocol.OnionCodecs.onionRoutingPacketCodec import fr.acinq.eclair.{Features, KamonExt} import scodec.bits.{BitVector, ByteVector} import scodec.codecs._ @@ -308,6 +309,11 @@ object LightningMessageCodecs { ("timestampRange" | uint32) :: ("tlvStream" | GossipTimestampFilterTlv.gossipTimestampFilterTlvCodec)).as[GossipTimestampFilter] + val onionMessageCodec: Codec[OnionMessage] = ( + ("blindingKey" | publicKey) :: + ("onionPacket" | OnionCodecs.messageOnionPacketCodec) :: + ("tlvStream" | OnionMessageTlv.onionMessageTlvCodec)).as[OnionMessage] + // NB: blank lines to minimize merge conflicts // @@ -361,6 +367,7 @@ object LightningMessageCodecs { .typecase(263, queryChannelRangeCodec) .typecase(264, replyChannelRangeCodec) .typecase(265, gossipTimestampFilterCodec) + .typecase(387, onionMessageCodec) // NB: blank lines to minimize merge conflicts // diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala index c7773fc716..3974851076 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/LightningMessageTypes.scala @@ -320,6 +320,8 @@ object ReplyChannelRange { case class GossipTimestampFilter(chainHash: ByteVector32, firstTimestamp: TimestampSecond, timestampRange: Long, tlvStream: TlvStream[GossipTimestampFilterTlv] = TlvStream.empty) extends RoutingMessage with HasChainHash +case class OnionMessage(blindingKey: PublicKey, onionRoutingPacket: OnionRoutingPacket, tlvStream: TlvStream[OnionMessageTlv] = TlvStream.empty) extends LightningMessage + // NB: blank lines to minimize merge conflicts // diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/Onion.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/Onion.scala index 7adf76a155..9c4b7cdfb9 100644 --- a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/Onion.scala +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/Onion.scala @@ -175,6 +175,15 @@ object OnionTlv { /** Pre-image included by the sender of a payment in case of a donation */ case class KeySend(paymentPreimage: ByteVector32) extends OnionTlv + case class ReplyHop(nodeId: PublicKey, encTlv: ByteVector) + + case class ReplyPath(firstNodeId: PublicKey, blinding: PublicKey, path: Seq[ReplyHop]) extends OnionTlv + + case class EncTlv(bytes: ByteVector) extends OnionTlv + + case class Padding(bytes: ByteVector) extends OnionTlv + + case class PathId(bytes: ByteVector) extends OnionTlv } object Onion { @@ -218,6 +227,9 @@ object Onion { /** See [[fr.acinq.eclair.crypto.Sphinx.TrampolinePacket]]. */ sealed trait TrampolinePacket extends PacketType + /** See [[fr.acinq.eclair.crypto.Sphinx.MessagePacket]]. */ + sealed trait MessagePacket extends PacketType + /** Per-hop payload from an HTLC's payment onion (after decryption and decoding). */ sealed trait PerHopPayload @@ -281,6 +293,24 @@ object Onion { override val paymentPreimage = records.get[KeySend].map(_.paymentPreimage) } + case class MessageRelayPayload(records: TlvStream[OnionTlv]) extends MessagePacket with TlvFormat { + val blindedTlv: ByteVector = records.get[EncTlv].get.bytes + } + + case class MessageFinalPayload(records: TlvStream[OnionTlv]) extends MessagePacket with TlvFormat { + val blindedTlv: Option[ByteVector] = records.get[EncTlv].map(_.bytes) + val replyPath: Option[ReplyPath] = records.get[ReplyPath] + } + + case class RelayBlindedTlv(records: TlvStream[OnionTlv]) { + val nextNodeId: PublicKey = records.get[OutgoingNodeId].get.nodeId + val nextBlinding: Option[PublicKey] = records.get[BlindingPoint].map(_.publicKey) + } + + case class FinalBlindedTlv(records: TlvStream[OnionTlv]) { + val pathId: Option[ByteVector] = records.get[PathId].map(_.bytes) + } + def createNodeRelayPayload(amount: MilliSatoshi, expiry: CltvExpiry, nextNodeId: PublicKey): NodeRelayPayload = NodeRelayPayload(TlvStream(AmountToForward(amount), OutgoingCltv(expiry), OutgoingNodeId(nextNodeId))) @@ -310,15 +340,18 @@ object OnionCodecs { import scodec.codecs._ import scodec.{Attempt, Codec, DecodeResult, Decoder, Err} - def onionRoutingPacketCodec(payloadLength: Int): Codec[OnionRoutingPacket] = ( - ("version" | uint8) :: - ("publicKey" | bytes(33)) :: - ("onionPayload" | bytes(payloadLength)) :: - ("hmac" | bytes32)).as[OnionRoutingPacket] + def onionRoutingPacketCodec(payloadLength: Codec[Int]): Codec[OnionRoutingPacket] = ( + variableSizePrefixedBytes(payloadLength, + ("version" | uint8) ~ + ("publicKey" | bytes(33)), + ("onionPayload" | bytes)) ~ + ("hmac" | bytes32) flattenLeftPairs).as[OnionRoutingPacket] + + val paymentOnionPacketCodec: Codec[OnionRoutingPacket] = onionRoutingPacketCodec(provide(Sphinx.PaymentPacket.PayloadLength)) - val paymentOnionPacketCodec: Codec[OnionRoutingPacket] = onionRoutingPacketCodec(Sphinx.PaymentPacket.PayloadLength) + val trampolineOnionPacketCodec: Codec[OnionRoutingPacket] = onionRoutingPacketCodec(provide(Sphinx.TrampolinePacket.PayloadLength)) - val trampolineOnionPacketCodec: Codec[OnionRoutingPacket] = onionRoutingPacketCodec(Sphinx.TrampolinePacket.PayloadLength) + val messageOnionPacketCodec: Codec[OnionRoutingPacket] = onionRoutingPacketCodec(uint16.xmap(_ - 66, _ + 66)) /** * The 1.1 BOLT spec changed the onion frame format to use variable-length per-hop payloads. @@ -412,9 +445,73 @@ object OnionCodecs { case FinalTlvPayload(tlvs) => tlvs }) + private val replyHopCodec: Codec[ReplyHop] = (("nodeId" | publicKey) :: ("encTlv" | variableSizeBytes(uint16, bytes))).as[ReplyHop] + + private val replyPathCodec: Codec[ReplyPath] = (("firstNodeId" | publicKey) :: ("blinding" | publicKey) :: ("path" | list(replyHopCodec).xmap[Seq[ReplyHop]](_.toSeq, _.toList))).as[ReplyPath] + + private val encTlvCodec: Codec[EncTlv] = bytes.as[EncTlv] + + private val messageTlvCodec = discriminated[OnionTlv].by(varint) + .typecase(UInt64(2), replyPathCodec) + .typecase(UInt64(10), encTlvCodec) + + val messagePerHopPayloadCodec: Codec[TlvStream[OnionTlv]] = TlvCodecs.lengthPrefixedTlvStream[OnionTlv](messageTlvCodec).complete + + val messageRelayPayloadCodec: Codec[MessageRelayPayload] = messagePerHopPayloadCodec.narrow({ + case tlvs if tlvs.get[EncTlv].isEmpty => Attempt.failure(MissingRequiredTlv(UInt64(10))) + case tlvs => Attempt.successful(MessageRelayPayload(tlvs)) + }, { + case MessageRelayPayload(tlvs) => tlvs + }) + + val messageFinalPayloadCodec: Codec[MessageFinalPayload] = messagePerHopPayloadCodec.narrow({ + case tlvs => Attempt.successful(MessageFinalPayload(tlvs)) + }, { + case MessageFinalPayload(tlvs) => tlvs + }) + def perHopPayloadCodecByPacketType[T <: PacketType](packetType: Sphinx.OnionRoutingPacket[T], isLastPacket: Boolean): Codec[PacketType] = packetType match { case Sphinx.PaymentPacket => if (isLastPacket) finalPerHopPayloadCodec.upcast[PacketType] else channelRelayPerHopPayloadCodec.upcast[PacketType] case Sphinx.TrampolinePacket => if (isLastPacket) finalPerHopPayloadCodec.upcast[PacketType] else nodeRelayPerHopPayloadCodec.upcast[PacketType] + case Sphinx.MessagePacket(payloadLength) => if (isLastPacket) messageFinalPayloadCodec.upcast[PacketType] else messageRelayPayloadCodec.upcast[PacketType] + } + + private val padding: Codec[Padding] = variableSizeBytesLong(varintoverflow, "padding" | bytes).as[Padding] + + private val blindingKey: Codec[BlindingPoint] = variableSizeBytesLong(varintoverflow, "blinding" | publicKey).as[BlindingPoint] + + private val pathId: Codec[PathId] = variableSizeBytesLong(varintoverflow, "path_id" | bytes).as[PathId] + + private val blindedTlvCodec: Codec[TlvStream[OnionTlv]] = TlvCodecs.tlvStream[OnionTlv]( + discriminated[OnionTlv].by(varint) + .typecase(UInt64(1), padding) + .typecase(UInt64(4), outgoingNodeId) + .typecase(UInt64(12), blindingKey) + .typecase(UInt64(14), pathId)).complete + + case class ForbiddenTlv(tag: UInt64) extends Err { + // @formatter:off + val failureMessage: FailureMessage = InvalidOnionPayload(tag, 0) + + override def message = failureMessage.message + + override def context: List[String] = Nil + + override def pushContext(ctx: String): Err = this + // @formatter:on } -} \ No newline at end of file + val relayBlindedTlvCodec: Codec[RelayBlindedTlv] = blindedTlvCodec.narrow({ + case tlvs if tlvs.get[OutgoingNodeId].isEmpty => Attempt.failure(MissingRequiredTlv(UInt64(4))) + case tlvs if tlvs.get[PathId].nonEmpty => Attempt.failure(ForbiddenTlv(UInt64(14))) + case tlvs => Attempt.successful(RelayBlindedTlv(tlvs)) + }, { + case RelayBlindedTlv(tlvs) => tlvs + }) + + val finalBlindedTlvCodec: Codec[FinalBlindedTlv] = blindedTlvCodec.narrow( + tlvs => Attempt.successful(FinalBlindedTlv(tlvs)) + , { + case FinalBlindedTlv(tlvs) => tlvs + }) +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OnionMessageTlv.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OnionMessageTlv.scala new file mode 100644 index 0000000000..4e8ab13185 --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OnionMessageTlv.scala @@ -0,0 +1,32 @@ +/* + * Copyright 2021 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.wire.protocol + +import fr.acinq.eclair.wire.protocol.CommonCodecs.varint +import fr.acinq.eclair.wire.protocol.TlvCodecs.tlvStream +import scodec.Codec +import scodec.codecs.discriminated + +/** + * Created by thomash on 10/09/2021. + */ + +sealed trait OnionMessageTlv extends Tlv + +object OnionMessageTlv { + val onionMessageTlvCodec: Codec[TlvStream[OnionMessageTlv]] = tlvStream(discriminated[OnionMessageTlv].by(varint)) +} diff --git a/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OnionMessages.scala b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OnionMessages.scala new file mode 100644 index 0000000000..c95b17ce4f --- /dev/null +++ b/eclair-core/src/main/scala/fr/acinq/eclair/wire/protocol/OnionMessages.scala @@ -0,0 +1,44 @@ +package fr.acinq.eclair.wire.protocol + +import fr.acinq.bitcoin.ByteVector32 +import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} +import fr.acinq.eclair.crypto.Sphinx +import fr.acinq.eclair.wire.protocol.Onion.{MessageFinalPayload, MessageRelayPayload} +import scodec.bits.ByteVector +import scodec.{Attempt, DecodeResult} + +import scala.util.{Failure, Success} + +object OnionMessages { + + sealed trait Action + + case class DropMessage(reason: String) extends Action + + case class RelayMessage(nextNodeId: PublicKey, dataToRelay: OnionMessage) extends Action + + case class ReceiveMessage(finalPayload: MessageFinalPayload) extends Action + + def process(privateKey: PrivateKey, msg: OnionMessage): Action = { + val packetType = Sphinx.MessagePacket(msg.onionRoutingPacket.payload.length.toInt) + val blindedPrivateKey = Sphinx.RouteBlinding.derivePrivateKey(privateKey, msg.blindingKey) + packetType.peel(blindedPrivateKey, ByteVector32.Zeroes, msg.onionRoutingPacket) match { + case Left(_: BadOnion) => DropMessage("Can't peel onion") + case Right(p@Sphinx.DecryptedPacket(payload, nextPacket, _)) => (OnionCodecs.perHopPayloadCodecByPacketType(packetType, p.isLastPacket).decode(payload.bits): @unchecked) match { + case Attempt.Successful(DecodeResult(relayPayload: MessageRelayPayload, _)) => + Sphinx.RouteBlinding.decryptPayload(privateKey, msg.blindingKey, relayPayload.blindedTlv) match { + case Success((decrypted, nextBlindingKey)) => + OnionCodecs.relayBlindedTlvCodec.decode(decrypted.bits) match { + case Attempt.Successful(DecodeResult(relayNext, _)) => + val toRelay = OnionMessage(relayNext.nextBlinding.getOrElse(nextBlindingKey), nextPacket) + RelayMessage(relayNext.nextNodeId, toRelay) + case Attempt.Failure(_) => DropMessage("Can't decode TLV") + } + case Failure(_) => DropMessage("Can't decrypt blinded TLV") + } + case Attempt.Successful(DecodeResult(finalPayload: MessageFinalPayload, _)) => ReceiveMessage(finalPayload) + case Attempt.Failure(_) => DropMessage("Can't decode packet") + } + } + } +} diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala index daf8ab41db..9b0fb22fa9 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/TestConstants.scala @@ -32,6 +32,7 @@ import org.scalatest.Tag import scodec.bits.ByteVector import java.util.UUID +import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicLong import scala.concurrent.duration._ @@ -197,7 +198,8 @@ object TestConstants { enableTrampolinePayment = true, instanceId = UUID.fromString("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"), balanceCheckInterval = 1 hour, - blockchainWatchdogSources = blockchainWatchdogSources + blockchainWatchdogSources = blockchainWatchdogSources, + onionMessageRateLimitPerSecond = 10 ) def channelParams: LocalParams = Peer.makeChannelParams( @@ -323,7 +325,8 @@ object TestConstants { enableTrampolinePayment = true, instanceId = UUID.fromString("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"), balanceCheckInterval = 1 hour, - blockchainWatchdogSources = blockchainWatchdogSources + blockchainWatchdogSources = blockchainWatchdogSources, + onionMessageRateLimitPerSecond = 10 ) def channelParams: LocalParams = Peer.makeChannelParams( diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala index 7818d0eb5b..0de6806491 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/crypto/SphinxSpec.scala @@ -367,31 +367,24 @@ class SphinxSpec extends AnyFunSuite { test("create blinded route (reference test vector)") { val sessionKey = PrivateKey(hex"0101010101010101010101010101010101010101010101010101010101010101") val blindedRoute = RouteBlinding.create(sessionKey, publicKeys, routeBlindingPayloads) - assert(blindedRoute.introductionNode.publicKey === publicKeys(0)) - assert(blindedRoute.introductionNode.blindingEphemeralKey === PublicKey(hex"031b84c5567b126440995d3ed5aaba0565d71e1834604819ff9c17f5e9d5dd078f")) - assert(blindedRoute.introductionNode.encryptedPayload === hex"a245b767bd52520bdf8179b2dc681d1a36c2ededaf59429dfc4bea342fa460c9") - assert(blindedRoute.nodeIds === Seq( - publicKeys(0), - PublicKey(hex"022b09d77fb3374ee3ed9d2153e15e9962944ad1690327cbb0a9acb7d90f168763"), - PublicKey(hex"03d9f889364dc5a173460a2a6cc565b4ca78931792115dd6ef82c0e18ced837372"), - PublicKey(hex"03bfddd2253b42fe12edd37f9071a3883830ed61a4bc347eeac63421629cf032b5"), - PublicKey(hex"03a8588bc4a0a2f0d2fb8d5c0f8d062fb4d78bfba24a85d0ddeb4fd35dd3b34110"), - )) + assert(blindedRoute.introductionNodeId === publicKeys(0)) + assert(blindedRoute.nodeIds(0) === publicKeys(0)) assert(blindedRoute.blindedNodes.map(_.blindedPublicKey) === Seq( + PublicKey(hex"02ec68ed555f5d18b12fe0e2208563c3566032967cf11dc29b20c345449f9a50a2"), PublicKey(hex"022b09d77fb3374ee3ed9d2153e15e9962944ad1690327cbb0a9acb7d90f168763"), PublicKey(hex"03d9f889364dc5a173460a2a6cc565b4ca78931792115dd6ef82c0e18ced837372"), PublicKey(hex"03bfddd2253b42fe12edd37f9071a3883830ed61a4bc347eeac63421629cf032b5"), PublicKey(hex"03a8588bc4a0a2f0d2fb8d5c0f8d062fb4d78bfba24a85d0ddeb4fd35dd3b34110"), )) - assert(blindedRoute.blindingEphemeralKeys === blindedRoute.introductionNode.blindingEphemeralKey +: blindedRoute.blindedNodes.map(_.blindingEphemeralKey)) assert(blindedRoute.blindedNodes.map(_.blindingEphemeralKey) === Seq( + PublicKey(hex"031b84c5567b126440995d3ed5aaba0565d71e1834604819ff9c17f5e9d5dd078f"), PublicKey(hex"035cb4c003d58e16cc9207270b3596c2be3309eca64c36b208c946bbb599bfcad0"), PublicKey(hex"02e105bc01a7af07074a1b0b1d9a112a1d89c6cd87cc4e2b6ba3a824731d9508bd"), PublicKey(hex"0349164db5398925ef234002e62d2834da115b8eafc73436fab98ed12266e797cc"), PublicKey(hex"020a6d1951916adcac22125063f62c35b3686f36e5db2f77073f3d35b19c7a118a"), )) - assert(blindedRoute.encryptedPayloads === blindedRoute.introductionNode.encryptedPayload +: blindedRoute.blindedNodes.map(_.encryptedPayload)) assert(blindedRoute.blindedNodes.map(_.encryptedPayload) === Seq( + hex"a245b767bd52520bdf8179b2dc681d1a36c2ededaf59429dfc4bea342fa460c9", hex"38748f94ead7de2a54fc43e8bb927bfc377dda7ed5a2e36b327b739c3c82a602e43e07e378f17cd46ee32d987eb8b6d03b3403acb095bd2868f640b92ea1", hex"a5ddddd448f15208452f4d65da0d53679e9652c8f9c9882d795388a492b4060afb5f2f556e36aed51d089f60f7c94f714b34cb30f1dac0c17f3855a827cb", hex"7ead52884542d180e76fec6ae2d137b6b4c771dc0d41390e992839dea0f4fcefb4a31589125e2ba535d0dc3bf1bc94e6c9039323579547921686d3b54c22", @@ -399,36 +392,36 @@ class SphinxSpec extends AnyFunSuite { )) // The introduction point can decrypt its encrypted payload and obtain the next ephemeral public key. - val Success((payload0, ephKey1)) = RouteBlinding.decryptPayload(privKeys(0), blindedRoute.blindingEphemeralKeys(0), blindedRoute.encryptedPayloads(0)) + val Success((payload0, ephKey1)) = RouteBlinding.decryptPayload(privKeys(0), blindedRoute.blindedNodes(0).blindingEphemeralKey, blindedRoute.blindedNodes(0).encryptedPayload) assert(payload0 === routeBlindingPayloads(0)) - assert(ephKey1 === blindedRoute.blindingEphemeralKeys(1)) + assert(ephKey1 === blindedRoute.blindedNodes(1).blindingEphemeralKey) // The next node can derive the private key used to unwrap the onion and decrypt its encrypted payload. - assert(RouteBlinding.derivePrivateKey(privKeys(1), ephKey1).publicKey === blindedRoute.nodeIds(1)) - val Success((payload1, ephKey2)) = RouteBlinding.decryptPayload(privKeys(1), ephKey1, blindedRoute.encryptedPayloads(1)) + assert(RouteBlinding.derivePrivateKey(privKeys(1), ephKey1).publicKey === blindedRoute.blindedNodes(1).blindedPublicKey) + val Success((payload1, ephKey2)) = RouteBlinding.decryptPayload(privKeys(1), ephKey1, blindedRoute.blindedNodes(1).encryptedPayload) assert(payload1 === routeBlindingPayloads(1)) - assert(ephKey2 === blindedRoute.blindingEphemeralKeys(2)) + assert(ephKey2 === blindedRoute.blindedNodes(2).blindingEphemeralKey) // The next node can derive the private key used to unwrap the onion and decrypt its encrypted payload. - assert(RouteBlinding.derivePrivateKey(privKeys(2), ephKey2).publicKey === blindedRoute.nodeIds(2)) - val Success((payload2, ephKey3)) = RouteBlinding.decryptPayload(privKeys(2), ephKey2, blindedRoute.encryptedPayloads(2)) + assert(RouteBlinding.derivePrivateKey(privKeys(2), ephKey2).publicKey === blindedRoute.blindedNodes(2).blindedPublicKey) + val Success((payload2, ephKey3)) = RouteBlinding.decryptPayload(privKeys(2), ephKey2, blindedRoute.blindedNodes(2).encryptedPayload) assert(payload2 === routeBlindingPayloads(2)) - assert(ephKey3 === blindedRoute.blindingEphemeralKeys(3)) + assert(ephKey3 === blindedRoute.blindedNodes(3).blindingEphemeralKey) // The next node can derive the private key used to unwrap the onion and decrypt its encrypted payload. - assert(RouteBlinding.derivePrivateKey(privKeys(3), ephKey3).publicKey === blindedRoute.nodeIds(3)) - val Success((payload3, ephKey4)) = RouteBlinding.decryptPayload(privKeys(3), ephKey3, blindedRoute.encryptedPayloads(3)) + assert(RouteBlinding.derivePrivateKey(privKeys(3), ephKey3).publicKey === blindedRoute.blindedNodes(3).blindedPublicKey) + val Success((payload3, ephKey4)) = RouteBlinding.decryptPayload(privKeys(3), ephKey3, blindedRoute.blindedNodes(3).encryptedPayload) assert(payload3 === routeBlindingPayloads(3)) - assert(ephKey4 === blindedRoute.blindingEphemeralKeys(4)) + assert(ephKey4 === blindedRoute.blindedNodes(4).blindingEphemeralKey) // The last node can derive the private key used to unwrap the onion and decrypt its encrypted payload. - assert(RouteBlinding.derivePrivateKey(privKeys(4), ephKey4).publicKey === blindedRoute.nodeIds(4)) - val Success((payload4, _)) = RouteBlinding.decryptPayload(privKeys(4), ephKey4, blindedRoute.encryptedPayloads(4)) + assert(RouteBlinding.derivePrivateKey(privKeys(4), ephKey4).publicKey === blindedRoute.blindedNodes(4).blindedPublicKey) + val Success((payload4, _)) = RouteBlinding.decryptPayload(privKeys(4), ephKey4, blindedRoute.blindedNodes(4).encryptedPayload) assert(payload4 === routeBlindingPayloads(4)) } test("invalid blinded route") { - val encryptedPayloads = RouteBlinding.create(sessionKey, publicKeys, routeBlindingPayloads).encryptedPayloads + val encryptedPayloads = RouteBlinding.create(sessionKey, publicKeys, routeBlindingPayloads).blindedNodes.map(_.encryptedPayload) // Invalid node private key: val ephKey0 = sessionKey.publicKey assert(RouteBlinding.decryptPayload(privKeys(1), ephKey0, encryptedPayloads(0)).isFailure) @@ -447,7 +440,8 @@ class SphinxSpec extends AnyFunSuite { // The sender obtains this information (e.g. from a Bolt11 invoice) and prepends two normal hops to reach the introduction node. val nodeIds = publicKeys.take(2) ++ blindedRoute.nodeIds - assert(blindedRoute.encryptedPayloads === Seq( + val encryptedPayloads = blindedRoute.blindedNodes.map(_.encryptedPayload) + assert(encryptedPayloads === Seq( hex"192256e1c0b289eee9a509bf94455c111838cab3f47010aeedc1367aa77cf44743c6cf49726ddb96b426cdbf6767e462f940638879805b04dd97d3bb823f", hex"38c490e3f4f29cc7af8620002fb497591e043377d19fdf4c9cc913600a4d7ae2842e538181790fe7309c85c845b360eab73c8eaa1068866d1a42fb3afb54", hex"d2706bb65ac8e1c2a319ba53a371d97dc237132b22ce4f7439983545e37164d792dc6925a3c7cde855ac824871c2417052efa103e5b53ec49a2bb4ab7cfc", @@ -457,10 +451,10 @@ class SphinxSpec extends AnyFunSuite { TlvStream[OnionTlv](OnionTlv.AmountToForward(MilliSatoshi(500)), OnionTlv.OutgoingCltv(CltvExpiry(1000)), OnionTlv.OutgoingChannelId(ShortChannelId(10))), TlvStream[OnionTlv](OnionTlv.AmountToForward(MilliSatoshi(450)), OnionTlv.OutgoingCltv(CltvExpiry(900)), OnionTlv.OutgoingChannelId(ShortChannelId(15))), // The sender includes the blinding key and the first encrypted recipient data in the introduction node's payload. - TlvStream[OnionTlv](OnionTlv.AmountToForward(MilliSatoshi(400)), OnionTlv.OutgoingCltv(CltvExpiry(860)), OnionTlv.BlindingPoint(blindingEphemeralKey0), OnionTlv.EncryptedRecipientData(blindedRoute.encryptedPayloads(0))), + TlvStream[OnionTlv](OnionTlv.AmountToForward(MilliSatoshi(400)), OnionTlv.OutgoingCltv(CltvExpiry(860)), OnionTlv.BlindingPoint(blindingEphemeralKey0), OnionTlv.EncryptedRecipientData(encryptedPayloads(0))), // The sender includes the correct encrypted recipient data in each blinded node's payload. - TlvStream[OnionTlv](OnionTlv.AmountToForward(MilliSatoshi(250)), OnionTlv.OutgoingCltv(CltvExpiry(750)), OnionTlv.EncryptedRecipientData(blindedRoute.encryptedPayloads(1))), - TlvStream[OnionTlv](OnionTlv.AmountToForward(MilliSatoshi(250)), OnionTlv.OutgoingCltv(CltvExpiry(750)), OnionTlv.EncryptedRecipientData(blindedRoute.encryptedPayloads(2))), + TlvStream[OnionTlv](OnionTlv.AmountToForward(MilliSatoshi(250)), OnionTlv.OutgoingCltv(CltvExpiry(750)), OnionTlv.EncryptedRecipientData(encryptedPayloads(1))), + TlvStream[OnionTlv](OnionTlv.AmountToForward(MilliSatoshi(250)), OnionTlv.OutgoingCltv(CltvExpiry(750)), OnionTlv.EncryptedRecipientData(encryptedPayloads(2))), ).map(tlvs => OnionCodecs.tlvPerHopPayloadCodec.encode(tlvs).require.bytes) val senderSessionKey = PrivateKey(hex"0202020202020202020202020202020202020202020202020202020202020202") diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/EncryptedRecipientDataSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/EncryptedRecipientDataSpec.scala index 153a0b37dd..631ea02598 100644 --- a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/EncryptedRecipientDataSpec.scala +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/EncryptedRecipientDataSpec.scala @@ -22,11 +22,12 @@ class EncryptedRecipientDataSpec extends AnyFunSuiteLike { ) val blindedRoute = RouteBlinding.create(sessionKey, nodePrivKeys.map(_.publicKey), payloads.map(_._2)) + val encryptedPayloads = blindedRoute.blindedNodes.map(_.encryptedPayload) val blinding0 = sessionKey.publicKey - val Success((decryptedPayload0, blinding1)) = EncryptedRecipientDataCodecs.decode(nodePrivKeys.head, blinding0, blindedRoute.encryptedPayloads(0)) - val Success((decryptedPayload1, blinding2)) = EncryptedRecipientDataCodecs.decode(nodePrivKeys(1), blinding1, blindedRoute.encryptedPayloads(1)) - val Success((decryptedPayload2, blinding3)) = EncryptedRecipientDataCodecs.decode(nodePrivKeys(2), blinding2, blindedRoute.encryptedPayloads(2)) - val Success((decryptedPayload3, _)) = EncryptedRecipientDataCodecs.decode(nodePrivKeys(3), blinding3, blindedRoute.encryptedPayloads(3)) + val Success((decryptedPayload0, blinding1)) = EncryptedRecipientDataCodecs.decode(nodePrivKeys.head, blinding0, encryptedPayloads.head) + val Success((decryptedPayload1, blinding2)) = EncryptedRecipientDataCodecs.decode(nodePrivKeys(1), blinding1, encryptedPayloads(1)) + val Success((decryptedPayload2, blinding3)) = EncryptedRecipientDataCodecs.decode(nodePrivKeys(2), blinding2, encryptedPayloads(2)) + val Success((decryptedPayload3, _)) = EncryptedRecipientDataCodecs.decode(nodePrivKeys(3), blinding3, encryptedPayloads(3)) assert(Seq(decryptedPayload0, decryptedPayload1, decryptedPayload2, decryptedPayload3) === payloads.map(_._1)) } @@ -45,15 +46,16 @@ class EncryptedRecipientDataSpec extends AnyFunSuiteLike { val payloads = Seq(hex"0a 02080000000000000231", testCase) val blindingPrivKey = randomKey() val blindedRoute = RouteBlinding.create(blindingPrivKey, nodePrivKeys.map(_.publicKey), payloads) + val encryptedPayloads = blindedRoute.blindedNodes.map(_.encryptedPayload) // The payload for the first node is valid. val blinding0 = blindingPrivKey.publicKey - val Success((_, blinding1)) = EncryptedRecipientDataCodecs.decode(nodePrivKeys.head, blinding0, blindedRoute.encryptedPayloads.head) + val Success((_, blinding1)) = EncryptedRecipientDataCodecs.decode(nodePrivKeys.head, blinding0, encryptedPayloads.head) // If the first node is given invalid decryption material, it cannot decrypt recipient data. - assert(EncryptedRecipientDataCodecs.decode(nodePrivKeys.last, blinding0, blindedRoute.encryptedPayloads.head).isFailure) - assert(EncryptedRecipientDataCodecs.decode(nodePrivKeys.head, blinding1, blindedRoute.encryptedPayloads.head).isFailure) - assert(EncryptedRecipientDataCodecs.decode(nodePrivKeys.head, blinding0, blindedRoute.encryptedPayloads.last).isFailure) + assert(EncryptedRecipientDataCodecs.decode(nodePrivKeys.last, blinding0, encryptedPayloads.head).isFailure) + assert(EncryptedRecipientDataCodecs.decode(nodePrivKeys.head, blinding1, encryptedPayloads.head).isFailure) + assert(EncryptedRecipientDataCodecs.decode(nodePrivKeys.head, blinding0, encryptedPayloads.last).isFailure) // The payload for the last node is invalid, even with valid decryption material. - assert(EncryptedRecipientDataCodecs.decode(nodePrivKeys.last, blinding1, blindedRoute.encryptedPayloads.last).isFailure) + assert(EncryptedRecipientDataCodecs.decode(nodePrivKeys.last, blinding1, encryptedPayloads.last).isFailure) } } diff --git a/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/OnionMessagesSpec.scala b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/OnionMessagesSpec.scala new file mode 100644 index 0000000000..5eaff80e5d --- /dev/null +++ b/eclair-core/src/test/scala/fr/acinq/eclair/wire/protocol/OnionMessagesSpec.scala @@ -0,0 +1,209 @@ +/* + * Copyright 2021 ACINQ SAS + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package fr.acinq.eclair.wire.protocol + +import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey} +import fr.acinq.bitcoin.{ByteVector32, Crypto} +import fr.acinq.eclair.crypto.Sphinx +import fr.acinq.eclair.crypto.Sphinx.{MessagePacket, PacketAndSecrets} +import fr.acinq.eclair.wire.protocol.Onion.MessageRelayPayload +import fr.acinq.eclair.wire.protocol.OnionCodecs.messageRelayPayloadCodec +import fr.acinq.eclair.wire.protocol.OnionTlv.EncTlv +import org.scalatest.funsuite.AnyFunSuite +import scodec.bits.{ByteVector, HexStringSyntax} +import scodec.{Attempt, DecodeResult} + +import scala.util.{Failure, Success} + +/** + * Created by thomash on 23/09/2021. + */ + +class OnionMessagesSpec extends AnyFunSuite { + + test("Spec tests") { + val alice = PrivateKey(hex"414141414141414141414141414141414141414141414141414141414141414101") + assert(alice.publicKey == PublicKey(hex"02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619")) + val bob = PrivateKey(hex"424242424242424242424242424242424242424242424242424242424242424201") + assert(bob.publicKey == PublicKey(hex"0324653eac434488002cc06bbfb7f10fe18991e35f9fe4302dbea6d2353dc0ab1c")) + val carol = PrivateKey(hex"434343434343434343434343434343434343434343434343434343434343434301") + assert(carol.publicKey == PublicKey(hex"027f31ebc5462c1fdce1b737ecff52d37d75dea43ce11c74d25aa297165faa2007")) + val dave = PrivateKey(hex"444444444444444444444444444444444444444444444444444444444444444401") + assert(dave.publicKey == PublicKey(hex"032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e668680991")) + + val blindingSecret = PrivateKey(hex"050505050505050505050505050505050505050505050505050505050505050501") + assert(blindingSecret.publicKey == PublicKey(hex"0362c0a046dacce86ddd0343c6d3c7c79c2208ba0d9c9cf24a6d046d21d21f90f7")) + val blindingOverride = PrivateKey(hex"070707070707070707070707070707070707070707070707070707070707070701") + assert(blindingOverride.publicKey == PublicKey(hex"02989c0b76cb563971fdc9bef31ec06c3560f3249d6ee9e5d83c57625596e05f6f")) + + val messageForAlice = Onion.RelayBlindedTlv(TlvStream(OnionTlv.OutgoingNodeId(bob.publicKey))) + val encodedForAlice = OnionCodecs.relayBlindedTlvCodec.encode(messageForAlice).require.bytes + assert(encodedForAlice == hex"04210324653eac434488002cc06bbfb7f10fe18991e35f9fe4302dbea6d2353dc0ab1c") + val messageForBob = Onion.RelayBlindedTlv(TlvStream(OnionTlv.OutgoingNodeId(carol.publicKey), OnionTlv.BlindingPoint(blindingOverride.publicKey))) + val encodedForBob = OnionCodecs.relayBlindedTlvCodec.encode(messageForBob).require.bytes + assert(encodedForBob == hex"0421027f31ebc5462c1fdce1b737ecff52d37d75dea43ce11c74d25aa297165faa20070c2102989c0b76cb563971fdc9bef31ec06c3560f3249d6ee9e5d83c57625596e05f6f") + val messageForCarol = Onion.RelayBlindedTlv(TlvStream(OnionTlv.Padding(hex"0000000000000000000000000000000000000000000000000000000000000000000000"), OnionTlv.OutgoingNodeId(dave.publicKey))) + val encodedForCarol = OnionCodecs.relayBlindedTlvCodec.encode(messageForCarol).require.bytes + assert(encodedForCarol == hex"012300000000000000000000000000000000000000000000000000000000000000000000000421032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e668680991") + val messageForDave = Onion.FinalBlindedTlv(TlvStream(OnionTlv.PathId(hex"01234567"))) + val encodedForDave = OnionCodecs.finalBlindedTlvCodec.encode(messageForDave).require.bytes + assert(encodedForDave == hex"0e0401234567") + + // Building blinded path Carol -> Dave + val routeFromCarol = Sphinx.RouteBlinding.create(blindingOverride, carol.publicKey :: dave.publicKey :: Nil, encodedForCarol :: encodedForDave :: Nil) + + // Building blinded path Alice -> Bob + val routeToCarol = Sphinx.RouteBlinding.create(blindingSecret, alice.publicKey :: bob.publicKey :: Nil, encodedForAlice :: encodedForBob :: Nil) + + val publicKeys = routeToCarol.blindedNodes.map(_.blindedPublicKey) concat routeFromCarol.blindedNodes.map(_.blindedPublicKey) + val encryptedPayloads = routeToCarol.encryptedPayloads concat routeFromCarol.encryptedPayloads + val payloads = encryptedPayloads.map(encTlv => messageRelayPayloadCodec.encode(MessageRelayPayload(TlvStream(EncTlv(encTlv)))).require.bytes) + + val sessionKey = PrivateKey(hex"090909090909090909090909090909090909090909090909090909090909090901") + + val PacketAndSecrets(packet, _) = MessagePacket(1300).create(sessionKey, publicKeys, payloads, ByteVector32.Zeroes) + val onionForAlice = OnionMessage(blindingSecret.publicKey, packet) + + OnionMessages.process(alice, onionForAlice) match { + case OnionMessages.RelayMessage(nextNodeId, onionForBob) => + assert(nextNodeId == bob.publicKey) + OnionMessages.process(bob, onionForBob) match { + case OnionMessages.RelayMessage(nextNodeId, onionForCarol) => + assert(nextNodeId == carol.publicKey) + OnionMessages.process(carol, onionForCarol) match { + case OnionMessages.RelayMessage(nextNodeId, onionForDave) => + assert(nextNodeId == dave.publicKey) + OnionMessages.process(dave, onionForDave) match { + case OnionMessages.ReceiveMessage(_) => () + case x => fail(x.toString) + } + case x => fail(x.toString) + } + case x => fail(x.toString) + } + case x => fail(x.toString) + } + } + + test("Simple enctlv for Alice, next is Bob") { + val nodePrivateKey = PrivateKey(hex"414141414141414141414141414141414141414141414141414141414141414101") + val nodeId = PublicKey(hex"02eec7245d6b7d2ccb30380bfbe2a3648cd7a942653f5aa340edcea1f283686619") + assert(nodePrivateKey.publicKey == nodeId) + val blindingSecret = PrivateKey(hex"050505050505050505050505050505050505050505050505050505050505050501") + val blindingKey = PublicKey(hex"0362c0a046dacce86ddd0343c6d3c7c79c2208ba0d9c9cf24a6d046d21d21f90f7") + assert(blindingSecret.publicKey == blindingKey) + val sharedSecret = ByteVector32(hex"2e83e9bc7821d3f6cec7301fa8493aee407557624fb5745bede9084852430e3f") + assert(Sphinx.computeSharedSecret(nodeId, blindingSecret) == sharedSecret) + assert(Sphinx.computeSharedSecret(blindingKey, nodePrivateKey) == sharedSecret) + assert(Sphinx.mac(ByteVector("blinded_node_id".getBytes), sharedSecret) == ByteVector32(hex"7d846b3445621d49a665e5698c52141e9dda8fa2fe0c3da7e0f9008ccc588a38")) + val blindedNodeId = PublicKey(hex"02004b5662061e9db495a6ad112b6c4eba228a079e8e304d9df50d61043acbc014") + val nextNodeId = PublicKey(hex"0324653eac434488002cc06bbfb7f10fe18991e35f9fe4302dbea6d2353dc0ab1c") + val encmsg = hex"04210324653eac434488002cc06bbfb7f10fe18991e35f9fe4302dbea6d2353dc0ab1c" + val Sphinx.RouteBlinding.BlindedRoute(introductionNode, blindedNodes) = Sphinx.RouteBlinding.create(blindingSecret, nodeId :: Nil, encmsg :: Nil) + assert(blindedNodes.head.blindedPublicKey == blindedNodeId) + assert(Crypto.sha256(blindingKey.value ++ sharedSecret.bytes) == ByteVector32(hex"bae3d9ea2b06efd1b7b9b49b6cdcaad0e789474a6939ffa54ff5ec9224d5b76c")) + val enctlv = hex"6970e870b473ddbc27e3098bfa45bb1aa54f1f637f803d957e6271d8ffeba89da2665d62123763d9b634e30714144a1c165ac9" + assert(blindedNodes.head.encryptedPayload == enctlv) + val message = Onion.RelayBlindedTlv(TlvStream(OnionTlv.OutgoingNodeId(nextNodeId))) + OnionCodecs.relayBlindedTlvCodec.encode(message) match { + case Attempt.Successful(bits) => assert(bits.bytes == encmsg) + case Attempt.Failure(err) => fail(err.toString) + } + OnionCodecs.relayBlindedTlvCodec.decode(encmsg.bits) match { + case Attempt.Successful(DecodeResult(relayNext, _)) => + assert(relayNext.nextNodeId == nextNodeId) + assert(relayNext.nextBlinding.isEmpty) + case Attempt.Failure(err) => fail(err.toString) + } + Sphinx.RouteBlinding.decryptPayload(nodePrivateKey, blindingKey, enctlv) match { + case Success((decrypted, _)) => assert(decrypted == encmsg) + case Failure(err) => fail(err.toString) + } + } + + test("Blinding-key-override enctlv for Bob, next is Carol") { + val nodePrivateKey = PrivateKey(hex"424242424242424242424242424242424242424242424242424242424242424201") + val nodeId = PublicKey(hex"0324653eac434488002cc06bbfb7f10fe18991e35f9fe4302dbea6d2353dc0ab1c") + assert(nodePrivateKey.publicKey == nodeId) + val blindingSecret = PrivateKey(hex"76d4de6c329c79623842dcf8f8eaee90c9742df1b5231f5350df4a231d16ebcf01") + val blindingKey = PublicKey(hex"03fc5e56da97b462744c9a6b0ba9d5b3ffbfb1a08367af9cc6ea5ae03c79a78eec") + assert(blindingSecret.publicKey == blindingKey) + val sharedSecret = ByteVector32(hex"f18a1ddb1cb27d8fc4faf2cf317e87524fcc6b7f053496d95bf6e6809d09851e") + assert(Sphinx.computeSharedSecret(nodeId, blindingSecret) == sharedSecret) + assert(Sphinx.computeSharedSecret(blindingKey, nodePrivateKey) == sharedSecret) + assert(Sphinx.mac(ByteVector("blinded_node_id".getBytes), sharedSecret) == ByteVector32(hex"8074773a3745818b0d97dd875023486cc35e7afd95f5e9ec1363f517979e8373")) + val blindedNodeId = PublicKey(hex"026ea8e36f78e038c659beba9229699796127471d9c7a24a0308533371fd63ad48") + val nextNodeId = PublicKey(hex"027f31ebc5462c1fdce1b737ecff52d37d75dea43ce11c74d25aa297165faa2007") + val encmsg = hex"0421027f31ebc5462c1fdce1b737ecff52d37d75dea43ce11c74d25aa297165faa20070c2102989c0b76cb563971fdc9bef31ec06c3560f3249d6ee9e5d83c57625596e05f6f" + val Sphinx.RouteBlinding.BlindedRoute(_, blindedHops) = Sphinx.RouteBlinding.create(blindingSecret, nodeId :: Nil, encmsg :: Nil) + assert(blindedHops.head.blindedPublicKey == blindedNodeId) + assert(Crypto.sha256(blindingKey.value ++ sharedSecret.bytes) == ByteVector32(hex"9afb8b2ebc174dcf9e270be24771da7796542398d29d4ff6a4e7b6b4b9205cfe")) + val enctlv = hex"1630da85e8759b8f3b94d74a539c6f0d870a87cf03d4986175865a2985553c997b560c36613bd9184c1a6d41a37027aabdab5433009d8409a1b638eb90373778a05716af2c215b3d31db7b2c2659716e663ba3d9c909" + assert(blindedHops.head.encryptedPayload == enctlv) + val message = Onion.RelayBlindedTlv(TlvStream(OnionTlv.OutgoingNodeId(nextNodeId), OnionTlv.BlindingPoint(PrivateKey(hex"070707070707070707070707070707070707070707070707070707070707070701").publicKey))) + OnionCodecs.relayBlindedTlvCodec.encode(message) match { + case Attempt.Successful(bits) => assert(bits.bytes == encmsg) + case Attempt.Failure(err) => fail(err.toString) + } + OnionCodecs.relayBlindedTlvCodec.decode(encmsg.bits) match { + case Attempt.Successful(DecodeResult(relayNext, _)) => + assert(relayNext.nextNodeId == nextNodeId) + assert(relayNext.nextBlinding contains PrivateKey(hex"070707070707070707070707070707070707070707070707070707070707070701").publicKey) + case Attempt.Failure(err) => fail(err.toString) + } + Sphinx.RouteBlinding.decryptPayload(nodePrivateKey, blindingKey, enctlv) match { + case Success((decrypted, _)) => assert(decrypted == encmsg) + case Failure(err) => fail(err.toString) + } + } + + test("Padded enctlv for Carol, next is Dave") { + val nodePrivateKey = PrivateKey(hex"434343434343434343434343434343434343434343434343434343434343434301") + val nodeId = PublicKey(hex"027f31ebc5462c1fdce1b737ecff52d37d75dea43ce11c74d25aa297165faa2007") + assert(nodePrivateKey.publicKey == nodeId) + val blindingSecret = PrivateKey(hex"070707070707070707070707070707070707070707070707070707070707070701") + val blindingKey = PublicKey(hex"02989c0b76cb563971fdc9bef31ec06c3560f3249d6ee9e5d83c57625596e05f6f") + assert(blindingSecret.publicKey == blindingKey) + val sharedSecret = ByteVector32(hex"8c0f7716da996c4913d720dbf691b559a4945bf70cdd18e0b61e3e42635efc9c") + assert(Sphinx.computeSharedSecret(nodeId, blindingSecret) == sharedSecret) + assert(Sphinx.computeSharedSecret(blindingKey, nodePrivateKey) == sharedSecret) + assert(Sphinx.mac(ByteVector("blinded_node_id".getBytes), sharedSecret) == ByteVector32(hex"02afb2187075c8af51488242194b44c02624785ccd6fd43b5796c68f3025bf88")) + val blindedNodeId = PublicKey(hex"02f4f524562868a09d5f54fb956ade3fa51ef071d64d923e395cc6db5e290ec67b") + val nextNodeId = PublicKey(hex"032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e668680991") + val encmsg = hex"012300000000000000000000000000000000000000000000000000000000000000000000000421032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e668680991" + val Sphinx.RouteBlinding.BlindedRoute(_, blindedHops) = Sphinx.RouteBlinding.create(blindingSecret, nodeId :: Nil, encmsg :: Nil) + assert(blindedHops.head.blindedPublicKey == blindedNodeId) + assert(Crypto.sha256(blindingKey.value ++ sharedSecret.bytes) == ByteVector32(hex"cc3b918cda6b1b049bdbe469c4dd952935e7c1518dd9c7ed0cd2cd5bc2742b82")) + val enctlv = hex"8285acbceb37dfb38b877a888900539be656233cd74a55c55344fb068f9d8da365340d21db96fb41b76123207daeafdfb1f571e3fea07a22e10da35f03109a0380b3c69fcbed9c698086671809658761cf65ecbc3c07a2e5" + assert(blindedHops.head.encryptedPayload == enctlv) + val message = Onion.RelayBlindedTlv(TlvStream(OnionTlv.Padding(hex"0000000000000000000000000000000000000000000000000000000000000000000000"), OnionTlv.OutgoingNodeId(nextNodeId))) + OnionCodecs.relayBlindedTlvCodec.encode(message) match { + case Attempt.Successful(bits) => assert(bits.bytes == encmsg) + case Attempt.Failure(err) => fail(err.toString) + } + OnionCodecs.relayBlindedTlvCodec.decode(encmsg.bits) match { + case Attempt.Successful(DecodeResult(relayNext, _)) => + assert(relayNext.nextNodeId == nextNodeId) + assert(relayNext.nextBlinding.isEmpty) + case Attempt.Failure(err) => fail(err.toString) + } + Sphinx.RouteBlinding.decryptPayload(nodePrivateKey, blindingKey, enctlv) match { + case Success((decrypted, _)) => assert(decrypted == encmsg) + case Failure(err) => fail(err.toString) + } + } +}