diff --git a/lightning/Cargo.toml b/lightning/Cargo.toml index e326c65803b..8f907513862 100644 --- a/lightning/Cargo.toml +++ b/lightning/Cargo.toml @@ -70,6 +70,7 @@ tokio = { version = "1.14.1", features = [ "macros", "rt-multi-thread", ] } +rusqlite = { version = "0.30", features = ["bundled"] } [dev-dependencies] regex = "1.5.6" diff --git a/lightning/src/color_ext/database.rs b/lightning/src/color_ext/database.rs index 44a5599c4b3..518f6e6f284 100644 --- a/lightning/src/color_ext/database.rs +++ b/lightning/src/color_ext/database.rs @@ -1,233 +1,368 @@ use core::fmt::{Display, Formatter}; use std::{ - collections::HashMap, - io::{self, Read, Write}, - sync::{Arc, Mutex}, + io::{self, Read, Write}, + path::PathBuf, + sync::{Arc, Mutex}, }; use bitcoin::Txid; +use rusqlite::{Connection, params}; +use serde_json; use crate::{ - ln::{ChannelId, PaymentHash}, - rgb_utils::{RgbInfo, RgbPaymentInfo, TransferInfo}, + ln::{ChannelId, PaymentHash}, + rgb_utils::{RgbInfo, RgbPaymentInfo, TransferInfo}, }; #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)] pub enum PaymentDirection { - Inbound, - Outbound, + Inbound, + Outbound, } impl From for PaymentDirection { - fn from(inbound: bool) -> Self { - if inbound { - PaymentDirection::Inbound - } else { - PaymentDirection::Outbound - } - } + fn from(inbound: bool) -> Self { + if inbound { + PaymentDirection::Inbound + } else { + PaymentDirection::Outbound + } + } } impl Display for PaymentDirection { - fn fmt(&self, f: &mut Formatter) -> Result<(), std::fmt::Error> { - match self { - PaymentDirection::Inbound => write!(f, "Inbound"), - PaymentDirection::Outbound => write!(f, "Outbound"), - } - } + fn fmt(&self, f: &mut Formatter) -> Result<(), std::fmt::Error> { + match self { + PaymentDirection::Inbound => write!(f, "Inbound"), + PaymentDirection::Outbound => write!(f, "Outbound"), + } + } } #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct ProxyIdKey { - channel_id: ChannelId, - payment_hash: PaymentHash, - direction: PaymentDirection, + channel_id: ChannelId, + payment_hash: PaymentHash, + direction: PaymentDirection, } impl Display for ProxyIdKey { - fn fmt(&self, f: &mut Formatter) -> Result<(), std::fmt::Error> { - write!(f, "{}.{}.{}", self.channel_id, self.payment_hash, self.direction) - } + fn fmt(&self, f: &mut Formatter) -> Result<(), std::fmt::Error> { + write!(f, "{}.{}.{}", self.channel_id, self.payment_hash, self.direction) + } } impl ProxyIdKey { - pub fn new( - channel_id: &ChannelId, payment_hash: &PaymentHash, direction: PaymentDirection, - ) -> Self { - Self { channel_id: channel_id.clone(), payment_hash: payment_hash.clone(), direction } - } + pub fn new( + channel_id: &ChannelId, payment_hash: &PaymentHash, direction: PaymentDirection, + ) -> Self { + Self { channel_id: channel_id.clone(), payment_hash: payment_hash.clone(), direction } + } } #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct PaymentHashKey { - payment_hash: PaymentHash, - direction: PaymentDirection, + payment_hash: PaymentHash, + direction: PaymentDirection, } impl From<&ProxyIdKey> for PaymentHashKey { - fn from(proxy_id_key: &ProxyIdKey) -> Self { - Self { payment_hash: proxy_id_key.payment_hash, direction: proxy_id_key.direction } - } + fn from(proxy_id_key: &ProxyIdKey) -> Self { + Self { payment_hash: proxy_id_key.payment_hash, direction: proxy_id_key.direction } + } } impl Display for PaymentHashKey { - fn fmt(&self, f: &mut Formatter) -> Result<(), std::fmt::Error> { - write!(f, "{}.{}", self.payment_hash, self.direction) - } + fn fmt(&self, f: &mut Formatter) -> Result<(), std::fmt::Error> { + write!(f, "{}.{}", self.payment_hash, self.direction) + } } impl PaymentHashKey { - pub fn new(payment_hash: PaymentHash, direction: PaymentDirection) -> Self { - Self { payment_hash, direction } - } + pub fn new(payment_hash: PaymentHash, direction: PaymentDirection) -> Self { + Self { payment_hash, direction } + } } +fn hex_str_decode(s: &str) -> Result, &'static str> { + if s.len() % 2 != 0 { + return Err("Hex string has an odd length"); + } -#[derive(Clone, Debug, Default)] + let mut bytes = Vec::with_capacity(s.len() / 2); + + for i in (0..s.len()).step_by(2) { + let hex_pair = &s[i..i + 2]; + + match u8::from_str_radix(hex_pair, 16) { + Ok(byte) => bytes.push(byte), + Err(_) => return Err("Invalid hex string"), + } + } + + Ok(bytes) +} + +fn channel_from_str(s: &str) -> ChannelId { + let array: [u8; 32] = hex_str_decode(s).expect("Invalid hex string").try_into().expect("Invalid length"); + ChannelId::from_bytes(array) +} + +#[derive(Clone, Debug)] pub(crate) struct RgbPaymentCache { - by_proxy_id: HashMap, - by_payment_hash_key: HashMap, - by_payment_hash: HashMap, - pending_payments: HashMap, + conn: Arc>, } impl RgbPaymentCache { - fn new() -> Self { - Self::default() - } + fn new(conn: Arc>) -> Self { + Self { conn } + } - pub fn get_by_proxy_id_key(&self, proxy_id: &ProxyIdKey) -> Option<&RgbPaymentInfo> { - self.by_proxy_id.get(proxy_id) - } + pub fn get_by_proxy_id_key(&self, proxy_id: &ProxyIdKey) -> Option { + let conn = self.conn.lock().unwrap(); + let mut stmt = conn.prepare("SELECT info FROM rgb_payments WHERE channel_id = ? AND payment_hash = ? AND direction = ?").unwrap(); + let result = stmt.query_row(params![proxy_id.channel_id.to_string(), proxy_id.payment_hash.to_string(), proxy_id.direction.to_string()], + |row| { + let info_str: String = row.get(0)?; + Ok(serde_json::from_str(&info_str).unwrap()) + }); + result.ok() + } - pub fn resolve_channel_id(&self, payment_hash: &PaymentHash) -> Option { - for key in self.by_proxy_id.keys() { - if key.payment_hash == *payment_hash { - return Some(key.channel_id.clone()); - } - } - None - } + pub fn resolve_channel_id(&self, payment_hash: &PaymentHash) -> Option { + let conn = self.conn.lock().unwrap(); + let mut stmt = conn.prepare("SELECT channel_id FROM rgb_payments WHERE payment_hash = ? LIMIT 1").unwrap(); + let result = stmt.query_row(params![payment_hash.to_string()], + |row| { + let channel_id_str: String = row.get(0)?; + Ok(channel_from_str(&channel_id_str)) + }); + result.ok() + } - pub fn get_by_payment_hash(&self, payment_hash: &PaymentHash) -> Option<&RgbPaymentInfo> { - self.by_payment_hash.get(payment_hash) - } + pub fn get_by_payment_hash(&self, payment_hash: &PaymentHash) -> Option { + let conn = self.conn.lock().unwrap(); + let mut stmt = conn.prepare("SELECT info FROM rgb_payments WHERE payment_hash = ? LIMIT 1").unwrap(); + let result = stmt.query_row(params![payment_hash.to_string()], + |row| { + let info_str: String = row.get(0)?; + Ok(serde_json::from_str(&info_str).unwrap()) + }); + result.ok() + } - pub fn get_by_payment_hash_key( - &self, payment_hash_key: &PaymentHashKey, - ) -> Option<&RgbPaymentInfo> { - self.by_payment_hash_key.get(payment_hash_key) - } + pub fn get_by_payment_hash_key(&self, payment_hash_key: &PaymentHashKey) -> Option { + let conn = self.conn.lock().unwrap(); + let mut stmt = conn.prepare("SELECT info FROM rgb_payments WHERE payment_hash = ? AND direction = ?").unwrap(); + let result = stmt.query_row(params![payment_hash_key.payment_hash.to_string(), payment_hash_key.direction.to_string()], + |row| { + let info_str: String = row.get(0)?; + Ok(serde_json::from_str(&info_str).unwrap()) + }); + result.ok() + } - pub fn is_pending(&self, payment_hash: &PaymentHash) -> bool { - self.pending_payments.contains_key(payment_hash) - } + pub fn is_pending(&self, payment_hash: &PaymentHash) -> bool { + let conn = self.conn.lock().unwrap(); + let mut stmt = conn.prepare("SELECT 1 FROM pending_payments WHERE payment_hash = ?").unwrap(); + stmt.exists(params![payment_hash.to_string()]).unwrap() + } - pub fn get_pending_payment(&self, payment_hash: &PaymentHash) -> Option<&RgbPaymentInfo> { - self.pending_payments.get(payment_hash) - } + pub fn get_pending_payment(&self, payment_hash: &PaymentHash) -> Option { + let conn = self.conn.lock().unwrap(); + let mut stmt = conn.prepare("SELECT info FROM pending_payments WHERE payment_hash = ?").unwrap(); + let result = stmt.query_row(params![payment_hash.to_string()], + |row| { + let info_str: String = row.get(0)?; + Ok(serde_json::from_str(&info_str).unwrap()) + }); + result.ok() + } pub fn insert(&mut self, proxy_id_key: &ProxyIdKey, info: RgbPaymentInfo, is_pending: bool) { - self.by_proxy_id.insert(proxy_id_key.clone(), info.clone()); - self.by_payment_hash_key.insert(proxy_id_key.into(), info.clone()); - self.by_payment_hash.insert(proxy_id_key.payment_hash, info.clone()); + let conn = self.conn.lock().unwrap(); + let info_str = serde_json::to_string(&info).unwrap(); + + // Check if the payment_hash already exists in the rgb_payments table + let mut stmt = conn.prepare("SELECT COUNT(*) FROM rgb_payments WHERE payment_hash = ?").unwrap(); + let mut rows = stmt.query(params![proxy_id_key.payment_hash.to_string()]).unwrap(); + let exists = rows.next().unwrap().unwrap().get::<_, i64>(0).unwrap() > 0; + + if exists { + // Update the existing entry + conn.execute( + "UPDATE rgb_payments SET channel_id = ?, direction = ?, info = ? WHERE payment_hash = ?", + params![proxy_id_key.channel_id.to_string(), proxy_id_key.direction.to_string(), info_str, proxy_id_key.payment_hash.to_string()], + ).unwrap(); + } else { + // Insert a new entry + conn.execute( + "INSERT INTO rgb_payments (channel_id, payment_hash, direction, info) VALUES (?, ?, ?, ?)", + params![proxy_id_key.channel_id.to_string(), proxy_id_key.payment_hash.to_string(), proxy_id_key.direction.to_string(), info_str], + ).unwrap(); + } + if is_pending { - self.pending_payments.insert(proxy_id_key.payment_hash, info.clone()); + // Check if the payment_hash already exists in the pending_payments table + let mut stmt = conn.prepare("SELECT COUNT(*) FROM pending_payments WHERE payment_hash = ?").unwrap(); + let mut rows = stmt.query(params![proxy_id_key.payment_hash.to_string()]).unwrap(); + let pending_exists = rows.next().unwrap().unwrap().get::<_, i64>(0).unwrap() > 0; + + if pending_exists { + // Update the existing pending entry + conn.execute( + "UPDATE pending_payments SET info = ? WHERE payment_hash = ?", + params![info_str, proxy_id_key.payment_hash.to_string()], + ).unwrap(); + } else { + // Insert a new pending entry + conn.execute( + "INSERT INTO pending_payments (payment_hash, info) VALUES (?, ?)", + params![proxy_id_key.payment_hash.to_string(), info_str], + ).unwrap(); + } } else { - self.pending_payments.remove(&proxy_id_key.payment_hash); + // Delete the entry from pending_payments if it exists + conn.execute( + "DELETE FROM pending_payments WHERE payment_hash = ?", + params![proxy_id_key.payment_hash.to_string()], + ).unwrap(); } } + + pub fn insert_without_proxy_id(&mut self, payment_hash_key: &PaymentHashKey, info: RgbPaymentInfo) { + let conn = self.conn.lock().unwrap(); + let info_str = serde_json::to_string(&info).unwrap(); + conn.execute( + "INSERT OR REPLACE INTO rgb_payments (payment_hash, direction, info) VALUES (?, ?, ?)", + params![payment_hash_key.payment_hash.to_string(), payment_hash_key.direction.to_string(), info_str], + ).unwrap(); + conn.execute( + "INSERT OR REPLACE INTO pending_payments (payment_hash, info) VALUES (?, ?)", + params![payment_hash_key.payment_hash.to_string(), info_str], + ).unwrap(); + } - pub fn insert_without_proxy_id( - &mut self, payment_hash_key: &PaymentHashKey, info: RgbPaymentInfo, - ) { - self.by_payment_hash_key.insert(payment_hash_key.clone(), info.clone()); - self.by_payment_hash.insert(payment_hash_key.payment_hash, info.clone()); - self.pending_payments.insert(payment_hash_key.payment_hash, info.clone()); - } - - pub fn remove(&mut self, proxy_id: &ProxyIdKey) { - self.by_proxy_id.remove(proxy_id); - self.by_payment_hash_key.remove(&proxy_id.into()); - self.by_payment_hash.remove(&proxy_id.payment_hash); - } + pub fn remove(&mut self, proxy_id: &ProxyIdKey) { + let conn = self.conn.lock().unwrap(); + conn.execute( + "DELETE FROM rgb_payments WHERE channel_id = ? AND payment_hash = ? AND direction = ?", + params![proxy_id.channel_id.to_string(), proxy_id.payment_hash.to_string(), proxy_id.direction.to_string()], + ).unwrap(); + conn.execute( + "DELETE FROM pending_payments WHERE payment_hash = ?", + params![proxy_id.payment_hash.to_string()], + ).unwrap(); + } } -#[derive(Clone, Debug, Default)] +#[derive(Clone, Debug)] pub(crate) struct TransferInfoCache { - by_txid: HashMap, + conn: Arc>, } impl TransferInfoCache { - fn new() -> Self { - Self::default() - } + fn new(conn: Arc>) -> Self { + Self { conn } + } - pub fn get_by_txid(&self, txid: &Txid) -> Option<&TransferInfo> { - self.by_txid.get(txid) - } + pub fn get_by_txid(&self, txid: &Txid) -> Option { + let conn = self.conn.lock().unwrap(); + let mut stmt = conn.prepare("SELECT info FROM transfer_info WHERE txid = ?").unwrap(); + let result = stmt.query_row(params![txid.to_string()], + |row| { + let info_str: String = row.get(0)?; + Ok(serde_json::from_str(&info_str).unwrap()) + }); + result.ok() + } - pub fn insert(&mut self, txid: Txid, info: TransferInfo) { - self.by_txid.insert(txid, info); - } + pub fn insert(&mut self, txid: Txid, info: TransferInfo) { + let conn = self.conn.lock().unwrap(); + let info_str = serde_json::to_string(&info).unwrap(); + conn.execute( + "INSERT OR REPLACE INTO transfer_info (txid, info) VALUES (?, ?)", + params![txid.to_string(), info_str], + ).unwrap(); + } - pub fn remove(&mut self, txid: &Txid) { - self.by_txid.remove(txid); - } + pub fn remove(&mut self, txid: &Txid) { + let conn = self.conn.lock().unwrap(); + conn.execute( + "DELETE FROM transfer_info WHERE txid = ?", + params![txid.to_string()], + ).unwrap(); + } } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct RgbInfoKey { - channel_id: ChannelId, - is_pending: bool, + channel_id: ChannelId, + is_pending: bool, } impl Display for RgbInfoKey { - fn fmt(&self, f: &mut Formatter) -> Result<(), std::fmt::Error> { - write!(f, "{}.{}", self.channel_id, if self.is_pending { "pending" } else { "" }) - } + fn fmt(&self, f: &mut Formatter) -> Result<(), std::fmt::Error> { + write!(f, "{}.{}", self.channel_id, if self.is_pending { "pending" } else { "" }) + } } impl RgbInfoKey { - pub fn new(channel_id: &ChannelId, is_pending: bool) -> Self { - Self { channel_id: channel_id.clone(), is_pending } - } + pub fn new(channel_id: &ChannelId, is_pending: bool) -> Self { + Self { channel_id: channel_id.clone(), is_pending } + } } -#[derive(Clone, Debug, Default)] +#[derive(Clone, Debug)] pub(crate) struct RgbInfoCache { - by_rgb_info_key: HashMap, + conn: Arc>, } impl RgbInfoCache { - fn new() -> Self { - Self::default() - } + fn new(conn: Arc>) -> Self { + Self { conn } + } - pub fn get_by_rgb_info_key(&self, rgb_info_key: &RgbInfoKey) -> Option<&RgbInfo> { - self.by_rgb_info_key.get(rgb_info_key) - } + pub fn get_by_rgb_info_key(&self, rgb_info_key: &RgbInfoKey) -> Option { + let conn = self.conn.lock().unwrap(); + let mut stmt = conn.prepare("SELECT info FROM rgb_info WHERE channel_id = ? AND is_pending = ?").unwrap(); + let result = stmt.query_row(params![rgb_info_key.channel_id.to_string(), rgb_info_key.is_pending], + |row| { + let info_str: String = row.get(0)?; + Ok(serde_json::from_str(&info_str).unwrap()) + }); + result.ok() + } - pub fn insert(&mut self, rgb_info_key: RgbInfoKey, info: RgbInfo) { - self.by_rgb_info_key.insert(rgb_info_key, info); - } + pub fn insert(&mut self, rgb_info_key: RgbInfoKey, info: RgbInfo) { + let conn = self.conn.lock().unwrap(); + let info_str = serde_json::to_string(&info).unwrap(); + conn.execute( + "INSERT OR REPLACE INTO rgb_info (channel_id, is_pending, info) VALUES (?, ?, ?)", + params![rgb_info_key.channel_id.to_string(), rgb_info_key.is_pending, info_str], + ).unwrap(); + } - pub fn remove(&mut self, rgb_info_key: &RgbInfoKey) { - self.by_rgb_info_key.remove(rgb_info_key); - } + pub fn remove(&mut self, rgb_info_key: &RgbInfoKey) { + let conn = self.conn.lock().unwrap(); + conn.execute( + "DELETE FROM rgb_info WHERE channel_id = ? AND is_pending = ?", + params![rgb_info_key.channel_id.to_string(), rgb_info_key.is_pending], + ).unwrap(); + } } #[derive(Clone, Debug, Default)] pub struct ConsignmentBinaryData(Vec); impl Write for ConsignmentBinaryData { - fn write(&mut self, buf: &[u8]) -> io::Result { - self.0.extend_from_slice(buf); - Ok(buf.len()) - } + fn write(&mut self, buf: &[u8]) -> io::Result { + self.0.extend_from_slice(buf); + Ok(buf.len()) + } - fn flush(&mut self) -> io::Result<()> { - Ok(()) - } + fn flush(&mut self) -> io::Result<()> { + Ok(()) + } } impl Read for ConsignmentBinaryData { @@ -245,123 +380,224 @@ impl Read for ConsignmentBinaryData { #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] pub struct ConsignmentHandle(usize); -#[derive(Clone, Debug, Default)] +#[derive(Clone, Debug)] pub(crate) struct ConsignmentCache { - by_channel_id: HashMap, - by_funding_txid: HashMap, - data_store: HashMap, - next_handle: usize, + conn: Arc>, + data_root: PathBuf, } impl ConsignmentCache { - fn new() -> Self { - Self::default() - } + fn new(conn: Arc>, data_root: PathBuf) -> Self { + Self { conn, data_root } + } - pub fn get_by_channel_id(&self, channel_id: &ChannelId) -> Option { - self.by_channel_id.get(channel_id).copied() - } + pub fn get_by_channel_id(&self, channel_id: &ChannelId) -> Option { + let conn = self.conn.lock().unwrap(); + let mut stmt = conn.prepare("SELECT handle FROM consignments WHERE channel_id = ?").unwrap(); + let result = stmt.query_row(params![channel_id.to_string()], + |row| { + let handle: usize = row.get(0)?; + Ok(ConsignmentHandle(handle)) + }); + result.ok() + } - pub fn get_by_funding_txid(&self, funding_txid: &Txid) -> Option { - self.by_funding_txid.get(funding_txid).copied() - } + pub fn get_by_funding_txid(&self, funding_txid: &Txid) -> Option { + let conn = self.conn.lock().unwrap(); + let mut stmt = conn.prepare("SELECT handle FROM consignments WHERE funding_txid = ?").unwrap(); + let result = stmt.query_row(params![funding_txid.to_string()], + |row| { + let handle: usize = row.get(0)?; + Ok(ConsignmentHandle(handle)) + }); + result.ok() + } - pub fn insert( - &mut self, channel_id: &ChannelId, funding_txid: Txid, info: ConsignmentBinaryData, - ) -> ConsignmentHandle { - let handle = ConsignmentHandle(self.next_handle); - self.next_handle += 1; - self.data_store.insert(handle, info); - self.by_channel_id.insert(channel_id.clone(), handle); - self.by_funding_txid.insert(funding_txid, handle); + pub fn insert( + &mut self, channel_id: &ChannelId, funding_txid: Txid, info: ConsignmentBinaryData, + ) -> ConsignmentHandle { + let conn = self.conn.lock().unwrap(); + let handle = conn.query_row( + "SELECT COALESCE(MAX(handle), 0) + 1 FROM consignments", + [], + |row| row.get::<_, usize>(0), + ).unwrap(); + + let file_path = self.data_root.join(format!("consignment_{}.bin", handle)); + std::fs::write(&file_path, &info.0).unwrap(); + + conn.execute( + "INSERT INTO consignments (handle, channel_id, funding_txid, file_path) VALUES (?, ?, ?, ?)", + params![handle, channel_id.to_string(), funding_txid.to_string(), file_path.to_str().unwrap()], + ).unwrap(); + + ConsignmentHandle(handle) + } - handle - } + pub fn remove(&mut self, channel_id: &ChannelId, funding_txid: Txid) { + let conn = self.conn.lock().unwrap(); + let mut stmt = conn.prepare("SELECT file_path FROM consignments WHERE channel_id = ? OR funding_txid = ?").unwrap(); + let file_paths: Vec = stmt.query_map( + params![channel_id.to_string(), funding_txid.to_string()], + |row| row.get(0) + ).unwrap().filter_map(Result::ok).collect(); - pub fn remove(&mut self, channel_id: &ChannelId, funding_txid: Txid) { - if let Some(handle) = self.by_channel_id.remove(channel_id) { - self.by_funding_txid.retain(|_, &mut v| v != handle); - self.data_store.remove(&handle); - } - self.by_funding_txid.remove(&funding_txid); - } + for file_path in file_paths { + std::fs::remove_file(file_path).unwrap_or_else(|e| eprintln!("Failed to remove file: {}", e)); + } - pub fn resolve(&self, handle: ConsignmentHandle) -> Option<&ConsignmentBinaryData> { - self.data_store.get(&handle) - } + conn.execute( + "DELETE FROM consignments WHERE channel_id = ? OR funding_txid = ?", + params![channel_id.to_string(), funding_txid.to_string()], + ).unwrap(); + } - pub fn rename_channel_id( - &mut self, handle: ConsignmentHandle, old_channel_id: &ChannelId, - new_channel_id: &ChannelId, - ) { - if let Some(old_handle) = self.by_channel_id.remove(old_channel_id) { - if old_handle == handle { - self.by_channel_id.insert(new_channel_id.clone(), handle); - } - } - } + pub fn resolve(&self, handle: ConsignmentHandle) -> Option { + let conn = self.conn.lock().unwrap(); + let mut stmt = conn.prepare("SELECT file_path FROM consignments WHERE handle = ?").unwrap(); + let result = stmt.query_row(params![handle.0], + |row| { + let file_path: String = row.get(0)?; + Ok(file_path) + }); + + if let Ok(file_path) = result { + std::fs::read(file_path).ok().map(ConsignmentBinaryData) + } else { + None + } + } + + pub fn rename_channel_id( + &mut self, handle: ConsignmentHandle, old_channel_id: &ChannelId, + new_channel_id: &ChannelId, + ) { + let conn = self.conn.lock().unwrap(); + conn.execute( + "UPDATE consignments SET channel_id = ? WHERE handle = ? AND channel_id = ?", + params![new_channel_id.to_string(), handle.0, old_channel_id.to_string()], + ).unwrap(); + } +} + +pub struct SqliteConnection { + conn: Arc>, +} + +impl SqliteConnection { + pub fn new(path: &PathBuf) -> Result { + let conn = Connection::open(path)?; + + // Create tables if they don't exist + conn.execute( + "CREATE TABLE IF NOT EXISTS rgb_payments ( + channel_id TEXT, + payment_hash TEXT, + direction TEXT, + info TEXT, + PRIMARY KEY (channel_id, payment_hash, direction) + )", + [], + )?; + + conn.execute( + "CREATE TABLE IF NOT EXISTS pending_payments ( + payment_hash TEXT PRIMARY KEY, + info TEXT + )", + [], + )?; + + conn.execute( + "CREATE TABLE IF NOT EXISTS transfer_info ( + txid TEXT PRIMARY KEY, + info TEXT + )", + [], + )?; + + conn.execute( + "CREATE TABLE IF NOT EXISTS rgb_info ( + channel_id TEXT, + is_pending BOOLEAN, + info TEXT, + PRIMARY KEY (channel_id, is_pending) + )", + [], + )?; + + conn.execute( + "CREATE TABLE IF NOT EXISTS consignments ( + handle INTEGER PRIMARY KEY, + channel_id TEXT, + funding_txid TEXT, + file_path TEXT + )", + [], + )?; + + Ok(Self { conn: Arc::new(Mutex::new(conn)) }) + } } -#[derive(Default, Debug)] +#[derive(Clone, Debug)] pub struct ColorDatabaseImpl { - rgb_payment_cache: Arc>, - transfer_info: Arc>, - rgb_info: Arc>, - consignment_cache: Arc>, + rgb_payment_cache: Arc>, + transfer_info: Arc>, + rgb_info: Arc>, + consignment_cache: Arc>, } impl ColorDatabaseImpl { - pub fn new() -> Self { - Self::default() - } + pub fn new(sqlite_conn: SqliteConnection, data_root: PathBuf) -> Self { + Self { + rgb_payment_cache: Arc::new(Mutex::new(RgbPaymentCache::new(sqlite_conn.conn.clone()))), + transfer_info: Arc::new(Mutex::new(TransferInfoCache::new(sqlite_conn.conn.clone()))), + rgb_info: Arc::new(Mutex::new(RgbInfoCache::new(sqlite_conn.conn.clone()))), + consignment_cache: Arc::new(Mutex::new(ConsignmentCache::new(sqlite_conn.conn.clone(), data_root))), + } + } - pub fn rgb_payment(&self) -> Arc> { - self.rgb_payment_cache.clone() - } + pub fn rgb_payment(&self) -> Arc> { + Arc::clone(&self.rgb_payment_cache) + } - pub fn transfer_info(&self) -> Arc> { - self.transfer_info.clone() - } + pub fn transfer_info(&self) -> Arc> { + Arc::clone(&self.transfer_info) + } - pub fn rgb_info(&self) -> Arc> { - self.rgb_info.clone() - } + pub fn rgb_info(&self) -> Arc> { + Arc::clone(&self.rgb_info) + } - pub fn consignment(&self) -> Arc> { - self.consignment_cache.clone() - } + pub fn consignment(&self) -> Arc> { + Arc::clone(&self.consignment_cache) + } - pub fn rename_channel_id(&self, old_channel_id: &ChannelId, new_channel_id: &ChannelId) { - let rgb_info_key = RgbInfoKey::new(old_channel_id, false); - let info = self.rgb_info().lock().unwrap().get_by_rgb_info_key(&rgb_info_key).cloned(); - if let Some(info) = info { - let new_info = info.clone(); - self.rgb_info() - .lock() - .unwrap() - .insert(RgbInfoKey::new(new_channel_id, false), new_info); - self.rgb_info().lock().unwrap().remove(&rgb_info_key); - } - println!("rename_channel_id before rgb_info_key_pending"); - - let rgb_info_key_pending = RgbInfoKey::new(old_channel_id, true); - let info = - self.rgb_info().lock().unwrap().get_by_rgb_info_key(&rgb_info_key_pending).cloned(); - if let Some(info) = info { - let new_info = info.clone(); - self.rgb_info().lock().unwrap().insert(RgbInfoKey::new(new_channel_id, true), new_info); - self.rgb_info().lock().unwrap().remove(&rgb_info_key_pending); - } - println!("rename_channel_id before consignment"); - - let consignment_handle = self.consignment().lock().unwrap().get_by_channel_id(old_channel_id); - if let Some(consignment_handle) = consignment_handle { - println!("rename_channel_id consignment"); - self.consignment().lock().unwrap().rename_channel_id( - consignment_handle, - old_channel_id, - new_channel_id, - ); - } - } + pub fn rename_channel_id(&self, old_channel_id: &ChannelId, new_channel_id: &ChannelId) { + let rgb_info_key = RgbInfoKey::new(old_channel_id, false); + let info = self.rgb_info().lock().unwrap().get_by_rgb_info_key(&rgb_info_key); + if let Some(info) = info { + let new_info = info.clone(); + self.rgb_info().lock().unwrap().insert(RgbInfoKey::new(new_channel_id, false), new_info); + self.rgb_info().lock().unwrap().remove(&rgb_info_key); + } + + let rgb_info_key_pending = RgbInfoKey::new(old_channel_id, true); + let info = self.rgb_info().lock().unwrap().get_by_rgb_info_key(&rgb_info_key_pending); + if let Some(info) = info { + let new_info = info.clone(); + self.rgb_info().lock().unwrap().insert(RgbInfoKey::new(new_channel_id, true), new_info); + self.rgb_info().lock().unwrap().remove(&rgb_info_key_pending); + } + + let consignment_handle = self.consignment().lock().unwrap().get_by_channel_id(old_channel_id); + if let Some(consignment_handle) = consignment_handle { + self.consignment().lock().unwrap().rename_channel_id( + consignment_handle, + old_channel_id, + new_channel_id, + ); + } + } } diff --git a/lightning/src/color_ext/mod.rs b/lightning/src/color_ext/mod.rs index c1a2e948938..6156ac620f0 100644 --- a/lightning/src/color_ext/mod.rs +++ b/lightning/src/color_ext/mod.rs @@ -158,7 +158,9 @@ impl ColorSourceImpl { ldk_data_dir: PathBuf, network: BitcoinNetwork, xpub: ExtendedPubKey, xprv: ExtendedPrivKey, ) -> Self { let ldk_data_dir = Arc::new(ldk_data_dir); - + let data_root = ldk_data_dir.clone().join("db_data"); + std::fs::create_dir_all(&data_root).expect("Failed to create data root directory"); + let sqlite_conn = database::SqliteConnection::new(&data_root.join("rgb.db")).unwrap(); let instance = Self { ldk_data_dir: Arc::clone(&ldk_data_dir).to_path_buf(), network, @@ -168,7 +170,10 @@ impl ColorSourceImpl { xprv, Arc::clone(&ldk_data_dir).to_path_buf(), ), - database: ColorDatabaseImpl::new(), + database: ColorDatabaseImpl::new( + sqlite_conn, + data_root + ), }; instance @@ -298,8 +303,7 @@ impl ColorSourceImpl { .rgb_payment() .lock() .unwrap() - .get_pending_payment(&htlc.payment_hash) - .cloned(); + .get_pending_payment(&htlc.payment_hash); if let Some(mut rgb_payment_info) = info { rgb_payment_info.local_rgb_amount = rgb_info.local_rgb_amount; @@ -316,8 +320,7 @@ impl ColorSourceImpl { .rgb_payment() .lock() .unwrap() - .get_by_proxy_id_key(&proxy_id_key) - .cloned(); + .get_by_proxy_id_key(&proxy_id_key); let rgb_payment_info = info.unwrap_or_else(|| { let info = RgbPaymentInfo { @@ -690,7 +693,7 @@ impl ColorSourceImpl { ) -> Option { let handle = self.database.consignment().lock().unwrap().get_by_funding_txid(funding_txid)?; - self.database.consignment().lock().unwrap().resolve(handle).cloned() + self.database.consignment().lock().unwrap().resolve(handle) } /// Update RGB channel amount @@ -786,8 +789,7 @@ impl ColorSourceImpl { .get_by_payment_hash_key(&PaymentHashKey::new( payment_hash.clone(), PaymentDirection::from(receiver), - )) - .cloned(); + )); if payment.is_none() { return; } @@ -797,7 +799,7 @@ impl ColorSourceImpl { self.database.rgb_payment().lock().unwrap().resolve_channel_id(payment_hash); if channel_id.is_none() { - panic!("failed to resolve channel id, which is a bug or of broken data."); + panic!("failed to resolve channel id by paymenthash {}, which is a bug or of broken data.", payment_hash); } let channel_id = channel_id.unwrap(); @@ -808,6 +810,6 @@ impl ColorSourceImpl { } pub fn get_transfer_info(&self, txid: &Txid) -> Option { - self.database.transfer_info().lock().unwrap().get_by_txid(txid).cloned() + self.database.transfer_info().lock().unwrap().get_by_txid(txid) } }