Skip to content

Commit

Permalink
Simplify onion message codec
Browse files Browse the repository at this point in the history
The scodec magic was quite hard to read, and the use of the prefix wasn't
very intuitive since Sphinx uses both a prefix and a suffix.

Also added more codec tests.
  • Loading branch information
t-bast committed Nov 9, 2021
1 parent 333e9ef commit 036ce10
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,10 @@ object MessageOnionCodecs {

def messageOnionPerHopPayloadCodec(isLastPacket: Boolean): Codec[PerHopPayload] = if (isLastPacket) finalPerHopPayloadCodec.upcast[PerHopPayload] else relayPerHopPayloadCodec.upcast[PerHopPayload]

val messageOnionPacketCodec: Codec[OnionRoutingPacket] =
(variableSizePrefixedBytes(uint16.xmap(_ - 66, _ + 66),
("version" | uint8) ~
("publicKey" | bytes(33)),
("onionPayload" | bytes)) ~
("hmac" | bytes32) flattenLeftPairs).as[OnionRoutingPacket]
val messageOnionPacketCodec: Codec[OnionRoutingPacket] = variableSizeBytes(uint16, bytes).exmap[OnionRoutingPacket](
// The Sphinx packet header contains a version (1 byte), a public key (33 bytes) and a mac (32 bytes) -> total 66 bytes
bytes => OnionRoutingCodecs.onionRoutingPacketCodec(bytes.length.toInt - 66).decode(bytes.bits).map(_.value),
onion => OnionRoutingCodecs.onionRoutingPacketCodec(onion.payload.length.toInt).encode(onion).map(_.bytes)
)

}
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,57 @@ class CommonCodecsSpec extends AnyFunSuite {
}
}

test("encode/decode bytevector32") {
val testCases = Seq(
(hex"0000000000000000000000000000000000000000000000000000000000000000", Some(ByteVector32.Zeroes)),
(hex"0101010101010101010101010101010101010101010101010101010101010101", Some(ByteVector32(hex"0101010101010101010101010101010101010101010101010101010101010101"))),
(hex"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", Some(ByteVector32(hex"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"))),
// Ignore additional trailing bytes
(hex"000000000000000000000000000000000000000000000000000000000000000000", Some(ByteVector32.Zeroes)),
(hex"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff00", Some(ByteVector32(hex"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"))),
// Not enough bytes
(hex"00000000000000000000000000000000000000000000000000000000000000", None),
(hex"", None)
)

for ((encoded, expected_opt) <- testCases) {
expected_opt match {
case Some(expected) =>
val decoded = bytes32.decode(encoded.bits).require.value
assert(decoded === expected)
assert(expected.bytes === bytes32.encode(decoded).require.bytes)
case None =>
assert(bytes32.decode(encoded.bits).isFailure)
}
}
}

test("encode/decode bytevector64") {
val testCases = Seq(
(hex"00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", Some(ByteVector64.Zeroes)),
(hex"01010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101", Some(ByteVector64(hex"01010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101010101"))),
(hex"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff", Some(ByteVector64(hex"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"))),
// Ignore additional trailing bytes
(hex"0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", Some(ByteVector64.Zeroes)),
(hex"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff00", Some(ByteVector64(hex"ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff"))),
// Not enough bytes
(hex"000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000", None),
(hex"00000000000000000000000000000000000000000000000000000000000000", None),
(hex"", None)
)

for ((encoded, expected_opt) <- testCases) {
expected_opt match {
case Some(expected) =>
val decoded = bytes64.decode(encoded.bits).require.value
assert(decoded === expected)
assert(expected.bytes === bytes64.encode(decoded).require.bytes)
case None =>
assert(bytes64.decode(encoded.bits).isFailure)
}
}
}

test("encode/decode with private key codec") {
val value = PrivateKey(randomBytes32())
val wire = privateKey.encode(value).require
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,19 +51,39 @@ class MessageOnionCodecsSpec extends AnyFunSuiteLike {
assert(finalPerHopPayloadCodec.decode(serialized.bits).require.value === payload)
}

test("onion packet can be any size"){
test("onion packet can be any size") {
{ // small onion
val onion = OnionRoutingPacket(1, hex"032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e668680991", hex"012345679abcdef", ByteVector32(hex"0000111122223333444455556666777788889999aaaabbbbccccddddeeee0000"))
val serialized = hex"004a01032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e6686809910012345679abcdef0000111122223333444455556666777788889999aaaabbbbccccddddeeee0000"
val onion = OnionRoutingPacket(1, hex"032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e668680991", hex"0012345679abcdef", ByteVector32(hex"0000111122223333444455556666777788889999aaaabbbbccccddddeeee0000"))
val serialized = hex"004a 01 032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e668680991 0012345679abcdef 0000111122223333444455556666777788889999aaaabbbbccccddddeeee0000"
assert(messageOnionPacketCodec.encode(onion).require.bytes === serialized)
assert(messageOnionPacketCodec.decode(serialized.bits).require.value === onion)
}
{ // larger onion
val onion = OnionRoutingPacket(2, hex"027f31ebc5462c1fdce1b737ecff52d37d75dea43ce11c74d25aa297165faa2007", hex"012345679abcdef012345679abcdef012345679abcdef012345679abcdef012345679abcdef", ByteVector32(hex"eeee0000111122223333444455556666777788889999aaaabbbbccccddddeeee"))
val serialized = hex"006802027f31ebc5462c1fdce1b737ecff52d37d75dea43ce11c74d25aa297165faa20070012345679abcdef012345679abcdef012345679abcdef012345679abcdef012345679abcdefeeee0000111122223333444455556666777788889999aaaabbbbccccddddeeee"
val onion = OnionRoutingPacket(2, hex"027f31ebc5462c1fdce1b737ecff52d37d75dea43ce11c74d25aa297165faa2007", hex"0012345679abcdef012345679abcdef012345679abcdef012345679abcdef012345679abcdef", ByteVector32(hex"eeee0000111122223333444455556666777788889999aaaabbbbccccddddeeee"))
val serialized = hex"0068 02 027f31ebc5462c1fdce1b737ecff52d37d75dea43ce11c74d25aa297165faa2007 0012345679abcdef012345679abcdef012345679abcdef012345679abcdef012345679abcdef eeee0000111122223333444455556666777788889999aaaabbbbccccddddeeee"
assert(messageOnionPacketCodec.encode(onion).require.bytes === serialized)
assert(messageOnionPacketCodec.decode(serialized.bits).require.value === onion)
}
{ // onion with trailing additional bytes
val onion = OnionRoutingPacket(0, hex"032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e668680991", hex"ffffffff", ByteVector32.Zeroes)
val serialized = hex"0046 00 032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e668680991 ffffffff 0000000000000000000000000000000000000000000000000000000000000000 0a01020000030400000000"
assert(messageOnionPacketCodec.encode(onion).require.bytes === serialized.dropRight(11))
assert(messageOnionPacketCodec.decode(serialized.bits).require.value === onion)
}
{ // onion with empty payload
val onion = OnionRoutingPacket(0, hex"032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e668680991", hex"", ByteVector32.Zeroes)
val serialized = hex"0042 00 032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e668680991 0000000000000000000000000000000000000000000000000000000000000000"
assert(messageOnionPacketCodec.encode(onion).require.bytes === serialized)
assert(messageOnionPacketCodec.decode(serialized.bits).require.value === onion)
}
{ // onion length too big
val serialized = hex"0048 00 032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e668680991 ffffffff 0000000000000000000000000000000000000000000000000000000000000000"
assert(messageOnionPacketCodec.decode(serialized.bits).isFailure)
}
{ // onion length way too big
val serialized = hex"00ff 00 032c0b7cf95324a07d05398b240174dc0c2be444d96b159aa6c7f7b1e668680991 ffffffff 0000000000000000000000000000000000000000000000000000000000000000"
assert(messageOnionPacketCodec.decode(serialized.bits).isFailure)
}
}

}

0 comments on commit 036ce10

Please sign in to comment.