Skip to content

Commit

Permalink
Merge pull request #6734 from multiversx/get-proof-refactor
Browse files Browse the repository at this point in the history
refactor GetProof func and add some missing unit tests
  • Loading branch information
BeniaminDrasovean authored Feb 3, 2025
2 parents 778c10c + fab5175 commit 7dd3353
Show file tree
Hide file tree
Showing 10 changed files with 211 additions and 68 deletions.
22 changes: 13 additions & 9 deletions trie/branchNode.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,24 +309,28 @@ func (bn *branchNode) tryGet(key []byte, currentDepth uint32, db common.TrieStor
return child.tryGet(key, currentDepth+1, db)
}

func (bn *branchNode) getNext(key []byte, db common.TrieStorageInteractor) (node, []byte, error) {
func (bn *branchNode) getNext(key []byte, db common.TrieStorageInteractor) (*nodeData, error) {
if len(key) == 0 {
return nil, nil, ErrValueTooShort
return nil, ErrValueTooShort
}
childPos := key[firstByte]
if childPosOutOfRange(childPos) {
return nil, nil, ErrChildPosOutOfRange
return nil, ErrChildPosOutOfRange
}
key = key[1:]
_, err := bn.resolveIfCollapsed(childPos, db)
if len(bn.EncodedChildren[childPos]) == 0 {
return nil, ErrNodeNotFound
}
childNode, encodedNode, err := getNodeFromDBAndDecode(bn.EncodedChildren[childPos], db, bn.marsh, bn.hasher)
if err != nil {
return nil, nil, err
return nil, err
}

if bn.children[childPos] == nil {
return nil, nil, ErrNodeNotFound
}
return bn.children[childPos], key, nil
return &nodeData{
currentNode: childNode,
encodedNode: encodedNode,
hexKey: key,
}, nil
}

func (bn *branchNode) insert(
Expand Down
30 changes: 16 additions & 14 deletions trie/branchNode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -346,13 +346,17 @@ func TestBranchNode_getNext(t *testing.T) {
nextNode, _ := newLeafNode(getTrieDataWithDefaultVersion("dog", "dog"), bn.marsh, bn.hasher)
childPos := byte(2)
key := append([]byte{childPos}, []byte("dog")...)

n, key, err := bn.getNext(key, nil)
db := testscommon.NewMemDbMock()
bn.commitDirty(0, 5, getTestGoroutinesManager(), hashesCollector.NewDisabledHashesCollector(), db, db)
data, err := bn.getNext(key, db)
assert.NotNil(t, data)

h1, _ := encodeNodeAndGetHash(nextNode)
h2, _ := encodeNodeAndGetHash(n)
h2, _ := encodeNodeAndGetHash(data.currentNode)
nextNodeBytes, _ := nextNode.getEncodedNode()
assert.Equal(t, nextNodeBytes, data.encodedNode)
assert.Equal(t, h1, h2)
assert.Equal(t, []byte("dog"), key)
assert.Equal(t, []byte("dog"), data.hexKey)
assert.Nil(t, err)
}

Expand All @@ -362,9 +366,8 @@ func TestBranchNode_getNextWrongKey(t *testing.T) {
bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher())
key := []byte("dog")

n, key, err := bn.getNext(key, nil)
assert.Nil(t, n)
assert.Nil(t, key)
data, err := bn.getNext(key, nil)
assert.Nil(t, data)
assert.Equal(t, ErrChildPosOutOfRange, err)
}

Expand All @@ -375,9 +378,8 @@ func TestBranchNode_getNextNilChild(t *testing.T) {
nilChildPos := byte(4)
key := append([]byte{nilChildPos}, []byte("dog")...)

n, key, err := bn.getNext(key, nil)
assert.Nil(t, n)
assert.Nil(t, key)
data, err := bn.getNext(key, nil)
assert.Nil(t, data)
assert.Equal(t, ErrNodeNotFound, err)
}

Expand Down Expand Up @@ -458,8 +460,8 @@ func TestBranchNode_insertInStoredBnOnExistingPos(t *testing.T) {

bn.commitDirty(0, 5, getTestGoroutinesManager(), hashesCollector.NewDisabledHashesCollector(), db, db)
bnHash := bn.getHash()
ln, _, _ := bn.getNext(key, db)
lnHash := ln.getHash()
nd, _ := bn.getNext(key, db)
lnHash := nd.currentNode.getHash()
expectedHashes := [][]byte{lnHash, bnHash}

goRoutinesManager := getTestGoroutinesManager()
Expand Down Expand Up @@ -586,8 +588,8 @@ func TestBranchNode_deleteFromStoredBn(t *testing.T) {

bn.commitDirty(0, 5, getTestGoroutinesManager(), hashesCollector.NewDisabledHashesCollector(), db, db)
bnHash := bn.getHash()
ln, _, _ := bn.getNext(lnKey, db)
lnHash := ln.getHash()
nd, _ := bn.getNext(lnKey, db)
lnHash := nd.currentNode.getHash()
expectedHashes := [][]byte{lnHash, bnHash}

goRoutinesManager := getTestGoroutinesManager()
Expand Down
16 changes: 10 additions & 6 deletions trie/extensionNode.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,22 +246,26 @@ func (en *extensionNode) tryGet(key []byte, currentDepth uint32, db common.TrieS
return child.tryGet(key, currentDepth+1, db)
}

func (en *extensionNode) getNext(key []byte, db common.TrieStorageInteractor) (node, []byte, error) {
func (en *extensionNode) getNext(key []byte, db common.TrieStorageInteractor) (*nodeData, error) {
keyTooShort := len(key) < len(en.Key)
if keyTooShort {
return nil, nil, ErrNodeNotFound
return nil, ErrNodeNotFound
}
keysDontMatch := !bytes.Equal(en.Key, key[:len(en.Key)])
if keysDontMatch {
return nil, nil, ErrNodeNotFound
return nil, ErrNodeNotFound
}
childNode, err := en.resolveIfCollapsed(db)
child, encodedChild, err := getNodeFromDBAndDecode(en.EncodedChild, db, en.marsh, en.hasher)
if err != nil {
return nil, nil, err
return nil, err
}

key = key[len(en.Key):]
return childNode, key, nil
return &nodeData{
currentNode: child,
encodedNode: encodedChild,
hexKey: key,
}, nil
}

func (en *extensionNode) insert(
Expand Down
27 changes: 15 additions & 12 deletions trie/extensionNode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,17 +275,21 @@ func TestExtensionNode_getNext(t *testing.T) {
t.Parallel()

en, _ := getEnAndCollapsedEn()
nextNode, _ := getBnAndCollapsedBn(en.marsh, en.hasher)
db := testscommon.NewMemDbMock()
en.commitDirty(0, 5, getTestGoroutinesManager(), hashesCollector.NewDisabledHashesCollector(), db, db)

enKey := []byte{100}
bnKey := []byte{2}
lnKey := []byte("dog")
key := append(enKey, bnKey...)
key = append(key, lnKey...)

n, newKey, err := en.getNext(key, nil)
assert.Equal(t, nextNode, n)
assert.Equal(t, key[1:], newKey)
data, err := en.getNext(key, db)
child, childBytes, _ := getNodeFromDBAndDecode(en.EncodedChild, db, en.marsh, en.hasher)
assert.NotNil(t, data)
assert.Equal(t, childBytes, data.encodedNode)
assert.Equal(t, child, data.currentNode)
assert.Equal(t, key[1:], data.hexKey)
assert.Nil(t, err)
}

Expand All @@ -297,9 +301,8 @@ func TestExtensionNode_getNextWrongKey(t *testing.T) {
lnKey := []byte("dog")
key := append(bnKey, lnKey...)

n, key, err := en.getNext(key, nil)
assert.Nil(t, n)
assert.Nil(t, key)
data, err := en.getNext(key, nil)
assert.Nil(t, data)
assert.Equal(t, ErrNodeNotFound, err)
}

Expand Down Expand Up @@ -352,8 +355,8 @@ func TestExtensionNode_insertInStoredEnSameKey(t *testing.T) {

en.commitDirty(0, 5, getTestGoroutinesManager(), hashesCollector.NewDisabledHashesCollector(), db, db)
enHash := en.getHash()
bn, _, _ := en.getNext(enKey, db)
bnHash := bn.getHash()
nd, _ := en.getNext(enKey, db)
bnHash := nd.currentNode.getHash()
expectedHashes := [][]byte{bnHash, enHash}

goRoutinesManager := getTestGoroutinesManager()
Expand Down Expand Up @@ -461,9 +464,9 @@ func TestExtensionNode_deleteFromStoredEn(t *testing.T) {
en.setHash(getTestGoroutinesManager())

en.commitDirty(0, 5, getTestGoroutinesManager(), hashesCollector.NewDisabledHashesCollector(), db, db)
bn, key, _ := en.getNext(key, db)
ln, _, _ := bn.getNext(key, db)
expectedHashes := [][]byte{ln.getHash(), bn.getHash(), en.getHash()}
bnData, _ := en.getNext(key, db)
lnData, _ := bnData.currentNode.getNext(bnData.hexKey, db)
expectedHashes := [][]byte{lnData.currentNode.getHash(), bnData.currentNode.getHash(), en.getHash()}
data := []core.TrieData{{Key: lnPathKey}}

goRoutinesManager := getTestGoroutinesManager()
Expand Down
8 changes: 7 additions & 1 deletion trie/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ import (
vmcommon "github.com/multiversx/mx-chain-vm-common-go"
)

type nodeData struct {
currentNode node
encodedNode []byte
hexKey []byte
}

type baseTrieNode interface {
getHash() []byte
setGivenHash([]byte)
Expand All @@ -28,7 +34,7 @@ type node interface {
setHash(goRoutinesManager common.TrieGoroutinesManager)
getEncodedNode() ([]byte, error)
tryGet(key []byte, depth uint32, db common.TrieStorageInteractor) ([]byte, uint32, error)
getNext(key []byte, db common.TrieStorageInteractor) (node, []byte, error)
getNext(key []byte, db common.TrieStorageInteractor) (*nodeData, error)
insert(newData []core.TrieData, goRoutinesManager common.TrieGoroutinesManager, modifiedHashes common.AtomicBytesSlice, db common.TrieStorageInteractor) node
delete(data []core.TrieData, goRoutinesManager common.TrieGoroutinesManager, modifiedHashes common.AtomicBytesSlice, db common.TrieStorageInteractor) (bool, node)
reduceNode(pos int, db common.TrieStorageInteractor) (node, bool, error)
Expand Down
6 changes: 3 additions & 3 deletions trie/leafNode.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,11 @@ func (ln *leafNode) tryGet(key []byte, currentDepth uint32, _ common.TrieStorage
return nil, currentDepth, nil
}

func (ln *leafNode) getNext(key []byte, _ common.TrieStorageInteractor) (node, []byte, error) {
func (ln *leafNode) getNext(key []byte, _ common.TrieStorageInteractor) (*nodeData, error) {
if bytes.Equal(key, ln.Key) {
return nil, nil, nil
return nil, nil
}
return nil, nil, ErrNodeNotFound
return nil, ErrNodeNotFound
}

func (ln *leafNode) insert(
Expand Down
10 changes: 4 additions & 6 deletions trie/leafNode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,8 @@ func TestLeafNode_getNext(t *testing.T) {
ln := getLn(getTestMarshalizerAndHasher())
key := []byte("dog")

n, key, err := ln.getNext(key, nil)
assert.Nil(t, n)
assert.Nil(t, key)
data, err := ln.getNext(key, nil)
assert.Nil(t, data)
assert.Nil(t, err)
}

Expand All @@ -177,9 +176,8 @@ func TestLeafNode_getNextWrongKey(t *testing.T) {
ln := getLn(getTestMarshalizerAndHasher())
wrongKey := append([]byte{2}, []byte("dog")...)

n, key, err := ln.getNext(wrongKey, nil)
assert.Nil(t, n)
assert.Nil(t, key)
data, err := ln.getNext(wrongKey, nil)
assert.Nil(t, data)
assert.Equal(t, ErrNodeNotFound, err)
}

Expand Down
28 changes: 13 additions & 15 deletions trie/patriciaMerkleTrie.go
Original file line number Diff line number Diff line change
Expand Up @@ -637,35 +637,33 @@ func logMapWithTrace(message string, paramName string, hashes common.ModifiedHas

// GetProof computes a Merkle proof for the node that is present at the given key
func (tr *patriciaMerkleTrie) GetProof(key []byte, rootHash []byte) ([][]byte, []byte, error) {
//TODO refactor this function to avoid encoding the node after it is retrieved from the DB.
// The encoded node is actually the value from db, thus we can use the retrieved value directly
if len(key) == 0 || bytes.Equal(rootHash, common.EmptyTrieHash) {
if common.IsEmptyTrie(rootHash) {
return nil, nil, ErrNilNode
}

rootNode, _, err := getNodeFromDBAndDecode(rootHash, tr.trieStorage, tr.marshalizer, tr.hasher)
rootNode, encodedNode, err := getNodeFromDBAndDecode(rootHash, tr.trieStorage, tr.marshalizer, tr.hasher)
if err != nil {
return nil, nil, fmt.Errorf("trie get proof error: %w", err)
}

var proof [][]byte
hexKey := keyBytesToHex(key)
currentNode := rootNode
var errGet error

data := &nodeData{
currentNode: rootNode,
encodedNode: encodedNode,
hexKey: keyBytesToHex(key),
}

for {
encodedNode, errGet := currentNode.getEncodedNode()
if errGet != nil {
return nil, nil, errGet
}
proof = append(proof, encodedNode)
value := currentNode.getValue()
proof = append(proof, data.encodedNode)
value := data.currentNode.getValue()

currentNode, hexKey, errGet = currentNode.getNext(hexKey, tr.trieStorage)
data, errGet = data.currentNode.getNext(data.hexKey, tr.trieStorage)
if errGet != nil {
return nil, nil, errGet
}

if currentNode == nil {
if data == nil {
return proof, value, nil
}
}
Expand Down
2 changes: 0 additions & 2 deletions trie/rootManager.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ import (
"github.com/multiversx/mx-chain-core-go/core/check"
)

// TODO: add unit tests

type rootManager struct {
root node
oldHashes [][]byte
Expand Down
Loading

0 comments on commit 7dd3353

Please sign in to comment.