diff --git a/p2p/ipld/read.go b/p2p/ipld/read.go index e08136499a..d58255c76b 100644 --- a/p2p/ipld/read.go +++ b/p2p/ipld/read.go @@ -2,7 +2,6 @@ package ipld import ( "context" - "errors" "fmt" "math" "math/rand" @@ -246,8 +245,8 @@ func GetLeafData( func leafPath(index, total uint32) ([]string, error) { // ensure that the total is a power of two - if total != nextPowerOf2(total) { - return nil, errors.New("expected total to be a power of 2") + if !isPowerOf2(total) { + return nil, fmt.Errorf("expected total to be a power of 2, got %d", total) } if total == 0 { @@ -269,29 +268,7 @@ func leafPath(index, total uint32) ([]string, error) { return path, nil } -// nextPowerOf2 returns the next lowest power of 2 unless the input is a power -// of two, in which case it returns the input -func nextPowerOf2(v uint32) uint32 { - if v == 1 { - return 1 - } - // keep track of the input - i := v - - // find the next highest power using bit mashing - v-- - v |= v >> 1 - v |= v >> 2 - v |= v >> 4 - v |= v >> 8 - v |= v >> 16 - v++ - - // check if the input was the next highest power - if i == v { - return v - } - - // return the next lowest power - return v / 2 +// isPowerOf2 returns checks if a given number is a power of two +func isPowerOf2(v uint32) bool { + return math.Ceil(math.Log2(float64(v))) == math.Floor(math.Log2(float64(v))) } diff --git a/p2p/ipld/read_test.go b/p2p/ipld/read_test.go index 715597298f..aca4b46780 100644 --- a/p2p/ipld/read_test.go +++ b/p2p/ipld/read_test.go @@ -58,36 +58,41 @@ func TestLeafPath(t *testing.T) { } } -func TestNextPowerOf2(t *testing.T) { +func Test_isPowerOf2(t *testing.T) { type test struct { input uint32 - expected uint32 + expected bool } tests := []test{ { input: 2, - expected: 2, + expected: true, }, { input: 11, - expected: 8, + expected: false, }, { input: 511, - expected: 256, + expected: false, + }, + + { + input: 0, + expected: true, }, { input: 1, - expected: 1, + expected: true, }, { - input: 0, - expected: 0, + input: 16, + expected: true, }, } for _, tt := range tests { - res := nextPowerOf2(tt.input) - assert.Equal(t, tt.expected, res) + res := isPowerOf2(tt.input) + assert.Equal(t, tt.expected, res, fmt.Sprintf("input was %d", tt.input)) } } diff --git a/types/block.go b/types/block.go index a77a6e58f4..481fdc9105 100644 --- a/types/block.go +++ b/types/block.go @@ -1381,9 +1381,9 @@ func (data *Data) ComputeShares() (NamespacedShares, int) { msgShares := data.Messages.splitIntoShares() curLen := len(txShares) + len(intermRootsShares) + len(evidenceShares) + len(msgShares) - // FIXME(ismail): this is not a power of two - // see: https://github.com/lazyledger/lazyledger-specs/issues/80 and - wantLen := getNextSquareNum(curLen) + // find the number of shares needed to create a square that has a power of + // two width + wantLen := paddedLen(curLen) // ensure that the min square size is used if wantLen < minSharecount { @@ -1400,10 +1400,32 @@ func (data *Data) ComputeShares() (NamespacedShares, int) { tailShares...), curLen } -func getNextSquareNum(length int) int { - width := int(math.Ceil(math.Sqrt(float64(length)))) - // TODO(ismail): make width a power of two instead - return width * width +// paddedLen calculates the number of shares needed to make a power of 2 square +// given the current number of shares +func paddedLen(length int) int { + width := uint32(math.Ceil(math.Sqrt(float64(length)))) + width = nextHighestPowerOf2(width) + return int(width * width) +} + +// nextPowerOf2 returns the next highest power of 2 unless the input is a power +// of two, in which case it returns the input +func nextHighestPowerOf2(v uint32) uint32 { + if v == 0 { + return 0 + } + + // find the next highest power using bit mashing + v-- + v |= v >> 1 + v |= v >> 2 + v |= v >> 4 + v |= v >> 8 + v |= v >> 16 + v++ + + // return the next highest power + return v } type Message struct { diff --git a/types/block_test.go b/types/block_test.go index 7cc9432426..667e1c303a 100644 --- a/types/block_test.go +++ b/types/block_test.go @@ -1398,6 +1398,70 @@ func TestPutBlock(t *testing.T) { } } +func TestPaddedLength(t *testing.T) { + type test struct { + input, expected int + } + tests := []test{ + {0, 0}, + {1, 1}, + {2, 4}, + {4, 4}, + {5, 16}, + {11, 16}, + {128, 256}, + } + for _, tt := range tests { + res := paddedLen(tt.input) + assert.Equal(t, tt.expected, res) + } +} + +func TestNextHighestPowerOf2(t *testing.T) { + type test struct { + input uint32 + expected uint32 + } + tests := []test{ + { + input: 2, + expected: 2, + }, + { + input: 11, + expected: 16, + }, + { + input: 511, + expected: 512, + }, + { + input: 1, + expected: 1, + }, + { + input: 0, + expected: 0, + }, + { + input: 5, + expected: 8, + }, + { + input: 6, + expected: 8, + }, + { + input: 16, + expected: 16, + }, + } + for _, tt := range tests { + res := nextHighestPowerOf2(tt.input) + assert.Equal(t, tt.expected, res) + } +} + func generateRandomMsgOnlyData(msgCount int) Data { out := make([]Message, msgCount) for i, msg := range generateRandNamespacedRawData(msgCount, NamespaceSize, MsgShareSize-2) {