diff --git a/turbo/rpchelper/filters.go b/turbo/rpchelper/filters.go index 2057150c493..28f508c49b3 100644 --- a/turbo/rpchelper/filters.go +++ b/turbo/rpchelper/filters.go @@ -44,7 +44,6 @@ type Filters struct { logsRequestor atomic.Value onNewSnapshot func() - storeMu sync.Mutex logsStores *concurrent.SyncMap[LogsSubID, []*types.Log] pendingHeadsStores *concurrent.SyncMap[HeadsSubID, []*types.Header] pendingTxsStores *concurrent.SyncMap[PendingTxsSubID, [][]types.Transaction] @@ -647,8 +646,20 @@ func (ff *Filters) AddLogs(id LogsSubID, log *types.Log) { maxLogs := ff.config.RpcSubscriptionFiltersMaxLogs if maxLogs > 0 && len(st)+1 > maxLogs { - st = st[len(st)+1-maxLogs:] // Remove oldest logs to make space + // Calculate the number of logs to remove + excessLogs := len(st) + 1 - maxLogs + if excessLogs > 0 { + if excessLogs >= len(st) { + // If excessLogs is greater than or equal to the length of st, remove all + st = []*types.Log{} + } else { + // Otherwise, remove the oldest logs + st = st[excessLogs:] + } + } } + + // Append the new log st = append(st, log) return st }) @@ -672,9 +683,21 @@ func (ff *Filters) AddPendingBlock(id HeadsSubID, block *types.Header) { } maxHeaders := ff.config.RpcSubscriptionFiltersMaxHeaders - if maxHeaders > 0 && len(st) >= maxHeaders { - st = st[1:] // Remove the oldest header to make space + if maxHeaders > 0 && len(st)+1 > maxHeaders { + // Calculate the number of headers to remove + excessHeaders := len(st) + 1 - maxHeaders + if excessHeaders > 0 { + if excessHeaders >= len(st) { + // If excessHeaders is greater than or equal to the length of st, remove all + st = []*types.Header{} + } else { + // Otherwise, remove the oldest headers + st = st[excessHeaders:] + } + } } + + // Append the new header st = append(st, block) return st }) @@ -712,9 +735,16 @@ func (ff *Filters) AddPendingTxs(id PendingTxsSubID, txs []types.Transaction) { flatSt = append(flatSt, txBatch...) } - // Remove the oldest transactions to make space for new ones - if len(flatSt)+len(txs) > maxTxs { - flatSt = flatSt[len(flatSt)+len(txs)-maxTxs:] + // Calculate how many transactions need to be removed + excessTxs := len(flatSt) + len(txs) - maxTxs + if excessTxs > 0 { + if excessTxs >= len(flatSt) { + // If excessTxs is greater than or equal to the length of flatSt, remove all + flatSt = []types.Transaction{} + } else { + // Otherwise, remove the oldest transactions + flatSt = flatSt[excessTxs:] + } } // Convert flatSt back to [][]types.Transaction with a single batch diff --git a/turbo/rpchelper/filters_test.go b/turbo/rpchelper/filters_test.go index 593d2662c7e..3be33ff7f9b 100644 --- a/turbo/rpchelper/filters_test.go +++ b/turbo/rpchelper/filters_test.go @@ -356,131 +356,136 @@ func TestFilters_SubscribeLogsGeneratesCorrectLogFilterRequest(t *testing.T) { } func TestFilters_AddLogs(t *testing.T) { - config := FiltersConfig{RpcSubscriptionFiltersMaxLogs: 5} - f := New(context.TODO(), config, nil, nil, nil, func() {}, log.New()) - logID := LogsSubID("test-log") - logEntry := &types.Log{} - - // Add 10 logs to the store, but limit is 5 - for i := 0; i < 10; i++ { - f.AddLogs(logID, logEntry) - } - - logs, found := f.ReadLogs(logID) - if !found { - t.Error("expected to find logs in the store") - } - if len(logs) != 5 { - t.Errorf("expected 5 logs in the store, got %d", len(logs)) - } -} - -func TestFilters_AddLogs_Unlimited(t *testing.T) { - config := FiltersConfig{RpcSubscriptionFiltersMaxLogs: 0} - f := New(context.TODO(), config, nil, nil, nil, func() {}, log.New()) - logID := LogsSubID("test-log") - logEntry := &types.Log{} - - // Add 10 logs to the store, limit is unlimited - for i := 0; i < 10; i++ { - f.AddLogs(logID, logEntry) - } - - logs, found := f.ReadLogs(logID) - if !found { - t.Error("expected to find logs in the store") - } - if len(logs) != 10 { - t.Errorf("expected 10 logs in the store, got %d", len(logs)) + tests := []struct { + name string + maxLogs int + numToAdd int + expectedLen int + }{ + {"WithinLimit", 5, 5, 5}, + {"ExceedingLimit", 2, 3, 2}, + {"UnlimitedLogs", 0, 10, 10}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := FiltersConfig{RpcSubscriptionFiltersMaxLogs: tt.maxLogs} + f := New(context.TODO(), config, nil, nil, nil, func() {}, log.New()) + logID := LogsSubID("test-log") + logEntry := &types.Log{Address: libcommon.HexToAddress("095e7baea6a6c7c4c2dfeb977efac326af552d87")} + + for i := 0; i < tt.numToAdd; i++ { + f.AddLogs(logID, logEntry) + } + + logs, found := f.logsStores.Get(logID) + if !found { + t.Fatal("Expected to find logs in the store") + } + if len(logs) != tt.expectedLen { + t.Fatalf("Expected %d logs, but got %d", tt.expectedLen, len(logs)) + } + }) } } func TestFilters_AddPendingBlocks(t *testing.T) { - config := FiltersConfig{RpcSubscriptionFiltersMaxHeaders: 3} - f := New(context.TODO(), config, nil, nil, nil, func() {}, log.New()) - headerID := HeadsSubID("test-header") - header := &types.Header{} - - // Add 5 headers to the store, but limit is 3 - for i := 0; i < 5; i++ { - f.AddPendingBlock(headerID, header) - } - - headers, found := f.ReadPendingBlocks(headerID) - if !found { - t.Error("expected to find headers in the store") - } - if len(headers) != 3 { - t.Errorf("expected 3 headers in the store, got %d", len(headers)) - } -} - -func TestFilters_AddPendingBlocks_Unlimited(t *testing.T) { - config := FiltersConfig{RpcSubscriptionFiltersMaxHeaders: 0} - f := New(context.TODO(), config, nil, nil, nil, func() {}, log.New()) - headerID := HeadsSubID("test-header") - header := &types.Header{} - - // Add 5 headers to the store, limit is unlimited - for i := 0; i < 5; i++ { - f.AddPendingBlock(headerID, header) - } - - headers, found := f.ReadPendingBlocks(headerID) - if !found { - t.Error("expected to find headers in the store") - } - if len(headers) != 5 { - t.Errorf("expected 5 headers in the store, got %d", len(headers)) + tests := []struct { + name string + maxHeaders int + numToAdd int + expectedLen int + }{ + {"WithinLimit", 3, 3, 3}, + {"ExceedingLimit", 2, 5, 2}, + {"UnlimitedHeaders", 0, 10, 10}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := FiltersConfig{RpcSubscriptionFiltersMaxHeaders: tt.maxHeaders} + f := New(context.TODO(), config, nil, nil, nil, func() {}, log.New()) + blockID := HeadsSubID("test-block") + header := &types.Header{} + + for i := 0; i < tt.numToAdd; i++ { + f.AddPendingBlock(blockID, header) + } + + blocks, found := f.pendingHeadsStores.Get(blockID) + if !found { + t.Fatal("Expected to find blocks in the store") + } + if len(blocks) != tt.expectedLen { + t.Fatalf("Expected %d blocks, but got %d", tt.expectedLen, len(blocks)) + } + }) } } func TestFilters_AddPendingTxs(t *testing.T) { - config := FiltersConfig{RpcSubscriptionFiltersMaxTxs: 4} - f := New(context.TODO(), config, nil, nil, nil, func() {}, log.New()) - txID := PendingTxsSubID("test-tx") - var tx types.Transaction = types.NewTransaction(0, libcommon.HexToAddress("095e7baea6a6c7c4c2dfeb977efac326af552d87"), uint256.NewInt(10), 50000, uint256.NewInt(10), nil) - tx, _ = tx.WithSignature(*types.LatestSignerForChainID(nil), libcommon.Hex2Bytes("9bea4c4daac7c7c52e093e6a4c35dbbcf8856f1af7b059ba20253e70848d094f8a8fae537ce25ed8cb5af9adac3f141af69bd515bd2ba031522df09b97dd72b100")) - - // Add 6 txs to the store, but limit is 4 - for i := 0; i < 6; i++ { - f.AddPendingTxs(txID, []types.Transaction{tx}) - } - - txs, found := f.ReadPendingTxs(txID) - if !found { - t.Error("expected to find txs in the store") - } - totalTxs := 0 - for _, batch := range txs { - totalTxs += len(batch) - } - if totalTxs != 4 { - t.Errorf("expected 4 txs in the store, got %d", totalTxs) - } -} - -func TestFilters_AddPendingTxs_Unlimited(t *testing.T) { - config := FiltersConfig{RpcSubscriptionFiltersMaxTxs: 0} - f := New(context.TODO(), config, nil, nil, nil, func() {}, log.New()) - txID := PendingTxsSubID("test-tx") - var tx types.Transaction = types.NewTransaction(0, libcommon.HexToAddress("095e7baea6a6c7c4c2dfeb977efac326af552d87"), uint256.NewInt(10), 50000, uint256.NewInt(10), nil) - tx, _ = tx.WithSignature(*types.LatestSignerForChainID(nil), libcommon.Hex2Bytes("9bea4c4daac7c7c52e093e6a4c35dbbcf8856f1af7b059ba20253e70848d094f8a8fae537ce25ed8cb5af9adac3f141af69bd515bd2ba031522df09b97dd72b100")) - - // Add 6 txs to the store, limit is unlimited - for i := 0; i < 6; i++ { - f.AddPendingTxs(txID, []types.Transaction{tx}) - } - - txs, found := f.ReadPendingTxs(txID) - if !found { - t.Error("expected to find txs in the store") - } - totalTxs := 0 - for _, batch := range txs { - totalTxs += len(batch) - } - if totalTxs != 6 { - t.Errorf("expected 6 txs in the store, got %d", totalTxs) + tests := []struct { + name string + maxTxs int + numToAdd int + expectedLen int + }{ + {"WithinLimit", 5, 5, 5}, + {"ExceedingLimit", 2, 6, 2}, + {"UnlimitedTxs", 0, 10, 10}, + {"TriggerPanic", 5, 10, 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := FiltersConfig{RpcSubscriptionFiltersMaxTxs: tt.maxTxs} + f := New(context.TODO(), config, nil, nil, nil, func() {}, log.New()) + txID := PendingTxsSubID("test-tx") + var tx types.Transaction = types.NewTransaction(0, libcommon.HexToAddress("095e7baea6a6c7c4c2dfeb977efac326af552d87"), uint256.NewInt(10), 50000, uint256.NewInt(10), nil) + tx, _ = tx.WithSignature(*types.LatestSignerForChainID(nil), libcommon.Hex2Bytes("9bea4c4daac7c7c52e093e6a4c35dbbcf8856f1af7b059ba20253e70848d094f8a8fae537ce25ed8cb5af9adac3f141af69bd515bd2ba031522df09b97dd72b100")) + + // Testing for panic + if tt.name == "TriggerPanic" { + defer func() { + if r := recover(); r != nil { + t.Errorf("AddPendingTxs caused a panic: %v", r) + } + }() + + // Add transactions to trigger panic + // Initial batch to set the stage + for i := 0; i < 4; i++ { + f.AddPendingTxs(txID, []types.Transaction{tx}) + } + + // Adding more transactions in smaller increments to ensure the panic + for i := 0; i < 2; i++ { + f.AddPendingTxs(txID, []types.Transaction{tx}) + } + + // Adding another large batch to ensure it exceeds the limit and triggers the panic + largeBatch := make([]types.Transaction, 10) + for i := range largeBatch { + largeBatch[i] = tx + } + f.AddPendingTxs(txID, largeBatch) + } else { + for i := 0; i < tt.numToAdd; i++ { + f.AddPendingTxs(txID, []types.Transaction{tx}) + } + + txs, found := f.ReadPendingTxs(txID) + if !found { + t.Fatal("Expected to find transactions in the store") + } + totalTxs := 0 + for _, batch := range txs { + totalTxs += len(batch) + } + if totalTxs != tt.expectedLen { + t.Fatalf("Expected %d transactions, but got %d", tt.expectedLen, totalTxs) + } + } + }) } }