Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(mempool): data race in mempool prepare proposal handler #21413

Merged
merged 12 commits into from
Sep 4, 2024
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Every module contains its own CHANGELOG.md. Please refer to the module you are i
### Bug Fixes

* (baseapp) [#21256](https://github.com/cosmos/cosmos-sdk/pull/21256) Halt height will not commit the block indicated, meaning that if halt-height is set to 10, only blocks until 9 (included) will be committed. This is to go back to the original behavior before a change was introduced in v0.50.0.

* (baseapp) [#21413](https://github.com/cosmos/cosmos-sdk/pull/21413) Fix data race in sdk mempool.

### API Breaking Changes

Expand Down
32 changes: 19 additions & 13 deletions baseapp/abci_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,14 +285,16 @@ func (h *DefaultProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHan
return &abci.PrepareProposalResponse{Txs: h.txSelector.SelectedTxs(ctx)}, nil
}

iterator := h.mempool.Select(ctx, req.Txs)
selectedTxsSignersSeqs := make(map[string]uint64)
var selectedTxsNums int
for iterator != nil {
memTx := iterator.Tx()
signerData, err := h.signerExtAdapter.GetSigners(memTx)
var (
err error
selectedTxsNums int
)
h.mempool.SelectBy(ctx, req.Txs, func(memTx sdk.Tx) bool {
var signerData []mempool.SignerData
signerData, err = h.signerExtAdapter.GetSigners(memTx)
if err != nil {
return nil, err
return false
}

// If the signers aren't in selectedTxsSignersSeqs then we haven't seen them before
Expand All @@ -316,24 +318,24 @@ func (h *DefaultProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHan
txSignersSeqs[signer.Signer.String()] = signer.Sequence
}
if !shouldAdd {
iterator = iterator.Next()
continue
return true
}

// NOTE: Since transaction verification was already executed in CheckTx,
// which calls mempool.Insert, in theory everything in the pool should be
// valid. But some mempool implementations may insert invalid txs, so we
// check again.
txBz, err := h.txVerifier.PrepareProposalVerifyTx(memTx)
var txBz []byte
txBz, err = h.txVerifier.PrepareProposalVerifyTx(memTx)
if err != nil {
err := h.mempool.Remove(memTx)
err = h.mempool.Remove(memTx)
if err != nil && !errors.Is(err, mempool.ErrTxNotFound) {
return nil, err
return false
}
} else {
stop := h.txSelector.SelectTxForProposal(ctx, uint64(req.MaxTxBytes), maxBlockGas, memTx, txBz)
if stop {
break
return false
}

txsLen := len(h.txSelector.SelectedTxs(ctx))
Expand All @@ -354,7 +356,11 @@ func (h *DefaultProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHan
selectedTxsNums = txsLen
}

iterator = iterator.Next()
return true
})

if err != nil {
return nil, err
}

return &abci.PrepareProposalResponse{Txs: h.txSelector.SelectedTxs(ctx)}, nil
Expand Down
6 changes: 4 additions & 2 deletions types/mempool/mempool.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ type Mempool interface {
Insert(context.Context, sdk.Tx) error

// Select returns an Iterator over the app-side mempool. If txs are specified,
// then they shall be incorporated into the Iterator. The Iterator must be
// closed by the caller.
// then they shall be incorporated into the Iterator. The Iterator is not thread-safe to use.
Select(context.Context, [][]byte) Iterator

// SelectBy use callback to iterate over the mempool, it's thread-safe to use.
SelectBy(context.Context, [][]byte, func(sdk.Tx) bool)

// CountTx returns the number of transactions currently in the mempool.
CountTx() int

Expand Down
9 changes: 5 additions & 4 deletions types/mempool/noop.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ var _ Mempool = (*NoOpMempool)(nil)
// is FIFO-ordered by default.
type NoOpMempool struct{}

func (NoOpMempool) Insert(context.Context, sdk.Tx) error { return nil }
func (NoOpMempool) Select(context.Context, [][]byte) Iterator { return nil }
func (NoOpMempool) CountTx() int { return 0 }
func (NoOpMempool) Remove(sdk.Tx) error { return nil }
func (NoOpMempool) Insert(context.Context, sdk.Tx) error { return nil }
func (NoOpMempool) Select(context.Context, [][]byte) Iterator { return nil }
func (NoOpMempool) SelectBy(context.Context, [][]byte, func(sdk.Tx) bool) {}
func (NoOpMempool) CountTx() int { return 0 }
func (NoOpMempool) Remove(sdk.Tx) error { return nil }
17 changes: 16 additions & 1 deletion types/mempool/priority_nonce.go
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,13 @@ func (i *PriorityNonceIterator[C]) Tx() sdk.Tx {
//
// NOTE: It is not safe to use this iterator while removing transactions from
// the underlying mempool.
func (mp *PriorityNonceMempool[C]) Select(_ context.Context, _ [][]byte) Iterator {
func (mp *PriorityNonceMempool[C]) Select(ctx context.Context, txs [][]byte) Iterator {
mp.mtx.Lock()
defer mp.mtx.Unlock()
return mp.doSelect(ctx, txs)
}

func (mp *PriorityNonceMempool[C]) doSelect(_ context.Context, _ [][]byte) Iterator {
if mp.priorityIndex.Len() == 0 {
return nil
}
Expand All @@ -368,6 +372,17 @@ func (mp *PriorityNonceMempool[C]) Select(_ context.Context, _ [][]byte) Iterato
return iterator.iteratePriority()
}

// SelectBy will hold the mutex during the iteration, callback returns if continue.
func (mp *PriorityNonceMempool[C]) SelectBy(ctx context.Context, txs [][]byte, callback func(sdk.Tx) bool) {
mp.mtx.Lock()
defer mp.mtx.Unlock()

iter := mp.doSelect(ctx, txs)
for iter != nil && callback(iter.Tx()) {
iter = iter.Next()
}
}

type reorderKey[C comparable] struct {
deleteKey txMeta[C]
insertKey txMeta[C]
Expand Down
85 changes: 85 additions & 0 deletions types/mempool/priority_nonce_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package mempool_test

import (
"context"
"fmt"
"math"
"math/rand"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -395,6 +397,89 @@ func (s *MempoolTestSuite) TestIterator() {
}
}

func (s *MempoolTestSuite) TestIteratorConcurrency() {
t := s.T()
ctx := sdk.NewContext(nil, false, log.NewNopLogger())
accounts := simtypes.RandomAccounts(rand.New(rand.NewSource(0)), 2)
sa := accounts[0].Address
sb := accounts[1].Address

tests := []struct {
txs []txSpec
fail bool
}{
{
txs: []txSpec{
{p: 20, n: 1, a: sa},
{p: 15, n: 1, a: sb},
{p: 6, n: 2, a: sa},
{p: 21, n: 4, a: sa},
{p: 8, n: 2, a: sb},
},
},
{
txs: []txSpec{
{p: 20, n: 1, a: sa},
{p: 15, n: 1, a: sb},
{p: 6, n: 2, a: sa},
{p: 21, n: 4, a: sa},
{p: math.MinInt64, n: 2, a: sb},
},
},
}

for i, tt := range tests {
t.Run(fmt.Sprintf("case %d", i), func(t *testing.T) {
pool := mempool.DefaultPriorityMempool()

// create test txs and insert into mempool
for i, ts := range tt.txs {
tx := testTx{id: i, priority: int64(ts.p), nonce: uint64(ts.n), address: ts.a}
c := ctx.WithPriority(tx.priority)
err := pool.Insert(c, tx)
require.NoError(t, err)
}

// iterate through txs
stdCtx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()

id := len(tt.txs)
for {
select {
case <-stdCtx.Done():
return
default:
id++
tx := testTx{id: id, priority: int64(rand.Intn(100)), nonce: uint64(id), address: sa}
c := ctx.WithPriority(tx.priority)
err := pool.Insert(c, tx)
require.NoError(t, err)
}
}
}()

var i int
pool.SelectBy(ctx, nil, func(memTx sdk.Tx) bool {
tx := memTx.(testTx)
if tx.id < len(tt.txs) {
require.Equal(t, tt.txs[tx.id].p, int(tx.priority))
require.Equal(t, tt.txs[tx.id].n, int(tx.nonce))
require.Equal(t, tt.txs[tx.id].a, tx.address)
i++
}
return i < len(tt.txs)
})
require.Equal(t, i, len(tt.txs))
cancel()
wg.Wait()
})
}
}

func (s *MempoolTestSuite) TestPriorityTies() {
ctx := sdk.NewContext(nil, false, log.NewNopLogger())
accounts := simtypes.RandomAccounts(rand.New(rand.NewSource(0)), 3)
Expand Down
17 changes: 16 additions & 1 deletion types/mempool/sender_nonce.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,13 @@ func (snm *SenderNonceMempool) Insert(_ context.Context, tx sdk.Tx) error {
//
// NOTE: It is not safe to use this iterator while removing transactions from
// the underlying mempool.
func (snm *SenderNonceMempool) Select(_ context.Context, _ [][]byte) Iterator {
func (snm *SenderNonceMempool) Select(ctx context.Context, txs [][]byte) Iterator {
snm.mtx.Lock()
defer snm.mtx.Unlock()
return snm.doSelect(ctx, txs)
}

func (snm *SenderNonceMempool) doSelect(_ context.Context, _ [][]byte) Iterator {
var senders []string

senderCursors := make(map[string]*skiplist.Element)
Expand Down Expand Up @@ -189,6 +193,17 @@ func (snm *SenderNonceMempool) Select(_ context.Context, _ [][]byte) Iterator {
return iter.Next()
}

// SelectBy will hold the mutex during the iteration, callback returns if continue.
func (snm *SenderNonceMempool) SelectBy(ctx context.Context, txs [][]byte, callback func(sdk.Tx) bool) {
snm.mtx.Lock()
defer snm.mtx.Unlock()

iter := snm.doSelect(ctx, txs)
for iter != nil && callback(iter.Tx()) {
iter = iter.Next()
}
}

// CountTx returns the total count of txs in the mempool.
func (snm *SenderNonceMempool) CountTx() int {
snm.mtx.Lock()
Expand Down
Loading