Skip to content

Commit

Permalink
tiny smt opt (erigontech#1196)
Browse files Browse the repository at this point in the history
* tiny smt opt

* restore original test's datasize
  • Loading branch information
kstoykov authored Sep 19, 2024
1 parent 1697e62 commit f2a86f1
Show file tree
Hide file tree
Showing 23 changed files with 968 additions and 1,067 deletions.
28 changes: 5 additions & 23 deletions cmd/rpcdaemon/commands/zkevm_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1623,25 +1623,10 @@ func (zkapi *ZkEvmAPIImpl) GetProof(ctx context.Context, address common.Address,
return nil, err
}

balanceKey, err := smtUtils.KeyEthAddrBalance(address.String())
if err != nil {
return nil, err
}

nonceKey, err := smtUtils.KeyEthAddrNonce(address.String())
if err != nil {
return nil, err
}

codeHashKey, err := smtUtils.KeyContractCode(address.String())
if err != nil {
return nil, err
}

codeLengthKey, err := smtUtils.KeyContractLength(address.String())
if err != nil {
return nil, err
}
balanceKey := smtUtils.KeyEthAddrBalance(address.String())
nonceKey := smtUtils.KeyEthAddrNonce(address.String())
codeHashKey := smtUtils.KeyContractCode(address.String())
codeLengthKey := smtUtils.KeyContractLength(address.String())

balanceProofs := smt.FilterProofs(proofs, balanceKey)
balanceBytes, err := smt.VerifyAndGetVal(stateRootNode, balanceProofs, balanceKey)
Expand Down Expand Up @@ -1687,10 +1672,7 @@ func (zkapi *ZkEvmAPIImpl) GetProof(ctx context.Context, address common.Address,

addressArrayBig := smtUtils.ScalarToArrayBig(smtUtils.ConvertHexToBigInt(address.String()))
for _, k := range storageKeys {
storageKey, err := smtUtils.KeyContractStorage(addressArrayBig, k.String())
if err != nil {
return nil, err
}
storageKey := smtUtils.KeyContractStorage(addressArrayBig, k.String())
storageProofs := smt.FilterProofs(proofs, storageKey)

valueBytes, err := smt.VerifyAndGetVal(stateRootNode, storageProofs, storageKey)
Expand Down
4 changes: 2 additions & 2 deletions core/state/intra_block_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ import (

"encoding/hex"

"github.com/holiman/uint256"
libcommon "github.com/gateway-fm/cdk-erigon-lib/common"
types2 "github.com/gateway-fm/cdk-erigon-lib/types"
"github.com/holiman/uint256"
"github.com/ledgerwatch/erigon/chain"
"github.com/ledgerwatch/erigon/common/u256"
"github.com/ledgerwatch/erigon/core/types"
Expand Down Expand Up @@ -362,7 +362,7 @@ func (sdb *IntraBlockState) SetCode(addr libcommon.Address, code []byte) {
return
}

hashedBytecode, _ := utils.HashContractBytecode(hex.EncodeToString(code))
hashedBytecode := utils.HashContractBytecode(hex.EncodeToString(code))
stateObject.SetCode(libcommon.HexToHash(hashedBytecode), code)
}
}
Expand Down
34 changes: 5 additions & 29 deletions core/state/trie_db.go
Original file line number Diff line number Diff line change
Expand Up @@ -912,48 +912,24 @@ func (tds *TrieDbState) ResolveSMTRetainList() (*trie.RetainList, error) {
for _, addrHash := range accountTouches {
addr := common.BytesToAddress(tds.preimageMap[addrHash]).String()

nonceKey, err := utils.KeyEthAddrNonce(addr)

if err != nil {
return nil, err
}

nonceKey := utils.KeyEthAddrNonce(addr)
keys = append(keys, nonceKey.GetPath())

balanceKey, err := utils.KeyEthAddrBalance(addr)

if err != nil {
return nil, err
}

balanceKey := utils.KeyEthAddrBalance(addr)
keys = append(keys, balanceKey.GetPath())

codeKey, err := utils.KeyContractCode(addr)

if err != nil {
return nil, err
}

codeKey := utils.KeyContractCode(addr)
keys = append(keys, codeKey.GetPath())

codeLengthKey, err := utils.KeyContractLength(addr)

if err != nil {
return nil, err
}

codeLengthKey := utils.KeyContractLength(addr)
keys = append(keys, codeLengthKey.GetPath())
}

getSMTPath := func(ethAddr string, key string) ([]int, error) {
a := utils.ConvertHexToBigInt(ethAddr)
addr := utils.ScalarToArrayBig(a)

storageKey, err := utils.KeyContractStorage(addr, key)

if err != nil {
return nil, err
}
storageKey := utils.KeyContractStorage(addr, key)

return storageKey.GetPath(), nil
}
Expand Down
6 changes: 1 addition & 5 deletions smt/pkg/blockinfo/block_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,7 @@ func (b *BlockInfoTree) GenerateBlockTxKeysVals(

logToEncode := "0x" + hex.EncodeToString(rLog.Data) + reducedTopics

hash, err := utils.HashContractBytecode(logToEncode)
if err != nil {
return nil, nil, err
}

hash := utils.HashContractBytecode(logToEncode)
logEncodedBig := utils.ConvertHexToBigInt(hash)
key, val, err = generateTxLog(txIndexBig, big.NewInt(logIndex), logEncodedBig)
if err != nil {
Expand Down
10 changes: 2 additions & 8 deletions smt/pkg/blockinfo/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,8 @@ func KeyTxLogs(txIndex, logIndex *big.Int) (*utils.NodeKey, error) {
return nil, err
}

hk0, err := utils.Hash(lia.ToUintArray(), utils.BranchCapacity)
if err != nil {
return nil, err
}
hkRes, err := utils.Hash(key1.ToUintArray(), hk0)
if err != nil {
return nil, err
}
hk0 := utils.Hash(lia.ToUintArray(), utils.BranchCapacity)
hkRes := utils.Hash(key1.ToUintArray(), hk0)

return &utils.NodeKey{hkRes[0], hkRes[1], hkRes[2], hkRes[3]}, nil
}
Expand Down
6 changes: 1 addition & 5 deletions smt/pkg/db/mem-db.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,11 +253,7 @@ func (m *MemDb) AddCode(code []byte) error {
m.lock.Lock() // Lock for writing
defer m.lock.Unlock() // Make sure to unlock when done

codeHash, err := utils.HashContractBytecode(hex.EncodeToString(code))
if err != nil {
return err
}

codeHash := utils.HashContractBytecode(hex.EncodeToString(code))
m.DbCode[codeHash] = code
return nil
}
Expand Down
77 changes: 21 additions & 56 deletions smt/pkg/smt/entity_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,39 +15,29 @@ import (
)

func (s *SMT) SetAccountState(ethAddr string, balance, nonce *big.Int) (*big.Int, error) {
keyBalance, err := utils.KeyEthAddrBalance(ethAddr)
if err != nil {
return nil, err
}
keyNonce, err := utils.KeyEthAddrNonce(ethAddr)
if err != nil {
return nil, err
}
keyBalance := utils.KeyEthAddrBalance(ethAddr)
keyNonce := utils.KeyEthAddrNonce(ethAddr)

_, err = s.InsertKA(keyBalance, balance)
if err != nil {
if _, err := s.InsertKA(keyBalance, balance); err != nil {
return nil, err
}

ks := utils.EncodeKeySource(utils.KEY_BALANCE, utils.ConvertHexToAddress(ethAddr), common.Hash{})
err = s.Db.InsertKeySource(keyBalance, ks)
if err != nil {
if err := s.Db.InsertKeySource(keyBalance, ks); err != nil {
return nil, err
}

auxRes, err := s.InsertKA(keyNonce, nonce)

if err != nil {
return nil, err
}

ks = utils.EncodeKeySource(utils.KEY_NONCE, utils.ConvertHexToAddress(ethAddr), common.Hash{})
err = s.Db.InsertKeySource(keyNonce, ks)
if err != nil {
if err := s.Db.InsertKeySource(keyNonce, ks); err != nil {
return nil, err
}

return auxRes.NewRootScalar.ToBigInt(), err
return auxRes.NewRootScalar.ToBigInt(), nil
}

func (s *SMT) SetAccountStorage(addr libcommon.Address, acc *accounts.Account) error {
Expand All @@ -62,14 +52,8 @@ func (s *SMT) SetAccountStorage(addr libcommon.Address, acc *accounts.Account) e
}

func (s *SMT) SetContractBytecode(ethAddr string, bytecode string) error {
keyContractCode, err := utils.KeyContractCode(ethAddr)
if err != nil {
return err
}
keyContractLength, err := utils.KeyContractLength(ethAddr)
if err != nil {
return err
}
keyContractCode := utils.KeyContractCode(ethAddr)
keyContractLength := utils.KeyContractLength(ethAddr)

bi, bytecodeLength, err := convertBytecodeToBigInt(bytecode)
if err != nil {
Expand Down Expand Up @@ -205,6 +189,7 @@ func (s *SMT) SetContractStorage(ethAddr string, storage map[string]string, prog

func (s *SMT) SetStorage(ctx context.Context, logPrefix string, accChanges map[libcommon.Address]*accounts.Account, codeChanges map[libcommon.Address]string, storageChanges map[libcommon.Address]map[string]string) ([]*utils.NodeKey, []*utils.NodeValue8, error) {
var isDelete bool
var err error

storageChangesInitialCapacity := 0
for _, storage := range storageChanges {
Expand All @@ -222,14 +207,8 @@ func (s *SMT) SetStorage(ctx context.Context, logPrefix string, accChanges map[l
default:
}
ethAddr := addr.String()
keyBalance, err := utils.KeyEthAddrBalance(ethAddr)
if err != nil {
return nil, nil, err
}
keyNonce, err := utils.KeyEthAddrNonce(ethAddr)
if err != nil {
return nil, nil, err
}
keyBalance := utils.KeyEthAddrBalance(ethAddr)
keyNonce := utils.KeyEthAddrNonce(ethAddr)

balance := big.NewInt(0)
nonce := big.NewInt(0)
Expand Down Expand Up @@ -276,14 +255,8 @@ func (s *SMT) SetStorage(ctx context.Context, logPrefix string, accChanges map[l
}

ethAddr := addr.String()
keyContractCode, err := utils.KeyContractCode(ethAddr)
if err != nil {
return nil, nil, err
}
keyContractLength, err := utils.KeyContractLength(ethAddr)
if err != nil {
return nil, nil, err
}
keyContractCode := utils.KeyContractCode(ethAddr)
keyContractLength := utils.KeyContractLength(ethAddr)

bi, bytecodeLength, err := convertBytecodeToBigInt(code)
if err != nil {
Expand Down Expand Up @@ -330,11 +303,7 @@ func (s *SMT) SetStorage(ctx context.Context, logPrefix string, accChanges map[l
ethAddrBigIngArray := utils.ScalarToArrayBig(ethAddrBigInt)

for k, v := range storage {
keyStoragePosition, err := utils.KeyContractStorage(ethAddrBigIngArray, k)
if err != nil {
return nil, nil, err
}

keyStoragePosition := utils.KeyContractStorage(ethAddrBigIngArray, k)
valueBigInt := convertStrintToBigInt(v)
keysBatchStorage = append(keysBatchStorage, &keyStoragePosition)
if valuesBatchStorage, isDelete, err = appendToValuesBatchStorageBigInt(valuesBatchStorage, valueBigInt); err != nil {
Expand All @@ -355,8 +324,11 @@ func (s *SMT) SetStorage(ctx context.Context, logPrefix string, accChanges map[l
}

insertBatchCfg := NewInsertBatchConfig(ctx, logPrefix, true)
_, err := s.InsertBatch(insertBatchCfg, keysBatchStorage, valuesBatchStorage, nil, nil)
return keysBatchStorage, valuesBatchStorage, err
if _, err = s.InsertBatch(insertBatchCfg, keysBatchStorage, valuesBatchStorage, nil, nil); err != nil {
return nil, nil, err
}

return keysBatchStorage, valuesBatchStorage, nil
}

func (s *SMT) InsertKeySource(nodeKey *utils.NodeKey, key int, accountAddr *libcommon.Address, storagePosition *libcommon.Hash) error {
Expand All @@ -377,10 +349,7 @@ func calcHashVal(v string) (*utils.NodeValue8, [4]uint64, error) {
return nil, [4]uint64{}, err
}

h, err := utils.Hash(value.ToUintArray(), utils.BranchCapacity)
if err != nil {
return nil, [4]uint64{}, err
}
h := utils.Hash(value.ToUintArray(), utils.BranchCapacity)

return value, h, nil
}
Expand All @@ -405,12 +374,8 @@ func appendToValuesBatchStorageBigInt(valuesBatchStorage []*utils.NodeValue8, va
}

func convertBytecodeToBigInt(bytecode string) (*big.Int, int, error) {
hashedBytecode, err := utils.HashContractBytecode(bytecode)
if err != nil {
return nil, 0, err
}

var parsedBytecode string
hashedBytecode := utils.HashContractBytecode(bytecode)

if strings.HasPrefix(bytecode, "0x") {
parsedBytecode = bytecode[2:]
Expand Down
6 changes: 1 addition & 5 deletions smt/pkg/smt/entity_storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,7 @@ func Test_SetContractBytecode_HashBytecode(t *testing.T) {

expected := "0x9257c9a31308a7cb046aba1a95679dd7e3ad695b6900e84a6470b401b1ea416e"

hashedBytecode, err := utils.HashContractBytecode(byteCode)
if err != nil {
t.Errorf("setContractBytecode failed: %v", err)
}

hashedBytecode := utils.HashContractBytecode(byteCode)
if hashedBytecode != expected {
t.Errorf("setContractBytecode failed: expected %v, got %v", expected, hashedBytecode)
}
Expand Down
17 changes: 2 additions & 15 deletions smt/pkg/smt/proof.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,8 @@ func VerifyAndGetVal(stateRoot utils.NodeKey, proof []hexutility.Bytes, key util
}

path := key.GetPath()

curRoot := stateRoot

foundValue := false

for i := 0; i < len(proof); i++ {
isFinalNode := len(proof[i]) == 65

Expand All @@ -147,12 +144,7 @@ func VerifyAndGetVal(stateRoot utils.NodeKey, proof []hexutility.Bytes, key util
leftChildNode := [4]uint64{leftChild[0], leftChild[1], leftChild[2], leftChild[3]}
rightChildNode := [4]uint64{rightChild[0], rightChild[1], rightChild[2], rightChild[3]}

h, err := utils.Hash(utils.ConcatArrays4(leftChildNode, rightChildNode), capacity)

if err != nil {
return nil, err
}

h := utils.Hash(utils.ConcatArrays4(leftChildNode, rightChildNode), capacity)
if curRoot != h {
return nil, fmt.Errorf("root mismatch at level %d, expected %d, got %d", i, curRoot, h)
}
Expand Down Expand Up @@ -193,12 +185,7 @@ func VerifyAndGetVal(stateRoot utils.NodeKey, proof []hexutility.Bytes, key util
return nil, err
}

h, err := utils.Hash(nodeValue.ToUintArray(), utils.BranchCapacity)

if err != nil {
return nil, err
}

h := utils.Hash(nodeValue.ToUintArray(), utils.BranchCapacity)
if h != curRoot {
return nil, fmt.Errorf("root mismatch at level %d, expected %d, got %d", len(proof)-1, curRoot, h)
}
Expand Down
Loading

0 comments on commit f2a86f1

Please sign in to comment.