diff --git a/relays/client-substrate/src/transaction_tracker.rs b/relays/client-substrate/src/transaction_tracker.rs index b85e859017f7..a84a46240a4a 100644 --- a/relays/client-substrate/src/transaction_tracker.rs +++ b/relays/client-substrate/src/transaction_tracker.rs @@ -19,7 +19,7 @@ use crate::{Chain, HashOf, Subscription, TransactionStatusOf}; use async_trait::async_trait; -use futures::{Stream, StreamExt}; +use futures::{future::Either, Future, FutureExt, Stream, StreamExt}; use relay_utils::TrackedTransactionStatus; use std::time::Duration; @@ -61,28 +61,50 @@ impl TransactionTracker { /// Wait for final transaction status and return it along with last known internal invalidation /// status. - async fn do_wait(self) -> (TrackedTransactionStatus, InvalidationStatus) { - let invalidation_status = watch_transaction_status::( + async fn do_wait( + self, + wait_for_stall_timeout: impl Future, + wait_for_stall_timeout_rest: impl Future, + ) -> (TrackedTransactionStatus, Option) { + // sometimes we want to wait for the rest of the stall timeout even if + // `wait_for_invalidation` has been "select"ed first => it is shared + let wait_for_invalidation = watch_transaction_status::( self.transaction_hash, self.subscription.into_stream(), - ) - .await; - match invalidation_status { - InvalidationStatus::Finalized => - (TrackedTransactionStatus::Finalized, invalidation_status), - InvalidationStatus::Invalid => (TrackedTransactionStatus::Lost, invalidation_status), - InvalidationStatus::Lost => { - async_std::task::sleep(self.stall_timeout).await; - // if someone is still watching for our transaction, then we're reporting - // an error here (which is treated as "transaction lost") + ); + futures::pin_mut!(wait_for_stall_timeout, wait_for_invalidation); + + match futures::future::select(wait_for_stall_timeout, wait_for_invalidation).await { + Either::Left((_, _)) => { log::trace!( target: "bridge", - "{} transaction {:?} is considered lost after timeout", + "{} transaction {:?} is considered lost after timeout (no status response from the node)", C::NAME, self.transaction_hash, ); - (TrackedTransactionStatus::Lost, invalidation_status) + (TrackedTransactionStatus::Lost, None) + }, + Either::Right((invalidation_status, _)) => match invalidation_status { + InvalidationStatus::Finalized => + (TrackedTransactionStatus::Finalized, Some(invalidation_status)), + InvalidationStatus::Invalid => + (TrackedTransactionStatus::Lost, Some(invalidation_status)), + InvalidationStatus::Lost => { + // wait for the rest of stall timeout - this way we'll be sure that the + // transaction is actually dead if it has been crafted properly + wait_for_stall_timeout_rest.await; + // if someone is still watching for our transaction, then we're reporting + // an error here (which is treated as "transaction lost") + log::trace!( + target: "bridge", + "{} transaction {:?} is considered lost after timeout", + C::NAME, + self.transaction_hash, + ); + + (TrackedTransactionStatus::Lost, Some(invalidation_status)) + }, }, } } @@ -91,7 +113,9 @@ impl TransactionTracker { #[async_trait] impl relay_utils::TransactionTracker for TransactionTracker { async fn wait(self) -> TrackedTransactionStatus { - self.do_wait().await.0 + let wait_for_stall_timeout = async_std::task::sleep(self.stall_timeout).shared(); + let wait_for_stall_timeout_rest = wait_for_stall_timeout.clone(); + self.do_wait(wait_for_stall_timeout, wait_for_stall_timeout_rest).await.0 } } @@ -233,8 +257,13 @@ mod tests { Subscription(async_std::sync::Mutex::new(receiver)), ); + let wait_for_stall_timeout = futures::future::pending(); + let wait_for_stall_timeout_rest = futures::future::ready(()); sender.send(Some(status)).await.unwrap(); - tx_tracker.do_wait().now_or_never() + tx_tracker + .do_wait(wait_for_stall_timeout, wait_for_stall_timeout_rest) + .now_or_never() + .map(|(ts, is)| (ts, is.unwrap())) } #[async_std::test] @@ -319,4 +348,22 @@ mod tests { Some(InvalidationStatus::Lost), ); } + + #[async_std::test] + async fn lost_on_timeout_when_waiting_for_invalidation_status() { + let (_sender, receiver) = futures::channel::mpsc::channel(1); + let tx_tracker = TransactionTracker::::new( + Duration::from_secs(0), + Default::default(), + Subscription(async_std::sync::Mutex::new(receiver)), + ); + + let wait_for_stall_timeout = futures::future::ready(()).shared(); + let wait_for_stall_timeout_rest = wait_for_stall_timeout.clone(); + let wait_result = tx_tracker + .do_wait(wait_for_stall_timeout, wait_for_stall_timeout_rest) + .now_or_never(); + + assert_eq!(wait_result, Some((TrackedTransactionStatus::Lost, None))); + } }