diff --git a/core/genesis_write.go b/core/genesis_write.go index e90c5781da6..feca2c0c68e 100644 --- a/core/genesis_write.go +++ b/core/genesis_write.go @@ -27,12 +27,12 @@ import ( "sync" "github.com/c2h5oh/datasize" - "github.com/holiman/uint256" erigonchain "github.com/gateway-fm/cdk-erigon-lib/chain" libcommon "github.com/gateway-fm/cdk-erigon-lib/common" "github.com/gateway-fm/cdk-erigon-lib/kv" "github.com/gateway-fm/cdk-erigon-lib/kv/mdbx" "github.com/gateway-fm/cdk-erigon-lib/kv/rawdbv3" + "github.com/holiman/uint256" "github.com/ledgerwatch/erigon/chain" "github.com/ledgerwatch/log/v3" @@ -49,8 +49,8 @@ import ( "github.com/ledgerwatch/erigon/eth/ethconfig" "github.com/ledgerwatch/erigon/params" "github.com/ledgerwatch/erigon/params/networkname" - "github.com/ledgerwatch/erigon/smt/pkg/smt" eridb "github.com/ledgerwatch/erigon/smt/pkg/db" + "github.com/ledgerwatch/erigon/smt/pkg/smt" "golang.org/x/exp/slices" ) @@ -525,7 +525,7 @@ func GenesisToBlock(g *types.Genesis, tmpDir string) (*types.Block, *state.Intra wg.Add(1) var err error sparseDb := eridb.NewMemDb() - sparseTree := smt.NewSMT(sparseDb) + sparseTree := smt.NewSMT(sparseDb, false) go func() { // we may run inside write tx, can't open 2nd write tx in same goroutine // TODO(yperbasis): use memdb.MemoryMutation instead defer wg.Done() diff --git a/smt/pkg/blockinfo/block_info.go b/smt/pkg/blockinfo/block_info.go index 3f2e97bbdf3..612df1504b3 100644 --- a/smt/pkg/blockinfo/block_info.go +++ b/smt/pkg/blockinfo/block_info.go @@ -104,7 +104,7 @@ type BlockInfoTree struct { func NewBlockInfoTree() *BlockInfoTree { return &BlockInfoTree{ - smt: smt.NewSMT(nil), + smt: smt.NewSMT(nil, true), } } func (b *BlockInfoTree) GetRoot() *big.Int { @@ -147,83 +147,56 @@ func (b *BlockInfoTree) GenerateBlockHeader(oldBlockHash *common.Hash, coinbase } func generateL2BlockHash(blockHash *common.Hash) (key *utils.NodeKey, value *utils.NodeValue8, err error) { - if key, err = KeyBlockHeaderParams(big.NewInt(IndexBlockHeaderParamBlockHash)); err != nil { - return nil, nil, err - } if value, err = bigInt2NodeVal8(blockHash.Big()); err != nil { return nil, nil, err } - return key, value, nil + return &BlockHeaderBlockHashKey, value, nil } func generateCoinbase(coinbase *common.Address) (key *utils.NodeKey, value *utils.NodeValue8, err error) { - if key, err = KeyBlockHeaderParams(big.NewInt(IndexBlockHeaderParamCoinbase)); err != nil { - return nil, nil, err - } - if value, err = bigInt2NodeVal8(coinbase.Hash().Big()); err != nil { return nil, nil, err } - return key, value, nil + return &BlockHeaderCoinbaseKey, value, nil } func generateGasLimit(gasLimit uint64) (key *utils.NodeKey, value *utils.NodeValue8, err error) { - if key, err = KeyBlockHeaderParams(big.NewInt(IndexBlockHeaderParamGasLimit)); err != nil { - return nil, nil, err - } - if value, err = bigInt2NodeVal8(big.NewInt(0).SetUint64(gasLimit)); err != nil { return nil, nil, err } - return key, value, nil + return &BlockHeaderGasLimitKey, value, nil } func generateBlockNumber(blockNumber uint64) (key *utils.NodeKey, value *utils.NodeValue8, err error) { - if key, err = KeyBlockHeaderParams(big.NewInt(IndexBlockHeaderParamNumber)); err != nil { - return nil, nil, err - } - if value, err = bigInt2NodeVal8(big.NewInt(0).SetUint64(blockNumber)); err != nil { return nil, nil, err } - return key, value, nil + return &BlockHeaderNumberKey, value, nil } func generateTimestamp(timestamp uint64) (key *utils.NodeKey, value *utils.NodeValue8, err error) { - if key, err = KeyBlockHeaderParams(big.NewInt(IndexBlockHeaderParamTimestamp)); err != nil { - return nil, nil, err - } - if value, err = bigInt2NodeVal8(big.NewInt(0).SetUint64(timestamp)); err != nil { return nil, nil, err } - return key, value, nil + return &BlockHeaderTimestampKey, value, nil } func generateGer(ger *common.Hash) (key *utils.NodeKey, value *utils.NodeValue8, err error) { - if key, err = KeyBlockHeaderParams(big.NewInt(IndexBlockHeaderParamGer)); err != nil { - return nil, nil, err - } - if value, err = bigInt2NodeVal8(ger.Big()); err != nil { return nil, nil, err } - return key, value, nil + return &BlockHeaderGerKey, value, nil } func generateL1BlockHash(blockHash *common.Hash) (key *utils.NodeKey, value *utils.NodeValue8, err error) { - if key, err = KeyBlockHeaderParams(big.NewInt(IndexBlockHeaderParamBlockHashL1)); err != nil { - return nil, nil, err - } - if value, err = bigInt2NodeVal8(blockHash.Big()); err != nil { return nil, nil, err } - return key, value, nil + return &BlockHeaderBlockHashL1Key, value, nil } func bigInt2NodeVal8(val *big.Int) (*utils.NodeValue8, error) { @@ -291,15 +264,12 @@ func generateTxEffectivePercentage(txIndex, effectivePercentage *big.Int) (key * } func generateBlockGasUsed(gasUsed uint64) (key *utils.NodeKey, value *utils.NodeValue8, err error) { - if key, err = KeyBlockHeaderParams(big.NewInt(IndexBlockHeaderParamGasUsed)); err != nil { - return nil, nil, err - } gasUsedBig := big.NewInt(0).SetUint64(gasUsed) if value, err = bigInt2NodeVal8(gasUsedBig); err != nil { return nil, nil, err } - return key, value, nil + return &BlockHeaderGasUsedKey, value, nil } func (b *BlockInfoTree) GenerateBlockTxKeysVals( diff --git a/smt/pkg/blockinfo/block_info_test.go b/smt/pkg/blockinfo/block_info_test.go index 1c1f13dd3f9..42ddfaea093 100644 --- a/smt/pkg/blockinfo/block_info_test.go +++ b/smt/pkg/blockinfo/block_info_test.go @@ -288,7 +288,7 @@ func TestSetL2BlockHash(t *testing.T) { if err != nil { t.Fatal(err) } - smt := smt.NewSMT(nil) + smt := smt.NewSMT(nil, true) root, err2 := smt.InsertKA(*key, smtutils.NodeValue8ToBigInt(val)) if err2 != nil { @@ -314,7 +314,7 @@ func TestSetCoinbase(t *testing.T) { } for i, test := range tests { - smt := smt.NewSMT(nil) + smt := smt.NewSMT(nil, true) coinbaseAddress := common.HexToAddress(test.coinbaseAddress) key, val, err := generateCoinbase(&coinbaseAddress) @@ -348,7 +348,7 @@ func TestSetBlockNumber(t *testing.T) { } for i, test := range tests { - smt := smt.NewSMT(nil) + smt := smt.NewSMT(nil, true) key, val, err := generateBlockNumber(test.blockNum) if err != nil { @@ -379,7 +379,7 @@ func TestSetGasLimit(t *testing.T) { } for i, test := range tests { - smt := smt.NewSMT(nil) + smt := smt.NewSMT(nil, true) key, val, err := generateGasLimit(test.gasLimit) if err != nil { @@ -410,7 +410,7 @@ func TestSetTimestamp(t *testing.T) { } for i, test := range tests { - smt := smt.NewSMT(nil) + smt := smt.NewSMT(nil, true) key, val, err := generateTimestamp(test.timestamp) if err != nil { @@ -447,7 +447,7 @@ func TestSetGer(t *testing.T) { } for i, test := range tests { - smt := smt.NewSMT(nil) + smt := smt.NewSMT(nil, true) ger := common.HexToHash(test.ger) key, val, err := generateGer(&ger) @@ -485,7 +485,7 @@ func TestSetL1BlockHash(t *testing.T) { } for i, test := range tests { - smt := smt.NewSMT(nil) + smt := smt.NewSMT(nil, true) l1BlockHash := common.HexToHash(test.l1BlockHash) key, val, err := generateL1BlockHash(&l1BlockHash) @@ -506,7 +506,7 @@ func TestSetL1BlockHash(t *testing.T) { } func TestSetL2TxHash(t *testing.T) { - smt := smt.NewSMT(nil) + smt := smt.NewSMT(nil, true) txIndex := big.NewInt(1) l2TxHash := common.HexToHash("0x000000000000000000000000000000005Ca1aB1E").Big() @@ -529,7 +529,7 @@ func TestSetL2TxHash(t *testing.T) { } func TestSetTxStatus(t *testing.T) { - smt := smt.NewSMT(nil) + smt := smt.NewSMT(nil, true) txIndex := big.NewInt(1) status := common.HexToHash("0x000000000000000000000000000000005Ca1aB1E").Big() @@ -552,7 +552,7 @@ func TestSetTxStatus(t *testing.T) { } func TestSetCumulativeGasUsed(t *testing.T) { - smt := smt.NewSMT(nil) + smt := smt.NewSMT(nil, true) txIndex := big.NewInt(1) cgu := common.HexToHash("0x000000000000000000000000000000005Ca1aB1E").Big() @@ -576,7 +576,7 @@ func TestSetCumulativeGasUsed(t *testing.T) { } func TestSetTxEffectivePercentage(t *testing.T) { - smt := smt.NewSMT(nil) + smt := smt.NewSMT(nil, true) txIndex := big.NewInt(1) egp := common.HexToHash("0x000000000000000000000000000000005Ca1aB1E").Big() @@ -600,7 +600,7 @@ func TestSetTxEffectivePercentage(t *testing.T) { } func TestSetTxLogs(t *testing.T) { - smt := smt.NewSMT(nil) + smt := smt.NewSMT(nil, true) txIndex := big.NewInt(1) logIndex := big.NewInt(1) log := common.HexToHash("0x000000000000000000000000000000005Ca1aB1E").Big() diff --git a/smt/pkg/blockinfo/keys.go b/smt/pkg/blockinfo/keys.go index de3a8a4d17c..af9cfdef475 100644 --- a/smt/pkg/blockinfo/keys.go +++ b/smt/pkg/blockinfo/keys.go @@ -17,6 +17,18 @@ const IndexBlockHeaderParamGer = 5 const IndexBlockHeaderParamBlockHashL1 = 6 const IndexBlockHeaderParamGasUsed = 7 +// generated by KeyBlockHeaderParams so we don't calculate them every time +var ( + BlockHeaderBlockHashKey = utils.NodeKey{17540094328570681229, 15492539097581145461, 7686481670809850401, 16577991319572125169} + BlockHeaderCoinbaseKey = utils.NodeKey{13866806033333411216, 11510953292839890698, 8274877395843603978, 9372332419316597113} + BlockHeaderNumberKey = utils.NodeKey{6024064788222257862, 13049342112699253445, 12127984136733687200, 8398043461199794462} + BlockHeaderGasLimitKey = utils.NodeKey{5319681466197319121, 14057433120745733551, 5638531288094714593, 17204828339478940337} + BlockHeaderTimestampKey = utils.NodeKey{7890158832167317866, 11032486557242372179, 9653801891436451408, 2062577087515942703} + BlockHeaderGerKey = utils.NodeKey{16031278424721309229, 4132999715765882778, 6388713709192801251, 10826219431775251904} + BlockHeaderBlockHashL1Key = utils.NodeKey{5354929451503733866, 3129555839551084896, 2132809659008379950, 8230742270813566472} + BlockHeaderGasUsedKey = utils.NodeKey{8577769200631379655, 8682051454686970557, 5016656739138242322, 16717481432904730287} +) + // SMT block header constant keys const IndexBlockHeaderParam = 7 const IndexBlockHeaderTransactionHash = 8 diff --git a/smt/pkg/blockinfo/keys_test.go b/smt/pkg/blockinfo/keys_test.go index dcd70fa8f8f..6b52b834d9e 100644 --- a/smt/pkg/blockinfo/keys_test.go +++ b/smt/pkg/blockinfo/keys_test.go @@ -11,44 +11,58 @@ import ( func TestKeyBlockHeaderParams(t *testing.T) { scenarios := map[string]struct { param *big.Int + constKey utils.NodeKey expected utils.NodeKey shouldFail bool }{ "KeyBlockHash": { param: big.NewInt(IndexBlockHeaderParamBlockHash), + constKey: BlockHeaderBlockHashKey, expected: utils.NodeKey{17540094328570681229, 15492539097581145461, 7686481670809850401, 16577991319572125169}, shouldFail: false, }, "KeyCoinbase": { param: big.NewInt(IndexBlockHeaderParamCoinbase), + constKey: BlockHeaderCoinbaseKey, expected: utils.NodeKey{13866806033333411216, 11510953292839890698, 8274877395843603978, 9372332419316597113}, shouldFail: false, }, "KeyBlockNumber": { param: big.NewInt(IndexBlockHeaderParamNumber), + constKey: BlockHeaderNumberKey, expected: utils.NodeKey{6024064788222257862, 13049342112699253445, 12127984136733687200, 8398043461199794462}, shouldFail: false, }, "KeyGasLimit": { param: big.NewInt(IndexBlockHeaderParamGasLimit), + constKey: BlockHeaderGasLimitKey, expected: utils.NodeKey{5319681466197319121, 14057433120745733551, 5638531288094714593, 17204828339478940337}, shouldFail: false, }, "KeyTimestamp": { param: big.NewInt(IndexBlockHeaderParamTimestamp), + constKey: BlockHeaderTimestampKey, expected: utils.NodeKey{7890158832167317866, 11032486557242372179, 9653801891436451408, 2062577087515942703}, shouldFail: false, }, "KeyGer": { param: big.NewInt(IndexBlockHeaderParamGer), + constKey: BlockHeaderGerKey, expected: utils.NodeKey{16031278424721309229, 4132999715765882778, 6388713709192801251, 10826219431775251904}, shouldFail: false, }, "KeyBlockHashL1": { param: big.NewInt(IndexBlockHeaderParamBlockHashL1), + constKey: BlockHeaderBlockHashL1Key, expected: utils.NodeKey{5354929451503733866, 3129555839551084896, 2132809659008379950, 8230742270813566472}, shouldFail: false, }, + "KeyGasUsed": { + param: big.NewInt(IndexBlockHeaderParamGasUsed), + constKey: BlockHeaderGasUsedKey, + expected: utils.NodeKey{8577769200631379655, 8682051454686970557, 5016656739138242322, 16717481432904730287}, + shouldFail: false, + }, "NilKey": { param: nil, expected: utils.NodeKey{}, @@ -64,6 +78,7 @@ func TestKeyBlockHeaderParams(t *testing.T) { } else { assert.NoError(t, err) assert.Equal(t, scenario.expected, *val) + assert.Equal(t, scenario.constKey, *val) } }) } diff --git a/smt/pkg/smt/entity_storage_mdbx_test.go b/smt/pkg/smt/entity_storage_mdbx_test.go index ac16e48b570..22c34bc87ca 100644 --- a/smt/pkg/smt/entity_storage_mdbx_test.go +++ b/smt/pkg/smt/entity_storage_mdbx_test.go @@ -27,7 +27,7 @@ func TestSMT_Mdbx_AddRemove1Element(t *testing.T) { } //defer dbi.Close() - s := NewSMT(sdb) + s := NewSMT(sdb, false) r, _ := s.InsertBI(big.NewInt(1), big.NewInt(2)) if r.Mode != "insertNotFound" { @@ -50,7 +50,7 @@ func TestSMT_Mdbx_AddRemove3Elements(t *testing.T) { t.Errorf("Failed to create temp db: %v", err) } - s := NewSMT(sdb) + s := NewSMT(sdb, false) N := 3 var r *SMTResponse @@ -82,7 +82,7 @@ func TestSMT_Mdbx_AddRemove128Elements(t *testing.T) { t.Errorf("Failed to create temp db: %v", err) } - s := NewSMT(sdb) + s := NewSMT(sdb, false) N := 128 var r *SMTResponse @@ -163,7 +163,7 @@ func TestSMT_Mdbx_MultipleInsert(t *testing.T) { tr = x } - s := NewSMT(sdb) + s := NewSMT(sdb, false) s.SetLastRoot(tr) r, err := s.InsertBI(testCase.key, testCase.value) @@ -217,7 +217,7 @@ func runGenesisTestMdbx(tb testing.TB, filename string) { tb.Fatal("Failed to create db buckets: ", err) } - smt := NewSMT(sdb) + smt := NewSMT(sdb, false) for _, addr := range genesis.Genesis { fmt.Println(addr.ContractName) @@ -275,7 +275,7 @@ func runTestVectorsMdbx(t *testing.T, filename string) { for k, tc := range testCases { t.Run(strconv.Itoa(k), func(t *testing.T) { - smt := NewSMT(nil) + smt := NewSMT(nil, false) for _, addr := range tc.Addresses { bal, _ := new(big.Int).SetString(addr.Balance, 10) diff --git a/smt/pkg/smt/entity_storage_test.go b/smt/pkg/smt/entity_storage_test.go index 3688f590e59..ca057a5674e 100644 --- a/smt/pkg/smt/entity_storage_test.go +++ b/smt/pkg/smt/entity_storage_test.go @@ -72,7 +72,7 @@ func runGenesisTest(tb testing.TB, filename string) { tb.Fatal("Failed to parse json: ", err) } - smt := NewSMT(nil) + smt := NewSMT(nil, false) for _, addr := range genesis.Genesis { fmt.Println(addr.ContractName) @@ -130,7 +130,7 @@ func runTestVectors(t *testing.T, filename string) { for k, tc := range testCases { t.Run(strconv.Itoa(k), func(t *testing.T) { - smt := NewSMT(nil) + smt := NewSMT(nil, false) for _, addr := range tc.Addresses { bal, _ := new(big.Int).SetString(addr.Balance, 10) @@ -163,7 +163,7 @@ func runTestVectors(t *testing.T, filename string) { } func Test_FullGenesisTest_Deprecated(t *testing.T) { - s := NewSMT(nil) + s := NewSMT(nil, false) e := utils.NodeKey{ 13946701032480821596, diff --git a/smt/pkg/smt/smt.go b/smt/pkg/smt/smt.go index 31175259a9d..2a2ad956edc 100644 --- a/smt/pkg/smt/smt.go +++ b/smt/pkg/smt/smt.go @@ -51,7 +51,8 @@ type DebuggableDB interface { } type SMT struct { - Db DB + noSaveOnInsert bool + Db DB *RoSMT } @@ -65,14 +66,15 @@ type SMTResponse struct { Mode string } -func NewSMT(database DB) *SMT { +func NewSMT(database DB, noSaveOnInsert bool) *SMT { if database == nil { database = db.NewMemDb() } return &SMT{ - Db: database, - RoSMT: NewRoSMT(database), + noSaveOnInsert: noSaveOnInsert, + Db: database, + RoSMT: NewRoSMT(database), } } @@ -536,19 +538,23 @@ func (s *SMT) insert(k utils.NodeKey, v utils.NodeValue8, newValH [4]uint64, old } func (s *SMT) hashSave(in [8]uint64, capacity, h [4]uint64) ([4]uint64, error) { - var sl []uint64 - sl = append(sl, in[:]...) - sl = append(sl, capacity[:]...) + if !s.noSaveOnInsert { + var sl []uint64 + sl = append(sl, in[:]...) + sl = append(sl, capacity[:]...) + + v := utils.NodeValue12{} + for i, val := range sl { + b := new(big.Int) + v[i] = b.SetUint64(val) + } - v := utils.NodeValue12{} - for i, val := range sl { - b := new(big.Int) - v[i] = b.SetUint64(val) + err := s.Db.Insert(h, v) + if err != nil { + return [4]uint64{}, err + } } - - err := s.Db.Insert(h, v) - - return h, err + return h, nil } func (s *SMT) hashcalcAndSave(in [8]uint64, capacity [4]uint64) ([4]uint64, error) { diff --git a/smt/pkg/smt/smt_batch_test.go b/smt/pkg/smt/smt_batch_test.go index 30cba8fd343..31b74328d6e 100644 --- a/smt/pkg/smt/smt_batch_test.go +++ b/smt/pkg/smt/smt_batch_test.go @@ -55,8 +55,9 @@ func TestBatchSimpleInsert(t *testing.T) { keyPointers := []*utils.NodeKey{} valuePointers := []*utils.NodeValue8{} - smtIncremental := smt.NewSMT(nil) - smtBatch := smt.NewSMT(nil) + smtIncremental := smt.NewSMT(nil, false) + smtBatch := smt.NewSMT(nil, false) + smtBatchNoSave := smt.NewSMT(nil, true) for i := range keysRaw { k := utils.ScalarToNodeKey(keysRaw[i]) @@ -72,6 +73,9 @@ func TestBatchSimpleInsert(t *testing.T) { _, err := smtBatch.InsertBatch(context.Background(), "", keyPointers, valuePointers, nil, nil) assert.NilError(t, err) + _, err = smtBatchNoSave.InsertBatch(context.Background(), "", keyPointers, valuePointers, nil, nil) + assert.NilError(t, err) + smtIncremental.DumpTree() fmt.Println() smtBatch.DumpTree() @@ -81,7 +85,9 @@ func TestBatchSimpleInsert(t *testing.T) { smtIncrementalRootHash, _ := smtIncremental.Db.GetLastRoot() smtBatchRootHash, _ := smtBatch.Db.GetLastRoot() + smtBatchNoSaveRootHash, _ := smtBatchNoSave.Db.GetLastRoot() assert.Equal(t, utils.ConvertBigIntToHex(smtBatchRootHash), utils.ConvertBigIntToHex(smtIncrementalRootHash)) + assert.Equal(t, utils.ConvertBigIntToHex(smtBatchRootHash), utils.ConvertBigIntToHex(smtBatchNoSaveRootHash)) assertSmtDbStructure(t, smtBatch, false) } @@ -121,7 +127,7 @@ func BenchmarkIncrementalInsert(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - smtIncremental := smt.NewSMT(nil) + smtIncremental := smt.NewSMT(nil, false) incrementalInsert(smtIncremental, keys, vals) } } @@ -139,7 +145,25 @@ func BenchmarkBatchInsert(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - smtBatch := smt.NewSMT(nil) + smtBatch := smt.NewSMT(nil, false) + batchInsert(smtBatch, keys, vals) + } +} + +func BenchmarkBatchInsertNoSave(b *testing.B) { + keys := []*big.Int{} + vals := []*big.Int{} + for i := 0; i < 1000; i++ { + rand.Seed(time.Now().UnixNano()) + keys = append(keys, big.NewInt(int64(rand.Intn(10000)))) + + rand.Seed(time.Now().UnixNano()) + vals = append(vals, big.NewInt(int64(rand.Intn(10000)))) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + smtBatch := smt.NewSMT(nil, true) batchInsert(smtBatch, keys, vals) } } @@ -155,15 +179,21 @@ func TestBatchSimpleInsert2(t *testing.T) { vals = append(vals, big.NewInt(int64(rand.Intn(10000)))) } - smtIncremental := smt.NewSMT(nil) + smtIncremental := smt.NewSMT(nil, false) incrementalInsert(smtIncremental, keys, vals) - smtBatch := smt.NewSMT(nil) + smtBatch := smt.NewSMT(nil, false) batchInsert(smtBatch, keys, vals) + smtBatchNoSave := smt.NewSMT(nil, false) + batchInsert(smtBatchNoSave, keys, vals) + smtIncrementalRootHash, _ := smtIncremental.Db.GetLastRoot() smtBatchRootHash, _ := smtBatch.Db.GetLastRoot() + smtBatchNoSaveRootHash, _ := smtBatchNoSave.Db.GetLastRoot() + assert.Equal(t, utils.ConvertBigIntToHex(smtBatchRootHash), utils.ConvertBigIntToHex(smtIncrementalRootHash)) + assert.Equal(t, utils.ConvertBigIntToHex(smtBatchRootHash), utils.ConvertBigIntToHex(smtBatchNoSaveRootHash)) } func TestBatchWitness(t *testing.T) { @@ -327,8 +357,8 @@ func TestBatchWitness(t *testing.T) { }) } - smtIncremental := smt.NewSMT(nil) - smtBatch := smt.NewSMT(nil) + smtIncremental := smt.NewSMT(nil, false) + smtBatch := smt.NewSMT(nil, false) for i, k := range keys { smtIncremental.Insert(k, values[i]) @@ -391,8 +421,8 @@ func TestBatchDelete(t *testing.T) { }) } - smtIncremental := smt.NewSMT(nil) - smtBatch := smt.NewSMT(nil) + smtIncremental := smt.NewSMT(nil, false) + smtBatch := smt.NewSMT(nil, false) for i, k := range keys { smtIncremental.Insert(k, values[i]) @@ -419,8 +449,8 @@ func TestBatchRawInsert(t *testing.T) { keysForIncremental := []utils.NodeKey{} valuesForIncremental := []utils.NodeValue8{} - smtIncremental := smt.NewSMT(nil) - smtBatch := smt.NewSMT(nil) + smtIncremental := smt.NewSMT(nil, false) + smtBatch := smt.NewSMT(nil, false) rand.Seed(1) size := 1 << 10 @@ -506,9 +536,9 @@ func TestCompareAllTreesInsertTimesAndFinalHashesUsingDiskDb(t *testing.T) { batchDbPath := "/tmp/smt-batch" smtBatchDb, smtBatchTx, smtBatchSmtDb := initDb(t, batchDbPath) - smtIncremental := smt.NewSMT(smtIncrementalSmtDb) - smtBulk := smt.NewSMT(smtBulkSmtDb) - smtBatch := smt.NewSMT(smtBatchSmtDb) + smtIncremental := smt.NewSMT(smtIncrementalSmtDb, false) + smtBulk := smt.NewSMT(smtBulkSmtDb, false) + smtBatch := smt.NewSMT(smtBatchSmtDb, false) compareAllTreesInsertTimesAndFinalHashes(t, smtIncremental, smtBulk, smtBatch) @@ -526,9 +556,9 @@ func TestCompareAllTreesInsertTimesAndFinalHashesUsingDiskDb(t *testing.T) { } func TestCompareAllTreesInsertTimesAndFinalHashesUsingInMemoryDb(t *testing.T) { - smtIncremental := smt.NewSMT(nil) - smtBulk := smt.NewSMT(nil) - smtBatch := smt.NewSMT(nil) + smtIncremental := smt.NewSMT(nil, false) + smtBulk := smt.NewSMT(nil, false) + smtBatch := smt.NewSMT(nil, false) compareAllTreesInsertTimesAndFinalHashes(t, smtIncremental, smtBulk, smtBatch) } diff --git a/smt/pkg/smt/smt_create_test.go b/smt/pkg/smt/smt_create_test.go index 886514a94c5..b5d45b13910 100644 --- a/smt/pkg/smt/smt_create_test.go +++ b/smt/pkg/smt/smt_create_test.go @@ -59,7 +59,7 @@ func TestSMT_Create_Insert(t *testing.T) { for _, scenario := range testCases { t.Run(scenario.name, func(t *testing.T) { - s := NewSMT(nil) + s := NewSMT(nil, false) keys := []utils.NodeKey{} for k, v := range scenario.kvMap { if !v.IsZero() { @@ -92,7 +92,7 @@ func TestSMT_Create_CompareWithRandomData(t *testing.T) { //build and benchmark the tree the first way startTime := time.Now() - s1 := NewSMT(nil) + s1 := NewSMT(nil, false) var root1 *big.Int for k, v := range kvMap { @@ -110,7 +110,7 @@ func TestSMT_Create_CompareWithRandomData(t *testing.T) { //build the tree the from kvbulk startTime = time.Now() - s2 := NewSMT(nil) + s2 := NewSMT(nil, false) // set scenario old root if fail keys := []utils.NodeKey{} for k, v := range kvMap { @@ -150,7 +150,7 @@ func TestSMT_Create_Benchmark(t *testing.T) { //build and benchmark the tree the first way startTime := time.Now() //build the tree the from kvbulk - s := NewSMT(nil) + s := NewSMT(nil, false) // set scenario old root if fail keys := []utils.NodeKey{} for k, v := range kvMap { diff --git a/smt/pkg/smt/smt_test.go b/smt/pkg/smt/smt_test.go index 2fdd122e61e..a0faba89f42 100644 --- a/smt/pkg/smt/smt_test.go +++ b/smt/pkg/smt/smt_test.go @@ -44,7 +44,7 @@ func TestSMT_SingleInsert(t *testing.T) { for _, scenario := range scenarios { t.Run(scenario.name, func(t *testing.T) { - s := NewSMT(nil) + s := NewSMT(nil, false) // set scenario old root if fail newRoot, err := s.InsertBI(scenario.k, scenario.v) if err != nil { @@ -60,7 +60,7 @@ func TestSMT_SingleInsert(t *testing.T) { } func TestSMT_MultipleInsert(t *testing.T) { - s := NewSMT(nil) + s := NewSMT(nil, false) testCases := []struct { root *big.Int key *big.Int @@ -122,7 +122,7 @@ func TestSMT_MultipleInsert(t *testing.T) { } func TestSMT_MultipleInsert3(t *testing.T) { - s := NewSMT(nil) + s := NewSMT(nil, false) testCases := []struct { root *big.Int key *big.Int @@ -170,7 +170,7 @@ func TestSMT_MultipleInsert3(t *testing.T) { } func TestSMT_UpdateElement1(t *testing.T) { - s := NewSMT(nil) + s := NewSMT(nil, false) testCases := []struct { root *big.Int key *big.Int @@ -224,7 +224,7 @@ func TestSMT_UpdateElement1(t *testing.T) { } func TestSMT_AddSharedElement2(t *testing.T) { - s := NewSMT(nil) + s := NewSMT(nil, false) r1, err := s.InsertBI(big.NewInt(8), big.NewInt(2)) if err != nil { @@ -249,7 +249,7 @@ func TestSMT_AddSharedElement2(t *testing.T) { } func TestSMT_AddRemove128Elements(t *testing.T) { - s := NewSMT(nil) + s := NewSMT(nil, false) N := 128 var r *SMTResponse @@ -272,7 +272,7 @@ func TestSMT_AddRemove128Elements(t *testing.T) { } func TestSMT_MultipleInsert2(t *testing.T) { - s := NewSMT(nil) + s := NewSMT(nil, false) testCases := []struct { root *big.Int key utils.NodeKey diff --git a/smt/pkg/smt/witness_test.go b/smt/pkg/smt/witness_test.go index 8e17cf3967a..a46d928bcf5 100644 --- a/smt/pkg/smt/witness_test.go +++ b/smt/pkg/smt/witness_test.go @@ -7,9 +7,9 @@ import ( "fmt" "testing" - "github.com/holiman/uint256" libcommon "github.com/gateway-fm/cdk-erigon-lib/common" "github.com/gateway-fm/cdk-erigon-lib/kv/memdb" + "github.com/holiman/uint256" "github.com/ledgerwatch/erigon/chain" "github.com/ledgerwatch/erigon/core/state" "github.com/ledgerwatch/erigon/smt/pkg/db" @@ -58,7 +58,7 @@ func prepareSMT(t *testing.T) (*smt.SMT, *trie.RetainList) { memdb := db.NewMemDb() - smtTrie := smt.NewSMT(memdb) + smtTrie := smt.NewSMT(memdb, false) smtTrie.SetAccountState(contract.String(), balance.ToBig(), uint256.NewInt(1).ToBig()) smtTrie.SetContractBytecode(contract.String(), hex.EncodeToString(code)) diff --git a/smt/pkg/utils/util_test.go b/smt/pkg/utils/util_test.go index 715c4f7d0ec..a1b718b3a87 100644 --- a/smt/pkg/utils/util_test.go +++ b/smt/pkg/utils/util_test.go @@ -141,6 +141,13 @@ func TestScalarToArrayBig(t *testing.T) { } } +func BenchmarkScalarToArrayBig(b *testing.B) { + scalar := big.NewInt(0x1234567890ABCDEF) + for i := 0; i < b.N; i++ { + ScalarToArrayBig(scalar) + } +} + func TestArrayBigToScalar(t *testing.T) { scalar := big.NewInt(0x1234567890ABCDEF) diff --git a/smt/pkg/utils/utils.go b/smt/pkg/utils/utils.go index befe688a3da..d40363a1fb9 100644 --- a/smt/pkg/utils/utils.go +++ b/smt/pkg/utils/utils.go @@ -480,33 +480,31 @@ func ScalarToArrayBig12(scalar *big.Int) []*big.Int { return []*big.Int{r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, r10, r11} } -func ScalarToArrayBig(scalar *big.Int) []*big.Int { - scalar = new(big.Int).Set(scalar) - mask := new(big.Int) - mask.SetString("FFFFFFFF", 16) +var mask = big.NewInt(4294967295) +func ScalarToArrayBig(scalar *big.Int) []*big.Int { r0 := new(big.Int).And(scalar, mask) r1 := new(big.Int).Rsh(scalar, 32) - r1 = new(big.Int).And(r1, mask) + r1.And(r1, mask) r2 := new(big.Int).Rsh(scalar, 64) - r2 = new(big.Int).And(r2, mask) + r2.And(r2, mask) r3 := new(big.Int).Rsh(scalar, 96) - r3 = new(big.Int).And(r3, mask) + r3.And(r3, mask) r4 := new(big.Int).Rsh(scalar, 128) - r4 = new(big.Int).And(r4, mask) + r4.And(r4, mask) r5 := new(big.Int).Rsh(scalar, 160) - r5 = new(big.Int).And(r5, mask) + r5.And(r5, mask) r6 := new(big.Int).Rsh(scalar, 192) - r6 = new(big.Int).And(r6, mask) + r6.And(r6, mask) r7 := new(big.Int).Rsh(scalar, 224) - r7 = new(big.Int).And(r7, mask) + r7.And(r7, mask) return []*big.Int{r0, r1, r2, r3, r4, r5, r6, r7} } diff --git a/zk/stages/stage_interhashes.go b/zk/stages/stage_interhashes.go index e4f25df258e..c63e06e537d 100644 --- a/zk/stages/stage_interhashes.go +++ b/zk/stages/stage_interhashes.go @@ -127,7 +127,7 @@ func SpawnZkIntermediateHashesStage(s *stagedsync.StageState, u stagedsync.Unwin shouldRegenerate := to > s.BlockNumber && to-s.BlockNumber > cfg.zk.RebuildTreeAfter eridb := db2.NewEriDb(tx) - smt := smt.NewSMT(eridb) + smt := smt.NewSMT(eridb, false) if cfg.zk.IncrementTreeAlways { // increment only behaviour @@ -475,7 +475,7 @@ func unwindZkSMT(ctx context.Context, logPrefix string, from, to uint64, db kv.R defer log.Info(fmt.Sprintf("[%s] Unwind ended", logPrefix)) eridb := db2.NewEriDb(db) - dbSmt := smt.NewSMT(eridb) + dbSmt := smt.NewSMT(eridb, false) log.Info(fmt.Sprintf("[%s]", logPrefix), "last root", common.BigToHash(dbSmt.LastRoot())) diff --git a/zk/stages/stage_sequence_execute_utils.go b/zk/stages/stage_sequence_execute_utils.go index c7a6f2cb239..a93ae224d2b 100644 --- a/zk/stages/stage_sequence_execute_utils.go +++ b/zk/stages/stage_sequence_execute_utils.go @@ -180,7 +180,7 @@ func (sdb *stageDb) SetTx(tx kv.RwTx) { sdb.hermezDb = hermez_db.NewHermezDb(tx) sdb.eridb = db2.NewEriDb(tx) sdb.stateReader = state.NewPlainStateReader(tx) - sdb.smt = smtNs.NewSMT(sdb.eridb) + sdb.smt = smtNs.NewSMT(sdb.eridb, false) } type nextBatchL1Data struct { diff --git a/zk/tx/tx.go b/zk/tx/tx.go index 1b4025d25c4..9b0dfe072ba 100644 --- a/zk/tx/tx.go +++ b/zk/tx/tx.go @@ -504,6 +504,8 @@ func ComputeL2TxHash( return common.HexToHash(hashed), nil } +var re = regexp.MustCompile("^[0-9a-fA-F]*$") + func formatL2TxHashParam(param interface{}, paramLength int) (string, error) { var paramStr string @@ -560,11 +562,7 @@ func formatL2TxHashParam(param interface{}, paramLength int) (string, error) { paramStr = "0" + paramStr } - matched, err := regexp.MatchString("^[0-9a-fA-F]+$", paramStr) - if err != nil { - return "", err - } - if !matched { + if !re.MatchString(paramStr) { return "", fmt.Errorf("invalid hex string") } diff --git a/zk/tx/tx_test.go b/zk/tx/tx_test.go index 994e4bfb3db..c120f16abfc 100644 --- a/zk/tx/tx_test.go +++ b/zk/tx/tx_test.go @@ -318,6 +318,21 @@ func TestComputeL2TxHashScenarios(t *testing.T) { } +func BenchmarkComputeL2TxHashSt(b *testing.B) { + chainId := big.NewInt(2440) + nonce := uint64(87) + gasPrice := uint256.NewInt(493000000) + gasLimit := uint64(100000) + value := uint256.NewInt(100) + data := []byte{} + to := common.HexToAddress("0x5751D5b29dA14d5C334A9453cF04181f417aBe4c") + from := common.HexToAddress("0x5751D5b29dA14d5C334A9453cF04181f417aBe4c") + + for i := 0; i < b.N; i++ { + _, _ = ComputeL2TxHash(chainId, value, gasPrice, nonce, gasLimit, &to, &from, data) + } +} + type testCase struct { param interface{} paramLength int @@ -361,6 +376,12 @@ func TestFormatL2TxHashParam(t *testing.T) { } } +func BenchmarkFormatL2TxHashParam(b *testing.B) { + for i := 0; i < b.N; i++ { + _, _ = formatL2TxHashParam(uint256.NewInt(1000), 8) + } +} + func Test_EncodeToBatchL2DataAndBack(t *testing.T) { toAddress := common.HexToAddress("0x1") tx := &types.LegacyTx{ diff --git a/zk/witness/witness.go b/zk/witness/witness.go index 3febd76b384..450b6ed3833 100644 --- a/zk/witness/witness.go +++ b/zk/witness/witness.go @@ -6,6 +6,8 @@ import ( "errors" "fmt" + "math/big" + libcommon "github.com/gateway-fm/cdk-erigon-lib/common" "github.com/gateway-fm/cdk-erigon-lib/common/datadir" "github.com/gateway-fm/cdk-erigon-lib/kv" @@ -27,11 +29,10 @@ import ( "github.com/ledgerwatch/erigon/turbo/trie" dstypes "github.com/ledgerwatch/erigon/zk/datastream/types" "github.com/ledgerwatch/erigon/zk/hermez_db" + "github.com/ledgerwatch/erigon/zk/l1_data" zkStages "github.com/ledgerwatch/erigon/zk/stages" zkUtils "github.com/ledgerwatch/erigon/zk/utils" "github.com/ledgerwatch/log/v3" - "github.com/ledgerwatch/erigon/zk/l1_data" - "math/big" ) var ( @@ -327,7 +328,7 @@ func (g *Generator) generateWitness(tx kv.Tx, ctx context.Context, blocks []*eri } eridb := db2.NewEriDb(batch) - smtTrie := smt.NewSMT(eridb) + smtTrie := smt.NewSMT(eridb, false) witness, err := smt.BuildWitness(smtTrie, rl, ctx) if err != nil {