From c2855b38bec06b5bba31b5e3311a49799027d2f5 Mon Sep 17 00:00:00 2001 From: Etan Kissling Date: Tue, 9 Jul 2024 15:06:55 +0200 Subject: [PATCH 1/9] Add support for `default_node` in `StableContainer` / `Profile` Implement `default_node` so that incomplete initialization becomes possible, similar to how it works for `Container` when not all fields are explicitly defined during construction. --- remerkleable/stable_container.py | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/remerkleable/stable_container.py b/remerkleable/stable_container.py index 667dc27..902662d 100644 --- a/remerkleable/stable_container.py +++ b/remerkleable/stable_container.py @@ -173,6 +173,13 @@ def tree_depth(cls) -> int: def item_elem_cls(cls, i: int) -> Type[View]: return list(cls._field_indices.values())[i] + @classmethod + def default_node(cls) -> Node: + return PairNode( + left=subtree_fill_to_contents([], cls.tree_depth()), + right=Bitvector[cls.N].default_node(), + ) + def active_fields(self) -> Bitvector: active_fields_node = super().get_backing().get_right() return Bitvector[self.__class__.N].view_from_backing(active_fields_node) @@ -319,11 +326,11 @@ def __new__(cls, backing: Optional[Node] = None, hook: Optional[ViewHook] = None return super().__new__(cls, backing=backing, hook=hook, **kwargs) extra_kw = kwargs.copy() - for fkey, (_, _, fopt) in cls._field_indices.items(): + for fkey, (_, ftyp, fopt) in cls._field_indices.items(): if fkey in extra_kw: extra_kw.pop(fkey) elif not fopt: - raise AttributeError(f'Field `{fkey}` is required in {cls.__name__}') + kwargs[fkey] = ftyp.view_from_backing(ftyp.default_node()) else: pass if len(extra_kw) > 0: @@ -548,6 +555,19 @@ def tree_depth(cls) -> int: def item_elem_cls(cls, i: int) -> Type[View]: return cls.B.item_elem_cls(i) + @classmethod + def default_node(cls) -> Node: + fnodes = [zero_node(0)] * cls.B.N + active_fields = Bitvector[cls.B.N]() + for (findex, ftyp, fopt) in cls._field_indices.values(): + if not fopt: + fnodes[findex] = ftyp.default_node() + active_fields.set(findex, True) + return PairNode( + left=subtree_fill_to_contents(fnodes, cls.tree_depth()), + right=active_fields.get_backing(), + ) + def active_fields(self) -> Bitvector: active_fields_node = super().get_backing().get_right() return Bitvector[self.__class__.B.N].view_from_backing(active_fields_node) From 72a203c32e2e07e49f4f3d7eca5b0f919e098cdf Mon Sep 17 00:00:00 2001 From: Etan Kissling Date: Tue, 16 Jul 2024 15:03:49 +0200 Subject: [PATCH 2/9] Properly configure changed hook to allow modifying returned keys as ref --- remerkleable/stable_container.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/remerkleable/stable_container.py b/remerkleable/stable_container.py index 902662d..2c124f9 100644 --- a/remerkleable/stable_container.py +++ b/remerkleable/stable_container.py @@ -26,7 +26,7 @@ def stable_get(self, findex, ftyp, n): return None data = self.get_backing().get_left() fnode = data.getter(2**get_depth(n) + findex) - return ftyp.view_from_backing(fnode) + return ftyp.view_from_backing(fnode, lambda v: stable_set(self, findex, ftyp, n, v)) def stable_set(self, findex, ftyp, n, value): From 18c7b2ed28a7fb05825bd176b8e45a4341961ec3 Mon Sep 17 00:00:00 2001 From: Etan Kissling Date: Wed, 17 Jul 2024 14:35:06 +0200 Subject: [PATCH 3/9] Add `coerce_view` implementations (same as regular `Container`) --- remerkleable/stable_container.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/remerkleable/stable_container.py b/remerkleable/stable_container.py index 2c124f9..cbd113f 100644 --- a/remerkleable/stable_container.py +++ b/remerkleable/stable_container.py @@ -140,6 +140,10 @@ def __init_subclass__(cls, **kwargs): StableContainerView.__name__ = StableContainerView.type_repr() return StableContainerView + @classmethod + def coerce_view(cls: Type[SV], v: Any) -> SV: + return cls(**{fkey: getattr(v, fkey) for fkey in cls.fields().keys()}) + @classmethod def fields(cls) -> Dict[str, Type[View]]: return { fkey: ftyp for fkey, (_, ftyp) in cls._field_indices.items() } @@ -503,6 +507,10 @@ def __init_subclass__(cls, **kwargs): ProfileView.__name__ = ProfileView.type_repr() return ProfileView + @classmethod + def coerce_view(cls: Type[BV], v: Any) -> BV: + return cls(**{fkey: getattr(v, fkey) for fkey in cls.fields().keys()}) + @classmethod def fields(cls) -> Dict[str, Tuple[Type[View], bool]]: return { fkey: (ftyp, fopt) for fkey, (_, ftyp, fopt) in cls._field_indices.items() } From ec1ca68296cb3881294b1520e297c95ce9bffcb4 Mon Sep 17 00:00:00 2001 From: Etan Kissling Date: Mon, 30 Sep 2024 16:11:08 +0200 Subject: [PATCH 4/9] Fix warnings when doing arithmetic with different limit types --- remerkleable/stable_container.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/remerkleable/stable_container.py b/remerkleable/stable_container.py index cbd113f..1c6a01f 100644 --- a/remerkleable/stable_container.py +++ b/remerkleable/stable_container.py @@ -107,7 +107,7 @@ def __init_subclass__(cls, **kwargs): raise TypeError(f'Invalid capacity: `StableContainer[{n}]`') if n <= 0: raise TypeError(f'Unsupported capacity: `StableContainer[{n}]`') - cls.N = n + cls.N = int(n) def __class_getitem__(cls, n: int) -> Type['StableContainer']: class StableContainerMeta(ViewMeta): From 65fb72f432392f9be6f410d138a725d7d3faa9e2 Mon Sep 17 00:00:00 2001 From: Etan Kissling Date: Tue, 1 Oct 2024 17:09:43 +0200 Subject: [PATCH 5/9] Reject deserialization if there is extra data left over --- remerkleable/stable_container.py | 6 +++ remerkleable/test_impl.py | 75 +++++++------------------------- 2 files changed, 21 insertions(+), 60 deletions(-) diff --git a/remerkleable/stable_container.py b/remerkleable/stable_container.py index 1c6a01f..5b75a2e 100644 --- a/remerkleable/stable_container.py +++ b/remerkleable/stable_container.py @@ -266,6 +266,9 @@ def deserialize(cls: Type[SV], stream: BinaryIO, scope: int) -> SV: f'{foffset}, next {next_offset}, implied size: {fsize}, ' f'size bounds: [{f_min_size}, {f_max_size}]') field_values[fkey] = ftyp.deserialize(stream, fsize) + else: + if scope != fixed_size: + raise Exception(f'Incorrect object size: {scope}, expected: {fixed_size}') return cls(**field_values) # type: ignore def serialize(self, stream: BinaryIO) -> int: @@ -677,6 +680,9 @@ def deserialize(cls: Type[BV], stream: BinaryIO, scope: int) -> BV: f'{foffset}, next {next_offset}, implied size: {fsize}, ' f'size bounds: [{f_min_size}, {f_max_size}]') field_values[fkey] = ftyp.deserialize(stream, fsize) + else: + if scope != fixed_size: + raise Exception(f'Incorrect object size: {scope}, expected: {fixed_size}') return cls(**field_values) # type: ignore def serialize(self, stream: BinaryIO) -> int: diff --git a/remerkleable/test_impl.py b/remerkleable/test_impl.py index 966bfb2..5c639db 100644 --- a/remerkleable/test_impl.py +++ b/remerkleable/test_impl.py @@ -541,23 +541,14 @@ class ShapePairRepr(Container): ) assert all(shape.hash_tree_root() == square_root for shape in shapes) assert all(square.hash_tree_root() == square_root for square in squares) - try: + with pytest.raises(Exception): circle = Circle(side=0x42, color=1) - assert False - except: - pass for shape in shapes: - try: + with pytest.raises(Exception): circle = Circle(backing=shape.get_backing()) - assert False - except: - pass for square in squares: - try: + with pytest.raises(Exception): circle = Circle(backing=square.get_backing()) - assert False - except: - pass for shape in shapes: shape.side = 0x1337 for square in squares: @@ -579,17 +570,11 @@ class ShapePairRepr(Container): assert all(shape.hash_tree_root() == square_root for shape in shapes) assert all(square.hash_tree_root() == square_root for square in squares) for square in squares: - try: + with pytest.raises(Exception): square.radius = 0x1337 - assert False - except: - pass for square in squares: - try: + with pytest.raises(Exception): square.side = None - assert False - except: - pass # Circle tests circle_bytes_stable = bytes.fromhex("06014200") @@ -617,23 +602,14 @@ class ShapePairRepr(Container): ) assert all(shape.hash_tree_root() == circle_root for shape in shapes) assert all(circle.hash_tree_root() == circle_root for circle in circles) - try: + with pytest.raises(Exception): square = Square(radius=0x42, color=1) - assert False - except: - pass for shape in shapes: - try: + with pytest.raises(Exception): square = Square(backing=shape.get_backing()) - assert False - except: - pass for circle in circles: - try: + with pytest.raises(Exception): square = Square(backing=circle.get_backing()) - assert False - except: - pass # SquarePair tests square_pair_bytes_stable = bytes.fromhex("080000000c0000000342000103690001") @@ -712,45 +688,24 @@ class ShapePairRepr(Container): shape_bytes = bytes.fromhex("0201") assert shape.encode_bytes() == shape_bytes assert Shape.decode_bytes(shape_bytes) == shape - try: + with pytest.raises(Exception): shape = Square.decode_bytes(shape_bytes) - assert False - except: - pass - try: + with pytest.raises(Exception): shape = Circle.decode_bytes(shape_bytes) - assert False - except: - pass shape = Shape(side=0x42, color=1, radius=0x42) shape_bytes = bytes.fromhex("074200014200") assert shape.encode_bytes() == shape_bytes assert Shape.decode_bytes(shape_bytes) == shape - try: + with pytest.raises(Exception): shape = Square.decode_bytes(shape_bytes) - assert False - except: - pass - try: + with pytest.raises(Exception): shape = Circle.decode_bytes(shape_bytes) - assert False - except: - pass - try: + with pytest.raises(Exception): shape = Shape.decode_bytes("00") - assert False - except: - pass - try: + with pytest.raises(Exception): square = Square(radius=0x42, color=1) - assert False - except: - pass - try: + with pytest.raises(Exception): circle = Circle(side=0x42, color=1) - assert False - except: - pass # Surrounding container tests class ShapeContainer(Container): From 74fadc009e4552c24f396c9158744e6ba13656af Mon Sep 17 00:00:00 2001 From: Etan Kissling Date: Tue, 1 Oct 2024 17:12:02 +0200 Subject: [PATCH 6/9] Add checks when coercing between containers that have different profiles --- remerkleable/complex.py | 15 +++++++++++---- remerkleable/stable_container.py | 15 +++++++++++++-- remerkleable/test_impl.py | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 56 insertions(+), 6 deletions(-) diff --git a/remerkleable/complex.py b/remerkleable/complex.py index 5971925..a70f367 100644 --- a/remerkleable/complex.py +++ b/remerkleable/complex.py @@ -5,7 +5,8 @@ from collections.abc import Sequence as ColSequence from itertools import chain import io -from remerkleable.core import View, BasicView, OFFSET_BYTE_LENGTH, ViewHook, ObjType, ObjParseException +from remerkleable.core import View, BackedView, BasicView, OFFSET_BYTE_LENGTH,\ + ViewHook, ObjType, ObjParseException from remerkleable.basic import uint256, uint8, uint32 from remerkleable.tree import Node, subtree_fill_to_length, subtree_fill_to_contents,\ zero_node, Gindex, PairNode, to_gindex, NavigationError, get_depth, RIGHT_GINDEX @@ -283,7 +284,9 @@ def __new__(cls, *args, backing: Optional[Node] = None, hook: Optional[ViewHook] raise Exception(f"too many list inputs: {len(vals)}, limit is: {limit}") input_views = [] for el in vals: - if isinstance(el, View): + if isinstance(el, BackedView): + input_views.append(elem_cls(backing=el.get_backing())) + elif isinstance(el, View): input_views.append(el) else: input_views.append(elem_cls.coerce_view(el)) @@ -527,7 +530,9 @@ def __new__(cls, *args, backing: Optional[Node] = None, hook: Optional[ViewHook] raise Exception(f"invalid inputs length: {len(vals)}, vector length is: {vector_length}") input_views = [] for el in vals: - if isinstance(el, View): + if isinstance(el, BackedView): + input_views.append(elem_cls(backing=el.get_backing())) + elif isinstance(el, View): input_views.append(el) else: input_views.append(elem_cls.coerce_view(el)) @@ -731,7 +736,9 @@ def __new__(cls, *args, backing: Optional[Node] = None, hook: Optional[ViewHook] fnode: Node if fkey in kwargs: finput = kwargs.pop(fkey) - if isinstance(finput, View): + if isinstance(finput, BackedView): + fnode = ftyp(backing=finput.get_backing()).get_backing() + elif isinstance(finput, View): fnode = finput.get_backing() else: fnode = ftyp.coerce_view(finput).get_backing() diff --git a/remerkleable/stable_container.py b/remerkleable/stable_container.py index 5b75a2e..442b6e9 100644 --- a/remerkleable/stable_container.py +++ b/remerkleable/stable_container.py @@ -11,7 +11,7 @@ from remerkleable.byte_arrays import ByteList, ByteVector from remerkleable.complex import ComplexView, Container, FieldOffset, List, Vector, \ decode_offset, encode_offset -from remerkleable.core import View, ViewHook, ViewMeta, OFFSET_BYTE_LENGTH +from remerkleable.core import BackedView, View, ViewHook, ViewMeta, OFFSET_BYTE_LENGTH from remerkleable.tree import Gindex, NavigationError, Node, PairNode, \ get_depth, subtree_fill_to_contents, zero_node, \ RIGHT_GINDEX @@ -84,7 +84,9 @@ def __new__(cls, backing: Optional[Node] = None, hook: Optional[ViewHook] = None fnode = zero_node(0) active_fields.set(findex, False) else: - if isinstance(finput, View): + if isinstance(finput, BackedView): + fnode = ftyp(backing=finput.get_backing()).get_backing() + elif isinstance(finput, View): fnode = finput.get_backing() else: fnode = ftyp.coerce_view(finput).get_backing() @@ -330,6 +332,15 @@ def __new__(cls, backing: Optional[Node] = None, hook: Optional[ViewHook] = None if backing is not None: if len(kwargs) != 0: raise Exception('Cannot have both a backing and elements to init fields') + active_fields = Bitvector[cls.B.N].view_from_backing(backing.get_right()) + for fkey, (findex, _) in cls.B._field_indices.items(): + if fkey not in cls._field_indices: + if active_fields.get(findex): + raise ValueError(f'Cannot convert to `{cls.__name__}`: {fkey} unsupported') + else: + (_, _, fopt) = cls._field_indices[fkey] + if not fopt and not active_fields.get(findex): + raise ValueError(f'Cannot convert to `{cls.__name__}`: {fkey} is required') return super().__new__(cls, backing=backing, hook=hook, **kwargs) extra_kw = kwargs.copy() diff --git a/remerkleable/test_impl.py b/remerkleable/test_impl.py index 5c639db..2737ab0 100644 --- a/remerkleable/test_impl.py +++ b/remerkleable/test_impl.py @@ -706,6 +706,10 @@ class ShapePairRepr(Container): square = Square(radius=0x42, color=1) with pytest.raises(Exception): circle = Circle(side=0x42, color=1) + with pytest.raises(Exception): + square = Square.coerce_view(Shape(radius=0x42, color=1)) + with pytest.raises(Exception): + square = Square(backing=Circle(radius=0x42, color=1).get_backing()) # Surrounding container tests class ShapeContainer(Container): @@ -741,6 +745,34 @@ class ShapeContainerRepr(Container): ), ).hash_tree_root() + # Unsupported surrounding container tests + with pytest.raises(Exception): + shapes = List[Square, 5].coerce_view( + List[Circle, 5](Circle(radius=0x42, color=1))) + with pytest.raises(Exception): + shapes = Vector[Square, 1].coerce_view( + Vector[Circle, 1](Circle(radius=0x42, color=1))) + + class SquareContainer(Container): + shape: Square + + class CircleContainer(Container): + shape: Circle + + with pytest.raises(Exception): + shape = SquareContainer.coerce_view( + CircleContainer(shape=Circle(radius=0x42, color=1))) + + class SquareStableContainer(StableContainer[1]): + shape: Optional[Square] + + class CircleStableContainer(StableContainer[1]): + shape: Optional[Circle] + + with pytest.raises(Exception): + shape = SquareStableContainer.coerce_view( + CircleStableContainer(shape=Circle(radius=0x42, color=1))) + # basic container class Shape1(StableContainer[4]): side: Optional[uint16] From 5cad2b12002f017cb82fc267ffb626014761b75b Mon Sep 17 00:00:00 2001 From: Etan Kissling Date: Wed, 2 Oct 2024 11:32:28 +0200 Subject: [PATCH 7/9] Extend nested coercing tests --- remerkleable/complex.py | 22 ++++++++++------- remerkleable/core.py | 6 +++++ remerkleable/stable_container.py | 42 +++++++++++++++++++++++--------- remerkleable/test_impl.py | 29 ++++++++++++++-------- 4 files changed, 68 insertions(+), 31 deletions(-) diff --git a/remerkleable/complex.py b/remerkleable/complex.py index a70f367..0483095 100644 --- a/remerkleable/complex.py +++ b/remerkleable/complex.py @@ -108,6 +108,11 @@ def readonly_iter(self): else: return ComplexElemIter(backing, tree_depth, length, elem_type) + def check_backing(self): + for el in self: + if isinstance(el, BackedView): + el.check_backing() + @classmethod def deserialize(cls: Type[M], stream: BinaryIO, scope: int) -> M: elem_cls = cls.element_cls() @@ -284,9 +289,7 @@ def __new__(cls, *args, backing: Optional[Node] = None, hook: Optional[ViewHook] raise Exception(f"too many list inputs: {len(vals)}, limit is: {limit}") input_views = [] for el in vals: - if isinstance(el, BackedView): - input_views.append(elem_cls(backing=el.get_backing())) - elif isinstance(el, View): + if isinstance(el, View): input_views.append(el) else: input_views.append(elem_cls.coerce_view(el)) @@ -530,9 +533,7 @@ def __new__(cls, *args, backing: Optional[Node] = None, hook: Optional[ViewHook] raise Exception(f"invalid inputs length: {len(vals)}, vector length is: {vector_length}") input_views = [] for el in vals: - if isinstance(el, BackedView): - input_views.append(elem_cls(backing=el.get_backing())) - elif isinstance(el, View): + if isinstance(el, View): input_views.append(el) else: input_views.append(elem_cls.coerce_view(el)) @@ -719,6 +720,11 @@ def __new__(cls, *args, backing: Optional[Node] = None, hook: Optional[ViewHook] def fields(cls) -> Fields: # base condition for the subclasses deriving the fields return {} + def check_backing(self): + for el in self: + if isinstance(el, BackedView): + el.check_backing() + class Container(_ContainerBase): _field_indices: Dict[str, int] @@ -736,9 +742,7 @@ def __new__(cls, *args, backing: Optional[Node] = None, hook: Optional[ViewHook] fnode: Node if fkey in kwargs: finput = kwargs.pop(fkey) - if isinstance(finput, BackedView): - fnode = ftyp(backing=finput.get_backing()).get_backing() - elif isinstance(finput, View): + if isinstance(finput, View): fnode = finput.get_backing() else: fnode = ftyp.coerce_view(finput).get_backing() diff --git a/remerkleable/core.py b/remerkleable/core.py index cac08d9..6dac545 100644 --- a/remerkleable/core.py +++ b/remerkleable/core.py @@ -238,6 +238,9 @@ def __new__(cls, backing: Optional[Node] = None, hook: Optional[ViewHook] = None out._hook = hook return out + def __init__(self, *args, **kwargs): + self.check_backing() + def get_backing(self) -> Node: return self._backing @@ -247,6 +250,9 @@ def set_backing(self, value): if self._hook is not None: self._hook(self) + def check_backing(self): + pass + BV = TypeVar('BV', bound="BasicView") diff --git a/remerkleable/stable_container.py b/remerkleable/stable_container.py index 442b6e9..d258826 100644 --- a/remerkleable/stable_container.py +++ b/remerkleable/stable_container.py @@ -84,9 +84,7 @@ def __new__(cls, backing: Optional[Node] = None, hook: Optional[ViewHook] = None fnode = zero_node(0) active_fields.set(findex, False) else: - if isinstance(finput, BackedView): - fnode = ftyp(backing=finput.get_backing()).get_backing() - elif isinstance(finput, View): + if isinstance(finput, View): fnode = finput.get_backing() else: fnode = ftyp.coerce_view(finput).get_backing() @@ -190,6 +188,17 @@ def active_fields(self) -> Bitvector: active_fields_node = super().get_backing().get_right() return Bitvector[self.__class__.N].view_from_backing(active_fields_node) + def check_backing(self): + active_fields = self.active_fields() + for fkey, (findex, _) in self.__class__._field_indices.items(): + if active_fields.get(findex): + value = getattr(self, fkey) + if isinstance(value, BackedView): + value.check_backing() + for findex in range(len(self.__class__._field_indices), self.__class__.N): + if active_fields.get(findex): + raise ValueError(f'`{self.__class__.__name__}` invalid: Unknown field {findex}') + def __getattribute__(self, item): if item == 'N': raise AttributeError(f'Use `.__class__.{item}` to access `{item}`') @@ -332,15 +341,6 @@ def __new__(cls, backing: Optional[Node] = None, hook: Optional[ViewHook] = None if backing is not None: if len(kwargs) != 0: raise Exception('Cannot have both a backing and elements to init fields') - active_fields = Bitvector[cls.B.N].view_from_backing(backing.get_right()) - for fkey, (findex, _) in cls.B._field_indices.items(): - if fkey not in cls._field_indices: - if active_fields.get(findex): - raise ValueError(f'Cannot convert to `{cls.__name__}`: {fkey} unsupported') - else: - (_, _, fopt) = cls._field_indices[fkey] - if not fopt and not active_fields.get(findex): - raise ValueError(f'Cannot convert to `{cls.__name__}`: {fkey} is required') return super().__new__(cls, backing=backing, hook=hook, **kwargs) extra_kw = kwargs.copy() @@ -606,6 +606,24 @@ def optional_fields(self) -> Bitvector: oindex += 1 return optional_fields + def check_backing(self): + active_fields = self.active_fields() + for fkey, (findex, _) in self.__class__.B._field_indices.items(): + if fkey not in self.__class__._field_indices: + if active_fields.get(findex): + raise ValueError(f'`{self.__class__.__name__}` invalid: {fkey} unsupported') + elif active_fields.get(findex): + value = getattr(self, fkey) + if isinstance(value, BackedView): + value.check_backing() + else: + (_, _, fopt) = self.__class__._field_indices[fkey] + if not fopt: + raise ValueError(f'`{self.__class__.__name__}` invalid: {fkey} is required') + for findex in range(len(self.__class__.B._field_indices), self.__class__.B.N): + if active_fields.get(findex): + raise ValueError(f'`{self.__class__.__name__}` invalid: Unknown field {findex}') + def __getattribute__(self, item): if item == 'B': raise AttributeError(f'Use `.__class__.{item}` to access `{item}`') diff --git a/remerkleable/test_impl.py b/remerkleable/test_impl.py index 2737ab0..0066f07 100644 --- a/remerkleable/test_impl.py +++ b/remerkleable/test_impl.py @@ -706,8 +706,6 @@ class ShapePairRepr(Container): square = Square(radius=0x42, color=1) with pytest.raises(Exception): circle = Circle(side=0x42, color=1) - with pytest.raises(Exception): - square = Square.coerce_view(Shape(radius=0x42, color=1)) with pytest.raises(Exception): square = Square(backing=Circle(radius=0x42, color=1).get_backing()) @@ -747,11 +745,11 @@ class ShapeContainerRepr(Container): # Unsupported surrounding container tests with pytest.raises(Exception): - shapes = List[Square, 5].coerce_view( - List[Circle, 5](Circle(radius=0x42, color=1))) + shapes = List[Square, 5]( + backing=List[Circle, 5](Circle(radius=0x42, color=1)).get_backing()) with pytest.raises(Exception): - shapes = Vector[Square, 1].coerce_view( - Vector[Circle, 1](Circle(radius=0x42, color=1))) + shapes = Vector[Square, 1]( + backing=Vector[Circle, 1](Circle(radius=0x42, color=1)).get_backing()) class SquareContainer(Container): shape: Square @@ -760,8 +758,8 @@ class CircleContainer(Container): shape: Circle with pytest.raises(Exception): - shape = SquareContainer.coerce_view( - CircleContainer(shape=Circle(radius=0x42, color=1))) + shape = SquareContainer( + backing=CircleContainer(shape=Circle(radius=0x42, color=1)).get_backing()) class SquareStableContainer(StableContainer[1]): shape: Optional[Square] @@ -770,8 +768,19 @@ class CircleStableContainer(StableContainer[1]): shape: Optional[Circle] with pytest.raises(Exception): - shape = SquareStableContainer.coerce_view( - CircleStableContainer(shape=Circle(radius=0x42, color=1))) + shape = SquareStableContainer( + backing=CircleStableContainer(shape=Circle(radius=0x42, color=1)).get_backing()) + + class NestedSquareContainer(Container): + item: SquareContainer + + class NestedCircleContainer(Container): + item: CircleContainer + + with pytest.raises(Exception): + shape = NestedSquareContainer( + backing=NestedCircleContainer( + item=CircleContainer(shape=Circle(radius=0x42, color=1))).get_backing()) # basic container class Shape1(StableContainer[4]): From f55fc28d67d54baeccb5429f22028f91106d7932 Mon Sep 17 00:00:00 2001 From: Etan Kissling Date: Thu, 10 Oct 2024 17:09:53 +0200 Subject: [PATCH 8/9] Add `from_base` / `to_base` explicit conversion helpers --- remerkleable/core.py | 12 +++-- remerkleable/test_impl.py | 98 ++++++++++++++++++++++++--------------- 2 files changed, 69 insertions(+), 41 deletions(-) diff --git a/remerkleable/core.py b/remerkleable/core.py index 6dac545..91e89ed 100644 --- a/remerkleable/core.py +++ b/remerkleable/core.py @@ -238,9 +238,6 @@ def __new__(cls, backing: Optional[Node] = None, hook: Optional[ViewHook] = None out._hook = hook return out - def __init__(self, *args, **kwargs): - self.check_backing() - def get_backing(self) -> Node: return self._backing @@ -253,6 +250,15 @@ def set_backing(self, value): def check_backing(self): pass + @classmethod + def from_base(cls: Type[BackedV], value) -> BackedV: + res = cls(backing=value.get_backing()) + res.check_backing() + return res + + def to_base(self, cls: Type[BackedV]) -> BackedV: + return cls(backing=self.get_backing()) + BV = TypeVar('BV', bound="BasicView") diff --git a/remerkleable/test_impl.py b/remerkleable/test_impl.py index 0066f07..97bfaac 100644 --- a/remerkleable/test_impl.py +++ b/remerkleable/test_impl.py @@ -527,16 +527,18 @@ class ShapePairRepr(Container): ).hash_tree_root() shapes = [Shape(side=0x42, color=1, radius=None)] squares = [Square(side=0x42, color=1)] - squares.extend(list(Square(backing=shape.get_backing()) for shape in shapes)) - shapes.extend(list(Shape(backing=shape.get_backing()) for shape in shapes)) - shapes.extend(list(Shape(backing=square.get_backing()) for square in squares)) - squares.extend(list(Square(backing=square.get_backing()) for square in squares)) + squares.extend(list(Square.from_base(shape) for shape in shapes)) + shapes.extend(list(Shape( + side=shape.side, radius=shape.radius, color=shape.color + ) for shape in shapes)) + shapes.extend(list(square.to_base(Shape) for square in squares)) + squares.extend(list(Square(side=square.side, color=square.color) for square in squares)) assert len(set(shapes)) == 1 assert len(set(squares)) == 1 assert all(shape.encode_bytes() == square_bytes_stable for shape in shapes) assert all(square.encode_bytes() == square_bytes_profile for square in squares) assert ( - Square(backing=Shape.decode_bytes(square_bytes_stable).get_backing()) == + Square.from_base(Shape.decode_bytes(square_bytes_stable)) == Square.decode_bytes(square_bytes_profile) ) assert all(shape.hash_tree_root() == square_root for shape in shapes) @@ -545,10 +547,10 @@ class ShapePairRepr(Container): circle = Circle(side=0x42, color=1) for shape in shapes: with pytest.raises(Exception): - circle = Circle(backing=shape.get_backing()) + circle = Circle.from_base(shape) for square in squares: with pytest.raises(Exception): - circle = Circle(backing=square.get_backing()) + circle = Circle.from_base(square.to_base(Shape)) for shape in shapes: shape.side = 0x1337 for square in squares: @@ -564,7 +566,7 @@ class ShapePairRepr(Container): assert all(shape.encode_bytes() == square_bytes_stable for shape in shapes) assert all(square.encode_bytes() == square_bytes_profile for square in squares) assert ( - Square(backing=Shape.decode_bytes(square_bytes_stable).get_backing()) == + Square.from_base(Shape.decode_bytes(square_bytes_stable)) == Square.decode_bytes(square_bytes_profile) ) assert all(shape.hash_tree_root() == square_root for shape in shapes) @@ -588,16 +590,18 @@ class ShapePairRepr(Container): modified_shape.radius = 0x42 shapes = [Shape(side=None, color=1, radius=0x42), modified_shape] circles = [Circle(radius=0x42, color=1)] - circles.extend(list(Circle(backing=shape.get_backing()) for shape in shapes)) - shapes.extend(list(Shape(backing=shape.get_backing()) for shape in shapes)) - shapes.extend(list(Shape(backing=circle.get_backing()) for circle in circles)) - circles.extend(list(Circle(backing=circle.get_backing()) for circle in circles)) + circles.extend(list(Circle.from_base(shape) for shape in shapes)) + shapes.extend(list(Shape( + side=shape.side, radius=shape.radius, color=shape.color + ) for shape in shapes)) + shapes.extend(list(circle.to_base(Shape) for circle in circles)) + circles.extend(list(Circle(radius=circle.radius, color=circle.color) for circle in circles)) assert len(set(shapes)) == 1 assert len(set(circles)) == 1 assert all(shape.encode_bytes() == circle_bytes_stable for shape in shapes) assert all(circle.encode_bytes() == circle_bytes_profile for circle in circles) assert ( - Circle(backing=Shape.decode_bytes(circle_bytes_stable).get_backing()) == + Circle.from_base(Shape.decode_bytes(circle_bytes_stable)) == Circle.decode_bytes(circle_bytes_profile) ) assert all(shape.hash_tree_root() == circle_root for shape in shapes) @@ -606,10 +610,10 @@ class ShapePairRepr(Container): square = Square(radius=0x42, color=1) for shape in shapes: with pytest.raises(Exception): - square = Square(backing=shape.get_backing()) + square = Square.from_base(shape) for circle in circles: with pytest.raises(Exception): - square = Square(backing=circle.get_backing()) + square = Square.from_base(circle.to_base(Shape)) # SquarePair tests square_pair_bytes_stable = bytes.fromhex("080000000c0000000342000103690001") @@ -632,16 +636,18 @@ class ShapePairRepr(Container): shape_1=Square(side=0x42, color=1), shape_2=Square(side=0x69, color=1), )] - square_pairs.extend(list(SquarePair(backing=pair.get_backing()) for pair in shape_pairs)) - shape_pairs.extend(list(ShapePair(backing=pair.get_backing()) for pair in shape_pairs)) - shape_pairs.extend(list(ShapePair(backing=pair.get_backing()) for pair in square_pairs)) - square_pairs.extend(list(SquarePair(backing=pair.get_backing()) for pair in square_pairs)) + square_pairs.extend(list(SquarePair.from_base(pair) for pair in shape_pairs)) + shape_pairs.extend(list(ShapePair( + shape_1=pair.shape_1, shape_2=pair.shape_2) for pair in shape_pairs)) + shape_pairs.extend(list(pair.to_base(ShapePair) for pair in square_pairs)) + square_pairs.extend(list(SquarePair( + shape_1=pair.shape_1, shape_2=pair.shape_2) for pair in square_pairs)) assert len(set(shape_pairs)) == 1 assert len(set(square_pairs)) == 1 assert all(pair.encode_bytes() == square_pair_bytes_stable for pair in shape_pairs) assert all(pair.encode_bytes() == square_pair_bytes_profile for pair in square_pairs) assert ( - SquarePair(backing=ShapePair.decode_bytes(square_pair_bytes_stable).get_backing()) == + SquarePair.from_base(ShapePair.decode_bytes(square_pair_bytes_stable)) == SquarePair.decode_bytes(square_pair_bytes_profile) ) assert all(pair.hash_tree_root() == square_pair_root for pair in shape_pairs) @@ -668,16 +674,18 @@ class ShapePairRepr(Container): shape_1=Circle(radius=0x42, color=1), shape_2=Circle(radius=0x69, color=1), )] - circle_pairs.extend(list(CirclePair(backing=pair.get_backing()) for pair in shape_pairs)) - shape_pairs.extend(list(ShapePair(backing=pair.get_backing()) for pair in shape_pairs)) - shape_pairs.extend(list(ShapePair(backing=pair.get_backing()) for pair in circle_pairs)) - circle_pairs.extend(list(CirclePair(backing=pair.get_backing()) for pair in circle_pairs)) + circle_pairs.extend(list(CirclePair.from_base(pair) for pair in shape_pairs)) + shape_pairs.extend(list(ShapePair( + shape_1=pair.shape_1, shape_2=pair.shape_2) for pair in shape_pairs)) + shape_pairs.extend(list(pair.to_base(ShapePair) for pair in circle_pairs)) + circle_pairs.extend(list(CirclePair( + shape_1=pair.shape_1, shape_2=pair.shape_2) for pair in circle_pairs)) assert len(set(shape_pairs)) == 1 assert len(set(circle_pairs)) == 1 assert all(pair.encode_bytes() == circle_pair_bytes_stable for pair in shape_pairs) assert all(pair.encode_bytes() == circle_pair_bytes_profile for pair in circle_pairs) assert ( - CirclePair(backing=ShapePair.decode_bytes(circle_pair_bytes_stable).get_backing()) == + CirclePair.from_base(ShapePair.decode_bytes(circle_pair_bytes_stable)) == CirclePair.decode_bytes(circle_pair_bytes_profile) ) assert all(pair.hash_tree_root() == circle_pair_root for pair in shape_pairs) @@ -707,7 +715,7 @@ class ShapePairRepr(Container): with pytest.raises(Exception): circle = Circle(side=0x42, color=1) with pytest.raises(Exception): - square = Square(backing=Circle(radius=0x42, color=1).get_backing()) + square = Square.from_base(Circle(radius=0x42, color=1).to_base(Shape)) # Surrounding container tests class ShapeContainer(Container): @@ -743,13 +751,19 @@ class ShapeContainerRepr(Container): ), ).hash_tree_root() - # Unsupported surrounding container tests + # Nested surrounding container tests + shapes = List[Circle, 5](Circle(radius=0x42, color=1)) + assert List[Circle, 5].from_base(shapes.to_base(List[Shape, 5])) == shapes with pytest.raises(Exception): - shapes = List[Square, 5]( - backing=List[Circle, 5](Circle(radius=0x42, color=1)).get_backing()) + shapes = List[Square, 5].from_base(shapes.to_base(List[Shape, 5])) + + shapes = Vector[Circle, 1](Circle(radius=0x42, color=1)) + assert Vector[Circle, 1].from_base(shapes.to_base(Vector[Shape, 1])) == shapes with pytest.raises(Exception): - shapes = Vector[Square, 1]( - backing=Vector[Circle, 1](Circle(radius=0x42, color=1)).get_backing()) + shapes = Vector[Square, 1].from_base(shapes.to_base(Vector[Shape, 1])) + + class ShapeContainer(Container): + shape: Shape class SquareContainer(Container): shape: Square @@ -757,9 +771,13 @@ class SquareContainer(Container): class CircleContainer(Container): shape: Circle + shape = CircleContainer(shape=Circle(radius=0x42, color=1)) + assert CircleContainer.from_base(shape.to_base(ShapeContainer)) == shape with pytest.raises(Exception): - shape = SquareContainer( - backing=CircleContainer(shape=Circle(radius=0x42, color=1)).get_backing()) + shape = SquareContainer.from_base(shape.to_base(ShapeContainer)) + + class ShapeStableContainer(StableContainer[1]): + shape: Optional[Shape] class SquareStableContainer(StableContainer[1]): shape: Optional[Square] @@ -767,9 +785,13 @@ class SquareStableContainer(StableContainer[1]): class CircleStableContainer(StableContainer[1]): shape: Optional[Circle] + shape = CircleStableContainer(shape=Circle(radius=0x42, color=1)) + assert CircleStableContainer.from_base(shape.to_base(ShapeStableContainer)) == shape with pytest.raises(Exception): - shape = SquareStableContainer( - backing=CircleStableContainer(shape=Circle(radius=0x42, color=1)).get_backing()) + shape = SquareStableContainer.from_base(shape.to_base(ShapeStableContainer)) + + class NestedShapeContainer(Container): + item: ShapeContainer class NestedSquareContainer(Container): item: SquareContainer @@ -777,10 +799,10 @@ class NestedSquareContainer(Container): class NestedCircleContainer(Container): item: CircleContainer + shape = NestedCircleContainer(item=CircleContainer(shape=Circle(radius=0x42, color=1))) + assert NestedCircleContainer.from_base(shape.to_base(NestedShapeContainer)) == shape with pytest.raises(Exception): - shape = NestedSquareContainer( - backing=NestedCircleContainer( - item=CircleContainer(shape=Circle(radius=0x42, color=1))).get_backing()) + shape = NestedSquareContainer.from_base(shape.to_base(NestedShapeContainer)) == shape # basic container class Shape1(StableContainer[4]): From 00eda8d69cb469d1589cb4793231d4eae1017c7a Mon Sep 17 00:00:00 2001 From: Etan Kissling Date: Thu, 10 Oct 2024 20:42:27 +0200 Subject: [PATCH 9/9] Cleanup negative test --- remerkleable/test_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/remerkleable/test_impl.py b/remerkleable/test_impl.py index 97bfaac..a3e4553 100644 --- a/remerkleable/test_impl.py +++ b/remerkleable/test_impl.py @@ -802,7 +802,7 @@ class NestedCircleContainer(Container): shape = NestedCircleContainer(item=CircleContainer(shape=Circle(radius=0x42, color=1))) assert NestedCircleContainer.from_base(shape.to_base(NestedShapeContainer)) == shape with pytest.raises(Exception): - shape = NestedSquareContainer.from_base(shape.to_base(NestedShapeContainer)) == shape + shape = NestedSquareContainer.from_base(shape.to_base(NestedShapeContainer)) # basic container class Shape1(StableContainer[4]):