Skip to content

Commit

Permalink
recursive_mutex -> mutex
Browse files Browse the repository at this point in the history
  • Loading branch information
bboston7 committed Jan 15, 2025
1 parent 2a5af04 commit 7bb9da0
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 49 deletions.
88 changes: 56 additions & 32 deletions src/herder/TransactionQueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ isDuplicateTx(TransactionFrameBasePtr oldTx, TransactionFrameBasePtr newTx)
bool
TransactionQueue::sourceAccountPending(AccountID const& accountID) const
{
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
std::lock_guard<std::mutex> guard(mTxQueueMutex);
return mAccountStates.find(accountID) != mAccountStates.end();
}

Expand Down Expand Up @@ -334,7 +334,7 @@ TransactionQueue::canAdd(
std::vector<std::pair<TransactionFrameBasePtr, bool>>& txsToEvict)
{
ZoneScoped;
if (isBanned(tx->getFullHash()))
if (isBannedInternal(tx->getFullHash()))
{
return AddResult(
TransactionQueue::AddResultCode::ADD_STATUS_TRY_AGAIN_LATER);
Expand Down Expand Up @@ -436,7 +436,7 @@ TransactionQueue::canAdd(
mTxQueueLimiter.canAddTx(tx, currentTx, txsToEvict, ledgerVersion);
if (!canAddRes.first)
{
ban({tx});
banInternal({tx});
if (canAddRes.second != 0)
{
AddResult result(TransactionQueue::AddResultCode::ADD_STATUS_ERROR,
Expand All @@ -454,10 +454,6 @@ TransactionQueue::canAdd(
// This is done so minSeqLedgerGap is validated against the next
// ledgerSeq, which is what will be used at apply time
++ls.getLedgerHeader().currentToModify().ledgerSeq;
// TODO: ^^ I think this is the right thing to do. Was previously the
// commented out line below.
// ls.getLedgerHeader().currentToModify().ledgerSeq =
// mApp.getLedgerManager().getLastClosedLedgerNum() + 1;
}

auto txResult =
Expand Down Expand Up @@ -645,7 +641,7 @@ TransactionQueue::AddResult
TransactionQueue::tryAdd(TransactionFrameBasePtr tx, bool submittedFromSelf)
{
ZoneScoped;
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
std::lock_guard<std::mutex> guard(mTxQueueMutex);

auto c1 =
tx->getEnvelope().type() == ENVELOPE_TYPE_TX_FEE_BUMP &&
Expand Down Expand Up @@ -701,8 +697,9 @@ TransactionQueue::tryAdd(TransactionFrameBasePtr tx, bool submittedFromSelf)
// make space so that we can add this transaction
// this will succeed as `canAdd` ensures that this is the case
mTxQueueLimiter.evictTransactions(
txsToEvict, *tx,
[&](TransactionFrameBasePtr const& txToEvict) { ban({txToEvict}); });
txsToEvict, *tx, [&](TransactionFrameBasePtr const& txToEvict) {
banInternal({txToEvict});
});
mTxQueueLimiter.addTransaction(tx);
mKnownTxHashes[tx->getFullHash()] = tx;

Expand Down Expand Up @@ -806,7 +803,14 @@ void
TransactionQueue::ban(Transactions const& banTxs)
{
ZoneScoped;
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
std::lock_guard<std::mutex> guard(mTxQueueMutex);
banInternal(banTxs);
}

void
TransactionQueue::banInternal(Transactions const& banTxs)
{
ZoneScoped;
auto& bannedFront = mBannedTransactions.front();

// Group the transactions by source account and ban all the transactions
Expand Down Expand Up @@ -852,7 +856,7 @@ TransactionQueue::AccountState
TransactionQueue::getAccountTransactionQueueInfo(
AccountID const& accountID) const
{
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
std::lock_guard<std::mutex> guard(mTxQueueMutex);
auto i = mAccountStates.find(accountID);
if (i == std::end(mAccountStates))
{
Expand All @@ -864,7 +868,7 @@ TransactionQueue::getAccountTransactionQueueInfo(
size_t
TransactionQueue::countBanned(int index) const
{
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
std::lock_guard<std::mutex> guard(mTxQueueMutex);
return mBannedTransactions[index].size();
}
#endif
Expand Down Expand Up @@ -939,7 +943,13 @@ TransactionQueue::shift()
bool
TransactionQueue::isBanned(Hash const& hash) const
{
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
std::lock_guard<std::mutex> guard(mTxQueueMutex);
return isBannedInternal(hash);
}

bool
TransactionQueue::isBannedInternal(Hash const& hash) const
{
return std::any_of(
std::begin(mBannedTransactions), std::end(mBannedTransactions),
[&](UnorderedSet<Hash> const& transactions) {
Expand All @@ -951,7 +961,14 @@ TxFrameList
TransactionQueue::getTransactions(LedgerHeader const& lcl) const
{
ZoneScoped;
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
std::lock_guard<std::mutex> guard(mTxQueueMutex);
return getTransactionsInternal(lcl);
}

TxFrameList
TransactionQueue::getTransactionsInternal(LedgerHeader const& lcl) const
{
ZoneScoped;
TxFrameList txs;

uint32_t const nextLedgerSeq = lcl.ledgerSeq + 1;
Expand All @@ -972,7 +989,7 @@ TransactionFrameBaseConstPtr
TransactionQueue::getTx(Hash const& hash) const
{
ZoneScoped;
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
std::lock_guard<std::mutex> guard(mTxQueueMutex);
auto it = mKnownTxHashes.find(hash);
if (it != mKnownTxHashes.end())
{
Expand Down Expand Up @@ -1184,6 +1201,8 @@ SorobanTransactionQueue::broadcastSome()
size_t
SorobanTransactionQueue::getMaxQueueSizeOps() const
{
ZoneScoped;
std::lock_guard<std::mutex> guard(mTxQueueMutex);
if (protocolVersionStartsFrom(
mBucketSnapshot->getLedgerHeader().ledgerVersion,
SOROBAN_PROTOCOL_VERSION))
Expand Down Expand Up @@ -1264,7 +1283,7 @@ ClassicTransactionQueue::broadcastSome()
std::make_shared<DexLimitingLaneConfig>(opsToFlood, dexOpsToFlood),
mBroadcastSeed);
queue.visitTopTxs(txsToBroadcast, visitor, mBroadcastOpCarryover);
ban(banningTxs);
banInternal(banningTxs);
// carry over remainder, up to MAX_OPS_PER_TX ops
// reason is that if we add 1 next round, we can flood a "worst case fee
// bump" tx
Expand All @@ -1277,15 +1296,12 @@ ClassicTransactionQueue::broadcastSome()
}

void
TransactionQueue::broadcast(bool fromCallback)
TransactionQueue::broadcast(bool fromCallback,
std::lock_guard<std::mutex> const& guard)
{
// Must be called from the main thread due to the use of `mBroadcastTimer`
releaseAssert(threadIsMain());

// NOTE: Although this is not a public function, it can be called from
// `mBroadcastTimer` and so it needs to be synchronized.
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);

if (mShutdown || (!fromCallback && mWaiting))
{
return;
Expand Down Expand Up @@ -1317,7 +1333,14 @@ TransactionQueue::broadcast(bool fromCallback)
}

void
TransactionQueue::rebroadcast()
TransactionQueue::broadcast(bool fromCallback)
{
std::lock_guard<std::mutex> guard(mTxQueueMutex);
broadcast(fromCallback, guard);
}

void
TransactionQueue::rebroadcast(std::lock_guard<std::mutex> const& guard)
{
// For `broadcast` call
releaseAssert(threadIsMain());
Expand All @@ -1331,14 +1354,14 @@ TransactionQueue::rebroadcast()
as.mTransaction->mBroadcasted = false;
}
}
broadcast(false);
broadcast(false, guard);
}

void
TransactionQueue::shutdown()
{
releaseAssert(threadIsMain());
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
std::lock_guard<std::mutex> guard(mTxQueueMutex);
mShutdown = true;
mBroadcastTimer.cancel();
}
Expand All @@ -1351,7 +1374,7 @@ TransactionQueue::update(
{
ZoneScoped;
releaseAssert(threadIsMain());
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
std::lock_guard<std::mutex> guard(mTxQueueMutex);

mValidationSnapshot =
std::make_shared<ImmutableValidationSnapshot>(mAppConn);
Expand All @@ -1361,11 +1384,11 @@ TransactionQueue::update(
removeApplied(applied);
shift();

auto txs = getTransactions(lcl);
auto txs = getTransactionsInternal(lcl);
auto invalidTxs = filterInvalidTxs(txs);
ban(invalidTxs);
banInternal(invalidTxs);

rebroadcast();
rebroadcast(guard);
}

static bool
Expand Down Expand Up @@ -1409,14 +1432,14 @@ TransactionQueue::isFiltered(TransactionFrameBasePtr tx) const
size_t
TransactionQueue::getQueueSizeOps() const
{
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
std::lock_guard<std::mutex> guard(mTxQueueMutex);
return mTxQueueLimiter.size();
}

std::optional<int64_t>
TransactionQueue::getInQueueSeqNum(AccountID const& account) const
{
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
std::lock_guard<std::mutex> guard(mTxQueueMutex);
auto stateIter = mAccountStates.find(account);
if (stateIter == mAccountStates.end())
{
Expand All @@ -1433,7 +1456,8 @@ TransactionQueue::getInQueueSeqNum(AccountID const& account) const
size_t
ClassicTransactionQueue::getMaxQueueSizeOps() const
{
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
ZoneScoped;
std::lock_guard<std::mutex> guard(mTxQueueMutex);
auto res = mTxQueueLimiter.maxScaledLedgerResources(false);
releaseAssert(res.size() == NUM_CLASSIC_TX_RESOURCES);
return res.getVal(Resource::Type::OPERATIONS);
Expand Down
60 changes: 43 additions & 17 deletions src/herder/TransactionQueue.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,10 +210,15 @@ class TransactionQueue
virtual std::pair<Resource, std::optional<Resource>>
getMaxResourcesToFloodThisPeriod() const = 0;
virtual bool broadcastSome() = 0;
virtual int getFloodPeriod() const = 0;
virtual bool allowTxBroadcast(TimestampedTx const& tx) = 0;

// TODO: Explain that there's an overload that takes a guard because this
// function is called internally, and also scheduled on a timer. Any async
// call should call the first overload (which grabs a lock), and any
// internal call should call the second overload (which enforces that the
// lock is already held).
void broadcast(bool fromCallback);
void broadcast(bool fromCallback, std::lock_guard<std::mutex> const& guard);
// broadcasts a single transaction
enum class BroadcastStatus
{
Expand All @@ -234,6 +239,12 @@ class TransactionQueue

bool isFiltered(TransactionFrameBasePtr tx) const;

// TODO: Docs
// Protected versions of public functions that contain the actual
// implementation so they can be called internally when the lock is already
// held.
void banInternal(Transactions const& banTxs);

// Snapshots to use for transaction validation
ImmutableValidationSnapshotPtr mValidationSnapshot;
SearchableSnapshotConstPtr mBucketSnapshot;
Expand All @@ -245,7 +256,7 @@ class TransactionQueue

size_t mBroadcastSeed;

mutable std::recursive_mutex mTxQueueMutex;
mutable std::mutex mTxQueueMutex;

private:
AppConnector& mAppConn;
Expand All @@ -259,10 +270,24 @@ class TransactionQueue
*/
void shift();

void rebroadcast();
// TODO: Explain that this takes a lock guard due to the `broadcast` call
// that it makes.
void rebroadcast(std::lock_guard<std::mutex> const& guard);

// TODO: Docs
// Private versions of public functions that contain the actual
// implementation so they can be called internally when the lock is already
// held.
bool isBannedInternal(Hash const& hash) const;
TxFrameList getTransactionsInternal(LedgerHeader const& lcl) const;

virtual int getFloodPeriod() const = 0;

#ifdef BUILD_TESTS
public:
// TODO: These tests invoke protected/private functions directly that assume
// things are properly locked. I need to make sure these tests operate in a
// thread-safe manner or change them to not require private member access.
friend class TransactionQueueTest;

size_t getQueueSizeOps() const;
Expand All @@ -278,19 +303,13 @@ class SorobanTransactionQueue : public TransactionQueue
SearchableSnapshotConstPtr bucketSnapshot,
uint32 pendingDepth, uint32 banDepth,
uint32 poolLedgerMultiplier);
int
getFloodPeriod() const override
{
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
return mValidationSnapshot->getConfig().FLOOD_SOROBAN_TX_PERIOD_MS;
}

size_t getMaxQueueSizeOps() const override;
#ifdef BUILD_TESTS
void
clearBroadcastCarryover()
{
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
std::lock_guard<std::mutex> guard(mTxQueueMutex);
mBroadcastOpCarryover.clear();
mBroadcastOpCarryover.resize(1, Resource::makeEmptySoroban());
}
Expand All @@ -307,6 +326,13 @@ class SorobanTransactionQueue : public TransactionQueue
{
return true;
}

int
getFloodPeriod() const override
{
return mValidationSnapshot->getConfig().FLOOD_SOROBAN_TX_PERIOD_MS;
}

};

class ClassicTransactionQueue : public TransactionQueue
Expand All @@ -317,13 +343,6 @@ class ClassicTransactionQueue : public TransactionQueue
uint32 pendingDepth, uint32 banDepth,
uint32 poolLedgerMultiplier);

int
getFloodPeriod() const override
{
std::lock_guard<std::recursive_mutex> guard(mTxQueueMutex);
return mValidationSnapshot->getConfig().FLOOD_TX_PERIOD_MS;
}

size_t getMaxQueueSizeOps() const override;

private:
Expand All @@ -335,6 +354,13 @@ class ClassicTransactionQueue : public TransactionQueue
virtual bool broadcastSome() override;
std::vector<Resource> mBroadcastOpCarryover;
virtual bool allowTxBroadcast(TimestampedTx const& tx) override;

int
getFloodPeriod() const override
{
return mValidationSnapshot->getConfig().FLOOD_TX_PERIOD_MS;
}

};

extern std::array<const char*,
Expand Down

0 comments on commit 7bb9da0

Please sign in to comment.