diff --git a/nmt.go b/nmt.go index e9c318a..a2b8c1f 100644 --- a/nmt.go +++ b/nmt.go @@ -190,13 +190,176 @@ func (n *NamespacedMerkleTree) ProveRange(start, end int) (Proof, error) { if err := n.validateRange(start, end); err != nil { return NewEmptyRangeProof(isMaxNsIgnored), err } - proof, err := n.buildRangeProof(start, end) + proof, _, err := n.buildRangeProof(start, end, false) if err != nil { return Proof{}, err } return NewInclusionProof(start, end, proof, isMaxNsIgnored), nil } +// Coordinate identifies a tree node using the depth and position +// +// Depth Position +// 0 0 +// / \ +// / \ +// 1 0 1 +// /\ /\ +// 2 0 1 2 3 +// /\ /\ /\ /\ +// 3 0 1 2 3 4 5 6 7 +type Coordinate struct { + // depth is the typical depth of a tree, 0 being the root. + // TODO test 0 case and 7 leaves all nodes + depth int + // position is the index of the node at the provided depth, 0 being the left most + // node. + position int +} + +func (coordinate Coordinate) Validate() error { + if coordinate.depth < 0 { + // TODO: test this case + return fmt.Errorf("depth cannot be negative: %d", coordinate.depth) + } + if coordinate.position < 0 { + // TODO: test this case + return fmt.Errorf("position cannot be negative: %d", coordinate.position) + } + return nil +} + +// ProveInner takes a list of inner nodes coordinates and returns the corresponding +// inner proof. +// The range used to build the proof +// TODO: range is consecutive +// TODO investigate the range in drawIO +// TODO: write test for that case and see if it works +func (n *NamespacedMerkleTree) ProveInner(coordinates []Coordinate) (InnerProof, error) { + isMaxNsIgnored := n.treeHasher.IsMaxNamespaceIDIgnored() + start, end, err := toRange(coordinates, n.Size()) + if err != nil { + return InnerProof{}, err + } + proof, coordinates, err := n.buildRangeProof(start, end, true) + if err != nil { + return InnerProof{}, err + } + return NewInnerInclusionProof(proof, coordinates, n.Size(), isMaxNsIgnored), nil +} + +// toRange takes a list of coordinates and a tree size, then converts +// that range to a list of leaves. +// The returned range is consecutive even if the coordinates refer to +// a disjoint range. +// For example, in an eight leaves tree: +// +// Depth Position +// 0 0 +// / \ +// / \ +// 1 0 1 +// /\ /\ +// 2 0 1 2 3 +// /\ /\ /\ /\ +// 3 0 1 2 3 4 5 6 7 +// +// If the provided coordinates are {2, 0}, which cover the range [0, 2) +// and {2, 3}, which cover the range [6, 8), the returned range will be [0, 8). +func toRange(coordinates []Coordinate, treeSize int) (int, int, error) { + for _, coordinate := range coordinates { + // TODO: test this case + if err := coordinate.Validate(); err != nil { + return 0, 0, fmt.Errorf("coordinate {%d, %d} is invalid: %w", coordinate.depth, coordinate.position, err) + } + } + if treeSize < 0 { + // TODO: test this case + return 0, 0, fmt.Errorf("tree size %d cannot be stricly negative", treeSize) + } + // TODO check the case where treeSize == 0 and multiple coordinates, and what are the possibilities. + start := 0 + end := 0 + maxDepth, err := maxDepth(treeSize) + if err != nil { + return 0, 0, err + } + for _, coordinate := range coordinates { + currentStart, err := startLeafIndex(coordinate, maxDepth) + if err != nil { + return 0, 0, err + } + currentEnd, err := endLeafIndex(coordinate, maxDepth) + if err != nil { + return 0, 0, err + } + if currentEnd < start { + start = currentStart + } + if currentEnd > end { + end = currentEnd + } + } + return start, end, nil +} + +// maxDepth returns the maximum depth of a tree with treeSize +// number of leaves. +func maxDepth(treeSize int) (int, error) { + if treeSize < 0 { + // TODO: test this case + return 0, fmt.Errorf("tree size %d cannot be stricly negative", treeSize) + } + return bits.Len(uint(treeSize)) - 1, nil +} + +// endLeafIndex returns the index of range's end leaf covered by the provided +// inner node coordinates. +// The max depth is provided to know at which level to stop. +// Note: the formula used is based on: +// - end_leaf = start_leaf + (2 ** height) +// with position being the index of the inner node inside the tree +// and the height being the traditional height of a tree, i.e., bottom -> top. +func endLeafIndex(coordinate Coordinate, maxDepth int) (int, error) { + if err := coordinate.Validate(); err != nil { + return 0, err + } + if maxDepth < coordinate.depth { + return 0, fmt.Errorf("max depth %d cannot be stricly smaller than the coordinates depth %d", maxDepth, coordinate.depth) + } + // since the coordinates are expressed in depth, we need to calculate the height + // using: maxDepth = height + depth + height := maxDepth - coordinate.depth + // the bit shift is used to compute 2 ** height. + subtreeSize := 1 << height + return (coordinate.position + 1) * subtreeSize, nil +} + +// startLeafIndex returns the index of the range's start leaf covered by +// the provided inner node coordinates. +// The max depth is provided to know at which level to stop. +// Note: the formula used is based on: +// - start_leaf = position * (2 ** height) +// with position being the index of the inner node inside the tree +// and the height being the traditional height of a tree, i.e., bottom -> top. +func startLeafIndex(coordinate Coordinate, maxDepth int) (int, error) { + if err := coordinate.Validate(); err != nil { + // TODO: test this + return 0, err + } + if maxDepth < coordinate.depth { + // TODO: test this + return 0, fmt.Errorf("max depth %d cannot be stricly smaller than the coordinates depth %d", maxDepth, coordinate.depth) + } + // since the coordinates are expressed in depth, we need to calculate the height + // using: maxDepth = height + depth + height := maxDepth - coordinate.depth + // In an RFC-6962 merkle tree, the tree height increases with every multiple of 2. + // For example, for all the trees of size 4 to 7, the RFC-6962 tree will have a height of 3. + subtreeSize := 1 << height + return coordinate.position * subtreeSize, nil +} + // ProveNamespace returns a range proof for the given NamespaceID. // // case 1) If the namespace nID is out of the range of the tree's min and max @@ -265,7 +428,7 @@ func (n *NamespacedMerkleTree) ProveNamespace(nID namespace.ID) (Proof, error) { // the tree or calculated the range it would be in (to generate a proof of // absence and to return the corresponding leaf hashes). - proof, err := n.buildRangeProof(proofStart, proofEnd) + proof, _, err := n.buildRangeProof(proofStart, proofEnd, false) if err != nil { return Proof{}, err } @@ -289,14 +452,20 @@ func (n *NamespacedMerkleTree) validateRange(start, end int) error { // buildRangeProof returns the nodes (as byte slices) in the range proof of the // supplied range i.e., [proofStart, proofEnd) where proofEnd is non-inclusive. // The nodes are ordered according to in order traversal of the namespaced tree. +// If the saveInnerNodesCoordinates flag is set to true, the method also returns +// the coordinates of the range proof's nodes in the same +// order. These can be used for creating inner nodes proofs. // Any errors returned by this method are irrecoverable and indicate an illegal state of the tree (n). -func (n *NamespacedMerkleTree) buildRangeProof(proofStart, proofEnd int) ([][]byte, error) { - proof := [][]byte{} // it is the list of nodes hashes (as byte slices) with no index +func (n *NamespacedMerkleTree) buildRangeProof(proofStart, proofEnd int, saveProofNodesCoordinates bool) ([][]byte, []Coordinate, error) { + var proof [][]byte // it is the list of nodes hashes (as byte slices) with no index + // the list of the proof nodes coordinates. + // gets populated if the saveProofNodesCoordinates flag is set + var coordinates []Coordinate var recurse func(start, end int, includeNode bool) ([]byte, error) // validate the range if err := n.validateRange(proofStart, proofEnd); err != nil { - return nil, err + return nil, nil, err } // start, end are indices of leaves in the tree hence they should be within @@ -318,6 +487,16 @@ func (n *NamespacedMerkleTree) buildRangeProof(proofStart, proofEnd int) ([][]by if (start < proofStart || start >= proofEnd) && includeNode { // add the leafHash to the proof proof = append(proof, leafHash) + if saveProofNodesCoordinates { + maxDepth, err := maxDepth(n.Size()) + if err != nil { + return nil, err + } + coordinates = append(coordinates, Coordinate{ + depth: maxDepth, + position: start, + }) + } } // if the index of the leaf is within the queried range i.e., // [proofStart, proofEnd] OR if the leaf is not required as part of @@ -368,6 +547,13 @@ func (n *NamespacedMerkleTree) buildRangeProof(proofStart, proofEnd int) ([][]by // of the proof but not its left and right subtrees if includeNode && !newIncludeNode { proof = append(proof, hash) + coordinate, err := ToCoordinate(start, end, n.Size()) + if err != nil { + return nil, err + } + if saveProofNodesCoordinates { + coordinates = append(coordinates, coordinate) + } } return hash, nil @@ -378,9 +564,57 @@ func (n *NamespacedMerkleTree) buildRangeProof(proofStart, proofEnd int) ([][]by fullTreeSize = 1 } if _, err := recurse(0, fullTreeSize, true); err != nil { - return nil, err - } - return proof, nil + return nil, nil, err + } + return proof, coordinates, nil +} + +// ToCoordinate takes a start leaf index, an end exclusive leaf index +// and a tree size and returns the coordinates of the node +// that covers that whole range. +// The target node can either be a leaf node if the range contains +// a single element, i.e. (end-start == 1), or an inner node. +// The coordinate calculation follows the RFC-6962 standard. +// This means that leaves get elevated in trees that have +// a size that is not a power of 2. +// Important: The inputs need to satisfy the following criteria: +// - start >= 0 +// - end > start +// - treeSize >= end +// Otherwise, a sensible error is returned. +// Note: the formula used is based on: +// - start_leaf = position * (2 ** height) +// - end_leaf = start_leaf + (2 ** height) +// with position being the index of the inner node inside the tree +// and the height being the traditional height of a tree, i.e. bottom -> top. +func ToCoordinate(start, end, treeSize int) (Coordinate, error) { + if start < 0 { + return Coordinate{}, fmt.Errorf("start cannot be stricly negative: %d", start) + } + if end <= start { + return Coordinate{}, fmt.Errorf("end %d cannot be smaller than start %d", end, start) + } + if treeSize < end { + return Coordinate{}, fmt.Errorf("tree size %d cannot be smaller than end %d", treeSize, end) + } + + // calculates the height of the smallest subtree + // that can contain the [start, end) range. + // bits.Len() - 1 is used as a fast alternative to compute + // the integer part of the result of log2(end-start). + height := bits.Len(uint(end-start)) - 1 + maxDepth, err := maxDepth(treeSize) + if err != nil { + // TODO test this case + return Coordinate{}, err + } + // 1 << height == 2 ** height. This result is based + // on the formula documented above. + position := start / (1 << height) + return Coordinate{ + depth: maxDepth - height, + position: position, + }, nil } // Get returns leaves for the given namespace.ID. diff --git a/nmt_test.go b/nmt_test.go index 6e0565e..79b1504 100644 --- a/nmt_test.go +++ b/nmt_test.go @@ -906,7 +906,7 @@ func Test_buildRangeProof_Err(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, err := tt.tree.buildRangeProof(tt.proofStart, tt.proofEnd) + _, _, err := tt.tree.buildRangeProof(tt.proofStart, tt.proofEnd) assert.Equal(t, tt.wantErr, err != nil) if tt.wantErr { assert.True(t, errors.Is(err, tt.errType)) @@ -1175,3 +1175,89 @@ func TestForcedOutOfOrderNamespacedMerkleTree(t *testing.T) { assert.NoError(t, err) } } + +func TestToCoordinate(t *testing.T) { + tests := []struct { + start, end, treeSize int + expectError bool + expectedCoordinate Coordinate + }{ + {start: 0, end: 1, treeSize: 1, expectError: false, expectedCoordinate: Coordinate{depth: 0, position: 0}}, + {start: 0, end: 2, treeSize: 2, expectError: false, expectedCoordinate: Coordinate{depth: 0, position: 0}}, + {start: 0, end: 3, treeSize: 4, expectError: false, expectedCoordinate: Coordinate{depth: 1, position: 0}}, + {start: 1, end: 2, treeSize: 4, expectError: false, expectedCoordinate: Coordinate{depth: 2, position: 1}}, + {start: 0, end: 4, treeSize: 4, expectError: false, expectedCoordinate: Coordinate{depth: 0, position: 0}}, + {start: 0, end: 3, treeSize: 5, expectError: false, expectedCoordinate: Coordinate{depth: 1, position: 0}}, + {start: 0, end: 3, treeSize: 6, expectError: false, expectedCoordinate: Coordinate{depth: 1, position: 0}}, + {start: 0, end: 3, treeSize: 7, expectError: false, expectedCoordinate: Coordinate{depth: 1, position: 0}}, + {start: 2, end: 4, treeSize: 8, expectError: false, expectedCoordinate: Coordinate{depth: 2, position: 1}}, + {start: 3, end: 4, treeSize: 8, expectError: false, expectedCoordinate: Coordinate{depth: 3, position: 3}}, + {start: 0, end: 8, treeSize: 8, expectError: false, expectedCoordinate: Coordinate{depth: 0, position: 0}}, + // TODO add false cases and add all cases for 7 leaves tree + } + + for _, test := range tests { + result, err := ToCoordinate(test.start, test.end, test.treeSize) + if test.expectError { + assert.Error(t, err) + } else { + assert.Equal(t, test.expectedCoordinate, result) + } + } +} + +func TestBuildRangeProofCoordinates(t *testing.T) { + tests := []struct { + leavesNID []byte + proofStart int + proofEnd int + expected []Coordinate + expectError bool + }{ + { + leavesNID: []byte{1, 2, 3, 4, 5, 6, 7, 8}, + proofStart: 2, + proofEnd: 4, + expected: []Coordinate{{depth: 2, position: 0}, {depth: 1, position: 1}}, + expectError: false, + }, + { + leavesNID: []byte{1, 2, 3, 4, 5, 6, 7, 8}, + proofStart: 1, + proofEnd: 3, + expected: []Coordinate{{depth: 3, position: 0}, {depth: 3, position: 3}, {depth: 1, position: 1}}, + expectError: false, + }, + { + leavesNID: []byte{1, 2, 3, 4, 5, 6, 7, 8}, + proofStart: 0, + proofEnd: 8, + expected: []Coordinate{}, + expectError: false, + }, + // TODO add all cases for 7 leaves tree + } + + for _, test := range tests { + tree := exampleNMT(1, true, test.leavesNID...) + _, coords, err := tree.buildRangeProof(test.proofStart, test.proofEnd) + if (err != nil) != test.expectError { + t.Fatalf("expected error: %v, got: %v", test.expectError, err) + } + if !test.expectError && !equalCoordinates(coords, test.expected) { + t.Fatalf("expected: %+v, got: %+v", test.expected, coords) + } + } +} + +func equalCoordinates(a, b []Coordinate) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/proof.go b/proof.go index 998b9a8..798e4bc 100644 --- a/proof.go +++ b/proof.go @@ -9,7 +9,7 @@ import ( "math/bits" "github.com/celestiaorg/nmt/namespace" - pb "github.com/celestiaorg/nmt/pb" + "github.com/celestiaorg/nmt/pb" ) var ( @@ -404,6 +404,168 @@ func (proof Proof) VerifyLeafHashes(nth *NmtHasher, verifyCompleteness bool, nID return bytes.Equal(rootHash, root), nil } +// InnerProof contains an inclusion proof for a set of inner nodes to the NMT. +// Currently, the inner proof generation only supports adjacent ranges even if the +// provided coordinates cover a disjoint range. +// For example, if the inner nodes of the proof represent the ranges [1, 3), [6, 10), +// the generated proof will be targeting the range [1, 10). +// However, the inner proof verification can take any inner proof, even if the represented +// range is disjointed, and will verify it accordingly. +type InnerProof struct { + // nodes the proof inner nodes needed to verify inclusion. + nodes [][]byte + // coordinates the coordinates of the above nodes in the + // same order. + coordinates []Coordinate + // treeSize the size of the tree, i.e., the number of leaves. + treeSize int + // isMaxNamespaceIDIgnored whether to ignore the maximum namespace IDs. + isMaxNamespaceIDIgnored bool +} + +// TODO add marshallers and protobuf definitions + +// NewInnerInclusionProof constructs a proof proving that a set of inner nodes is +// included in an NMT. +// Does not validate the inputs. +func NewInnerInclusionProof(proofNodes [][]byte, coordinates []Coordinate, treeSize int, ignoreMaxNamespace bool) InnerProof { + return InnerProof{ + nodes: proofNodes, + coordinates: coordinates, + treeSize: treeSize, + isMaxNamespaceIDIgnored: ignoreMaxNamespace, + } +} + +// VerifyInnerNodes +// coordinates should be in the same order as inner nodes +func (proof InnerProof) VerifyInnerNodes(nth *NmtHasher, innerNodes [][]byte, coordinates []Coordinate, root []byte) (bool, error) { + if len(innerNodes) != len(coordinates) { + return false, fmt.Errorf("the number of inner nodes %d is different than the number of coordinates %d", len(innerNodes), len(coordinates)) + } + + // check that the root is valid w.r.t the NMT hasher + if err := nth.ValidateNodeFormat(root); err != nil { + return false, fmt.Errorf("root does not match the NMT hasher's hash format: %w", err) + } + // check that all the proof.nodes are valid w.r.t the NMT hasher + for _, node := range proof.nodes { + if err := nth.ValidateNodeFormat(node); err != nil { + return false, fmt.Errorf("proof nodes do not match the NMT hasher's hash format: %w", err) + } + } + // check that all the leafHashes are valid w.r.t the NMT hasher + for _, leafHash := range innerNodes { + if err := nth.ValidateNodeFormat(leafHash); err != nil { + return false, fmt.Errorf("leaf hash does not match the NMT hasher's hash format: %w", err) + } + } + + _, proofEnd, err := toRange(coordinates, proof.treeSize) + if err != nil { + return false, err + } + + allNodes := append(proof.nodes, innerNodes...) + allCoordinates := append(proof.coordinates, coordinates...) + + var computeRoot func(start, end int) ([]byte, error) + // computeRoot can return error iff the HashNode function fails while calculating the root + computeRoot = func(start, end int) ([]byte, error) { + innerNode, found, err := getInnerNode(allNodes, allCoordinates, proof.treeSize, start, end) + if err != nil { + return nil, err + } + if found { + return innerNode, nil + } + + // Recursively get left and right subtree + k := getSplitPoint(end - start) + left, found, err := getInnerNode(allNodes, allCoordinates, proof.treeSize, start, start+k) + if err != nil { + return nil, err + } + // if a node is found, we could optimize by removing it from the list of nodes. + if !found { + left, err = computeRoot(start, start+k) + if err != nil { + return nil, fmt.Errorf("failed to compute subtree root [%d, %d): %w", start, start+k, err) + } + } + right, found, err := getInnerNode(allNodes, allCoordinates, proof.treeSize, start+k, end) + if err != nil { + return nil, err + } + // Similarly, if a node is found, we could optimize by removing it from the list of nodes. + if !found { + right, err = computeRoot(start+k, end) + if err != nil { + return nil, fmt.Errorf("failed to compute subtree root [%d, %d): %w", start+k, end, err) + } + } + + // only right leaf/subtree can be non-existent + if right == nil { + // TODO test with coordinates + return left, nil + } + hash, err := nth.HashNode(left, right) + if err != nil { + return nil, fmt.Errorf("failed to hash node: %w", err) + } + return hash, nil + } + + // estimate the leaf size of the subtree containing the proof range + proofRangeSubtreeEstimate := getSplitPoint(proofEnd) * 2 + if proofRangeSubtreeEstimate < 1 { + proofRangeSubtreeEstimate = 1 + } + rootHash, err := computeRoot(0, proofRangeSubtreeEstimate) + if err != nil { + return false, fmt.Errorf("failed to compute root [%d, %d): %w", 0, proofRangeSubtreeEstimate, err) + } + for i := 0; i < len(proof.nodes); i++ { + rootHash, err = nth.HashNode(rootHash, proof.nodes[i]) + if err != nil { + return false, fmt.Errorf("failed to hash node: %w", err) + } + } + + return bytes.Equal(rootHash, root), nil +} + +// getInnerNode takes a list of nodes and coordinates and returns the inner node +// corresponding to the [start, end) range. +// Expects the number of nodes and coordinates to be in the same order. +// Otherwise, the returned node might not be the correct one. +func getInnerNode(nodes [][]byte, coordinates []Coordinate, treeSize int, start int, end int) ([]byte, bool, error) { + if start < 0 { + return nil, false, fmt.Errorf("range start %d cannot be strictly negative", start) + } + if end <= start { + return nil, false, fmt.Errorf("range end %d cannot be smaller than start %d", end, start) + } + if treeSize < end { + return nil, false, fmt.Errorf("tree size %d cannot be strictly smaller than the end of range %d", treeSize, end) + } + // TODO: test these validates + for index, coordinate := range coordinates { + if err := coordinate.Validate(); err != nil { + return nil, false, err + } + startLeaf, endLeaf, err := toRange([]Coordinate{coordinate}, treeSize) + if err != nil { + return nil, false, err + } + if startLeaf == start && endLeaf == end { + return nodes[index], true, nil + } + } + return nil, false, nil +} + // VerifyInclusion checks that the inclusion proof is valid by using leaf data // and the provided proof to regenerate and compare the root. Note that the leavesWithoutNamespace data should not contain the prefixed namespace, unlike the tree.Push method, // which takes prefixed data. All leaves implicitly have the same namespace ID: diff --git a/proof_test.go b/proof_test.go index 235a403..b6e177c 100644 --- a/proof_test.go +++ b/proof_test.go @@ -122,7 +122,7 @@ func TestProof_VerifyNamespace_False(t *testing.T) { t.Fatalf("invalid test setup: error on ProveNamespace(): %v", err) } // inclusion proof of the leaf index 0 - incProof0, err := n.buildRangeProof(0, 1) + incProof0, _, err := n.buildRangeProof(0, 1) require.NoError(t, err) incompleteFirstNs := NewInclusionProof(0, 1, incProof0, false) type args struct { @@ -135,13 +135,13 @@ func TestProof_VerifyNamespace_False(t *testing.T) { // an invalid absence proof for an existing namespace ID (2) in the constructed tree leafIndex := 3 - inclusionProofOfLeafIndex, err := n.buildRangeProof(leafIndex, leafIndex+1) + inclusionProofOfLeafIndex, _, err := n.buildRangeProof(leafIndex, leafIndex+1) require.NoError(t, err) leafHash := n.leafHashes[leafIndex] // the only data item with namespace ID = 2 in the constructed tree is at index 3 invalidAbsenceProof := NewAbsenceProof(leafIndex, leafIndex+1, inclusionProofOfLeafIndex, leafHash, false) // inclusion proof of the leaf index 10 - incProof10, err := n.buildRangeProof(10, 11) + incProof10, _, err := n.buildRangeProof(10, 11) require.NoError(t, err) // root @@ -229,6 +229,14 @@ func TestProof_VerifyNamespace_False(t *testing.T) { } } +func TestInnerProofs(t *testing.T) { + n := exampleNMT(1, true, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15) + proof, err := n.ProveInner([]Coordinate{ + {1, 0}, {3, 4}, {4, 10}, + }) + assert.NoError(t, err) + assert.NotNil(t, proof) // this is just to stop the debugger here and see if the proof is valid +} func TestProof_MultipleLeaves(t *testing.T) { n := New(sha256.New()) ns := []byte{1, 2, 3, 4, 5, 6, 7, 8}