From fab51758d222d6fbe69adac8496c96051431f09a Mon Sep 17 00:00:00 2001 From: BeniaminDrasovean Date: Mon, 3 Feb 2025 14:42:52 +0200 Subject: [PATCH] fix after review --- trie/branchNode.go | 16 ++++++++++------ trie/branchNode_test.go | 29 +++++++++++++---------------- trie/extensionNode.go | 14 +++++++++----- trie/extensionNode_test.go | 25 ++++++++++++------------- trie/interface.go | 8 +++++++- trie/leafNode.go | 6 +++--- trie/leafNode_test.go | 12 ++++-------- trie/patriciaMerkleTrie.go | 17 ++++++++++------- 8 files changed, 68 insertions(+), 59 deletions(-) diff --git a/trie/branchNode.go b/trie/branchNode.go index e8e3d0a994..b918f8278c 100644 --- a/trie/branchNode.go +++ b/trie/branchNode.go @@ -309,24 +309,28 @@ func (bn *branchNode) tryGet(key []byte, currentDepth uint32, db common.TrieStor return child.tryGet(key, currentDepth+1, db) } -func (bn *branchNode) getNext(key []byte, db common.TrieStorageInteractor) (node, []byte, []byte, error) { +func (bn *branchNode) getNext(key []byte, db common.TrieStorageInteractor) (*nodeData, error) { if len(key) == 0 { - return nil, nil, nil, ErrValueTooShort + return nil, ErrValueTooShort } childPos := key[firstByte] if childPosOutOfRange(childPos) { - return nil, nil, nil, ErrChildPosOutOfRange + return nil, ErrChildPosOutOfRange } key = key[1:] if len(bn.EncodedChildren[childPos]) == 0 { - return nil, nil, nil, ErrNodeNotFound + return nil, ErrNodeNotFound } childNode, encodedNode, err := getNodeFromDBAndDecode(bn.EncodedChildren[childPos], db, bn.marsh, bn.hasher) if err != nil { - return nil, nil, nil, err + return nil, err } - return childNode, encodedNode, key, nil + return &nodeData{ + currentNode: childNode, + encodedNode: encodedNode, + hexKey: key, + }, nil } func (bn *branchNode) insert( diff --git a/trie/branchNode_test.go b/trie/branchNode_test.go index 37edc88096..f0454053e8 100644 --- a/trie/branchNode_test.go +++ b/trie/branchNode_test.go @@ -348,14 +348,15 @@ func TestBranchNode_getNext(t *testing.T) { key := append([]byte{childPos}, []byte("dog")...) db := testscommon.NewMemDbMock() bn.commitDirty(0, 5, getTestGoroutinesManager(), hashesCollector.NewDisabledHashesCollector(), db, db) - n, nodeBytes, key, err := bn.getNext(key, db) + data, err := bn.getNext(key, db) + assert.NotNil(t, data) h1, _ := encodeNodeAndGetHash(nextNode) - h2, _ := encodeNodeAndGetHash(n) + h2, _ := encodeNodeAndGetHash(data.currentNode) nextNodeBytes, _ := nextNode.getEncodedNode() - assert.Equal(t, nextNodeBytes, nodeBytes) + assert.Equal(t, nextNodeBytes, data.encodedNode) assert.Equal(t, h1, h2) - assert.Equal(t, []byte("dog"), key) + assert.Equal(t, []byte("dog"), data.hexKey) assert.Nil(t, err) } @@ -365,10 +366,8 @@ func TestBranchNode_getNextWrongKey(t *testing.T) { bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher()) key := []byte("dog") - n, nodeBytes, key, err := bn.getNext(key, nil) - assert.Nil(t, n) - assert.Nil(t, key) - assert.Nil(t, nodeBytes) + data, err := bn.getNext(key, nil) + assert.Nil(t, data) assert.Equal(t, ErrChildPosOutOfRange, err) } @@ -379,10 +378,8 @@ func TestBranchNode_getNextNilChild(t *testing.T) { nilChildPos := byte(4) key := append([]byte{nilChildPos}, []byte("dog")...) - n, nodeBytes, key, err := bn.getNext(key, nil) - assert.Nil(t, n) - assert.Nil(t, key) - assert.Nil(t, nodeBytes) + data, err := bn.getNext(key, nil) + assert.Nil(t, data) assert.Equal(t, ErrNodeNotFound, err) } @@ -463,8 +460,8 @@ func TestBranchNode_insertInStoredBnOnExistingPos(t *testing.T) { bn.commitDirty(0, 5, getTestGoroutinesManager(), hashesCollector.NewDisabledHashesCollector(), db, db) bnHash := bn.getHash() - ln, _, _, _ := bn.getNext(key, db) - lnHash := ln.getHash() + nd, _ := bn.getNext(key, db) + lnHash := nd.currentNode.getHash() expectedHashes := [][]byte{lnHash, bnHash} goRoutinesManager := getTestGoroutinesManager() @@ -591,8 +588,8 @@ func TestBranchNode_deleteFromStoredBn(t *testing.T) { bn.commitDirty(0, 5, getTestGoroutinesManager(), hashesCollector.NewDisabledHashesCollector(), db, db) bnHash := bn.getHash() - ln, _, _, _ := bn.getNext(lnKey, db) - lnHash := ln.getHash() + nd, _ := bn.getNext(lnKey, db) + lnHash := nd.currentNode.getHash() expectedHashes := [][]byte{lnHash, bnHash} goRoutinesManager := getTestGoroutinesManager() diff --git a/trie/extensionNode.go b/trie/extensionNode.go index 162fd4cb84..acf5ca9c26 100644 --- a/trie/extensionNode.go +++ b/trie/extensionNode.go @@ -246,22 +246,26 @@ func (en *extensionNode) tryGet(key []byte, currentDepth uint32, db common.TrieS return child.tryGet(key, currentDepth+1, db) } -func (en *extensionNode) getNext(key []byte, db common.TrieStorageInteractor) (node, []byte, []byte, error) { +func (en *extensionNode) getNext(key []byte, db common.TrieStorageInteractor) (*nodeData, error) { keyTooShort := len(key) < len(en.Key) if keyTooShort { - return nil, nil, nil, ErrNodeNotFound + return nil, ErrNodeNotFound } keysDontMatch := !bytes.Equal(en.Key, key[:len(en.Key)]) if keysDontMatch { - return nil, nil, nil, ErrNodeNotFound + return nil, ErrNodeNotFound } child, encodedChild, err := getNodeFromDBAndDecode(en.EncodedChild, db, en.marsh, en.hasher) if err != nil { - return nil, nil, nil, err + return nil, err } key = key[len(en.Key):] - return child, encodedChild, key, nil + return &nodeData{ + currentNode: child, + encodedNode: encodedChild, + hexKey: key, + }, nil } func (en *extensionNode) insert( diff --git a/trie/extensionNode_test.go b/trie/extensionNode_test.go index cb03047d08..f67c99a825 100644 --- a/trie/extensionNode_test.go +++ b/trie/extensionNode_test.go @@ -284,11 +284,12 @@ func TestExtensionNode_getNext(t *testing.T) { key := append(enKey, bnKey...) key = append(key, lnKey...) - n, nodeBytes, newKey, err := en.getNext(key, db) + data, err := en.getNext(key, db) child, childBytes, _ := getNodeFromDBAndDecode(en.EncodedChild, db, en.marsh, en.hasher) - assert.Equal(t, childBytes, nodeBytes) - assert.Equal(t, child, n) - assert.Equal(t, key[1:], newKey) + assert.NotNil(t, data) + assert.Equal(t, childBytes, data.encodedNode) + assert.Equal(t, child, data.currentNode) + assert.Equal(t, key[1:], data.hexKey) assert.Nil(t, err) } @@ -300,10 +301,8 @@ func TestExtensionNode_getNextWrongKey(t *testing.T) { lnKey := []byte("dog") key := append(bnKey, lnKey...) - n, nodeBytes, key, err := en.getNext(key, nil) - assert.Nil(t, n) - assert.Nil(t, key) - assert.Nil(t, nodeBytes) + data, err := en.getNext(key, nil) + assert.Nil(t, data) assert.Equal(t, ErrNodeNotFound, err) } @@ -356,8 +355,8 @@ func TestExtensionNode_insertInStoredEnSameKey(t *testing.T) { en.commitDirty(0, 5, getTestGoroutinesManager(), hashesCollector.NewDisabledHashesCollector(), db, db) enHash := en.getHash() - bn, _, _, _ := en.getNext(enKey, db) - bnHash := bn.getHash() + nd, _ := en.getNext(enKey, db) + bnHash := nd.currentNode.getHash() expectedHashes := [][]byte{bnHash, enHash} goRoutinesManager := getTestGoroutinesManager() @@ -465,9 +464,9 @@ func TestExtensionNode_deleteFromStoredEn(t *testing.T) { en.setHash(getTestGoroutinesManager()) en.commitDirty(0, 5, getTestGoroutinesManager(), hashesCollector.NewDisabledHashesCollector(), db, db) - bn, _, key, _ := en.getNext(key, db) - ln, _, _, _ := bn.getNext(key, db) - expectedHashes := [][]byte{ln.getHash(), bn.getHash(), en.getHash()} + bnData, _ := en.getNext(key, db) + lnData, _ := bnData.currentNode.getNext(bnData.hexKey, db) + expectedHashes := [][]byte{lnData.currentNode.getHash(), bnData.currentNode.getHash(), en.getHash()} data := []core.TrieData{{Key: lnPathKey}} goRoutinesManager := getTestGoroutinesManager() diff --git a/trie/interface.go b/trie/interface.go index 8d0bbc88ed..2b291f074a 100644 --- a/trie/interface.go +++ b/trie/interface.go @@ -12,6 +12,12 @@ import ( vmcommon "github.com/multiversx/mx-chain-vm-common-go" ) +type nodeData struct { + currentNode node + encodedNode []byte + hexKey []byte +} + type baseTrieNode interface { getHash() []byte setGivenHash([]byte) @@ -28,7 +34,7 @@ type node interface { setHash(goRoutinesManager common.TrieGoroutinesManager) getEncodedNode() ([]byte, error) tryGet(key []byte, depth uint32, db common.TrieStorageInteractor) ([]byte, uint32, error) - getNext(key []byte, db common.TrieStorageInteractor) (node, []byte, []byte, error) + getNext(key []byte, db common.TrieStorageInteractor) (*nodeData, error) insert(newData []core.TrieData, goRoutinesManager common.TrieGoroutinesManager, modifiedHashes common.AtomicBytesSlice, db common.TrieStorageInteractor) node delete(data []core.TrieData, goRoutinesManager common.TrieGoroutinesManager, modifiedHashes common.AtomicBytesSlice, db common.TrieStorageInteractor) (bool, node) reduceNode(pos int, db common.TrieStorageInteractor) (node, bool, error) diff --git a/trie/leafNode.go b/trie/leafNode.go index 1a5ce2462f..60cb79b1a1 100644 --- a/trie/leafNode.go +++ b/trie/leafNode.go @@ -144,11 +144,11 @@ func (ln *leafNode) tryGet(key []byte, currentDepth uint32, _ common.TrieStorage return nil, currentDepth, nil } -func (ln *leafNode) getNext(key []byte, _ common.TrieStorageInteractor) (node, []byte, []byte, error) { +func (ln *leafNode) getNext(key []byte, _ common.TrieStorageInteractor) (*nodeData, error) { if bytes.Equal(key, ln.Key) { - return nil, nil, nil, nil + return nil, nil } - return nil, nil, nil, ErrNodeNotFound + return nil, ErrNodeNotFound } func (ln *leafNode) insert( diff --git a/trie/leafNode_test.go b/trie/leafNode_test.go index 77b0018426..ab505f154e 100644 --- a/trie/leafNode_test.go +++ b/trie/leafNode_test.go @@ -165,10 +165,8 @@ func TestLeafNode_getNext(t *testing.T) { ln := getLn(getTestMarshalizerAndHasher()) key := []byte("dog") - n, nodeBytes, key, err := ln.getNext(key, nil) - assert.Nil(t, n) - assert.Nil(t, key) - assert.Nil(t, nodeBytes) + data, err := ln.getNext(key, nil) + assert.Nil(t, data) assert.Nil(t, err) } @@ -178,10 +176,8 @@ func TestLeafNode_getNextWrongKey(t *testing.T) { ln := getLn(getTestMarshalizerAndHasher()) wrongKey := append([]byte{2}, []byte("dog")...) - n, nodeBytes, key, err := ln.getNext(wrongKey, nil) - assert.Nil(t, n) - assert.Nil(t, key) - assert.Nil(t, nodeBytes) + data, err := ln.getNext(wrongKey, nil) + assert.Nil(t, data) assert.Equal(t, ErrNodeNotFound, err) } diff --git a/trie/patriciaMerkleTrie.go b/trie/patriciaMerkleTrie.go index d9d8c162e1..161010752d 100644 --- a/trie/patriciaMerkleTrie.go +++ b/trie/patriciaMerkleTrie.go @@ -648,19 +648,22 @@ func (tr *patriciaMerkleTrie) GetProof(key []byte, rootHash []byte) ([][]byte, [ var proof [][]byte var errGet error - hexKey := keyBytesToHex(key) - currentNode := rootNode + + data := &nodeData{ + currentNode: rootNode, + encodedNode: encodedNode, + hexKey: keyBytesToHex(key), + } for { - proof = append(proof, encodedNode) - value := currentNode.getValue() + proof = append(proof, data.encodedNode) + value := data.currentNode.getValue() - currentNode, encodedNode, hexKey, errGet = currentNode.getNext(hexKey, tr.trieStorage) + data, errGet = data.currentNode.getNext(data.hexKey, tr.trieStorage) if errGet != nil { return nil, nil, errGet } - - if currentNode == nil { + if data == nil { return proof, value, nil } }