Skip to content

Commit

Permalink
Rework Bolt11Features serialization, cleaner, better testable
Browse files Browse the repository at this point in the history
  • Loading branch information
optout21 committed Aug 31, 2024
1 parent 22a53d4 commit 5fa337a
Show file tree
Hide file tree
Showing 5 changed files with 263 additions and 100 deletions.
116 changes: 80 additions & 36 deletions lightning-invoice/src/de.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use core::num::ParseIntError;
use core::str;
use core::str::FromStr;

use bech32::primitives::decode::{CheckedHrpstring, CheckedHrpstringError};
use bech32::{Bech32, Fe32, Fe32IterExt};
use bech32::primitives::decode::{CheckedHrpstring, CheckedHrpstringError, ChecksumError};

use bitcoin::{PubkeyHash, ScriptHash, WitnessVersion};
use bitcoin::hashes::Hash;
Expand Down Expand Up @@ -39,19 +39,37 @@ pub trait FromBase32: Sized {
// FromBase32 implementations are here, because the trait is in this module.

impl FromBase32 for Vec<u8> {
type Err = CheckedHrpstringError;
type Err = Bolt11ParseError;

fn from_base32(data: &[Fe32]) -> Result<Self, Self::Err> {
Ok(data.iter().copied().fes_to_bytes().collect::<Self>())
}
}

impl<const N: usize> FromBase32 for [u8; N] {
type Err = Bolt11ParseError;

fn from_base32(data: &[Fe32]) -> Result<Self, Self::Err> {
data.iter().copied().fes_to_bytes().collect::<Vec<_>>().try_into().map_err(|_| {
Bolt11ParseError::InvalidSliceLength(
data.len(),
(N * 8 + 4) / 5,
"<[u8; N]>::from_base32()".into(),
)
})
}
}

impl FromBase32 for PaymentSecret {
type Err = CheckedHrpstringError;
type Err = Bolt11ParseError;

fn from_base32(field_data: &[Fe32]) -> Result<Self, Self::Err> {
if field_data.len() != 52 {
return Err(CheckedHrpstringError::Checksum(ChecksumError::InvalidLength)) // TODO(bech32): not entirely accurate
return Err(Bolt11ParseError::InvalidSliceLength(
field_data.len(),
52,
"PaymentSecret::from_base32()".into(),
));
}
let data_bytes = Vec::<u8>::from_base32(field_data)?;
let mut payment_secret = [0; 32];
Expand All @@ -61,28 +79,45 @@ impl FromBase32 for PaymentSecret {
}

impl FromBase32 for Bolt11InvoiceFeatures {
type Err = CheckedHrpstringError;
type Err = Bolt11ParseError;

/// Convert to byte values, by packing the 5-bit groups,
/// putting the 5-bit values from left to-right (reverse order),
/// starting from the rightmost bit,
/// and taking the resulting 8-bit values (right to left),
/// with the leading 0's skipped.
fn from_base32(field_data: &[Fe32]) -> Result<Self, Self::Err> {
// Explanation for the "7": the normal way to round up when dividing is to add the divisor
// minus one before dividing
let length_bytes = (field_data.len() * 5 + 7) / 8 as usize;
let mut res_bytes: Vec<u8> = vec![0; length_bytes];
for (u5_idx, chunk) in field_data.iter().enumerate() {
let bit_pos_from_right_0_indexed = (field_data.len() - u5_idx - 1) * 5;
let new_byte_idx = (bit_pos_from_right_0_indexed / 8) as usize;
let new_bit_pos = bit_pos_from_right_0_indexed % 8;
let chunk_u16 = chunk.to_u8() as u16;
res_bytes[new_byte_idx] |= ((chunk_u16 << new_bit_pos) & 0xff) as u8;
if new_byte_idx != length_bytes - 1 {
res_bytes[new_byte_idx + 1] |= ((chunk_u16 >> (8-new_bit_pos)) & 0xff) as u8;
// Fe32 conversion cannot be used, because this unpacks from right, right-to-left
// Carry bits, 0, 1, 2, 3, or 4 bits
let mut carry_bits = 0;
let mut carry = 0u8;
let mut output = Vec::<u8>::new();

// Iterate over input in reverse
for curr_in in field_data.iter().rev() {
let curr_in_as_u8 = curr_in.to_u8();
if carry_bits >= 3 {
// we have a new full byte -- 3, 4 or 5 carry bits, plus 5 new ones
// For combining with carry '|', '^', or '+' can be used (disjoint bit positions)
let next = carry + (curr_in_as_u8 << carry_bits);
output.push(next);
carry = curr_in_as_u8 >> (8 - carry_bits);
carry_bits -= 3; // added 5, removed 8
} else {
// only 0, 1, or 2 carry bits, plus 5 new ones
carry += curr_in_as_u8 << carry_bits;
carry_bits += 5;
}
}
// Trim the highest feature bits.
while !res_bytes.is_empty() && res_bytes[res_bytes.len() - 1] == 0 {
res_bytes.pop();
// No more inputs, output remaining (if any)
if carry_bits > 0 {
output.push(carry);
}
Ok(Bolt11InvoiceFeatures::from_le_bytes(res_bytes))
// Trim the highest feature bits
while !output.is_empty() && output[output.len() - 1] == 0 {
output.pop();
}
Ok(Bolt11InvoiceFeatures::from_le_bytes(output))
}
}

Expand Down Expand Up @@ -342,7 +377,7 @@ impl FromStr for SignedRawBolt11Invoice {
}

let raw_hrp: RawHrp = hrp.to_string().to_lowercase().parse()?;
let data_part = RawDataPart::from_base32(&data[..data.len()-SIGNATURE_LEN5])?;
let data_part = RawDataPart::from_base32(&data[..data.len() - SIGNATURE_LEN5])?;

Ok(SignedRawBolt11Invoice {
raw_invoice: RawBolt11Invoice {
Expand All @@ -351,9 +386,9 @@ impl FromStr for SignedRawBolt11Invoice {
},
hash: RawBolt11Invoice::hash_from_parts(
hrp.to_string().as_bytes(),
&data[..data.len()-SIGNATURE_LEN5]
&data[..data.len() - SIGNATURE_LEN5],
),
signature: Bolt11InvoiceSignature::from_base32(&data[data.len()-SIGNATURE_LEN5..])?,
signature: Bolt11InvoiceSignature::from_base32(&data[data.len() - SIGNATURE_LEN5..])?,
})
}
}
Expand Down Expand Up @@ -415,7 +450,11 @@ impl FromBase32 for PositiveTimestamp {

fn from_base32(b32: &[Fe32]) -> Result<Self, Self::Err> {
if b32.len() != 7 {
return Err(Bolt11ParseError::InvalidSliceLength("PositiveTimestamp::from_base32()".into()));
return Err(Bolt11ParseError::InvalidSliceLength(
b32.len(),
7,
"PositiveTimestamp::from_base32()".into(),
));
}
let timestamp: u64 = parse_u64_be(b32)
.expect("7*5bit < 64bit, no overflow possible");
Expand All @@ -430,7 +469,11 @@ impl FromBase32 for Bolt11InvoiceSignature {
type Err = Bolt11ParseError;
fn from_base32(signature: &[Fe32]) -> Result<Self, Self::Err> {
if signature.len() != 104 {
return Err(Bolt11ParseError::InvalidSliceLength("Bolt11InvoiceSignature::from_base32()".into()));
return Err(Bolt11ParseError::InvalidSliceLength(
signature.len(),
104,
"Bolt11InvoiceSignature::from_base32()".into(),
));
}
let recoverable_signature_bytes = Vec::<u8>::from_base32(signature)?;
let signature = &recoverable_signature_bytes[0..64];
Expand Down Expand Up @@ -483,7 +526,9 @@ fn parse_tagged_parts(data: &[Fe32]) -> Result<Vec<RawTaggedField>, Bolt11ParseE
Ok(field) => {
parts.push(RawTaggedField::KnownSemantics(field))
},
Err(Bolt11ParseError::Skip)|Err(Bolt11ParseError::Bech32Error(_)) => {
Err(Bolt11ParseError::Skip)
| Err(Bolt11ParseError::InvalidSliceLength(_, _, _))
| Err(Bolt11ParseError::Bech32Error(_)) => {
parts.push(RawTaggedField::UnknownSemantics(field.into()))
},
Err(e) => {return Err(e)}
Expand Down Expand Up @@ -688,9 +733,6 @@ impl Display for Bolt11ParseError {
Bolt11ParseError::Bech32Error(ref e) => {
write!(f, "Invalid bech32: {}", e)
}
Bolt11ParseError::GenericBech32Error => {
write!(f, "Invalid bech32")
}
Bolt11ParseError::ParseAmountError(ref e) => {
write!(f, "Invalid amount in hrp ({})", e)
}
Expand All @@ -700,9 +742,13 @@ impl Display for Bolt11ParseError {
Bolt11ParseError::DescriptionDecodeError(ref e) => {
write!(f, "Description is not a valid utf-8 string: {}", e)
}
Bolt11ParseError::InvalidSliceLength(ref function) => {
write!(f, "Slice in function {} had the wrong length", function)
}
Bolt11ParseError::InvalidSliceLength(ref len, ref expected, ref function) => {
write!(
f,
"Slice had length {} instead of {} in function {}",
len, expected, function
)
},
Bolt11ParseError::BadPrefix => f.write_str("did not begin with 'ln'"),
Bolt11ParseError::UnknownCurrency => f.write_str("currency code unknown"),
Bolt11ParseError::UnknownSiPrefix => f.write_str("unknown SI prefix"),
Expand Down Expand Up @@ -767,9 +813,7 @@ from_error!(Bolt11ParseError::DescriptionDecodeError, str::Utf8Error);

impl From<CheckedHrpstringError> for Bolt11ParseError {
fn from(e: CheckedHrpstringError) -> Self {
match e {
_ => Bolt11ParseError::Bech32Error(e)
}
Self::Bech32Error(e)
}
}

Expand Down
26 changes: 17 additions & 9 deletions lightning-invoice/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ extern crate serde;
#[cfg(feature = "std")]
use std::time::SystemTime;

use bech32::Fe32;
use bech32::primitives::decode::CheckedHrpstringError;
use bech32::Fe32;
use bitcoin::{Address, Network, PubkeyHash, ScriptHash, WitnessProgram, WitnessVersion};
use bitcoin::hashes::{Hash, sha256};
use lightning_types::features::Bolt11InvoiceFeatures;
Expand Down Expand Up @@ -79,17 +79,16 @@ use crate::prelude::*;

/// Re-export serialization traits
#[cfg(fuzzing)]
pub use crate::ser::Base32Iterable;
#[cfg(fuzzing)]
pub use crate::de::FromBase32;
#[cfg(fuzzing)]
pub use crate::ser::Base32Iterable;

/// Errors that indicate what is wrong with the invoice. They have some granularity for debug
/// reasons, but should generally result in an "invalid BOLT11 invoice" message for the user.
#[allow(missing_docs)]
#[derive(PartialEq, Eq, Debug, Clone)]
pub enum Bolt11ParseError {
Bech32Error(CheckedHrpstringError),
GenericBech32Error,
ParseAmountError(ParseIntError),
MalformedSignature(bitcoin::secp256k1::Error),
BadPrefix,
Expand All @@ -105,7 +104,8 @@ pub enum Bolt11ParseError {
InvalidPubKeyHashLength,
InvalidScriptHashLength,
InvalidRecoveryId,
InvalidSliceLength(String),
// Invalid length, with actual length, expected length, and function info
InvalidSliceLength(usize, usize, String),

/// Not an error, but used internally to signal that a part of the invoice should be ignored
/// according to BOLT11
Expand Down Expand Up @@ -439,10 +439,18 @@ impl PartialOrd for RawTaggedField {
impl Ord for RawTaggedField {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
match (self, other) {
(RawTaggedField::KnownSemantics(ref a), RawTaggedField::KnownSemantics(ref b)) => a.cmp(b),
(RawTaggedField::UnknownSemantics(ref a), RawTaggedField::UnknownSemantics(ref b)) => a.iter().map(|a| a.to_u8()).cmp(b.iter().map(|b| b.to_u8())),
(RawTaggedField::KnownSemantics(..), RawTaggedField::UnknownSemantics(..)) => core::cmp::Ordering::Less,
(RawTaggedField::UnknownSemantics(..), RawTaggedField::KnownSemantics(..)) => core::cmp::Ordering::Greater,
(RawTaggedField::KnownSemantics(ref a), RawTaggedField::KnownSemantics(ref b)) => {
a.cmp(b)
},
(RawTaggedField::UnknownSemantics(ref a), RawTaggedField::UnknownSemantics(ref b)) => {
a.iter().map(|a| a.to_u8()).cmp(b.iter().map(|b| b.to_u8()))
},
(RawTaggedField::KnownSemantics(..), RawTaggedField::UnknownSemantics(..)) => {
core::cmp::Ordering::Less
},
(RawTaggedField::UnknownSemantics(..), RawTaggedField::KnownSemantics(..)) => {
core::cmp::Ordering::Greater
},
}
}
}
Expand Down
Loading

0 comments on commit 5fa337a

Please sign in to comment.