Skip to content

Commit

Permalink
Use EncryptedRecipientDataTlv
Browse files Browse the repository at this point in the history
  • Loading branch information
thomash-acinq committed Nov 5, 2021
1 parent 1b47940 commit dbcd4a9
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import fr.acinq.bitcoin.Crypto.{PrivateKey, PublicKey}
import fr.acinq.eclair.crypto.Sphinx
import fr.acinq.eclair.wire.protocol.MessageOnion.{finalBlindedTlvCodec, messageOnionPerHopPayloadCodec, messageRelayPayloadCodec, relayBlindedTlvCodec}
import fr.acinq.eclair.wire.protocol.MessageTlv._
import fr.acinq.eclair.wire.protocol.{BadOnion, OnionMessage, OnionMessagePayloadTlv, TlvStream}
import fr.acinq.eclair.wire.protocol.{BadOnion, EncryptedRecipientDataTlv, OnionMessage, OnionMessagePayloadTlv, TlvStream}
import scodec.bits.ByteVector
import scodec.{Attempt, DecodeResult}

Expand All @@ -14,27 +14,27 @@ object OnionMessages {

case class IntermediateNode(nodeId: PublicKey, padding: Option[ByteVector] = None)

case class Recipient(nodeId: PublicKey, pathId: Option[ByteVector], padding: Option[ByteVector] = None)
case class Recipient(nodeId: PublicKey, secret: Option[ByteVector], padding: Option[ByteVector] = None)

def buildRoute(blindingSecret: PrivateKey,
intermediateNodes: Seq[IntermediateNode],
destination: Either[Recipient, Sphinx.RouteBlinding.BlindedRoute]): Sphinx.RouteBlinding.BlindedRoute = {
val last = destination match {
case Left(Recipient(nodeId, _, _)) => NextNodeId(nodeId) :: Nil
case Right(Sphinx.RouteBlinding.BlindedRoute(nodeId, blindingKey, _)) => NextNodeId(nodeId) :: BlindingPoint(blindingKey) :: Nil
case Left(Recipient(nodeId, _, _)) => EncryptedRecipientDataTlv.OutgoingNodeId(nodeId) :: Nil
case Right(Sphinx.RouteBlinding.BlindedRoute(nodeId, blindingKey, _)) => EncryptedRecipientDataTlv.OutgoingNodeId(nodeId) :: EncryptedRecipientDataTlv.NextBlinding(blindingKey) :: Nil
}
val intermediatePayloads =
if (intermediateNodes.isEmpty) {
Nil
} else {
(intermediateNodes.tail.map(node => NextNodeId(node.nodeId) :: Nil) :+ last)
.zip(intermediateNodes).map { case (tlvs, hop) => hop.padding.map(Padding(_) :: Nil).getOrElse(Nil) ++ tlvs }
(intermediateNodes.tail.map(node => EncryptedRecipientDataTlv.OutgoingNodeId(node.nodeId) :: Nil) :+ last)
.zip(intermediateNodes).map { case (tlvs, hop) => hop.padding.map(EncryptedRecipientDataTlv.Padding(_) :: Nil).getOrElse(Nil) ++ tlvs }
.map(tlvs => RelayBlindedTlv(TlvStream(tlvs)))
.map(relayBlindedTlvCodec.encode(_).require.bytes)
}
destination match {
case Left(Recipient(nodeId, pathId, padding)) =>
val tlvs = padding.map(Padding(_) :: Nil).getOrElse(Nil) ++ pathId.map(PathId(_) :: Nil).getOrElse(Nil)
case Left(Recipient(nodeId, recipientSecret, padding)) =>
val tlvs = padding.map(EncryptedRecipientDataTlv.Padding(_) :: Nil).getOrElse(Nil) ++ recipientSecret.map(EncryptedRecipientDataTlv.RecipientSecret(_) :: Nil).getOrElse(Nil)
val lastPayload = finalBlindedTlvCodec.encode(FinalBlindedTlv(TlvStream(tlvs))).require.bytes
Sphinx.RouteBlinding.create(blindingSecret, intermediateNodes.map(_.nodeId) :+ nodeId, intermediatePayloads :+ lastPayload)
case Right(route) =>
Expand Down Expand Up @@ -95,7 +95,7 @@ object OnionMessages {
case Success((decrypted, _)) =>
finalBlindedTlvCodec.decode(decrypted.bits) match {
case Attempt.Successful(DecodeResult(messageToSelf, _)) =>
ReceiveMessage(finalPayload, messageToSelf.pathId)
ReceiveMessage(finalPayload, messageToSelf.recipientSecret)
case Attempt.Failure(_) => DropMessage("Can't decode blinded TLV")
}
case Failure(_) => DropMessage("Can't decrypt blinded TLV")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ object EncryptedRecipientDataTlv {
*/
case class RecipientSecret(data: ByteVector) extends EncryptedRecipientDataTlv

/** Blinding override for the rest of the route. */
case class NextBlinding(blinding: PublicKey) extends EncryptedRecipientDataTlv

}

object EncryptedRecipientDataCodecs {
Expand All @@ -55,12 +58,14 @@ object EncryptedRecipientDataCodecs {
private val outgoingChannelId: Codec[OutgoingChannelId] = variableSizeBytesLong(varintoverflow, "short_channel_id" | shortchannelid).as[OutgoingChannelId]
private val outgoingNodeId: Codec[OutgoingNodeId] = variableSizeBytesLong(varintoverflow, "node_id" | publicKey).as[OutgoingNodeId]
private val recipientSecret: Codec[RecipientSecret] = variableSizeBytesLong(varintoverflow, "recipient_secret" | bytes).as[RecipientSecret]
private val nextBlinding: Codec[NextBlinding] = variableSizeBytesLong(varintoverflow, "blinding" | publicKey).as[NextBlinding]

private val encryptedRecipientDataTlvCodec = discriminated[EncryptedRecipientDataTlv].by(varint)
.typecase(UInt64(1), padding)
.typecase(UInt64(2), outgoingChannelId)
.typecase(UInt64(4), outgoingNodeId)
.typecase(UInt64(6), recipientSecret)
.typecase(UInt64(12), nextBlinding)

val encryptedRecipientDataCodec: Codec[TlvStream[EncryptedRecipientDataTlv]] = TlvCodecs.tlvStream[EncryptedRecipientDataTlv](encryptedRecipientDataTlvCodec).complete

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,18 @@ package fr.acinq.eclair.wire.protocol
import fr.acinq.bitcoin.Crypto.PublicKey
import fr.acinq.eclair.UInt64
import fr.acinq.eclair.crypto.Sphinx.RouteBlinding.{BlindedNode, BlindedRoute}
import fr.acinq.eclair.wire.protocol.EncryptedRecipientDataCodecs.encryptedRecipientDataCodec
import fr.acinq.eclair.wire.protocol.OnionRoutingCodecs.{ForbiddenTlv, MissingRequiredTlv, onionRoutingPacketCodec}
import scodec.bits.ByteVector

sealed trait OnionMessagePayloadTlv extends Tlv

sealed trait BlindedTlv extends Tlv

object MessageTlv {

/** Blinding ephemeral public key that should be used to derive shared secrets when using route blinding. */
case class BlindingPoint(publicKey: PublicKey) extends BlindedTlv

case class ReplyPath(blindedRoute: BlindedRoute) extends OnionMessagePayloadTlv

case class EncTlv(bytes: ByteVector) extends OnionMessagePayloadTlv

case class NextNodeId(nodeId: PublicKey) extends BlindedTlv

case class Padding(bytes: ByteVector) extends BlindedTlv

case class PathId(bytes: ByteVector) extends BlindedTlv

sealed trait MessagePacket

case class MessageRelayPayload(records: TlvStream[OnionMessagePayloadTlv]) extends MessagePacket {
Expand All @@ -36,13 +26,13 @@ object MessageTlv {
val replyPath: Option[MessageTlv.ReplyPath] = records.get[MessageTlv.ReplyPath]
}

case class RelayBlindedTlv(records: TlvStream[BlindedTlv]) {
val nextNodeId: PublicKey = records.get[MessageTlv.NextNodeId].get.nodeId
val nextBlinding: Option[PublicKey] = records.get[MessageTlv.BlindingPoint].map(_.publicKey)
case class RelayBlindedTlv(records: TlvStream[EncryptedRecipientDataTlv]) {
val nextNodeId: PublicKey = records.get[EncryptedRecipientDataTlv.OutgoingNodeId].get.nodeId
val nextBlinding: Option[PublicKey] = records.get[EncryptedRecipientDataTlv.NextBlinding].map(_.blinding)
}

case class FinalBlindedTlv(records: TlvStream[BlindedTlv]) {
val pathId: Option[ByteVector] = records.get[MessageTlv.PathId].map(_.bytes)
case class FinalBlindedTlv(records: TlvStream[EncryptedRecipientDataTlv]) {
val recipientSecret: Option[ByteVector] = records.get[EncryptedRecipientDataTlv.RecipientSecret].map(_.data)
}
}

Expand Down Expand Up @@ -80,30 +70,15 @@ object MessageOnion {
case MessageFinalPayload(tlvs) => tlvs
})

private val padding: Codec[Padding] = variableSizeBytesLong(varintoverflow, "padding" | bytes).as[Padding]

private val nextNodeId: Codec[NextNodeId] = variableSizeBytesLong(varintoverflow, "node_id" | publicKey).as[NextNodeId]

private val blindingKey: Codec[BlindingPoint] = variableSizeBytesLong(varintoverflow, "blinding" | publicKey).as[BlindingPoint]

private val pathId: Codec[PathId] = variableSizeBytesLong(varintoverflow, "path_id" | bytes).as[PathId]

private val blindedTlvCodec: Codec[TlvStream[BlindedTlv]] = TlvCodecs.tlvStream[BlindedTlv](
discriminated[BlindedTlv].by(varint)
.typecase(UInt64(1), padding)
.typecase(UInt64(4), nextNodeId)
.typecase(UInt64(12), blindingKey)
.typecase(UInt64(14), pathId)).complete

val relayBlindedTlvCodec: Codec[RelayBlindedTlv] = blindedTlvCodec.narrow({
case tlvs if tlvs.get[NextNodeId].isEmpty => Attempt.failure(MissingRequiredTlv(UInt64(4)))
case tlvs if tlvs.get[PathId].nonEmpty => Attempt.failure(ForbiddenTlv(UInt64(14)))
val relayBlindedTlvCodec: Codec[RelayBlindedTlv] = encryptedRecipientDataCodec.narrow({
case tlvs if tlvs.get[EncryptedRecipientDataTlv.OutgoingNodeId].isEmpty => Attempt.failure(MissingRequiredTlv(UInt64(4)))
case tlvs if tlvs.get[EncryptedRecipientDataTlv.RecipientSecret].nonEmpty => Attempt.failure(ForbiddenTlv(UInt64(6)))
case tlvs => Attempt.successful(RelayBlindedTlv(tlvs))
}, {
case RelayBlindedTlv(tlvs) => tlvs
})

val finalBlindedTlvCodec: Codec[FinalBlindedTlv] = blindedTlvCodec.narrow(
val finalBlindedTlvCodec: Codec[FinalBlindedTlv] = encryptedRecipientDataCodec.narrow(
tlvs => Attempt.successful(FinalBlindedTlv(tlvs))
, {
case FinalBlindedTlv(tlvs) => tlvs
Expand Down
Loading

0 comments on commit dbcd4a9

Please sign in to comment.