From 7bb9da0cac81c5793add740e9dfd5a5eeb342871 Mon Sep 17 00:00:00 2001 From: Brett Boston Date: Wed, 15 Jan 2025 15:35:20 -0800 Subject: [PATCH] recursive_mutex -> mutex --- src/herder/TransactionQueue.cpp | 88 +++++++++++++++++++++------------ src/herder/TransactionQueue.h | 60 +++++++++++++++------- 2 files changed, 99 insertions(+), 49 deletions(-) diff --git a/src/herder/TransactionQueue.cpp b/src/herder/TransactionQueue.cpp index 5ce8f8503c..0603048262 100644 --- a/src/herder/TransactionQueue.cpp +++ b/src/herder/TransactionQueue.cpp @@ -270,7 +270,7 @@ isDuplicateTx(TransactionFrameBasePtr oldTx, TransactionFrameBasePtr newTx) bool TransactionQueue::sourceAccountPending(AccountID const& accountID) const { - std::lock_guard guard(mTxQueueMutex); + std::lock_guard guard(mTxQueueMutex); return mAccountStates.find(accountID) != mAccountStates.end(); } @@ -334,7 +334,7 @@ TransactionQueue::canAdd( std::vector>& txsToEvict) { ZoneScoped; - if (isBanned(tx->getFullHash())) + if (isBannedInternal(tx->getFullHash())) { return AddResult( TransactionQueue::AddResultCode::ADD_STATUS_TRY_AGAIN_LATER); @@ -436,7 +436,7 @@ TransactionQueue::canAdd( mTxQueueLimiter.canAddTx(tx, currentTx, txsToEvict, ledgerVersion); if (!canAddRes.first) { - ban({tx}); + banInternal({tx}); if (canAddRes.second != 0) { AddResult result(TransactionQueue::AddResultCode::ADD_STATUS_ERROR, @@ -454,10 +454,6 @@ TransactionQueue::canAdd( // This is done so minSeqLedgerGap is validated against the next // ledgerSeq, which is what will be used at apply time ++ls.getLedgerHeader().currentToModify().ledgerSeq; - // TODO: ^^ I think this is the right thing to do. Was previously the - // commented out line below. - // ls.getLedgerHeader().currentToModify().ledgerSeq = - // mApp.getLedgerManager().getLastClosedLedgerNum() + 1; } auto txResult = @@ -645,7 +641,7 @@ TransactionQueue::AddResult TransactionQueue::tryAdd(TransactionFrameBasePtr tx, bool submittedFromSelf) { ZoneScoped; - std::lock_guard guard(mTxQueueMutex); + std::lock_guard guard(mTxQueueMutex); auto c1 = tx->getEnvelope().type() == ENVELOPE_TYPE_TX_FEE_BUMP && @@ -701,8 +697,9 @@ TransactionQueue::tryAdd(TransactionFrameBasePtr tx, bool submittedFromSelf) // make space so that we can add this transaction // this will succeed as `canAdd` ensures that this is the case mTxQueueLimiter.evictTransactions( - txsToEvict, *tx, - [&](TransactionFrameBasePtr const& txToEvict) { ban({txToEvict}); }); + txsToEvict, *tx, [&](TransactionFrameBasePtr const& txToEvict) { + banInternal({txToEvict}); + }); mTxQueueLimiter.addTransaction(tx); mKnownTxHashes[tx->getFullHash()] = tx; @@ -806,7 +803,14 @@ void TransactionQueue::ban(Transactions const& banTxs) { ZoneScoped; - std::lock_guard guard(mTxQueueMutex); + std::lock_guard guard(mTxQueueMutex); + banInternal(banTxs); +} + +void +TransactionQueue::banInternal(Transactions const& banTxs) +{ + ZoneScoped; auto& bannedFront = mBannedTransactions.front(); // Group the transactions by source account and ban all the transactions @@ -852,7 +856,7 @@ TransactionQueue::AccountState TransactionQueue::getAccountTransactionQueueInfo( AccountID const& accountID) const { - std::lock_guard guard(mTxQueueMutex); + std::lock_guard guard(mTxQueueMutex); auto i = mAccountStates.find(accountID); if (i == std::end(mAccountStates)) { @@ -864,7 +868,7 @@ TransactionQueue::getAccountTransactionQueueInfo( size_t TransactionQueue::countBanned(int index) const { - std::lock_guard guard(mTxQueueMutex); + std::lock_guard guard(mTxQueueMutex); return mBannedTransactions[index].size(); } #endif @@ -939,7 +943,13 @@ TransactionQueue::shift() bool TransactionQueue::isBanned(Hash const& hash) const { - std::lock_guard guard(mTxQueueMutex); + std::lock_guard guard(mTxQueueMutex); + return isBannedInternal(hash); +} + +bool +TransactionQueue::isBannedInternal(Hash const& hash) const +{ return std::any_of( std::begin(mBannedTransactions), std::end(mBannedTransactions), [&](UnorderedSet const& transactions) { @@ -951,7 +961,14 @@ TxFrameList TransactionQueue::getTransactions(LedgerHeader const& lcl) const { ZoneScoped; - std::lock_guard guard(mTxQueueMutex); + std::lock_guard guard(mTxQueueMutex); + return getTransactionsInternal(lcl); +} + +TxFrameList +TransactionQueue::getTransactionsInternal(LedgerHeader const& lcl) const +{ + ZoneScoped; TxFrameList txs; uint32_t const nextLedgerSeq = lcl.ledgerSeq + 1; @@ -972,7 +989,7 @@ TransactionFrameBaseConstPtr TransactionQueue::getTx(Hash const& hash) const { ZoneScoped; - std::lock_guard guard(mTxQueueMutex); + std::lock_guard guard(mTxQueueMutex); auto it = mKnownTxHashes.find(hash); if (it != mKnownTxHashes.end()) { @@ -1184,6 +1201,8 @@ SorobanTransactionQueue::broadcastSome() size_t SorobanTransactionQueue::getMaxQueueSizeOps() const { + ZoneScoped; + std::lock_guard guard(mTxQueueMutex); if (protocolVersionStartsFrom( mBucketSnapshot->getLedgerHeader().ledgerVersion, SOROBAN_PROTOCOL_VERSION)) @@ -1264,7 +1283,7 @@ ClassicTransactionQueue::broadcastSome() std::make_shared(opsToFlood, dexOpsToFlood), mBroadcastSeed); queue.visitTopTxs(txsToBroadcast, visitor, mBroadcastOpCarryover); - ban(banningTxs); + banInternal(banningTxs); // carry over remainder, up to MAX_OPS_PER_TX ops // reason is that if we add 1 next round, we can flood a "worst case fee // bump" tx @@ -1277,15 +1296,12 @@ ClassicTransactionQueue::broadcastSome() } void -TransactionQueue::broadcast(bool fromCallback) +TransactionQueue::broadcast(bool fromCallback, + std::lock_guard const& guard) { // Must be called from the main thread due to the use of `mBroadcastTimer` releaseAssert(threadIsMain()); - // NOTE: Although this is not a public function, it can be called from - // `mBroadcastTimer` and so it needs to be synchronized. - std::lock_guard guard(mTxQueueMutex); - if (mShutdown || (!fromCallback && mWaiting)) { return; @@ -1317,7 +1333,14 @@ TransactionQueue::broadcast(bool fromCallback) } void -TransactionQueue::rebroadcast() +TransactionQueue::broadcast(bool fromCallback) +{ + std::lock_guard guard(mTxQueueMutex); + broadcast(fromCallback, guard); +} + +void +TransactionQueue::rebroadcast(std::lock_guard const& guard) { // For `broadcast` call releaseAssert(threadIsMain()); @@ -1331,14 +1354,14 @@ TransactionQueue::rebroadcast() as.mTransaction->mBroadcasted = false; } } - broadcast(false); + broadcast(false, guard); } void TransactionQueue::shutdown() { releaseAssert(threadIsMain()); - std::lock_guard guard(mTxQueueMutex); + std::lock_guard guard(mTxQueueMutex); mShutdown = true; mBroadcastTimer.cancel(); } @@ -1351,7 +1374,7 @@ TransactionQueue::update( { ZoneScoped; releaseAssert(threadIsMain()); - std::lock_guard guard(mTxQueueMutex); + std::lock_guard guard(mTxQueueMutex); mValidationSnapshot = std::make_shared(mAppConn); @@ -1361,11 +1384,11 @@ TransactionQueue::update( removeApplied(applied); shift(); - auto txs = getTransactions(lcl); + auto txs = getTransactionsInternal(lcl); auto invalidTxs = filterInvalidTxs(txs); - ban(invalidTxs); + banInternal(invalidTxs); - rebroadcast(); + rebroadcast(guard); } static bool @@ -1409,14 +1432,14 @@ TransactionQueue::isFiltered(TransactionFrameBasePtr tx) const size_t TransactionQueue::getQueueSizeOps() const { - std::lock_guard guard(mTxQueueMutex); + std::lock_guard guard(mTxQueueMutex); return mTxQueueLimiter.size(); } std::optional TransactionQueue::getInQueueSeqNum(AccountID const& account) const { - std::lock_guard guard(mTxQueueMutex); + std::lock_guard guard(mTxQueueMutex); auto stateIter = mAccountStates.find(account); if (stateIter == mAccountStates.end()) { @@ -1433,7 +1456,8 @@ TransactionQueue::getInQueueSeqNum(AccountID const& account) const size_t ClassicTransactionQueue::getMaxQueueSizeOps() const { - std::lock_guard guard(mTxQueueMutex); + ZoneScoped; + std::lock_guard guard(mTxQueueMutex); auto res = mTxQueueLimiter.maxScaledLedgerResources(false); releaseAssert(res.size() == NUM_CLASSIC_TX_RESOURCES); return res.getVal(Resource::Type::OPERATIONS); diff --git a/src/herder/TransactionQueue.h b/src/herder/TransactionQueue.h index f47df0c307..f5bfbcbc1c 100644 --- a/src/herder/TransactionQueue.h +++ b/src/herder/TransactionQueue.h @@ -210,10 +210,15 @@ class TransactionQueue virtual std::pair> getMaxResourcesToFloodThisPeriod() const = 0; virtual bool broadcastSome() = 0; - virtual int getFloodPeriod() const = 0; virtual bool allowTxBroadcast(TimestampedTx const& tx) = 0; + // TODO: Explain that there's an overload that takes a guard because this + // function is called internally, and also scheduled on a timer. Any async + // call should call the first overload (which grabs a lock), and any + // internal call should call the second overload (which enforces that the + // lock is already held). void broadcast(bool fromCallback); + void broadcast(bool fromCallback, std::lock_guard const& guard); // broadcasts a single transaction enum class BroadcastStatus { @@ -234,6 +239,12 @@ class TransactionQueue bool isFiltered(TransactionFrameBasePtr tx) const; + // TODO: Docs + // Protected versions of public functions that contain the actual + // implementation so they can be called internally when the lock is already + // held. + void banInternal(Transactions const& banTxs); + // Snapshots to use for transaction validation ImmutableValidationSnapshotPtr mValidationSnapshot; SearchableSnapshotConstPtr mBucketSnapshot; @@ -245,7 +256,7 @@ class TransactionQueue size_t mBroadcastSeed; - mutable std::recursive_mutex mTxQueueMutex; + mutable std::mutex mTxQueueMutex; private: AppConnector& mAppConn; @@ -259,10 +270,24 @@ class TransactionQueue */ void shift(); - void rebroadcast(); + // TODO: Explain that this takes a lock guard due to the `broadcast` call + // that it makes. + void rebroadcast(std::lock_guard const& guard); + + // TODO: Docs + // Private versions of public functions that contain the actual + // implementation so they can be called internally when the lock is already + // held. + bool isBannedInternal(Hash const& hash) const; + TxFrameList getTransactionsInternal(LedgerHeader const& lcl) const; + + virtual int getFloodPeriod() const = 0; #ifdef BUILD_TESTS public: + // TODO: These tests invoke protected/private functions directly that assume + // things are properly locked. I need to make sure these tests operate in a + // thread-safe manner or change them to not require private member access. friend class TransactionQueueTest; size_t getQueueSizeOps() const; @@ -278,19 +303,13 @@ class SorobanTransactionQueue : public TransactionQueue SearchableSnapshotConstPtr bucketSnapshot, uint32 pendingDepth, uint32 banDepth, uint32 poolLedgerMultiplier); - int - getFloodPeriod() const override - { - std::lock_guard guard(mTxQueueMutex); - return mValidationSnapshot->getConfig().FLOOD_SOROBAN_TX_PERIOD_MS; - } size_t getMaxQueueSizeOps() const override; #ifdef BUILD_TESTS void clearBroadcastCarryover() { - std::lock_guard guard(mTxQueueMutex); + std::lock_guard guard(mTxQueueMutex); mBroadcastOpCarryover.clear(); mBroadcastOpCarryover.resize(1, Resource::makeEmptySoroban()); } @@ -307,6 +326,13 @@ class SorobanTransactionQueue : public TransactionQueue { return true; } + + int + getFloodPeriod() const override + { + return mValidationSnapshot->getConfig().FLOOD_SOROBAN_TX_PERIOD_MS; + } + }; class ClassicTransactionQueue : public TransactionQueue @@ -317,13 +343,6 @@ class ClassicTransactionQueue : public TransactionQueue uint32 pendingDepth, uint32 banDepth, uint32 poolLedgerMultiplier); - int - getFloodPeriod() const override - { - std::lock_guard guard(mTxQueueMutex); - return mValidationSnapshot->getConfig().FLOOD_TX_PERIOD_MS; - } - size_t getMaxQueueSizeOps() const override; private: @@ -335,6 +354,13 @@ class ClassicTransactionQueue : public TransactionQueue virtual bool broadcastSome() override; std::vector mBroadcastOpCarryover; virtual bool allowTxBroadcast(TimestampedTx const& tx) override; + + int + getFloodPeriod() const override + { + return mValidationSnapshot->getConfig().FLOOD_TX_PERIOD_MS; + } + }; extern std::array