diff --git a/.github/workflows/ci_zkevm.yml b/.github/workflows/ci_zkevm.yml index 698b871408e..daa2c3efe43 100644 --- a/.github/workflows/ci_zkevm.yml +++ b/.github/workflows/ci_zkevm.yml @@ -23,9 +23,9 @@ jobs: tests: strategy: matrix: - os: [ ubuntu-22.04, macos-14 ] # list of os: https://github.com/actions/virtual-environments + os: [ ubuntu-22.04, macos-14-xlarge ] # list of os: https://github.com/actions/virtual-environments runs-on: ${{ matrix.os }} - timeout-minutes: ${{ matrix.os == 'macos-14' && 40 || 30 }} + timeout-minutes: ${{ matrix.os == 'macos-14-xlarge' && 40 || 30 }} steps: - uses: actions/checkout@v3 diff --git a/Makefile b/Makefile index 21688ce2e4a..c84c657c03f 100644 --- a/Makefile +++ b/Makefile @@ -167,7 +167,7 @@ test-erigon-ext: ## test: run unit tests with a 100s timeout test: - $(GOTEST) --timeout 200s + $(GOTEST) --timeout 10m test3: $(GOTEST) --timeout 200s -tags $(BUILD_TAGS),erigon3 diff --git a/cmd/integration/commands/flags.go b/cmd/integration/commands/flags.go index 0f567cc14e4..43bda19d835 100644 --- a/cmd/integration/commands/flags.go +++ b/cmd/integration/commands/flags.go @@ -35,6 +35,7 @@ var ( unwindTypes []string chain string // Which chain to use (mainnet, goerli, sepolia, etc.) outputCsvFile string + config string commitmentMode string commitmentTrie string @@ -53,7 +54,7 @@ func must(err error) { } func withConfig(cmd *cobra.Command) { - cmd.Flags().String("config", "", "yaml/toml config file location") + cmd.Flags().StringVar(&config, "config", "", "yaml/toml config file location") } func withMining(cmd *cobra.Command) { diff --git a/cmd/integration/commands/stage_stages_zkevm.go b/cmd/integration/commands/stage_stages_zkevm.go index 3a5a7f9dd6d..d6df5d1aaa9 100644 --- a/cmd/integration/commands/stage_stages_zkevm.go +++ b/cmd/integration/commands/stage_stages_zkevm.go @@ -9,11 +9,8 @@ import ( common2 "github.com/ledgerwatch/erigon-lib/common" "github.com/ledgerwatch/erigon-lib/kv" "github.com/ledgerwatch/erigon-lib/wrap" - "github.com/ledgerwatch/erigon/core" - "github.com/ledgerwatch/erigon/eth/ethconfig" "github.com/ledgerwatch/erigon/eth/stagedsync/stages" smtdb "github.com/ledgerwatch/erigon/smt/pkg/db" - erigoncli "github.com/ledgerwatch/erigon/turbo/cli" "github.com/ledgerwatch/erigon/zk/hermez_db" "github.com/ledgerwatch/log/v3" "github.com/spf13/cobra" @@ -29,9 +26,6 @@ state_stages_zkevm --datadir=/datadirs/hermez-mainnet --unwind-batch-no=2 --chai Example: "go run ./cmd/integration state_stages_zkevm --config=... --verbosity=3 --unwind-batch-no=100", Run: func(cmd *cobra.Command, args []string) { ctx, _ := common2.RootContext() - ethConfig := ðconfig.Defaults - ethConfig.Genesis = core.GenesisBlockByChainName(chain) - erigoncli.ApplyFlagsForEthConfigCobra(cmd.Flags(), ethConfig) logger := log.New() db, err := openDB(dbCfg(kv.ChainDB, chaindata), true, logger) if err != nil { diff --git a/cmd/integration/commands/stages_zkevm.go b/cmd/integration/commands/stages_zkevm.go index 92332de1d06..cd59fe35224 100644 --- a/cmd/integration/commands/stages_zkevm.go +++ b/cmd/integration/commands/stages_zkevm.go @@ -2,6 +2,12 @@ package commands import ( "context" + "encoding/json" + "math/big" + "os" + "path" + "path/filepath" + "strings" "github.com/c2h5oh/datasize" chain3 "github.com/ledgerwatch/erigon-lib/chain" @@ -9,6 +15,7 @@ import ( "github.com/ledgerwatch/erigon-lib/kv" "github.com/ledgerwatch/erigon-lib/kv/kvcfg" "github.com/ledgerwatch/erigon/cmd/hack/tool/fromdb" + "github.com/ledgerwatch/erigon/cmd/utils" "github.com/ledgerwatch/erigon/consensus" "github.com/ledgerwatch/erigon/core" "github.com/ledgerwatch/erigon/core/types" @@ -17,6 +24,7 @@ import ( "github.com/ledgerwatch/erigon/eth/stagedsync" "github.com/ledgerwatch/erigon/p2p/sentry" "github.com/ledgerwatch/erigon/p2p/sentry/sentry_multi_client" + "github.com/ledgerwatch/erigon/params" "github.com/ledgerwatch/erigon/turbo/shards" stages2 "github.com/ledgerwatch/erigon/turbo/stages" "github.com/ledgerwatch/erigon/zk/sequencer" @@ -29,7 +37,36 @@ func newSyncZk(ctx context.Context, db kv.RwDB) (consensus.Engine, *vm.Config, * vmConfig := &vm.Config{} - genesis := core.GenesisBlockByChainName(chain) + var genesis *types.Genesis + + if strings.HasPrefix(chain, "dynamic") { + if config == "" { + panic("Config file is required for dynamic chain") + } + + params.DynamicChainConfigPath = filepath.Dir(config) + genesis = core.GenesisBlockByChainName(chain) + filename := path.Join(params.DynamicChainConfigPath, chain+"-conf.json") + + dConf := utils.DynamicConfig{} + + if _, err := os.Stat(filename); err == nil { + dConfBytes, err := os.ReadFile(filename) + if err != nil { + panic(err) + } + if err := json.Unmarshal(dConfBytes, &dConf); err != nil { + panic(err) + } + } + + genesis.Timestamp = dConf.Timestamp + genesis.GasLimit = dConf.GasLimit + genesis.Difficulty = big.NewInt(dConf.Difficulty) + } else { + genesis = core.GenesisBlockByChainName(chain) + } + chainConfig, genesisBlock, genesisErr := core.CommitGenesisBlock(db, genesis, "", log.New()) if _, ok := genesisErr.(*chain3.ConfigCompatError); genesisErr != nil && !ok { panic(genesisErr) diff --git a/cmd/utils/flags.go b/cmd/utils/flags.go index c3b581b14dc..6f88a533dce 100644 --- a/cmd/utils/flags.go +++ b/cmd/utils/flags.go @@ -496,6 +496,11 @@ var ( Usage: "First block to start syncing from on the L1", Value: 0, } + L1FinalizedBlockRequirementFlag = cli.Uint64Flag{ + Name: "zkevm.l1-finalized-block-requirement", + Usage: "The given block must be finalized before sequencer L1 sync continues", + Value: 0, + } L1ContractAddressCheckFlag = cli.BoolFlag{ Name: "zkevm.l1-contract-address-check", Usage: "Check the contract address on the L1", @@ -541,6 +546,21 @@ var ( Usage: "Halt the sequencer on this batch number", Value: 0, } + SequencerResequence = cli.BoolFlag{ + Name: "zkevm.sequencer-resequence", + Usage: "When enabled, the sequencer will automatically resequence unseen batches stored in data stream", + Value: false, + } + SequencerResequenceStrict = cli.BoolFlag{ + Name: "zkevm.sequencer-resequence-strict", + Usage: "Strictly resequence the rolledback batches", + Value: true, + } + SequencerResequenceReuseL1InfoIndex = cli.BoolFlag{ + Name: "zkevm.sequencer-resequence-reuse-l1-info-index", + Usage: "Reuse the L1 info index for resequencing", + Value: true, + } ExecutorUrls = cli.StringFlag{ Name: "zkevm.executor-urls", Usage: "A comma separated list of grpc addresses that host executors", diff --git a/core/state/intra_block_state.go b/core/state/intra_block_state.go index 528dfbe954a..832977cb864 100644 --- a/core/state/intra_block_state.go +++ b/core/state/intra_block_state.go @@ -369,7 +369,7 @@ func (sdb *IntraBlockState) SetCode(addr libcommon.Address, code []byte) { return } - hashedBytecode, _ := utils.HashContractBytecode(hex.EncodeToString(code)) + hashedBytecode := utils.HashContractBytecode(hex.EncodeToString(code)) stateObject.SetCode(libcommon.HexToHash(hashedBytecode), code) } } diff --git a/core/state/trie_db.go b/core/state/trie_db.go index 3657794a198..fb23f799cf8 100644 --- a/core/state/trie_db.go +++ b/core/state/trie_db.go @@ -913,36 +913,16 @@ func (tds *TrieDbState) ResolveSMTRetainList() (*trie.RetainList, error) { for _, addrHash := range accountTouches { addr := common.BytesToAddress(tds.preimageMap[addrHash]).String() - nonceKey, err := utils.KeyEthAddrNonce(addr) - - if err != nil { - return nil, err - } - + nonceKey := utils.KeyEthAddrNonce(addr) keys = append(keys, nonceKey.GetPath()) - balanceKey, err := utils.KeyEthAddrBalance(addr) - - if err != nil { - return nil, err - } - + balanceKey := utils.KeyEthAddrBalance(addr) keys = append(keys, balanceKey.GetPath()) - codeKey, err := utils.KeyContractCode(addr) - - if err != nil { - return nil, err - } - + codeKey := utils.KeyContractCode(addr) keys = append(keys, codeKey.GetPath()) - codeLengthKey, err := utils.KeyContractLength(addr) - - if err != nil { - return nil, err - } - + codeLengthKey := utils.KeyContractLength(addr) keys = append(keys, codeLengthKey.GetPath()) } @@ -950,11 +930,7 @@ func (tds *TrieDbState) ResolveSMTRetainList() (*trie.RetainList, error) { a := utils.ConvertHexToBigInt(ethAddr) addr := utils.ScalarToArrayBig(a) - storageKey, err := utils.KeyContractStorage(addr, key) - - if err != nil { - return nil, err - } + storageKey := utils.KeyContractStorage(addr, key) return storageKey.GetPath(), nil } diff --git a/eth/ethconfig/config_zkevm.go b/eth/ethconfig/config_zkevm.go index 3d171233dc8..6c6bab30bef 100644 --- a/eth/ethconfig/config_zkevm.go +++ b/eth/ethconfig/config_zkevm.go @@ -28,6 +28,7 @@ type Zk struct { L1HighestBlockType string L1MaticContractAddress common.Address L1FirstBlock uint64 + L1FinalizedBlockRequirement uint64 L1CacheEnabled bool L1CachePort uint RpcRateLimits int @@ -38,6 +39,9 @@ type Zk struct { SequencerBatchVerificationTimeout time.Duration SequencerTimeoutOnEmptyTxPool time.Duration SequencerHaltOnBatchNumber uint64 + SequencerResequence bool + SequencerResequenceStrict bool + SequencerResequenceReuseL1InfoIndex bool ExecutorUrls []string ExecutorStrictMode bool ExecutorRequestTimeout time.Duration diff --git a/eth/tracers/logger/json_stream_zkevm.go b/eth/tracers/logger/json_stream_zkevm.go index b7ba1b77baa..3498bd41a5a 100644 --- a/eth/tracers/logger/json_stream_zkevm.go +++ b/eth/tracers/logger/json_stream_zkevm.go @@ -33,6 +33,7 @@ type JsonStreamLogger_ZkEvm struct { counterCollector *vm.CounterCollector stateClosed bool + memSize int } // NewStructLogger returns a new logger @@ -204,36 +205,23 @@ func (l *JsonStreamLogger_ZkEvm) writeMemory(memory *vm.Memory) { if !l.cfg.DisableMemory { memData := memory.Data() - //[zkevm] don't print empty bytes in memory array after the last non-empty byte line - filteredByteLines := [][]byte{} - foundValueLine := false - for i := len(memData); i-32 >= 0; i -= 32 { - bytes := memData[i-32 : i] - - isEmpty := true - if !foundValueLine { - for _, b := range bytes { - if b != 0 { - isEmpty = false - foundValueLine = true - break - } - } - } - - if !isEmpty || foundValueLine { - filteredByteLines = append(filteredByteLines, bytes) - } + // on first occurance don't expand memory + // this is because in interpreter we expand the memory before we execute the opcode + // and the state for traced opcode should be before the execution of the opcode + if l.memSize < len(memData) { + size := len(memData) + memData = memData[:l.memSize] + l.memSize = size } l.stream.WriteMore() l.stream.WriteObjectField("memory") l.stream.WriteArrayStart() - for i := len(filteredByteLines) - 1; i >= 0; i-- { - if i != len(filteredByteLines)-1 { + for i := len(memData); i-32 >= 0; i -= 32 { + if i != len(memData) { // first 32 bytes, don't add a comma l.stream.WriteMore() } - l.stream.WriteString(string(l.hexEncodeBuf[0:hex.Encode(l.hexEncodeBuf[:], filteredByteLines[i])])) + l.stream.WriteString(string(l.hexEncodeBuf[0:hex.Encode(l.hexEncodeBuf[:], memData[i-32:i])])) } l.stream.WriteArrayEnd() diff --git a/rpc/types.go b/rpc/types.go index ca54feb8b41..7a6f5759caf 100644 --- a/rpc/types.go +++ b/rpc/types.go @@ -480,3 +480,11 @@ func (ts *Timestamp) UnmarshalJSON(data []byte) error { return nil } + +type ForkInterval struct { + ForkId hexutil.Uint64 `json:"forkId"` + FromBatchNumber hexutil.Uint64 `json:"fromBatchNumber"` + ToBatchNumber hexutil.Uint64 `json:"toBatchNumber"` + Version string `json:"version"` + BlockNumber hexutil.Uint64 `json:"blockNumber"` +} diff --git a/smt/pkg/blockinfo/block_info.go b/smt/pkg/blockinfo/block_info.go index fc3a1f7cdd9..d9deddde26f 100644 --- a/smt/pkg/blockinfo/block_info.go +++ b/smt/pkg/blockinfo/block_info.go @@ -325,11 +325,7 @@ func (b *BlockInfoTree) GenerateBlockTxKeysVals( logToEncode := "0x" + hex.EncodeToString(rLog.Data) + reducedTopics - hash, err := utils.HashContractBytecode(logToEncode) - if err != nil { - return nil, nil, err - } - + hash := utils.HashContractBytecode(logToEncode) logEncodedBig := utils.ConvertHexToBigInt(hash) key, val, err = generateTxLog(txIndexBig, big.NewInt(logIndex), logEncodedBig) if err != nil { diff --git a/smt/pkg/blockinfo/keys.go b/smt/pkg/blockinfo/keys.go index af9cfdef475..3f67fa4b58d 100644 --- a/smt/pkg/blockinfo/keys.go +++ b/smt/pkg/blockinfo/keys.go @@ -55,14 +55,8 @@ func KeyTxLogs(txIndex, logIndex *big.Int) (*utils.NodeKey, error) { return nil, err } - hk0, err := utils.Hash(lia.ToUintArray(), utils.BranchCapacity) - if err != nil { - return nil, err - } - hkRes, err := utils.Hash(key1.ToUintArray(), hk0) - if err != nil { - return nil, err - } + hk0 := utils.Hash(lia.ToUintArray(), utils.BranchCapacity) + hkRes := utils.Hash(key1.ToUintArray(), hk0) return &utils.NodeKey{hkRes[0], hkRes[1], hkRes[2], hkRes[3]}, nil } diff --git a/smt/pkg/db/mem-db.go b/smt/pkg/db/mem-db.go index a3d38be8626..949f267b402 100644 --- a/smt/pkg/db/mem-db.go +++ b/smt/pkg/db/mem-db.go @@ -253,11 +253,7 @@ func (m *MemDb) AddCode(code []byte) error { m.lock.Lock() // Lock for writing defer m.lock.Unlock() // Make sure to unlock when done - codeHash, err := utils.HashContractBytecode(hex.EncodeToString(code)) - if err != nil { - return err - } - + codeHash := utils.HashContractBytecode(hex.EncodeToString(code)) m.DbCode[codeHash] = code return nil } diff --git a/smt/pkg/smt/entity_storage.go b/smt/pkg/smt/entity_storage.go index 13c00219b5a..089f41d798a 100644 --- a/smt/pkg/smt/entity_storage.go +++ b/smt/pkg/smt/entity_storage.go @@ -15,39 +15,29 @@ import ( ) func (s *SMT) SetAccountState(ethAddr string, balance, nonce *big.Int) (*big.Int, error) { - keyBalance, err := utils.KeyEthAddrBalance(ethAddr) - if err != nil { - return nil, err - } - keyNonce, err := utils.KeyEthAddrNonce(ethAddr) - if err != nil { - return nil, err - } + keyBalance := utils.KeyEthAddrBalance(ethAddr) + keyNonce := utils.KeyEthAddrNonce(ethAddr) - _, err = s.InsertKA(keyBalance, balance) - if err != nil { + if _, err := s.InsertKA(keyBalance, balance); err != nil { return nil, err } ks := utils.EncodeKeySource(utils.KEY_BALANCE, utils.ConvertHexToAddress(ethAddr), common.Hash{}) - err = s.Db.InsertKeySource(keyBalance, ks) - if err != nil { + if err := s.Db.InsertKeySource(keyBalance, ks); err != nil { return nil, err } auxRes, err := s.InsertKA(keyNonce, nonce) - if err != nil { return nil, err } ks = utils.EncodeKeySource(utils.KEY_NONCE, utils.ConvertHexToAddress(ethAddr), common.Hash{}) - err = s.Db.InsertKeySource(keyNonce, ks) - if err != nil { + if err := s.Db.InsertKeySource(keyNonce, ks); err != nil { return nil, err } - return auxRes.NewRootScalar.ToBigInt(), err + return auxRes.NewRootScalar.ToBigInt(), nil } func (s *SMT) SetAccountStorage(addr libcommon.Address, acc *accounts.Account) error { @@ -62,14 +52,8 @@ func (s *SMT) SetAccountStorage(addr libcommon.Address, acc *accounts.Account) e } func (s *SMT) SetContractBytecode(ethAddr string, bytecode string) error { - keyContractCode, err := utils.KeyContractCode(ethAddr) - if err != nil { - return err - } - keyContractLength, err := utils.KeyContractLength(ethAddr) - if err != nil { - return err - } + keyContractCode := utils.KeyContractCode(ethAddr) + keyContractLength := utils.KeyContractLength(ethAddr) bi, bytecodeLength, err := convertBytecodeToBigInt(bytecode) if err != nil { @@ -205,6 +189,7 @@ func (s *SMT) SetContractStorage(ethAddr string, storage map[string]string, prog func (s *SMT) SetStorage(ctx context.Context, logPrefix string, accChanges map[libcommon.Address]*accounts.Account, codeChanges map[libcommon.Address]string, storageChanges map[libcommon.Address]map[string]string) ([]*utils.NodeKey, []*utils.NodeValue8, error) { var isDelete bool + var err error storageChangesInitialCapacity := 0 for _, storage := range storageChanges { @@ -222,14 +207,8 @@ func (s *SMT) SetStorage(ctx context.Context, logPrefix string, accChanges map[l default: } ethAddr := addr.String() - keyBalance, err := utils.KeyEthAddrBalance(ethAddr) - if err != nil { - return nil, nil, err - } - keyNonce, err := utils.KeyEthAddrNonce(ethAddr) - if err != nil { - return nil, nil, err - } + keyBalance := utils.KeyEthAddrBalance(ethAddr) + keyNonce := utils.KeyEthAddrNonce(ethAddr) balance := big.NewInt(0) nonce := big.NewInt(0) @@ -276,14 +255,8 @@ func (s *SMT) SetStorage(ctx context.Context, logPrefix string, accChanges map[l } ethAddr := addr.String() - keyContractCode, err := utils.KeyContractCode(ethAddr) - if err != nil { - return nil, nil, err - } - keyContractLength, err := utils.KeyContractLength(ethAddr) - if err != nil { - return nil, nil, err - } + keyContractCode := utils.KeyContractCode(ethAddr) + keyContractLength := utils.KeyContractLength(ethAddr) bi, bytecodeLength, err := convertBytecodeToBigInt(code) if err != nil { @@ -330,11 +303,7 @@ func (s *SMT) SetStorage(ctx context.Context, logPrefix string, accChanges map[l ethAddrBigIngArray := utils.ScalarToArrayBig(ethAddrBigInt) for k, v := range storage { - keyStoragePosition, err := utils.KeyContractStorage(ethAddrBigIngArray, k) - if err != nil { - return nil, nil, err - } - + keyStoragePosition := utils.KeyContractStorage(ethAddrBigIngArray, k) valueBigInt := convertStrintToBigInt(v) keysBatchStorage = append(keysBatchStorage, &keyStoragePosition) if valuesBatchStorage, isDelete, err = appendToValuesBatchStorageBigInt(valuesBatchStorage, valueBigInt); err != nil { @@ -355,8 +324,11 @@ func (s *SMT) SetStorage(ctx context.Context, logPrefix string, accChanges map[l } insertBatchCfg := NewInsertBatchConfig(ctx, logPrefix, true) - _, err := s.InsertBatch(insertBatchCfg, keysBatchStorage, valuesBatchStorage, nil, nil) - return keysBatchStorage, valuesBatchStorage, err + if _, err = s.InsertBatch(insertBatchCfg, keysBatchStorage, valuesBatchStorage, nil, nil); err != nil { + return nil, nil, err + } + + return keysBatchStorage, valuesBatchStorage, nil } func (s *SMT) InsertKeySource(nodeKey *utils.NodeKey, key int, accountAddr *libcommon.Address, storagePosition *libcommon.Hash) error { @@ -377,10 +349,7 @@ func calcHashVal(v string) (*utils.NodeValue8, [4]uint64, error) { return nil, [4]uint64{}, err } - h, err := utils.Hash(value.ToUintArray(), utils.BranchCapacity) - if err != nil { - return nil, [4]uint64{}, err - } + h := utils.Hash(value.ToUintArray(), utils.BranchCapacity) return value, h, nil } @@ -405,12 +374,8 @@ func appendToValuesBatchStorageBigInt(valuesBatchStorage []*utils.NodeValue8, va } func convertBytecodeToBigInt(bytecode string) (*big.Int, int, error) { - hashedBytecode, err := utils.HashContractBytecode(bytecode) - if err != nil { - return nil, 0, err - } - var parsedBytecode string + hashedBytecode := utils.HashContractBytecode(bytecode) if strings.HasPrefix(bytecode, "0x") { parsedBytecode = bytecode[2:] diff --git a/smt/pkg/smt/entity_storage_test.go b/smt/pkg/smt/entity_storage_test.go index ca057a5674e..c45a12fc313 100644 --- a/smt/pkg/smt/entity_storage_test.go +++ b/smt/pkg/smt/entity_storage_test.go @@ -194,11 +194,7 @@ func Test_SetContractBytecode_HashBytecode(t *testing.T) { expected := "0x9257c9a31308a7cb046aba1a95679dd7e3ad695b6900e84a6470b401b1ea416e" - hashedBytecode, err := utils.HashContractBytecode(byteCode) - if err != nil { - t.Errorf("setContractBytecode failed: %v", err) - } - + hashedBytecode := utils.HashContractBytecode(byteCode) if hashedBytecode != expected { t.Errorf("setContractBytecode failed: expected %v, got %v", expected, hashedBytecode) } diff --git a/smt/pkg/smt/proof.go b/smt/pkg/smt/proof.go index 1cc7bcad31f..6558385e4a5 100644 --- a/smt/pkg/smt/proof.go +++ b/smt/pkg/smt/proof.go @@ -127,11 +127,8 @@ func VerifyAndGetVal(stateRoot utils.NodeKey, proof []hexutility.Bytes, key util } path := key.GetPath() - curRoot := stateRoot - foundValue := false - for i := 0; i < len(proof); i++ { isFinalNode := len(proof[i]) == 65 @@ -147,12 +144,7 @@ func VerifyAndGetVal(stateRoot utils.NodeKey, proof []hexutility.Bytes, key util leftChildNode := [4]uint64{leftChild[0], leftChild[1], leftChild[2], leftChild[3]} rightChildNode := [4]uint64{rightChild[0], rightChild[1], rightChild[2], rightChild[3]} - h, err := utils.Hash(utils.ConcatArrays4(leftChildNode, rightChildNode), capacity) - - if err != nil { - return nil, err - } - + h := utils.Hash(utils.ConcatArrays4(leftChildNode, rightChildNode), capacity) if curRoot != h { return nil, fmt.Errorf("root mismatch at level %d, expected %d, got %d", i, curRoot, h) } @@ -193,12 +185,7 @@ func VerifyAndGetVal(stateRoot utils.NodeKey, proof []hexutility.Bytes, key util return nil, err } - h, err := utils.Hash(nodeValue.ToUintArray(), utils.BranchCapacity) - - if err != nil { - return nil, err - } - + h := utils.Hash(nodeValue.ToUintArray(), utils.BranchCapacity) if h != curRoot { return nil, fmt.Errorf("root mismatch at level %d, expected %d, got %d", len(proof)-1, curRoot, h) } diff --git a/smt/pkg/smt/proof_test.go b/smt/pkg/smt/proof_test.go index 29fe762f718..0c68b3ae612 100644 --- a/smt/pkg/smt/proof_test.go +++ b/smt/pkg/smt/proof_test.go @@ -74,12 +74,7 @@ func TestVerifyAndGetVal(t *testing.T) { root := utils.ScalarToRoot(smtRoot) t.Run("Value exists and proof is correct", func(t *testing.T) { - storageKey, err := utils.KeyContractStorage(address, libcommon.HexToHash("0x5").String()) - - if err != nil { - t.Fatalf("KeyContractStorage() error = %v", err) - } - + storageKey := utils.KeyContractStorage(address, libcommon.HexToHash("0x5").String()) storageProof := smt.FilterProofs(proofs, storageKey) val, err := smt.VerifyAndGetVal(root, storageProof, storageKey) @@ -101,19 +96,12 @@ func TestVerifyAndGetVal(t *testing.T) { // Fuzz with 1000 non-existent keys for i := 0; i < 1000; i++ { - nonExistentKey, err := utils.KeyContractStorage( + nonExistentKey := utils.KeyContractStorage( address, libcommon.HexToHash(fmt.Sprintf("0xdeadbeefabcd1234%d", i)).String(), ) - nonExistentKeys = append(nonExistentKeys, nonExistentKey) - - if err != nil { - t.Fatalf("KeyContractStorage() error = %v", err) - } - nonExistentKeyPath := nonExistentKey.GetPath() - keyBytes := make([]byte, 0, len(nonExistentKeyPath)) for _, v := range nonExistentKeyPath { @@ -144,7 +132,7 @@ func TestVerifyAndGetVal(t *testing.T) { t.Run("Value doesn't exist but non-existent proof is insufficient", func(t *testing.T) { nonExistentRl := trie.NewRetainList(0) - nonExistentKey, _ := utils.KeyContractStorage(address, libcommon.HexToHash("0x999").String()) + nonExistentKey := utils.KeyContractStorage(address, libcommon.HexToHash("0x999").String()) nonExistentKeyPath := nonExistentKey.GetPath() keyBytes := make([]byte, 0, len(nonExistentKeyPath)) @@ -177,7 +165,7 @@ func TestVerifyAndGetVal(t *testing.T) { }) t.Run("Value exists but proof is incorrect (first value corrupted)", func(t *testing.T) { - storageKey, _ := utils.KeyContractStorage(address, libcommon.HexToHash("0x5").String()) + storageKey := utils.KeyContractStorage(address, libcommon.HexToHash("0x5").String()) storageProof := smt.FilterProofs(proofs, storageKey) // Corrupt the proof by changing a byte @@ -195,7 +183,7 @@ func TestVerifyAndGetVal(t *testing.T) { }) t.Run("Value exists but proof is incorrect (last value corrupted)", func(t *testing.T) { - storageKey, _ := utils.KeyContractStorage(address, libcommon.HexToHash("0x5").String()) + storageKey := utils.KeyContractStorage(address, libcommon.HexToHash("0x5").String()) storageProof := smt.FilterProofs(proofs, storageKey) // Corrupt the proof by changing the last byte of the last proof element @@ -216,7 +204,7 @@ func TestVerifyAndGetVal(t *testing.T) { }) t.Run("Value exists but proof is insufficient", func(t *testing.T) { - storageKey, _ := utils.KeyContractStorage(address, libcommon.HexToHash("0x5").String()) + storageKey := utils.KeyContractStorage(address, libcommon.HexToHash("0x5").String()) storageProof := smt.FilterProofs(proofs, storageKey) // Modify the proof to claim the value doesn't exist diff --git a/smt/pkg/smt/smt.go b/smt/pkg/smt/smt.go index 0a2733b7103..a978878a20d 100644 --- a/smt/pkg/smt/smt.go +++ b/smt/pkg/smt/smt.go @@ -170,11 +170,7 @@ func (s *SMT) InsertStorage(ethAddr string, storage *map[string]string, chm *map NewRootScalar: &or, } for k := range *storage { - keyStoragePosition, err := utils.KeyContractStorage(add, k) - if err != nil { - return nil, err - } - + keyStoragePosition := utils.KeyContractStorage(add, k) smtr, err = s.insert(keyStoragePosition, *(*chm)[k], (*vhm)[k], *smtr.NewRootScalar) if err != nil { return nil, err @@ -538,14 +534,12 @@ func (s *SMT) insert(k utils.NodeKey, v utils.NodeValue8, newValH [4]uint64, old } func prepareHashValueForSave(in [8]uint64, capacity [4]uint64) utils.NodeValue12 { - var sl []uint64 - sl = append(sl, in[:]...) - sl = append(sl, capacity[:]...) - v := utils.NodeValue12{} - for i, val := range sl { - b := new(big.Int) - v[i] = b.SetUint64(val) + for i, val := range in { + v[i] = new(big.Int).SetUint64(val) + } + for i, val := range capacity { + v[i+8] = new(big.Int).SetUint64(val) } return v @@ -561,20 +555,12 @@ func (s *SMT) hashSave(in [8]uint64, capacity, h [4]uint64) error { } func (s *SMT) hashcalcAndSave(in [8]uint64, capacity [4]uint64) ([4]uint64, error) { - h, err := utils.Hash(in, capacity) - if err != nil { - return [4]uint64{}, err - } - + h := utils.Hash(in, capacity) return h, s.hashSave(in, capacity, h) } func hashCalcAndPrepareForSave(in [8]uint64, capacity [4]uint64) ([4]uint64, utils.NodeValue12, error) { - h, err := utils.Hash(in, capacity) - if err != nil { - return [4]uint64{}, utils.NodeValue12{}, err - } - + h := utils.Hash(in, capacity) return h, prepareHashValueForSave(in, capacity), nil } diff --git a/smt/pkg/smt/smt_batch.go b/smt/pkg/smt/smt_batch.go index 4ee0e071f2a..a08b6765994 100644 --- a/smt/pkg/smt/smt_batch.go +++ b/smt/pkg/smt/smt_batch.go @@ -58,7 +58,7 @@ func (s *SMT) InsertBatch(cfg InsertBatchConfig, nodeKeys []*utils.NodeKey, node } progressChanPre <- uint64(1) - if err = calculateNodeValueHashesIfMissing(s, nodeValues, &nodeValuesHashes); err != nil { + if err = calculateNodeValueHashesIfMissing(nodeValues, &nodeValuesHashes); err != nil { return nil, err } progressChanPre <- uint64(1) @@ -313,7 +313,7 @@ func removeDuplicateEntriesByKeys(size *int, nodeKeys *[]*utils.NodeKey, nodeVal return nil } -func calculateNodeValueHashesIfMissing(s *SMT, nodeValues []*utils.NodeValue8, nodeValuesHashes *[]*[4]uint64) error { +func calculateNodeValueHashesIfMissing(nodeValues []*utils.NodeValue8, nodeValuesHashes *[]*[4]uint64) error { var globalError error size := len(nodeValues) cpuNum := parallel.DefaultNumGoroutines() @@ -353,11 +353,7 @@ func calculateNodeValueHashesIfMissingInInterval(nodeValues []*utils.NodeValue8, continue } - nodeValueHashObj, err := utils.Hash(nodeValues[i].ToUintArray(), utils.BranchCapacity) - if err != nil { - return err - } - + nodeValueHashObj := utils.Hash(nodeValues[i].ToUintArray(), utils.BranchCapacity) (*nodeValuesHashes)[i] = &nodeValueHashObj } diff --git a/smt/pkg/smt/smt_batch_advance_test.go b/smt/pkg/smt/smt_batch_advance_test.go new file mode 100644 index 00000000000..59c09de2324 --- /dev/null +++ b/smt/pkg/smt/smt_batch_advance_test.go @@ -0,0 +1,194 @@ +package smt_test + +import ( + "context" + "fmt" + "math/big" + "testing" + + "github.com/ledgerwatch/erigon/smt/pkg/smt" + "github.com/ledgerwatch/erigon/smt/pkg/utils" + "gotest.tools/v3/assert" +) + +func TestBatchWitness(t *testing.T) { + keys := []utils.NodeKey{ + utils.NodeKey{17822804428864912231, 4683868963463720294, 2947512351908939790, 2330225637707749973}, + utils.NodeKey{15928606457751385034, 926210564408807848, 3634217732472610234, 18021748560357139965}, + utils.NodeKey{1623861826376204094, 570263533561698889, 4654109133431364496, 7281957057362652730}, + utils.NodeKey{13644513224119225920, 15807577943241006501, 9942496498562648573, 15190659753926523377}, + utils.NodeKey{9275812266666786730, 4204572028245381139, 3605834086260069958, 10007478335141208804}, + utils.NodeKey{8235907590678154663, 6691762687086189695, 15487167600723075149, 10984821506434298343}, + utils.NodeKey{16417603439618455829, 5362127645905990998, 10661203900902368419, 16076124886006448905}, + utils.NodeKey{11707747219427568787, 933117036015558858, 16439357349021750126, 14064521656451211675}, + utils.NodeKey{10768458483543229763, 12393104588695647110, 7306859896719697582, 4178785141502415085}, + utils.NodeKey{7512520260500009967, 3751662918911081259, 9113133324668552163, 12072005766952080289}, + utils.NodeKey{9944065905482556519, 8594459084728876791, 17786637052462706859, 15521772847998069525}, + utils.NodeKey{5036431633232956882, 16658186702978753823, 2870215478624537606, 11907126160741124846}, + utils.NodeKey{17938814940856978076, 13147879352039549979, 1303554763666506875, 14953772317105337015}, + utils.NodeKey{17398863357602626404, 4841219907503399295, 2992012704517273588, 16471435007473943078}, + utils.NodeKey{4763654225644445738, 5354841943603308259, 16476366216865814029, 10492509060169249179}, + utils.NodeKey{3554925909441560661, 16583852156861238748, 15693104712527552035, 8799937559790156794}, + utils.NodeKey{9617343367546549815, 6562355304138083186, 4016039301486039807, 10864657160754550133}, + utils.NodeKey{17933907347870222658, 16190350511466382228, 13330881818854499962, 1410294862891786839}, + utils.NodeKey{17260204906255015513, 15380909239227623493, 8567606678138088594, 4899143890802672405}, + utils.NodeKey{12539511585850227228, 3973200204826286539, 8108069613182344498, 11385621942985713904}, + utils.NodeKey{5984161349947667925, 7514232801604484380, 16331057190188025237, 2178913139230121631}, + utils.NodeKey{1993407781442332939, 1513605408256072860, 9533711780544200094, 4407755968940168245}, + utils.NodeKey{10660689026092155967, 7772873226204509526, 940412750970337957, 11934396459574454979}, + utils.NodeKey{13517500090161376813, 3430655983873553997, 5375259408796912397, 1582918923617071297}, + utils.NodeKey{1530581473737529386, 12702896566116465736, 5914767264290477911, 17646414071976395527}, + utils.NodeKey{16058468518382574435, 17573595348125839734, 14299084025723850432, 9173086175977268459}, + utils.NodeKey{3492167051156683621, 5113280701490269535, 3519293511105800335, 4519124618482063071}, + utils.NodeKey{18174025977752953446, 170880634573707059, 1420648486923115869, 7650935848186468717}, + utils.NodeKey{16208859541132551432, 6618660032536205153, 10385910322459208315, 8083618043937979883}, + utils.NodeKey{18055381843795531980, 13462709273291510955, 680380512647919587, 11342529403284590651}, + utils.NodeKey{14208409806025064162, 3405833321788641051, 10002545051615441056, 3286956713137532874}, + utils.NodeKey{5680425176740212736, 8706205589048866541, 1439054882559309464, 17935966873927915285}, + utils.NodeKey{110533614413158858, 1569162572987050699, 17606018854685897411, 14063722484766563720}, + utils.NodeKey{11233753640608616570, 12359586935502800882, 9900310098552340970, 2424696158120948624}, + utils.NodeKey{17470957289258137535, 89496548814733839, 13431046055752824170, 4863600257776330164}, + utils.NodeKey{12096080439449907754, 3586504186348650027, 16024032131582461863, 3698791599656620348}, + utils.NodeKey{12011265607191854676, 16995709771660398040, 10097323095148987140, 5271835541457063617}, + utils.NodeKey{13774341565485367328, 12574592232097177017, 13203533943886016969, 15689605306663468445}, + utils.NodeKey{17673889518692219847, 6954332541823247394, 954524149166700463, 10005323665613190430}, + utils.NodeKey{3390665384912132081, 273113266583762518, 15391923996500582086, 16937300536792272468}, + utils.NodeKey{3282365570547600329, 2269401659256178523, 12133143125482037239, 9431318293795439322}, + utils.NodeKey{10308056630015396434, 9302651503878791339, 1753436441509383136, 12655301298828119054}, + utils.NodeKey{4866095004323601391, 7715812469294898395, 13448442241363136994, 12560331541471347748}, + utils.NodeKey{9555357893875481640, 14044231432423634485, 2076021859364793876, 2098251167883986095}, + utils.NodeKey{13166561572768359955, 8774399027495495913, 17115924986198600732, 14679213838814779978}, + utils.NodeKey{1830856192880052688, 16817835989594317540, 6792141515706996611, 13263912888227522233}, + utils.NodeKey{8580776493878106180, 13275268150083925070, 1298114825004489111, 6818033484593972896}, + utils.NodeKey{2562799672200229655, 18444468184514201072, 17883941549041529369, 4070387813552736545}, + utils.NodeKey{9268691730026813326, 11545055880246569979, 1187823334319829775, 17259421874098825958}, + utils.NodeKey{9994578653598857505, 13890799434279521010, 6971431511534499255, 9998397274436059169}, + utils.NodeKey{18287575540870662480, 11943532407729972209, 15340299232888708073, 10838674117466297196}, + utils.NodeKey{14761821088000158583, 964796443048506502, 5721781221240658401, 13211032425907534953}, + utils.NodeKey{18144880475727242601, 4972225809077124674, 14334455111087919063, 8111397810232896953}, + utils.NodeKey{16933784929062172058, 9574268379822183272, 4944644580885359493, 3289128208877342006}, + utils.NodeKey{8619895206600224966, 15003370087833528133, 8252241585179054714, 9201580897217580981}, + utils.NodeKey{16332458695522739594, 7936008380823170261, 1848556403564669799, 17993420240804923523}, + utils.NodeKey{6515233280772008301, 4313177990083710387, 4012549955023285042, 12696650320500651942}, + utils.NodeKey{6070193153822371132, 14833198544694594099, 8041604520195724295, 569408677969141468}, + utils.NodeKey{18121124933744588643, 14019823252026845797, 497098216249706813, 14507670067050817524}, + utils.NodeKey{10768458483543229763, 12393104588695647110, 7306859896719697582, 4178785141502415085}, + utils.NodeKey{7512520260500009967, 3751662918911081259, 9113133324668552163, 12072005766952080289}, + utils.NodeKey{5911840277575969690, 14631288768946722660, 9289463458792995190, 11361263549285604206}, + utils.NodeKey{5112807231234019664, 3952289862952962911, 12826043220050158925, 4455878876833215993}, + utils.NodeKey{16417603439618455829, 5362127645905990998, 10661203900902368419, 16076124886006448905}, + utils.NodeKey{11707747219427568787, 933117036015558858, 16439357349021750126, 14064521656451211675}, + utils.NodeKey{16208859541132551432, 6618660032536205153, 10385910322459208315, 8083618043937979883}, + utils.NodeKey{18055381843795531980, 13462709273291510955, 680380512647919587, 11342529403284590651}, + utils.NodeKey{2562799672200229655, 18444468184514201072, 17883941549041529369, 4070387813552736545}, + utils.NodeKey{16339509425341743973, 7562720126843377837, 6087776866015284100, 13287333209707648581}, + utils.NodeKey{1830856192880052688, 16817835989594317540, 6792141515706996611, 13263912888227522233}, + } + + valuesTemp := [][8]uint64{ + [8]uint64{0, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{0, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{0, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{2802548736, 3113182143, 10842021, 0, 0, 0, 0, 0}, + [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{0, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{0, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{552894464, 46566, 0, 0, 0, 0, 0, 0}, + [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{0, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{0, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{4, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{0, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{8, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{0, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{0, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{92883624, 129402807, 3239216982, 1921492768, 41803744, 3662741242, 922499619, 611206845}, + [8]uint64{2149, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{1220686685, 2241513088, 3059933278, 877008478, 3450374550, 2577819195, 3646855908, 1714882695}, + [8]uint64{433, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{1807748760, 2873297298, 945201229, 411604167, 1063664423, 1763702642, 2637524917, 1284041408}, + [8]uint64{2112, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{2407450438, 2021315520, 3591671307, 1981785129, 893348094, 802675915, 3804752326, 2006944699}, + [8]uint64{2583, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{1400751902, 190749285, 93436423, 2918498711, 3630577401, 3928294404, 1037307865, 2336717508}, + [8]uint64{10043, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{2040622618, 1654767043, 2359080366, 3993652948, 2990917507, 41202511, 3266270425, 2537679611}, + [8]uint64{2971, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{2958032465, 981708138, 2081777150, 750201226, 3046928486, 2765783602, 2851559840, 1406574120}, + [8]uint64{23683, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{1741943335, 1540916232, 1327285029, 2450002482, 2695899944, 0, 0, 0}, + [8]uint64{3109587049, 2273239893, 220080300, 1823520391, 35937659, 0, 0, 0}, + [8]uint64{1677672755, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{337379899, 3225725520, 234013414, 1425864754, 2013026225, 0, 0, 0}, + [8]uint64{1031512883, 3743101878, 2828268606, 2468973124, 1081703471, 0, 0, 0}, + [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{256, 481884672, 2932392155, 111365737, 1511099657, 224351860, 164, 0}, + [8]uint64{632216695, 2300948800, 3904328458, 2148496278, 971473112, 0, 0, 0}, + [8]uint64{1031512883, 3743101878, 2828268606, 2468973124, 1081703471, 0, 0, 0}, + [8]uint64{4, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{2401446452, 1128446136, 4183588423, 3903755242, 16083787, 848717237, 2276372267, 2020002041}, + [8]uint64{2793696421, 3373683791, 3597304417, 3609426094, 2371386802, 1021540367, 828590482, 1599660962}, + [8]uint64{2793696421, 3373683791, 3597304417, 3609426094, 2371386802, 1021540367, 828590482, 1599660962}, + [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{2793696421, 3373683791, 3597304417, 3609426094, 2371386802, 1021540367, 828590482, 1599660962}, + [8]uint64{2793696421, 3373683791, 3597304417, 3609426094, 2371386802, 1021540367, 828590482, 1599660962}, + [8]uint64{86400, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{0, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{0, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{1321730048, 465661287, 0, 0, 0, 0, 0, 0}, + [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{1480818688, 2647520856, 10842021, 0, 0, 0, 0, 0}, + [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{2407450438, 2021315520, 3591671307, 1981785129, 893348094, 802675915, 3804752326, 2006944699}, + [8]uint64{2583, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{2, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{4210873971, 1869123984, 4035019538, 1823911763, 1097145772, 827956438, 819220988, 1111695650}, + [8]uint64{20, 0, 0, 0, 0, 0, 0, 0}, + } + + values := make([]utils.NodeValue8, 0) + for _, vT := range valuesTemp { + values = append(values, utils.NodeValue8{ + big.NewInt(0).SetUint64(vT[0]), + big.NewInt(0).SetUint64(vT[1]), + big.NewInt(0).SetUint64(vT[2]), + big.NewInt(0).SetUint64(vT[3]), + big.NewInt(0).SetUint64(vT[4]), + big.NewInt(0).SetUint64(vT[5]), + big.NewInt(0).SetUint64(vT[6]), + big.NewInt(0).SetUint64(vT[7]), + }) + } + + smtIncremental := smt.NewSMT(nil, false) + smtBatch := smt.NewSMT(nil, false) + insertBatchCfg := smt.NewInsertBatchConfig(context.Background(), "", false) + for i, k := range keys { + smtIncremental.Insert(k, values[i]) + _, err := smtBatch.InsertBatch(insertBatchCfg, []*utils.NodeKey{&k}, []*utils.NodeValue8{&values[i]}, nil, nil) + assert.NilError(t, err) + + smtIncrementalRootHash, _ := smtIncremental.Db.GetLastRoot() + smtBatchRootHash, _ := smtBatch.Db.GetLastRoot() + assert.Equal(t, utils.ConvertBigIntToHex(smtBatchRootHash), utils.ConvertBigIntToHex(smtIncrementalRootHash)) + } + + smtIncremental.DumpTree() + fmt.Println() + smtBatch.DumpTree() + fmt.Println() + + assertSmtDbStructure(t, smtBatch, false) +} diff --git a/smt/pkg/smt/smt_batch_asserts_test.go b/smt/pkg/smt/smt_batch_asserts_test.go new file mode 100644 index 00000000000..171c3af6e0f --- /dev/null +++ b/smt/pkg/smt/smt_batch_asserts_test.go @@ -0,0 +1,125 @@ +package smt_test + +import ( + "context" + "fmt" + "testing" + + "github.com/ledgerwatch/erigon/smt/pkg/db" + "github.com/ledgerwatch/erigon/smt/pkg/smt" + "github.com/ledgerwatch/erigon/smt/pkg/utils" + "gotest.tools/v3/assert" +) + +func assertSmtDbStructure(t *testing.T, s *smt.SMT, testMetadata bool) { + smtBatchRootHash, _ := s.Db.GetLastRoot() + + actualDb, ok := s.Db.(*db.MemDb) + if !ok { + return + } + + usedNodeHashesMap := make(map[string]*utils.NodeKey) + assertSmtTreeDbStructure(t, s, utils.ScalarToRoot(smtBatchRootHash), usedNodeHashesMap) + + // EXPLAIN THE LINE BELOW: db could have more values because values' hashes are not deleted + assert.Equal(t, true, len(actualDb.Db)-len(usedNodeHashesMap) >= 0) + for k := range usedNodeHashesMap { + _, found := actualDb.Db[k] + assert.Equal(t, true, found) + } + + totalLeaves := assertHashToKeyDbStrcture(t, s, utils.ScalarToRoot(smtBatchRootHash), testMetadata) + assert.Equal(t, totalLeaves, len(actualDb.DbHashKey)) + if testMetadata { + assert.Equal(t, totalLeaves, len(actualDb.DbKeySource)) + } + + assertTraverse(t, s) +} + +func assertSmtTreeDbStructure(t *testing.T, s *smt.SMT, nodeHash utils.NodeKey, usedNodeHashesMap map[string]*utils.NodeKey) { + if nodeHash.IsZero() { + return + } + + dbNodeValue, err := s.Db.Get(nodeHash) + assert.NilError(t, err) + + nodeHashHex := utils.ConvertBigIntToHex(utils.ArrayToScalar(nodeHash[:])) + usedNodeHashesMap[nodeHashHex] = &nodeHash + + if dbNodeValue.IsFinalNode() { + nodeValueHash := utils.NodeKeyFromBigIntArray(dbNodeValue[4:8]) + dbNodeValue, err = s.Db.Get(nodeValueHash) + assert.NilError(t, err) + + nodeHashHex := utils.ConvertBigIntToHex(utils.ArrayToScalar(nodeValueHash[:])) + usedNodeHashesMap[nodeHashHex] = &nodeValueHash + return + } + + assertSmtTreeDbStructure(t, s, utils.NodeKeyFromBigIntArray(dbNodeValue[0:4]), usedNodeHashesMap) + assertSmtTreeDbStructure(t, s, utils.NodeKeyFromBigIntArray(dbNodeValue[4:8]), usedNodeHashesMap) +} + +func assertHashToKeyDbStrcture(t *testing.T, smtBatch *smt.SMT, nodeHash utils.NodeKey, testMetadata bool) int { + if nodeHash.IsZero() { + return 0 + } + + dbNodeValue, err := smtBatch.Db.Get(nodeHash) + assert.NilError(t, err) + + if dbNodeValue.IsFinalNode() { + memDb := smtBatch.Db.(*db.MemDb) + + nodeKey, err := smtBatch.Db.GetHashKey(nodeHash) + assert.NilError(t, err) + + keyConc := utils.ArrayToScalar(nodeHash[:]) + k := utils.ConvertBigIntToHex(keyConc) + _, found := memDb.DbHashKey[k] + assert.Equal(t, found, true) + + if testMetadata { + keyConc = utils.ArrayToScalar(nodeKey[:]) + + _, found = memDb.DbKeySource[keyConc.String()] + assert.Equal(t, found, true) + } + return 1 + } + + return assertHashToKeyDbStrcture(t, smtBatch, utils.NodeKeyFromBigIntArray(dbNodeValue[0:4]), testMetadata) + assertHashToKeyDbStrcture(t, smtBatch, utils.NodeKeyFromBigIntArray(dbNodeValue[4:8]), testMetadata) +} + +func assertTraverse(t *testing.T, s *smt.SMT) { + smtBatchRootHash, _ := s.Db.GetLastRoot() + + ctx := context.Background() + action := func(prefix []byte, k utils.NodeKey, v utils.NodeValue12) (bool, error) { + if v.IsFinalNode() { + valHash := v.Get4to8() + v, err := s.Db.Get(*valHash) + if err != nil { + return false, err + } + + if v[0] == nil { + return false, fmt.Errorf("value is missing in the db") + } + + vInBytes := utils.ArrayBigToScalar(utils.BigIntArrayFromNodeValue8(v.GetNodeValue8())).Bytes() + if vInBytes == nil { + return false, fmt.Errorf("error in converting to bytes") + } + + return false, nil + } + + return true, nil + } + err := s.Traverse(ctx, smtBatchRootHash, action) + assert.NilError(t, err) +} diff --git a/smt/pkg/smt/smt_batch_compare_test.go b/smt/pkg/smt/smt_batch_compare_test.go new file mode 100644 index 00000000000..a3fd57c3984 --- /dev/null +++ b/smt/pkg/smt/smt_batch_compare_test.go @@ -0,0 +1,120 @@ +package smt_test + +import ( + "context" + "os" + "testing" + "time" + + libcommon "github.com/ledgerwatch/erigon-lib/common" + "github.com/ledgerwatch/erigon/core/types/accounts" + "github.com/ledgerwatch/erigon/smt/pkg/smt" + "github.com/ledgerwatch/erigon/smt/pkg/utils" + "gotest.tools/v3/assert" +) + +func TestCompareAllTreesInsertTimesAndFinalHashesUsingDiskDb(t *testing.T) { + incrementalDbPath := "/tmp/smt-incremental" + smtIncrementalDb, smtIncrementalTx, smtIncrementalSmtDb := initDb(t, incrementalDbPath) + + bulkDbPath := "/tmp/smt-bulk" + smtBulkDb, smtBulkTx, smtBulkSmtDb := initDb(t, bulkDbPath) + + batchDbPath := "/tmp/smt-batch" + smtBatchDb, smtBatchTx, smtBatchSmtDb := initDb(t, batchDbPath) + + smtIncremental := smt.NewSMT(smtIncrementalSmtDb, false) + smtBulk := smt.NewSMT(smtBulkSmtDb, false) + smtBatch := smt.NewSMT(smtBatchSmtDb, false) + + compareAllTreesInsertTimesAndFinalHashes(t, smtIncremental, smtBulk, smtBatch) + + smtIncrementalTx.Commit() + tt := time.Now() + t.Logf("1: %v\n", time.Since(tt)) + smtBulkTx.Commit() + tt = time.Now() + t.Logf("2: %v\n", time.Since(tt)) + smtBatchTx.Commit() + tt = time.Now() + t.Logf("3: %v\n", time.Since(tt)) + t.Cleanup(func() { + smtIncrementalDb.Close() + smtBulkDb.Close() + smtBatchDb.Close() + os.RemoveAll(incrementalDbPath) + os.RemoveAll(bulkDbPath) + os.RemoveAll(batchDbPath) + }) +} + +func TestCompareAllTreesInsertTimesAndFinalHashesUsingInMemoryDb(t *testing.T) { + smtIncremental := smt.NewSMT(nil, false) + smtBulk := smt.NewSMT(nil, false) + smtBatch := smt.NewSMT(nil, false) + + compareAllTreesInsertTimesAndFinalHashes(t, smtIncremental, smtBulk, smtBatch) +} + +func compareAllTreesInsertTimesAndFinalHashes(t *testing.T, smtIncremental, smtBulk, smtBatch *smt.SMT) { + batchInsertDataHolders, totalInserts := prepareData() + ctx := context.Background() + var incrementalError error + + accChanges := make(map[libcommon.Address]*accounts.Account) + codeChanges := make(map[libcommon.Address]string) + storageChanges := make(map[libcommon.Address]map[string]string) + + for _, batchInsertDataHolder := range batchInsertDataHolders { + accChanges[batchInsertDataHolder.AddressAccount] = &batchInsertDataHolder.acc + codeChanges[batchInsertDataHolder.AddressContract] = batchInsertDataHolder.Bytecode + storageChanges[batchInsertDataHolder.AddressContract] = batchInsertDataHolder.Storage + } + + startTime := time.Now() + for addr, acc := range accChanges { + if err := smtIncremental.SetAccountStorage(addr, acc); err != nil { + incrementalError = err + } + } + + for addr, code := range codeChanges { + if err := smtIncremental.SetContractBytecode(addr.String(), code); err != nil { + incrementalError = err + } + } + + for addr, storage := range storageChanges { + if _, err := smtIncremental.SetContractStorage(addr.String(), storage, nil); err != nil { + incrementalError = err + } + } + + assert.NilError(t, incrementalError) + t.Logf("Incremental insert %d values in %v\n", totalInserts, time.Since(startTime)) + + startTime = time.Now() + keyPointers, valuePointers, err := smtBatch.SetStorage(ctx, "", accChanges, codeChanges, storageChanges) + assert.NilError(t, err) + t.Logf("Batch insert %d values in %v\n", totalInserts, time.Since(startTime)) + + keys := []utils.NodeKey{} + for i, key := range keyPointers { + v := valuePointers[i] + if !v.IsZero() { + smtBulk.Db.InsertAccountValue(*key, *v) + keys = append(keys, *key) + } + } + startTime = time.Now() + smtBulk.GenerateFromKVBulk(ctx, "", keys) + t.Logf("Bulk insert %d values in %v\n", totalInserts, time.Since(startTime)) + + smtIncrementalRootHash, _ := smtIncremental.Db.GetLastRoot() + smtBatchRootHash, _ := smtBatch.Db.GetLastRoot() + smtBulkRootHash, _ := smtBulk.Db.GetLastRoot() + assert.Equal(t, utils.ConvertBigIntToHex(smtBatchRootHash), utils.ConvertBigIntToHex(smtIncrementalRootHash)) + assert.Equal(t, utils.ConvertBigIntToHex(smtBulkRootHash), utils.ConvertBigIntToHex(smtIncrementalRootHash)) + + assertSmtDbStructure(t, smtBatch, true) +} diff --git a/smt/pkg/smt/smt_batch_delete_test.go b/smt/pkg/smt/smt_batch_delete_test.go new file mode 100644 index 00000000000..469d938e314 --- /dev/null +++ b/smt/pkg/smt/smt_batch_delete_test.go @@ -0,0 +1,76 @@ +package smt_test + +import ( + "context" + "fmt" + "math/big" + "testing" + + "github.com/ledgerwatch/erigon/smt/pkg/smt" + "github.com/ledgerwatch/erigon/smt/pkg/utils" + "gotest.tools/v3/assert" +) + +func TestBatchDelete(t *testing.T) { + keys := []utils.NodeKey{ + utils.NodeKey{10768458483543229763, 12393104588695647110, 7306859896719697582, 4178785141502415085}, + utils.NodeKey{7512520260500009967, 3751662918911081259, 9113133324668552163, 12072005766952080289}, + utils.NodeKey{4755722537892498409, 14621988746728905818, 15452350668109735064, 8819587610951133148}, + utils.NodeKey{6340777516277056037, 6264482673611175884, 1063722098746108599, 9062208133640346025}, + utils.NodeKey{6319287575763093444, 10809750365832475266, 6426706394050518186, 9463173325157812560}, + utils.NodeKey{15155415624738072211, 3736290188193138617, 8461047487943769832, 12188454615342744806}, + utils.NodeKey{15276670325385989216, 10944726794004460540, 9369946489424614125, 817372649097925902}, + utils.NodeKey{2562799672200229655, 18444468184514201072, 17883941549041529369, 407038781355273654}, + utils.NodeKey{10768458483543229763, 12393104588695647110, 7306859896719697582, 4178785141502415085}, + utils.NodeKey{7512520260500009967, 3751662918911081259, 9113133324668552163, 12072005766952080289}, + utils.NodeKey{4755722537892498409, 14621988746728905818, 15452350668109735064, 8819587610951133148}, + } + + valuesTemp := [][8]uint64{ + [8]uint64{0, 1, 0, 0, 0, 0, 0, 0}, + [8]uint64{0, 1, 0, 0, 0, 0, 0, 0}, + [8]uint64{0, 1, 0, 0, 0, 0, 0, 0}, + [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{103184848, 115613322, 0, 0, 0, 0, 0, 0}, + [8]uint64{2, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{3038602192, 2317586098, 794977000, 2442751483, 2309555181, 2028447238, 1023640522, 2687173865}, + [8]uint64{3100, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{0, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{0, 0, 0, 0, 0, 0, 0, 0}, + [8]uint64{0, 0, 0, 0, 0, 0, 0, 0}, + } + + values := make([]utils.NodeValue8, 0) + for _, vT := range valuesTemp { + values = append(values, utils.NodeValue8{ + big.NewInt(0).SetUint64(vT[0]), + big.NewInt(0).SetUint64(vT[1]), + big.NewInt(0).SetUint64(vT[2]), + big.NewInt(0).SetUint64(vT[3]), + big.NewInt(0).SetUint64(vT[4]), + big.NewInt(0).SetUint64(vT[5]), + big.NewInt(0).SetUint64(vT[6]), + big.NewInt(0).SetUint64(vT[7]), + }) + } + + smtIncremental := smt.NewSMT(nil, false) + smtBatch := smt.NewSMT(nil, false) + insertBatchCfg := smt.NewInsertBatchConfig(context.Background(), "", false) + for i, k := range keys { + smtIncremental.Insert(k, values[i]) + _, err := smtBatch.InsertBatch(insertBatchCfg, []*utils.NodeKey{&k}, []*utils.NodeValue8{&values[i]}, nil, nil) + assert.NilError(t, err) + + smtIncrementalRootHash, _ := smtIncremental.Db.GetLastRoot() + smtBatchRootHash, _ := smtBatch.Db.GetLastRoot() + assert.Equal(t, utils.ConvertBigIntToHex(smtBatchRootHash), utils.ConvertBigIntToHex(smtIncrementalRootHash)) + } + + smtIncremental.DumpTree() + fmt.Println() + smtBatch.DumpTree() + fmt.Println() + + assertSmtDbStructure(t, smtBatch, false) +} diff --git a/smt/pkg/smt/smt_batch_insert_test.go b/smt/pkg/smt/smt_batch_insert_test.go new file mode 100644 index 00000000000..fc6b7b497ac --- /dev/null +++ b/smt/pkg/smt/smt_batch_insert_test.go @@ -0,0 +1,267 @@ +package smt_test + +import ( + "context" + "fmt" + "math/big" + "math/rand" + "testing" + "time" + + "github.com/ledgerwatch/erigon/smt/pkg/smt" + "github.com/ledgerwatch/erigon/smt/pkg/utils" + "gotest.tools/v3/assert" +) + +func TestBatchSimpleInsert(t *testing.T) { + keysRaw := []*big.Int{ + big.NewInt(8), + big.NewInt(8), + big.NewInt(1), + big.NewInt(31), + big.NewInt(31), + big.NewInt(0), + big.NewInt(8), + } + valuesRaw := []*big.Int{ + big.NewInt(17), + big.NewInt(18), + big.NewInt(19), + big.NewInt(20), + big.NewInt(0), + big.NewInt(0), + big.NewInt(0), + } + + keyPointers := []*utils.NodeKey{} + valuePointers := []*utils.NodeValue8{} + + smtIncremental := smt.NewSMT(nil, false) + smtBatch := smt.NewSMT(nil, false) + smtBatchNoSave := smt.NewSMT(nil, true) + + for i := range keysRaw { + k := utils.ScalarToNodeKey(keysRaw[i]) + vArray := utils.ScalarToArrayBig(valuesRaw[i]) + v, _ := utils.NodeValue8FromBigIntArray(vArray) + + keyPointers = append(keyPointers, &k) + valuePointers = append(valuePointers, v) + + smtIncremental.InsertKA(k, valuesRaw[i]) + } + + insertBatchCfg := smt.NewInsertBatchConfig(context.Background(), "", false) + _, err := smtBatch.InsertBatch(insertBatchCfg, keyPointers, valuePointers, nil, nil) + assert.NilError(t, err) + + _, err = smtBatchNoSave.InsertBatch(insertBatchCfg, keyPointers, valuePointers, nil, nil) + assert.NilError(t, err) + + smtIncremental.DumpTree() + fmt.Println() + smtBatch.DumpTree() + fmt.Println() + fmt.Println() + fmt.Println() + + smtIncrementalRootHash, _ := smtIncremental.Db.GetLastRoot() + smtBatchRootHash, _ := smtBatch.Db.GetLastRoot() + smtBatchNoSaveRootHash, _ := smtBatchNoSave.Db.GetLastRoot() + assert.Equal(t, utils.ConvertBigIntToHex(smtBatchRootHash), utils.ConvertBigIntToHex(smtIncrementalRootHash)) + assert.Equal(t, utils.ConvertBigIntToHex(smtBatchRootHash), utils.ConvertBigIntToHex(smtBatchNoSaveRootHash)) + + assertSmtDbStructure(t, smtBatch, false) +} + +func TestBatchRawInsert(t *testing.T) { + keysForBatch := []*utils.NodeKey{} + valuesForBatch := []*utils.NodeValue8{} + + keysForIncremental := []utils.NodeKey{} + valuesForIncremental := []utils.NodeValue8{} + + smtIncremental := smt.NewSMT(nil, false) + smtBatch := smt.NewSMT(nil, false) + + rand.Seed(1) + size := 1 << 10 + for i := 0; i < size; i++ { + rawKey := big.NewInt(rand.Int63()) + rawValue := big.NewInt(rand.Int63()) + + k := utils.ScalarToNodeKey(rawKey) + vArray := utils.ScalarToArrayBig(rawValue) + v, _ := utils.NodeValue8FromBigIntArray(vArray) + + keysForBatch = append(keysForBatch, &k) + valuesForBatch = append(valuesForBatch, v) + + keysForIncremental = append(keysForIncremental, k) + valuesForIncremental = append(valuesForIncremental, *v) + + } + + startTime := time.Now() + for i := range keysForIncremental { + smtIncremental.Insert(keysForIncremental[i], valuesForIncremental[i]) + } + t.Logf("Incremental insert %d values in %v\n", len(keysForIncremental), time.Since(startTime)) + + startTime = time.Now() + + insertBatchCfg := smt.NewInsertBatchConfig(context.Background(), "", true) + _, err := smtBatch.InsertBatch(insertBatchCfg, keysForBatch, valuesForBatch, nil, nil) + assert.NilError(t, err) + t.Logf("Batch insert %d values in %v\n", len(keysForBatch), time.Since(startTime)) + + smtIncrementalRootHash, _ := smtIncremental.Db.GetLastRoot() + smtBatchRootHash, _ := smtBatch.Db.GetLastRoot() + assert.Equal(t, utils.ConvertBigIntToHex(smtBatchRootHash), utils.ConvertBigIntToHex(smtIncrementalRootHash)) + + assertSmtDbStructure(t, smtBatch, false) + + // DELETE + keysForBatchDelete := []*utils.NodeKey{} + valuesForBatchDelete := []*utils.NodeValue8{} + + keysForIncrementalDelete := []utils.NodeKey{} + valuesForIncrementalDelete := []utils.NodeValue8{} + + sizeToDelete := 1 << 14 + for i := 0; i < sizeToDelete; i++ { + rawValue := big.NewInt(0) + vArray := utils.ScalarToArrayBig(rawValue) + v, _ := utils.NodeValue8FromBigIntArray(vArray) + + deleteIndex := rand.Intn(size) + + keyForBatchDelete := keysForBatch[deleteIndex] + keyForIncrementalDelete := keysForIncremental[deleteIndex] + + keysForBatchDelete = append(keysForBatchDelete, keyForBatchDelete) + valuesForBatchDelete = append(valuesForBatchDelete, v) + + keysForIncrementalDelete = append(keysForIncrementalDelete, keyForIncrementalDelete) + valuesForIncrementalDelete = append(valuesForIncrementalDelete, *v) + } + + startTime = time.Now() + for i := range keysForIncrementalDelete { + smtIncremental.Insert(keysForIncrementalDelete[i], valuesForIncrementalDelete[i]) + } + t.Logf("Incremental delete %d values in %v\n", len(keysForIncrementalDelete), time.Since(startTime)) + + startTime = time.Now() + + _, err = smtBatch.InsertBatch(insertBatchCfg, keysForBatchDelete, valuesForBatchDelete, nil, nil) + assert.NilError(t, err) + t.Logf("Batch delete %d values in %v\n", len(keysForBatchDelete), time.Since(startTime)) + + assertSmtDbStructure(t, smtBatch, false) +} + +func BenchmarkIncrementalInsert(b *testing.B) { + keys := []*big.Int{} + vals := []*big.Int{} + for i := 0; i < 1000; i++ { + rand.Seed(time.Now().UnixNano()) + keys = append(keys, big.NewInt(int64(rand.Intn(10000)))) + + rand.Seed(time.Now().UnixNano()) + vals = append(vals, big.NewInt(int64(rand.Intn(10000)))) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + smtIncremental := smt.NewSMT(nil, false) + incrementalInsert(smtIncremental, keys, vals) + } +} + +func BenchmarkBatchInsert(b *testing.B) { + keys := []*big.Int{} + vals := []*big.Int{} + for i := 0; i < 1000; i++ { + rand.Seed(time.Now().UnixNano()) + keys = append(keys, big.NewInt(int64(rand.Intn(10000)))) + + rand.Seed(time.Now().UnixNano()) + vals = append(vals, big.NewInt(int64(rand.Intn(10000)))) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + smtBatch := smt.NewSMT(nil, false) + batchInsert(smtBatch, keys, vals) + } +} + +func BenchmarkBatchInsertNoSave(b *testing.B) { + keys := []*big.Int{} + vals := []*big.Int{} + for i := 0; i < 1000; i++ { + rand.Seed(time.Now().UnixNano()) + keys = append(keys, big.NewInt(int64(rand.Intn(10000)))) + + rand.Seed(time.Now().UnixNano()) + vals = append(vals, big.NewInt(int64(rand.Intn(10000)))) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + smtBatch := smt.NewSMT(nil, true) + batchInsert(smtBatch, keys, vals) + } +} + +func TestBatchSimpleInsert2(t *testing.T) { + keys := []*big.Int{} + vals := []*big.Int{} + for i := 0; i < 1000; i++ { + rand.Seed(time.Now().UnixNano()) + keys = append(keys, big.NewInt(int64(rand.Intn(10000)))) + + rand.Seed(time.Now().UnixNano()) + vals = append(vals, big.NewInt(int64(rand.Intn(10000)))) + } + + smtIncremental := smt.NewSMT(nil, false) + incrementalInsert(smtIncremental, keys, vals) + + smtBatch := smt.NewSMT(nil, false) + batchInsert(smtBatch, keys, vals) + + smtBatchNoSave := smt.NewSMT(nil, false) + batchInsert(smtBatchNoSave, keys, vals) + + smtIncrementalRootHash, _ := smtIncremental.Db.GetLastRoot() + smtBatchRootHash, _ := smtBatch.Db.GetLastRoot() + smtBatchNoSaveRootHash, _ := smtBatchNoSave.Db.GetLastRoot() + + assert.Equal(t, utils.ConvertBigIntToHex(smtBatchRootHash), utils.ConvertBigIntToHex(smtIncrementalRootHash)) + assert.Equal(t, utils.ConvertBigIntToHex(smtBatchRootHash), utils.ConvertBigIntToHex(smtBatchNoSaveRootHash)) +} + +func incrementalInsert(tree *smt.SMT, key, val []*big.Int) { + for i := range key { + k := utils.ScalarToNodeKey(key[i]) + tree.InsertKA(k, val[i]) + } +} + +func batchInsert(tree *smt.SMT, key, val []*big.Int) { + keyPointers := []*utils.NodeKey{} + valuePointers := []*utils.NodeValue8{} + + for i := range key { + k := utils.ScalarToNodeKey(key[i]) + vArray := utils.ScalarToArrayBig(val[i]) + v, _ := utils.NodeValue8FromBigIntArray(vArray) + + keyPointers = append(keyPointers, &k) + valuePointers = append(valuePointers, v) + } + insertBatchCfg := smt.NewInsertBatchConfig(context.Background(), "", false) + tree.InsertBatch(insertBatchCfg, keyPointers, valuePointers, nil, nil) +} diff --git a/smt/pkg/smt/smt_batch_test.go b/smt/pkg/smt/smt_batch_test.go deleted file mode 100644 index 29beb88fef1..00000000000 --- a/smt/pkg/smt/smt_batch_test.go +++ /dev/null @@ -1,815 +0,0 @@ -package smt_test - -import ( - "context" - "fmt" - "math/big" - "math/rand" - "os" - "testing" - "time" - - "github.com/c2h5oh/datasize" - "github.com/holiman/uint256" - libcommon "github.com/ledgerwatch/erigon-lib/common" - "github.com/ledgerwatch/erigon-lib/kv" - "github.com/ledgerwatch/erigon-lib/kv/mdbx" - "github.com/ledgerwatch/erigon/core/types/accounts" - "github.com/ledgerwatch/erigon/migrations" - "github.com/ledgerwatch/erigon/smt/pkg/db" - "github.com/ledgerwatch/erigon/smt/pkg/smt" - "github.com/ledgerwatch/erigon/smt/pkg/utils" - "github.com/ledgerwatch/log/v3" - "golang.org/x/sync/semaphore" - "gotest.tools/v3/assert" -) - -type BatchInsertDataHolder struct { - acc accounts.Account - AddressAccount libcommon.Address - AddressContract libcommon.Address - Bytecode string - Storage map[string]string -} - -func TestBatchSimpleInsert(t *testing.T) { - keysRaw := []*big.Int{ - big.NewInt(8), - big.NewInt(8), - big.NewInt(1), - big.NewInt(31), - big.NewInt(31), - big.NewInt(0), - big.NewInt(8), - } - valuesRaw := []*big.Int{ - big.NewInt(17), - big.NewInt(18), - big.NewInt(19), - big.NewInt(20), - big.NewInt(0), - big.NewInt(0), - big.NewInt(0), - } - - keyPointers := []*utils.NodeKey{} - valuePointers := []*utils.NodeValue8{} - - smtIncremental := smt.NewSMT(nil, false) - smtBatch := smt.NewSMT(nil, false) - smtBatchNoSave := smt.NewSMT(nil, true) - - for i := range keysRaw { - k := utils.ScalarToNodeKey(keysRaw[i]) - vArray := utils.ScalarToArrayBig(valuesRaw[i]) - v, _ := utils.NodeValue8FromBigIntArray(vArray) - - keyPointers = append(keyPointers, &k) - valuePointers = append(valuePointers, v) - - smtIncremental.InsertKA(k, valuesRaw[i]) - } - - insertBatchCfg := smt.NewInsertBatchConfig(context.Background(), "", false) - _, err := smtBatch.InsertBatch(insertBatchCfg, keyPointers, valuePointers, nil, nil) - assert.NilError(t, err) - - _, err = smtBatchNoSave.InsertBatch(insertBatchCfg, keyPointers, valuePointers, nil, nil) - assert.NilError(t, err) - - smtIncremental.DumpTree() - fmt.Println() - smtBatch.DumpTree() - fmt.Println() - fmt.Println() - fmt.Println() - - smtIncrementalRootHash, _ := smtIncremental.Db.GetLastRoot() - smtBatchRootHash, _ := smtBatch.Db.GetLastRoot() - smtBatchNoSaveRootHash, _ := smtBatchNoSave.Db.GetLastRoot() - assert.Equal(t, utils.ConvertBigIntToHex(smtBatchRootHash), utils.ConvertBigIntToHex(smtIncrementalRootHash)) - assert.Equal(t, utils.ConvertBigIntToHex(smtBatchRootHash), utils.ConvertBigIntToHex(smtBatchNoSaveRootHash)) - - assertSmtDbStructure(t, smtBatch, false) -} - -func incrementalInsert(tree *smt.SMT, key, val []*big.Int) { - for i := range key { - k := utils.ScalarToNodeKey(key[i]) - tree.InsertKA(k, val[i]) - } -} - -func batchInsert(tree *smt.SMT, key, val []*big.Int) { - keyPointers := []*utils.NodeKey{} - valuePointers := []*utils.NodeValue8{} - - for i := range key { - k := utils.ScalarToNodeKey(key[i]) - vArray := utils.ScalarToArrayBig(val[i]) - v, _ := utils.NodeValue8FromBigIntArray(vArray) - - keyPointers = append(keyPointers, &k) - valuePointers = append(valuePointers, v) - } - insertBatchCfg := smt.NewInsertBatchConfig(context.Background(), "", false) - tree.InsertBatch(insertBatchCfg, keyPointers, valuePointers, nil, nil) -} - -func BenchmarkIncrementalInsert(b *testing.B) { - keys := []*big.Int{} - vals := []*big.Int{} - for i := 0; i < 1000; i++ { - rand.Seed(time.Now().UnixNano()) - keys = append(keys, big.NewInt(int64(rand.Intn(10000)))) - - rand.Seed(time.Now().UnixNano()) - vals = append(vals, big.NewInt(int64(rand.Intn(10000)))) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - smtIncremental := smt.NewSMT(nil, false) - incrementalInsert(smtIncremental, keys, vals) - } -} - -func BenchmarkBatchInsert(b *testing.B) { - keys := []*big.Int{} - vals := []*big.Int{} - for i := 0; i < 1000; i++ { - rand.Seed(time.Now().UnixNano()) - keys = append(keys, big.NewInt(int64(rand.Intn(10000)))) - - rand.Seed(time.Now().UnixNano()) - vals = append(vals, big.NewInt(int64(rand.Intn(10000)))) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - smtBatch := smt.NewSMT(nil, false) - batchInsert(smtBatch, keys, vals) - } -} - -func BenchmarkBatchInsertNoSave(b *testing.B) { - keys := []*big.Int{} - vals := []*big.Int{} - for i := 0; i < 1000; i++ { - rand.Seed(time.Now().UnixNano()) - keys = append(keys, big.NewInt(int64(rand.Intn(10000)))) - - rand.Seed(time.Now().UnixNano()) - vals = append(vals, big.NewInt(int64(rand.Intn(10000)))) - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - smtBatch := smt.NewSMT(nil, true) - batchInsert(smtBatch, keys, vals) - } -} - -func TestBatchSimpleInsert2(t *testing.T) { - keys := []*big.Int{} - vals := []*big.Int{} - for i := 0; i < 1000; i++ { - rand.Seed(time.Now().UnixNano()) - keys = append(keys, big.NewInt(int64(rand.Intn(10000)))) - - rand.Seed(time.Now().UnixNano()) - vals = append(vals, big.NewInt(int64(rand.Intn(10000)))) - } - - smtIncremental := smt.NewSMT(nil, false) - incrementalInsert(smtIncremental, keys, vals) - - smtBatch := smt.NewSMT(nil, false) - batchInsert(smtBatch, keys, vals) - - smtBatchNoSave := smt.NewSMT(nil, false) - batchInsert(smtBatchNoSave, keys, vals) - - smtIncrementalRootHash, _ := smtIncremental.Db.GetLastRoot() - smtBatchRootHash, _ := smtBatch.Db.GetLastRoot() - smtBatchNoSaveRootHash, _ := smtBatchNoSave.Db.GetLastRoot() - - assert.Equal(t, utils.ConvertBigIntToHex(smtBatchRootHash), utils.ConvertBigIntToHex(smtIncrementalRootHash)) - assert.Equal(t, utils.ConvertBigIntToHex(smtBatchRootHash), utils.ConvertBigIntToHex(smtBatchNoSaveRootHash)) -} - -func TestBatchWitness(t *testing.T) { - keys := []utils.NodeKey{ - utils.NodeKey{17822804428864912231, 4683868963463720294, 2947512351908939790, 2330225637707749973}, - utils.NodeKey{15928606457751385034, 926210564408807848, 3634217732472610234, 18021748560357139965}, - utils.NodeKey{1623861826376204094, 570263533561698889, 4654109133431364496, 7281957057362652730}, - utils.NodeKey{13644513224119225920, 15807577943241006501, 9942496498562648573, 15190659753926523377}, - utils.NodeKey{9275812266666786730, 4204572028245381139, 3605834086260069958, 10007478335141208804}, - utils.NodeKey{8235907590678154663, 6691762687086189695, 15487167600723075149, 10984821506434298343}, - utils.NodeKey{16417603439618455829, 5362127645905990998, 10661203900902368419, 16076124886006448905}, - utils.NodeKey{11707747219427568787, 933117036015558858, 16439357349021750126, 14064521656451211675}, - utils.NodeKey{10768458483543229763, 12393104588695647110, 7306859896719697582, 4178785141502415085}, - utils.NodeKey{7512520260500009967, 3751662918911081259, 9113133324668552163, 12072005766952080289}, - utils.NodeKey{9944065905482556519, 8594459084728876791, 17786637052462706859, 15521772847998069525}, - utils.NodeKey{5036431633232956882, 16658186702978753823, 2870215478624537606, 11907126160741124846}, - utils.NodeKey{17938814940856978076, 13147879352039549979, 1303554763666506875, 14953772317105337015}, - utils.NodeKey{17398863357602626404, 4841219907503399295, 2992012704517273588, 16471435007473943078}, - utils.NodeKey{4763654225644445738, 5354841943603308259, 16476366216865814029, 10492509060169249179}, - utils.NodeKey{3554925909441560661, 16583852156861238748, 15693104712527552035, 8799937559790156794}, - utils.NodeKey{9617343367546549815, 6562355304138083186, 4016039301486039807, 10864657160754550133}, - utils.NodeKey{17933907347870222658, 16190350511466382228, 13330881818854499962, 1410294862891786839}, - utils.NodeKey{17260204906255015513, 15380909239227623493, 8567606678138088594, 4899143890802672405}, - utils.NodeKey{12539511585850227228, 3973200204826286539, 8108069613182344498, 11385621942985713904}, - utils.NodeKey{5984161349947667925, 7514232801604484380, 16331057190188025237, 2178913139230121631}, - utils.NodeKey{1993407781442332939, 1513605408256072860, 9533711780544200094, 4407755968940168245}, - utils.NodeKey{10660689026092155967, 7772873226204509526, 940412750970337957, 11934396459574454979}, - utils.NodeKey{13517500090161376813, 3430655983873553997, 5375259408796912397, 1582918923617071297}, - utils.NodeKey{1530581473737529386, 12702896566116465736, 5914767264290477911, 17646414071976395527}, - utils.NodeKey{16058468518382574435, 17573595348125839734, 14299084025723850432, 9173086175977268459}, - utils.NodeKey{3492167051156683621, 5113280701490269535, 3519293511105800335, 4519124618482063071}, - utils.NodeKey{18174025977752953446, 170880634573707059, 1420648486923115869, 7650935848186468717}, - utils.NodeKey{16208859541132551432, 6618660032536205153, 10385910322459208315, 8083618043937979883}, - utils.NodeKey{18055381843795531980, 13462709273291510955, 680380512647919587, 11342529403284590651}, - utils.NodeKey{14208409806025064162, 3405833321788641051, 10002545051615441056, 3286956713137532874}, - utils.NodeKey{5680425176740212736, 8706205589048866541, 1439054882559309464, 17935966873927915285}, - utils.NodeKey{110533614413158858, 1569162572987050699, 17606018854685897411, 14063722484766563720}, - utils.NodeKey{11233753640608616570, 12359586935502800882, 9900310098552340970, 2424696158120948624}, - utils.NodeKey{17470957289258137535, 89496548814733839, 13431046055752824170, 4863600257776330164}, - utils.NodeKey{12096080439449907754, 3586504186348650027, 16024032131582461863, 3698791599656620348}, - utils.NodeKey{12011265607191854676, 16995709771660398040, 10097323095148987140, 5271835541457063617}, - utils.NodeKey{13774341565485367328, 12574592232097177017, 13203533943886016969, 15689605306663468445}, - utils.NodeKey{17673889518692219847, 6954332541823247394, 954524149166700463, 10005323665613190430}, - utils.NodeKey{3390665384912132081, 273113266583762518, 15391923996500582086, 16937300536792272468}, - utils.NodeKey{3282365570547600329, 2269401659256178523, 12133143125482037239, 9431318293795439322}, - utils.NodeKey{10308056630015396434, 9302651503878791339, 1753436441509383136, 12655301298828119054}, - utils.NodeKey{4866095004323601391, 7715812469294898395, 13448442241363136994, 12560331541471347748}, - utils.NodeKey{9555357893875481640, 14044231432423634485, 2076021859364793876, 2098251167883986095}, - utils.NodeKey{13166561572768359955, 8774399027495495913, 17115924986198600732, 14679213838814779978}, - utils.NodeKey{1830856192880052688, 16817835989594317540, 6792141515706996611, 13263912888227522233}, - utils.NodeKey{8580776493878106180, 13275268150083925070, 1298114825004489111, 6818033484593972896}, - utils.NodeKey{2562799672200229655, 18444468184514201072, 17883941549041529369, 4070387813552736545}, - utils.NodeKey{9268691730026813326, 11545055880246569979, 1187823334319829775, 17259421874098825958}, - utils.NodeKey{9994578653598857505, 13890799434279521010, 6971431511534499255, 9998397274436059169}, - utils.NodeKey{18287575540870662480, 11943532407729972209, 15340299232888708073, 10838674117466297196}, - utils.NodeKey{14761821088000158583, 964796443048506502, 5721781221240658401, 13211032425907534953}, - utils.NodeKey{18144880475727242601, 4972225809077124674, 14334455111087919063, 8111397810232896953}, - utils.NodeKey{16933784929062172058, 9574268379822183272, 4944644580885359493, 3289128208877342006}, - utils.NodeKey{8619895206600224966, 15003370087833528133, 8252241585179054714, 9201580897217580981}, - utils.NodeKey{16332458695522739594, 7936008380823170261, 1848556403564669799, 17993420240804923523}, - utils.NodeKey{6515233280772008301, 4313177990083710387, 4012549955023285042, 12696650320500651942}, - utils.NodeKey{6070193153822371132, 14833198544694594099, 8041604520195724295, 569408677969141468}, - utils.NodeKey{18121124933744588643, 14019823252026845797, 497098216249706813, 14507670067050817524}, - utils.NodeKey{10768458483543229763, 12393104588695647110, 7306859896719697582, 4178785141502415085}, - utils.NodeKey{7512520260500009967, 3751662918911081259, 9113133324668552163, 12072005766952080289}, - utils.NodeKey{5911840277575969690, 14631288768946722660, 9289463458792995190, 11361263549285604206}, - utils.NodeKey{5112807231234019664, 3952289862952962911, 12826043220050158925, 4455878876833215993}, - utils.NodeKey{16417603439618455829, 5362127645905990998, 10661203900902368419, 16076124886006448905}, - utils.NodeKey{11707747219427568787, 933117036015558858, 16439357349021750126, 14064521656451211675}, - utils.NodeKey{16208859541132551432, 6618660032536205153, 10385910322459208315, 8083618043937979883}, - utils.NodeKey{18055381843795531980, 13462709273291510955, 680380512647919587, 11342529403284590651}, - utils.NodeKey{2562799672200229655, 18444468184514201072, 17883941549041529369, 4070387813552736545}, - utils.NodeKey{16339509425341743973, 7562720126843377837, 6087776866015284100, 13287333209707648581}, - utils.NodeKey{1830856192880052688, 16817835989594317540, 6792141515706996611, 13263912888227522233}, - } - - valuesTemp := [][8]uint64{ - [8]uint64{0, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{0, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{0, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{2802548736, 3113182143, 10842021, 0, 0, 0, 0, 0}, - [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{0, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{0, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{552894464, 46566, 0, 0, 0, 0, 0, 0}, - [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{0, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{0, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{4, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{0, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{8, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{0, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{0, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{92883624, 129402807, 3239216982, 1921492768, 41803744, 3662741242, 922499619, 611206845}, - [8]uint64{2149, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{1220686685, 2241513088, 3059933278, 877008478, 3450374550, 2577819195, 3646855908, 1714882695}, - [8]uint64{433, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{1807748760, 2873297298, 945201229, 411604167, 1063664423, 1763702642, 2637524917, 1284041408}, - [8]uint64{2112, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{2407450438, 2021315520, 3591671307, 1981785129, 893348094, 802675915, 3804752326, 2006944699}, - [8]uint64{2583, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{1400751902, 190749285, 93436423, 2918498711, 3630577401, 3928294404, 1037307865, 2336717508}, - [8]uint64{10043, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{2040622618, 1654767043, 2359080366, 3993652948, 2990917507, 41202511, 3266270425, 2537679611}, - [8]uint64{2971, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{2958032465, 981708138, 2081777150, 750201226, 3046928486, 2765783602, 2851559840, 1406574120}, - [8]uint64{23683, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{1741943335, 1540916232, 1327285029, 2450002482, 2695899944, 0, 0, 0}, - [8]uint64{3109587049, 2273239893, 220080300, 1823520391, 35937659, 0, 0, 0}, - [8]uint64{1677672755, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{337379899, 3225725520, 234013414, 1425864754, 2013026225, 0, 0, 0}, - [8]uint64{1031512883, 3743101878, 2828268606, 2468973124, 1081703471, 0, 0, 0}, - [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{256, 481884672, 2932392155, 111365737, 1511099657, 224351860, 164, 0}, - [8]uint64{632216695, 2300948800, 3904328458, 2148496278, 971473112, 0, 0, 0}, - [8]uint64{1031512883, 3743101878, 2828268606, 2468973124, 1081703471, 0, 0, 0}, - [8]uint64{4, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{2401446452, 1128446136, 4183588423, 3903755242, 16083787, 848717237, 2276372267, 2020002041}, - [8]uint64{2793696421, 3373683791, 3597304417, 3609426094, 2371386802, 1021540367, 828590482, 1599660962}, - [8]uint64{2793696421, 3373683791, 3597304417, 3609426094, 2371386802, 1021540367, 828590482, 1599660962}, - [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{2793696421, 3373683791, 3597304417, 3609426094, 2371386802, 1021540367, 828590482, 1599660962}, - [8]uint64{2793696421, 3373683791, 3597304417, 3609426094, 2371386802, 1021540367, 828590482, 1599660962}, - [8]uint64{86400, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{0, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{0, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{1321730048, 465661287, 0, 0, 0, 0, 0, 0}, - [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{1480818688, 2647520856, 10842021, 0, 0, 0, 0, 0}, - [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{2407450438, 2021315520, 3591671307, 1981785129, 893348094, 802675915, 3804752326, 2006944699}, - [8]uint64{2583, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{2, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{4210873971, 1869123984, 4035019538, 1823911763, 1097145772, 827956438, 819220988, 1111695650}, - [8]uint64{20, 0, 0, 0, 0, 0, 0, 0}, - } - - values := make([]utils.NodeValue8, 0) - for _, vT := range valuesTemp { - values = append(values, utils.NodeValue8{ - big.NewInt(0).SetUint64(vT[0]), - big.NewInt(0).SetUint64(vT[1]), - big.NewInt(0).SetUint64(vT[2]), - big.NewInt(0).SetUint64(vT[3]), - big.NewInt(0).SetUint64(vT[4]), - big.NewInt(0).SetUint64(vT[5]), - big.NewInt(0).SetUint64(vT[6]), - big.NewInt(0).SetUint64(vT[7]), - }) - } - - smtIncremental := smt.NewSMT(nil, false) - smtBatch := smt.NewSMT(nil, false) - insertBatchCfg := smt.NewInsertBatchConfig(context.Background(), "", false) - for i, k := range keys { - smtIncremental.Insert(k, values[i]) - _, err := smtBatch.InsertBatch(insertBatchCfg, []*utils.NodeKey{&k}, []*utils.NodeValue8{&values[i]}, nil, nil) - assert.NilError(t, err) - - smtIncrementalRootHash, _ := smtIncremental.Db.GetLastRoot() - smtBatchRootHash, _ := smtBatch.Db.GetLastRoot() - assert.Equal(t, utils.ConvertBigIntToHex(smtBatchRootHash), utils.ConvertBigIntToHex(smtIncrementalRootHash)) - } - - smtIncremental.DumpTree() - fmt.Println() - smtBatch.DumpTree() - fmt.Println() - - assertSmtDbStructure(t, smtBatch, false) -} - -func TestBatchDelete(t *testing.T) { - keys := []utils.NodeKey{ - utils.NodeKey{10768458483543229763, 12393104588695647110, 7306859896719697582, 4178785141502415085}, - utils.NodeKey{7512520260500009967, 3751662918911081259, 9113133324668552163, 12072005766952080289}, - utils.NodeKey{4755722537892498409, 14621988746728905818, 15452350668109735064, 8819587610951133148}, - utils.NodeKey{6340777516277056037, 6264482673611175884, 1063722098746108599, 9062208133640346025}, - utils.NodeKey{6319287575763093444, 10809750365832475266, 6426706394050518186, 9463173325157812560}, - utils.NodeKey{15155415624738072211, 3736290188193138617, 8461047487943769832, 12188454615342744806}, - utils.NodeKey{15276670325385989216, 10944726794004460540, 9369946489424614125, 817372649097925902}, - utils.NodeKey{2562799672200229655, 18444468184514201072, 17883941549041529369, 407038781355273654}, - utils.NodeKey{10768458483543229763, 12393104588695647110, 7306859896719697582, 4178785141502415085}, - utils.NodeKey{7512520260500009967, 3751662918911081259, 9113133324668552163, 12072005766952080289}, - utils.NodeKey{4755722537892498409, 14621988746728905818, 15452350668109735064, 8819587610951133148}, - } - - valuesTemp := [][8]uint64{ - [8]uint64{0, 1, 0, 0, 0, 0, 0, 0}, - [8]uint64{0, 1, 0, 0, 0, 0, 0, 0}, - [8]uint64{0, 1, 0, 0, 0, 0, 0, 0}, - [8]uint64{1, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{103184848, 115613322, 0, 0, 0, 0, 0, 0}, - [8]uint64{2, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{3038602192, 2317586098, 794977000, 2442751483, 2309555181, 2028447238, 1023640522, 2687173865}, - [8]uint64{3100, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{0, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{0, 0, 0, 0, 0, 0, 0, 0}, - [8]uint64{0, 0, 0, 0, 0, 0, 0, 0}, - } - - values := make([]utils.NodeValue8, 0) - for _, vT := range valuesTemp { - values = append(values, utils.NodeValue8{ - big.NewInt(0).SetUint64(vT[0]), - big.NewInt(0).SetUint64(vT[1]), - big.NewInt(0).SetUint64(vT[2]), - big.NewInt(0).SetUint64(vT[3]), - big.NewInt(0).SetUint64(vT[4]), - big.NewInt(0).SetUint64(vT[5]), - big.NewInt(0).SetUint64(vT[6]), - big.NewInt(0).SetUint64(vT[7]), - }) - } - - smtIncremental := smt.NewSMT(nil, false) - smtBatch := smt.NewSMT(nil, false) - insertBatchCfg := smt.NewInsertBatchConfig(context.Background(), "", false) - for i, k := range keys { - smtIncremental.Insert(k, values[i]) - _, err := smtBatch.InsertBatch(insertBatchCfg, []*utils.NodeKey{&k}, []*utils.NodeValue8{&values[i]}, nil, nil) - assert.NilError(t, err) - - smtIncrementalRootHash, _ := smtIncremental.Db.GetLastRoot() - smtBatchRootHash, _ := smtBatch.Db.GetLastRoot() - assert.Equal(t, utils.ConvertBigIntToHex(smtBatchRootHash), utils.ConvertBigIntToHex(smtIncrementalRootHash)) - } - - smtIncremental.DumpTree() - fmt.Println() - smtBatch.DumpTree() - fmt.Println() - - assertSmtDbStructure(t, smtBatch, false) -} - -func TestBatchRawInsert(t *testing.T) { - keysForBatch := []*utils.NodeKey{} - valuesForBatch := []*utils.NodeValue8{} - - keysForIncremental := []utils.NodeKey{} - valuesForIncremental := []utils.NodeValue8{} - - smtIncremental := smt.NewSMT(nil, false) - smtBatch := smt.NewSMT(nil, false) - - rand.Seed(1) - size := 1 << 10 - for i := 0; i < size; i++ { - rawKey := big.NewInt(rand.Int63()) - rawValue := big.NewInt(rand.Int63()) - - k := utils.ScalarToNodeKey(rawKey) - vArray := utils.ScalarToArrayBig(rawValue) - v, _ := utils.NodeValue8FromBigIntArray(vArray) - - keysForBatch = append(keysForBatch, &k) - valuesForBatch = append(valuesForBatch, v) - - keysForIncremental = append(keysForIncremental, k) - valuesForIncremental = append(valuesForIncremental, *v) - - } - - startTime := time.Now() - for i := range keysForIncremental { - smtIncremental.Insert(keysForIncremental[i], valuesForIncremental[i]) - } - t.Logf("Incremental insert %d values in %v\n", len(keysForIncremental), time.Since(startTime)) - - startTime = time.Now() - - insertBatchCfg := smt.NewInsertBatchConfig(context.Background(), "", true) - _, err := smtBatch.InsertBatch(insertBatchCfg, keysForBatch, valuesForBatch, nil, nil) - assert.NilError(t, err) - t.Logf("Batch insert %d values in %v\n", len(keysForBatch), time.Since(startTime)) - - smtIncrementalRootHash, _ := smtIncremental.Db.GetLastRoot() - smtBatchRootHash, _ := smtBatch.Db.GetLastRoot() - assert.Equal(t, utils.ConvertBigIntToHex(smtBatchRootHash), utils.ConvertBigIntToHex(smtIncrementalRootHash)) - - assertSmtDbStructure(t, smtBatch, false) - - // DELETE - keysForBatchDelete := []*utils.NodeKey{} - valuesForBatchDelete := []*utils.NodeValue8{} - - keysForIncrementalDelete := []utils.NodeKey{} - valuesForIncrementalDelete := []utils.NodeValue8{} - - sizeToDelete := 1 << 14 - for i := 0; i < sizeToDelete; i++ { - rawValue := big.NewInt(0) - vArray := utils.ScalarToArrayBig(rawValue) - v, _ := utils.NodeValue8FromBigIntArray(vArray) - - deleteIndex := rand.Intn(size) - - keyForBatchDelete := keysForBatch[deleteIndex] - keyForIncrementalDelete := keysForIncremental[deleteIndex] - - keysForBatchDelete = append(keysForBatchDelete, keyForBatchDelete) - valuesForBatchDelete = append(valuesForBatchDelete, v) - - keysForIncrementalDelete = append(keysForIncrementalDelete, keyForIncrementalDelete) - valuesForIncrementalDelete = append(valuesForIncrementalDelete, *v) - } - - startTime = time.Now() - for i := range keysForIncrementalDelete { - smtIncremental.Insert(keysForIncrementalDelete[i], valuesForIncrementalDelete[i]) - } - t.Logf("Incremental delete %d values in %v\n", len(keysForIncrementalDelete), time.Since(startTime)) - - startTime = time.Now() - - _, err = smtBatch.InsertBatch(insertBatchCfg, keysForBatchDelete, valuesForBatchDelete, nil, nil) - assert.NilError(t, err) - t.Logf("Batch delete %d values in %v\n", len(keysForBatchDelete), time.Since(startTime)) - - assertSmtDbStructure(t, smtBatch, false) -} - -func TestCompareAllTreesInsertTimesAndFinalHashesUsingDiskDb(t *testing.T) { - incrementalDbPath := "/tmp/smt-incremental" - smtIncrementalDb, smtIncrementalTx, smtIncrementalSmtDb := initDb(t, incrementalDbPath) - - bulkDbPath := "/tmp/smt-bulk" - smtBulkDb, smtBulkTx, smtBulkSmtDb := initDb(t, bulkDbPath) - - batchDbPath := "/tmp/smt-batch" - smtBatchDb, smtBatchTx, smtBatchSmtDb := initDb(t, batchDbPath) - - smtIncremental := smt.NewSMT(smtIncrementalSmtDb, false) - smtBulk := smt.NewSMT(smtBulkSmtDb, false) - smtBatch := smt.NewSMT(smtBatchSmtDb, false) - - compareAllTreesInsertTimesAndFinalHashes(t, smtIncremental, smtBulk, smtBatch) - - smtIncrementalTx.Commit() - smtBulkTx.Commit() - smtBatchTx.Commit() - t.Cleanup(func() { - smtIncrementalDb.Close() - smtBulkDb.Close() - smtBatchDb.Close() - os.RemoveAll(incrementalDbPath) - os.RemoveAll(bulkDbPath) - os.RemoveAll(batchDbPath) - }) -} - -func TestCompareAllTreesInsertTimesAndFinalHashesUsingInMemoryDb(t *testing.T) { - smtIncremental := smt.NewSMT(nil, false) - smtBulk := smt.NewSMT(nil, false) - smtBatch := smt.NewSMT(nil, false) - - compareAllTreesInsertTimesAndFinalHashes(t, smtIncremental, smtBulk, smtBatch) -} - -func compareAllTreesInsertTimesAndFinalHashes(t *testing.T, smtIncremental, smtBulk, smtBatch *smt.SMT) { - batchInsertDataHolders, totalInserts := prepareData() - ctx := context.Background() - var incrementalError error - - accChanges := make(map[libcommon.Address]*accounts.Account) - codeChanges := make(map[libcommon.Address]string) - storageChanges := make(map[libcommon.Address]map[string]string) - - for _, batchInsertDataHolder := range batchInsertDataHolders { - accChanges[batchInsertDataHolder.AddressAccount] = &batchInsertDataHolder.acc - codeChanges[batchInsertDataHolder.AddressContract] = batchInsertDataHolder.Bytecode - storageChanges[batchInsertDataHolder.AddressContract] = batchInsertDataHolder.Storage - } - - startTime := time.Now() - for addr, acc := range accChanges { - if err := smtIncremental.SetAccountStorage(addr, acc); err != nil { - incrementalError = err - } - } - - for addr, code := range codeChanges { - if err := smtIncremental.SetContractBytecode(addr.String(), code); err != nil { - incrementalError = err - } - } - - for addr, storage := range storageChanges { - if _, err := smtIncremental.SetContractStorage(addr.String(), storage, nil); err != nil { - incrementalError = err - } - } - - assert.NilError(t, incrementalError) - t.Logf("Incremental insert %d values in %v\n", totalInserts, time.Since(startTime)) - - startTime = time.Now() - keyPointers, valuePointers, err := smtBatch.SetStorage(ctx, "", accChanges, codeChanges, storageChanges) - assert.NilError(t, err) - t.Logf("Batch insert %d values in %v\n", totalInserts, time.Since(startTime)) - - keys := []utils.NodeKey{} - for i, key := range keyPointers { - v := valuePointers[i] - if !v.IsZero() { - smtBulk.Db.InsertAccountValue(*key, *v) - keys = append(keys, *key) - } - } - startTime = time.Now() - smtBulk.GenerateFromKVBulk(ctx, "", keys) - t.Logf("Bulk insert %d values in %v\n", totalInserts, time.Since(startTime)) - - smtIncrementalRootHash, _ := smtIncremental.Db.GetLastRoot() - smtBatchRootHash, _ := smtBatch.Db.GetLastRoot() - smtBulkRootHash, _ := smtBulk.Db.GetLastRoot() - assert.Equal(t, utils.ConvertBigIntToHex(smtBatchRootHash), utils.ConvertBigIntToHex(smtIncrementalRootHash)) - assert.Equal(t, utils.ConvertBigIntToHex(smtBulkRootHash), utils.ConvertBigIntToHex(smtIncrementalRootHash)) - - assertSmtDbStructure(t, smtBatch, true) -} - -func initDb(t *testing.T, dbPath string) (kv.RwDB, kv.RwTx, *db.EriDb) { - ctx := context.Background() - - os.RemoveAll(dbPath) - - dbOpts := mdbx.NewMDBX(log.Root()).Path(dbPath).Label(kv.ChainDB).GrowthStep(16 * datasize.MB).RoTxsLimiter(semaphore.NewWeighted(128)) - database, err := dbOpts.Open(ctx) - if err != nil { - t.Fatalf("Cannot create db %e", err) - } - - migrator := migrations.NewMigrator(kv.ChainDB) - if err := migrator.VerifyVersion(database); err != nil { - t.Fatalf("Cannot verify db version %e", err) - } - // if err = migrator.Apply(database, dbPath); err != nil { - // t.Fatalf("Cannot migrate db %e", err) - // } - - // if err := database.Update(context.Background(), func(tx kv.RwTx) (err error) { - // return params.SetErigonVersion(tx, "test") - // }); err != nil { - // t.Fatalf("Cannot update db") - // } - - dbTransaction, err := database.BeginRw(ctx) - if err != nil { - t.Fatalf("Cannot craete db transaction") - } - - db.CreateEriDbBuckets(dbTransaction) - return database, dbTransaction, db.NewEriDb(dbTransaction) -} - -func prepareData() ([]*BatchInsertDataHolder, int) { - treeSize := 150 - storageSize := 96 - batchInsertDataHolders := make([]*BatchInsertDataHolder, 0) - rand.Seed(1) - for i := 0; i < treeSize; i++ { - storage := make(map[string]string) - addressAccountBytes := make([]byte, 20) - addressContractBytes := make([]byte, 20) - storageKeyBytes := make([]byte, 20) - storageValueBytes := make([]byte, 20) - rand.Read(addressAccountBytes) - rand.Read(addressContractBytes) - - for j := 0; j < storageSize; j++ { - rand.Read(storageKeyBytes) - rand.Read(storageValueBytes) - storage[libcommon.BytesToAddress(storageKeyBytes).Hex()] = libcommon.BytesToAddress(storageValueBytes).Hex() - } - - acc := accounts.NewAccount() - acc.Balance = *uint256.NewInt(rand.Uint64()) - acc.Nonce = rand.Uint64() - - batchInsertDataHolders = append(batchInsertDataHolders, &BatchInsertDataHolder{ - acc: acc, - AddressAccount: libcommon.BytesToAddress(addressAccountBytes), - AddressContract: libcommon.BytesToAddress(addressContractBytes), - Bytecode: "0x60806040526004361061007b5760003560e01c80639623609d1161004e5780639623609d1461012b57806399a88ec41461013e578063f2fde38b1461015e578063f3b7dead1461017e57600080fd5b8063204e1c7a14610080578063715018a6146100c95780637eff275e146100e05780638da5cb5b14610100575b600080fd5b34801561008c57600080fd5b506100a061009b366004610608565b61019e565b60405173ffffffffffffffffffffffffffffffffffffffff909116815260200160405180910390f35b3480156100d557600080fd5b506100de610255565b005b3480156100ec57600080fd5b506100de6100fb36600461062c565b610269565b34801561010c57600080fd5b5060005473ffffffffffffffffffffffffffffffffffffffff166100a0565b6100de610139366004610694565b6102f7565b34801561014a57600080fd5b506100de61015936600461062c565b61038c565b34801561016a57600080fd5b506100de610179366004610608565b6103e8565b34801561018a57600080fd5b506100a0610199366004610608565b6104a4565b60008060008373ffffffffffffffffffffffffffffffffffffffff166040516101ea907f5c60da1b00000000000000000000000000000000000000000000000000000000815260040190565b600060405180830381855afa9150503d8060008114610225576040519150601f19603f3d011682016040523d82523d6000602084013e61022a565b606091505b50915091508161023957600080fd5b8080602001905181019061024d9190610788565b949350505050565b61025d6104f0565b6102676000610571565b565b6102716104f0565b6040517f8f28397000000000000000000000000000000000000000000000000000000000815273ffffffffffffffffffffffffffffffffffffffff8281166004830152831690638f283970906024015b600060405180830381600087803b1580156102db57600080fd5b505af11580156102ef573d6000803e3d6000fd5b505050505050565b6102ff6104f0565b6040517f4f1ef28600000000000000000000000000000000000000000000000000000000815273ffffffffffffffffffffffffffffffffffffffff841690634f1ef28690349061035590869086906004016107a5565b6000604051808303818588803b15801561036e57600080fd5b505af1158015610382573d6000803e3d6000fd5b5050505050505050565b6103946104f0565b6040517f3659cfe600000000000000000000000000000000000000000000000000000000815273ffffffffffffffffffffffffffffffffffffffff8281166004830152831690633659cfe6906024016102c1565b6103f06104f0565b73ffffffffffffffffffffffffffffffffffffffff8116610498576040517f08c379a000000000000000000000000000000000000000000000000000000000815260206004820152602660248201527f4f776e61626c653a206e6577206f776e657220697320746865207a65726f206160448201527f646472657373000000000000000000000000000000000000000000000000000060648201526084015b60405180910390fd5b6104a181610571565b50565b60008060008373ffffffffffffffffffffffffffffffffffffffff166040516101ea907ff851a44000000000000000000000000000000000000000000000000000000000815260040190565b60005473ffffffffffffffffffffffffffffffffffffffff163314610267576040517f08c379a000000000000000000000000000000000000000000000000000000000815260206004820181905260248201527f4f776e61626c653a2063616c6c6572206973206e6f7420746865206f776e6572604482015260640161048f565b6000805473ffffffffffffffffffffffffffffffffffffffff8381167fffffffffffffffffffffffff0000000000000000000000000000000000000000831681178455604051919092169283917f8be0079c531659141344cd1fd0a4f28419497f9722a3daafe3b4186f6b6457e09190a35050565b73ffffffffffffffffffffffffffffffffffffffff811681146104a157600080fd5b60006020828403121561061a57600080fd5b8135610625816105e6565b9392505050565b6000806040838503121561063f57600080fd5b823561064a816105e6565b9150602083013561065a816105e6565b809150509250929050565b7f4e487b7100000000000000000000000000000000000000000000000000000000600052604160045260246000fd5b6000806000606084860312156106a957600080fd5b83356106b4816105e6565b925060208401356106c4816105e6565b9150604084013567ffffffffffffffff808211156106e157600080fd5b818601915086601f8301126106f557600080fd5b81358181111561070757610707610665565b604051601f82017fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe0908116603f0116810190838211818310171561074d5761074d610665565b8160405282815289602084870101111561076657600080fd5b8260208601602083013760006020848301015280955050505050509250925092565b60006020828403121561079a57600080fd5b8151610625816105e6565b73ffffffffffffffffffffffffffffffffffffffff8316815260006020604081840152835180604085015260005b818110156107ef578581018301518582016060015282016107d3565b5060006060828601015260607fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe0601f83011685010192505050939250505056fea2646970667358221220372a0e10eebea1b7fa43ae4c976994e6ed01d85eedc3637b83f01d3f06be442064736f6c63430008110033", - Storage: storage, - }) - } - - return batchInsertDataHolders, treeSize*4 + treeSize*storageSize -} - -func assertSmtDbStructure(t *testing.T, s *smt.SMT, testMetadata bool) { - smtBatchRootHash, _ := s.Db.GetLastRoot() - - actualDb, ok := s.Db.(*db.MemDb) - if !ok { - return - } - - usedNodeHashesMap := make(map[string]*utils.NodeKey) - assertSmtTreeDbStructure(t, s, utils.ScalarToRoot(smtBatchRootHash), usedNodeHashesMap) - - // EXPLAIN THE LINE BELOW: db could have more values because values' hashes are not deleted - assert.Equal(t, true, len(actualDb.Db)-len(usedNodeHashesMap) >= 0) - for k := range usedNodeHashesMap { - _, found := actualDb.Db[k] - assert.Equal(t, true, found) - } - - totalLeaves := assertHashToKeyDbStrcture(t, s, utils.ScalarToRoot(smtBatchRootHash), testMetadata) - assert.Equal(t, totalLeaves, len(actualDb.DbHashKey)) - if testMetadata { - assert.Equal(t, totalLeaves, len(actualDb.DbKeySource)) - } - - assertTraverse(t, s) -} - -func assertSmtTreeDbStructure(t *testing.T, s *smt.SMT, nodeHash utils.NodeKey, usedNodeHashesMap map[string]*utils.NodeKey) { - if nodeHash.IsZero() { - return - } - - dbNodeValue, err := s.Db.Get(nodeHash) - assert.NilError(t, err) - - nodeHashHex := utils.ConvertBigIntToHex(utils.ArrayToScalar(nodeHash[:])) - usedNodeHashesMap[nodeHashHex] = &nodeHash - - if dbNodeValue.IsFinalNode() { - nodeValueHash := utils.NodeKeyFromBigIntArray(dbNodeValue[4:8]) - dbNodeValue, err = s.Db.Get(nodeValueHash) - assert.NilError(t, err) - - nodeHashHex := utils.ConvertBigIntToHex(utils.ArrayToScalar(nodeValueHash[:])) - usedNodeHashesMap[nodeHashHex] = &nodeValueHash - return - } - - assertSmtTreeDbStructure(t, s, utils.NodeKeyFromBigIntArray(dbNodeValue[0:4]), usedNodeHashesMap) - assertSmtTreeDbStructure(t, s, utils.NodeKeyFromBigIntArray(dbNodeValue[4:8]), usedNodeHashesMap) -} - -func assertHashToKeyDbStrcture(t *testing.T, smtBatch *smt.SMT, nodeHash utils.NodeKey, testMetadata bool) int { - if nodeHash.IsZero() { - return 0 - } - - dbNodeValue, err := smtBatch.Db.Get(nodeHash) - assert.NilError(t, err) - - if dbNodeValue.IsFinalNode() { - memDb := smtBatch.Db.(*db.MemDb) - - nodeKey, err := smtBatch.Db.GetHashKey(nodeHash) - assert.NilError(t, err) - - keyConc := utils.ArrayToScalar(nodeHash[:]) - k := utils.ConvertBigIntToHex(keyConc) - _, found := memDb.DbHashKey[k] - assert.Equal(t, found, true) - - if testMetadata { - keyConc = utils.ArrayToScalar(nodeKey[:]) - - _, found = memDb.DbKeySource[keyConc.String()] - assert.Equal(t, found, true) - } - return 1 - } - - return assertHashToKeyDbStrcture(t, smtBatch, utils.NodeKeyFromBigIntArray(dbNodeValue[0:4]), testMetadata) + assertHashToKeyDbStrcture(t, smtBatch, utils.NodeKeyFromBigIntArray(dbNodeValue[4:8]), testMetadata) -} - -func assertTraverse(t *testing.T, s *smt.SMT) { - smtBatchRootHash, _ := s.Db.GetLastRoot() - - ctx := context.Background() - action := func(prefix []byte, k utils.NodeKey, v utils.NodeValue12) (bool, error) { - if v.IsFinalNode() { - valHash := v.Get4to8() - v, err := s.Db.Get(*valHash) - if err != nil { - return false, err - } - - if v[0] == nil { - return false, fmt.Errorf("value is missing in the db") - } - - vInBytes := utils.ArrayBigToScalar(utils.BigIntArrayFromNodeValue8(v.GetNodeValue8())).Bytes() - if vInBytes == nil { - return false, fmt.Errorf("error in converting to bytes") - } - - return false, nil - } - - return true, nil - } - err := s.Traverse(ctx, smtBatchRootHash, action) - assert.NilError(t, err) -} diff --git a/smt/pkg/smt/smt_batch_types_test.go b/smt/pkg/smt/smt_batch_types_test.go new file mode 100644 index 00000000000..253d6073ed7 --- /dev/null +++ b/smt/pkg/smt/smt_batch_types_test.go @@ -0,0 +1,14 @@ +package smt_test + +import ( + libcommon "github.com/ledgerwatch/erigon-lib/common" + "github.com/ledgerwatch/erigon/core/types/accounts" +) + +type BatchInsertDataHolder struct { + acc accounts.Account + AddressAccount libcommon.Address + AddressContract libcommon.Address + Bytecode string + Storage map[string]string +} diff --git a/smt/pkg/smt/smt_batch_utils_test.go b/smt/pkg/smt/smt_batch_utils_test.go new file mode 100644 index 00000000000..05dc69f3b64 --- /dev/null +++ b/smt/pkg/smt/smt_batch_utils_test.go @@ -0,0 +1,89 @@ +package smt_test + +import ( + "context" + "math/rand" + "os" + "testing" + + "github.com/c2h5oh/datasize" + "github.com/holiman/uint256" + libcommon "github.com/ledgerwatch/erigon-lib/common" + "github.com/ledgerwatch/erigon-lib/kv" + "github.com/ledgerwatch/erigon-lib/kv/mdbx" + "github.com/ledgerwatch/erigon/core/types/accounts" + "github.com/ledgerwatch/erigon/migrations" + "github.com/ledgerwatch/erigon/smt/pkg/db" + "github.com/ledgerwatch/log/v3" + "golang.org/x/sync/semaphore" +) + +func initDb(t *testing.T, dbPath string) (kv.RwDB, kv.RwTx, *db.EriDb) { + ctx := context.Background() + + os.RemoveAll(dbPath) + + dbOpts := mdbx.NewMDBX(log.Root()).Path(dbPath).Label(kv.ChainDB).GrowthStep(16 * datasize.MB).RoTxsLimiter(semaphore.NewWeighted(128)) + database, err := dbOpts.Open(ctx) + if err != nil { + t.Fatalf("Cannot create db %e", err) + } + + migrator := migrations.NewMigrator(kv.ChainDB) + if err := migrator.VerifyVersion(database); err != nil { + t.Fatalf("Cannot verify db version %e", err) + } + // if err = migrator.Apply(database, dbPath); err != nil { + // t.Fatalf("Cannot migrate db %e", err) + // } + + // if err := database.Update(context.Background(), func(tx kv.RwTx) (err error) { + // return params.SetErigonVersion(tx, "test") + // }); err != nil { + // t.Fatalf("Cannot update db") + // } + + dbTransaction, err := database.BeginRw(ctx) + if err != nil { + t.Fatalf("Cannot craete db transaction") + } + + db.CreateEriDbBuckets(dbTransaction) + return database, dbTransaction, db.NewEriDb(dbTransaction) +} + +func prepareData() ([]*BatchInsertDataHolder, int) { + treeSize := 150 + storageSize := 96 + batchInsertDataHolders := make([]*BatchInsertDataHolder, 0) + rand.Seed(1) + for i := 0; i < treeSize; i++ { + storage := make(map[string]string) + addressAccountBytes := make([]byte, 20) + addressContractBytes := make([]byte, 20) + storageKeyBytes := make([]byte, 20) + storageValueBytes := make([]byte, 20) + rand.Read(addressAccountBytes) + rand.Read(addressContractBytes) + + for j := 0; j < storageSize; j++ { + rand.Read(storageKeyBytes) + rand.Read(storageValueBytes) + storage[libcommon.BytesToAddress(storageKeyBytes).Hex()] = libcommon.BytesToAddress(storageValueBytes).Hex() + } + + acc := accounts.NewAccount() + acc.Balance = *uint256.NewInt(rand.Uint64()) + acc.Nonce = rand.Uint64() + + batchInsertDataHolders = append(batchInsertDataHolders, &BatchInsertDataHolder{ + acc: acc, + AddressAccount: libcommon.BytesToAddress(addressAccountBytes), + AddressContract: libcommon.BytesToAddress(addressContractBytes), + Bytecode: "0x60806040526004361061007b5760003560e01c80639623609d1161004e5780639623609d1461012b57806399a88ec41461013e578063f2fde38b1461015e578063f3b7dead1461017e57600080fd5b8063204e1c7a14610080578063715018a6146100c95780637eff275e146100e05780638da5cb5b14610100575b600080fd5b34801561008c57600080fd5b506100a061009b366004610608565b61019e565b60405173ffffffffffffffffffffffffffffffffffffffff909116815260200160405180910390f35b3480156100d557600080fd5b506100de610255565b005b3480156100ec57600080fd5b506100de6100fb36600461062c565b610269565b34801561010c57600080fd5b5060005473ffffffffffffffffffffffffffffffffffffffff166100a0565b6100de610139366004610694565b6102f7565b34801561014a57600080fd5b506100de61015936600461062c565b61038c565b34801561016a57600080fd5b506100de610179366004610608565b6103e8565b34801561018a57600080fd5b506100a0610199366004610608565b6104a4565b60008060008373ffffffffffffffffffffffffffffffffffffffff166040516101ea907f5c60da1b00000000000000000000000000000000000000000000000000000000815260040190565b600060405180830381855afa9150503d8060008114610225576040519150601f19603f3d011682016040523d82523d6000602084013e61022a565b606091505b50915091508161023957600080fd5b8080602001905181019061024d9190610788565b949350505050565b61025d6104f0565b6102676000610571565b565b6102716104f0565b6040517f8f28397000000000000000000000000000000000000000000000000000000000815273ffffffffffffffffffffffffffffffffffffffff8281166004830152831690638f283970906024015b600060405180830381600087803b1580156102db57600080fd5b505af11580156102ef573d6000803e3d6000fd5b505050505050565b6102ff6104f0565b6040517f4f1ef28600000000000000000000000000000000000000000000000000000000815273ffffffffffffffffffffffffffffffffffffffff841690634f1ef28690349061035590869086906004016107a5565b6000604051808303818588803b15801561036e57600080fd5b505af1158015610382573d6000803e3d6000fd5b5050505050505050565b6103946104f0565b6040517f3659cfe600000000000000000000000000000000000000000000000000000000815273ffffffffffffffffffffffffffffffffffffffff8281166004830152831690633659cfe6906024016102c1565b6103f06104f0565b73ffffffffffffffffffffffffffffffffffffffff8116610498576040517f08c379a000000000000000000000000000000000000000000000000000000000815260206004820152602660248201527f4f776e61626c653a206e6577206f776e657220697320746865207a65726f206160448201527f646472657373000000000000000000000000000000000000000000000000000060648201526084015b60405180910390fd5b6104a181610571565b50565b60008060008373ffffffffffffffffffffffffffffffffffffffff166040516101ea907ff851a44000000000000000000000000000000000000000000000000000000000815260040190565b60005473ffffffffffffffffffffffffffffffffffffffff163314610267576040517f08c379a000000000000000000000000000000000000000000000000000000000815260206004820181905260248201527f4f776e61626c653a2063616c6c6572206973206e6f7420746865206f776e6572604482015260640161048f565b6000805473ffffffffffffffffffffffffffffffffffffffff8381167fffffffffffffffffffffffff0000000000000000000000000000000000000000831681178455604051919092169283917f8be0079c531659141344cd1fd0a4f28419497f9722a3daafe3b4186f6b6457e09190a35050565b73ffffffffffffffffffffffffffffffffffffffff811681146104a157600080fd5b60006020828403121561061a57600080fd5b8135610625816105e6565b9392505050565b6000806040838503121561063f57600080fd5b823561064a816105e6565b9150602083013561065a816105e6565b809150509250929050565b7f4e487b7100000000000000000000000000000000000000000000000000000000600052604160045260246000fd5b6000806000606084860312156106a957600080fd5b83356106b4816105e6565b925060208401356106c4816105e6565b9150604084013567ffffffffffffffff808211156106e157600080fd5b818601915086601f8301126106f557600080fd5b81358181111561070757610707610665565b604051601f82017fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe0908116603f0116810190838211818310171561074d5761074d610665565b8160405282815289602084870101111561076657600080fd5b8260208601602083013760006020848301015280955050505050509250925092565b60006020828403121561079a57600080fd5b8151610625816105e6565b73ffffffffffffffffffffffffffffffffffffffff8316815260006020604081840152835180604085015260005b818110156107ef578581018301518582016060015282016107d3565b5060006060828601015260607fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe0601f83011685010192505050939250505056fea2646970667358221220372a0e10eebea1b7fa43ae4c976994e6ed01d85eedc3637b83f01d3f06be442064736f6c63430008110033", + Storage: storage, + }) + } + + return batchInsertDataHolders, treeSize*4 + treeSize*storageSize +} diff --git a/smt/pkg/smt/smt_create.go b/smt/pkg/smt/smt_create.go index f204a9df29d..da017b96d82 100644 --- a/smt/pkg/smt/smt_create.go +++ b/smt/pkg/smt/smt_create.go @@ -64,7 +64,7 @@ func (s *SMT) GenerateFromKVBulk(ctx context.Context, logPrefix string, nodeKeys maxReachedLevel := 0 - deletesWorker := utils.NewWorker(ctx, "smt_save_finished", 1000) + deletesWorker := utils.NewWorker(ctx, "smt_save_finished", 5) // start a worker to delete finished parts of the tree and return values to save to the db wg := sync.WaitGroup{} diff --git a/smt/pkg/utils/utils.go b/smt/pkg/utils/utils.go index f932d278c9f..909e0c87b13 100644 --- a/smt/pkg/utils/utils.go +++ b/smt/pkg/utils/utils.go @@ -47,11 +47,13 @@ const ( var ( LeafCapacity = [4]uint64{1, 0, 0, 0} BranchCapacity = [4]uint64{0, 0, 0, 0} - hashFunc = poseidon.Hash + hashFunc = poseidon.HashWithResult ) -func Hash(in [8]uint64, capacity [4]uint64) ([4]uint64, error) { - return hashFunc(in, capacity) +func Hash(in [8]uint64, capacity [4]uint64) [4]uint64 { + var result [4]uint64 = [4]uint64{0, 0, 0, 0} + hashFunc(&in, &capacity, &result) + return result } func (nk *NodeKey) IsZero() bool { @@ -568,30 +570,30 @@ func StringToH4(s string) ([4]uint64, error) { return res, nil } -func KeyEthAddrBalance(ethAddr string) (NodeKey, error) { +func KeyEthAddrBalance(ethAddr string) NodeKey { return Key(ethAddr, KEY_BALANCE) } -func KeyEthAddrNonce(ethAddr string) (NodeKey, error) { +func KeyEthAddrNonce(ethAddr string) NodeKey { return Key(ethAddr, KEY_NONCE) } -func KeyContractCode(ethAddr string) (NodeKey, error) { +func KeyContractCode(ethAddr string) NodeKey { return Key(ethAddr, SC_CODE) } -func KeyContractLength(ethAddr string) (NodeKey, error) { +func KeyContractLength(ethAddr string) NodeKey { return Key(ethAddr, SC_LENGTH) } -func Key(ethAddr string, c int) (NodeKey, error) { +func Key(ethAddr string, c int) NodeKey { a := ConvertHexToBigInt(ethAddr) add := ScalarToArrayBig(a) key1 := NodeValue8{add[0], add[1], add[2], add[3], add[4], add[5], big.NewInt(int64(c)), big.NewInt(0)} key1Capacity, err := StringToH4(HASH_POSEIDON_ALL_ZEROES) if err != nil { - return NodeKey{}, err + return NodeKey{} } return Hash(key1.ToUintArray(), key1Capacity) @@ -609,8 +611,8 @@ func KeyBig(k *big.Int, c int) (*NodeKey, error) { return nil, err } - hk0, err := Hash(key1.ToUintArray(), key1Capacity) - return &NodeKey{hk0[0], hk0[1], hk0[2], hk0[3]}, err + hk0 := Hash(key1.ToUintArray(), key1Capacity) + return &NodeKey{hk0[0], hk0[1], hk0[2], hk0[3]}, nil } func StrValToBigInt(v string) (*big.Int, bool) { @@ -621,25 +623,21 @@ func StrValToBigInt(v string) (*big.Int, bool) { return new(big.Int).SetString(v, 10) } -func KeyContractStorage(ethAddr []*big.Int, storagePosition string) (NodeKey, error) { +func KeyContractStorage(ethAddr []*big.Int, storagePosition string) NodeKey { sp, _ := StrValToBigInt(storagePosition) spArray, err := NodeValue8FromBigIntArray(ScalarToArrayBig(sp)) if err != nil { - return NodeKey{}, err + return NodeKey{} } - hk0, err := Hash(spArray.ToUintArray(), [4]uint64{0, 0, 0, 0}) - if err != nil { - return NodeKey{}, err - } + hk0 := Hash(spArray.ToUintArray(), [4]uint64{0, 0, 0, 0}) key1 := NodeValue8{ethAddr[0], ethAddr[1], ethAddr[2], ethAddr[3], ethAddr[4], ethAddr[5], big.NewInt(int64(SC_STORAGE)), big.NewInt(0)} return Hash(key1.ToUintArray(), hk0) } -func HashContractBytecode(bc string) (string, error) { - var err error +func HashContractBytecode(bc string) string { bytecode := bc if strings.HasPrefix(bc, "0x") { @@ -700,15 +698,10 @@ func HashContractBytecode(bc string) (string, error) { var capacity [4]uint64 copy(capacity[:], elementsToHash[:4]) - tmpHash, err = Hash(in, capacity) - if err != nil { - return "", err - } + tmpHash = Hash(in, capacity) } - hex := ConvertBigIntToHex(ArrayToScalar(tmpHash[:])) - - return hex, err + return ConvertBigIntToHex(ArrayToScalar(tmpHash[:])) } func ResizeHashTo32BytesByPrefixingWithZeroes(hashValue []byte) []byte { diff --git a/turbo/cli/default_flags.go b/turbo/cli/default_flags.go index 8e826c2d1f9..f4c744a55b6 100644 --- a/turbo/cli/default_flags.go +++ b/turbo/cli/default_flags.go @@ -190,6 +190,7 @@ var DefaultFlags = []cli.Flag{ &utils.L1HighestBlockTypeFlag, &utils.L1MaticContractAddressFlag, &utils.L1FirstBlockFlag, + &utils.L1FinalizedBlockRequirementFlag, &utils.L1ContractAddressCheckFlag, &utils.RpcRateLimitsFlag, &utils.RpcGetBatchWitnessConcurrencyLimitFlag, @@ -202,6 +203,9 @@ var DefaultFlags = []cli.Flag{ &utils.SequencerBatchVerificationTimeout, &utils.SequencerTimeoutOnEmptyTxPool, &utils.SequencerHaltOnBatchNumber, + &utils.SequencerResequence, + &utils.SequencerResequenceStrict, + &utils.SequencerResequenceReuseL1InfoIndex, &utils.ExecutorUrls, &utils.ExecutorStrictMode, &utils.ExecutorRequestTimeout, diff --git a/turbo/cli/flags_zkevm.go b/turbo/cli/flags_zkevm.go index 1784d3b7443..e33d68515e1 100644 --- a/turbo/cli/flags_zkevm.go +++ b/turbo/cli/flags_zkevm.go @@ -129,6 +129,7 @@ func ApplyFlagsForZkConfig(ctx *cli.Context, cfg *ethconfig.Config) { L1HighestBlockType: ctx.String(utils.L1HighestBlockTypeFlag.Name), L1MaticContractAddress: libcommon.HexToAddress(ctx.String(utils.L1MaticContractAddressFlag.Name)), L1FirstBlock: ctx.Uint64(utils.L1FirstBlockFlag.Name), + L1FinalizedBlockRequirement: ctx.Uint64(utils.L1FinalizedBlockRequirementFlag.Name), L1ContractAddressCheck: ctx.Bool(utils.L1ContractAddressCheckFlag.Name), RpcRateLimits: ctx.Int(utils.RpcRateLimitsFlag.Name), RpcGetBatchWitnessConcurrencyLimit: ctx.Int(utils.RpcGetBatchWitnessConcurrencyLimitFlag.Name), @@ -141,6 +142,9 @@ func ApplyFlagsForZkConfig(ctx *cli.Context, cfg *ethconfig.Config) { SequencerBatchVerificationTimeout: sequencerBatchVerificationTimeout, SequencerTimeoutOnEmptyTxPool: sequencerTimeoutOnEmptyTxPool, SequencerHaltOnBatchNumber: ctx.Uint64(utils.SequencerHaltOnBatchNumber.Name), + SequencerResequence: ctx.Bool(utils.SequencerResequence.Name), + SequencerResequenceStrict: ctx.Bool(utils.SequencerResequenceStrict.Name), + SequencerResequenceReuseL1InfoIndex: ctx.Bool(utils.SequencerResequenceReuseL1InfoIndex.Name), ExecutorUrls: strings.Split(strings.ReplaceAll(ctx.String(utils.ExecutorUrls.Name), " ", ""), ","), ExecutorStrictMode: ctx.Bool(utils.ExecutorStrictMode.Name), ExecutorRequestTimeout: ctx.Duration(utils.ExecutorRequestTimeout.Name), diff --git a/turbo/jsonrpc/zkevm_api.go b/turbo/jsonrpc/zkevm_api.go index ece4c7c60cb..0a341686a5f 100644 --- a/turbo/jsonrpc/zkevm_api.go +++ b/turbo/jsonrpc/zkevm_api.go @@ -73,6 +73,10 @@ type ZkEvmAPI interface { GetBatchCountersByNumber(ctx context.Context, batchNumRpc rpc.BlockNumber) (res json.RawMessage, err error) GetExitRootTable(ctx context.Context) ([]l1InfoTreeData, error) GetVersionHistory(ctx context.Context) (json.RawMessage, error) + GetForkId(ctx context.Context) (hexutil.Uint64, error) + GetForkById(ctx context.Context, forkId hexutil.Uint64) (res json.RawMessage, err error) + GetForkIdByBatchNumber(ctx context.Context, batchNumber rpc.BlockNumber) (hexutil.Uint64, error) + GetForks(ctx context.Context) (res json.RawMessage, err error) } const getBatchWitness = "getBatchWitness" @@ -695,13 +699,14 @@ func (api *ZkEvmAPIImpl) GetBatchByNumber(ctx context.Context, batchNumber rpc.B } batch.BatchL2Data = batchL2Data - oldAccInputHash, err := api.l1Syncer.GetOldAccInputHash(ctx, &api.config.AddressRollup, api.config.L1RollupId, batchNo) - if err != nil { - log.Warn("Failed to get old acc input hash", "err", err) - batch.AccInputHash = common.Hash{} + if api.l1Syncer != nil { + oldAccInputHash, err := api.l1Syncer.GetOldAccInputHash(ctx, &api.config.AddressRollup, api.config.L1RollupId, batchNo) + if err != nil { + log.Warn("Failed to get old acc input hash", "err", err) + batch.AccInputHash = common.Hash{} + } + batch.AccInputHash = oldAccInputHash } - batch.AccInputHash = oldAccInputHash - // forkid exit roots logic // if forkid < 12 then we should only set the exit roots if they have changed, otherwise 0x00..00 // if forkid >= 12 then we should always set the exit roots @@ -1226,6 +1231,55 @@ func getBatchNoByL2Block(tx kv.Tx, l2BlockNo uint64) (uint64, error) { return reader.GetBatchNoByL2Block(l2BlockNo) } +func getForkIdByBatchNo(tx kv.Tx, batchNo uint64) (uint64, error) { + reader := hermez_db.NewHermezDbReader(tx) + return reader.GetForkId(batchNo) +} + +func getForkInterval(tx kv.Tx, forkId uint64) (*rpc.ForkInterval, error) { + reader := hermez_db.NewHermezDbReader(tx) + + forkInterval, found, err := reader.GetForkInterval(forkId) + if err != nil { + return nil, err + } else if !found { + return nil, nil + } + + result := rpc.ForkInterval{ + ForkId: hexutil.Uint64(forkInterval.ForkID), + FromBatchNumber: hexutil.Uint64(forkInterval.FromBatchNumber), + ToBatchNumber: hexutil.Uint64(forkInterval.ToBatchNumber), + Version: "", + BlockNumber: hexutil.Uint64(forkInterval.BlockNumber), + } + + return &result, nil +} + +func getForkIntervals(tx kv.Tx) ([]rpc.ForkInterval, error) { + reader := hermez_db.NewHermezDbReader(tx) + + forkIntervals, err := reader.GetAllForkIntervals() + if err != nil { + return nil, err + } + + result := make([]rpc.ForkInterval, 0, len(forkIntervals)) + + for _, forkInterval := range forkIntervals { + result = append(result, rpc.ForkInterval{ + ForkId: hexutil.Uint64(forkInterval.ForkID), + FromBatchNumber: hexutil.Uint64(forkInterval.FromBatchNumber), + ToBatchNumber: hexutil.Uint64(forkInterval.ToBatchNumber), + Version: "", + BlockNumber: hexutil.Uint64(forkInterval.BlockNumber), + }) + } + + return result, nil +} + func convertTransactionsReceipts( txs []eritypes.Transaction, receipts eritypes.Receipts, @@ -1570,25 +1624,10 @@ func (zkapi *ZkEvmAPIImpl) GetProof(ctx context.Context, address common.Address, return nil, err } - balanceKey, err := smtUtils.KeyEthAddrBalance(address.String()) - if err != nil { - return nil, err - } - - nonceKey, err := smtUtils.KeyEthAddrNonce(address.String()) - if err != nil { - return nil, err - } - - codeHashKey, err := smtUtils.KeyContractCode(address.String()) - if err != nil { - return nil, err - } - - codeLengthKey, err := smtUtils.KeyContractLength(address.String()) - if err != nil { - return nil, err - } + balanceKey := smtUtils.KeyEthAddrBalance(address.String()) + nonceKey := smtUtils.KeyEthAddrNonce(address.String()) + codeHashKey := smtUtils.KeyContractCode(address.String()) + codeLengthKey := smtUtils.KeyContractLength(address.String()) balanceProofs := smt.FilterProofs(proofs, balanceKey) balanceBytes, err := smt.VerifyAndGetVal(stateRootNode, balanceProofs, balanceKey) @@ -1634,10 +1673,7 @@ func (zkapi *ZkEvmAPIImpl) GetProof(ctx context.Context, address common.Address, addressArrayBig := smtUtils.ScalarToArrayBig(smtUtils.ConvertHexToBigInt(address.String())) for _, k := range storageKeys { - storageKey, err := smtUtils.KeyContractStorage(addressArrayBig, k.String()) - if err != nil { - return nil, err - } + storageKey := smtUtils.KeyContractStorage(addressArrayBig, k.String()) storageProofs := smt.FilterProofs(proofs, storageKey) valueBytes, err := smt.VerifyAndGetVal(stateRootNode, storageProofs, storageKey) @@ -1656,3 +1692,86 @@ func (zkapi *ZkEvmAPIImpl) GetProof(ctx context.Context, address common.Address, return accProof, nil } + +// ForkId returns the network's current fork ID +func (api *ZkEvmAPIImpl) GetForkId(ctx context.Context) (hexutil.Uint64, error) { + tx, err := api.db.BeginRo(ctx) + if err != nil { + return hexutil.Uint64(0), err + } + defer tx.Rollback() + + currentBatchNumber, err := getLatestBatchNumber(tx) + if err != nil { + return 0, err + } + + currentForkId, err := getForkIdByBatchNo(tx, currentBatchNumber) + if err != nil { + return 0, err + } + + return hexutil.Uint64(currentForkId), err +} + +// GetForkById returns the network fork interval given the provided fork id +func (api *ZkEvmAPIImpl) GetForkById(ctx context.Context, forkId hexutil.Uint64) (res json.RawMessage, err error) { + tx, err := api.db.BeginRo(ctx) + if err != nil { + return nil, err + } + defer tx.Rollback() + + forkInterval, err := getForkInterval(tx, uint64(forkId)) + if err != nil { + return nil, err + } + + if forkInterval == nil { + return nil, nil + } + + forkJson, err := json.Marshal(forkInterval) + if err != nil { + return nil, err + } + + return forkJson, err +} + +// GetForkIdByBatchNumber returns the fork ID given the provided batch number +func (api *ZkEvmAPIImpl) GetForkIdByBatchNumber(ctx context.Context, batchNumber rpc.BlockNumber) (hexutil.Uint64, error) { + tx, err := api.db.BeginRo(ctx) + if err != nil { + return hexutil.Uint64(0), err + } + defer tx.Rollback() + + currentForkId, err := getForkIdByBatchNo(tx, uint64(batchNumber)) + if err != nil { + return 0, err + } + + return hexutil.Uint64(currentForkId), err +} + +// GetForks returns the network fork intervals +func (api *ZkEvmAPIImpl) GetForks(ctx context.Context) (res json.RawMessage, err error) { + tx, err := api.db.BeginRo(ctx) + if err != nil { + return nil, err + } + defer tx.Rollback() + + forkIntervals, err := getForkIntervals(tx) + if err != nil { + return nil, err + } + + forksJson, err := json.Marshal(forkIntervals) + if err != nil { + return nil, err + } + + return forksJson, err +} diff --git a/turbo/jsonrpc/zkevm_api_test.go b/turbo/jsonrpc/zkevm_api_test.go index 51ed2e07d59..d80d537bf3a 100644 --- a/turbo/jsonrpc/zkevm_api_test.go +++ b/turbo/jsonrpc/zkevm_api_test.go @@ -5,6 +5,7 @@ import ( "encoding/hex" "encoding/json" "fmt" + "math" "math/big" "testing" "time" @@ -1205,3 +1206,218 @@ func TestGetFullBlockByHash(t *testing.T) { assert.Equal(tx7.Hash(), *block4.Transactions[1].Hash) assert.Equal(tx8.Hash(), *block4.Transactions[2].Hash) } + +func TestGetForkId(t *testing.T) { + assert := assert.New(t) + + ////////////// + contractBackend := backends.NewTestSimulatedBackendWithConfig(t, gspec.Alloc, gspec.Config, gspec.GasLimit) + defer contractBackend.Close() + stateCache := kvcache.New(kvcache.DefaultCoherentConfig) + contractBackend.Commit() + /////////// + + db := contractBackend.DB() + agg := contractBackend.Agg() + + baseApi := NewBaseApi(nil, stateCache, contractBackend.BlockReader(), agg, false, rpccfg.DefaultEvmCallTimeout, contractBackend.Engine(), datadir.New(t.TempDir())) + ethImpl := NewEthAPI(baseApi, db, nil, nil, nil, 5000000, 100_000, 100_000, ðconfig.Defaults, false, 100, 100, log.New()) + var l1Syncer *syncer.L1Syncer + zkEvmImpl := NewZkEvmAPI(ethImpl, db, 100_000, ðconfig.Defaults, l1Syncer, "") + tx, err := db.BeginRw(ctx) + assert.NoError(err) + hDB := hermez_db.NewHermezDb(tx) + + for i := 1; i <= 10; i++ { + err := hDB.WriteBlockBatch(uint64(i), uint64(i)) + assert.NoError(err) + err = hDB.WriteForkId(uint64(i), uint64(i)) + assert.NoError(err) + } + err = hDB.WriteSequence(4, 4, common.HexToHash("0x21ddb9a356815c3fac1026b6dec5df3124afbadb485c9ba5a3e3398a04b7ba85"), common.HexToHash("0xcefad4e508c098b9a7e1d8feb19955fb02ba9675585078710969d3440f5054e0")) + assert.NoError(err) + err = hDB.WriteSequence(7, 7, common.HexToHash("0x21ddb9a356815c3fac1026b6dec5df3124afbadb485c9ba5a3e3398a04b7ba86"), common.HexToHash("0xcefad4e508c098b9a7e1d8feb19955fb02ba9675585078710969d3440f5054e1")) + assert.NoError(err) + for i := 1; i <= 4; i++ { + err = stages.SaveStageProgress(tx, stages.L1VerificationsBatchNo, uint64(i)) + assert.NoError(err) + } + + tx.Commit() + forkId, err := zkEvmImpl.GetForkId(ctx) + assert.NoError(err) + assert.Equal(hexutil.Uint64(10), forkId) +} + +func TestGetForkIdByBatchNumber(t *testing.T) { + assert := assert.New(t) + + ////////////// + contractBackend := backends.NewTestSimulatedBackendWithConfig(t, gspec.Alloc, gspec.Config, gspec.GasLimit) + defer contractBackend.Close() + stateCache := kvcache.New(kvcache.DefaultCoherentConfig) + contractBackend.Commit() + /////////// + + db := contractBackend.DB() + agg := contractBackend.Agg() + + baseApi := NewBaseApi(nil, stateCache, contractBackend.BlockReader(), agg, false, rpccfg.DefaultEvmCallTimeout, contractBackend.Engine(), datadir.New(t.TempDir())) + ethImpl := NewEthAPI(baseApi, db, nil, nil, nil, 5000000, 100_000, 100_000, ðconfig.Defaults, false, 100, 100, log.New()) + var l1Syncer *syncer.L1Syncer + zkEvmImpl := NewZkEvmAPI(ethImpl, db, 100_000, ðconfig.Defaults, l1Syncer, "") + tx, err := db.BeginRw(ctx) + assert.NoError(err) + hDB := hermez_db.NewHermezDb(tx) + + for i := 1; i <= 10; i++ { + err := hDB.WriteBlockBatch(uint64(i), uint64(i)) + assert.NoError(err) + err = hDB.WriteForkId(uint64(i), uint64(i)) + assert.NoError(err) + } + err = hDB.WriteSequence(4, 4, common.HexToHash("0x21ddb9a356815c3fac1026b6dec5df3124afbadb485c9ba5a3e3398a04b7ba85"), common.HexToHash("0xcefad4e508c098b9a7e1d8feb19955fb02ba9675585078710969d3440f5054e0")) + assert.NoError(err) + err = hDB.WriteSequence(7, 7, common.HexToHash("0x21ddb9a356815c3fac1026b6dec5df3124afbadb485c9ba5a3e3398a04b7ba86"), common.HexToHash("0xcefad4e508c098b9a7e1d8feb19955fb02ba9675585078710969d3440f5054e1")) + assert.NoError(err) + for i := 1; i <= 4; i++ { + err = stages.SaveStageProgress(tx, stages.L1VerificationsBatchNo, uint64(i)) + assert.NoError(err) + } + + tx.Commit() + forkId, err := zkEvmImpl.GetForkIdByBatchNumber(ctx, 5) + assert.NoError(err) + assert.Equal(hexutil.Uint64(5), forkId) + + forkId, err = zkEvmImpl.GetForkIdByBatchNumber(ctx, 7) + assert.NoError(err) + assert.Equal(hexutil.Uint64(7), forkId) +} + +func TestGetForkById(t *testing.T) { + assert := assert.New(t) + + ////////////// + contractBackend := backends.NewTestSimulatedBackendWithConfig(t, gspec.Alloc, gspec.Config, gspec.GasLimit) + defer contractBackend.Close() + stateCache := kvcache.New(kvcache.DefaultCoherentConfig) + contractBackend.Commit() + /////////// + + db := contractBackend.DB() + agg := contractBackend.Agg() + + baseApi := NewBaseApi(nil, stateCache, contractBackend.BlockReader(), agg, false, rpccfg.DefaultEvmCallTimeout, contractBackend.Engine(), datadir.New(t.TempDir())) + ethImpl := NewEthAPI(baseApi, db, nil, nil, nil, 5000000, 100_000, 100_000, ðconfig.Defaults, false, 100, 100, log.New()) + var l1Syncer *syncer.L1Syncer + zkEvmImpl := NewZkEvmAPI(ethImpl, db, 100_000, ðconfig.Defaults, l1Syncer, "") + tx, err := db.BeginRw(ctx) + assert.NoError(err) + hDB := hermez_db.NewHermezDb(tx) + + for f := uint64(1); f <= 3; f++ { + forkId := f + blockNumber := (forkId * uint64(1000)) + err = hDB.WriteForkIdBlockOnce(forkId, blockNumber) + assert.NoError(err) + for i := uint64(1); i <= 10; i++ { + batchNumber := ((forkId - 1) * uint64(10)) + i + err = hDB.WriteForkId(batchNumber, forkId) + assert.NoError(err) + } + } + + tx.Commit() + forkInterval := rpc.ForkInterval{} + + forkIntervalJson, err := zkEvmImpl.GetForkById(ctx, 1) + assert.NoError(err) + err = json.Unmarshal(forkIntervalJson, &forkInterval) + assert.NoError(err) + assert.Equal(hexutil.Uint64(1), forkInterval.ForkId) + assert.Equal(hexutil.Uint64(1), forkInterval.FromBatchNumber) + assert.Equal(hexutil.Uint64(10), forkInterval.ToBatchNumber) + assert.Equal("", forkInterval.Version) + assert.Equal(hexutil.Uint64(1000), forkInterval.BlockNumber) + + forkIntervalJson, err = zkEvmImpl.GetForkById(ctx, 2) + assert.NoError(err) + err = json.Unmarshal(forkIntervalJson, &forkInterval) + assert.NoError(err) + assert.Equal(hexutil.Uint64(2), forkInterval.ForkId) + assert.Equal(hexutil.Uint64(11), forkInterval.FromBatchNumber) + assert.Equal(hexutil.Uint64(20), forkInterval.ToBatchNumber) + assert.Equal("", forkInterval.Version) + assert.Equal(hexutil.Uint64(2000), forkInterval.BlockNumber) + + forkIntervalJson, err = zkEvmImpl.GetForkById(ctx, 3) + assert.NoError(err) + err = json.Unmarshal(forkIntervalJson, &forkInterval) + assert.NoError(err) + assert.Equal(hexutil.Uint64(3), forkInterval.ForkId) + assert.Equal(hexutil.Uint64(21), forkInterval.FromBatchNumber) + assert.Equal(hexutil.Uint64(math.MaxUint64), forkInterval.ToBatchNumber) + assert.Equal("", forkInterval.Version) + assert.Equal(hexutil.Uint64(3000), forkInterval.BlockNumber) +} + +func TestGetForks(t *testing.T) { + assert := assert.New(t) + + ////////////// + contractBackend := backends.NewTestSimulatedBackendWithConfig(t, gspec.Alloc, gspec.Config, gspec.GasLimit) + defer contractBackend.Close() + stateCache := kvcache.New(kvcache.DefaultCoherentConfig) + contractBackend.Commit() + /////////// + + db := contractBackend.DB() + agg := contractBackend.Agg() + + baseApi := NewBaseApi(nil, stateCache, contractBackend.BlockReader(), agg, false, rpccfg.DefaultEvmCallTimeout, contractBackend.Engine(), datadir.New(t.TempDir())) + ethImpl := NewEthAPI(baseApi, db, nil, nil, nil, 5000000, 100_000, 100_000, ðconfig.Defaults, false, 100, 100, log.New()) + var l1Syncer *syncer.L1Syncer + zkEvmImpl := NewZkEvmAPI(ethImpl, db, 100_000, ðconfig.Defaults, l1Syncer, "") + tx, err := db.BeginRw(ctx) + assert.NoError(err) + hDB := hermez_db.NewHermezDb(tx) + + for f := uint64(1); f <= 3; f++ { + forkId := f + blockNumber := (forkId * uint64(1000)) + err = hDB.WriteForkIdBlockOnce(forkId, blockNumber) + assert.NoError(err) + for i := uint64(1); i <= 10; i++ { + batchNumber := ((forkId - 1) * uint64(10)) + i + err = hDB.WriteForkId(batchNumber, forkId) + assert.NoError(err) + } + } + + tx.Commit() + forksJson, err := zkEvmImpl.GetForks(ctx) + assert.NoError(err) + + forks := []rpc.ForkInterval{} + err = json.Unmarshal(forksJson, &forks) + assert.NoError(err) + + assert.Equal(forks[0].ForkId, hexutil.Uint64(1)) + assert.Equal(forks[0].FromBatchNumber, hexutil.Uint64(1)) + assert.Equal(forks[0].ToBatchNumber, hexutil.Uint64(10)) + assert.Equal(forks[0].Version, "") + assert.Equal(forks[0].BlockNumber, hexutil.Uint64(1000)) + + assert.Equal(forks[1].ForkId, hexutil.Uint64(2)) + assert.Equal(forks[1].FromBatchNumber, hexutil.Uint64(11)) + assert.Equal(forks[1].ToBatchNumber, hexutil.Uint64(20)) + assert.Equal(forks[1].Version, "") + assert.Equal(forks[1].BlockNumber, hexutil.Uint64(2000)) + + assert.Equal(forks[2].ForkId, hexutil.Uint64(3)) + assert.Equal(forks[2].FromBatchNumber, hexutil.Uint64(21)) + assert.Equal(forks[2].ToBatchNumber, hexutil.Uint64(math.MaxUint64)) + assert.Equal(forks[2].Version, "") + assert.Equal(forks[2].BlockNumber, hexutil.Uint64(3000)) +} diff --git a/turbo/transactions/tracing.go b/turbo/transactions/tracing.go index 4d63942da38..06d5aa85ddb 100644 --- a/turbo/transactions/tracing.go +++ b/turbo/transactions/tracing.go @@ -149,8 +149,12 @@ func TraceTx( var streaming bool var counterCollector *vm.TransactionCounter + var executionCounters *vm.CounterCollector if config != nil { counterCollector = config.CounterCollector + if counterCollector != nil { + executionCounters = counterCollector.ExecutionCounters() + } } switch { case config != nil && config.Tracer != nil: @@ -186,15 +190,15 @@ func TraceTx( streaming = false case config == nil: - tracer = logger.NewJsonStreamLogger_ZkEvm(nil, ctx, stream, counterCollector.ExecutionCounters()) + tracer = logger.NewJsonStreamLogger_ZkEvm(nil, ctx, stream, executionCounters) streaming = true default: - tracer = logger.NewJsonStreamLogger_ZkEvm(config.LogConfig, ctx, stream, counterCollector.ExecutionCounters()) + tracer = logger.NewJsonStreamLogger_ZkEvm(config.LogConfig, ctx, stream, executionCounters) streaming = true } - zkConfig := vm.NewZkConfig(vm.Config{Debug: true, Tracer: tracer}, counterCollector.ExecutionCounters()) + zkConfig := vm.NewZkConfig(vm.Config{Debug: true, Tracer: tracer}, executionCounters) // Run the transaction with tracing enabled. vmenv := vm.NewZkEVM(blockCtx, txCtx, ibs, chainConfig, zkConfig) @@ -205,9 +209,9 @@ func TraceTx( if streaming { stream.WriteObjectStart() - if config != nil && config.CounterCollector != nil { + if executionCounters != nil { stream.WriteObjectField("smtLevels") - stream.WriteInt(config.CounterCollector.ExecutionCounters().GetSmtLevels()) + stream.WriteInt(executionCounters.GetSmtLevels()) stream.WriteMore() } diff --git a/zk/datastream/client/commands.go b/zk/datastream/client/commands.go index bb9cbafe825..8676a2807eb 100644 --- a/zk/datastream/client/commands.go +++ b/zk/datastream/client/commands.go @@ -23,23 +23,27 @@ func (c *StreamClient) sendHeaderCmd() error { return nil } -// sendStartBookmarkCmd sends a start command to the server, indicating -// that the client wishes to start streaming from the given bookmark -func (c *StreamClient) sendStartBookmarkCmd(bookmark []byte) error { - err := c.sendCommand(CmdStartBookmark) - if err != nil { - return err +// sendBookmarkCmd sends either CmdStartBookmark or CmdBookmark for the provided bookmark value. +// In case streaming parameter is set to true, the CmdStartBookmark is sent, otherwise the CmdBookmark. +func (c *StreamClient) sendBookmarkCmd(bookmark []byte, streaming bool) error { + // in case we want to stream the entries, CmdStartBookmark is sent, otherwise CmdBookmark command + command := CmdStartBookmark + if !streaming { + command = CmdBookmark } - // Send starting/from entry number - if err := writeFullUint32ToConn(c.conn, uint32(len(bookmark))); err != nil { + // Send the command + if err := c.sendCommand(command); err != nil { return err } - if err := writeBytesToConn(c.conn, bookmark); err != nil { + + // Send bookmark length + if err := writeFullUint32ToConn(c.conn, uint32(len(bookmark))); err != nil { return err } - return nil + // Send the bookmark to retrieve + return writeBytesToConn(c.conn, bookmark) } // sendStartCmd sends a start command to the server, indicating @@ -51,11 +55,18 @@ func (c *StreamClient) sendStartCmd(from uint64) error { } // Send starting/from entry number - if err := writeFullUint64ToConn(c.conn, from); err != nil { + return writeFullUint64ToConn(c.conn, from) +} + +// sendEntryCmd sends the get data stream entry by number command to a TCP connection +func (c *StreamClient) sendEntryCmd(entryNum uint64) error { + // Send CmdEntry command + if err := c.sendCommand(CmdEntry); err != nil { return err } - return nil + // Send entry number + return writeFullUint64ToConn(c.conn, entryNum) } // sendHeaderCmd sends the header command to the server. diff --git a/zk/datastream/client/stream_client.go b/zk/datastream/client/stream_client.go index f7362743284..461b3d9371c 100644 --- a/zk/datastream/client/stream_client.go +++ b/zk/datastream/client/stream_client.go @@ -29,15 +29,19 @@ const ( versionAddedBlockEnd = 3 // Added block end ) +var ( + // ErrFileEntryNotFound denotes error that is returned when the certain file entry is not found in the datastream + ErrFileEntryNotFound = errors.New("file entry not found") +) + type StreamClient struct { ctx context.Context server string // Server address to connect IP:port version int streamType StreamType conn net.Conn - id string // Client id - Header types.HeaderEntry // Header info received (from Header command) - checkTimeout time.Duration // time to wait for data before reporting an error + id string // Client id + checkTimeout time.Duration // time to wait for data before reporting an error // atomic lastWrittenTime atomic.Int64 @@ -59,6 +63,7 @@ const ( PtPadding = 0 PtHeader = 1 // Just for the header page PtData = 2 // Data entry + PtDataRsp = 0xfe // PtDataRsp is packet type for command response with data PtResult = 0xff // Not stored/present in file (just for client command result) ) @@ -86,6 +91,108 @@ func (c *StreamClient) IsVersion3() bool { func (c *StreamClient) GetEntryChan() chan interface{} { return c.entryChan } + +// GetL2BlockByNumber queries the data stream by sending the L2 block start bookmark for the certain block number +// and streams the changes for that block (including the transactions). +// Note that this function is intended for on demand querying and it disposes the connection after it ends. +func (c *StreamClient) GetL2BlockByNumber(blockNum uint64) (*types.FullL2Block, int, error) { + if _, err := c.EnsureConnected(); err != nil { + return nil, -1, err + } + defer c.Stop() + + var ( + l2Block *types.FullL2Block + err error + isL2Block bool + ) + + bookmark := types.NewBookmarkProto(blockNum, datastream.BookmarkType_BOOKMARK_TYPE_L2_BLOCK) + bookmarkRaw, err := bookmark.Marshal() + if err != nil { + return nil, -1, err + } + + re, err := c.initiateDownloadBookmark(bookmarkRaw) + if err != nil { + errorCode := -1 + if re != nil { + errorCode = int(re.ErrorNum) + } + return nil, errorCode, err + } + + for l2Block == nil { + select { + case <-c.ctx.Done(): + errorCode := -1 + if re != nil { + errorCode = int(re.ErrorNum) + } + return l2Block, errorCode, nil + default: + } + + parsedEntry, err := ReadParsedProto(c) + if err != nil { + return nil, -1, err + } + + l2Block, isL2Block = parsedEntry.(*types.FullL2Block) + if isL2Block { + break + } + } + + if l2Block.L2BlockNumber != blockNum { + return nil, -1, fmt.Errorf("expected block number %d but got %d", blockNum, l2Block.L2BlockNumber) + } + + return l2Block, types.CmdErrOK, nil +} + +// GetLatestL2Block queries the data stream by reading the header entry and based on total entries field, +// it retrieves the latest File entry that is of EntryTypeL2Block type. +// Note that this function is intended for on demand querying and it disposes the connection after it ends. +func (c *StreamClient) GetLatestL2Block() (l2Block *types.FullL2Block, err error) { + if _, err := c.EnsureConnected(); err != nil { + return nil, err + } + defer c.Stop() + + h, err := c.GetHeader() + if err != nil { + return nil, err + } + + latestEntryNum := h.TotalEntries - 1 + + for l2Block == nil && latestEntryNum > 0 { + if err := c.sendEntryCmdWrapper(latestEntryNum); err != nil { + return nil, err + } + + entry, err := c.NextFileEntry() + if err != nil { + return nil, err + } + + if entry.EntryType == types.EntryTypeL2Block { + if l2Block, err = types.UnmarshalL2Block(entry.Data); err != nil { + return nil, err + } + } + + latestEntryNum-- + } + + if latestEntryNum == 0 { + return nil, errors.New("failed to retrieve the latest block from the data stream") + } + + return l2Block, nil +} + func (c *StreamClient) GetLastWrittenTimeAtomic() *atomic.Int64 { return &c.lastWrittenTime } @@ -111,10 +218,14 @@ func (c *StreamClient) Start() error { } func (c *StreamClient) Stop() { + if c.conn == nil { + return + } if err := c.sendStopCmd(); err != nil { log.Warn(fmt.Sprintf("Failed to send the stop command to the data stream server: %s", err)) } c.conn.Close() + c.conn = nil close(c.entryChan) } @@ -122,45 +233,59 @@ func (c *StreamClient) Stop() { // Command header: Get status // Returns the current status of the header. // If started, terminate the connection. -func (c *StreamClient) GetHeader() error { +func (c *StreamClient) GetHeader() (*types.HeaderEntry, error) { if err := c.sendHeaderCmd(); err != nil { - return fmt.Errorf("%s send header error: %v", c.id, err) + return nil, fmt.Errorf("%s send header error: %v", c.id, err) } // Read packet packet, err := readBuffer(c.conn, 1) if err != nil { - return fmt.Errorf("%s read buffer: %v", c.id, err) + return nil, fmt.Errorf("%s read buffer: %v", c.id, err) } // Check packet type if packet[0] != PtResult { - return fmt.Errorf("%s error expecting result packet type %d and received %d", c.id, PtResult, packet[0]) + return nil, fmt.Errorf("%s error expecting result packet type %d and received %d", c.id, PtResult, packet[0]) } // Read server result entry for the command r, err := c.readResultEntry(packet) if err != nil { - return fmt.Errorf("%s read result entry error: %v", c.id, err) + return nil, fmt.Errorf("%s read result entry error: %v", c.id, err) } if err := r.GetError(); err != nil { - return fmt.Errorf("%s got Result error code %d: %v", c.id, r.ErrorNum, err) + return nil, fmt.Errorf("%s got Result error code %d: %v", c.id, r.ErrorNum, err) } // Read header entry h, err := c.readHeaderEntry() if err != nil { - return fmt.Errorf("%s read header entry error: %v", c.id, err) + return nil, fmt.Errorf("%s read header entry error: %v", c.id, err) + } + + return h, nil +} + +// sendEntryCmdWrapper sends CmdEntry command and reads packet type and decodes result entry. +func (c *StreamClient) sendEntryCmdWrapper(entryNum uint64) error { + if err := c.sendEntryCmd(entryNum); err != nil { + return err } - c.Header = *h + if re, err := c.readPacketAndDecodeResultEntry(); err != nil { + return fmt.Errorf("failed to retrieve the result entry: %w", err) + } else if err := re.GetError(); err != nil { + return err + } return nil } func (c *StreamClient) ExecutePerFile(bookmark *types.BookmarkProto, function func(file *types.FileEntry) error) error { // Get header from server - if err := c.GetHeader(); err != nil { + header, err := c.GetHeader() + if err != nil { return fmt.Errorf("%s get header error: %v", c.id, err) } @@ -169,7 +294,7 @@ func (c *StreamClient) ExecutePerFile(bookmark *types.BookmarkProto, function fu return fmt.Errorf("failed to marshal bookmark: %v", err) } - if err := c.initiateDownloadBookmark(protoBookmark); err != nil { + if _, err := c.initiateDownloadBookmark(protoBookmark); err != nil { return err } count := uint64(0) @@ -181,10 +306,10 @@ func (c *StreamClient) ExecutePerFile(bookmark *types.BookmarkProto, function fu fmt.Println("Entries read count: ", count) default: } - if c.Header.TotalEntries == count { + if header.TotalEntries == count { break } - file, err := c.readFileEntry() + file, err := c.NextFileEntry() if err != nil { return fmt.Errorf("reading file entry: %v", err) } @@ -203,7 +328,7 @@ func (c *StreamClient) EnsureConnected() (bool, error) { if err := c.tryReConnect(); err != nil { return false, fmt.Errorf("failed to reconnect the datastream client: %w", err) } - log.Info("[datastream_client] Datastream client connected.") + c.entryChan = make(chan interface{}, 100000) } return true, nil @@ -229,7 +354,7 @@ func (c *StreamClient) ReadAllEntriesToChannel() error { } // send start command - if err := c.initiateDownloadBookmark(protoBookmark); err != nil { + if _, err := c.initiateDownloadBookmark(protoBookmark); err != nil { return err } @@ -253,37 +378,31 @@ func (c *StreamClient) ReadAllEntriesToChannel() error { } // runs the prerequisites for entries download -func (c *StreamClient) initiateDownloadBookmark(bookmark []byte) error { - // send start command - if err := c.sendStartBookmarkCmd(bookmark); err != nil { - return err +func (c *StreamClient) initiateDownloadBookmark(bookmark []byte) (*types.ResultEntry, error) { + // send CmdStartBookmark command + if err := c.sendBookmarkCmd(bookmark, true); err != nil { + return nil, err } - if err := c.afterStartCommand(); err != nil { - return fmt.Errorf("after start command error: %v", err) + re, err := c.afterStartCommand() + if err != nil { + return re, fmt.Errorf("after start command error: %v", err) } - return nil + return re, nil } -func (c *StreamClient) afterStartCommand() error { - // Read packet - packet, err := readBuffer(c.conn, 1) - if err != nil { - return fmt.Errorf("read buffer error %v", err) - } - - // Read server result entry for the command - r, err := c.readResultEntry(packet) +func (c *StreamClient) afterStartCommand() (*types.ResultEntry, error) { + re, err := c.readPacketAndDecodeResultEntry() if err != nil { - return fmt.Errorf("read result entry error: %v", err) + return nil, err } - if err := r.GetError(); err != nil { - return fmt.Errorf("got Result error code %d: %v", r.ErrorNum, err) + if err := re.GetError(); err != nil { + return re, fmt.Errorf("got Result error code %d: %v", re.ErrorNum, err) } - return nil + return re, nil } // reads all entries from the server and sends them to a channel @@ -304,7 +423,7 @@ LOOP: c.conn.SetReadDeadline(time.Now().Add(c.checkTimeout)) } - parsedProto, localErr := c.readParsedProto() + parsedProto, localErr := ReadParsedProto(c) if localErr != nil { err = localErr break @@ -339,11 +458,13 @@ func (c *StreamClient) tryReConnect() error { for i := 0; i < 50; i++ { if c.conn != nil { if err := c.conn.Close(); err != nil { + log.Warn(fmt.Sprintf("[%d. iteration] failed to close the DS connection: %s", i+1, err)) return err } c.conn = nil } if err = c.Start(); err != nil { + log.Warn(fmt.Sprintf("[%d. iteration] failed to start the DS connection: %s", i+1, err)) time.Sleep(5 * time.Second) continue } @@ -353,16 +474,24 @@ func (c *StreamClient) tryReConnect() error { return err } -func (c *StreamClient) readParsedProto() ( +type FileEntryIterator interface { + NextFileEntry() (*types.FileEntry, error) +} + +func ReadParsedProto(iterator FileEntryIterator) ( parsedEntry interface{}, err error, ) { - file, err := c.readFileEntry() + file, err := iterator.NextFileEntry() if err != nil { - err = fmt.Errorf("read file entry error: %v", err) + err = fmt.Errorf("read file entry error: %w", err) return } + if file == nil { + return nil, nil + } + switch file.EntryType { case types.BookmarkEntryType: parsedEntry, err = types.UnmarshalBookmark(file.Data) @@ -384,7 +513,7 @@ func (c *StreamClient) readParsedProto() ( var l2Tx *types.L2TransactionProto LOOP: for { - if innerFile, err = c.readFileEntry(); err != nil { + if innerFile, err = iterator.NextFileEntry(); err != nil { return } @@ -428,8 +557,11 @@ func (c *StreamClient) readParsedProto() ( l2Block.L2Txs = txs parsedEntry = l2Block return + case types.EntryTypeL2BlockEnd: + log.Debug(fmt.Sprintf("retrieved EntryTypeL2BlockEnd: %+v", file)) + return case types.EntryTypeL2Tx: - err = fmt.Errorf("unexpected l2Tx out of block") + err = errors.New("unexpected L2 tx entry, found outside of block") default: err = fmt.Errorf("unexpected entry type: %d", file.EntryType) } @@ -438,15 +570,16 @@ func (c *StreamClient) readParsedProto() ( // reads file bytes from socket and tries to parse them // returns the parsed FileEntry -func (c *StreamClient) readFileEntry() (file *types.FileEntry, err error) { +func (c *StreamClient) NextFileEntry() (file *types.FileEntry, err error) { // Read packet type packet, err := readBuffer(c.conn, 1) if err != nil { return file, fmt.Errorf("failed to read packet type: %v", err) } + packetType := packet[0] // Check packet type - if packet[0] == PtResult { + if packetType == PtResult { // Read server result entry for the command r, err := c.readResultEntry(packet) if err != nil { @@ -456,8 +589,8 @@ func (c *StreamClient) readFileEntry() (file *types.FileEntry, err error) { return file, fmt.Errorf("got Result error code %d: %v", r.ErrorNum, err) } return file, nil - } else if packet[0] != PtData { - return file, fmt.Errorf("error expecting data packet type %d and received %d", PtData, packet[0]) + } else if packetType != PtData && packetType != PtDataRsp { + return file, fmt.Errorf("expected data packet type %d or %d and received %d", PtData, PtDataRsp, packetType) } // Read the rest of fixed size fields @@ -465,6 +598,10 @@ func (c *StreamClient) readFileEntry() (file *types.FileEntry, err error) { if err != nil { return file, fmt.Errorf("error reading file bytes: %v", err) } + + if packetType != PtData { + packet[0] = PtData + } buffer = append(packet, buffer...) // Read variable field (data) @@ -485,6 +622,10 @@ func (c *StreamClient) readFileEntry() (file *types.FileEntry, err error) { return file, fmt.Errorf("decode file entry error: %v", err) } + if file.EntryType == types.EntryTypeNotFound { + return file, ErrFileEntryNotFound + } + return } @@ -550,3 +691,20 @@ func (c *StreamClient) readResultEntry(packet []byte) (re *types.ResultEntry, er return re, nil } + +// readPacketAndDecodeResultEntry reads the packet from the connection and tries to decode the ResultEntry from it. +func (c *StreamClient) readPacketAndDecodeResultEntry() (*types.ResultEntry, error) { + // Read packet + packet, err := readBuffer(c.conn, 1) + if err != nil { + return nil, fmt.Errorf("read buffer error: %w", err) + } + + // Read server result entry for the command + r, err := c.readResultEntry(packet) + if err != nil { + return nil, fmt.Errorf("read result entry error: %w", err) + } + + return r, nil +} diff --git a/zk/datastream/client/stream_client_test.go b/zk/datastream/client/stream_client_test.go index 026879aa424..05e0cd80149 100644 --- a/zk/datastream/client/stream_client_test.go +++ b/zk/datastream/client/stream_client_test.go @@ -1,17 +1,28 @@ package client import ( + "bytes" "context" + "encoding/binary" + "errors" "fmt" "net" + "sync" "testing" + "time" + "github.com/ledgerwatch/erigon-lib/common" + "github.com/ledgerwatch/erigon/zk/datastream/proto/github.com/0xPolygonHermez/zkevm-node/state/datastream" "github.com/ledgerwatch/erigon/zk/datastream/types" "github.com/stretchr/testify/require" "gotest.tools/v3/assert" ) -func Test_readHeaderEntry(t *testing.T) { +const ( + streamTypeFieldName = "stream type" +) + +func TestStreamClientReadHeaderEntry(t *testing.T) { type testCase struct { name string input []byte @@ -35,7 +46,7 @@ func Test_readHeaderEntry(t *testing.T) { name: "Invalid byte array length", input: []byte{20, 21, 22, 23, 24, 20}, expectedResult: nil, - expectedError: fmt.Errorf("failed to read header bytes reading from server: unexpected EOF"), + expectedError: errors.New("failed to read header bytes reading from server: unexpected EOF"), }, } @@ -59,7 +70,7 @@ func Test_readHeaderEntry(t *testing.T) { } } -func Test_readResultEntry(t *testing.T) { +func TestStreamClientReadResultEntry(t *testing.T) { type testCase struct { name string input []byte @@ -93,13 +104,13 @@ func Test_readResultEntry(t *testing.T) { name: "Invalid byte array length", input: []byte{20, 21, 22, 23, 24, 20}, expectedResult: nil, - expectedError: fmt.Errorf("failed to read main result bytes reading from server: unexpected EOF"), + expectedError: errors.New("failed to read main result bytes reading from server: unexpected EOF"), }, { name: "Invalid error length", input: []byte{0, 0, 0, 12, 0, 0, 0, 0, 20, 21}, expectedResult: nil, - expectedError: fmt.Errorf("failed to read result errStr bytes reading from server: unexpected EOF"), + expectedError: errors.New("failed to read result errStr bytes reading from server: unexpected EOF"), }, } @@ -123,7 +134,7 @@ func Test_readResultEntry(t *testing.T) { } } -func Test_readFileEntry(t *testing.T) { +func TestStreamClientReadFileEntry(t *testing.T) { type testCase struct { name string input []byte @@ -158,18 +169,18 @@ func Test_readFileEntry(t *testing.T) { name: "Invalid packet type", input: []byte{5, 0, 0, 0, 17, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 45}, expectedResult: nil, - expectedError: fmt.Errorf("error expecting data packet type 2 and received 5"), + expectedError: errors.New("expected data packet type 2 or 254 and received 5"), }, { name: "Invalid byte array length", input: []byte{2, 21, 22, 23, 24, 20}, expectedResult: nil, - expectedError: fmt.Errorf("error reading file bytes: reading from server: unexpected EOF"), + expectedError: errors.New("error reading file bytes: reading from server: unexpected EOF"), }, { name: "Invalid data length", input: []byte{2, 0, 0, 0, 31, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 45, 0, 0, 0, 24, 0, 0, 0, 0, 0, 0, 0, 64}, expectedResult: nil, - expectedError: fmt.Errorf("error reading file data bytes: reading from server: unexpected EOF"), + expectedError: errors.New("error reading file data bytes: reading from server: unexpected EOF"), }, } for _, testCase := range testCases { @@ -185,9 +196,377 @@ func Test_readFileEntry(t *testing.T) { server.Close() }() - result, err := c.readFileEntry() + result, err := c.NextFileEntry() require.Equal(t, testCase.expectedError, err) assert.DeepEqual(t, testCase.expectedResult, result) }) } } + +func TestStreamClientReadParsedProto(t *testing.T) { + c := NewClient(context.Background(), "", 0, 0, 0) + serverConn, clientConn := net.Pipe() + c.conn = clientConn + defer func() { + serverConn.Close() + clientConn.Close() + }() + + l2Block, l2Txs := createL2BlockAndTransactions(t, 3, 1) + l2BlockProto := &types.L2BlockProto{L2Block: l2Block} + l2BlockRaw, err := l2BlockProto.Marshal() + require.NoError(t, err) + + l2Tx := l2Txs[0] + l2TxProto := &types.TxProto{Transaction: l2Tx} + l2TxRaw, err := l2TxProto.Marshal() + require.NoError(t, err) + + l2BlockEnd := &types.L2BlockEndProto{Number: l2Block.GetNumber()} + l2BlockEndRaw, err := l2BlockEnd.Marshal() + require.NoError(t, err) + + var ( + errCh = make(chan error) + wg sync.WaitGroup + ) + wg.Add(1) + + go func() { + defer wg.Done() + fileEntries := []*types.FileEntry{ + createFileEntry(t, types.EntryTypeL2Block, 1, l2BlockRaw), + createFileEntry(t, types.EntryTypeL2Tx, 2, l2TxRaw), + createFileEntry(t, types.EntryTypeL2BlockEnd, 3, l2BlockEndRaw), + } + for _, fe := range fileEntries { + _, writeErr := serverConn.Write(fe.Encode()) + if writeErr != nil { + errCh <- writeErr + break + } + } + }() + + go func() { + wg.Wait() + close(errCh) + }() + + parsedEntry, err := ReadParsedProto(c) + require.NoError(t, err) + serverErr := <-errCh + require.NoError(t, serverErr) + expectedL2Tx := types.ConvertToL2TransactionProto(l2Tx) + expectedL2Block := types.ConvertToFullL2Block(l2Block) + expectedL2Block.L2Txs = append(expectedL2Block.L2Txs, *expectedL2Tx) + require.Equal(t, expectedL2Block, parsedEntry) +} + +func TestStreamClientGetLatestL2Block(t *testing.T) { + serverConn, clientConn := net.Pipe() + defer func() { + serverConn.Close() + clientConn.Close() + }() + + c := NewClient(context.Background(), "", 0, 0, 0) + c.conn = clientConn + + expectedL2Block, _ := createL2BlockAndTransactions(t, 5, 0) + l2BlockProto := &types.L2BlockProto{L2Block: expectedL2Block} + l2BlockRaw, err := l2BlockProto.Marshal() + require.NoError(t, err) + + var ( + errCh = make(chan error) + wg sync.WaitGroup + ) + wg.Add(1) + + // Prepare the server to send responses in a separate goroutine + go func() { + defer wg.Done() + + // Read the Command + if err := readAndValidateUint(t, serverConn, uint64(CmdHeader), "command"); err != nil { + errCh <- err + return + } + + // Read the StreamType + if err := readAndValidateUint(t, serverConn, uint64(StSequencer), streamTypeFieldName); err != nil { + errCh <- err + return + } + + // Write ResultEntry + re := createResultEntry(t) + _, err = serverConn.Write(re.Encode()) + if err != nil { + errCh <- fmt.Errorf("failed to write result entry to the connection: %w", err) + } + + // Write HeaderEntry + he := &types.HeaderEntry{ + PacketType: uint8(CmdHeader), + HeadLength: types.HeaderSize, + Version: 2, + SystemId: 1, + StreamType: types.StreamType(StSequencer), + TotalEntries: 4, + } + _, err = serverConn.Write(he.Encode()) + if err != nil { + errCh <- fmt.Errorf("failed to write header entry to the connection: %w", err) + } + + // Read the Command + if err := readAndValidateUint(t, serverConn, uint64(CmdEntry), "command"); err != nil { + errCh <- err + return + } + + // Read the StreamType + if err := readAndValidateUint(t, serverConn, uint64(StSequencer), streamTypeFieldName); err != nil { + errCh <- err + return + } + + // Read the EntryNumber + if err := readAndValidateUint(t, serverConn, he.TotalEntries-1, "entry number"); err != nil { + errCh <- err + return + } + + // Write the ResultEntry + _, err = serverConn.Write(re.Encode()) + if err != nil { + errCh <- fmt.Errorf("failed to write result entry to the connection: %w", err) + return + } + + // Write the FileEntry containing the L2 block information + fe := createFileEntry(t, types.EntryTypeL2Block, 1, l2BlockRaw) + _, err = serverConn.Write(fe.Encode()) + if err != nil { + errCh <- fmt.Errorf("failed to write the l2 block file entry to the connection: %w", err) + return + } + + serverConn.Close() + }() + + go func() { + wg.Wait() + close(errCh) + }() + + // ACT + l2Block, err := c.GetLatestL2Block() + require.NoError(t, err) + + // ASSERT + serverErr := <-errCh + require.NoError(t, serverErr) + + expectedFullL2Block := types.ConvertToFullL2Block(expectedL2Block) + require.Equal(t, expectedFullL2Block, l2Block) +} + +func TestStreamClientGetL2BlockByNumber(t *testing.T) { + const blockNum = uint64(5) + + serverConn, clientConn := net.Pipe() + defer func() { + serverConn.Close() + clientConn.Close() + }() + + c := NewClient(context.Background(), "", 0, 0, 0) + c.conn = clientConn + + bookmark := types.NewBookmarkProto(blockNum, datastream.BookmarkType_BOOKMARK_TYPE_L2_BLOCK) + bookmarkRaw, err := bookmark.Marshal() + require.NoError(t, err) + + expectedL2Block, l2Txs := createL2BlockAndTransactions(t, blockNum, 3) + l2BlockProto := &types.L2BlockProto{L2Block: expectedL2Block} + l2BlockRaw, err := l2BlockProto.Marshal() + require.NoError(t, err) + + l2TxsRaw := make([][]byte, len(l2Txs)) + for i, l2Tx := range l2Txs { + l2TxProto := &types.TxProto{Transaction: l2Tx} + l2TxRaw, err := l2TxProto.Marshal() + require.NoError(t, err) + l2TxsRaw[i] = l2TxRaw + } + + l2BlockEnd := &types.L2BlockEndProto{Number: expectedL2Block.GetNumber()} + l2BlockEndRaw, err := l2BlockEnd.Marshal() + require.NoError(t, err) + + errCh := make(chan error) + + createServerResponses := func(t *testing.T, serverConn net.Conn, bookmarkRaw, l2BlockRaw []byte, l2TxsRaw [][]byte, l2BlockEndRaw []byte, errCh chan error) { + defer func() { + close(errCh) + serverConn.Close() + }() + + // Read the command + if err := readAndValidateUint(t, serverConn, uint64(CmdStartBookmark), "command"); err != nil { + errCh <- err + return + } + + // Read the stream type + if err := readAndValidateUint(t, serverConn, uint64(StSequencer), streamTypeFieldName); err != nil { + errCh <- err + return + } + + // Read the bookmark length + if err := readAndValidateUint(t, serverConn, uint32(len(bookmarkRaw)), "bookmark length"); err != nil { + errCh <- err + return + } + + // Read the actual bookmark + actualBookmarkRaw, err := readBuffer(serverConn, uint32(len(bookmarkRaw))) + if err != nil { + errCh <- err + return + } + if !bytes.Equal(bookmarkRaw, actualBookmarkRaw) { + errCh <- fmt.Errorf("mismatch between expected %v and actual bookmark %v", bookmarkRaw, actualBookmarkRaw) + return + } + + // Write ResultEntry + re := createResultEntry(t) + if _, err := serverConn.Write(re.Encode()); err != nil { + errCh <- err + return + } + + // Write File entries (EntryTypeL2Block, EntryTypeL2Tx and EntryTypeL2BlockEnd) + fileEntries := make([]*types.FileEntry, 0, len(l2TxsRaw)+2) + fileEntries = append(fileEntries, createFileEntry(t, types.EntryTypeL2Block, 1, l2BlockRaw)) + entryNum := uint64(2) + for _, l2TxRaw := range l2TxsRaw { + fileEntries = append(fileEntries, createFileEntry(t, types.EntryTypeL2Tx, entryNum, l2TxRaw)) + entryNum++ + } + fileEntries = append(fileEntries, createFileEntry(t, types.EntryTypeL2BlockEnd, entryNum, l2BlockEndRaw)) + + for _, fe := range fileEntries { + if _, err := serverConn.Write(fe.Encode()); err != nil { + errCh <- err + return + } + } + + } + + go createServerResponses(t, serverConn, bookmarkRaw, l2BlockRaw, l2TxsRaw, l2BlockEndRaw, errCh) + + l2Block, errCode, err := c.GetL2BlockByNumber(blockNum) + require.NoError(t, err) + require.Equal(t, types.CmdErrOK, errCode) + + serverErr := <-errCh + require.NoError(t, serverErr) + + l2TxsProto := make([]types.L2TransactionProto, len(l2Txs)) + for i, tx := range l2Txs { + l2TxProto := types.ConvertToL2TransactionProto(tx) + l2TxsProto[i] = *l2TxProto + } + expectedFullL2Block := types.ConvertToFullL2Block(expectedL2Block) + expectedFullL2Block.L2Txs = l2TxsProto + require.Equal(t, expectedFullL2Block, l2Block) +} + +// readAndValidateUint reads the uint value and validates it against expected value from the connection in order to unblock future write operations +func readAndValidateUint(t *testing.T, conn net.Conn, expected interface{}, paramName string) error { + t.Helper() + + var length uint32 + switch expected.(type) { + case uint64: + length = 8 + case uint32: + length = 4 + default: + return fmt.Errorf("unsupported expected type for %s: %T", paramName, expected) + } + + valueRaw, err := readBuffer(conn, length) + if err != nil { + return fmt.Errorf("failed to read %s parameter: %w", paramName, err) + } + + switch expectedValue := expected.(type) { + case uint64: + value := binary.BigEndian.Uint64(valueRaw) + if value != expectedValue { + return fmt.Errorf("%s parameter value mismatch between expected %d and actual %d", paramName, expectedValue, value) + } + case uint32: + value := binary.BigEndian.Uint32(valueRaw) + if value != expectedValue { + return fmt.Errorf("%s parameter value mismatch between expected %d and actual %d", paramName, expectedValue, value) + } + } + + return nil +} + +// createFileEntry is a helper function that creates FileEntry +func createFileEntry(t *testing.T, entryType types.EntryType, num uint64, data []byte) *types.FileEntry { + t.Helper() + return &types.FileEntry{ + PacketType: PtData, + Length: types.FileEntryMinSize + uint32(len(data)), + EntryType: entryType, + EntryNum: num, + Data: data, + } +} + +func createResultEntry(t *testing.T) *types.ResultEntry { + t.Helper() + return &types.ResultEntry{ + PacketType: PtResult, + ErrorNum: types.CmdErrOK, + Length: types.ResultEntryMinSize, + ErrorStr: nil, + } +} + +// createL2BlockAndTransactions creates a single L2 block with the transactions +func createL2BlockAndTransactions(t *testing.T, blockNum uint64, txnCount int) (*datastream.L2Block, []*datastream.Transaction) { + t.Helper() + txns := make([]*datastream.Transaction, 0, txnCount) + l2Block := &datastream.L2Block{ + Number: blockNum, + BatchNumber: 1, + Timestamp: uint64(time.Now().UnixMilli()), + Hash: common.HexToHash("0x123456987654321").Bytes(), + BlockGasLimit: 1000000000, + } + + for i := 0; i < txnCount; i++ { + txns = append(txns, + &datastream.Transaction{ + L2BlockNumber: l2Block.GetNumber(), + Index: uint64(i), + IsValid: true, + Debug: &datastream.Debug{Message: fmt.Sprintf("Hello %d. transaction!", i+1)}, + }) + } + + return l2Block, txns +} diff --git a/zk/datastream/server/data_stream_server.go b/zk/datastream/server/data_stream_server.go index b293a4a6ef1..80d523618a3 100644 --- a/zk/datastream/server/data_stream_server.go +++ b/zk/datastream/server/data_stream_server.go @@ -8,6 +8,7 @@ import ( "github.com/ledgerwatch/erigon-lib/kv" "github.com/ledgerwatch/erigon/core/rawdb" eritypes "github.com/ledgerwatch/erigon/core/types" + "github.com/ledgerwatch/erigon/zk/datastream/client" "github.com/ledgerwatch/erigon/zk/datastream/proto/github.com/0xPolygonHermez/zkevm-node/state/datastream" "github.com/ledgerwatch/erigon/zk/datastream/types" zktypes "github.com/ledgerwatch/erigon/zk/types" @@ -598,3 +599,88 @@ func (srv *DataStreamServer) getLastEntryOfType(entryType datastreamer.EntryType return emtryEntry, false, nil } + +type dataStreamServerIterator struct { + stream *datastreamer.StreamServer + curEntryNum uint64 + header uint64 +} + +func newDataStreamServerIterator(stream *datastreamer.StreamServer, start uint64) *dataStreamServerIterator { + return &dataStreamServerIterator{ + stream: stream, + curEntryNum: start, + header: stream.GetHeader().TotalEntries - 1, + } +} + +func (it *dataStreamServerIterator) NextFileEntry() (entry *types.FileEntry, err error) { + if it.curEntryNum > it.header { + return nil, nil + } + + var fileEntry datastreamer.FileEntry + fileEntry, err = it.stream.GetEntry(it.curEntryNum) + if err != nil { + return nil, err + } + + it.curEntryNum += 1 + + return &types.FileEntry{ + PacketType: uint8(fileEntry.Type), + Length: fileEntry.Length, + EntryType: types.EntryType(fileEntry.Type), + EntryNum: fileEntry.Number, + Data: fileEntry.Data, + }, nil +} + +func (srv *DataStreamServer) ReadBatches(start uint64, end uint64) ([][]*types.FullL2Block, error) { + bookmark := types.NewBookmarkProto(start, datastream.BookmarkType_BOOKMARK_TYPE_BATCH) + marshalled, err := bookmark.Marshal() + if err != nil { + return nil, err + } + + entryNum, err := srv.stream.GetBookmark(marshalled) + + if err != nil { + return nil, err + } + + iterator := newDataStreamServerIterator(srv.stream, entryNum) + + return ReadBatches(iterator, start, end) +} + +func ReadBatches(iterator client.FileEntryIterator, start uint64, end uint64) ([][]*types.FullL2Block, error) { + batches := make([][]*types.FullL2Block, end-start+1) + +LOOP_ENTRIES: + for { + parsedProto, err := client.ReadParsedProto(iterator) + if err != nil { + return nil, err + } + + if parsedProto == nil { + break + } + + switch parsedProto := parsedProto.(type) { + case *types.BatchStart: + batches[parsedProto.Number-start] = []*types.FullL2Block{} + case *types.BatchEnd: + if parsedProto.Number == end { + break LOOP_ENTRIES + } + case *types.FullL2Block: + batches[parsedProto.BatchNumber-start] = append(batches[parsedProto.BatchNumber-start], parsedProto) + default: + continue + } + } + + return batches, nil +} diff --git a/zk/datastream/types/entry_type.go b/zk/datastream/types/entry_type.go index 827aabfb15c..49a3909b67b 100644 --- a/zk/datastream/types/entry_type.go +++ b/zk/datastream/types/entry_type.go @@ -1,5 +1,7 @@ package types +import "math" + type EntryType uint32 var ( @@ -11,4 +13,5 @@ var ( EntryTypeGerUpdate EntryType = 5 EntryTypeL2BlockEnd EntryType = 6 BookmarkEntryType EntryType = 176 + EntryTypeNotFound EntryType = math.MaxUint32 ) diff --git a/zk/datastream/types/file.go b/zk/datastream/types/file.go index 41c043caa22..20417460dfe 100644 --- a/zk/datastream/types/file.go +++ b/zk/datastream/types/file.go @@ -34,7 +34,7 @@ func (f *FileEntry) IsBookmarkBlock() bool { } func (f *FileEntry) IsL2BlockEnd() bool { - return uint32(f.EntryType) == uint32(6) //TODO: fix once it is added in the lib + return uint32(f.EntryType) == uint32(datastream.EntryType_ENTRY_TYPE_L2_BLOCK_END) } func (f *FileEntry) IsL2Block() bool { return uint32(f.EntryType) == uint32(datastream.EntryType_ENTRY_TYPE_L2_BLOCK) @@ -62,6 +62,17 @@ func (f *FileEntry) IsGerUpdate() bool { return f.EntryType == EntryTypeGerUpdate } +// Encode encodes file entry to the binary format +func (f *FileEntry) Encode() []byte { + be := make([]byte, 1) + be[0] = f.PacketType + be = binary.BigEndian.AppendUint32(be, f.Length) + be = binary.BigEndian.AppendUint32(be, uint32(f.EntryType)) + be = binary.BigEndian.AppendUint64(be, f.EntryNum) + be = append(be, f.Data...) //nolint:makezero + return be +} + // Decode/convert from binary bytes slice to FileEntry type func DecodeFileEntry(b []byte) (*FileEntry, error) { if uint32(len(b)) < FileEntryMinSize { diff --git a/zk/datastream/types/header.go b/zk/datastream/types/header.go index 5b393a17794..9af3d368e36 100644 --- a/zk/datastream/types/header.go +++ b/zk/datastream/types/header.go @@ -5,14 +5,16 @@ import ( "fmt" ) -const HeaderSize = 38 -const HeaderSizePreEtrog = 29 +const ( + HeaderSize = 38 + HeaderSizePreEtrog = 29 +) type StreamType uint64 type HeaderEntry struct { PacketType uint8 // 1:Header - HeadLength uint32 // 38 oe 29 + HeadLength uint32 // 38 or 29 Version uint8 SystemId uint64 StreamType StreamType // 1:Sequencer @@ -20,6 +22,19 @@ type HeaderEntry struct { TotalEntries uint64 // Total number of data entries (entry type 2) } +// Encode encodes given HeaderEntry into a binary format +func (e *HeaderEntry) Encode() []byte { + be := make([]byte, 1) + be[0] = e.PacketType + be = binary.BigEndian.AppendUint32(be, e.HeadLength) + be = append(be, e.Version) //nolint:makezero + be = binary.BigEndian.AppendUint64(be, e.SystemId) + be = binary.BigEndian.AppendUint64(be, uint64(e.StreamType)) + be = binary.BigEndian.AppendUint64(be, e.TotalLength) + be = binary.BigEndian.AppendUint64(be, e.TotalEntries) + return be +} + // Decode/convert from binary bytes slice to a header entry type func DecodeHeaderEntryPreEtrog(b []byte) (*HeaderEntry, error) { return &HeaderEntry{ diff --git a/zk/datastream/types/l2block_proto.go b/zk/datastream/types/l2block_proto.go index 36be0c9e446..c0251a16af4 100644 --- a/zk/datastream/types/l2block_proto.go +++ b/zk/datastream/types/l2block_proto.go @@ -73,21 +73,24 @@ func UnmarshalL2Block(data []byte) (*FullL2Block, error) { return nil, err } - l2Block := &FullL2Block{ - BatchNumber: block.BatchNumber, - L2BlockNumber: block.Number, - Timestamp: int64(block.Timestamp), - DeltaTimestamp: block.DeltaTimestamp, - L1InfoTreeIndex: block.L1InfotreeIndex, - GlobalExitRoot: libcommon.BytesToHash(block.GlobalExitRoot), - Coinbase: libcommon.BytesToAddress(block.Coinbase), - L1BlockHash: libcommon.BytesToHash(block.L1Blockhash), - L2Blockhash: libcommon.BytesToHash(block.Hash), - StateRoot: libcommon.BytesToHash(block.StateRoot), - BlockGasLimit: block.BlockGasLimit, - BlockInfoRoot: libcommon.BytesToHash(block.BlockInfoRoot), - Debug: ProcessDebug(block.Debug), - } + return ConvertToFullL2Block(&block), nil +} - return l2Block, nil +// ConvertToFullL2Block converts the datastream.L2Block to types.FullL2Block +func ConvertToFullL2Block(block *datastream.L2Block) *FullL2Block { + return &FullL2Block{ + BatchNumber: block.GetBatchNumber(), + L2BlockNumber: block.GetNumber(), + Timestamp: int64(block.GetTimestamp()), + DeltaTimestamp: block.GetDeltaTimestamp(), + L1InfoTreeIndex: block.GetL1InfotreeIndex(), + GlobalExitRoot: libcommon.BytesToHash(block.GetGlobalExitRoot()), + Coinbase: libcommon.BytesToAddress(block.GetCoinbase()), + L1BlockHash: libcommon.BytesToHash(block.GetL1Blockhash()), + L2Blockhash: libcommon.BytesToHash(block.GetHash()), + StateRoot: libcommon.BytesToHash(block.GetStateRoot()), + BlockGasLimit: block.GetBlockGasLimit(), + BlockInfoRoot: libcommon.BytesToHash(block.GetBlockInfoRoot()), + Debug: ProcessDebug(block.GetDebug()), + } } diff --git a/zk/datastream/types/result.go b/zk/datastream/types/result.go index 1db1061c0e2..1e6652cbb9d 100644 --- a/zk/datastream/types/result.go +++ b/zk/datastream/types/result.go @@ -12,11 +12,12 @@ const ( ResultEntryMinSize = uint32(9) // Command errors - CmdErrOK = 0 - CmdErrAlreadyStarted = 1 - CmdErrAlreadyStopped = 2 - CmdErrBadFromEntry = 3 - CmdErrInvalidCommand = 9 + CmdErrOK = 0 // CmdErrOK for no error + CmdErrAlreadyStarted = 1 // CmdErrAlreadyStarted for client already started error + CmdErrAlreadyStopped = 2 // CmdErrAlreadyStopped for client already stopped error + CmdErrBadFromEntry = 3 // CmdErrBadFromEntry for invalid starting entry number + CmdErrBadFromBookmark = 4 // CmdErrBadFromBookmark for invalid starting bookmark + CmdErrInvalidCommand = 9 // CmdErrInvalidCommand for invalid/unknown command error ) type ResultEntry struct { @@ -41,9 +42,18 @@ func (r *ResultEntry) GetError() error { return errors.New(string(r.ErrorStr)) } +// Encode encodes result entry to the binary format +func (r *ResultEntry) Encode() []byte { + be := make([]byte, 1) + be[0] = r.PacketType + be = binary.BigEndian.AppendUint32(be, r.Length) + be = binary.BigEndian.AppendUint32(be, r.ErrorNum) + be = append(be, r.ErrorStr...) //nolint:makezero + return be +} + // Decode/convert from binary bytes slice to an entry type func DecodeResultEntry(b []byte) (*ResultEntry, error) { - if uint32(len(b)) < ResultEntryMinSize { return &ResultEntry{}, fmt.Errorf("invalid result entry binary size. Expected: >=%d, got: %d", ResultEntryMinSize, len(b)) } diff --git a/zk/datastream/types/result_test.go b/zk/datastream/types/result_test.go index 20a1f6cbabc..bdbbf40c7d6 100644 --- a/zk/datastream/types/result_test.go +++ b/zk/datastream/types/result_test.go @@ -59,3 +59,17 @@ func TestResultDecode(t *testing.T) { }) } } + +func TestEncodeDecodeResult(t *testing.T) { + expectedResult := &ResultEntry{ + PacketType: 1, + Length: 19, + ErrorNum: 5, + ErrorStr: []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, + } + + resultRaw := expectedResult.Encode() + actualResult, err := DecodeResultEntry(resultRaw) + require.NoError(t, err) + require.Equal(t, expectedResult, actualResult) +} diff --git a/zk/datastream/types/tx_proto.go b/zk/datastream/types/tx_proto.go index 28d48d9afa4..31b4b2d1902 100644 --- a/zk/datastream/types/tx_proto.go +++ b/zk/datastream/types/tx_proto.go @@ -35,15 +35,18 @@ func UnmarshalTx(data []byte) (*L2TransactionProto, error) { return nil, err } - l2Tx := &L2TransactionProto{ - L2BlockNumber: tx.L2BlockNumber, - Index: tx.Index, - IsValid: tx.IsValid, - Encoded: tx.Encoded, - EffectiveGasPricePercentage: uint8(tx.EffectiveGasPricePercentage), - IntermediateStateRoot: libcommon.BytesToHash(tx.ImStateRoot), - Debug: ProcessDebug(tx.Debug), - } + return ConvertToL2TransactionProto(&tx), nil +} - return l2Tx, nil +// ConvertToL2TransactionProto converts transaction object from datastream.Transaction to types.L2TransactionProto +func ConvertToL2TransactionProto(tx *datastream.Transaction) *L2TransactionProto { + return &L2TransactionProto{ + L2BlockNumber: tx.GetL2BlockNumber(), + Index: tx.GetIndex(), + IsValid: tx.GetIsValid(), + Encoded: tx.GetEncoded(), + EffectiveGasPricePercentage: uint8(tx.GetEffectiveGasPricePercentage()), + IntermediateStateRoot: libcommon.BytesToHash(tx.GetImStateRoot()), + Debug: ProcessDebug(tx.GetDebug()), + } } diff --git a/zk/debug_tools/test-contracts/contracts/Creates.sol b/zk/debug_tools/test-contracts/contracts/Creates.sol new file mode 100644 index 00000000000..cdd7c517bae --- /dev/null +++ b/zk/debug_tools/test-contracts/contracts/Creates.sol @@ -0,0 +1,67 @@ +// SPDX-License-Identifier: GPL-3.0 + +pragma solidity >=0.7.0 <0.9.0; + +contract Creates { + function opCreate(bytes memory bytecode, uint length) public returns(address) { + address addr; + assembly { + addr := create(0, 0xa0, length) + sstore(0x0, addr) + } + return addr; + } + + function opCreate2(bytes memory bytecode, uint length) public returns(address) { + address addr; + assembly { + addr := create2(0, 0xa0, length, 0x2) + sstore(0x0, addr) + } + return addr; + } + + function opCreate2Complex(bytes memory bytecode, uint length) public returns(address, uint256) { + uint256 number = add(1, 2); + + address addr; + assembly { + addr := create2(0, add(bytecode, 0x20), length, 0x2) + sstore(0x0, addr) + } + + number = add(2, 4); + + return (addr, number); + } + + function add(uint256 a, uint256 b) public pure returns(uint256) { + return a + b; + } + + function sendValue() public payable { + uint bal; + assembly{ + bal := add(bal,callvalue()) + sstore(0x1, bal) + } + } + + function opCreateValue(bytes memory bytecode, uint length) public payable returns(address) { + address addr; + assembly { + addr := create(500, 0xa0, length) + sstore(0x0, addr) + } + return addr; + } + + function opCreate2Value(bytes memory bytecode, uint length) public payable returns(address) { + address addr; + assembly { + addr := create2(300, 0xa0, length, 0x55555) + sstore(0x0, addr) + } + return addr; + } +} \ No newline at end of file diff --git a/zk/debug_tools/test-contracts/package.json b/zk/debug_tools/test-contracts/package.json index 25a3edd0745..06514e69be1 100644 --- a/zk/debug_tools/test-contracts/package.json +++ b/zk/debug_tools/test-contracts/package.json @@ -10,6 +10,7 @@ "counter:bali": "npx hardhat compile && npx hardhat run scripts/counter.js --network bali", "counter:cardona": "npx hardhat compile && npx hardhat run scripts/counter.js --network cardona", "counter:mainnet": "npx hardhat compile && npx hardhat run scripts/counter.js --network mainnet", + "emitlog:local": "npx hardhat compile && npx hardhat run scripts/emitlog.js --network local", "emitlog:bali": "npx hardhat compile && npx hardhat run scripts/emitlog.js --network bali", "emitlog:cardona": "npx hardhat compile && npx hardhat run scripts/emitlog.js --network cardona", "emitlog:mainnet": "npx hardhat compile && npx hardhat run scripts/emitlog.js --network mainnet", @@ -18,7 +19,8 @@ "erc20Revert:local": "npx hardhat compile && npx hardhat run scripts/ERC20-revert.js --network local", "erc20Revert:sepolia": "npx hardhat compile && npx hardhat run scripts/ERC20-revert.js --network sepolia", "chainCall:local": "npx hardhat compile && npx hardhat run scripts/chain-call.js --network local", - "chainCall:sepolia": "npx hardhat compile && npx hardhat run scripts/chain-call.js --network sepolia" + "chainCall:sepolia": "npx hardhat compile && npx hardhat run scripts/chain-call.js --network sepolia", + "create:local": "npx hardhat compile && npx hardhat run scripts/create.js --network local" }, "keywords": [], "author": "", diff --git a/zk/debug_tools/test-contracts/scripts/create.js b/zk/debug_tools/test-contracts/scripts/create.js new file mode 100644 index 00000000000..291e7ac0c45 --- /dev/null +++ b/zk/debug_tools/test-contracts/scripts/create.js @@ -0,0 +1,30 @@ + +// deploys contracts and calls a method to produce delegate call + +async function main() { + const deployableBytecode = "608060405234801561000f575f80fd5b506101778061001d5f395ff3fe608060405234801561000f575f80fd5b506004361061003f575f3560e01c806306661abd14610043578063a87d942c14610061578063d09de08a1461007f575b5f80fd5b61004b610089565b60405161005891906100c8565b60405180910390f35b61006961008e565b60405161007691906100c8565b60405180910390f35b610087610096565b005b5f5481565b5f8054905090565b60015f808282546100a7919061010e565b92505081905550565b5f819050919050565b6100c2816100b0565b82525050565b5f6020820190506100db5f8301846100b9565b92915050565b7f4e487b71000000000000000000000000000000000000000000000000000000005f52601160045260245ffd5b5f610118826100b0565b9150610123836100b0565b925082820190508082111561013b5761013a6100e1565b5b9291505056fea2646970667358221220137ae5cf0fcdf694f11fbe24952b202d62e7154851f6232b7b897dbf37a2d18164736f6c63430008140033" + try { + const Creates = await hre.ethers.getContractFactory("Creates"); + + // Deploy the contracts + const createsContract = await Creates.deploy(); + + // Wait for the deployment transactions to be mined + await createsContract.waitForDeployment(); + + console.log(`DelegateCalled deployed to: ${await createsContract.getAddress()}`); + + const opCreate = await createsContract.opCreate(hre.ethers.toUtf8Bytes(deployableBytecode), deployableBytecode.length); + console.log('opCreate method call transaction: ', opCreate.hash); + } catch (error) { + console.error(error.toString()); + process.exit(1); + } + } + + main() + .then(() => process.exit(0)) + .catch(error => { + console.error(error); + process.exit(1); + }); \ No newline at end of file diff --git a/zk/erigon_db/db.go b/zk/erigon_db/db.go index 4338ffc44bc..6a6ae1a512f 100644 --- a/zk/erigon_db/db.go +++ b/zk/erigon_db/db.go @@ -15,6 +15,12 @@ import ( var sha3UncleHash = common.HexToHash("0x1dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347") +type ReadOnlyErigonDb interface { + GetBodyTransactions(fromBlockNo, toBlockNo uint64) (*[]ethTypes.Transaction, error) + ReadCanonicalHash(blockNo uint64) (common.Hash, error) + GetHeader(blockNo uint64) (*ethTypes.Header, error) +} + type ErigonDb struct { tx kv.RwTx } diff --git a/zk/hermez_db/db.go b/zk/hermez_db/db.go index c0680cfe8de..0bdbf99ef88 100644 --- a/zk/hermez_db/db.go +++ b/zk/hermez_db/db.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "math" + "sort" "github.com/ledgerwatch/erigon-lib/common" "github.com/ledgerwatch/erigon-lib/kv" @@ -50,7 +51,7 @@ const FORK_HISTORY = "fork_history" // index const JUST_UNWOUND = "just_unwound" // batch number -> true const PLAIN_STATE_VERSION = "plain_state_version" // batch number -> true const ERIGON_VERSIONS = "erigon_versions" // erigon version -> timestamp of startup -const BATCH_ENDS = "batch_ends" +const BATCH_ENDS = "batch_ends" // var HermezDbTables = []string{ L1VERIFICATIONS, @@ -1016,6 +1017,9 @@ func (db *HermezDb) deleteFromBucketWithUintKeysRange(bucket string, fromBlockNu } func (db *HermezDbReader) GetForkId(batchNo uint64) (uint64, error) { + if batchNo == 0 { + batchNo = 1 + } v, err := db.tx.GetOne(FORKIDS, Uint64ToBytes(batchNo)) if err != nil { return 0, err @@ -1763,3 +1767,90 @@ func (db *HermezDbReader) GetBatchEnd(blockNo uint64) (bool, error) { func (db *HermezDb) DeleteBatchEnds(from, to uint64) error { return db.deleteFromBucketWithUintKeysRange(BATCH_ENDS, from, to) } + +func (db *HermezDbReader) GetAllForkIntervals() ([]types.ForkInterval, error) { + return db.getForkIntervals(nil) +} + +func (db *HermezDbReader) GetForkInterval(forkID uint64) (*types.ForkInterval, bool, error) { + forkIntervals, err := db.getForkIntervals(&forkID) + if err != nil { + return nil, false, err + } + + if len(forkIntervals) == 0 { + return nil, false, err + } + + forkInterval := forkIntervals[0] + return &forkInterval, true, nil +} + +func (db *HermezDbReader) getForkIntervals(forkIdFilter *uint64) ([]types.ForkInterval, error) { + mapForkIntervals := map[uint64]types.ForkInterval{} + + c, err := db.tx.Cursor(FORKIDS) + if err != nil { + return nil, err + } + defer c.Close() + + lastForkId := uint64(0) + for k, v, err := c.First(); k != nil; k, v, err = c.Next() { + if err != nil { + return nil, err + } + + batchNumber := BytesToUint64(k) + forkID := BytesToUint64(v) + + if forkID > lastForkId { + lastForkId = forkID + } + + if forkIdFilter != nil && *forkIdFilter != forkID { + continue + } + + mapInterval, found := mapForkIntervals[forkID] + if !found { + mapInterval = types.ForkInterval{ + ForkID: forkID, + FromBatchNumber: batchNumber, + ToBatchNumber: batchNumber, + } + } + + if batchNumber < mapInterval.FromBatchNumber { + mapInterval.FromBatchNumber = batchNumber + } + + if batchNumber > mapInterval.ToBatchNumber { + mapInterval.ToBatchNumber = batchNumber + } + + mapForkIntervals[forkID] = mapInterval + } + + forkIntervals := make([]types.ForkInterval, 0, len(mapForkIntervals)) + for forkId, forkInterval := range mapForkIntervals { + blockNumber, found, err := db.GetForkIdBlock(forkInterval.ForkID) + if err != nil { + return nil, err + } else if found { + forkInterval.BlockNumber = blockNumber + } + + if forkId == lastForkId { + forkInterval.ToBatchNumber = math.MaxUint64 + } + + forkIntervals = append(forkIntervals, forkInterval) + } + + sort.Slice(forkIntervals, func(i, j int) bool { + return forkIntervals[i].FromBatchNumber < forkIntervals[j].FromBatchNumber + }) + + return forkIntervals, nil +} diff --git a/zk/hermez_db/db_test.go b/zk/hermez_db/db_test.go index f017fcc8ceb..7641393d459 100644 --- a/zk/hermez_db/db_test.go +++ b/zk/hermez_db/db_test.go @@ -3,6 +3,7 @@ package hermez_db import ( "context" "fmt" + "math" "testing" "github.com/ledgerwatch/erigon-lib/common" @@ -191,36 +192,52 @@ func TestGetAndSetLatestUnordered(t *testing.T) { } func TestGetAndSetForkId(t *testing.T) { + tx, cleanup := GetDbTx() + defer cleanup() + db := NewHermezDb(tx) - testCases := []struct { - batchNo uint64 - forkId uint64 + forkIntervals := []struct { + ForkId uint64 + FromBatchNumber uint64 + ToBatchNumber uint64 }{ - {9, 0}, // batchNo < 10 -> forkId = 0 - {10, 1}, // batchNo = 10 -> forkId = 1 - {11, 1}, // batchNo > 10 -> forkId = 1 - {99, 1}, // batchNo < 100 -> forkId = 1 - {100, 2}, // batchNo >= 100 -> forkId = 2 - {1000, 2}, // batchNo > 100 -> forkId = 2 + {ForkId: 1, FromBatchNumber: 1, ToBatchNumber: 10}, + {ForkId: 2, FromBatchNumber: 11, ToBatchNumber: 100}, + {ForkId: 3, FromBatchNumber: 101, ToBatchNumber: 1000}, } - for _, tc := range testCases { - t.Run(fmt.Sprintf("BatchNo: %d ForkId: %d", tc.batchNo, tc.forkId), func(t *testing.T) { - tx, cleanup := GetDbTx() - db := NewHermezDb(tx) - - err := db.WriteForkId(10, 1) - require.NoError(t, err, "Failed to write ForkId") - err = db.WriteForkId(tc.batchNo, tc.forkId) - require.NoError(t, err, "Failed to write ForkId") - err = db.WriteForkId(100, 2) + for _, forkInterval := range forkIntervals { + for b := forkInterval.FromBatchNumber; b <= forkInterval.ToBatchNumber; b++ { + err := db.WriteForkId(b, forkInterval.ForkId) require.NoError(t, err, "Failed to write ForkId") + } + } - fetchedForkId, err := db.GetForkId(tc.batchNo) - require.NoError(t, err, "Failed to get ForkId") - assert.Equal(t, tc.forkId, fetchedForkId, "Fetched ForkId doesn't match expected") - cleanup() - }) + testCases := []struct { + batchNo uint64 + expectedForkId uint64 + }{ + {0, 1}, // batch 0 = forkID, special case, batch 0 has the same forkId as batch 1 + + {1, 1}, // batch 1 = forkId 1, first batch for forkId 1 + {5, 1}, // batch 5 = forkId 1, a batch between first and last for forkId 1 + {10, 1}, // batch 10 = forkId 1, last batch for forkId 1 + + {11, 2}, // batch 11 = forkId 1, first batch for forkId 2 + {50, 2}, // batch 50 = forkId 1, a batch between first and last for forkId 2 + {100, 2}, // batch 100 = forkId 1, last batch for forkId 2 + + {101, 3}, // batch 101 = forkId 1, first batch for forkId 3 + {500, 3}, // batch 500 = forkId 1, a batch between first and last for forkId 3 + {1000, 3}, // batch 1000 = forkId 1, last batch for forkId 3 + + {1001, 0}, // batch 1001 = a batch out of the range of the known forks + } + + for _, tc := range testCases { + fetchedForkId, err := db.GetForkId(tc.batchNo) + assert.NoError(t, err) + assert.Equal(t, tc.expectedForkId, fetchedForkId, "invalid expected fork id when getting fork id by batch number") } } @@ -503,3 +520,105 @@ func TestBatchBlocks(t *testing.T) { t.Fatal("Expected 1000 blocks") } } + +func TestDeleteForkId(t *testing.T) { + type forkInterval struct { + ForkId uint64 + FromBatchNumber uint64 + ToBatchNumber uint64 + } + forkIntervals := []forkInterval{ + {1, 1, 10}, + {2, 11, 20}, + {3, 21, 30}, + {4, 31, 40}, + {5, 41, 50}, + {6, 51, 60}, + {7, 61, 70}, + } + + testCases := []struct { + name string + fromBatchToDelete uint64 + toBatchToDelete uint64 + expectedDeletedForksIds []uint64 + expectedRemainingForkIntervals []forkInterval + }{ + {"delete fork id only for the last batch", 70, 70, nil, []forkInterval{ + {1, 1, 10}, + {2, 11, 20}, + {3, 21, 30}, + {4, 31, 40}, + {5, 41, 50}, + {6, 51, 60}, + {7, 61, math.MaxUint64}, + }}, + {"delete fork id for batches that don't exist", 80, 90, nil, []forkInterval{ + {1, 1, 10}, + {2, 11, 20}, + {3, 21, 30}, + {4, 31, 40}, + {5, 41, 50}, + {6, 51, 60}, + {7, 61, math.MaxUint64}, + }}, + {"delete fork id for batches that cross multiple forks from some point until the last one - unwind", 27, 70, []uint64{4, 5, 6, 7}, []forkInterval{ + {1, 1, 10}, + {2, 11, 20}, + {3, 21, math.MaxUint64}, + }}, + {"delete fork id for batches that cross multiple forks from zero to some point - prune", 0, 36, []uint64{1, 2, 3}, []forkInterval{ + {4, 37, 40}, + {5, 41, 50}, + {6, 51, 60}, + {7, 61, math.MaxUint64}, + }}, + {"delete fork id for batches that cross multiple forks from some point after the beginning to some point before the end - hole", 23, 42, []uint64{4}, []forkInterval{ + {1, 1, 10}, + {2, 11, 20}, + {3, 21, 22}, + {5, 43, 50}, + {6, 51, 60}, + {7, 61, math.MaxUint64}, + }}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + tx, cleanup := GetDbTx() + defer cleanup() + db := NewHermezDb(tx) + + for _, forkInterval := range forkIntervals { + for b := forkInterval.FromBatchNumber; b <= forkInterval.ToBatchNumber; b++ { + err := db.WriteForkId(b, forkInterval.ForkId) + require.NoError(t, err, "Failed to write ForkId") + } + } + + err := db.DeleteForkIds(tc.fromBatchToDelete, tc.toBatchToDelete) + require.NoError(t, err) + + for batchNum := tc.fromBatchToDelete; batchNum <= tc.toBatchToDelete; batchNum++ { + forkId, err := db.GetForkId(batchNum) + require.NoError(t, err) + assert.Equal(t, uint64(0), forkId) + } + + for _, forkId := range tc.expectedDeletedForksIds { + forkInterval, found, err := db.GetForkInterval(forkId) + require.NoError(t, err) + assert.False(t, found) + assert.Nil(t, forkInterval) + } + + for _, remainingForkInterval := range tc.expectedRemainingForkIntervals { + forkInterval, found, err := db.GetForkInterval(remainingForkInterval.ForkId) + require.NoError(t, err) + assert.True(t, found) + assert.Equal(t, remainingForkInterval.FromBatchNumber, forkInterval.FromBatchNumber) + assert.Equal(t, remainingForkInterval.ToBatchNumber, forkInterval.ToBatchNumber) + } + }) + } +} diff --git a/zk/l1infotree/tree.go b/zk/l1infotree/tree.go index b8e9bf6c81f..2a544e0678c 100644 --- a/zk/l1infotree/tree.go +++ b/zk/l1infotree/tree.go @@ -36,8 +36,8 @@ func NewL1InfoTree(height uint8, initialLeaves [][32]byte) (*L1InfoTree, error) mt.allLeaves[leaf] = struct{}{} } - log.Debug("Initial count: ", mt.count) - log.Debug("Initial root: ", mt.currentRoot) + log.Debug(fmt.Sprintf("Initial count: %d", mt.count)) + log.Debug(fmt.Sprintf("Initial root: %s", mt.currentRoot)) return mt, nil } diff --git a/zk/legacy_executor_verifier/legacy_executor_verifier.go b/zk/legacy_executor_verifier/legacy_executor_verifier.go index 7933ebec36e..f284f6a152a 100644 --- a/zk/legacy_executor_verifier/legacy_executor_verifier.go +++ b/zk/legacy_executor_verifier/legacy_executor_verifier.go @@ -315,6 +315,13 @@ func (v *LegacyExecutorVerifier) VerifyWithoutExecutor(request *VerifierRequest) return promise } +func (v *LegacyExecutorVerifier) HasPendingVerifications() bool { + v.mtxPromises.Lock() + defer v.mtxPromises.Unlock() + + return len(v.promises) > 0 +} + func (v *LegacyExecutorVerifier) ProcessResultsSequentially(logPrefix string) ([]*VerifierBundle, error) { v.mtxPromises.Lock() defer v.mtxPromises.Unlock() @@ -444,7 +451,7 @@ func (v *LegacyExecutorVerifier) GetWholeBatchStreamBytes( txsPerBlock[blockNumber] = filteredTransactions } - entries, err := server.BuildWholeBatchStreamEntriesProto(tx, hermezDb, v.streamServer.GetChainId(), batchNumber, previousBatch, blocks, txsPerBlock, l1InfoTreeMinTimestamps) + entries, err := server.BuildWholeBatchStreamEntriesProto(tx, hermezDb, v.streamServer.GetChainId(), previousBatch, batchNumber, blocks, txsPerBlock, l1InfoTreeMinTimestamps) if err != nil { return nil, err } diff --git a/zk/stages/stage_batches.go b/zk/stages/stage_batches.go index ce19ed52cfd..95647338d7e 100644 --- a/zk/stages/stage_batches.go +++ b/zk/stages/stage_batches.go @@ -25,7 +25,9 @@ import ( txtype "github.com/ledgerwatch/erigon/zk/tx" "github.com/ledgerwatch/erigon/core/rawdb" + "github.com/ledgerwatch/erigon/core/state" "github.com/ledgerwatch/erigon/eth/ethconfig" + "github.com/ledgerwatch/erigon/zk/datastream/client" "github.com/ledgerwatch/erigon/zk/utils" "github.com/ledgerwatch/log/v3" ) @@ -35,6 +37,11 @@ const ( STAGE_PROGRESS_SAVE = 3000000 ) +var ( + // ErrFailedToFindCommonAncestor denotes error suggesting that the common ancestor is not found in the database + ErrFailedToFindCommonAncestor = errors.New("failed to find common ancestor block in the db") +) + type ErigonDb interface { WriteHeader(batchNo *big.Int, blockHash common.Hash, stateRoot, txHash, parentHash common.Hash, coinbase common.Address, ts, gasLimit uint64, chainConfig *chain.Config) (*ethTypes.Header, error) WriteBody(batchNo *big.Int, headerHash common.Hash, txs []ethTypes.Transaction) error @@ -72,23 +79,30 @@ type HermezDb interface { type DatastreamClient interface { ReadAllEntriesToChannel() error GetEntryChan() chan interface{} + GetL2BlockByNumber(blockNum uint64) (*types.FullL2Block, int, error) + GetLatestL2Block() (*types.FullL2Block, error) GetLastWrittenTimeAtomic() *atomic.Int64 GetStreamingAtomic() *atomic.Bool GetProgressAtomic() *atomic.Uint64 EnsureConnected() (bool, error) + Start() error + Stop() } +type dsClientCreatorHandler func(context.Context, *ethconfig.Zk, uint64) (DatastreamClient, error) + type BatchesCfg struct { - db kv.RwDB - blockRoutineStarted bool - dsClient DatastreamClient - zkCfg *ethconfig.Zk - chainConfig *chain.Config - miningConfig *params.MiningConfig + db kv.RwDB + blockRoutineStarted bool + dsClient DatastreamClient + dsQueryClientCreator dsClientCreatorHandler + zkCfg *ethconfig.Zk + chainConfig *chain.Config + miningConfig *params.MiningConfig } -func StageBatchesCfg(db kv.RwDB, dsClient DatastreamClient, zkCfg *ethconfig.Zk, chainConfig *chain.Config, miningConfig *params.MiningConfig) BatchesCfg { - return BatchesCfg{ +func StageBatchesCfg(db kv.RwDB, dsClient DatastreamClient, zkCfg *ethconfig.Zk, chainConfig *chain.Config, miningConfig *params.MiningConfig, options ...Option) BatchesCfg { + cfg := BatchesCfg{ db: db, blockRoutineStarted: false, dsClient: dsClient, @@ -96,6 +110,21 @@ func StageBatchesCfg(db kv.RwDB, dsClient DatastreamClient, zkCfg *ethconfig.Zk, chainConfig: chainConfig, miningConfig: miningConfig, } + + for _, opt := range options { + opt(&cfg) + } + + return cfg +} + +type Option func(*BatchesCfg) + +// WithDSClientCreator is a functional option to set the datastream client creator callback. +func WithDSClientCreator(handler dsClientCreatorHandler) Option { + return func(c *BatchesCfg) { + c.dsQueryClientCreator = handler + } } var emptyHash = common.Hash{0} @@ -136,23 +165,48 @@ func SpawnStageBatches( return fmt.Errorf("save stage progress error: %v", err) } + //// BISECT //// + if cfg.zkCfg.DebugLimit > 0 && stageProgressBlockNo > cfg.zkCfg.DebugLimit { + return nil + } + // get batch for batches progress stageProgressBatchNo, err := hermezDb.GetBatchNoByL2Block(stageProgressBlockNo) if err != nil && !errors.Is(err, hermez_db.ErrorNotStored) { return fmt.Errorf("get batch no by l2 block error: %v", err) } - //// BISECT //// - if cfg.zkCfg.DebugLimit > 0 && stageProgressBlockNo > cfg.zkCfg.DebugLimit { - return nil - } - highestVerifiedBatch, err := stages.GetStageProgress(tx, stages.L1VerificationsBatchNo) if err != nil { return errors.New("could not retrieve l1 verifications batch no progress") } startSyncTime := time.Now() + + latestForkId, err := stages.GetStageProgress(tx, stages.ForkId) + if err != nil { + return err + } + + dsQueryClient, err := newStreamClient(ctx, cfg, latestForkId) + if err != nil { + log.Warn(fmt.Sprintf("[%s] %s", logPrefix, err)) + return err + } + defer dsQueryClient.Stop() + + highestDSL2Block, err := dsQueryClient.GetLatestL2Block() + if err != nil { + return fmt.Errorf("failed to retrieve the latest datastream l2 block: %w", err) + } + + if highestDSL2Block.L2BlockNumber < stageProgressBlockNo { + stageProgressBlockNo = highestDSL2Block.L2BlockNumber + } + + log.Debug(fmt.Sprintf("[%s] Highest block in datastream", logPrefix), "block", highestDSL2Block.L2BlockNumber) + log.Debug(fmt.Sprintf("[%s] Highest block in db", logPrefix), "block", stageProgressBlockNo) + dsClientProgress := cfg.dsClient.GetProgressAtomic() dsClientProgress.Store(stageProgressBlockNo) // start routine to download blocks and push them in a channel @@ -166,7 +220,7 @@ func SpawnStageBatches( for i := 0; i < 5; i++ { connected, err = cfg.dsClient.EnsureConnected() if err != nil { - log.Error("[datastream_client] Error connecting to datastream", "error", err) + log.Error(fmt.Sprintf("[%s] Error connecting to datastream", logPrefix), "error", err) continue } if connected { @@ -180,7 +234,7 @@ func SpawnStageBatches( if connected { if err := cfg.dsClient.ReadAllEntriesToChannel(); err != nil { - log.Error("[datastream_client] Error downloading blocks from datastream", "error", err) + log.Error(fmt.Sprintf("[%s] Error downloading blocks from datastream", logPrefix), "error", err) } } }() @@ -261,6 +315,7 @@ LOOP: return err } case *types.FullL2Block: + log.Debug(fmt.Sprintf("[%s] Retrieved %d (%s) block from stream", logPrefix, entry.L2BlockNumber, entry.L2Blockhash.String())) if cfg.zkCfg.SyncLimit > 0 && entry.L2BlockNumber >= cfg.zkCfg.SyncLimit { // stop the node going into a crazy loop time.Sleep(2 * time.Second) @@ -298,9 +353,57 @@ LOOP: // when the stage is fired up for the first time log.Warn(fmt.Sprintf("[%s] Skipping block %d, already processed", logPrefix, entry.L2BlockNumber)) } + + dbBatchNum, err := hermezDb.GetBatchNoByL2Block(entry.L2BlockNumber) + if err != nil { + return err + } + + if entry.BatchNumber != dbBatchNum { + // if the bath number mismatches, it means that we need to trigger an unwinding of blocks + log.Warn(fmt.Sprintf("[%s] Batch number mismatch detected. Triggering unwind...", logPrefix), + "block", entry.L2BlockNumber, "ds batch", entry.BatchNumber, "db batch", dbBatchNum) + if err := rollback(logPrefix, eriDb, hermezDb, dsQueryClient, entry.L2BlockNumber, tx, u); err != nil { + return err + } + cfg.dsClient.Stop() + return nil + } continue } + var dbParentBlockHash common.Hash + if entry.L2BlockNumber > 0 { + dbParentBlockHash, err = eriDb.ReadCanonicalHash(entry.L2BlockNumber - 1) + if err != nil { + return fmt.Errorf("failed to retrieve parent block hash for datastream block %d: %w", + entry.L2BlockNumber, err) + } + } + + dsParentBlockHash := lastHash + if dsParentBlockHash == emptyHash { + parentBlockDS, _, err := dsQueryClient.GetL2BlockByNumber(entry.L2BlockNumber - 1) + if err != nil { + return err + } + + if parentBlockDS != nil { + dsParentBlockHash = parentBlockDS.L2Blockhash + } + } + + if dbParentBlockHash != dsParentBlockHash { + // unwind/rollback blocks until the latest common ancestor block + log.Warn(fmt.Sprintf("[%s] Parent block hashes mismatch on block %d. Triggering unwind...", logPrefix, entry.L2BlockNumber), + "db parent block hash", dbParentBlockHash, "ds parent block hash", dsParentBlockHash) + if err := rollback(logPrefix, eriDb, hermezDb, dsQueryClient, entry.L2BlockNumber, tx, u); err != nil { + return err + } + cfg.dsClient.Stop() + return nil + } + // skip if we already have this block if entry.L2BlockNumber < lastBlockHeight+1 { log.Warn(fmt.Sprintf("[%s] Unwinding to block %d", logPrefix, entry.L2BlockNumber)) @@ -309,6 +412,7 @@ LOOP: return fmt.Errorf("failed to get bad block: %v", err) } u.UnwindTo(entry.L2BlockNumber, stagedsync.BadBlock(badBlock, fmt.Errorf("received block %d again", entry.L2BlockNumber))) + return nil } // check for sequential block numbers @@ -432,16 +536,18 @@ LOOP: return err } - if err := tx.Commit(); err != nil { - return fmt.Errorf("failed to commit tx, %w", err) - } + if freshTx { + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit tx, %w", err) + } - tx, err = cfg.db.BeginRw(ctx) - if err != nil { - return fmt.Errorf("failed to open tx, %w", err) + tx, err = cfg.db.BeginRw(ctx) + if err != nil { + return fmt.Errorf("failed to open tx, %w", err) + } + hermezDb = hermez_db.NewHermezDb(tx) + eriDb = erigon_db.NewErigonDb(tx) } - hermezDb = hermez_db.NewHermezDb(tx) - eriDb = erigon_db.NewErigonDb(tx) prevAmountBlocksWritten = blocksWritten } @@ -712,7 +818,7 @@ func PruneBatchesStage(s *stagedsync.PruneState, tx kv.RwTx, cfg BatchesCfg, ctx defer tx.Rollback() } - log.Info(fmt.Sprintf("[%s] Pruning barches...", logPrefix)) + log.Info(fmt.Sprintf("[%s] Pruning batches...", logPrefix)) defer log.Info(fmt.Sprintf("[%s] Unwinding batches complete", logPrefix)) hermezDb := hermez_db.NewHermezDb(tx) @@ -862,3 +968,137 @@ func writeL2Block(eriDb ErigonDb, hermezDb HermezDb, l2Block *types.FullL2Block, return nil } + +// rollback performs the unwinding of blocks: +// 1. queries the latest common ancestor for datastream and db, +// 2. resolves the unwind block (as the latest block in the previous batch, comparing to the found ancestor block) +// 3. triggers the unwinding +func rollback(logPrefix string, eriDb *erigon_db.ErigonDb, hermezDb *hermez_db.HermezDb, + dsQueryClient DatastreamClient, latestDSBlockNum uint64, tx kv.RwTx, u stagedsync.Unwinder) error { + ancestorBlockNum, ancestorBlockHash, err := findCommonAncestor(eriDb, hermezDb, dsQueryClient, latestDSBlockNum) + if err != nil { + return err + } + log.Debug(fmt.Sprintf("[%s] The common ancestor for datastream and db is block %d (%s)", logPrefix, ancestorBlockNum, ancestorBlockHash)) + + unwindBlockNum, unwindBlockHash, batchNum, err := getUnwindPoint(eriDb, hermezDb, ancestorBlockNum, ancestorBlockHash) + if err != nil { + return err + } + + if err = stages.SaveStageProgress(tx, stages.HighestSeenBatchNumber, batchNum-1); err != nil { + return err + } + log.Warn(fmt.Sprintf("[%s] Unwinding to block %d (%s)", logPrefix, unwindBlockNum, unwindBlockHash)) + u.UnwindTo(unwindBlockNum, stagedsync.BadBlock(unwindBlockHash, fmt.Errorf("unwind to block %d", unwindBlockNum))) + return nil +} + +// findCommonAncestor searches the latest common ancestor block number and hash between the data stream and the local db. +// The common ancestor block is the one that matches both l2 block hash and batch number. +func findCommonAncestor( + db erigon_db.ReadOnlyErigonDb, + hermezDb state.ReadOnlyHermezDb, + dsClient DatastreamClient, + latestBlockNum uint64) (uint64, common.Hash, error) { + var ( + startBlockNum = uint64(0) + endBlockNum = latestBlockNum + blockNumber *uint64 + blockHash common.Hash + ) + + if latestBlockNum == 0 { + return 0, emptyHash, ErrFailedToFindCommonAncestor + } + + for startBlockNum <= endBlockNum { + if endBlockNum == 0 { + return 0, emptyHash, ErrFailedToFindCommonAncestor + } + + midBlockNum := (startBlockNum + endBlockNum) / 2 + midBlockDataStream, errCode, err := dsClient.GetL2BlockByNumber(midBlockNum) + if err != nil && + // the required block might not be in the data stream, so ignore that error + errCode != types.CmdErrBadFromBookmark { + return 0, emptyHash, err + } + + midBlockDbHash, err := db.ReadCanonicalHash(midBlockNum) + if err != nil { + return 0, emptyHash, err + } + + dbBatchNum, err := hermezDb.GetBatchNoByL2Block(midBlockNum) + if err != nil { + return 0, emptyHash, err + } + + if midBlockDataStream != nil && + midBlockDataStream.L2Blockhash == midBlockDbHash && + midBlockDataStream.BatchNumber == dbBatchNum { + startBlockNum = midBlockNum + 1 + + blockNumber = &midBlockNum + blockHash = midBlockDbHash + } else { + endBlockNum = midBlockNum - 1 + } + } + + if blockNumber == nil { + return 0, emptyHash, ErrFailedToFindCommonAncestor + } + + return *blockNumber, blockHash, nil +} + +// getUnwindPoint resolves the unwind block as the latest block in the previous batch, relative to the provided block. +func getUnwindPoint(eriDb erigon_db.ReadOnlyErigonDb, hermezDb state.ReadOnlyHermezDb, blockNum uint64, blockHash common.Hash) (uint64, common.Hash, uint64, error) { + batchNum, err := hermezDb.GetBatchNoByL2Block(blockNum) + if err != nil { + return 0, emptyHash, 0, err + } + + if batchNum == 0 { + return 0, emptyHash, 0, + fmt.Errorf("failed to find batch number for the block %d (%s)", blockNum, blockHash) + } + + unwindBlockNum, _, err := hermezDb.GetHighestBlockInBatch(batchNum - 1) + if err != nil { + return 0, emptyHash, 0, err + } + + unwindBlockHash, err := eriDb.ReadCanonicalHash(unwindBlockNum) + if err != nil { + return 0, emptyHash, 0, err + } + + return unwindBlockNum, unwindBlockHash, batchNum, nil +} + +// newStreamClient instantiates new datastreamer client and starts it. +func newStreamClient(ctx context.Context, cfg BatchesCfg, latestForkId uint64) (DatastreamClient, error) { + var ( + dsClient DatastreamClient + err error + ) + + if cfg.dsQueryClientCreator != nil { + dsClient, err = cfg.dsQueryClientCreator(ctx, cfg.zkCfg, latestForkId) + if err != nil { + return nil, fmt.Errorf("failed to create a datastream client. Reason: %w", err) + } + } else { + zkCfg := cfg.zkCfg + dsClient = client.NewClient(ctx, zkCfg.L2DataStreamerUrl, zkCfg.DatastreamVersion, zkCfg.L2DataStreamerTimeout, uint16(latestForkId)) + } + + if err := dsClient.Start(); err != nil { + return nil, fmt.Errorf("failed to start a datastream client. Reason: %w", err) + } + + return dsClient, nil +} diff --git a/zk/stages/stage_batches_test.go b/zk/stages/stage_batches_test.go index ffbc9b0e686..14d32f09ce5 100644 --- a/zk/stages/stage_batches_test.go +++ b/zk/stages/stage_batches_test.go @@ -3,16 +3,19 @@ package stages import ( "context" "encoding/hex" + "strings" "testing" "github.com/ledgerwatch/erigon-lib/chain" "github.com/ledgerwatch/erigon-lib/common" "github.com/ledgerwatch/erigon-lib/kv" "github.com/ledgerwatch/erigon-lib/kv/memdb" + "github.com/ledgerwatch/erigon/core/rawdb" "github.com/ledgerwatch/erigon/eth/stagedsync" "github.com/ledgerwatch/erigon/eth/stagedsync/stages" "github.com/ledgerwatch/erigon/smt/pkg/db" "github.com/ledgerwatch/erigon/zk/datastream/types" + "github.com/ledgerwatch/erigon/zk/erigon_db" "github.com/ledgerwatch/erigon/zk/hermez_db" "github.com/ledgerwatch/erigon/eth/ethconfig" @@ -20,36 +23,9 @@ import ( ) func TestUnwindBatches(t *testing.T) { - fullL2Blocks := []types.FullL2Block{} - post155 := "0xf86780843b9aca00826163941275fbb540c8efc58b812ba83b0d0b8b9917ae98808464fbb77c1ba0b7d2a666860f3c6b8f5ef96f86c7ec5562e97fd04c2e10f3755ff3a0456f9feba0246df95217bf9082f84f9e40adb0049c6664a5bb4c9cbe34ab1a73e77bab26ed" - post155Bytes, err := hex.DecodeString(post155[2:]) currentBlockNumber := 10 + fullL2Blocks := createTestL2Blocks(t, currentBlockNumber) - require.NoError(t, err) - for i := 1; i <= currentBlockNumber; i++ { - fullL2Blocks = append(fullL2Blocks, types.FullL2Block{ - BatchNumber: 1 + uint64(i/2), - L2BlockNumber: uint64(i), - Timestamp: int64(i) * 10000, - DeltaTimestamp: uint32(i) * 10, - L1InfoTreeIndex: uint32(i) + 20, - GlobalExitRoot: common.Hash{byte(i)}, - Coinbase: common.Address{byte(i)}, - ForkId: 1 + uint64(i)/3, - L1BlockHash: common.Hash{byte(i)}, - L2Blockhash: common.Hash{byte(i)}, - StateRoot: common.Hash{byte(i)}, - L2Txs: []types.L2TransactionProto{ - { - EffectiveGasPricePercentage: 255, - IsValid: true, - IntermediateStateRoot: common.Hash{byte(i + 1)}, - Encoded: post155Bytes, - }, - }, - ParentHash: common.Hash{byte(i - 1)}, - }) - } gerUpdates := []types.GerUpdate{} for i := currentBlockNumber + 1; i <= currentBlockNumber+5; i++ { gerUpdates = append(gerUpdates, types.GerUpdate{ @@ -65,7 +41,7 @@ func TestUnwindBatches(t *testing.T) { ctx, db1 := context.Background(), memdb.NewTestDB(t) tx := memdb.BeginRw(t, db1) - err = hermez_db.CreateHermezBuckets(tx) + err := hermez_db.CreateHermezBuckets(tx) require.NoError(t, err) err = db.CreateEriDbBuckets(tx) @@ -73,7 +49,10 @@ func TestUnwindBatches(t *testing.T) { dsClient := NewTestDatastreamClient(fullL2Blocks, gerUpdates) - cfg := StageBatchesCfg(db1, dsClient, ðconfig.Zk{}, &chain.Config{}, nil) + tmpDSClientCreator := func(_ context.Context, _ *ethconfig.Zk, _ uint64) (DatastreamClient, error) { + return NewTestDatastreamClient(fullL2Blocks, gerUpdates), nil + } + cfg := StageBatchesCfg(db1, dsClient, ðconfig.Zk{}, &chain.Config{}, nil, WithDSClientCreator(tmpDSClientCreator)) s := &stagedsync.StageState{ID: stages.Batches, BlockNumber: 0} u := &stagedsync.Sync{} @@ -132,6 +111,134 @@ func TestUnwindBatches(t *testing.T) { } size, err := tx3.BucketSize(bucket) require.NoError(t, err) - require.Equal(t, bucketSized[bucket], size, "butcket %s is not empty", bucket) + require.Equal(t, bucketSized[bucket], size, "bucket %s is not empty", bucket) + } +} + +func TestFindCommonAncestor(t *testing.T) { + blocksCount := 40 + l2Blocks := createTestL2Blocks(t, blocksCount) + + testCases := []struct { + name string + dbBlocksCount int + dsBlocksCount int + latestBlockNum uint64 + divergentBlockHistory bool + expectedBlockNum uint64 + expectedHash common.Hash + expectedError error + }{ + { + name: "Successful search (db lagging behind the data stream)", + dbBlocksCount: 5, + dsBlocksCount: 10, + latestBlockNum: 5, + expectedBlockNum: 5, + expectedHash: common.Hash{byte(5)}, + expectedError: nil, + }, + { + name: "Successful search (db leading the data stream)", + dbBlocksCount: 20, + dsBlocksCount: 10, + latestBlockNum: 10, + expectedBlockNum: 10, + expectedHash: common.Hash{byte(10)}, + expectedError: nil, + }, + { + name: "Failed to find common ancestor block (latest block number is 0)", + dbBlocksCount: 10, + dsBlocksCount: 10, + latestBlockNum: 0, + expectedError: ErrFailedToFindCommonAncestor, + }, + { + name: "Failed to find common ancestor block (different blocks in the data stream and db)", + dbBlocksCount: 10, + dsBlocksCount: 10, + divergentBlockHistory: true, + latestBlockNum: 20, + expectedError: ErrFailedToFindCommonAncestor, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // ARRANGE + testDb, tx := memdb.NewTestTx(t) + defer testDb.Close() + defer tx.Rollback() + err := hermez_db.CreateHermezBuckets(tx) + require.NoError(t, err) + + err = db.CreateEriDbBuckets(tx) + require.NoError(t, err) + + hermezDb := hermez_db.NewHermezDb(tx) + erigonDb := erigon_db.NewErigonDb(tx) + + dsBlocks := l2Blocks[:tc.dsBlocksCount] + dbBlocks := l2Blocks[:tc.dbBlocksCount] + if tc.divergentBlockHistory { + dbBlocks = l2Blocks[tc.dsBlocksCount : tc.dbBlocksCount+tc.dsBlocksCount] + } + + dsClient := NewTestDatastreamClient(dsBlocks, nil) + for _, l2Block := range dbBlocks { + require.NoError(t, hermezDb.WriteBlockBatch(l2Block.L2BlockNumber, l2Block.BatchNumber)) + require.NoError(t, rawdb.WriteCanonicalHash(tx, l2Block.L2Blockhash, l2Block.L2BlockNumber)) + } + + // ACT + ancestorNum, ancestorHash, err := findCommonAncestor(erigonDb, hermezDb, dsClient, tc.latestBlockNum) + + // ASSERT + if tc.expectedError != nil { + require.Error(t, err) + require.Equal(t, tc.expectedError.Error(), err.Error()) + require.Equal(t, uint64(0), ancestorNum) + require.Equal(t, emptyHash, ancestorHash) + } else { + require.NoError(t, err) + require.Equal(t, tc.expectedBlockNum, ancestorNum) + require.Equal(t, tc.expectedHash, ancestorHash) + } + }) } } + +func createTestL2Blocks(t *testing.T, blocksCount int) []types.FullL2Block { + post155 := "0xf86780843b9aca00826163941275fbb540c8efc58b812ba83b0d0b8b9917ae98808464fbb77c1ba0b7d2a666860f3c6b8f5ef96f86c7ec5562e97fd04c2e10f3755ff3a0456f9feba0246df95217bf9082f84f9e40adb0049c6664a5bb4c9cbe34ab1a73e77bab26ed" + post155Bytes, err := hex.DecodeString(strings.TrimPrefix(post155, "0x")) + require.NoError(t, err) + + l2Blocks := make([]types.FullL2Block, 0, blocksCount) + for i := 1; i <= blocksCount; i++ { + l2Blocks = append(l2Blocks, types.FullL2Block{ + BatchNumber: 1 + uint64(i/2), + L2BlockNumber: uint64(i), + Timestamp: int64(i) * 10000, + DeltaTimestamp: uint32(i) * 10, + L1InfoTreeIndex: uint32(i) + 20, + GlobalExitRoot: common.Hash{byte(i)}, + Coinbase: common.Address{byte(i)}, + ForkId: 1 + uint64(i)/3, + L1BlockHash: common.Hash{byte(i)}, + L2Blockhash: common.Hash{byte(i)}, + StateRoot: common.Hash{byte(i)}, + L2Txs: []types.L2TransactionProto{ + { + EffectiveGasPricePercentage: 255, + IsValid: true, + IntermediateStateRoot: common.Hash{byte(i + 1)}, + Encoded: post155Bytes, + }, + }, + ParentHash: common.Hash{byte(i - 1)}, + }) + } + + return l2Blocks +} diff --git a/zk/stages/stage_interhashes.go b/zk/stages/stage_interhashes.go index 1fa410bbf60..b55e86dd1b6 100644 --- a/zk/stages/stage_interhashes.go +++ b/zk/stages/stage_interhashes.go @@ -406,9 +406,6 @@ func zkIncrementIntermediateHashes(ctx context.Context, logPrefix string, s *sta if len(ach) > 0 { hexcc := "0x" + ach codeChanges[addr] = hexcc - if err != nil { - return trie.EmptyRoot, err - } } } @@ -685,20 +682,9 @@ func processAccount(db smt.DB, a *accounts.Account, as map[string]string, inc ui } func insertContractBytecodeToKV(db smt.DB, keys []utils.NodeKey, ethAddr string, bytecode string) ([]utils.NodeKey, error) { - keyContractCode, err := utils.KeyContractCode(ethAddr) - if err != nil { - return []utils.NodeKey{}, err - } - - keyContractLength, err := utils.KeyContractLength(ethAddr) - if err != nil { - return []utils.NodeKey{}, err - } - - hashedBytecode, err := utils.HashContractBytecode(bytecode) - if err != nil { - return []utils.NodeKey{}, err - } + keyContractCode := utils.KeyContractCode(ethAddr) + keyContractLength := utils.KeyContractLength(ethAddr) + hashedBytecode := utils.HashContractBytecode(bytecode) parsedBytecode := strings.TrimPrefix(bytecode, "0x") if len(parsedBytecode)%2 != 0 { @@ -747,10 +733,7 @@ func insertContractStorageToKV(db smt.DB, keys []utils.NodeKey, ethAddr string, continue } - keyStoragePosition, err := utils.KeyContractStorage(add, k) - if err != nil { - return []utils.NodeKey{}, err - } + keyStoragePosition := utils.KeyContractStorage(add, k) base := 10 if strings.HasPrefix(v, "0x") { @@ -780,14 +763,8 @@ func insertContractStorageToKV(db smt.DB, keys []utils.NodeKey, ethAddr string, } func insertAccountStateToKV(db smt.DB, keys []utils.NodeKey, ethAddr string, balance, nonce *big.Int) ([]utils.NodeKey, error) { - keyBalance, err := utils.KeyEthAddrBalance(ethAddr) - if err != nil { - return []utils.NodeKey{}, err - } - keyNonce, err := utils.KeyEthAddrNonce(ethAddr) - if err != nil { - return []utils.NodeKey{}, err - } + keyBalance := utils.KeyEthAddrBalance(ethAddr) + keyNonce := utils.KeyEthAddrNonce(ethAddr) x := utils.ScalarToArrayBig(balance) valueBalance, err := utils.NodeValue8FromBigIntArray(x) diff --git a/zk/stages/stage_l1_sequencer_sync.go b/zk/stages/stage_l1_sequencer_sync.go index fd81e80e232..f96825209ad 100644 --- a/zk/stages/stage_l1_sequencer_sync.go +++ b/zk/stages/stage_l1_sequencer_sync.go @@ -66,6 +66,24 @@ func SpawnL1SequencerSyncStage( } if progress == 0 { progress = cfg.zkCfg.L1FirstBlock - 1 + + } + + // if the flag is set - wait for that block to be finalized on L1 before continuing + if progress <= cfg.zkCfg.L1FinalizedBlockRequirement && cfg.zkCfg.L1FinalizedBlockRequirement > 0 { + for { + finalized, finalizedBn, err := cfg.syncer.CheckL1BlockFinalized(cfg.zkCfg.L1FinalizedBlockRequirement) + if err != nil { + // we shouldn't just throw the error, because it could be a timeout, or "too many requests" error and we could jsut retry + log.Error(fmt.Sprintf("[%s] Error checking if L1 block %v is finalized: %v", logPrefix, cfg.zkCfg.L1FinalizedBlockRequirement, err)) + } + + if finalized { + break + } + log.Info(fmt.Sprintf("[%s] Waiting for L1 block %v to be correctly checked for \"finalized\" before continuing. Current finalized is %d", logPrefix, cfg.zkCfg.L1FinalizedBlockRequirement, finalizedBn)) + time.Sleep(1 * time.Minute) // sleep could be even bigger since finalization takes more than 10 minutes + } } hermezDb := hermez_db.NewHermezDb(tx) diff --git a/zk/stages/stage_l1syncer.go b/zk/stages/stage_l1syncer.go index 5fdac96b6dc..ccf4b856b18 100644 --- a/zk/stages/stage_l1syncer.go +++ b/zk/stages/stage_l1syncer.go @@ -42,6 +42,7 @@ type IL1Syncer interface { StopQueryBlocks() ConsumeQueryBlocks() WaitQueryBlocksToFinish() + CheckL1BlockFinalized(blockNo uint64) (bool, uint64, error) } var ( @@ -352,7 +353,7 @@ func verifyAgainstLocalBlocks(tx kv.RwTx, hermezDb *hermez_db.HermezDb, logPrefi // in this case we need to find the blocknumber that is highest for the last batch // get the batch of the last hashed block hashedBatch, err := hermezDb.GetBatchNoByL2Block(hashedBlockNo) - if err != nil && !errors.Is(err, hermez_db.ErrorNotStored){ + if err != nil && !errors.Is(err, hermez_db.ErrorNotStored) { return err } diff --git a/zk/stages/stage_sequence_execute.go b/zk/stages/stage_sequence_execute.go index a318ce1cbf6..903051f9fc7 100644 --- a/zk/stages/stage_sequence_execute.go +++ b/zk/stages/stage_sequence_execute.go @@ -15,6 +15,7 @@ import ( "github.com/ledgerwatch/erigon/eth/stagedsync" "github.com/ledgerwatch/erigon/eth/stagedsync/stages" "github.com/ledgerwatch/erigon/zk" + zktx "github.com/ledgerwatch/erigon/zk/tx" "github.com/ledgerwatch/erigon/zk/utils" ) @@ -28,6 +29,78 @@ func SpawnSequencingStage( cfg SequenceBlockCfg, historyCfg stagedsync.HistoryCfg, quiet bool, +) (err error) { + roTx, err := cfg.db.BeginRo(ctx) + if err != nil { + return err + } + defer roTx.Rollback() + + lastBatch, err := stages.GetStageProgress(roTx, stages.HighestSeenBatchNumber) + if err != nil { + return err + } + + highestBatchInDS, err := cfg.datastreamServer.GetHighestBatchNumber() + if err != nil { + return err + } + + if !cfg.zk.SequencerResequence || lastBatch >= highestBatchInDS { + if cfg.zk.SequencerResequence { + log.Info(fmt.Sprintf("[%s] Resequencing completed. Please restart sequencer without resequence flag.", s.LogPrefix())) + time.Sleep(10 * time.Second) + return nil + } + + err = sequencingStageStep(s, u, ctx, cfg, historyCfg, quiet, nil) + if err != nil { + return err + } + } else { + log.Info(fmt.Sprintf("[%s] Last batch %d is lower than highest batch in datastream %d, resequencing...", s.LogPrefix(), lastBatch, highestBatchInDS)) + + batches, err := cfg.datastreamServer.ReadBatches(lastBatch+1, highestBatchInDS) + if err != nil { + return err + } + + err = cfg.datastreamServer.UnwindToBatchStart(lastBatch + 1) + if err != nil { + return err + } + + log.Info(fmt.Sprintf("[%s] Resequence from batch %d to %d in data stream", s.LogPrefix(), lastBatch+1, highestBatchInDS)) + + for _, batch := range batches { + batchJob := NewResequenceBatchJob(batch) + subBatchCount := 0 + for batchJob.HasMoreBlockToProcess() { + if err = sequencingStageStep(s, u, ctx, cfg, historyCfg, quiet, batchJob); err != nil { + return err + } + + subBatchCount += 1 + } + + log.Info(fmt.Sprintf("[%s] Resequenced original batch %d with %d batches", s.LogPrefix(), batchJob.batchToProcess[0].BatchNumber, subBatchCount)) + if cfg.zk.SequencerResequenceStrict && subBatchCount != 1 { + return fmt.Errorf("strict mode enabled, but resequenced batch %d has %d sub-batches", batchJob.batchToProcess[0].BatchNumber, subBatchCount) + } + } + } + + return nil +} + +func sequencingStageStep( + s *stagedsync.StageState, + u stagedsync.Unwinder, + ctx context.Context, + cfg SequenceBlockCfg, + historyCfg stagedsync.HistoryCfg, + quiet bool, + resequenceBatchJob *ResequenceBatchJob, ) (err error) { logPrefix := s.LogPrefix() log.Info(fmt.Sprintf("[%s] Starting sequencing stage", logPrefix)) @@ -69,7 +142,7 @@ func SpawnSequencingStage( var block *types.Block runLoopBlocks := true batchContext := newBatchContext(ctx, &cfg, &historyCfg, s, sdb) - batchState := newBatchState(forkId, batchNumberForStateInitialization, executionAt+1, cfg.zk.HasExecutors(), cfg.zk.L1SyncStartBlock > 0, cfg.txPool) + batchState := newBatchState(forkId, batchNumberForStateInitialization, executionAt+1, cfg.zk.HasExecutors(), cfg.zk.L1SyncStartBlock > 0, cfg.txPool, resequenceBatchJob) blockDataSizeChecker := NewBlockDataChecker(cfg.zk.ShouldCountersBeUnlimited(batchState.isL1Recovery())) streamWriter := newSequencerBatchStreamWriter(batchContext, batchState) @@ -176,6 +249,18 @@ func SpawnSequencingStage( } } + if batchState.isResequence() { + if !batchState.resequenceBatchJob.HasMoreBlockToProcess() { + for streamWriter.legacyVerifier.HasPendingVerifications() { + streamWriter.CommitNewUpdates() + time.Sleep(1 * time.Second) + } + + runLoopBlocks = false + break + } + } + header, parentBlock, err := prepareHeader(sdb.tx, blockNumber-1, batchState.blockState.getDeltaTimestamp(), batchState.getBlockHeaderForcedTimestamp(), batchState.forkId, batchState.getCoinbase(&cfg), cfg.chainConfig, cfg.miningConfig) if err != nil { return err @@ -189,7 +274,7 @@ func SpawnSequencingStage( // timer: evm + smt t := utils.StartTimer("stage_sequence_execute", "evm", "smt") - infoTreeIndexProgress, l1TreeUpdate, l1TreeUpdateIndex, l1BlockHash, ger, shouldWriteGerToContract, err := prepareL1AndInfoTreeRelatedStuff(sdb, batchState, header.Time) + infoTreeIndexProgress, l1TreeUpdate, l1TreeUpdateIndex, l1BlockHash, ger, shouldWriteGerToContract, err := prepareL1AndInfoTreeRelatedStuff(sdb, batchState, header.Time, cfg.zk.SequencerResequenceReuseL1InfoIndex) if err != nil { return err } @@ -198,7 +283,7 @@ func SpawnSequencingStage( if err != nil { return err } - if !batchState.isAnyRecovery() && overflowOnNewBlock { + if (!batchState.isAnyRecovery() || batchState.isResequence()) && overflowOnNewBlock { break } @@ -240,7 +325,7 @@ func SpawnSequencingStage( if err != nil { return err } - } else if !batchState.isL1Recovery() { + } else if !batchState.isL1Recovery() && !batchState.isResequence() { var allConditionsOK bool batchState.blockState.transactionsForInclusion, allConditionsOK, err = getNextPoolTransactions(ctx, cfg, executionAt, batchState.forkId, batchState.yieldedTransactions) if err != nil { @@ -256,6 +341,17 @@ func SpawnSequencingStage( } else { log.Trace(fmt.Sprintf("[%s] Yielded transactions from the pool", logPrefix), "txCount", len(batchState.blockState.transactionsForInclusion)) } + } else if batchState.isResequence() { + batchState.blockState.transactionsForInclusion, err = batchState.resequenceBatchJob.YieldNextBlockTransactions(zktx.DecodeTx) + if err != nil { + return err + } + } + + if len(batchState.blockState.transactionsForInclusion) == 0 { + time.Sleep(batchContext.cfg.zk.SequencerTimeoutOnEmptyTxPool) + } else { + log.Trace(fmt.Sprintf("[%s] Yielded transactions from the pool", logPrefix), "txCount", len(batchState.blockState.transactionsForInclusion)) } for i, transaction := range batchState.blockState.transactionsForInclusion { @@ -270,6 +366,18 @@ func SpawnSequencingStage( panic("limbo transaction has already been executed once so they must not fail while re-executing") } + if batchState.isResequence() { + if cfg.zk.SequencerResequenceStrict { + return fmt.Errorf("strict mode enabled, but resequenced batch %d failed to add transaction %s: %v", batchState.batchNumber, txHash, err) + } else { + log.Warn(fmt.Sprintf("[%s] error adding transaction to batch during resequence: %v", logPrefix, err), + "hash", txHash, + "to", transaction.GetTo(), + ) + continue + } + } + // if we are in recovery just log the error as a warning. If the data is on the L1 then we should consider it as confirmed. // The executor/prover would simply skip a TX with an invalid nonce for example so we don't need to worry about that here. if batchState.isL1Recovery() { @@ -311,12 +419,39 @@ func SpawnSequencingStage( break LOOP_TRANSACTIONS } + if batchState.isResequence() && cfg.zk.SequencerResequenceStrict { + return fmt.Errorf("strict mode enabled, but resequenced batch %d overflowed counters on block %d", batchState.batchNumber, blockNumber) + } + + break LOOP_TRANSACTIONS } if err == nil { blockDataSizeChecker = &backupDataSizeChecker batchState.onAddedTransaction(transaction, receipt, execResult, effectiveGas) } + + // We will only update the processed index in resequence job if there isn't overflow + if batchState.isResequence() { + batchState.resequenceBatchJob.UpdateLastProcessedTx(txHash) + } + } + + if batchState.isResequence() { + if len(batchState.blockState.transactionsForInclusion) == 0 { + // We need to jump to the next block here if there are no transactions in current block + batchState.resequenceBatchJob.UpdateLastProcessedTx(batchState.resequenceBatchJob.CurrentBlock().L2Blockhash) + break LOOP_TRANSACTIONS + } + + if batchState.resequenceBatchJob.AtNewBlockBoundary() { + // We need to jump to the next block here if we are at the end of the current block + break LOOP_TRANSACTIONS + } else { + if cfg.zk.SequencerResequenceStrict { + return fmt.Errorf("strict mode enabled, but resequenced batch %d has transactions that overflowed counters or failed transactions", batchState.batchNumber) + } + } } if batchState.isL1Recovery() { diff --git a/zk/stages/stage_sequence_execute_batch.go b/zk/stages/stage_sequence_execute_batch.go index abe9f2764e1..7c45791e90c 100644 --- a/zk/stages/stage_sequence_execute_batch.go +++ b/zk/stages/stage_sequence_execute_batch.go @@ -24,8 +24,16 @@ func prepareBatchNumber(sdb *stageDb, forkId, lastBatch uint64, isL1Recovery boo return 0, err } - if len(blockNumbersInBatchSoFar) < len(recoveredBatchData.DecodedData) { - return lastBatch, nil + if len(blockNumbersInBatchSoFar) < len(recoveredBatchData.DecodedData) { // check if there are more blocks to process + isLastBatchBad, err := sdb.hermezDb.GetInvalidBatch(lastBatch) + if err != nil { + return 0, err + } + + // if last batch is not bad then continue buildingin it, otherwise return lastBatch+1 (at the end of the function) + if !isLastBatchBad { + return lastBatch, nil + } } } diff --git a/zk/stages/stage_sequence_execute_state.go b/zk/stages/stage_sequence_execute_state.go index 24acf695ec5..53d7f5ed0dc 100644 --- a/zk/stages/stage_sequence_execute_state.go +++ b/zk/stages/stage_sequence_execute_state.go @@ -45,9 +45,10 @@ type BatchState struct { blockState *BlockState batchL1RecoveryData *BatchL1RecoveryData limboRecoveryData *LimboRecoveryData + resequenceBatchJob *ResequenceBatchJob } -func newBatchState(forkId, batchNumber, blockNumber uint64, hasExecutorForThisBatch, l1Recovery bool, txPool *txpool.TxPool) *BatchState { +func newBatchState(forkId, batchNumber, blockNumber uint64, hasExecutorForThisBatch, l1Recovery bool, txPool *txpool.TxPool, resequenceBatchJob *ResequenceBatchJob) *BatchState { batchState := &BatchState{ forkId: forkId, batchNumber: batchNumber, @@ -58,29 +59,32 @@ func newBatchState(forkId, batchNumber, blockNumber uint64, hasExecutorForThisBa blockState: newBlockState(), batchL1RecoveryData: nil, limboRecoveryData: nil, + resequenceBatchJob: resequenceBatchJob, } - if l1Recovery { - batchState.batchL1RecoveryData = newBatchL1RecoveryData(batchState) - } + if batchNumber != injectedBatchBatchNumber { // process injected batch regularly, no matter if it is in any recovery + if l1Recovery { + batchState.batchL1RecoveryData = newBatchL1RecoveryData(batchState) + } - limboBlock, limboTxHash := txPool.GetLimboDetailsForRecovery(blockNumber) - if limboTxHash != nil { - // batchNumber == limboBlock.BatchNumber then we've unwound to the very beginning of the batch. 'limboBlock.BlockNumber' is the 1st block of 'batchNumber' batch. Everything is fine. + limboBlock, limboTxHash := txPool.GetLimboDetailsForRecovery(blockNumber) + if limboTxHash != nil { + // batchNumber == limboBlock.BatchNumber then we've unwound to the very beginning of the batch. 'limboBlock.BlockNumber' is the 1st block of 'batchNumber' batch. Everything is fine. - // batchNumber - 1 == limboBlock.BatchNumber then we've unwound to the middle of a batch. We must set in 'batchState' that we're going to resume a batch build rather than starting a new one. Everything is fine. - if batchNumber-1 == limboBlock.BatchNumber { - batchState.batchNumber = limboBlock.BatchNumber - } else if batchNumber != limboBlock.BatchNumber { - // in any other configuration rather than (batchNumber or batchNumber - 1) == limboBlock.BatchNumber we can only panic - panic(fmt.Errorf("requested batch %d while the network is already on %d", limboBlock.BatchNumber, batchNumber)) - } + // batchNumber - 1 == limboBlock.BatchNumber then we've unwound to the middle of a batch. We must set in 'batchState' that we're going to resume a batch build rather than starting a new one. Everything is fine. + if batchNumber-1 == limboBlock.BatchNumber { + batchState.batchNumber = limboBlock.BatchNumber + } else if batchNumber != limboBlock.BatchNumber { + // in any other configuration rather than (batchNumber or batchNumber - 1) == limboBlock.BatchNumber we can only panic + panic(fmt.Errorf("requested batch %d while the network is already on %d", limboBlock.BatchNumber, batchNumber)) + } - batchState.limboRecoveryData = newLimboRecoveryData(limboBlock.BlockTimestamp, limboTxHash) - } + batchState.limboRecoveryData = newLimboRecoveryData(limboBlock.BlockTimestamp, limboTxHash) + } - if batchState.isL1Recovery() && batchState.isLimboRecovery() { - panic("Both recoveries cannot be active simultaneously") + if batchState.isL1Recovery() && batchState.isLimboRecovery() { + panic("Both recoveries cannot be active simultaneously") + } } return batchState @@ -94,8 +98,12 @@ func (bs *BatchState) isLimboRecovery() bool { return bs.limboRecoveryData != nil } +func (bs *BatchState) isResequence() bool { + return bs.resequenceBatchJob != nil +} + func (bs *BatchState) isAnyRecovery() bool { - return bs.isL1Recovery() || bs.isLimboRecovery() + return bs.isL1Recovery() || bs.isLimboRecovery() || bs.isResequence() } func (bs *BatchState) isThereAnyTransactionsToRecover() bool { @@ -118,6 +126,10 @@ func (bs *BatchState) getBlockHeaderForcedTimestamp() uint64 { return bs.limboRecoveryData.limboHeaderTimestamp } + if bs.isResequence() { + return uint64(bs.resequenceBatchJob.CurrentBlock().Timestamp) + } + return math.MaxUint64 } diff --git a/zk/stages/stage_sequence_execute_utils.go b/zk/stages/stage_sequence_execute_utils.go index ea31be4b0cc..48b6d33be21 100644 --- a/zk/stages/stage_sequence_execute_utils.go +++ b/zk/stages/stage_sequence_execute_utils.go @@ -31,6 +31,7 @@ import ( "github.com/ledgerwatch/erigon/turbo/shards" "github.com/ledgerwatch/erigon/turbo/stages/headerdownload" "github.com/ledgerwatch/erigon/zk/datastream/server" + dsTypes "github.com/ledgerwatch/erigon/zk/datastream/types" "github.com/ledgerwatch/erigon/zk/hermez_db" verifier "github.com/ledgerwatch/erigon/zk/legacy_executor_verifier" "github.com/ledgerwatch/erigon/zk/tx" @@ -252,7 +253,12 @@ func prepareHeader(tx kv.RwTx, previousBlockNumber, deltaTimestamp, forcedTimest return header, parentBlock, nil } -func prepareL1AndInfoTreeRelatedStuff(sdb *stageDb, batchState *BatchState, proposedTimestamp uint64) ( +func prepareL1AndInfoTreeRelatedStuff( + sdb *stageDb, + batchState *BatchState, + proposedTimestamp uint64, + reuseL1InfoIndex bool, +) ( infoTreeIndexProgress uint64, l1TreeUpdate *zktypes.L1InfoTreeUpdate, l1TreeUpdateIndex uint64, @@ -270,8 +276,17 @@ func prepareL1AndInfoTreeRelatedStuff(sdb *stageDb, batchState *BatchState, prop return } - if batchState.isL1Recovery() { - l1TreeUpdateIndex = uint64(batchState.blockState.blockL1RecoveryData.L1InfoTreeIndex) + if batchState.isL1Recovery() || (batchState.isResequence() && reuseL1InfoIndex) { + if batchState.isL1Recovery() { + l1TreeUpdateIndex = uint64(batchState.blockState.blockL1RecoveryData.L1InfoTreeIndex) + } else { + // Resequence mode: + // If we are resequencing at the beginning (AtNewBlockBoundary->true) of a rolledback block, we need to reuse the l1TreeUpdateIndex from the block. + // If we are in the middle of a block (AtNewBlockBoundary -> false), it means the original block will be requenced into multiple blocks, so we will leave l1TreeUpdateIndex as 0 for the rest of blocks. + if batchState.resequenceBatchJob.AtNewBlockBoundary() { + l1TreeUpdateIndex = uint64(batchState.resequenceBatchJob.CurrentBlock().L1InfoTreeIndex) + } + } if l1TreeUpdate, err = sdb.hermezDb.GetL1InfoTreeUpdate(l1TreeUpdateIndex); err != nil { return } @@ -508,3 +523,78 @@ func (bdc *BlockDataChecker) AddTransactionData(txL2Data []byte) bool { return false } + +type txMatadata struct { + blockNum int + txIndex int +} + +type ResequenceBatchJob struct { + batchToProcess []*dsTypes.FullL2Block + StartBlockIndex int + StartTxIndex int + txIndexMap map[common.Hash]txMatadata +} + +func NewResequenceBatchJob(batch []*dsTypes.FullL2Block) *ResequenceBatchJob { + return &ResequenceBatchJob{ + batchToProcess: batch, + StartBlockIndex: 0, + StartTxIndex: 0, + txIndexMap: make(map[common.Hash]txMatadata), + } +} + +func (r *ResequenceBatchJob) HasMoreBlockToProcess() bool { + return r.StartBlockIndex < len(r.batchToProcess) +} + +func (r *ResequenceBatchJob) AtNewBlockBoundary() bool { + return r.StartTxIndex == 0 +} + +func (r *ResequenceBatchJob) CurrentBlock() *dsTypes.FullL2Block { + if r.HasMoreBlockToProcess() { + return r.batchToProcess[r.StartBlockIndex] + } + return nil +} + +func (r *ResequenceBatchJob) YieldNextBlockTransactions(decoder zktx.TxDecoder) ([]types.Transaction, error) { + blockTransactions := make([]types.Transaction, 0) + if r.HasMoreBlockToProcess() { + block := r.CurrentBlock() + r.txIndexMap[block.L2Blockhash] = txMatadata{r.StartBlockIndex, 0} + + for i := r.StartTxIndex; i < len(block.L2Txs); i++ { + transaction := block.L2Txs[i] + tx, _, err := decoder(transaction.Encoded, transaction.EffectiveGasPricePercentage, block.ForkId) + if err != nil { + return nil, fmt.Errorf("decode tx error: %v", err) + } + r.txIndexMap[tx.Hash()] = txMatadata{r.StartBlockIndex, i} + blockTransactions = append(blockTransactions, tx) + } + } + + return blockTransactions, nil +} + +func (r *ResequenceBatchJob) UpdateLastProcessedTx(h common.Hash) { + if idx, ok := r.txIndexMap[h]; ok { + block := r.batchToProcess[idx.blockNum] + + if idx.txIndex >= len(block.L2Txs)-1 { + // we've processed all the transactions in this block + // move to the next block + r.StartBlockIndex = idx.blockNum + 1 + r.StartTxIndex = 0 + } else { + // move to the next transaction in the block + r.StartBlockIndex = idx.blockNum + r.StartTxIndex = idx.txIndex + 1 + } + } else { + log.Warn("tx hash not found in tx index map", "hash", h) + } +} diff --git a/zk/stages/stage_sequence_execute_utils_test.go b/zk/stages/stage_sequence_execute_utils_test.go index 3ff72032840..2b859923f02 100644 --- a/zk/stages/stage_sequence_execute_utils_test.go +++ b/zk/stages/stage_sequence_execute_utils_test.go @@ -1,8 +1,13 @@ package stages import ( + "reflect" "testing" + "github.com/holiman/uint256" + "github.com/ledgerwatch/erigon-lib/common" + "github.com/ledgerwatch/erigon/core/types" + dsTypes "github.com/ledgerwatch/erigon/zk/datastream/types" zktx "github.com/ledgerwatch/erigon/zk/tx" zktypes "github.com/ledgerwatch/erigon/zk/types" ) @@ -207,3 +212,252 @@ func Test_PrepareForkId_DuringRecovery(t *testing.T) { }) } } + +// Mock implementation of zktx.DecodeTx for testing purposes +func mockDecodeTx(encoded []byte, effectiveGasPricePercentage byte, forkId uint64) (types.Transaction, uint8, error) { + return types.NewTransaction(0, common.Address{}, uint256.NewInt(0), 0, uint256.NewInt(0), encoded), 0, nil +} + +func TestResequenceBatchJob_HasMoreToProcess(t *testing.T) { + tests := []struct { + name string + job ResequenceBatchJob + expected bool + }{ + { + name: "Has more blocks to process", + job: ResequenceBatchJob{ + batchToProcess: []*dsTypes.FullL2Block{{}, {}}, + StartBlockIndex: 1, + StartTxIndex: 0, + }, + expected: true, + }, + { + name: "Has more transactions to process", + job: ResequenceBatchJob{ + batchToProcess: []*dsTypes.FullL2Block{{L2Txs: []dsTypes.L2TransactionProto{{}, {}}}}, + StartBlockIndex: 0, + StartTxIndex: 0, + }, + expected: true, + }, + { + name: "No more to process", + job: ResequenceBatchJob{ + batchToProcess: []*dsTypes.FullL2Block{{}}, + StartBlockIndex: 1, + StartTxIndex: 0, + }, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.job.HasMoreBlockToProcess(); got != tt.expected { + t.Errorf("ResequenceBatchJob.HasMoreBlockToProcess() = %v, want %v", got, tt.expected) + } + }) + } +} + +func TestResequenceBatchJob_CurrentBlock(t *testing.T) { + tests := []struct { + name string + job ResequenceBatchJob + expected *dsTypes.FullL2Block + }{ + { + name: "Has current block", + job: ResequenceBatchJob{ + batchToProcess: []*dsTypes.FullL2Block{{L2BlockNumber: 1}, {L2BlockNumber: 2}}, + StartBlockIndex: 0, + StartTxIndex: 0, + }, + expected: &dsTypes.FullL2Block{L2BlockNumber: 1}, + }, + { + name: "No current block", + job: ResequenceBatchJob{ + batchToProcess: []*dsTypes.FullL2Block{{L2BlockNumber: 1}}, + StartBlockIndex: 1, + StartTxIndex: 0, + }, + expected: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.job.CurrentBlock() + if (got == nil && tt.expected != nil) || (got != nil && tt.expected == nil) { + t.Errorf("ResequenceBatchJob.CurrentBlock() = %v, want %v", got, tt.expected) + } + if got != nil && tt.expected != nil && got.L2BlockNumber != tt.expected.L2BlockNumber { + t.Errorf("ResequenceBatchJob.CurrentBlock().L2BlockNumber = %v, want %v", got.L2BlockNumber, tt.expected.L2BlockNumber) + } + }) + } +} + +func TestResequenceBatchJob_YieldNextBlockTransactions(t *testing.T) { + // Replace the actual zktx.DecodeTx with our mock function for testing + + tests := []struct { + name string + job ResequenceBatchJob + expectedTxCount int + expectedError bool + }{ + { + name: "Yield transactions", + job: ResequenceBatchJob{ + batchToProcess: []*dsTypes.FullL2Block{ + { + L2Txs: []dsTypes.L2TransactionProto{{}, {}}, + ForkId: 1, + }, + }, + StartBlockIndex: 0, + StartTxIndex: 0, + txIndexMap: make(map[common.Hash]txMatadata), + }, + expectedTxCount: 2, + expectedError: false, + }, + { + name: "No transactions to yield", + job: ResequenceBatchJob{ + batchToProcess: []*dsTypes.FullL2Block{{}}, + StartBlockIndex: 1, + StartTxIndex: 0, + txIndexMap: make(map[common.Hash]txMatadata), + }, + expectedTxCount: 0, + expectedError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + txs, err := tt.job.YieldNextBlockTransactions(mockDecodeTx) + if (err != nil) != tt.expectedError { + t.Errorf("ResequenceBatchJob.YieldNextBlockTransactions() error = %v, expectedError %v", err, tt.expectedError) + return + } + if len(txs) != tt.expectedTxCount { + t.Errorf("ResequenceBatchJob.YieldNextBlockTransactions() returned %d transactions, expected %d", len(txs), tt.expectedTxCount) + } + }) + } +} + +func TestResequenceBatchJob_YieldAndUpdate(t *testing.T) { + // Setup the batch + batch := []*dsTypes.FullL2Block{ + {L2Txs: []dsTypes.L2TransactionProto{{Encoded: []byte("1")}, {Encoded: []byte("2")}}, L2Blockhash: common.HexToHash("0")}, + {L2Txs: []dsTypes.L2TransactionProto{}, L2Blockhash: common.HexToHash("1")}, + {L2Txs: []dsTypes.L2TransactionProto{}, L2Blockhash: common.HexToHash("2")}, + {L2Txs: []dsTypes.L2TransactionProto{{Encoded: []byte("3")}, {Encoded: []byte("4")}}, L2Blockhash: common.HexToHash("3")}, + } + + job := ResequenceBatchJob{ + batchToProcess: batch, + StartBlockIndex: 0, + StartTxIndex: 1, // Start at block 0, index 1 + txIndexMap: make(map[common.Hash]txMatadata), + } + + processTransactions := func(txs []types.Transaction) { + for _, tx := range txs { + job.UpdateLastProcessedTx(tx.Hash()) + } + } + + // First call - should yield transaction 2 from block 0 + txs, err := job.YieldNextBlockTransactions(mockDecodeTx) + if err != nil { + t.Fatalf("First call: Unexpected error: %v", err) + } + if len(txs) != 1 || string(txs[0].GetData()) != "2" { + t.Errorf("Expected 1 transaction with data '2', got %d transactions with data '%s'", len(txs), string(txs[0].GetData())) + } + processTransactions(txs) + tx2 := txs[0] + + // Second call - should yield empty block (block 1) + txs, err = job.YieldNextBlockTransactions(mockDecodeTx) + if err != nil { + t.Fatalf("Second call: Unexpected error: %v", err) + } + if len(txs) != 0 { + t.Errorf("Expected 0 transactions, got %d", len(txs)) + } + job.UpdateLastProcessedTx(job.CurrentBlock().L2Blockhash) + + // Third call - should yield empty block (block 2) + txs, err = job.YieldNextBlockTransactions(mockDecodeTx) + if err != nil { + t.Fatalf("Third call: Unexpected error: %v", err) + } + if len(txs) != 0 { + t.Errorf("Expected 0 transactions, got %d", len(txs)) + } + job.UpdateLastProcessedTx(job.CurrentBlock().L2Blockhash) + + // Fourth call - should yield transactions 3 and 4, but we'll only process 3 + txs, err = job.YieldNextBlockTransactions(mockDecodeTx) + if err != nil { + t.Fatalf("Fourth call: Unexpected error: %v", err) + } + if len(txs) != 2 || string(txs[0].GetData()) != "3" || string(txs[1].GetData()) != "4" { + t.Errorf("Expected 2 transactions with data '3' and '4', got %d transactions", len(txs)) + } + processTransactions(txs[:1]) // Only process the first transaction (3) + tx3 := txs[0] + tx4 := txs[1] + + // Check final state + if job.StartBlockIndex != 3 { + t.Errorf("Expected StartBlockIndex to be 3, got %d", job.StartBlockIndex) + } + + if job.StartTxIndex != 1 { + t.Errorf("Expected StartTxIndex to be 1, got %d", job.StartTxIndex) + } + + // Final call - should yield transaction 4 + txs, err = job.YieldNextBlockTransactions(mockDecodeTx) + if err != nil { + t.Fatalf("Final call: Unexpected error: %v", err) + } + if len(txs) != 1 || string(txs[0].GetData()) != "4" { + t.Errorf("Expected 1 transaction with data '4', got %d transactions", len(txs)) + } + + processTransactions(txs) + + if job.HasMoreBlockToProcess() { + t.Errorf("Expected no more blocks to process") + } + + // Verify txIndexMap + expectedTxIndexMap := map[common.Hash]txMatadata{ + common.HexToHash("0"): {0, 0}, + common.HexToHash("1"): {1, 0}, + common.HexToHash("2"): {2, 0}, + common.HexToHash("3"): {3, 0}, + tx2.Hash(): {0, 1}, // Transaction 2 + tx3.Hash(): {3, 0}, // Transaction 3 + tx4.Hash(): {3, 1}, // Transaction 4 + } + + for hash, index := range expectedTxIndexMap { + if actualIndex, exists := job.txIndexMap[hash]; !exists { + t.Errorf("Expected hash %s to exist in txIndexMap", hash.Hex()) + } else if !reflect.DeepEqual(actualIndex, index) { + t.Errorf("For hash %s, expected index %v, got %v", hash.Hex(), index, actualIndex) + } + } +} diff --git a/zk/stages/test_utils.go b/zk/stages/test_utils.go index 62b130a9fa0..df250ebf717 100644 --- a/zk/stages/test_utils.go +++ b/zk/stages/test_utils.go @@ -14,6 +14,7 @@ type TestDatastreamClient struct { progress atomic.Uint64 entriesChan chan interface{} errChan chan error + isStarted bool } func NewTestDatastreamClient(fullL2Blocks []types.FullL2Block, gerUpdates []types.GerUpdate) *TestDatastreamClient { @@ -33,11 +34,12 @@ func (c *TestDatastreamClient) EnsureConnected() (bool, error) { func (c *TestDatastreamClient) ReadAllEntriesToChannel() error { c.streamingAtomic.Store(true) + defer c.streamingAtomic.Swap(false) - for i, _ := range c.fullL2Blocks { + for i := range c.fullL2Blocks { c.entriesChan <- &c.fullL2Blocks[i] } - for i, _ := range c.gerUpdates { + for i := range c.gerUpdates { c.entriesChan <- &c.gerUpdates[i] } @@ -52,12 +54,44 @@ func (c *TestDatastreamClient) GetErrChan() chan error { return c.errChan } +func (c *TestDatastreamClient) GetL2BlockByNumber(blockNum uint64) (*types.FullL2Block, int, error) { + for _, l2Block := range c.fullL2Blocks { + if l2Block.L2BlockNumber == blockNum { + return &l2Block, types.CmdErrOK, nil + } + } + + return nil, -1, nil +} + +func (c *TestDatastreamClient) GetLatestL2Block() (*types.FullL2Block, error) { + if len(c.fullL2Blocks) == 0 { + return nil, nil + } + return &c.fullL2Blocks[len(c.fullL2Blocks)-1], nil +} + func (c *TestDatastreamClient) GetLastWrittenTimeAtomic() *atomic.Int64 { return &c.lastWrittenTimeAtomic } + func (c *TestDatastreamClient) GetStreamingAtomic() *atomic.Bool { return &c.streamingAtomic } + func (c *TestDatastreamClient) GetProgressAtomic() *atomic.Uint64 { return &c.progress } + +func (c *TestDatastreamClient) ReadBatches(start uint64, end uint64) ([][]*types.FullL2Block, error) { + return nil, nil +} + +func (c *TestDatastreamClient) Start() error { + c.isStarted = true + return nil +} + +func (c *TestDatastreamClient) Stop() { + c.isStarted = false +} diff --git a/zk/syncer/l1_syncer.go b/zk/syncer/l1_syncer.go index 15b6ed2e160..0ae2c7f960c 100644 --- a/zk/syncer/l1_syncer.go +++ b/zk/syncer/l1_syncer.go @@ -514,3 +514,13 @@ func (s *L1Syncer) callGetAddress(ctx context.Context, addr *common.Address, dat return common.BytesToAddress(resp[len(resp)-20:]), nil } + +func (s *L1Syncer) CheckL1BlockFinalized(blockNo uint64) (finalized bool, finalizedBn uint64, err error) { + em := s.getNextEtherman() + block, err := em.BlockByNumber(s.ctx, big.NewInt(rpc.FinalizedBlockNumber.Int64())) + if err != nil { + return false, 0, err + } + + return block.NumberU64() >= blockNo, block.NumberU64(), nil +} diff --git a/zk/tx/tx.go b/zk/tx/tx.go index c938ccf5902..d2e5ce8a0bf 100644 --- a/zk/tx/tx.go +++ b/zk/tx/tx.go @@ -189,6 +189,8 @@ func DecodeBatchL2Blocks(txsData []byte, forkID uint64) ([]DecodedBatchL2Data, e return result, nil } +type TxDecoder func(encodedTx []byte, gasPricePercentage uint8, forkID uint64) (types.Transaction, uint8, error) + func DecodeTx(encodedTx []byte, efficiencyPercentage byte, forkId uint64) (types.Transaction, uint8, error) { // efficiencyPercentage := uint8(0) if forkId >= uint64(constants.ForkID5Dragonfruit) { @@ -500,11 +502,7 @@ func ComputeL2TxHash( } hash += fromPart - hashed, err := utils.HashContractBytecode(hash) - if err != nil { - return common.Hash{}, err - } - + hashed := utils.HashContractBytecode(hash) return common.HexToHash(hashed), nil } diff --git a/zk/types/zk_types.go b/zk/types/zk_types.go index c48c0c934e9..7734c2c6eb2 100644 --- a/zk/types/zk_types.go +++ b/zk/types/zk_types.go @@ -7,11 +7,11 @@ import ( "bytes" "encoding/binary" + "fmt" "github.com/holiman/uint256" "github.com/ledgerwatch/erigon/cl/utils" ethTypes "github.com/ledgerwatch/erigon/core/types" - "fmt" ) const EFFECTIVE_GAS_PRICE_PERCENTAGE_DISABLED = 0 @@ -118,3 +118,10 @@ func (ib *L1InjectedBatch) Unmarshall(input []byte) error { ib.Transaction = append([]byte{}, input[132:]...) return nil } + +type ForkInterval struct { + ForkID uint64 + FromBatchNumber uint64 + ToBatchNumber uint64 + BlockNumber uint64 +}