Skip to content

Commit

Permalink
Add Bitlist
Browse files Browse the repository at this point in the history
  • Loading branch information
hwwhww committed Jul 4, 2019
1 parent 9c72b93 commit 75de33e
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 0 deletions.
2 changes: 2 additions & 0 deletions ssz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from .sedes import ( # noqa: F401
BaseSedes,
BasicSedes,
Bitlist,
Bitvector,
Boolean,
Byte,
ByteList,
Expand Down
3 changes: 3 additions & 0 deletions ssz/sedes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
BasicSedes,
CompositeSedes,
)
from .bitlist import ( # noqa: F401
Bitlist,
)
from .bitvector import ( # noqa: F401
Bitvector,
)
Expand Down
80 changes: 80 additions & 0 deletions ssz/sedes/bitlist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import (
Sequence,
Union,
)

from ssz.exceptions import (
DeserializationError,
SerializationError,
)
from ssz.sedes.base import (
BaseCompositeSedes,
)
from ssz.utils import (
merkleize,
mix_in_length,
pack_bitvector_bitlist,
)

BytesOrByteArray = Union[bytes, bytearray]


class Bitlist(BaseCompositeSedes[BytesOrByteArray, bytes]):
def __init__(self, length: int) -> None:
if length < 0:
raise TypeError("Max length cannot be negative")
self.length = length

#
# Size
#
is_fixed_sized = False

def get_fixed_size(self):
raise ValueError("byte list has no static size")

#
# Serialization
#
def serialize(self, value: BytesOrByteArray) -> bytes:
len_value = len(value)
if len_value > self.length:
raise SerializationError(
f"Cannot serialize length {len_value} bytes as Bitlist[{self.length}]"
)

if len_value == 0:
return b'\x01'

as_bytearray = [0] * (len_value // 8 + 1)
for i in range(len_value):
as_bytearray[i // 8] |= value[i] << (i % 8)
as_bytearray[len_value // 8] |= 1 << (len_value % 8)
return bytes(as_bytearray)

#
# Deserialization
#
def deserialize(self, data: bytes) -> bytes:
as_integer = int.from_bytes(data, 'little')
len_value = get_bitlist_len(as_integer)

if len_value > self.length:
raise DeserializationError(
f"Cannot deserialize length {len_value} data as bytes{self.length}"
)

return tuple(
bool((data[index // 8] >> index % 8) % 2)
for index in range(len_value)
)

#
# Tree hashing
#
def hash_tree_root(self, value: Sequence[bool]) -> bytes:
return mix_in_length(merkleize(pack_bitvector_bitlist(value)), len(value))


def get_bitlist_len(x):
return x.bit_length() - 1
49 changes: 49 additions & 0 deletions tests/sedes/test_bitlist_serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import pytest

from ssz import (
decode,
encode,
)
from ssz.sedes import (
Bitlist,
)


@pytest.mark.parametrize(
'size, value, expected',
(
(16, tuple(), b'\x01'),
(16, (0b1, 0b0,), b'\x05'),
(16, (0b1,) + (0b0,) * 15, b'\x01\x00\x01'),
),
)
def test_bitlist_serialize_values(size, value, expected):
Foo = Bitlist(size)
assert encode(value, Foo) == expected
assert Foo.serialize(bytearray(value)) == expected


@pytest.mark.parametrize(
'size, value,expected',
(
(16, b'\x01', tuple()),
(16, b'\x05', (True, False,)),
(16, b'\x01\x00\x01', (True,) + (False,) * 15),
),
)
def test_bitlist_deserialize_values(size, value, expected):
Foo = Bitlist(size)
assert Foo.deserialize(value) == expected


@pytest.mark.parametrize(
'size, value',
(
# (16, tuple()),
(16, (True, False,)),
(16, (True,) + (False,) * 15),
),
)
def test_bitlist_round_trip_no_sedes(size, value):
Foo = Bitlist(size)
assert decode(encode(value, Foo), Foo) == value
18 changes: 18 additions & 0 deletions tests/tree_hash/test_composite_tree_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@
hash_eth2 as h,
)
from ssz.sedes import (
Bitlist,
Bitvector,
ByteVector,
Container,
List,
Vector,
uint128,
)
from ssz.utils import (
pad_zeros,
)

bytes16 = ByteVector(16)
EMPTY_BYTES = b"\x00" * 16
Expand Down Expand Up @@ -169,3 +173,17 @@ def test_container(bytes16_fields, result):
def test_bitvector(size, value, result):
Foo = Bitvector(size)
assert ssz.hash_tree_root(value, Foo) == result


@pytest.mark.parametrize(
("size", "value", "result"),
(
(8, (1, 1, 0, 1, 0, 1, 0, 0), h(pad_zeros(b"\x2b") + pad_zeros(b"\x08"))),
(512, tuple(1 for i in range(512)), h(
h(b"\xff" * 32 + b"\xff" * 32) + pad_zeros(b'\x00\x02')
)),
)
)
def test_bitlist(size, value, result):
Foo = Bitlist(size)
assert ssz.hash_tree_root(value, Foo) == result

0 comments on commit 75de33e

Please sign in to comment.