From fb47a1b274f8f97852b18340a02d7f3c4d724c28 Mon Sep 17 00:00:00 2001 From: Valentin Staykov Date: Thu, 1 Aug 2024 12:41:27 +0000 Subject: [PATCH 1/2] simplify datastream entry parsing --- zk/datastream/client/stream_client.go | 130 ++++--------- zk/hermez_db/db.go | 2 +- zk/stages/stage_batches.go | 259 +++++++++++++------------- 3 files changed, 170 insertions(+), 221 deletions(-) diff --git a/zk/datastream/client/stream_client.go b/zk/datastream/client/stream_client.go index a1a74d896e4..f20cb36534b 100644 --- a/zk/datastream/client/stream_client.go +++ b/zk/datastream/client/stream_client.go @@ -45,11 +45,7 @@ type StreamClient struct { progress atomic.Uint64 // Channels - batchStartChan chan types.BatchStart - batchEndChan chan types.BatchEnd - l2BlockChan chan types.FullL2Block - l2TxChan chan types.L2TransactionProto - gerUpdatesChan chan types.GerUpdate // NB: unused from etrog onwards (forkid 7) + entryChan chan interface{} // keeps track of the latest fork from the stream to assign to l2 blocks currentFork uint64 @@ -70,17 +66,14 @@ const ( // server must be in format "url:port" func NewClient(ctx context.Context, server string, version int, checkTimeout time.Duration, latestDownloadedForkId uint16) *StreamClient { c := &StreamClient{ - ctx: ctx, - checkTimeout: checkTimeout, - server: server, - version: version, - streamType: StSequencer, - id: "", - batchStartChan: make(chan types.BatchStart, 100), - batchEndChan: make(chan types.BatchEnd, 100), - l2BlockChan: make(chan types.FullL2Block, 100000), - gerUpdatesChan: make(chan types.GerUpdate, 1000), - currentFork: uint64(latestDownloadedForkId), + ctx: ctx, + checkTimeout: checkTimeout, + server: server, + version: version, + streamType: StSequencer, + id: "", + entryChan: make(chan interface{}, 100000), + currentFork: uint64(latestDownloadedForkId), } return c @@ -90,20 +83,8 @@ func (c *StreamClient) IsVersion3() bool { return c.version >= versionAddedBlockEnd } -func (c *StreamClient) GetBatchStartChan() chan types.BatchStart { - return c.batchStartChan -} -func (c *StreamClient) GetBatchEndChan() chan types.BatchEnd { - return c.batchEndChan -} -func (c *StreamClient) GetL2BlockChan() chan types.FullL2Block { - return c.l2BlockChan -} -func (c *StreamClient) GetL2TxChan() chan types.L2TransactionProto { - return c.l2TxChan -} -func (c *StreamClient) GetGerUpdatesChan() chan types.GerUpdate { - return c.gerUpdatesChan +func (c *StreamClient) GetEntryChan() chan interface{} { + return c.entryChan } func (c *StreamClient) GetLastWrittenTimeAtomic() *atomic.Int64 { return &c.lastWrittenTime @@ -132,8 +113,7 @@ func (c *StreamClient) Start() error { func (c *StreamClient) Stop() { c.conn.Close() - close(c.l2BlockChan) - close(c.gerUpdatesChan) + close(c.entryChan) } // Command header: Get status @@ -319,38 +299,30 @@ LOOP: c.conn.SetReadDeadline(time.Now().Add(c.checkTimeout)) } - fullBlock, batchStart, batchEnd, gerUpdate, batchBookmark, blockBookmark, localErr := c.readFullBlockProto() + parsedProto, localErr := c.readParsedProto() if localErr != nil { err = localErr break } c.lastWrittenTime.Store(time.Now().UnixNano()) - // skip over bookmarks (but only when fullblock is nil or will miss l2 blocks) - if batchBookmark != nil || blockBookmark != nil { + switch parsedProto := parsedProto.(type) { + case *types.BookmarkProto: continue - } - - // write batch starts to channel - if batchStart != nil { - c.currentFork = (*batchStart).ForkId - c.batchStartChan <- *batchStart - } - - if gerUpdate != nil { - c.gerUpdatesChan <- *gerUpdate - } - - if batchEnd != nil { - // this check was inside c.readFullBlockProto() but it is better to move it here - c.batchEndChan <- *batchEnd - } - - // ensure the block is assigned the currently known fork - if fullBlock != nil { - fullBlock.ForkId = c.currentFork - log.Trace("writing block to channel", "blockNumber", fullBlock.L2BlockNumber, "batchNumber", fullBlock.BatchNumber) - c.l2BlockChan <- *fullBlock + case *types.BatchStart: + c.currentFork = parsedProto.ForkId + c.entryChan <- parsedProto + case *types.GerUpdateProto: + c.entryChan <- parsedProto + case *types.BatchEnd: + c.entryChan <- parsedProto + case *types.FullL2Block: + parsedProto.ForkId = c.currentFork + log.Trace("writing block to channel", "blockNumber", parsedProto.L2BlockNumber, "batchNumber", parsedProto.BatchNumber) + c.entryChan <- parsedProto + default: + err = fmt.Errorf("unexpected entry type: %v", parsedProto) + break LOOP } } @@ -376,13 +348,8 @@ func (c *StreamClient) tryReConnect() error { return err } -func (c *StreamClient) readFullBlockProto() ( - l2Block *types.FullL2Block, - batchStart *types.BatchStart, - batchEnd *types.BatchEnd, - gerUpdate *types.GerUpdate, - batchBookmark *types.BookmarkProto, - blockBookmark *types.BookmarkProto, +func (c *StreamClient) readParsedProto() ( + parsedEntry interface{}, err error, ) { file, err := c.readFileEntry() @@ -393,34 +360,15 @@ func (c *StreamClient) readFullBlockProto() ( switch file.EntryType { case types.BookmarkEntryType: - var bookmark *types.BookmarkProto - if bookmark, err = types.UnmarshalBookmark(file.Data); err != nil { - return - } - if bookmark.BookmarkType() == datastream.BookmarkType_BOOKMARK_TYPE_BATCH { - batchBookmark = bookmark - return - } else { - blockBookmark = bookmark - return - } + parsedEntry, err = types.UnmarshalBookmark(file.Data) case types.EntryTypeGerUpdate: - if gerUpdate, err = types.DecodeGerUpdateProto(file.Data); err != nil { - return - } - log.Trace("ger update", "ger", gerUpdate) - return + parsedEntry, err = types.DecodeGerUpdateProto(file.Data) case types.EntryTypeBatchStart: - if batchStart, err = types.UnmarshalBatchStart(file.Data); err != nil { - return - } - return + parsedEntry, err = types.UnmarshalBatchStart(file.Data) case types.EntryTypeBatchEnd: - if batchEnd, err = types.UnmarshalBatchEnd(file.Data); err != nil { - return - } - return + parsedEntry, err = types.UnmarshalBatchEnd(file.Data) case types.EntryTypeL2Block: + var l2Block *types.FullL2Block if l2Block, err = types.UnmarshalL2Block(file.Data); err != nil { return } @@ -462,7 +410,7 @@ func (c *StreamClient) readFullBlockProto() ( return } } else if innerFile.IsBatchEnd() { - if batchEnd, err = types.UnmarshalBatchEnd(file.Data); err != nil { + if _, err = types.UnmarshalBatchEnd(file.Data); err != nil { return } break LOOP @@ -473,14 +421,14 @@ func (c *StreamClient) readFullBlockProto() ( } l2Block.L2Txs = txs + parsedEntry = l2Block return case types.EntryTypeL2Tx: err = fmt.Errorf("unexpected l2Tx out of block") - return default: err = fmt.Errorf("unexpected entry type: %d", file.EntryType) - return } + return } // reads file bytes from socket and tries to parse them diff --git a/zk/hermez_db/db.go b/zk/hermez_db/db.go index 2e263e4be71..64a68cb9da6 100644 --- a/zk/hermez_db/db.go +++ b/zk/hermez_db/db.go @@ -751,7 +751,7 @@ func (db *HermezDbReader) GetBlockL1BlockHashes(fromBlockNo, toBlockNo uint64) ( return l1BlockHashes, nil } -func (db *HermezDb) WriteBatchGlobalExitRoot(batchNumber uint64, ger dstypes.GerUpdate) error { +func (db *HermezDb) WriteBatchGlobalExitRoot(batchNumber uint64, ger *dstypes.GerUpdate) error { return db.tx.Put(GLOBAL_EXIT_ROOTS_BATCHES, Uint64ToBytes(batchNumber), ger.EncodeToBytes()) } diff --git a/zk/stages/stage_batches.go b/zk/stages/stage_batches.go index c421bc18655..d32980aff13 100644 --- a/zk/stages/stage_batches.go +++ b/zk/stages/stage_batches.go @@ -64,7 +64,7 @@ type HermezDb interface { DeleteL1BlockHashes(l1BlockHashes *[]common.Hash) error WriteGerForL1BlockHash(l1BlockHash, ger common.Hash) error DeleteL1BlockHashGers(l1BlockHashes *[]common.Hash) error - WriteBatchGlobalExitRoot(batchNumber uint64, ger types.GerUpdate) error + WriteBatchGlobalExitRoot(batchNumber uint64, ger *types.GerUpdate) error WriteIntermediateTxStateRoot(l2BlockNumber uint64, txHash common.Hash, rpcRoot common.Hash) error WriteBlockL1InfoTreeIndex(blockNumber uint64, l1Index uint64) error WriteLatestUsedGer(batchNo uint64, ger common.Hash) error @@ -73,11 +73,7 @@ type HermezDb interface { type DatastreamClient interface { ReadAllEntriesToChannel() error - GetL2BlockChan() chan types.FullL2Block - GetL2TxChan() chan types.L2TransactionProto - GetBatchStartChan() chan types.BatchStart - GetBatchEndChan() chan types.BatchEnd - GetGerUpdatesChan() chan types.GerUpdate + GetEntryChan() chan interface{} GetLastWrittenTimeAtomic() *atomic.Int64 GetStreamingAtomic() *atomic.Bool GetProgressAtomic() *atomic.Uint64 @@ -215,10 +211,7 @@ func SpawnStageBatches( log.Info(fmt.Sprintf("[%s] Reading blocks from the datastream.", logPrefix)) - l2BlockChan := cfg.dsClient.GetL2BlockChan() - batchStartChan := cfg.dsClient.GetBatchStartChan() - batchEndChan := cfg.dsClient.GetBatchEndChan() - gerUpdateChan := cfg.dsClient.GetGerUpdatesChan() + entryChan := cfg.dsClient.GetEntryChan() lastWrittenTimeAtomic := cfg.dsClient.GetLastWrittenTimeAtomic() streamingAtomic := cfg.dsClient.GetStreamingAtomic() @@ -231,149 +224,149 @@ LOOP: // if download routine finished, should continue to read from channel until it's empty // if both download routine stopped and channel empty - stop loop select { - case batchStart := <-batchStartChan: - // check if the batch is invalid so that we can replicate this over in the stream - // when we re-populate it - if batchStart.BatchType == types.BatchTypeInvalid { - if err = hermezDb.WriteInvalidBatch(batchStart.Number); err != nil { - return err + case entry := <-entryChan: + switch entry := entry.(type) { + case *types.BatchStart: + // check if the batch is invalid so that we can replicate this over in the stream + // when we re-populate it + if entry.BatchType == types.BatchTypeInvalid { + if err = hermezDb.WriteInvalidBatch(entry.Number); err != nil { + return err + } + // we need to write the fork here as well because the batch will never get processed as it is invalid + // but, we need it re-populate our own stream + if err = hermezDb.WriteForkId(entry.Number, entry.ForkId); err != nil { + return err + } } - // we need to write the fork here as well because the batch will never get processed as it is invalid - // but, we need it re-populate our own stream - if err = hermezDb.WriteForkId(batchStart.Number, batchStart.ForkId); err != nil { - return err + case *types.BatchEnd: + if err := writeBatchEnd(hermezDb, entry); err != nil { + return fmt.Errorf("write batch end error: %v", err) } - } - _ = batchStart - case batchEnd := <-batchEndChan: - if batchEnd.LocalExitRoot != emptyHash { - if err := hermezDb.WriteLocalExitRootForBatchNo(batchEnd.Number, batchEnd.LocalExitRoot); err != nil { - return fmt.Errorf("write local exit root for l1 block hash error: %v", err) + case *types.FullL2Block: + if cfg.zkCfg.SyncLimit > 0 && entry.L2BlockNumber >= cfg.zkCfg.SyncLimit { + // stop the node going into a crazy loop + time.Sleep(2 * time.Second) + break LOOP } - } - case l2Block := <-l2BlockChan: - if cfg.zkCfg.SyncLimit > 0 && l2Block.L2BlockNumber >= cfg.zkCfg.SyncLimit { - // stop the node going into a crazy loop - time.Sleep(2 * time.Second) - break LOOP - } - // handle batch boundary changes - we do this here instead of reading the batch start channel because - // channels can be read in random orders which then creates problems in detecting fork changes during - // execution - if l2Block.BatchNumber > highestSeenBatchNo && lastForkId < l2Block.ForkId { - if l2Block.ForkId > HIGHEST_KNOWN_FORK { - message := fmt.Sprintf("unsupported fork id %v received from the data stream", l2Block.ForkId) - panic(message) + // handle batch boundary changes - we do this here instead of reading the batch start channel because + // channels can be read in random orders which then creates problems in detecting fork changes during + // execution + if entry.BatchNumber > highestSeenBatchNo && lastForkId < entry.ForkId { + if entry.ForkId > HIGHEST_KNOWN_FORK { + message := fmt.Sprintf("unsupported fork id %v received from the data stream", entry.ForkId) + panic(message) + } + err = stages.SaveStageProgress(tx, stages.ForkId, entry.ForkId) + if err != nil { + return fmt.Errorf("save stage progress error: %v", err) + } + lastForkId = entry.ForkId + err = hermezDb.WriteForkId(entry.BatchNumber, entry.ForkId) + if err != nil { + return fmt.Errorf("write fork id error: %v", err) + } + // NOTE (RPC): avoided use of 'writeForkIdBlockOnce' by reading instead batch by forkId, and then lowest block number in batch } - err = stages.SaveStageProgress(tx, stages.ForkId, l2Block.ForkId) - if err != nil { - return fmt.Errorf("save stage progress error: %v", err) + + // ignore genesis or a repeat of the last block + if entry.L2BlockNumber == 0 { + continue } - lastForkId = l2Block.ForkId - err = hermezDb.WriteForkId(l2Block.BatchNumber, l2Block.ForkId) - if err != nil { - return fmt.Errorf("write fork id error: %v", err) + // skip but warn on already processed blocks + if entry.L2BlockNumber <= stageProgressBlockNo { + if entry.L2BlockNumber < stageProgressBlockNo { + // only warn if the block is very old, we expect the very latest block to be requested + // when the stage is fired up for the first time + log.Warn(fmt.Sprintf("[%s] Skipping block %d, already processed", logPrefix, entry.L2BlockNumber)) + } + continue } - // NOTE (RPC): avoided use of 'writeForkIdBlockOnce' by reading instead batch by forkId, and then lowest block number in batch - } - // ignore genesis or a repeat of the last block - if l2Block.L2BlockNumber == 0 { - continue - } - // skip but warn on already processed blocks - if l2Block.L2BlockNumber <= stageProgressBlockNo { - if l2Block.L2BlockNumber < stageProgressBlockNo { - // only warn if the block is very old, we expect the very latest block to be requested - // when the stage is fired up for the first time - log.Warn(fmt.Sprintf("[%s] Skipping block %d, already processed", logPrefix, l2Block.L2BlockNumber)) + // skip if we already have this block + if entry.L2BlockNumber < lastBlockHeight+1 { + log.Warn(fmt.Sprintf("[%s] Unwinding to block %d", logPrefix, entry.L2BlockNumber)) + badBlock, err := eriDb.ReadCanonicalHash(entry.L2BlockNumber) + if err != nil { + return fmt.Errorf("failed to get bad block: %v", err) + } + u.UnwindTo(entry.L2BlockNumber, badBlock) } - continue - } - // skip if we already have this block - if l2Block.L2BlockNumber < lastBlockHeight+1 { - log.Warn(fmt.Sprintf("[%s] Unwinding to block %d", logPrefix, l2Block.L2BlockNumber)) - badBlock, err := eriDb.ReadCanonicalHash(l2Block.L2BlockNumber) - if err != nil { - return fmt.Errorf("failed to get bad block: %v", err) + // check for sequential block numbers + if entry.L2BlockNumber != lastBlockHeight+1 { + return fmt.Errorf("block number is not sequential, expected %d, got %d", lastBlockHeight+1, entry.L2BlockNumber) } - u.UnwindTo(l2Block.L2BlockNumber, badBlock) - } - // check for sequential block numbers - if l2Block.L2BlockNumber != lastBlockHeight+1 { - return fmt.Errorf("block number is not sequential, expected %d, got %d", lastBlockHeight+1, l2Block.L2BlockNumber) - } - - // batch boundary - record the highest hashable block number (last block in last full batch) - if l2Block.BatchNumber > highestSeenBatchNo { - highestHashableL2BlockNo = l2Block.L2BlockNumber - 1 - } - highestSeenBatchNo = l2Block.BatchNumber - - /////// DEBUG BISECTION /////// - // exit stage when debug bisection flags set and we're at the limit block - if cfg.zkCfg.DebugLimit > 0 && l2Block.L2BlockNumber > cfg.zkCfg.DebugLimit { - fmt.Printf("[%s] Debug limit reached, stopping stage\n", logPrefix) - endLoop = true - } + // batch boundary - record the highest hashable block number (last block in last full batch) + if entry.BatchNumber > highestSeenBatchNo { + highestHashableL2BlockNo = entry.L2BlockNumber - 1 + } + highestSeenBatchNo = entry.BatchNumber - // if we're above StepAfter, and we're at a step, move the stages on - if cfg.zkCfg.DebugStep > 0 && cfg.zkCfg.DebugStepAfter > 0 && l2Block.L2BlockNumber > cfg.zkCfg.DebugStepAfter { - if l2Block.L2BlockNumber%cfg.zkCfg.DebugStep == 0 { - fmt.Printf("[%s] Debug step reached, stopping stage\n", logPrefix) + /////// DEBUG BISECTION /////// + // exit stage when debug bisection flags set and we're at the limit block + if cfg.zkCfg.DebugLimit > 0 && entry.L2BlockNumber > cfg.zkCfg.DebugLimit { + fmt.Printf("[%s] Debug limit reached, stopping stage\n", logPrefix) endLoop = true } - } - /////// END DEBUG BISECTION /////// - // store our finalized state if this batch matches the highest verified batch number on the L1 - if l2Block.BatchNumber == highestVerifiedBatch { - rawdb.WriteForkchoiceFinalized(tx, l2Block.L2Blockhash) - } + // if we're above StepAfter, and we're at a step, move the stages on + if cfg.zkCfg.DebugStep > 0 && cfg.zkCfg.DebugStepAfter > 0 && entry.L2BlockNumber > cfg.zkCfg.DebugStepAfter { + if entry.L2BlockNumber%cfg.zkCfg.DebugStep == 0 { + fmt.Printf("[%s] Debug step reached, stopping stage\n", logPrefix) + endLoop = true + } + } + /////// END DEBUG BISECTION /////// - if lastHash != emptyHash { - l2Block.ParentHash = lastHash - } else { - // first block in the loop so read the parent hash - previousHash, err := eriDb.ReadCanonicalHash(l2Block.L2BlockNumber - 1) - if err != nil { - return fmt.Errorf("failed to get genesis header: %v", err) + // store our finalized state if this batch matches the highest verified batch number on the L1 + if entry.BatchNumber == highestVerifiedBatch { + rawdb.WriteForkchoiceFinalized(tx, entry.L2Blockhash) } - l2Block.ParentHash = previousHash - } - if err := writeL2Block(eriDb, hermezDb, &l2Block, highestL1InfoTreeIndex); err != nil { - return fmt.Errorf("writeL2Block error: %v", err) - } - dsClientProgress.Store(l2Block.L2BlockNumber) + if lastHash != emptyHash { + entry.ParentHash = lastHash + } else { + // first block in the loop so read the parent hash + previousHash, err := eriDb.ReadCanonicalHash(entry.L2BlockNumber - 1) + if err != nil { + return fmt.Errorf("failed to get genesis header: %v", err) + } + entry.ParentHash = previousHash + } - // make sure to capture the l1 info tree index changes so we can store progress - if uint64(l2Block.L1InfoTreeIndex) > highestL1InfoTreeIndex { - highestL1InfoTreeIndex = uint64(l2Block.L1InfoTreeIndex) - } + if err := writeL2Block(eriDb, hermezDb, entry, highestL1InfoTreeIndex); err != nil { + return fmt.Errorf("writeL2Block error: %v", err) + } + dsClientProgress.Store(entry.L2BlockNumber) - lastHash = l2Block.L2Blockhash + // make sure to capture the l1 info tree index changes so we can store progress + if uint64(entry.L1InfoTreeIndex) > highestL1InfoTreeIndex { + highestL1InfoTreeIndex = uint64(entry.L1InfoTreeIndex) + } - atLeastOneBlockWritten = true - lastBlockHeight = l2Block.L2BlockNumber - blocksWritten++ - progressChan <- blocksWritten + lastHash = entry.L2Blockhash - if endLoop && cfg.zkCfg.DebugLimit > 0 { - break LOOP - } - case gerUpdate := <-gerUpdateChan: - if gerUpdate.GlobalExitRoot == emptyHash { - log.Warn(fmt.Sprintf("[%s] Skipping GER update with empty root", logPrefix)) - break - } + atLeastOneBlockWritten = true + lastBlockHeight = entry.L2BlockNumber + blocksWritten++ + progressChan <- blocksWritten - // NB: we won't get these post Etrog (fork id 7) - if err := hermezDb.WriteBatchGlobalExitRoot(gerUpdate.BatchNumber, gerUpdate); err != nil { - return fmt.Errorf("write batch global exit root error: %v", err) + if endLoop && cfg.zkCfg.DebugLimit > 0 { + break LOOP + } + case *types.GerUpdate: + if entry.GlobalExitRoot == emptyHash { + log.Warn(fmt.Sprintf("[%s] Skipping GER update with empty root", logPrefix)) + break + } + + // NB: we won't get these post Etrog (fork id 7) + if err := hermezDb.WriteBatchGlobalExitRoot(entry.BatchNumber, entry); err != nil { + return fmt.Errorf("write batch global exit root error: %v", err) + } } case <-ctx.Done(): log.Warn(fmt.Sprintf("[%s] Context done", logPrefix)) @@ -761,6 +754,14 @@ func PruneBatchesStage(s *stagedsync.PruneState, tx kv.RwTx, cfg BatchesCfg, ctx return nil } +func writeBatchEnd(hermezDb HermezDb, batchEnd *types.BatchEnd) (err error) { + // utils.CalculateAccInputHash(oldAccInputHash, batchStart., l1InfoRoot common.Hash, timestampLimit uint64, sequencerAddr common.Address, forcedBlockhashL1 common.Hash) + if batchEnd.LocalExitRoot != emptyHash { + err = hermezDb.WriteLocalExitRootForBatchNo(batchEnd.Number, batchEnd.LocalExitRoot) + } + return +} + // writeL2Block writes L2Block to ErigonDb and HermezDb // writes header, body, forkId and blockBatch func writeL2Block(eriDb ErigonDb, hermezDb HermezDb, l2Block *types.FullL2Block, highestL1InfoTreeIndex uint64) error { From 8963aa8d091dfe2872f198bf3c097e70539cedaa Mon Sep 17 00:00:00 2001 From: Valentin Staykov Date: Thu, 1 Aug 2024 13:49:02 +0000 Subject: [PATCH 2/2] fix tests --- .../test_datastream_compare.go | 14 +++--- zk/stages/test_utils.go | 44 +++++-------------- 2 files changed, 19 insertions(+), 39 deletions(-) diff --git a/zk/datastream/test/data_stream_compare/test_datastream_compare.go b/zk/datastream/test/data_stream_compare/test_datastream_compare.go index 9647bc8a8d0..d5093a482c9 100644 --- a/zk/datastream/test/data_stream_compare/test_datastream_compare.go +++ b/zk/datastream/test/data_stream_compare/test_datastream_compare.go @@ -8,6 +8,7 @@ import ( "reflect" "github.com/ledgerwatch/erigon/zk/datastream/client" + "github.com/ledgerwatch/erigon/zk/datastream/types" "github.com/nsf/jsondiff" ) @@ -80,13 +81,14 @@ func readFromClient(client *client.StreamClient, total int) ([]interface{}, erro LOOP: for { - select { - case d := <-client.GetL2BlockChan(): - data = append(data, d) - count++ - case d := <-client.GetGerUpdatesChan(): - data = append(data, d) + entry := <-client.GetEntryChan() + + switch entry.(type) { + case types.FullL2Block: + case types.GerUpdate: + data = append(data, entry) count++ + default: } if count == total { diff --git a/zk/stages/test_utils.go b/zk/stages/test_utils.go index 531292d4e90..9295da00166 100644 --- a/zk/stages/test_utils.go +++ b/zk/stages/test_utils.go @@ -12,22 +12,16 @@ type TestDatastreamClient struct { lastWrittenTimeAtomic atomic.Int64 streamingAtomic atomic.Bool progress atomic.Uint64 - l2BlockChan chan types.FullL2Block - l2TxChan chan types.L2TransactionProto - gerUpdatesChan chan types.GerUpdate + entriesChan chan interface{} errChan chan error - batchStartChan chan types.BatchStart - batchEndChan chan types.BatchEnd } func NewTestDatastreamClient(fullL2Blocks []types.FullL2Block, gerUpdates []types.GerUpdate) *TestDatastreamClient { client := &TestDatastreamClient{ - fullL2Blocks: fullL2Blocks, - gerUpdates: gerUpdates, - l2BlockChan: make(chan types.FullL2Block, 100), - gerUpdatesChan: make(chan types.GerUpdate, 100), - errChan: make(chan error, 100), - batchStartChan: make(chan types.BatchStart, 100), + fullL2Blocks: fullL2Blocks, + gerUpdates: gerUpdates, + entriesChan: make(chan interface{}, 1000), + errChan: make(chan error, 100), } return client @@ -36,40 +30,24 @@ func NewTestDatastreamClient(fullL2Blocks []types.FullL2Block, gerUpdates []type func (c *TestDatastreamClient) ReadAllEntriesToChannel() error { c.streamingAtomic.Store(true) - for _, block := range c.fullL2Blocks { - c.l2BlockChan <- block + for i, _ := range c.fullL2Blocks { + c.entriesChan <- &c.fullL2Blocks[i] } - for _, update := range c.gerUpdates { - c.gerUpdatesChan <- update + for i, _ := range c.gerUpdates { + c.entriesChan <- &c.gerUpdates[i] } return nil } -func (c *TestDatastreamClient) GetL2BlockChan() chan types.FullL2Block { - return c.l2BlockChan -} - -func (c *TestDatastreamClient) GetL2TxChan() chan types.L2TransactionProto { - return c.l2TxChan -} - -func (c *TestDatastreamClient) GetGerUpdatesChan() chan types.GerUpdate { - return c.gerUpdatesChan +func (c *TestDatastreamClient) GetEntryChan() chan interface{} { + return c.entriesChan } func (c *TestDatastreamClient) GetErrChan() chan error { return c.errChan } -func (c *TestDatastreamClient) GetBatchStartChan() chan types.BatchStart { - return c.batchStartChan -} - -func (c *TestDatastreamClient) GetBatchEndChan() chan types.BatchEnd { - return c.batchEndChan -} - func (c *TestDatastreamClient) GetLastWrittenTimeAtomic() *atomic.Int64 { return &c.lastWrittenTimeAtomic }