Skip to content

Commit

Permalink
Merge branch 'feat/trie-mutex-refactor' into trie-refactor-logging
Browse files Browse the repository at this point in the history
  • Loading branch information
BeniaminDrasovean authored Feb 3, 2025
2 parents ad95ed2 + 7dd3353 commit 4faa114
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 59 deletions.
16 changes: 10 additions & 6 deletions trie/branchNode.go
Original file line number Diff line number Diff line change
Expand Up @@ -309,24 +309,28 @@ func (bn *branchNode) tryGet(key []byte, currentDepth uint32, db common.TrieStor
return child.tryGet(key, currentDepth+1, db)
}

func (bn *branchNode) getNext(key []byte, db common.TrieStorageInteractor) (node, []byte, []byte, error) {
func (bn *branchNode) getNext(key []byte, db common.TrieStorageInteractor) (*nodeData, error) {
if len(key) == 0 {
return nil, nil, nil, ErrValueTooShort
return nil, ErrValueTooShort
}
childPos := key[firstByte]
if childPosOutOfRange(childPos) {
return nil, nil, nil, ErrChildPosOutOfRange
return nil, ErrChildPosOutOfRange
}
key = key[1:]
if len(bn.EncodedChildren[childPos]) == 0 {
return nil, nil, nil, ErrNodeNotFound
return nil, ErrNodeNotFound
}
childNode, encodedNode, err := getNodeFromDBAndDecode(bn.EncodedChildren[childPos], db, bn.marsh, bn.hasher)
if err != nil {
return nil, nil, nil, err
return nil, err
}

return childNode, encodedNode, key, nil
return &nodeData{
currentNode: childNode,
encodedNode: encodedNode,
hexKey: key,
}, nil
}

func (bn *branchNode) insert(
Expand Down
29 changes: 13 additions & 16 deletions trie/branchNode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,14 +348,15 @@ func TestBranchNode_getNext(t *testing.T) {
key := append([]byte{childPos}, []byte("dog")...)
db := testscommon.NewMemDbMock()
bn.commitDirty(0, 5, getTestGoroutinesManager(), hashesCollector.NewDisabledHashesCollector(), db, db)
n, nodeBytes, key, err := bn.getNext(key, db)
data, err := bn.getNext(key, db)
assert.NotNil(t, data)

h1, _ := encodeNodeAndGetHash(nextNode)
h2, _ := encodeNodeAndGetHash(n)
h2, _ := encodeNodeAndGetHash(data.currentNode)
nextNodeBytes, _ := nextNode.getEncodedNode()
assert.Equal(t, nextNodeBytes, nodeBytes)
assert.Equal(t, nextNodeBytes, data.encodedNode)
assert.Equal(t, h1, h2)
assert.Equal(t, []byte("dog"), key)
assert.Equal(t, []byte("dog"), data.hexKey)
assert.Nil(t, err)
}

Expand All @@ -365,10 +366,8 @@ func TestBranchNode_getNextWrongKey(t *testing.T) {
bn, _ := getBnAndCollapsedBn(getTestMarshalizerAndHasher())
key := []byte("dog")

n, nodeBytes, key, err := bn.getNext(key, nil)
assert.Nil(t, n)
assert.Nil(t, key)
assert.Nil(t, nodeBytes)
data, err := bn.getNext(key, nil)
assert.Nil(t, data)
assert.Equal(t, ErrChildPosOutOfRange, err)
}

Expand All @@ -379,10 +378,8 @@ func TestBranchNode_getNextNilChild(t *testing.T) {
nilChildPos := byte(4)
key := append([]byte{nilChildPos}, []byte("dog")...)

n, nodeBytes, key, err := bn.getNext(key, nil)
assert.Nil(t, n)
assert.Nil(t, key)
assert.Nil(t, nodeBytes)
data, err := bn.getNext(key, nil)
assert.Nil(t, data)
assert.Equal(t, ErrNodeNotFound, err)
}

Expand Down Expand Up @@ -463,8 +460,8 @@ func TestBranchNode_insertInStoredBnOnExistingPos(t *testing.T) {

bn.commitDirty(0, 5, getTestGoroutinesManager(), hashesCollector.NewDisabledHashesCollector(), db, db)
bnHash := bn.getHash()
ln, _, _, _ := bn.getNext(key, db)
lnHash := ln.getHash()
nd, _ := bn.getNext(key, db)
lnHash := nd.currentNode.getHash()
expectedHashes := [][]byte{lnHash, bnHash}

goRoutinesManager := getTestGoroutinesManager()
Expand Down Expand Up @@ -591,8 +588,8 @@ func TestBranchNode_deleteFromStoredBn(t *testing.T) {

bn.commitDirty(0, 5, getTestGoroutinesManager(), hashesCollector.NewDisabledHashesCollector(), db, db)
bnHash := bn.getHash()
ln, _, _, _ := bn.getNext(lnKey, db)
lnHash := ln.getHash()
nd, _ := bn.getNext(lnKey, db)
lnHash := nd.currentNode.getHash()
expectedHashes := [][]byte{lnHash, bnHash}

goRoutinesManager := getTestGoroutinesManager()
Expand Down
14 changes: 9 additions & 5 deletions trie/extensionNode.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,22 +246,26 @@ func (en *extensionNode) tryGet(key []byte, currentDepth uint32, db common.TrieS
return child.tryGet(key, currentDepth+1, db)
}

func (en *extensionNode) getNext(key []byte, db common.TrieStorageInteractor) (node, []byte, []byte, error) {
func (en *extensionNode) getNext(key []byte, db common.TrieStorageInteractor) (*nodeData, error) {
keyTooShort := len(key) < len(en.Key)
if keyTooShort {
return nil, nil, nil, ErrNodeNotFound
return nil, ErrNodeNotFound
}
keysDontMatch := !bytes.Equal(en.Key, key[:len(en.Key)])
if keysDontMatch {
return nil, nil, nil, ErrNodeNotFound
return nil, ErrNodeNotFound
}
child, encodedChild, err := getNodeFromDBAndDecode(en.EncodedChild, db, en.marsh, en.hasher)
if err != nil {
return nil, nil, nil, err
return nil, err
}

key = key[len(en.Key):]
return child, encodedChild, key, nil
return &nodeData{
currentNode: child,
encodedNode: encodedChild,
hexKey: key,
}, nil
}

func (en *extensionNode) insert(
Expand Down
25 changes: 12 additions & 13 deletions trie/extensionNode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -284,11 +284,12 @@ func TestExtensionNode_getNext(t *testing.T) {
key := append(enKey, bnKey...)
key = append(key, lnKey...)

n, nodeBytes, newKey, err := en.getNext(key, db)
data, err := en.getNext(key, db)
child, childBytes, _ := getNodeFromDBAndDecode(en.EncodedChild, db, en.marsh, en.hasher)
assert.Equal(t, childBytes, nodeBytes)
assert.Equal(t, child, n)
assert.Equal(t, key[1:], newKey)
assert.NotNil(t, data)
assert.Equal(t, childBytes, data.encodedNode)
assert.Equal(t, child, data.currentNode)
assert.Equal(t, key[1:], data.hexKey)
assert.Nil(t, err)
}

Expand All @@ -300,10 +301,8 @@ func TestExtensionNode_getNextWrongKey(t *testing.T) {
lnKey := []byte("dog")
key := append(bnKey, lnKey...)

n, nodeBytes, key, err := en.getNext(key, nil)
assert.Nil(t, n)
assert.Nil(t, key)
assert.Nil(t, nodeBytes)
data, err := en.getNext(key, nil)
assert.Nil(t, data)
assert.Equal(t, ErrNodeNotFound, err)
}

Expand Down Expand Up @@ -356,8 +355,8 @@ func TestExtensionNode_insertInStoredEnSameKey(t *testing.T) {

en.commitDirty(0, 5, getTestGoroutinesManager(), hashesCollector.NewDisabledHashesCollector(), db, db)
enHash := en.getHash()
bn, _, _, _ := en.getNext(enKey, db)
bnHash := bn.getHash()
nd, _ := en.getNext(enKey, db)
bnHash := nd.currentNode.getHash()
expectedHashes := [][]byte{bnHash, enHash}

goRoutinesManager := getTestGoroutinesManager()
Expand Down Expand Up @@ -465,9 +464,9 @@ func TestExtensionNode_deleteFromStoredEn(t *testing.T) {
en.setHash(getTestGoroutinesManager())

en.commitDirty(0, 5, getTestGoroutinesManager(), hashesCollector.NewDisabledHashesCollector(), db, db)
bn, _, key, _ := en.getNext(key, db)
ln, _, _, _ := bn.getNext(key, db)
expectedHashes := [][]byte{ln.getHash(), bn.getHash(), en.getHash()}
bnData, _ := en.getNext(key, db)
lnData, _ := bnData.currentNode.getNext(bnData.hexKey, db)
expectedHashes := [][]byte{lnData.currentNode.getHash(), bnData.currentNode.getHash(), en.getHash()}
data := []core.TrieData{{Key: lnPathKey}}

goRoutinesManager := getTestGoroutinesManager()
Expand Down
8 changes: 7 additions & 1 deletion trie/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ import (
vmcommon "github.com/multiversx/mx-chain-vm-common-go"
)

type nodeData struct {
currentNode node
encodedNode []byte
hexKey []byte
}

type baseTrieNode interface {
getHash() []byte
setGivenHash([]byte)
Expand All @@ -28,7 +34,7 @@ type node interface {
setHash(goRoutinesManager common.TrieGoroutinesManager)
getEncodedNode() ([]byte, error)
tryGet(key []byte, depth uint32, db common.TrieStorageInteractor) ([]byte, uint32, error)
getNext(key []byte, db common.TrieStorageInteractor) (node, []byte, []byte, error)
getNext(key []byte, db common.TrieStorageInteractor) (*nodeData, error)
insert(newData []core.TrieData, goRoutinesManager common.TrieGoroutinesManager, modifiedHashes common.AtomicBytesSlice, db common.TrieStorageInteractor) node
delete(data []core.TrieData, goRoutinesManager common.TrieGoroutinesManager, modifiedHashes common.AtomicBytesSlice, db common.TrieStorageInteractor) (bool, node)
reduceNode(pos int, db common.TrieStorageInteractor) (node, bool, error)
Expand Down
6 changes: 3 additions & 3 deletions trie/leafNode.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,11 @@ func (ln *leafNode) tryGet(key []byte, currentDepth uint32, _ common.TrieStorage
return nil, currentDepth, nil
}

func (ln *leafNode) getNext(key []byte, _ common.TrieStorageInteractor) (node, []byte, []byte, error) {
func (ln *leafNode) getNext(key []byte, _ common.TrieStorageInteractor) (*nodeData, error) {
if bytes.Equal(key, ln.Key) {
return nil, nil, nil, nil
return nil, nil
}
return nil, nil, nil, ErrNodeNotFound
return nil, ErrNodeNotFound
}

func (ln *leafNode) insert(
Expand Down
12 changes: 4 additions & 8 deletions trie/leafNode_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,8 @@ func TestLeafNode_getNext(t *testing.T) {
ln := getLn(getTestMarshalizerAndHasher())
key := []byte("dog")

n, nodeBytes, key, err := ln.getNext(key, nil)
assert.Nil(t, n)
assert.Nil(t, key)
assert.Nil(t, nodeBytes)
data, err := ln.getNext(key, nil)
assert.Nil(t, data)
assert.Nil(t, err)
}

Expand All @@ -178,10 +176,8 @@ func TestLeafNode_getNextWrongKey(t *testing.T) {
ln := getLn(getTestMarshalizerAndHasher())
wrongKey := append([]byte{2}, []byte("dog")...)

n, nodeBytes, key, err := ln.getNext(wrongKey, nil)
assert.Nil(t, n)
assert.Nil(t, key)
assert.Nil(t, nodeBytes)
data, err := ln.getNext(wrongKey, nil)
assert.Nil(t, data)
assert.Equal(t, ErrNodeNotFound, err)
}

Expand Down
17 changes: 10 additions & 7 deletions trie/patriciaMerkleTrie.go
Original file line number Diff line number Diff line change
Expand Up @@ -653,19 +653,22 @@ func (tr *patriciaMerkleTrie) GetProof(key []byte, rootHash []byte) ([][]byte, [

var proof [][]byte
var errGet error
hexKey := keyBytesToHex(key)
currentNode := rootNode

data := &nodeData{
currentNode: rootNode,
encodedNode: encodedNode,
hexKey: keyBytesToHex(key),
}

for {
proof = append(proof, encodedNode)
value := currentNode.getValue()
proof = append(proof, data.encodedNode)
value := data.currentNode.getValue()

currentNode, encodedNode, hexKey, errGet = currentNode.getNext(hexKey, tr.trieStorage)
data, errGet = data.currentNode.getNext(data.hexKey, tr.trieStorage)
if errGet != nil {
return nil, nil, errGet
}

if currentNode == nil {
if data == nil {
return proof, value, nil
}
}
Expand Down

0 comments on commit 4faa114

Please sign in to comment.