diff --git a/internal/trie/node/children.go b/internal/trie/node/children.go index 725366b42e..2a40319a76 100644 --- a/internal/trie/node/children.go +++ b/internal/trie/node/children.go @@ -40,3 +40,7 @@ func (n *Node) HasChild() (has bool) { } return false } + +func (n *Node) ChildAt(i uint) *Node { + return n.Children[i] +} diff --git a/internal/trie/node/decode.go b/internal/trie/node/decode.go index 8fba10c963..3fc45987e4 100644 --- a/internal/trie/node/decode.go +++ b/internal/trie/node/decode.go @@ -25,6 +25,8 @@ var ( ErrDecodeChildHash = errors.New("cannot decode child hash") ) +var EmptyNode = &Node{} + const hashLength = common.HashLength // Decode decodes a node from a reader. @@ -34,21 +36,21 @@ const hashLength = common.HashLength // For branch decoding, see the comments on decodeBranch. // For leaf decoding, see the comments on decodeLeaf. func Decode(reader io.Reader) (n *Node, err error) { - variant, partialKeyLength, err := decodeHeader(reader) + variant, partialKeyLength, err := DecodeHeader(reader) if err != nil { return nil, fmt.Errorf("decoding header: %w", err) } switch variant { - case emptyVariant: - return nil, nil //nolint:nilnil - case leafVariant, leafWithHashedValueVariant: + case EmptyVariant: + return EmptyNode, nil + case LeafVariant, LeafWithHashedValueVariant: n, err = decodeLeaf(reader, variant, partialKeyLength) if err != nil { return nil, fmt.Errorf("cannot decode leaf: %w", err) } return n, nil - case branchVariant, branchWithValueVariant, branchWithHashedValueVariant: + case BranchVariant, BranchWithValueVariant, BranchWithHashedValueVariant: n, err = decodeBranch(reader, variant, partialKeyLength) if err != nil { return nil, fmt.Errorf("cannot decode branch: %w", err) @@ -65,13 +67,13 @@ func Decode(reader io.Reader) (n *Node, err error) { // reconstructing the child nodes from the encoding. This function instead stubs where the // children are known to be with an empty leaf. The children nodes hashes are then used to // find other storage values using the persistent database. -func decodeBranch(reader io.Reader, variant variant, partialKeyLength uint16) ( +func decodeBranch(reader io.Reader, variant Variant, partialKeyLength uint16) ( node *Node, err error) { node = &Node{ Children: make([]*Node, ChildrenCapacity), } - node.PartialKey, err = decodeKey(reader, partialKeyLength) + node.PartialKey, err = DecodeKey(reader, partialKeyLength) if err != nil { return nil, fmt.Errorf("cannot decode key: %w", err) } @@ -85,12 +87,12 @@ func decodeBranch(reader io.Reader, variant variant, partialKeyLength uint16) ( sd := scale.NewDecoder(reader) switch variant { - case branchWithValueVariant: + case BranchWithValueVariant: err := sd.Decode(&node.StorageValue) if err != nil { return nil, fmt.Errorf("%w: %s", ErrDecodeStorageValue, err) } - case branchWithHashedValueVariant: + case BranchWithHashedValueVariant: hashedValue, err := decodeHashedValue(reader) if err != nil { return nil, err @@ -134,17 +136,17 @@ func decodeBranch(reader io.Reader, variant variant, partialKeyLength uint16) ( } // decodeLeaf reads from a reader and decodes to a leaf node. -func decodeLeaf(reader io.Reader, variant variant, partialKeyLength uint16) (node *Node, err error) { +func decodeLeaf(reader io.Reader, variant Variant, partialKeyLength uint16) (node *Node, err error) { node = &Node{} - node.PartialKey, err = decodeKey(reader, partialKeyLength) + node.PartialKey, err = DecodeKey(reader, partialKeyLength) if err != nil { return nil, fmt.Errorf("cannot decode key: %w", err) } sd := scale.NewDecoder(reader) - if variant == leafWithHashedValueVariant { + if variant == LeafWithHashedValueVariant { hashedValue, err := decodeHashedValue(reader) if err != nil { return nil, err diff --git a/internal/trie/node/decode_test.go b/internal/trie/node/decode_test.go index 56983161b7..3fed89cff5 100644 --- a/internal/trie/node/decode_test.go +++ b/internal/trie/node/decode_test.go @@ -59,12 +59,12 @@ func Test_Decode(t *testing.T) { errMessage: "decoding header: decoding header byte: node variant is unknown: for header byte 00001000", }, "empty_node": { - reader: bytes.NewReader([]byte{emptyVariant.bits}), - n: nil, + reader: bytes.NewReader([]byte{EmptyVariant.bits}), + n: EmptyNode, }, "leaf_decoding_error": { reader: bytes.NewReader([]byte{ - leafVariant.bits | 1, // key length 1 + LeafVariant.bits | 1, // key length 1 // missing key data byte }), errWrapped: io.EOF, @@ -73,7 +73,7 @@ func Test_Decode(t *testing.T) { }, "leaf_success": { reader: bytes.NewReader(concatByteSlices([][]byte{ - {leafVariant.bits | 1}, // partial key length 1 + {LeafVariant.bits | 1}, // partial key length 1 {9}, // key data scaleEncodeBytes(t, 1, 2, 3), })), @@ -84,7 +84,7 @@ func Test_Decode(t *testing.T) { }, "branch_decoding_error": { reader: bytes.NewReader([]byte{ - branchVariant.bits | 1, // key length 1 + BranchVariant.bits | 1, // key length 1 // missing key data byte }), errWrapped: io.EOF, @@ -93,7 +93,7 @@ func Test_Decode(t *testing.T) { }, "branch_success": { reader: bytes.NewReader(concatByteSlices([][]byte{ - {branchVariant.bits | 1}, // partial key length 1 + {BranchVariant.bits | 1}, // partial key length 1 {9}, // key data {0b0000_0000, 0b0000_0000}, // no children bitmap })), @@ -104,7 +104,7 @@ func Test_Decode(t *testing.T) { }, "leaf_with_hashed_value_success": { reader: bytes.NewReader(concatByteSlices([][]byte{ - {leafWithHashedValueVariant.bits | 1}, // partial key length 1 + {LeafWithHashedValueVariant.bits | 1}, // partial key length 1 {9}, // key data hashedValue.ToBytes(), })), @@ -116,7 +116,7 @@ func Test_Decode(t *testing.T) { }, "leaf_with_hashed_value_fail_too_short": { reader: bytes.NewReader(concatByteSlices([][]byte{ - {leafWithHashedValueVariant.bits | 1}, // partial key length 1 + {LeafWithHashedValueVariant.bits | 1}, // partial key length 1 {9}, // key data {0b0000_0000}, // less than 32bytes })), @@ -125,7 +125,7 @@ func Test_Decode(t *testing.T) { }, "branch_with_hashed_value_success": { reader: bytes.NewReader(concatByteSlices([][]byte{ - {branchWithHashedValueVariant.bits | 1}, // partial key length 1 + {BranchWithHashedValueVariant.bits | 1}, // partial key length 1 {9}, // key data {0b0000_0000, 0b0000_0000}, // no children bitmap hashedValue.ToBytes(), @@ -139,7 +139,7 @@ func Test_Decode(t *testing.T) { }, "branch_with_hashed_value_fail_too_short": { reader: bytes.NewReader(concatByteSlices([][]byte{ - {branchWithHashedValueVariant.bits | 1}, // partial key length 1 + {BranchWithHashedValueVariant.bits | 1}, // partial key length 1 {9}, // key data {0b0000_0000, 0b0000_0000}, // no children bitmap {0b0000_0000}, @@ -177,7 +177,7 @@ func Test_decodeBranch(t *testing.T) { testCases := map[string]struct { reader io.Reader - nodeVariant variant + nodeVariant Variant partialKeyLength uint16 branch *Node errWrapped error @@ -187,7 +187,7 @@ func Test_decodeBranch(t *testing.T) { reader: bytes.NewBuffer([]byte{ // missing key data byte }), - nodeVariant: branchVariant, + nodeVariant: BranchVariant, partialKeyLength: 1, errWrapped: io.EOF, errMessage: "cannot decode key: reading from reader: EOF", @@ -197,7 +197,7 @@ func Test_decodeBranch(t *testing.T) { 9, // key data // missing children bitmap 2 bytes }), - nodeVariant: branchVariant, + nodeVariant: BranchVariant, partialKeyLength: 1, errWrapped: ErrReadChildrenBitmap, errMessage: "cannot read children bitmap: EOF", @@ -208,7 +208,7 @@ func Test_decodeBranch(t *testing.T) { 0, 4, // children bitmap // missing children scale encoded data }), - nodeVariant: branchVariant, + nodeVariant: BranchVariant, partialKeyLength: 1, errWrapped: ErrDecodeChildHash, errMessage: "cannot decode child hash: at index 10: reading byte: EOF", @@ -221,7 +221,7 @@ func Test_decodeBranch(t *testing.T) { scaleEncodedChildHash, }), ), - nodeVariant: branchVariant, + nodeVariant: BranchVariant, partialKeyLength: 1, branch: &Node{ PartialKey: []byte{9}, @@ -243,7 +243,7 @@ func Test_decodeBranch(t *testing.T) { // missing encoded branch storage value }), ), - nodeVariant: branchWithValueVariant, + nodeVariant: BranchWithValueVariant, partialKeyLength: 1, errWrapped: ErrDecodeStorageValue, errMessage: "cannot decode storage value: reading byte: EOF", @@ -255,7 +255,7 @@ func Test_decodeBranch(t *testing.T) { scaleEncodeBytes(t, 7, 8, 9), // branch storage value scaleEncodedChildHash, })), - nodeVariant: branchWithValueVariant, + nodeVariant: BranchWithValueVariant, partialKeyLength: 1, branch: &Node{ PartialKey: []byte{9}, @@ -277,7 +277,7 @@ func Test_decodeBranch(t *testing.T) { scaleEncodeBytes(t, 1), // branch storage value {0}, // garbage inlined node })), - nodeVariant: branchWithValueVariant, + nodeVariant: BranchWithValueVariant, partialKeyLength: 1, errWrapped: io.EOF, errMessage: "decoding inlined child at index 0: " + @@ -289,25 +289,25 @@ func Test_decodeBranch(t *testing.T) { {0b0000_0011, 0b0000_0000}, // children bitmap // top level inlined leaf less than 32 bytes scaleEncodeByteSlice(t, concatByteSlices([][]byte{ - {leafVariant.bits | 1}, // partial key length of 1 + {LeafVariant.bits | 1}, // partial key length of 1 {2}, // key data scaleEncodeBytes(t, 2), // storage value data })), // top level inlined branch less than 32 bytes scaleEncodeByteSlice(t, concatByteSlices([][]byte{ - {branchWithValueVariant.bits | 1}, // partial key length of 1 + {BranchWithValueVariant.bits | 1}, // partial key length of 1 {3}, // key data {0b0000_0001, 0b0000_0000}, // children bitmap scaleEncodeBytes(t, 3), // branch storage value // bottom level leaf scaleEncodeByteSlice(t, concatByteSlices([][]byte{ - {leafVariant.bits | 1}, // partial key length of 1 + {LeafVariant.bits | 1}, // partial key length of 1 {4}, // key data scaleEncodeBytes(t, 4), // storage value data })), })), })), - nodeVariant: branchVariant, + nodeVariant: BranchVariant, partialKeyLength: 1, branch: &Node{ PartialKey: []byte{1}, @@ -349,7 +349,7 @@ func Test_decodeLeaf(t *testing.T) { testCases := map[string]struct { reader io.Reader - variant variant + variant Variant partialKeyLength uint16 leaf *Node errWrapped error @@ -359,7 +359,7 @@ func Test_decodeLeaf(t *testing.T) { reader: bytes.NewBuffer([]byte{ // missing key data byte }), - variant: leafVariant, + variant: LeafVariant, partialKeyLength: 1, errWrapped: io.EOF, errMessage: "cannot decode key: reading from reader: EOF", @@ -369,7 +369,7 @@ func Test_decodeLeaf(t *testing.T) { {9}, // key data {255, 255}, // bad storage value data })), - variant: leafVariant, + variant: LeafVariant, partialKeyLength: 1, errWrapped: ErrDecodeStorageValue, errMessage: "cannot decode storage value: unknown prefix for compact uint: 255", @@ -379,7 +379,7 @@ func Test_decodeLeaf(t *testing.T) { 9, // key data // missing storage value data }), - variant: leafVariant, + variant: LeafVariant, partialKeyLength: 1, errWrapped: ErrDecodeStorageValue, errMessage: "cannot decode storage value: reading byte: EOF", @@ -389,7 +389,7 @@ func Test_decodeLeaf(t *testing.T) { {9}, // key data scaleEncodeByteSlice(t, []byte{}), // results to []byte{0} })), - variant: leafVariant, + variant: LeafVariant, partialKeyLength: 1, leaf: &Node{ PartialKey: []byte{9}, @@ -401,7 +401,7 @@ func Test_decodeLeaf(t *testing.T) { {9}, // key data scaleEncodeBytes(t, 1, 2, 3, 4, 5), // storage value data })), - variant: leafVariant, + variant: LeafVariant, partialKeyLength: 1, leaf: &Node{ PartialKey: []byte{9}, diff --git a/internal/trie/node/encode_decode_test.go b/internal/trie/node/encode_decode_test.go index 71974ff907..e1b3da06a2 100644 --- a/internal/trie/node/encode_decode_test.go +++ b/internal/trie/node/encode_decode_test.go @@ -121,7 +121,7 @@ func Test_Branch_Encode_Decode(t *testing.T) { err := testCase.branchToEncode.Encode(buffer) require.NoError(t, err) - nodeVariant, partialKeyLength, err := decodeHeader(buffer) + nodeVariant, partialKeyLength, err := DecodeHeader(buffer) require.NoError(t, err) resultBranch, err := decodeBranch(buffer, nodeVariant, partialKeyLength) diff --git a/internal/trie/node/encode_test.go b/internal/trie/node/encode_test.go index 53707e5a3f..2a755bb43e 100644 --- a/internal/trie/node/encode_test.go +++ b/internal/trie/node/encode_test.go @@ -37,7 +37,7 @@ func Test_Node_Encode(t *testing.T) { node: nil, writes: []writeCall{ { - written: []byte{emptyVariant.bits}, + written: []byte{EmptyVariant.bits}, }, }, }, @@ -47,7 +47,7 @@ func Test_Node_Encode(t *testing.T) { }, writes: []writeCall{ { - written: []byte{leafVariant.bits | 1}, + written: []byte{LeafVariant.bits | 1}, err: errTest, }, }, @@ -61,7 +61,7 @@ func Test_Node_Encode(t *testing.T) { }, writes: []writeCall{ { - written: []byte{leafVariant.bits | 3}, // partial key length 3 + written: []byte{LeafVariant.bits | 3}, // partial key length 3 }, { written: []byte{0x01, 0x23}, @@ -78,7 +78,7 @@ func Test_Node_Encode(t *testing.T) { }, writes: []writeCall{ { - written: []byte{leafVariant.bits | 3}, // partial key length 3 + written: []byte{LeafVariant.bits | 3}, // partial key length 3 }, { written: []byte{0x01, 0x23}, @@ -98,7 +98,7 @@ func Test_Node_Encode(t *testing.T) { }, writes: []writeCall{ { - written: []byte{leafVariant.bits | 3}, // partial key length 3 + written: []byte{LeafVariant.bits | 3}, // partial key length 3 }, {written: []byte{0x01, 0x23}}, {written: []byte{12}}, @@ -111,7 +111,7 @@ func Test_Node_Encode(t *testing.T) { StorageValue: []byte{}, }, writes: []writeCall{ - {written: []byte{leafVariant.bits | 3}}, // partial key length 3 + {written: []byte{LeafVariant.bits | 3}}, // partial key length 3 {written: []byte{0x01, 0x23}}, // partial key {written: []byte{0}}, // node storage value encoded length {written: []byte{}}, // node storage value @@ -125,7 +125,7 @@ func Test_Node_Encode(t *testing.T) { }, writes: []writeCall{ { - written: []byte{leafWithHashedValueVariant.bits | 3}, + written: []byte{LeafWithHashedValueVariant.bits | 3}, }, {written: []byte{0x01, 0x23}}, {written: hashedValue.ToBytes()}, @@ -139,7 +139,7 @@ func Test_Node_Encode(t *testing.T) { }, writes: []writeCall{ { - written: []byte{leafWithHashedValueVariant.bits | 3}, + written: []byte{LeafWithHashedValueVariant.bits | 3}, }, { written: []byte{0x01, 0x23}, @@ -160,7 +160,7 @@ func Test_Node_Encode(t *testing.T) { }, writes: []writeCall{ { - written: []byte{leafWithHashedValueVariant.bits | 3}, + written: []byte{LeafWithHashedValueVariant.bits | 3}, }, {written: []byte{0x01, 0x23}}, }, @@ -174,7 +174,7 @@ func Test_Node_Encode(t *testing.T) { }, writes: []writeCall{ { // header - written: []byte{branchVariant.bits | 1}, // partial key length 1 + written: []byte{BranchVariant.bits | 1}, // partial key length 1 err: errTest, }, }, @@ -189,7 +189,7 @@ func Test_Node_Encode(t *testing.T) { }, writes: []writeCall{ { // header - written: []byte{branchWithValueVariant.bits | 3}, // partial key length 3 + written: []byte{BranchWithValueVariant.bits | 3}, // partial key length 3 }, { // key LE written: []byte{0x01, 0x23}, @@ -210,7 +210,7 @@ func Test_Node_Encode(t *testing.T) { }, writes: []writeCall{ { // header - written: []byte{branchWithValueVariant.bits | 3}, // partial key length 3 + written: []byte{BranchWithValueVariant.bits | 3}, // partial key length 3 }, { // key LE written: []byte{0x01, 0x23}, @@ -234,7 +234,7 @@ func Test_Node_Encode(t *testing.T) { }, writes: []writeCall{ { // header - written: []byte{branchWithValueVariant.bits | 3}, // partial key length 3 + written: []byte{BranchWithValueVariant.bits | 3}, // partial key length 3 }, { // key LE written: []byte{0x01, 0x23}, @@ -261,7 +261,7 @@ func Test_Node_Encode(t *testing.T) { }, writes: []writeCall{ { // header - written: []byte{branchWithValueVariant.bits | 3}, // partial key length 3 + written: []byte{BranchWithValueVariant.bits | 3}, // partial key length 3 }, { // key LE written: []byte{0x01, 0x23}, @@ -293,7 +293,7 @@ func Test_Node_Encode(t *testing.T) { }, writes: []writeCall{ { // header - written: []byte{branchWithValueVariant.bits | 3}, // partial key length 3 + written: []byte{BranchWithValueVariant.bits | 3}, // partial key length 3 }, { // key LE written: []byte{0x01, 0x23}, @@ -322,7 +322,7 @@ func Test_Node_Encode(t *testing.T) { }, writes: []writeCall{ { // header - written: []byte{branchVariant.bits | 3}, // partial key length 3 + written: []byte{BranchVariant.bits | 3}, // partial key length 3 }, { // key LE written: []byte{0x01, 0x23}, @@ -350,7 +350,7 @@ func Test_Node_Encode(t *testing.T) { }, writes: []writeCall{ { // header - written: []byte{branchWithHashedValueVariant.bits | 3}, // partial key length 3 + written: []byte{BranchWithHashedValueVariant.bits | 3}, // partial key length 3 }, { // key LE written: []byte{0x01, 0x23}, @@ -381,7 +381,7 @@ func Test_Node_Encode(t *testing.T) { }, writes: []writeCall{ { // header - written: []byte{branchWithHashedValueVariant.bits | 3}, // partial key length 3 + written: []byte{BranchWithHashedValueVariant.bits | 3}, // partial key length 3 }, { // key LE written: []byte{0x01, 0x23}, diff --git a/internal/trie/node/header.go b/internal/trie/node/header.go index 8863b410a8..880104b32a 100644 --- a/internal/trie/node/header.go +++ b/internal/trie/node/header.go @@ -12,7 +12,7 @@ import ( // encodeHeader writes the encoded header for the node. func encodeHeader(node *Node, writer io.Writer) (err error) { if node == nil { - _, err = writer.Write([]byte{emptyVariant.bits}) + _, err = writer.Write([]byte{EmptyVariant.bits}) return err } @@ -22,19 +22,19 @@ func encodeHeader(node *Node, writer io.Writer) (err error) { } // Merge variant byte and partial key length together - var nodeVariant variant + var nodeVariant Variant if node.Kind() == Leaf { if node.HashedValue { - nodeVariant = leafWithHashedValueVariant + nodeVariant = LeafWithHashedValueVariant } else { - nodeVariant = leafVariant + nodeVariant = LeafVariant } } else if node.StorageValue == nil { - nodeVariant = branchVariant + nodeVariant = BranchVariant } else if node.HashedValue { - nodeVariant = branchWithHashedValueVariant + nodeVariant = BranchWithHashedValueVariant } else { - nodeVariant = branchWithValueVariant + nodeVariant = BranchWithValueVariant } buffer := make([]byte, 1) @@ -81,7 +81,7 @@ var ( ErrPartialKeyTooBig = errors.New("partial key length cannot be larger than 2^16") ) -func decodeHeader(reader io.Reader) (nodeVariant variant, +func DecodeHeader(reader io.Reader) (nodeVariant Variant, partialKeyLength uint16, err error) { buffer := make([]byte, 1) _, err = reader.Read(buffer) @@ -91,11 +91,11 @@ func decodeHeader(reader io.Reader) (nodeVariant variant, nodeVariant, partialKeyLengthHeader, err := decodeHeaderByte(buffer[0]) if err != nil { - return variant{}, 0, fmt.Errorf("decoding header byte: %w", err) + return Variant{}, 0, fmt.Errorf("decoding header byte: %w", err) } partialKeyLengthHeaderMask := nodeVariant.partialKeyLengthHeaderMask() - if partialKeyLengthHeaderMask == emptyVariant.bits { + if partialKeyLengthHeaderMask == EmptyVariant.bits { // empty node or compact encoding which have no // partial key. The partial key length mask is // 0b0000_0000 since the variant mask is @@ -119,7 +119,7 @@ func decodeHeader(reader io.Reader) (nodeVariant variant, for { _, err = reader.Read(buffer) if err != nil { - return variant{}, 0, fmt.Errorf("reading key length: %w", err) + return Variant{}, 0, fmt.Errorf("reading key length: %w", err) } previousKeyLength = partialKeyLength @@ -130,7 +130,7 @@ func decodeHeader(reader io.Reader) (nodeVariant variant, // maximum uint16 value; therefore if we overflowed, we went over // this maximum. overflowed := maxPartialKeyLength - previousKeyLength + partialKeyLength - return variant{}, 0, fmt.Errorf("%w: overflowed by %d", ErrPartialKeyTooBig, overflowed) + return Variant{}, 0, fmt.Errorf("%w: overflowed by %d", ErrPartialKeyTooBig, overflowed) } if buffer[0] < 255 { @@ -150,17 +150,17 @@ var ErrVariantUnknown = errors.New("node variant is unknown") // reasons only, instead of having it locally defined in // the decodeHeaderByte function below. // For 7 variants, the performance is improved by ~20%. -var variantsOrderedByBitMask = [...]variant{ - leafVariant, // mask 1100_0000 - branchVariant, // mask 1100_0000 - branchWithValueVariant, // mask 1100_0000 - leafWithHashedValueVariant, // mask 1110_0000 - branchWithHashedValueVariant, // mask 1111_0000 - emptyVariant, // mask 1111_1111 +var variantsOrderedByBitMask = [...]Variant{ + LeafVariant, // mask 1100_0000 + BranchVariant, // mask 1100_0000 + BranchWithValueVariant, // mask 1100_0000 + LeafWithHashedValueVariant, // mask 1110_0000 + BranchWithHashedValueVariant, // mask 1111_0000 + EmptyVariant, // mask 1111_1111 compactEncodingVariant, // mask 1111_1111 } -func decodeHeaderByte(header byte) (nodeVariant variant, +func decodeHeaderByte(header byte) (nodeVariant Variant, partialKeyLengthHeader byte, err error) { var partialKeyLengthHeaderMask byte for i := len(variantsOrderedByBitMask) - 1; i >= 0; i-- { diff --git a/internal/trie/node/header_test.go b/internal/trie/node/header_test.go index 483ffb2730..3e7918c0aa 100644 --- a/internal/trie/node/header_test.go +++ b/internal/trie/node/header_test.go @@ -33,7 +33,7 @@ func Test_encodeHeader(t *testing.T) { Children: make([]*Node, ChildrenCapacity), }, writes: []writeCall{ - {written: []byte{branchVariant.bits}}, + {written: []byte{BranchVariant.bits}}, }, }, "branch_with_value": { @@ -42,7 +42,7 @@ func Test_encodeHeader(t *testing.T) { Children: make([]*Node, ChildrenCapacity), }, writes: []writeCall{ - {written: []byte{branchWithValueVariant.bits}}, + {written: []byte{BranchWithValueVariant.bits}}, }, }, "branch_with_hashed_value": { @@ -52,7 +52,7 @@ func Test_encodeHeader(t *testing.T) { Children: make([]*Node, ChildrenCapacity), }, writes: []writeCall{ - {written: []byte{branchWithHashedValueVariant.bits}}, + {written: []byte{BranchWithHashedValueVariant.bits}}, }, }, "branch_with_key_of_length_30": { @@ -61,7 +61,7 @@ func Test_encodeHeader(t *testing.T) { Children: make([]*Node, ChildrenCapacity), }, writes: []writeCall{ - {written: []byte{branchVariant.bits | 30}}, + {written: []byte{BranchVariant.bits | 30}}, }, }, "branch_with_key_of_length_62": { @@ -70,7 +70,7 @@ func Test_encodeHeader(t *testing.T) { Children: make([]*Node, ChildrenCapacity), }, writes: []writeCall{ - {written: []byte{branchVariant.bits | 62}}, + {written: []byte{BranchVariant.bits | 62}}, }, }, "branch_with_key_of_length_63": { @@ -79,7 +79,7 @@ func Test_encodeHeader(t *testing.T) { Children: make([]*Node, ChildrenCapacity), }, writes: []writeCall{ - {written: []byte{branchVariant.bits | 63}}, + {written: []byte{BranchVariant.bits | 63}}, {written: []byte{0x00}}, // trailing 0 to indicate the partial // key length is done here. }, @@ -90,7 +90,7 @@ func Test_encodeHeader(t *testing.T) { Children: make([]*Node, ChildrenCapacity), }, writes: []writeCall{ - {written: []byte{branchVariant.bits | 63}}, + {written: []byte{BranchVariant.bits | 63}}, {written: []byte{0x01}}, }, }, @@ -100,7 +100,7 @@ func Test_encodeHeader(t *testing.T) { }, writes: []writeCall{ { - written: []byte{branchVariant.bits}, + written: []byte{BranchVariant.bits}, err: errTest, }, }, @@ -109,12 +109,12 @@ func Test_encodeHeader(t *testing.T) { }, "branch_with_long_key_length_write_error": { node: &Node{ - PartialKey: make([]byte, int(^branchVariant.mask)+1), + PartialKey: make([]byte, int(^BranchVariant.mask)+1), Children: make([]*Node, ChildrenCapacity), }, writes: []writeCall{ { - written: []byte{branchVariant.bits | ^branchVariant.mask}, + written: []byte{BranchVariant.bits | ^BranchVariant.mask}, }, { written: []byte{0x01}, @@ -130,13 +130,13 @@ func Test_encodeHeader(t *testing.T) { HashedValue: true, }, writes: []writeCall{ - {written: []byte{leafWithHashedValueVariant.bits}}, + {written: []byte{LeafWithHashedValueVariant.bits}}, }, }, "leaf_with_no_key": { node: &Node{StorageValue: []byte{1}}, writes: []writeCall{ - {written: []byte{leafVariant.bits}}, + {written: []byte{LeafVariant.bits}}, }, }, "leaf_with_key_of_length_30": { @@ -144,7 +144,7 @@ func Test_encodeHeader(t *testing.T) { PartialKey: make([]byte, 30), }, writes: []writeCall{ - {written: []byte{leafVariant.bits | 30}}, + {written: []byte{LeafVariant.bits | 30}}, }, }, "leaf_with_short_key_write_error": { @@ -153,7 +153,7 @@ func Test_encodeHeader(t *testing.T) { }, writes: []writeCall{ { - written: []byte{leafVariant.bits | 30}, + written: []byte{LeafVariant.bits | 30}, err: errTest, }, }, @@ -165,7 +165,7 @@ func Test_encodeHeader(t *testing.T) { PartialKey: make([]byte, 62), }, writes: []writeCall{ - {written: []byte{leafVariant.bits | 62}}, + {written: []byte{LeafVariant.bits | 62}}, }, }, "leaf_with_key_of_length_63": { @@ -173,7 +173,7 @@ func Test_encodeHeader(t *testing.T) { PartialKey: make([]byte, 63), }, writes: []writeCall{ - {written: []byte{leafVariant.bits | 63}}, + {written: []byte{LeafVariant.bits | 63}}, {written: []byte{0x0}}, }, }, @@ -182,7 +182,7 @@ func Test_encodeHeader(t *testing.T) { PartialKey: make([]byte, 64), }, writes: []writeCall{ - {written: []byte{leafVariant.bits | 63}}, + {written: []byte{LeafVariant.bits | 63}}, {written: []byte{0x1}}, }, }, @@ -192,7 +192,7 @@ func Test_encodeHeader(t *testing.T) { }, writes: []writeCall{ { - written: []byte{leafVariant.bits | 63}, + written: []byte{LeafVariant.bits | 63}, err: errTest, }, }, @@ -201,20 +201,20 @@ func Test_encodeHeader(t *testing.T) { }, "leaf_with_key_length_over_3_bytes": { node: &Node{ - PartialKey: make([]byte, int(^leafVariant.mask)+0b1111_1111+0b0000_0001), + PartialKey: make([]byte, int(^LeafVariant.mask)+0b1111_1111+0b0000_0001), }, writes: []writeCall{ - {written: []byte{leafVariant.bits | ^leafVariant.mask}}, + {written: []byte{LeafVariant.bits | ^LeafVariant.mask}}, {written: []byte{0b1111_1111}}, {written: []byte{0b0000_0001}}, }, }, "leaf_with_key_length_over_3_bytes_and_last_byte_zero": { node: &Node{ - PartialKey: make([]byte, int(^leafVariant.mask)+0b1111_1111), + PartialKey: make([]byte, int(^LeafVariant.mask)+0b1111_1111), }, writes: []writeCall{ - {written: []byte{leafVariant.bits | ^leafVariant.mask}}, + {written: []byte{LeafVariant.bits | ^LeafVariant.mask}}, {written: []byte{0b1111_1111}}, {written: []byte{0x00}}, }, @@ -270,7 +270,7 @@ func Test_encodeHeader_At_Maximum(t *testing.T) { // mock writer since it's too slow, so we use // an actual buffer. - variant := leafVariant.bits + Variant := LeafVariant.bits const partialKeyLengthHeaderMask = 0b0011_1111 const keyLength = uint(maxPartialKeyLength) extraKeyBytesNeeded := math.Ceil(float64(maxPartialKeyLength-partialKeyLengthHeaderMask) / 255.0) @@ -278,7 +278,7 @@ func Test_encodeHeader_At_Maximum(t *testing.T) { lengthLeft := maxPartialKeyLength expectedBytes := make([]byte, expectedEncodingLength) - expectedBytes[0] = variant | partialKeyLengthHeaderMask + expectedBytes[0] = Variant | partialKeyLengthHeaderMask lengthLeft -= partialKeyLengthHeaderMask for i := 1; i < len(expectedBytes)-1; i++ { expectedBytes[i] = 255 @@ -304,7 +304,7 @@ func Test_decodeHeader(t *testing.T) { testCases := map[string]struct { reads []readCall - nodeVariant variant + nodeVariant Variant partialKeyLength uint16 errWrapped error errMessage string @@ -325,14 +325,14 @@ func Test_decodeHeader(t *testing.T) { }, "partial_key_length_contained_in_first_byte": { reads: []readCall{ - {buffArgCap: 1, read: []byte{leafVariant.bits | 0b0011_1110}}, + {buffArgCap: 1, read: []byte{LeafVariant.bits | 0b0011_1110}}, }, - nodeVariant: leafVariant, + nodeVariant: LeafVariant, partialKeyLength: uint16(0b0011_1110), }, "long_partial_key_length_and_second_byte_read_error": { reads: []readCall{ - {buffArgCap: 1, read: []byte{leafVariant.bits | 0b0011_1111}}, + {buffArgCap: 1, read: []byte{LeafVariant.bits | 0b0011_1111}}, {buffArgCap: 1, err: errTest}, }, errWrapped: errTest, @@ -340,11 +340,11 @@ func Test_decodeHeader(t *testing.T) { }, "partial_key_length_spread_on_multiple_bytes": { reads: []readCall{ - {buffArgCap: 1, read: []byte{leafVariant.bits | 0b0011_1111}}, + {buffArgCap: 1, read: []byte{LeafVariant.bits | 0b0011_1111}}, {buffArgCap: 1, read: []byte{0b1111_1111}}, {buffArgCap: 1, read: []byte{0b1111_0000}}, }, - nodeVariant: leafVariant, + nodeVariant: LeafVariant, partialKeyLength: uint16(0b0011_1111 + 0b1111_1111 + 0b1111_0000), }, "partial_key_length_too_long": { @@ -380,7 +380,7 @@ func Test_decodeHeader(t *testing.T) { previousCall = call } - nodeVariant, partialKeyLength, err := decodeHeader(reader) + nodeVariant, partialKeyLength, err := DecodeHeader(reader) assert.Equal(t, testCase.nodeVariant, nodeVariant) assert.Equal(t, int(testCase.partialKeyLength), int(partialKeyLength)) @@ -397,39 +397,39 @@ func Test_decodeHeaderByte(t *testing.T) { testCases := map[string]struct { header byte - nodeVariant variant + nodeVariant Variant partialKeyLengthHeader byte errWrapped error errMessage string }{ "empty_variant_header": { header: 0b0000_0000, - nodeVariant: emptyVariant, + nodeVariant: EmptyVariant, partialKeyLengthHeader: 0b0000_0000, }, "branch_with_value_header": { header: 0b1110_1001, - nodeVariant: branchWithValueVariant, + nodeVariant: BranchWithValueVariant, partialKeyLengthHeader: 0b0010_1001, }, "branch_header": { header: 0b1010_1001, - nodeVariant: branchVariant, + nodeVariant: BranchVariant, partialKeyLengthHeader: 0b0010_1001, }, "leaf_header": { header: 0b0110_1001, - nodeVariant: leafVariant, + nodeVariant: LeafVariant, partialKeyLengthHeader: 0b0010_1001, }, "leaf_containing_hashes_header": { header: 0b0011_1001, - nodeVariant: leafWithHashedValueVariant, + nodeVariant: LeafWithHashedValueVariant, partialKeyLengthHeader: 0b0001_1001, }, "branch_containing_hashes_header": { header: 0b0001_1001, - nodeVariant: branchWithHashedValueVariant, + nodeVariant: BranchWithHashedValueVariant, partialKeyLengthHeader: 0b0000_1001, }, "compact_encoding_header": { @@ -465,8 +465,8 @@ func Test_decodeHeaderByte(t *testing.T) { func Test_variantsOrderedByBitMask(t *testing.T) { t.Parallel() - slice := make([]variant, len(variantsOrderedByBitMask)) - sortedSlice := make([]variant, len(variantsOrderedByBitMask)) + slice := make([]Variant, len(variantsOrderedByBitMask)) + sortedSlice := make([]Variant, len(variantsOrderedByBitMask)) copy(slice, variantsOrderedByBitMask[:]) copy(sortedSlice, variantsOrderedByBitMask[:]) @@ -483,7 +483,7 @@ func Benchmark_decodeHeaderByte(b *testing.B) { // 2.987 ns/op 0 B/op 0 allocs/op // With locally scoped variants slice: // 3.873 ns/op 0 B/op 0 allocs/op - header := leafVariant.bits | 0b0000_0001 + header := LeafVariant.bits | 0b0000_0001 b.ResetTimer() for i := 0; i < b.N; i++ { _, _, _ = decodeHeaderByte(header) diff --git a/internal/trie/node/key.go b/internal/trie/node/key.go index 343a5d747d..806aca20bc 100644 --- a/internal/trie/node/key.go +++ b/internal/trie/node/key.go @@ -15,8 +15,8 @@ const maxPartialKeyLength = ^uint16(0) var ErrReaderMismatchCount = errors.New("read unexpected number of bytes from reader") -// decodeKey decodes a key from a reader. -func decodeKey(reader io.Reader, partialKeyLength uint16) (b []byte, err error) { +// DecodeKey decodes a key from a reader. +func DecodeKey(reader io.Reader, partialKeyLength uint16) (b []byte, err error) { if partialKeyLength == 0 { return []byte{}, nil } diff --git a/internal/trie/node/key_test.go b/internal/trie/node/key_test.go index 4a6ad3c47c..500b10ad57 100644 --- a/internal/trie/node/key_test.go +++ b/internal/trie/node/key_test.go @@ -134,7 +134,7 @@ func Test_decodeKey(t *testing.T) { previousCall = call } - b, err := decodeKey(reader, testCase.partialKeyLength) + b, err := DecodeKey(reader, testCase.partialKeyLength) assert.ErrorIs(t, err, testCase.errWrapped) if err != nil { diff --git a/internal/trie/node/node.go b/internal/trie/node/node.go index c5a7c83ce0..4d09d19192 100644 --- a/internal/trie/node/node.go +++ b/internal/trie/node/node.go @@ -7,6 +7,7 @@ package node import ( "fmt" + "strconv" "github.com/qdm12/gotree" ) @@ -57,6 +58,7 @@ func (n *Node) StringNode() (stringNode *gotree.Node) { stringNode.Appendf("Dirty: %t", n.Dirty) stringNode.Appendf("Key: " + bytesToString(n.PartialKey)) stringNode.Appendf("Storage value: " + bytesToString(n.StorageValue)) + stringNode.Appendf("Hashed value: " + strconv.FormatBool(n.HashedValue)) if n.Descendants > 0 { // must be a branch stringNode.Appendf("Descendants: %d", n.Descendants) } @@ -73,6 +75,10 @@ func (n *Node) StringNode() (stringNode *gotree.Node) { return stringNode } +func (n *Node) IsHashed() bool { + return n.HashedValue +} + func bytesToString(b []byte) (s string) { switch { case b == nil: diff --git a/internal/trie/node/node_test.go b/internal/trie/node/node_test.go index af4f2269b8..cf781b7ff7 100644 --- a/internal/trie/node/node_test.go +++ b/internal/trie/node/node_test.go @@ -27,6 +27,7 @@ func Test_Node_String(t *testing.T) { ├── Dirty: true ├── Key: 0x0102 ├── Storage value: 0x0304 +├── Hashed value: false └── Merkle value: nil`, }, "leaf_with_storage_value_higher_than_1024": { @@ -40,6 +41,7 @@ func Test_Node_String(t *testing.T) { ├── Dirty: true ├── Key: 0x0102 ├── Storage value: 0x0000000000000000...0000000000000000 +├── Hashed value: false └── Merkle value: nil`, }, "branch_with_storage_value_smaller_than_1024": { @@ -66,6 +68,7 @@ func Test_Node_String(t *testing.T) { ├── Dirty: true ├── Key: 0x0102 ├── Storage value: 0x0304 +├── Hashed value: false ├── Descendants: 3 ├── Merkle value: nil ├── Child 3 @@ -74,6 +77,7 @@ func Test_Node_String(t *testing.T) { | ├── Dirty: false | ├── Key: nil | ├── Storage value: nil +| ├── Hashed value: false | └── Merkle value: nil ├── Child 7 | └── Branch @@ -81,6 +85,7 @@ func Test_Node_String(t *testing.T) { | ├── Dirty: false | ├── Key: nil | ├── Storage value: nil +| ├── Hashed value: false | ├── Descendants: 1 | ├── Merkle value: nil | └── Child 0 @@ -89,6 +94,7 @@ func Test_Node_String(t *testing.T) { | ├── Dirty: false | ├── Key: nil | ├── Storage value: nil +| ├── Hashed value: false | └── Merkle value: nil └── Child 11 └── Leaf @@ -96,6 +102,7 @@ func Test_Node_String(t *testing.T) { ├── Dirty: false ├── Key: nil ├── Storage value: nil + ├── Hashed value: false └── Merkle value: nil`, }, "branch_with_storage_value_higher_than_1024": { @@ -122,6 +129,7 @@ func Test_Node_String(t *testing.T) { ├── Dirty: true ├── Key: 0x0102 ├── Storage value: 0x0000000000000000...0000000000000000 +├── Hashed value: false ├── Descendants: 3 ├── Merkle value: nil ├── Child 3 @@ -130,6 +138,7 @@ func Test_Node_String(t *testing.T) { | ├── Dirty: false | ├── Key: nil | ├── Storage value: nil +| ├── Hashed value: false | └── Merkle value: nil ├── Child 7 | └── Branch @@ -137,6 +146,7 @@ func Test_Node_String(t *testing.T) { | ├── Dirty: false | ├── Key: nil | ├── Storage value: nil +| ├── Hashed value: false | ├── Descendants: 1 | ├── Merkle value: nil | └── Child 0 @@ -145,6 +155,7 @@ func Test_Node_String(t *testing.T) { | ├── Dirty: false | ├── Key: nil | ├── Storage value: nil +| ├── Hashed value: false | └── Merkle value: nil └── Child 11 └── Leaf @@ -152,6 +163,7 @@ func Test_Node_String(t *testing.T) { ├── Dirty: false ├── Key: nil ├── Storage value: nil + ├── Hashed value: false └── Merkle value: nil`, }, } diff --git a/internal/trie/node/variants.go b/internal/trie/node/variants.go index 58b8080a98..59bb2b33eb 100644 --- a/internal/trie/node/variants.go +++ b/internal/trie/node/variants.go @@ -3,7 +3,7 @@ package node -type variant struct { +type Variant struct { bits byte mask byte } @@ -11,35 +11,35 @@ type variant struct { // Node variants // See https://spec.polkadot.network/#defn-node-header var ( - leafVariant = variant{ // leaf 01 + LeafVariant = Variant{ // leaf 01 bits: 0b0100_0000, mask: 0b1100_0000, } - branchVariant = variant{ // branch 10 + BranchVariant = Variant{ // branch 10 bits: 0b1000_0000, mask: 0b1100_0000, } - branchWithValueVariant = variant{ // branch 11 + BranchWithValueVariant = Variant{ // branch 11 bits: 0b1100_0000, mask: 0b1100_0000, } - leafWithHashedValueVariant = variant{ // leaf containing hashes 001 + LeafWithHashedValueVariant = Variant{ // leaf containing hashes 001 bits: 0b0010_0000, mask: 0b1110_0000, } - branchWithHashedValueVariant = variant{ // branch containing hashes 0001 + BranchWithHashedValueVariant = Variant{ // branch containing hashes 0001 bits: 0b0001_0000, mask: 0b1111_0000, } - emptyVariant = variant{ // empty 0000 0000 + EmptyVariant = Variant{ // empty 0000 0000 bits: 0b0000_0000, mask: 0b1111_1111, } - compactEncodingVariant = variant{ // compact encoding 0001 0000 + compactEncodingVariant = Variant{ // compact encoding 0001 0000 bits: 0b0000_0001, mask: 0b1111_1111, } - invalidVariant = variant{ + invalidVariant = Variant{ bits: 0b0000_0000, mask: 0b0000_0000, } @@ -49,23 +49,23 @@ var ( // header bit mask corresponding to the variant header bit mask. // For example for the leaf variant with variant mask 1100_0000, // the partial key length header mask returned is 0011_1111. -func (v variant) partialKeyLengthHeaderMask() byte { +func (v Variant) partialKeyLengthHeaderMask() byte { return ^v.mask } -func (v variant) String() string { +func (v Variant) String() string { switch v { - case leafVariant: + case LeafVariant: return "Leaf" - case leafWithHashedValueVariant: + case LeafWithHashedValueVariant: return "LeafWithHashedValue" - case branchVariant: + case BranchVariant: return "Branch" - case branchWithValueVariant: + case BranchWithValueVariant: return "BranchWithValue" - case branchWithHashedValueVariant: + case BranchWithHashedValueVariant: return "BranchWithHashedValue" - case emptyVariant: + case EmptyVariant: return "Empty" case compactEncodingVariant: return "Compact" diff --git a/lib/trie/db/db.go b/lib/trie/db/db.go index 2c04b28e05..3c9ca0e995 100644 --- a/lib/trie/db/db.go +++ b/lib/trie/db/db.go @@ -8,11 +8,11 @@ import ( "github.com/ChainSafe/gossamer/lib/common" ) -type MemoryDB struct { +type InMemoryDB struct { data map[common.Hash][]byte } -func NewMemoryDBFromProof(encodedNodes [][]byte) (*MemoryDB, error) { +func NewMemoryDBFromProof(encodedNodes [][]byte) (*InMemoryDB, error) { data := make(map[common.Hash][]byte, len(encodedNodes)) for _, encodedProofNode := range encodedNodes { @@ -24,13 +24,13 @@ func NewMemoryDBFromProof(encodedNodes [][]byte) (*MemoryDB, error) { data[nodeHash] = encodedProofNode } - return &MemoryDB{ + return &InMemoryDB{ data: data, }, nil } -func (mdb *MemoryDB) Get(key []byte) (value []byte, err error) { +func (mdb *InMemoryDB) Get(key []byte) (value []byte, err error) { if len(key) < common.HashLength { return nil, fmt.Errorf("expected %d bytes length key, given %d (%x)", common.HashLength, len(key), value) } diff --git a/lib/trie/print_test.go b/lib/trie/print_test.go index ff71d5123f..0cfeb7b743 100644 --- a/lib/trie/print_test.go +++ b/lib/trie/print_test.go @@ -32,6 +32,7 @@ func Test_Trie_String(t *testing.T) { ├── Dirty: false ├── Key: 0x010203 ├── Storage value: 0x030405 +├── Hashed value: false └── Merkle value: nil`, }, "branch_root": { @@ -60,6 +61,7 @@ func Test_Trie_String(t *testing.T) { ├── Dirty: false ├── Key: nil ├── Storage value: 0x0102 +├── Hashed value: false ├── Descendants: 2 ├── Merkle value: nil ├── Child 0 @@ -68,6 +70,7 @@ func Test_Trie_String(t *testing.T) { | ├── Dirty: false | ├── Key: 0x010203 | ├── Storage value: 0x030405 +| ├── Hashed value: false | └── Merkle value: nil └── Child 3 └── Leaf @@ -75,6 +78,7 @@ func Test_Trie_String(t *testing.T) { ├── Dirty: false ├── Key: 0x010203 ├── Storage value: 0x030405 + ├── Hashed value: false └── Merkle value: nil`, }, } diff --git a/lib/trie/proof/proof_test.go b/lib/trie/proof/proof_test.go index 8f813b4f61..311dc4412f 100644 --- a/lib/trie/proof/proof_test.go +++ b/lib/trie/proof/proof_test.go @@ -10,7 +10,7 @@ import ( "github.com/ChainSafe/chaindb" "github.com/ChainSafe/gossamer/lib/trie" - "github.com/ChainSafe/gossamer/lib/trie/db" + "github.com/ChainSafe/gossamer/pkg/trie/triedb" "github.com/stretchr/testify/require" ) @@ -82,15 +82,15 @@ func TestParachainHeaderStateProof(t *testing.T) { require.NoError(t, err) proof := [][]byte{proof1, proof2, proof3, proof4, proof5, proof6, proof7} - expectedValue := proof7 - proofDB, err := db.NewMemoryDBFromProof(proof) + db := NewStorageProof(proof).toMemoryDB() + trieDB := triedb.NewTrieDBBuilder(db, stateRoot).Build() require.NoError(t, err) - trie, err := buildTrie(proof, stateRoot, proofDB) + value, err := trieDB.GetValue(encodeStorageKey) require.NoError(t, err) - value := trie.Get(encodeStorageKey) + require.Equal(t, expectedValue, value) // Also check that we can verify the proof diff --git a/lib/trie/proof/storage_proof.go b/lib/trie/proof/storage_proof.go new file mode 100644 index 0000000000..2b41f76f59 --- /dev/null +++ b/lib/trie/proof/storage_proof.go @@ -0,0 +1,30 @@ +// Copyright 2022 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package proof + +import ( + "github.com/ChainSafe/gossamer/pkg/trie/hashdb" + "github.com/ChainSafe/gossamer/pkg/trie/memorydb" +) + +type StorageProof struct { + //TODO: Improve it using sets + trieNodes [][]byte +} + +func (sp *StorageProof) toMemoryDB() hashdb.HashDB { + db := memorydb.NewMemoryDB() + + for _, proof := range sp.trieNodes { + db.Insert(proof) + } + + return db +} + +func NewStorageProof(proof [][]byte) *StorageProof { + return &StorageProof{ + trieNodes: proof, + } +} diff --git a/lib/trie/proof/verify.go b/lib/trie/proof/verify.go index 1e8457f04f..8e08083258 100644 --- a/lib/trie/proof/verify.go +++ b/lib/trie/proof/verify.go @@ -9,12 +9,10 @@ import ( "fmt" "strings" - "github.com/ChainSafe/gossamer/internal/log" "github.com/ChainSafe/gossamer/internal/trie/node" "github.com/ChainSafe/gossamer/internal/trie/pools" "github.com/ChainSafe/gossamer/lib/common" "github.com/ChainSafe/gossamer/lib/trie" - "github.com/ChainSafe/gossamer/lib/trie/db" ) var ( @@ -22,15 +20,14 @@ var ( ErrValueMismatchProofTrie = errors.New("value found in proof trie does not match") ) -var logger = log.NewFromGlobal(log.AddContext("pkg", "proof")) - // Verify verifies a given key and value belongs to the trie by creating // a proof trie based on the encoded proof nodes given. The order of proofs is ignored. // A nil error is returned on success. // Note this is exported because it is imported and used by: // https://github.com/ComposableFi/ibc-go/blob/6d62edaa1a3cb0768c430dab81bb195e0b0c72db/modules/light-clients/11-beefy/types/client_state.go#L78 func Verify(encodedProofNodes [][]byte, rootHash, key, value []byte) (err error) { - proofDB, err := db.NewMemoryDBFromProof(encodedProofNodes) + storageProof := NewStorageProof(encodedProofNodes) + proofDB := storageProof.toMemoryDB() if err != nil { return err @@ -148,8 +145,6 @@ func loadProof(digestToEncoding map[string][]byte, n *node.Node) (err error) { merkleValue := child.MerkleValue encoding, ok := digestToEncoding[string(merkleValue)] - logger.Infof("Node: %x", encoding) - if !ok { inlinedChild := len(child.StorageValue) > 0 || child.HasChild() if inlinedChild { @@ -170,7 +165,6 @@ func loadProof(digestToEncoding map[string][]byte, n *node.Node) (err error) { continue } - logger.Info("loading proof DECODING...") child, err := node.Decode(bytes.NewReader(encoding)) if err != nil { return fmt.Errorf("decoding child node for hash digest 0x%x: %w", diff --git a/lib/trie/proof/verify_test.go b/lib/trie/proof/verify_test.go index db969e1619..78c539f8ea 100644 --- a/lib/trie/proof/verify_test.go +++ b/lib/trie/proof/verify_test.go @@ -8,7 +8,6 @@ import ( "github.com/ChainSafe/gossamer/internal/trie/node" "github.com/ChainSafe/gossamer/lib/trie" - "github.com/ChainSafe/gossamer/lib/trie/db" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -167,8 +166,8 @@ func Test_buildTrie(t *testing.T) { encodeNode(t, leafAShort), } - proofDB, err := db.NewMemoryDBFromProof(encodedProofNodes) - assert.NoError(t, err) + storageProof := NewStorageProof(encodedProofNodes) + proofDB := storageProof.toMemoryDB() return testCase{ encodedProofNodes: encodedProofNodes, @@ -186,8 +185,8 @@ func Test_buildTrie(t *testing.T) { encodeNode(t, leafBLarge), } - proofDB, err := db.NewMemoryDBFromProof(encodedProofNodes) - assert.NoError(t, err) + storageProof := NewStorageProof(encodedProofNodes) + proofDB := storageProof.toMemoryDB() return testCase{ encodedProofNodes: encodedProofNodes, @@ -206,8 +205,8 @@ func Test_buildTrie(t *testing.T) { encodeNode(t, leafBLarge), } - proofDB, err := db.NewMemoryDBFromProof(encodedProofNodes) - assert.NoError(t, err) + storageProof := NewStorageProof(encodedProofNodes) + proofDB := storageProof.toMemoryDB() return testCase{ encodedProofNodes: encodedProofNodes, @@ -235,8 +234,8 @@ func Test_buildTrie(t *testing.T) { encodeNode(t, leafCLarge), // children 2 } - proofDB, err := db.NewMemoryDBFromProof(encodedProofNodes) - assert.NoError(t, err) + storageProof := NewStorageProof(encodedProofNodes) + proofDB := storageProof.toMemoryDB() return testCase{ encodedProofNodes: encodedProofNodes, diff --git a/pkg/trie/hashdb/hashdb.go b/pkg/trie/hashdb/hashdb.go new file mode 100644 index 0000000000..49f3c27bcb --- /dev/null +++ b/pkg/trie/hashdb/hashdb.go @@ -0,0 +1,11 @@ +// Copyright 2022 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package hashdb + +import "github.com/ChainSafe/gossamer/lib/common" + +type HashDB interface { + Get(key []byte) (value []byte, err error) + Insert(value []byte) common.Hash +} diff --git a/pkg/trie/memorydb/memory.go b/pkg/trie/memorydb/memory.go new file mode 100644 index 0000000000..bc75a1b09c --- /dev/null +++ b/pkg/trie/memorydb/memory.go @@ -0,0 +1,78 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package memorydb + +import ( + "bytes" + "fmt" + + "github.com/ChainSafe/gossamer/lib/common" +) + +type MemoryDBItem struct { + data []byte + // Reference count + rc int32 +} + +type MemoryDB struct { + data map[common.Hash]MemoryDBItem + hashedNullNode common.Hash + nullNodeData []byte +} + +func NewMemoryDB() *MemoryDB { + return newMemoryDBFromNullNode([]byte{0}, []byte{0}) +} + +func newMemoryDBFromNullNode(nullKey []byte, nullNodeData []byte) *MemoryDB { + hashedKey := common.MustBlake2bHash(nullKey) + + return &MemoryDB{ + data: make(map[common.Hash]MemoryDBItem), + hashedNullNode: hashedKey, + nullNodeData: nullNodeData, + } +} + +func (mdb *MemoryDB) Get(key []byte) (value []byte, err error) { + if len(key) < common.HashLength { + return nil, fmt.Errorf("expected %d bytes length key, given %d (%x)", common.HashLength, len(key), value) + } + var hash common.Hash + copy(hash[:], key) + + if value, found := mdb.data[hash]; found { + return value.data, nil + } + + return nil, nil +} + +func (mdb *MemoryDB) Insert(value []byte) common.Hash { + if bytes.Equal(value, mdb.nullNodeData) { + return mdb.hashedNullNode + } + + key := common.MustBlake2bHash(value) + mdb.emplace(key, value) + return key +} + +func (mdb *MemoryDB) emplace(key common.Hash, value []byte) { + if bytes.Equal(value, mdb.nullNodeData) { + return + } + + data, ok := mdb.data[key] + if !ok { + mdb.data[key] = MemoryDBItem{value, 0} + return + } + + if data.rc <= 0 { + data.data = value + } + data.rc++ +} diff --git a/pkg/trie/triedb/lookup.go b/pkg/trie/triedb/lookup.go new file mode 100644 index 0000000000..41385db471 --- /dev/null +++ b/pkg/trie/triedb/lookup.go @@ -0,0 +1,119 @@ +// Copyright 2022 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package triedb + +import ( + "bytes" + "errors" + "fmt" + + "github.com/ChainSafe/gossamer/pkg/trie/hashdb" + "github.com/ChainSafe/gossamer/pkg/trie/triedb/nibble" +) + +var ErrInvalidStateRoot = errors.New("invalid state root") +var ErrIncompleteDB = errors.New("incomplete database") + +var EmptyValue = []byte{} + +type Lookup struct { + db hashdb.HashDB + hash []byte + //TODO: implement cache and recorder +} + +func NewLookup(db hashdb.HashDB, hash []byte) *Lookup { + return &Lookup{db, hash} +} + +func (l Lookup) Lookup(nibbleKey *nibble.NibbleSlice) ([]byte, error) { + return l.lookupWithoutCache(nibbleKey) +} + +func (l Lookup) lookupWithoutCache(nibbleKey *nibble.NibbleSlice) ([]byte, error) { + partial := nibbleKey + hash := l.hash + keyNibbles := uint(0) + + depth := 0 + + for { + // Get node from DB + nodeData, err := l.db.Get(hash) + + if err != nil { + if depth == 0 { + return nil, ErrInvalidStateRoot + } + return nil, ErrIncompleteDB + } + + // Iterates children + for { + // Decode node + reader := bytes.NewReader(nodeData) + decodedNode, err := decode(reader) + if err != nil { + return nil, fmt.Errorf("decoding node error %s", err.Error()) + } + + // Empty Node + if decodedNode.Type == Empty { + return EmptyValue, nil + } + + var nextNode *NodeHandle = nil + + switch decodedNode.Type { + case Leaf: + // If leaf and matches return value + if partial.Eq(&decodedNode.Partial) { + return l.loadValue(decodedNode.Value) + } + return EmptyValue, nil + // Nibbled branch + case NibbledBranch: + // Get next node + slice := decodedNode.Partial + + if !partial.StartsWith(&slice) { + return EmptyValue, nil + } + + if partial.Len() == slice.Len() { + if decodedNode.Value != nil { + return l.loadValue(decodedNode.Value) + } + } + + nextNode = decodedNode.Children[partial.At(slice.Len())] + if nextNode == nil { + return EmptyValue, nil + } + + partial = partial.Mid(slice.Len() + 1) + keyNibbles += slice.Len() + 1 + } + + if nextNode.Hashed { + hash = nextNode.Data + break + } + + nodeData = nextNode.Data + } + depth++ + } +} + +func (l Lookup) loadValue(value *NodeValue) ([]byte, error) { + if value == nil { + return nil, fmt.Errorf("trying to load value from nil node") + } + if !value.Hashed { + return value.Data, nil + } + + return l.db.Get(value.Data) +} diff --git a/pkg/trie/triedb/nibble/nibble.go b/pkg/trie/triedb/nibble/nibble.go new file mode 100644 index 0000000000..29ded47f61 --- /dev/null +++ b/pkg/trie/triedb/nibble/nibble.go @@ -0,0 +1,53 @@ +// Copyright 2022 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package nibble + +const NibblePerByte uint = 2 +const PaddingBitmask byte = 0x0F +const BitPerNibble = 4 + +func padLeft(b byte) byte { + padded := (b & ^PaddingBitmask) + return padded +} + +func padRight(b byte) byte { + padded := (b & PaddingBitmask) + return padded +} + +func NumberPadding(i uint) uint { + return i % NibblePerByte +} + +// Count the biggest common depth between two left aligned packed nibble slice +func biggestDepth(v1, v2 []byte) uint { + upperBound := minLength(v1, v2) + + for i := uint(0); i < upperBound; i++ { + if v1[i] != v2[i] { + return i*NibblePerByte + leftCommon(v1[i], v2[i]) + } + } + return upperBound * NibblePerByte +} + +// LeftCommon the number of common nibble between two left aligned bytes +func leftCommon(a, b byte) uint { + if a == b { + return 2 + } + if padLeft(a) == padLeft(b) { + return 1 + } else { + return 0 + } +} + +func minLength(v1, v2 []byte) uint { + if len(v1) < len(v2) { + return uint(len(v1)) + } + return uint(len(v2)) +} diff --git a/pkg/trie/triedb/nibble/nibble_slice.go b/pkg/trie/triedb/nibble/nibble_slice.go new file mode 100644 index 0000000000..d2ae942200 --- /dev/null +++ b/pkg/trie/triedb/nibble/nibble_slice.go @@ -0,0 +1,83 @@ +// Copyright 2022 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package nibble + +// NibbleSlice is a helper structure to store a slice of nibbles and a moving offset +// this is helpful to use it for example while we are looking for a key, we can define the full key in the data and +// moving the offset while we are going deep in the trie +type NibbleSlice struct { + data []byte + offset uint +} + +func NewNibbleSlice(data []byte) *NibbleSlice { + return &NibbleSlice{data, 0} +} + +func NewNibbleSliceWithPadding(data []byte, padding uint) *NibbleSlice { + return &NibbleSlice{data, padding} +} + +func (ns *NibbleSlice) Data() []byte { + return ns.data +} + +func (ns *NibbleSlice) Offset() uint { + return ns.offset +} + +func (ns *NibbleSlice) Mid(i uint) *NibbleSlice { + return &NibbleSlice{ns.data, ns.offset + i} +} + +func (ns *NibbleSlice) Len() uint { + return uint(len(ns.data))*NibblePerByte - ns.offset +} + +func (ns *NibbleSlice) At(i uint) byte { + ix := (ns.offset + i) / NibblePerByte + pad := (ns.offset + i) % NibblePerByte + b := ns.data[ix] + if pad == 1 { + return b & PaddingBitmask + } + return b >> BitPerNibble +} + +func (ns *NibbleSlice) StartsWith(other *NibbleSlice) bool { + return ns.commonPrefix(other) == other.Len() +} + +func (ns *NibbleSlice) Eq(other *NibbleSlice) bool { + return ns.Len() == other.Len() && ns.StartsWith(other) +} + +func (ns *NibbleSlice) commonPrefix(other *NibbleSlice) uint { + selfAlign := ns.offset % NibblePerByte + otherAlign := other.offset % NibblePerByte + if selfAlign == otherAlign { + selfStart := ns.offset / NibblePerByte + otherStart := other.offset / NibblePerByte + first := uint(0) + if selfAlign != 0 { + if padRight(ns.data[selfStart]) != padRight(other.data[otherStart]) { + return 0 + } + selfStart++ + otherStart++ + first++ + } + return biggestDepth(ns.data[selfStart:], other.data[otherStart:]) + first + } + + s := minLength(ns.data, other.data) + i := uint(0) + for i < s { + if ns.At(i) != other.At(i) { + break + } + i++ + } + return i +} diff --git a/pkg/trie/triedb/node_codec.go b/pkg/trie/triedb/node_codec.go new file mode 100644 index 0000000000..ef27ae8dc4 --- /dev/null +++ b/pkg/trie/triedb/node_codec.go @@ -0,0 +1,195 @@ +// Copyright 2022 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package triedb + +import ( + "errors" + "fmt" + "io" + + "github.com/ChainSafe/gossamer/internal/trie/node" + "github.com/ChainSafe/gossamer/lib/common" + "github.com/ChainSafe/gossamer/pkg/scale" + "github.com/ChainSafe/gossamer/pkg/trie/triedb/nibble" +) + +var EmptyNode = &Node{} + +var ( + ErrDecodeStorageValue = errors.New("cannot decode storage value") + ErrDecodeHashedValueTooShort = errors.New("hashed storage value too short") + ErrReadChildrenBitmap = errors.New("cannot read children bitmap") + ErrDecodeChildHash = errors.New("cannot decode child hash") + ErrReaderMismatchCount = errors.New("read unexpected number of bytes from reader") +) + +func decode(reader io.Reader) (n *Node, err error) { + variant, nibbleCount, err := node.DecodeHeader(reader) + + if err != nil { + return nil, fmt.Errorf("decoding header: %w", err) + } + switch variant { + case node.EmptyVariant: + return EmptyNode, nil + case node.LeafVariant, node.LeafWithHashedValueVariant: + n, err = decodeLeaf(reader, variant, nibbleCount) + if err != nil { + return nil, fmt.Errorf("cannot decode leaf: %w", err) + } + return n, nil + case node.BranchVariant, node.BranchWithValueVariant, node.BranchWithHashedValueVariant: + n, err = decodeBranch(reader, variant, nibbleCount) + if err != nil { + return nil, fmt.Errorf("cannot decode branch: %w", err) + } + return n, nil + default: + // this is a programming error, an unknown node variant should be caught by decodeHeader. + panic(fmt.Sprintf("not implemented for node variant %08b", variant)) + } +} + +func decodeBranch(reader io.Reader, variant node.Variant, nibbleCount uint16) (*Node, error) { + // TODO: find a way to solve this without consuming the byte from the reader + /*padding := nibble.NumberPadding(nibbleCount) != 0 + + buffer := make([]byte, 1) + _, err := reader.Read(buffer) + if err != nil { + return nil, fmt.Errorf("reading header byte: %w", err) + } + + if padding && nibble.PadLeft(buffer[0]) != 0 { + return nil, fmt.Errorf("bad format") + }*/ + + partial, err := decodePartialKey(reader, nibbleCount) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrReadChildrenBitmap, err) + } + + partialPadding := nibble.NumberPadding(uint(nibbleCount)) + + childrenBitmap := make([]byte, 2) + _, err = reader.Read(childrenBitmap) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrReadChildrenBitmap, err) + } + + sd := scale.NewDecoder(reader) + nodeValue := &NodeValue{} + + switch variant { + case node.BranchWithValueVariant: + err := sd.Decode(nodeValue.Data) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrDecodeStorageValue, err) + } + case node.BranchWithHashedValueVariant: + nodeValue, err = decodeHashedValue(reader) + if err != nil { + return nil, err + } + case node.BranchVariant: + nodeValue = nil + default: + // Ignored + } + + children := make([]*NodeHandle, node.ChildrenCapacity) + + for i := 0; i < node.ChildrenCapacity; i++ { + if (childrenBitmap[i/8]>>(i%8))&1 != 1 { + continue + } + + var hash []byte + err := sd.Decode(&hash) + if err != nil { + return nil, fmt.Errorf("%w: at index %d: %s", + ErrDecodeChildHash, i, err) + } + + children[i] = &NodeHandle{ + Data: hash, + Hashed: (len(hash) == common.HashLength), + } + } + + return NewNode(NibbledBranch, *nibble.NewNibbleSliceWithPadding(partial, partialPadding), nodeValue, children), nil +} + +func decodeLeaf(reader io.Reader, variant node.Variant, nibbleCount uint16) (*Node, error) { + // TODO: find a way to solve this without consuming the byte from the reader + /*padding := nibble.NumberPadding(nibbleCount) != 0 + + buffer := make([]byte, 1) + _, err := reader.Read(buffer) + if err != nil { + return nil, fmt.Errorf("reading header byte: %w", err) + } + + if padding && nibble.PadLeft(buffer[0]) != 0 { + return nil, fmt.Errorf("bad format") + }*/ + + partial, err := decodePartialKey(reader, nibbleCount) + + if err != nil { + return nil, fmt.Errorf("cannot decode key: %w", err) + } + + partialPadding := nibble.NumberPadding(uint(nibbleCount)) + + nodeValue := &NodeValue{} + + if variant == node.LeafVariant { + sd := scale.NewDecoder(reader) + err := sd.Decode(&nodeValue.Data) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrDecodeStorageValue, err) + } + return NewNode(Leaf, *nibble.NewNibbleSliceWithPadding(partial, partialPadding), nodeValue, nil), nil + } + + nodeValue, err = decodeHashedValue(reader) + + if err != nil { + return nil, err + } + + return NewNode(Leaf, *nibble.NewNibbleSliceWithPadding(partial, partialPadding), nodeValue, nil), nil +} + +func decodeHashedValue(reader io.Reader) (*NodeValue, error) { + buffer := make([]byte, common.HashLength) + n, err := reader.Read(buffer) + if err != nil { + return nil, fmt.Errorf("%w: %s", ErrDecodeStorageValue, err) + } + if n < common.HashLength { + return nil, fmt.Errorf("%w: expected %d, got: %d", ErrDecodeHashedValueTooShort, common.HashLength, n) + } + + return &NodeValue{buffer, true}, nil +} + +func decodePartialKey(reader io.Reader, partialKeyLength uint16) (b []byte, err error) { + if partialKeyLength == 0 { + return []byte{}, nil + } + + nibblePerByte := uint16(nibble.NibblePerByte) + key := make([]byte, partialKeyLength/nibblePerByte+partialKeyLength%nibblePerByte) + n, err := reader.Read(key) + if err != nil { + return nil, fmt.Errorf("reading from reader: %w", err) + } else if n != len(key) { + return nil, fmt.Errorf("%w: read %d bytes instead of expected %d bytes", + ErrReaderMismatchCount, n, len(key)) + } + + return key, nil +} diff --git a/pkg/trie/triedb/node_codec_test.go b/pkg/trie/triedb/node_codec_test.go new file mode 100644 index 0000000000..1a7974b187 --- /dev/null +++ b/pkg/trie/triedb/node_codec_test.go @@ -0,0 +1,122 @@ +// Copyright 2023 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package triedb + +import ( + "bytes" + "io" + "testing" + + "github.com/ChainSafe/gossamer/internal/trie/node" + "github.com/ChainSafe/gossamer/pkg/scale" + "github.com/ChainSafe/gossamer/pkg/trie/triedb/nibble" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_decodeLeaf(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + reader io.Reader + variant node.Variant + partialKeyLength uint16 + leaf *Node + errWrapped error + errMessage string + }{ + "key_decoding_error": { + reader: bytes.NewBuffer([]byte{ + // missing key data byte + }), + variant: node.LeafVariant, + partialKeyLength: 1, + errWrapped: io.EOF, + errMessage: "cannot decode key: reading from reader: EOF", + }, + "value_decoding_error": { + reader: bytes.NewBuffer(concatByteSlices([][]byte{ + {9}, // key data + {255, 255}, // bad storage value data + })), + variant: node.LeafVariant, + partialKeyLength: 2, + errWrapped: ErrDecodeStorageValue, + errMessage: "cannot decode storage value: unknown prefix for compact uint: 255", + }, + "missing_storage_value_data": { + reader: bytes.NewBuffer([]byte{ + 9, // key data + // missing storage value data + }), + variant: node.LeafVariant, + partialKeyLength: 2, + errWrapped: ErrDecodeStorageValue, + errMessage: "cannot decode storage value: reading byte: EOF", + }, + "empty_storage_value_data": { + reader: bytes.NewBuffer(concatByteSlices([][]byte{ + {9}, // key data + scaleEncodeByteSlice(t, []byte{}), // results to []byte{0} + })), + variant: node.LeafVariant, + partialKeyLength: 2, + leaf: &Node{ + Type: Leaf, + Partial: *nibble.NewNibbleSliceWithPadding([]byte{9}, 0), + Value: &NodeValue{[]byte{}, false}, + }, + }, + "success": { + reader: bytes.NewBuffer(concatByteSlices([][]byte{ + {9}, // key data + scaleEncodeBytes(t, 1, 2, 3, 4, 5), // storage value data + })), + variant: node.LeafVariant, + partialKeyLength: 2, + leaf: &Node{ + Type: Leaf, + Partial: *nibble.NewNibbleSliceWithPadding([]byte{9}, 0), + Value: &NodeValue{[]byte{1, 2, 3, 4, 5}, false}, + }, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + leaf, err := decodeLeaf(testCase.reader, testCase.variant, testCase.partialKeyLength) + + assert.ErrorIs(t, err, testCase.errWrapped) + if err != nil { + assert.EqualError(t, err, testCase.errMessage) + } + assert.Equal(t, testCase.leaf, leaf) + }) + } +} + +func concatByteSlices(slices [][]byte) (concatenated []byte) { + length := 0 + for i := range slices { + length += len(slices[i]) + } + concatenated = make([]byte, 0, length) + for _, slice := range slices { + concatenated = append(concatenated, slice...) + } + return concatenated +} + +func scaleEncodeByteSlice(t *testing.T, b []byte) (encoded []byte) { + encoded, err := scale.Marshal(b) + require.NoError(t, err) + return encoded +} + +func scaleEncodeBytes(t *testing.T, b ...byte) (encoded []byte) { + return scaleEncodeByteSlice(t, b) +} diff --git a/pkg/trie/triedb/nodes.go b/pkg/trie/triedb/nodes.go new file mode 100644 index 0000000000..64c6548466 --- /dev/null +++ b/pkg/trie/triedb/nodes.go @@ -0,0 +1,80 @@ +// Copyright 2022 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package triedb + +import ( + "fmt" + "strconv" + + "github.com/ChainSafe/gossamer/pkg/trie/triedb/nibble" + "github.com/qdm12/gotree" +) + +type NodeType uint8 + +const ( + Empty NodeType = iota + Leaf + NibbledBranch +) + +type Node struct { + Type NodeType + Partial nibble.NibbleSlice + Value *NodeValue + Children []*NodeHandle +} + +func NewNode(nodeType NodeType, partial nibble.NibbleSlice, value *NodeValue, children []*NodeHandle) *Node { + return &Node{nodeType, partial, value, children} +} + +func (n *Node) String() string { + return n.StringNode().String() +} + +// StringNode returns a gotree compatible node for String methods. +func (n *Node) StringNode() (stringNode *gotree.Node) { + stringNode = gotree.New(fmt.Sprintf("%d", n.Type)) + stringNode.Appendf("Partial: %s", bytesToString(n.Partial.Data())) + if n.Value != nil { + stringNode.Appendf("Value: %s", bytesToString(n.Value.Data)) + } else { + stringNode.Appendf("Value: nil") + } + stringNode.Appendf("Hashed: %s", strconv.FormatBool(n.Value.Hashed)) + if n.Children != nil && len(n.Children) > 0 { + for i, child := range n.Children { + if child == nil { + continue + } + stringNode.Appendf("Child: %d", i) + stringNode.Appendf("Child data: %s", bytesToString(child.Data)) + stringNode.Appendf("Child hashed: %s", strconv.FormatBool(child.Hashed)) + } + } + + return stringNode +} + +type NodeValue struct { + Data []byte + Hashed bool +} + +type NodeHandle struct { + Data []byte + Hashed bool +} + +func bytesToString(b []byte) (s string) { + switch { + case b == nil: + return "nil" + case len(b) <= 20: + return fmt.Sprintf("0x%x", b) + default: + return fmt.Sprintf("0x%x...%x", b[:8], b[len(b)-8:]) + } +} diff --git a/pkg/trie/triedb/trie_db.go b/pkg/trie/triedb/trie_db.go new file mode 100644 index 0000000000..e890abdb2e --- /dev/null +++ b/pkg/trie/triedb/trie_db.go @@ -0,0 +1,33 @@ +// Copyright 2022 ChainSafe Systems (ON) +// SPDX-License-Identifier: LGPL-3.0-only + +package triedb + +import ( + "github.com/ChainSafe/gossamer/pkg/trie/hashdb" + "github.com/ChainSafe/gossamer/pkg/trie/triedb/nibble" +) + +type TrieDBBuilder struct { + db hashdb.HashDB + root []byte + //TODO: implement cache and recorder +} + +func NewTrieDBBuilder(db hashdb.HashDB, root []byte) *TrieDBBuilder { + return &TrieDBBuilder{db, root} +} + +func (tdbb TrieDBBuilder) Build() *TrieDB { + return &TrieDB{tdbb.db, tdbb.root} +} + +type TrieDB struct { + db hashdb.HashDB + root []byte + //TODO: implement cache and recorder +} + +func (tdb TrieDB) GetValue(key []byte) ([]byte, error) { + return NewLookup(tdb.db, tdb.root).Lookup(nibble.NewNibbleSlice(key)) +}