Skip to content

Commit

Permalink
Add basic support for onion messages
Browse files Browse the repository at this point in the history
- Relay onion messages
- Add functions to create onion messages
  • Loading branch information
thomash-acinq committed Oct 25, 2021
1 parent 1573f7b commit 1eabee0
Show file tree
Hide file tree
Showing 26 changed files with 628 additions and 103 deletions.
5 changes: 5 additions & 0 deletions eclair-core/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ eclair {
option_anchor_outputs = disabled
option_anchors_zero_fee_htlc_tx = disabled
option_shutdown_anysegwit = optional
option_onion_messages = disabled
trampoline_payment = disabled
keysend = disabled
}
Expand Down Expand Up @@ -333,6 +334,10 @@ eclair {
"mempool.space"
]
}

onion-messages {
rate-limit-per-second-per-peer = 10
}
}

akka {
Expand Down
6 changes: 6 additions & 0 deletions eclair-core/src/main/scala/fr/acinq/eclair/Features.scala
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,11 @@ object Features {
val mandatory = 26
}

case object OnionMessages extends Feature {
val rfcName = "option_onion_messages"
val mandatory = 38
}

// TODO: @t-bast: update feature bits once spec-ed (currently reserved here: https://github.com/lightningnetwork/lightning-rfc/issues/605)
// We're not advertising these bits yet in our announcements, clients have to assume support.
// This is why we haven't added them yet to `areSupported`.
Expand Down Expand Up @@ -231,6 +236,7 @@ object Features {
AnchorOutputs,
AnchorOutputsZeroFeeHtlcTx,
ShutdownAnySegwit,
OnionMessages,
KeySend
)

Expand Down
6 changes: 4 additions & 2 deletions eclair-core/src/main/scala/fr/acinq/eclair/NodeParams.scala
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,8 @@ case class NodeParams(nodeKeyManager: NodeKeyManager,
maxPaymentAttempts: Int,
enableTrampolinePayment: Boolean,
balanceCheckInterval: FiniteDuration,
blockchainWatchdogSources: Seq[String]) {
blockchainWatchdogSources: Seq[String],
onionMessageRateLimitPerSecond: Double) {
val privateKey: Crypto.PrivateKey = nodeKeyManager.nodeKey.privateKey

val nodeId: PublicKey = nodeKeyManager.nodeId
Expand Down Expand Up @@ -457,7 +458,8 @@ object NodeParams extends Logging {
maxPaymentAttempts = config.getInt("max-payment-attempts"),
enableTrampolinePayment = config.getBoolean("trampoline-payments-enable"),
balanceCheckInterval = FiniteDuration(config.getDuration("balance-check-interval").getSeconds, TimeUnit.SECONDS),
blockchainWatchdogSources = config.getStringList("blockchain-watchdog.sources").asScala.toSeq
blockchainWatchdogSources = config.getStringList("blockchain-watchdog.sources").asScala.toSeq,
onionMessageRateLimitPerSecond = config.getDouble("onion-messages.rate-limit-per-second-per-peer")
)
}
}
33 changes: 23 additions & 10 deletions eclair-core/src/main/scala/fr/acinq/eclair/crypto/Sphinx.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import fr.acinq.eclair.wire.protocol._
import grizzled.slf4j.Logging
import scodec.Attempt
import scodec.bits.ByteVector
import scodec.codecs.provide

import scala.annotation.tailrec
import scala.util.{Failure, Success, Try}
Expand Down Expand Up @@ -217,7 +218,7 @@ object Sphinx extends Logging {
* @param onionPayloadFiller optional onion payload filler, needed only when you're constructing the last packet.
* @return the next packet.
*/
def wrap(payload: ByteVector, associatedData: ByteVector32, ephemeralPublicKey: PublicKey, sharedSecret: ByteVector32, packet: Either[ByteVector, protocol.OnionRoutingPacket], onionPayloadFiller: ByteVector = ByteVector.empty): protocol.OnionRoutingPacket = {
def wrap(payload: ByteVector, associatedData: ByteVector, ephemeralPublicKey: PublicKey, sharedSecret: ByteVector32, packet: Either[ByteVector, protocol.OnionRoutingPacket], onionPayloadFiller: ByteVector = ByteVector.empty): protocol.OnionRoutingPacket = {
require(payload.length <= PayloadLength - MacLength, s"packet payload cannot exceed ${PayloadLength - MacLength} bytes")

val (currentMac, currentPayload): (ByteVector32, ByteVector) = packet match {
Expand Down Expand Up @@ -248,7 +249,7 @@ object Sphinx extends Logging {
* @return An onion packet with all shared secrets. The onion packet can be sent to the first node in the list, and
* the shared secrets (one per node) can be used to parse returned failure messages if needed.
*/
def create(sessionKey: PrivateKey, publicKeys: Seq[PublicKey], payloads: Seq[ByteVector], associatedData: ByteVector32): PacketAndSecrets = {
def create(sessionKey: PrivateKey, publicKeys: Seq[PublicKey], payloads: Seq[ByteVector], associatedData: ByteVector): PacketAndSecrets = {
val (ephemeralPublicKeys, sharedsecrets) = computeEphemeralPublicKeysAndSharedSecrets(sessionKey, publicKeys)
val filler = generateFiller("rho", sharedsecrets.dropRight(1), payloads.dropRight(1))

Expand All @@ -272,7 +273,7 @@ object Sphinx extends Logging {
* When an invalid onion is received, its hash should be included in the failure message.
*/
def hash(onion: protocol.OnionRoutingPacket): ByteVector32 =
Crypto.sha256(OnionCodecs.onionRoutingPacketCodec(onion.payload.length.toInt).encode(onion).require.toByteVector)
Crypto.sha256(OnionCodecs.onionRoutingPacketCodec(provide(onion.payload.length.toInt)).encode(onion).require.toByteVector)

}

Expand All @@ -291,6 +292,12 @@ object Sphinx extends Logging {
override val PayloadLength = 400
}

/**
* A message onion packet is used when requesting/sending an invoice from/to a remote node when using offers (BOLT12).
* @param PayloadLength SHOULD be 1300 or 32768
*/
case class MessagePacket(PayloadLength: Int) extends OnionRoutingPacket[Onion.MessagePacket]

/**
* A properly decrypted failure from a node in the route.
*
Expand Down Expand Up @@ -396,12 +403,19 @@ object Sphinx extends Logging {
case class BlindedNode(blindedPublicKey: PublicKey, encryptedPayload: ByteVector)

/**
* @param introductionNode the first node should not be blinded, otherwise the sender cannot locate it.
* @param blindedNodes blinded nodes (not including the introduction node).
* @param introductionNodeId the first node, not be blinded so that the sender can locate it.
* @param blindingKey blinding tweak that can be used by the introduction node to derive the private key that
* matches the blinded public key.
* @param blindedNodes blinded nodes (including the introduction node).
*/
case class BlindedRoute(introductionNode: IntroductionNode, blindedNodes: Seq[BlindedNode]) {
val nodeIds: Seq[PublicKey] = introductionNode.publicKey +: blindedNodes.map(_.blindedPublicKey)
val encryptedPayloads: Seq[ByteVector] = introductionNode.encryptedPayload +: blindedNodes.map(_.encryptedPayload)
case class BlindedRoute(introductionNodeId: PublicKey, blindingKey: PublicKey, blindedNodes: Seq[BlindedNode]) {
val introductionNode: IntroductionNode = IntroductionNode(introductionNodeId, blindedNodes.head.blindedPublicKey, blindingKey, blindedNodes.head.encryptedPayload)
val subsequentNodes: Seq[BlindedNode] = blindedNodes.tail

val blindedNodeIds: Seq[PublicKey] = blindedNodes.map(_.blindedPublicKey)
val encryptedPayloads: Seq[ByteVector] = blindedNodes.map(_.encryptedPayload)

val nodeIds: Seq[PublicKey] = introductionNodeId +: blindedNodeIds.tail
}

/**
Expand All @@ -424,8 +438,7 @@ object Sphinx extends Logging {
e = e.multiply(PrivateKey(Crypto.sha256(blindingKey.value ++ sharedSecret.bytes)))
(BlindedNode(blindedPublicKey, encryptedPayload ++ mac), blindingKey)
}.unzip
val introductionNode = IntroductionNode(publicKeys.head, blindedHops.head.blindedPublicKey, blindingKeys.head, blindedHops.head.encryptedPayload)
BlindedRoute(introductionNode, blindedHops.tail)
BlindedRoute(publicKeys.head, blindingKeys.head, blindedHops)
}

/**
Expand Down
21 changes: 20 additions & 1 deletion eclair-core/src/main/scala/fr/acinq/eclair/io/Peer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import akka.event.Logging.MDC
import akka.event.{BusLogging, DiagnosticLoggingAdapter}
import akka.util.Timeout
import com.google.common.net.HostAndPort
import com.google.common.util.concurrent.RateLimiter
import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.bitcoin.{ByteVector32, Satoshi, SatoshiLong, Script}
import fr.acinq.eclair.Features.Wumbo
Expand All @@ -35,7 +36,7 @@ import fr.acinq.eclair.io.Monitoring.Metrics
import fr.acinq.eclair.io.PeerConnection.KillReason
import fr.acinq.eclair.remote.EclairInternalsSerializer.RemoteTypes
import fr.acinq.eclair.wire.protocol
import fr.acinq.eclair.wire.protocol.{Error, HasChannelId, HasTemporaryChannelId, LightningMessage, NodeAddress, RoutingMessage, UnknownMessage, Warning}
import fr.acinq.eclair.wire.protocol.{Error, HasChannelId, HasTemporaryChannelId, LightningMessage, NodeAddress, OnionMessage, OnionMessages, RoutingMessage, UnknownMessage, Warning}
import scodec.bits.ByteVector

import java.net.InetSocketAddress
Expand All @@ -55,6 +56,8 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: OnChainA

import Peer._

private val messageRelayRateLimiter = RateLimiter.create(nodeParams.onionMessageRateLimitPerSecond)

startWith(INSTANTIATING, Nothing)

when(INSTANTIATING) {
Expand Down Expand Up @@ -243,6 +246,20 @@ class Peer(val nodeParams: NodeParams, remoteNodeId: PublicKey, wallet: OnChainA
d.channels.values.toSet[ActorRef].foreach(_ ! INPUT_DISCONNECTED) // we deduplicate with toSet because there might be two entries per channel (tmp id and final id)
gotoConnected(connectionReady, d.channels)

case Event(msg: OnionMessage, d: ConnectedData) =>
if (nodeParams.features.hasFeature(Features.OnionMessages) && messageRelayRateLimiter.tryAcquire()) {
OnionMessages.process(nodeParams.privateKey, msg) match {
case OnionMessages.DropMessage(_) => () // We ignore bad messages
case OnionMessages.RelayMessage(nextNodeId, dataToRelay) => context.parent ! Peer.SendOnionMessage(nextNodeId, dataToRelay)
case OnionMessages.ReceiveMessage(_, _) => () // We only relay messages
}
}
stay()

case Event(Peer.SendOnionMessage(toNodeId, msg), d: ConnectedData) if toNodeId == remoteNodeId =>
d.peerConnection ! msg
stay()

case Event(unknownMsg: UnknownMessage, d: ConnectedData) if nodeParams.pluginMessageTags.contains(unknownMsg.tag) =>
context.system.eventStream.publish(UnknownMessageReceived(self, remoteNodeId, unknownMsg, d.connectionInfo))
stay()
Expand Down Expand Up @@ -427,6 +444,8 @@ object Peer {
def apply(uri: NodeURI): Connect = new Connect(uri.nodeId, Some(uri.address))
}

case class SendOnionMessage(nodeId: PublicKey, message: OnionMessage) extends PossiblyHarmful

case class Disconnect(nodeId: PublicKey) extends PossiblyHarmful
case class OpenChannel(remoteNodeId: PublicKey, fundingSatoshis: Satoshi, pushMsat: MilliSatoshi, channelType_opt: Option[SupportedChannelType], fundingTxFeeratePerKw_opt: Option[FeeratePerKw], channelFlags: Option[Byte], timeout_opt: Option[Timeout]) extends PossiblyHarmful {
require(pushMsat <= fundingSatoshis, s"pushMsat must be less or equal to fundingSatoshis")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ class Switchboard(nodeParams: NodeParams, peerFactory: Switchboard.PeerFactory)
case None => sender() ! Status.Failure(new RuntimeException("peer not found"))
}

case s@Peer.SendOnionMessage(nodeId, _) =>
val peer = createOrGetPeer(nodeId, offlineChannels = Set.empty)
peer forward s

case o: Peer.OpenChannel =>
getPeer(o.remoteNodeId) match {
case Some(peer) => peer forward o
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ object IncomingPacket {
case Left(failure) => Left(failure)
// NB: we don't validate the ChannelRelayPacket here because its fees and cltv depend on what channel we'll choose to use.
case Right(DecodedOnionPacket(payload: Onion.ChannelRelayPayload, next)) => Right(ChannelRelayPacket(add, payload, next))
case Right(DecodedOnionPacket(payload: Onion.FinalTlvPayload, _)) => payload.records.get[OnionTlv.TrampolineOnion] match {
case Some(OnionTlv.TrampolineOnion(trampolinePacket)) => decryptOnion(add.paymentHash, privateKey)(trampolinePacket, Sphinx.TrampolinePacket) match {
case Right(DecodedOnionPacket(payload: Onion.FinalTlvPayload, _)) => payload.records.get[OnionPaymentPayloadTlv.TrampolineOnion] match {
case Some(OnionPaymentPayloadTlv.TrampolineOnion(trampolinePacket)) => decryptOnion(add.paymentHash, privateKey)(trampolinePacket, Sphinx.TrampolinePacket) match {
case Left(failure) => Left(failure)
case Right(DecodedOnionPacket(innerPayload: Onion.NodeRelayPayload, next)) => validateNodeRelay(add, payload, innerPayload, next)
case Right(DecodedOnionPacket(innerPayload: Onion.FinalPayload, _)) => validateFinal(add, payload, innerPayload)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ class NodeRelay private(nodeParams: NodeParams,
context.log.debug("sending the payment to the next trampoline node")
val payFSM = outgoingPaymentFactory.spawnOutgoingPayFSM(context, paymentCfg, multiPart = true)
val paymentSecret = randomBytes32() // we generate a new secret to protect against probing attacks
val payment = SendMultiPartPayment(payFsmAdapters, paymentSecret, payloadOut.outgoingNodeId, payloadOut.amountToForward, payloadOut.outgoingCltv, nodeParams.maxPaymentAttempts, routeParams = routeParams, additionalTlvs = Seq(OnionTlv.TrampolineOnion(packetOut)))
val payment = SendMultiPartPayment(payFsmAdapters, paymentSecret, payloadOut.outgoingNodeId, payloadOut.amountToForward, payloadOut.outgoingCltv, nodeParams.maxPaymentAttempts, routeParams = routeParams, additionalTlvs = Seq(OnionPaymentPayloadTlv.TrampolineOnion(packetOut)))
payFSM ! payment
payFSM
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ object MultiPartPaymentLifecycle {
maxAttempts: Int,
assistedRoutes: Seq[Seq[ExtraHop]] = Nil,
routeParams: RouteParams,
additionalTlvs: Seq[OnionTlv] = Nil,
additionalTlvs: Seq[OnionPaymentPayloadTlv] = Nil,
userCustomTlvs: Seq[GenericTlv] = Nil) {
require(totalAmount > 0.msat, s"total amount must be > 0")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class PaymentInitiator(nodeParams: NodeParams, outgoingPaymentFactory: PaymentIn
sender() ! paymentId
val paymentCfg = SendPaymentConfig(paymentId, paymentId, r.externalId, r.paymentHash, r.recipientAmount, r.recipientNodeId, Upstream.Local(paymentId), None, storeInDb = true, publishEvent = true, recordPathFindingMetrics = r.recordPathFindingMetrics, Nil)
val finalExpiry = Channel.MIN_CLTV_EXPIRY_DELTA.toCltvExpiry(nodeParams.currentBlockHeight + 1)
val finalPayload = Onion.FinalTlvPayload(TlvStream(Seq(OnionTlv.AmountToForward(r.recipientAmount), OnionTlv.OutgoingCltv(finalExpiry), OnionTlv.PaymentData(randomBytes32(), r.recipientAmount), OnionTlv.KeySend(r.paymentPreimage)), r.userCustomTlvs))
val finalPayload = Onion.FinalTlvPayload(TlvStream(Seq(OnionPaymentPayloadTlv.AmountToForward(r.recipientAmount), OnionPaymentPayloadTlv.OutgoingCltv(finalExpiry), OnionPaymentPayloadTlv.PaymentData(randomBytes32(), r.recipientAmount), OnionPaymentPayloadTlv.KeySend(r.paymentPreimage)), r.userCustomTlvs))
val fsm = outgoingPaymentFactory.spawnOutgoingPayment(context, paymentCfg)
fsm ! PaymentLifecycle.SendPaymentToNode(sender(), r.recipientNodeId, finalPayload, r.maxAttempts, routeParams = r.routeParams)

Expand Down Expand Up @@ -140,7 +140,7 @@ class PaymentInitiator(nodeParams: NodeParams, outgoingPaymentFactory: PaymentIn
sender() ! SendPaymentToRouteResponse(paymentId, parentPaymentId, Some(trampolineSecret))
val payFsm = outgoingPaymentFactory.spawnOutgoingPayment(context, paymentCfg)
val (trampolineAmount, trampolineExpiry, trampolineOnion) = buildTrampolinePayment(r, trampoline, r.trampolineFees, r.trampolineExpiryDelta)
payFsm ! PaymentLifecycle.SendPaymentToRoute(sender(), Left(r.route), Onion.createMultiPartPayload(r.amount, trampolineAmount, trampolineExpiry, trampolineSecret, Seq(OnionTlv.TrampolineOnion(trampolineOnion))), r.paymentRequest.routingInfo)
payFsm ! PaymentLifecycle.SendPaymentToRoute(sender(), Left(r.route), Onion.createMultiPartPayload(r.amount, trampolineAmount, trampolineExpiry, trampolineSecret, Seq(OnionPaymentPayloadTlv.TrampolineOnion(trampolineOnion))), r.paymentRequest.routingInfo)
case Nil =>
sender() ! SendPaymentToRouteResponse(paymentId, parentPaymentId, None)
val payFsm = outgoingPaymentFactory.spawnOutgoingPayment(context, paymentCfg)
Expand Down Expand Up @@ -175,7 +175,7 @@ class PaymentInitiator(nodeParams: NodeParams, outgoingPaymentFactory: PaymentIn
val trampolineSecret = randomBytes32()
val (trampolineAmount, trampolineExpiry, trampolineOnion) = buildTrampolinePayment(r, r.trampolineNodeId, trampolineFees, trampolineExpiryDelta)
val fsm = outgoingPaymentFactory.spawnOutgoingMultiPartPayment(context, paymentCfg)
fsm ! SendMultiPartPayment(self, trampolineSecret, r.trampolineNodeId, trampolineAmount, trampolineExpiry, nodeParams.maxPaymentAttempts, r.paymentRequest.routingInfo, r.routeParams, Seq(OnionTlv.TrampolineOnion(trampolineOnion)))
fsm ! SendMultiPartPayment(self, trampolineSecret, r.trampolineNodeId, trampolineAmount, trampolineExpiry, nodeParams.maxPaymentAttempts, r.paymentRequest.routingInfo, r.routeParams, Seq(OnionPaymentPayloadTlv.TrampolineOnion(trampolineOnion)))
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package fr.acinq.eclair.wire.protocol

import fr.acinq.eclair.wire.Monitoring.{Metrics, Tags}
import fr.acinq.eclair.wire.protocol.CommonCodecs._
import fr.acinq.eclair.wire.protocol.OnionCodecs.onionRoutingPacketCodec
import fr.acinq.eclair.{Features, KamonExt}
import scodec.bits.{BitVector, ByteVector}
import scodec.codecs._
Expand Down Expand Up @@ -308,6 +309,11 @@ object LightningMessageCodecs {
("timestampRange" | uint32) ::
("tlvStream" | GossipTimestampFilterTlv.gossipTimestampFilterTlvCodec)).as[GossipTimestampFilter]

val onionMessageCodec: Codec[OnionMessage] = (
("blindingKey" | publicKey) ::
("onionPacket" | OnionCodecs.messageOnionPacketCodec) ::
("tlvStream" | OnionMessageTlv.onionMessageTlvCodec)).as[OnionMessage]

// NB: blank lines to minimize merge conflicts

//
Expand Down Expand Up @@ -361,6 +367,7 @@ object LightningMessageCodecs {
.typecase(263, queryChannelRangeCodec)
.typecase(264, replyChannelRangeCodec)
.typecase(265, gossipTimestampFilterCodec)
.typecase(513, onionMessageCodec)
// NB: blank lines to minimize merge conflicts

//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,8 @@ object ReplyChannelRange {

case class GossipTimestampFilter(chainHash: ByteVector32, firstTimestamp: TimestampSecond, timestampRange: Long, tlvStream: TlvStream[GossipTimestampFilterTlv] = TlvStream.empty) extends RoutingMessage with HasChainHash

case class OnionMessage(blindingKey: PublicKey, onionRoutingPacket: OnionRoutingPacket, tlvStream: TlvStream[OnionMessageTlv] = TlvStream.empty) extends LightningMessage

// NB: blank lines to minimize merge conflicts

//
Expand Down
Loading

0 comments on commit 1eabee0

Please sign in to comment.