From 75de33e75604c2f007adfc77b6b627672ddbd118 Mon Sep 17 00:00:00 2001 From: Hsiao-Wei Wang Date: Thu, 4 Jul 2019 15:35:42 +0800 Subject: [PATCH] Add `Bitlist` --- ssz/__init__.py | 2 + ssz/sedes/__init__.py | 3 + ssz/sedes/bitlist.py | 80 +++++++++++++++++++++ tests/sedes/test_bitlist_serializer.py | 49 +++++++++++++ tests/tree_hash/test_composite_tree_hash.py | 18 +++++ 5 files changed, 152 insertions(+) create mode 100644 ssz/sedes/bitlist.py create mode 100644 tests/sedes/test_bitlist_serializer.py diff --git a/ssz/__init__.py b/ssz/__init__.py index 765095ac..9b294e5d 100644 --- a/ssz/__init__.py +++ b/ssz/__init__.py @@ -13,6 +13,8 @@ from .sedes import ( # noqa: F401 BaseSedes, BasicSedes, + Bitlist, + Bitvector, Boolean, Byte, ByteList, diff --git a/ssz/sedes/__init__.py b/ssz/sedes/__init__.py index d996a548..2542a144 100644 --- a/ssz/sedes/__init__.py +++ b/ssz/sedes/__init__.py @@ -8,6 +8,9 @@ BasicSedes, CompositeSedes, ) +from .bitlist import ( # noqa: F401 + Bitlist, +) from .bitvector import ( # noqa: F401 Bitvector, ) diff --git a/ssz/sedes/bitlist.py b/ssz/sedes/bitlist.py new file mode 100644 index 00000000..ffeaca3b --- /dev/null +++ b/ssz/sedes/bitlist.py @@ -0,0 +1,80 @@ +from typing import ( + Sequence, + Union, +) + +from ssz.exceptions import ( + DeserializationError, + SerializationError, +) +from ssz.sedes.base import ( + BaseCompositeSedes, +) +from ssz.utils import ( + merkleize, + mix_in_length, + pack_bitvector_bitlist, +) + +BytesOrByteArray = Union[bytes, bytearray] + + +class Bitlist(BaseCompositeSedes[BytesOrByteArray, bytes]): + def __init__(self, length: int) -> None: + if length < 0: + raise TypeError("Max length cannot be negative") + self.length = length + + # + # Size + # + is_fixed_sized = False + + def get_fixed_size(self): + raise ValueError("byte list has no static size") + + # + # Serialization + # + def serialize(self, value: BytesOrByteArray) -> bytes: + len_value = len(value) + if len_value > self.length: + raise SerializationError( + f"Cannot serialize length {len_value} bytes as Bitlist[{self.length}]" + ) + + if len_value == 0: + return b'\x01' + + as_bytearray = [0] * (len_value // 8 + 1) + for i in range(len_value): + as_bytearray[i // 8] |= value[i] << (i % 8) + as_bytearray[len_value // 8] |= 1 << (len_value % 8) + return bytes(as_bytearray) + + # + # Deserialization + # + def deserialize(self, data: bytes) -> bytes: + as_integer = int.from_bytes(data, 'little') + len_value = get_bitlist_len(as_integer) + + if len_value > self.length: + raise DeserializationError( + f"Cannot deserialize length {len_value} data as bytes{self.length}" + ) + + return tuple( + bool((data[index // 8] >> index % 8) % 2) + for index in range(len_value) + ) + + # + # Tree hashing + # + def hash_tree_root(self, value: Sequence[bool]) -> bytes: + return mix_in_length(merkleize(pack_bitvector_bitlist(value)), len(value)) + + +def get_bitlist_len(x): + return x.bit_length() - 1 diff --git a/tests/sedes/test_bitlist_serializer.py b/tests/sedes/test_bitlist_serializer.py new file mode 100644 index 00000000..796dc87a --- /dev/null +++ b/tests/sedes/test_bitlist_serializer.py @@ -0,0 +1,49 @@ +import pytest + +from ssz import ( + decode, + encode, +) +from ssz.sedes import ( + Bitlist, +) + + +@pytest.mark.parametrize( + 'size, value, expected', + ( + (16, tuple(), b'\x01'), + (16, (0b1, 0b0,), b'\x05'), + (16, (0b1,) + (0b0,) * 15, b'\x01\x00\x01'), + ), +) +def test_bitlist_serialize_values(size, value, expected): + Foo = Bitlist(size) + assert encode(value, Foo) == expected + assert Foo.serialize(bytearray(value)) == expected + + +@pytest.mark.parametrize( + 'size, value,expected', + ( + (16, b'\x01', tuple()), + (16, b'\x05', (True, False,)), + (16, b'\x01\x00\x01', (True,) + (False,) * 15), + ), +) +def test_bitlist_deserialize_values(size, value, expected): + Foo = Bitlist(size) + assert Foo.deserialize(value) == expected + + +@pytest.mark.parametrize( + 'size, value', + ( + # (16, tuple()), + (16, (True, False,)), + (16, (True,) + (False,) * 15), + ), +) +def test_bitlist_round_trip_no_sedes(size, value): + Foo = Bitlist(size) + assert decode(encode(value, Foo), Foo) == value diff --git a/tests/tree_hash/test_composite_tree_hash.py b/tests/tree_hash/test_composite_tree_hash.py index cf889079..91070810 100644 --- a/tests/tree_hash/test_composite_tree_hash.py +++ b/tests/tree_hash/test_composite_tree_hash.py @@ -11,6 +11,7 @@ hash_eth2 as h, ) from ssz.sedes import ( + Bitlist, Bitvector, ByteVector, Container, @@ -18,6 +19,9 @@ Vector, uint128, ) +from ssz.utils import ( + pad_zeros, +) bytes16 = ByteVector(16) EMPTY_BYTES = b"\x00" * 16 @@ -169,3 +173,17 @@ def test_container(bytes16_fields, result): def test_bitvector(size, value, result): Foo = Bitvector(size) assert ssz.hash_tree_root(value, Foo) == result + + +@pytest.mark.parametrize( + ("size", "value", "result"), + ( + (8, (1, 1, 0, 1, 0, 1, 0, 0), h(pad_zeros(b"\x2b") + pad_zeros(b"\x08"))), + (512, tuple(1 for i in range(512)), h( + h(b"\xff" * 32 + b"\xff" * 32) + pad_zeros(b'\x00\x02') + )), + ) +) +def test_bitlist(size, value, result): + Foo = Bitlist(size) + assert ssz.hash_tree_root(value, Foo) == result