diff --git a/src/test_utils/katana/mod.rs b/src/test_utils/katana/mod.rs index 33d6cbdc7..9d72eb1e2 100644 --- a/src/test_utils/katana/mod.rs +++ b/src/test_utils/katana/mod.rs @@ -24,6 +24,7 @@ use crate::eth_provider::{ provider::EthDataProvider, }; use crate::test_utils::eoa::KakarotEOA; +use reth_primitives::Address; #[cfg(any(test, feature = "arbitrary", feature = "testing"))] use { @@ -266,6 +267,58 @@ impl<'a> Katana { }) } + pub fn logs_by_address(&self, addresses: &[Address]) -> Vec { + self.mock_data.get(&CollectionDB::Logs).map_or_else(Vec::new, |logs| { + logs.iter() + .filter_map(|data| data.extract_stored_log()) + .filter(|stored_log| { + let address = stored_log.log.address(); + addresses.iter().any(|addr| *addr == address) + }) + .map(|stored_log| stored_log.log.clone()) + .collect() + }) + } + + pub fn logs_by_block_number(&self, block_number: u64) -> Vec { + self.mock_data.get(&CollectionDB::Logs).map_or_else(Vec::new, |logs| { + logs.iter() + .filter_map(|data| data.extract_stored_log()) + .filter(|stored_log| stored_log.log.block_number.unwrap_or_default() == block_number) + .map(|stored_log| stored_log.log.clone()) + .collect() + }) + } + + pub fn logs_by_block_range(&self, block_range: std::ops::Range) -> Vec { + self.mock_data.get(&CollectionDB::Logs).map_or_else(Vec::new, |logs| { + logs.iter() + .filter_map(|data| data.extract_stored_log()) + .filter(|stored_log| { + let block_number = stored_log.log.block_number.unwrap_or_default(); + block_range.contains(&block_number) + }) + .map(|stored_log| stored_log.log.clone()) + .collect() + }) + } + + pub fn logs_by_block_hash(&self, block_hash: B256) -> Vec { + self.mock_data.get(&CollectionDB::Logs).map_or_else(Vec::new, |logs| { + logs.iter() + .filter_map(|data| data.extract_stored_log()) + .filter(|stored_log| stored_log.log.block_hash.unwrap_or_default() == block_hash) + .map(|stored_log| stored_log.log.clone()) + .collect() + }) + } + + pub fn all_logs(&self) -> Vec { + self.mock_data.get(&CollectionDB::Logs).map_or_else(Vec::new, |logs| { + logs.iter().filter_map(|data| data.extract_stored_log()).map(|stored_log| stored_log.log.clone()).collect() + }) + } + /// Retrieves the number of blocks in the database pub fn count_block(&self) -> usize { self.mock_data.get(&CollectionDB::Headers).map_or(0, std::vec::Vec::len) diff --git a/src/test_utils/mongo/mod.rs b/src/test_utils/mongo/mod.rs index 502a0c47f..83de74222 100644 --- a/src/test_utils/mongo/mod.rs +++ b/src/test_utils/mongo/mod.rs @@ -251,17 +251,6 @@ impl MongoFuzzer { Ok(()) } - /// Gets the highest block number in the logs collection. - pub fn max_block_number_in_logs(&self) -> u64 { - self.documents - .get(&CollectionDB::Logs) - .unwrap() - .iter() - .map(|log| log.extract_stored_log().unwrap().log.block_number.unwrap_or_default()) - .max() - .unwrap_or_default() - } - /// Gets the highest block number in the transactions collection. pub fn max_block_number(&self) -> u64 { self.documents @@ -310,6 +299,7 @@ impl MongoFuzzer { let mut modified_logs = (*receipt.receipt.inner.as_receipt_with_bloom().unwrap()).clone(); for log in &mut modified_logs.receipt.logs { log.block_number = Some(transaction.block_number.unwrap_or_default()); + log.block_hash = transaction.block_hash; } receipt.receipt.transaction_hash = transaction.hash; diff --git a/tests/tests/eth_provider.rs b/tests/tests/eth_provider.rs index 04ffd82e4..a16fff8d9 100644 --- a/tests/tests/eth_provider.rs +++ b/tests/tests/eth_provider.rs @@ -283,14 +283,83 @@ async fn test_get_logs_block_range(#[future] katana: Katana, _setup: ()) { assert!(!logs.is_empty()); } +/// Utility function to filter logs using the Ethereum provider. +/// Takes a filter and a provider, and returns the corresponding logs. async fn filter_logs(filter: Filter, provider: Arc) -> Vec { + // Call the provider to get logs using the filter. let logs = provider.get_logs(filter).await.expect("Failed to get logs"); + // If the result contains logs, return them, otherwise panic with an error. match logs { FilterChanges::Logs(logs) => logs, _ => panic!("Expected logs"), } } +#[rstest] +#[awt] +#[tokio::test(flavor = "multi_thread")] +async fn test_get_logs_block_filter(#[future] katana: Katana, _setup: ()) { + // Get the Ethereum provider from Katana. + let provider = katana.eth_provider(); + + // Get the first transaction from Katana. + let first_transaction = katana.first_transaction().unwrap(); + let block_number = first_transaction.block_number.unwrap(); + let block_hash = first_transaction.block_hash.unwrap(); + + // Get logs by block number from Katana. + let logs_katana_block_number = katana.logs_by_block_number(block_number); + // Get logs for a range of blocks from Katana. + let logs_katana_block_range = katana.logs_by_block_range(0..u64::MAX / 2); + // Get logs by block hash from Katana. + let logs_katana_block_hash = katana.logs_by_block_hash(block_hash); + // Get all logs from Katana. + let all_logs_katana = katana.all_logs(); + + // Verify logs filtered by block number. + assert_eq!(filter_logs(Filter::default().select(block_number), provider.clone()).await, logs_katana_block_number); + // Verify logs filtered by block hash. + assert_eq!(filter_logs(Filter::default().select(block_hash), provider.clone()).await, logs_katana_block_hash); + // Verify all logs. + assert_eq!(filter_logs(Filter::default().select(0..), provider.clone()).await, all_logs_katana); + // Verify logs filtered by a range of blocks. + assert_eq!(filter_logs(Filter::default().select(0..u64::MAX / 2), provider.clone()).await, logs_katana_block_range); + // Verify that filtering by an empty range returns an empty result. + assert!(filter_logs(Filter::default().select(0..0), provider.clone()).await.is_empty()); +} + +#[rstest] +#[awt] +#[tokio::test(flavor = "multi_thread")] +async fn test_get_logs_address_filter(#[future] katana: Katana, _setup: ()) { + // Get the Ethereum provider from Katana. + let provider = katana.eth_provider(); + + // Get all logs from Katana. + let all_logs_katana = katana.all_logs(); + + // Get the first log address, or default address if logs are empty. + let first_address = if all_logs_katana.is_empty() { Address::default() } else { all_logs_katana[0].address() }; + // Verify logs filtered by the first address. + assert_eq!( + filter_logs(Filter::new().address(vec![first_address]), provider.clone()).await, + katana.logs_by_address(&[first_address]) + ); + + // Create a vector to store a few addresses. + let some_addresses: Vec<_> = all_logs_katana.iter().take(2).map(Log::address).collect(); + // Verify logs filtered by these few addresses. + assert_eq!( + filter_logs(Filter::new().address(some_addresses.clone()), provider.clone()).await, + katana.logs_by_address(&some_addresses) + ); + + // Create a vector to store all addresses. + let all_addresses: Vec<_> = all_logs_katana.iter().map(Log::address).collect(); + // Verify that all logs are retrieved when filtered by all addresses. + assert_eq!(filter_logs(Filter::new().address(all_addresses), provider.clone()).await, all_logs_katana); +} + #[rstest] #[awt] #[tokio::test(flavor = "multi_thread")]