Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for default_node in StableContainer / Profile #25

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
13 changes: 12 additions & 1 deletion remerkleable/complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from collections.abc import Sequence as ColSequence
from itertools import chain
import io
from remerkleable.core import View, BasicView, OFFSET_BYTE_LENGTH, ViewHook, ObjType, ObjParseException
from remerkleable.core import View, BackedView, BasicView, OFFSET_BYTE_LENGTH,\
ViewHook, ObjType, ObjParseException
from remerkleable.basic import uint256, uint8, uint32
from remerkleable.tree import Node, subtree_fill_to_length, subtree_fill_to_contents,\
zero_node, Gindex, PairNode, to_gindex, NavigationError, get_depth, RIGHT_GINDEX
Expand Down Expand Up @@ -107,6 +108,11 @@ def readonly_iter(self):
else:
return ComplexElemIter(backing, tree_depth, length, elem_type)

def check_backing(self):
for el in self:
if isinstance(el, BackedView):
el.check_backing()

@classmethod
def deserialize(cls: Type[M], stream: BinaryIO, scope: int) -> M:
elem_cls = cls.element_cls()
Expand Down Expand Up @@ -714,6 +720,11 @@ def __new__(cls, *args, backing: Optional[Node] = None, hook: Optional[ViewHook]
def fields(cls) -> Fields: # base condition for the subclasses deriving the fields
return {}

def check_backing(self):
for el in self:
if isinstance(el, BackedView):
el.check_backing()


class Container(_ContainerBase):
_field_indices: Dict[str, int]
Expand Down
12 changes: 12 additions & 0 deletions remerkleable/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,18 @@ def set_backing(self, value):
if self._hook is not None:
self._hook(self)

def check_backing(self):
pass

@classmethod
def from_base(cls: Type[BackedV], value) -> BackedV:
res = cls(backing=value.get_backing())
res.check_backing()
return res

def to_base(self, cls: Type[BackedV]) -> BackedV:
return cls(backing=self.get_backing())


BV = TypeVar('BV', bound="BasicView")

Expand Down
73 changes: 68 additions & 5 deletions remerkleable/stable_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from remerkleable.byte_arrays import ByteList, ByteVector
from remerkleable.complex import ComplexView, Container, FieldOffset, List, Vector, \
decode_offset, encode_offset
from remerkleable.core import View, ViewHook, ViewMeta, OFFSET_BYTE_LENGTH
from remerkleable.core import BackedView, View, ViewHook, ViewMeta, OFFSET_BYTE_LENGTH
from remerkleable.tree import Gindex, NavigationError, Node, PairNode, \
get_depth, subtree_fill_to_contents, zero_node, \
RIGHT_GINDEX
Expand All @@ -26,7 +26,7 @@ def stable_get(self, findex, ftyp, n):
return None
data = self.get_backing().get_left()
fnode = data.getter(2**get_depth(n) + findex)
return ftyp.view_from_backing(fnode)
return ftyp.view_from_backing(fnode, lambda v: stable_set(self, findex, ftyp, n, v))


def stable_set(self, findex, ftyp, n, value):
Expand Down Expand Up @@ -107,7 +107,7 @@ def __init_subclass__(cls, **kwargs):
raise TypeError(f'Invalid capacity: `StableContainer[{n}]`')
if n <= 0:
raise TypeError(f'Unsupported capacity: `StableContainer[{n}]`')
cls.N = n
cls.N = int(n)

def __class_getitem__(cls, n: int) -> Type['StableContainer']:
class StableContainerMeta(ViewMeta):
Expand Down Expand Up @@ -140,6 +140,10 @@ def __init_subclass__(cls, **kwargs):
StableContainerView.__name__ = StableContainerView.type_repr()
return StableContainerView

@classmethod
def coerce_view(cls: Type[SV], v: Any) -> SV:
return cls(**{fkey: getattr(v, fkey) for fkey in cls.fields().keys()})

@classmethod
def fields(cls) -> Dict[str, Type[View]]:
return { fkey: ftyp for fkey, (_, ftyp) in cls._field_indices.items() }
Expand Down Expand Up @@ -173,10 +177,28 @@ def tree_depth(cls) -> int:
def item_elem_cls(cls, i: int) -> Type[View]:
return list(cls._field_indices.values())[i]

@classmethod
def default_node(cls) -> Node:
return PairNode(
left=subtree_fill_to_contents([], cls.tree_depth()),
right=Bitvector[cls.N].default_node(),
)

def active_fields(self) -> Bitvector:
active_fields_node = super().get_backing().get_right()
return Bitvector[self.__class__.N].view_from_backing(active_fields_node)

def check_backing(self):
active_fields = self.active_fields()
for fkey, (findex, _) in self.__class__._field_indices.items():
if active_fields.get(findex):
value = getattr(self, fkey)
if isinstance(value, BackedView):
value.check_backing()
for findex in range(len(self.__class__._field_indices), self.__class__.N):
if active_fields.get(findex):
raise ValueError(f'`{self.__class__.__name__}` invalid: Unknown field {findex}')

def __getattribute__(self, item):
if item == 'N':
raise AttributeError(f'Use `.__class__.{item}` to access `{item}`')
Expand Down Expand Up @@ -255,6 +277,9 @@ def deserialize(cls: Type[SV], stream: BinaryIO, scope: int) -> SV:
f'{foffset}, next {next_offset}, implied size: {fsize}, '
f'size bounds: [{f_min_size}, {f_max_size}]')
field_values[fkey] = ftyp.deserialize(stream, fsize)
else:
if scope != fixed_size:
raise Exception(f'Incorrect object size: {scope}, expected: {fixed_size}')
return cls(**field_values) # type: ignore

def serialize(self, stream: BinaryIO) -> int:
Expand Down Expand Up @@ -319,11 +344,11 @@ def __new__(cls, backing: Optional[Node] = None, hook: Optional[ViewHook] = None
return super().__new__(cls, backing=backing, hook=hook, **kwargs)

extra_kw = kwargs.copy()
for fkey, (_, _, fopt) in cls._field_indices.items():
for fkey, (_, ftyp, fopt) in cls._field_indices.items():
if fkey in extra_kw:
extra_kw.pop(fkey)
elif not fopt:
raise AttributeError(f'Field `{fkey}` is required in {cls.__name__}')
kwargs[fkey] = ftyp.view_from_backing(ftyp.default_node())
else:
pass
if len(extra_kw) > 0:
Expand Down Expand Up @@ -496,6 +521,10 @@ def __init_subclass__(cls, **kwargs):
ProfileView.__name__ = ProfileView.type_repr()
return ProfileView

@classmethod
def coerce_view(cls: Type[BV], v: Any) -> BV:
return cls(**{fkey: getattr(v, fkey) for fkey in cls.fields().keys()})

@classmethod
def fields(cls) -> Dict[str, Tuple[Type[View], bool]]:
return { fkey: (ftyp, fopt) for fkey, (_, ftyp, fopt) in cls._field_indices.items() }
Expand Down Expand Up @@ -548,6 +577,19 @@ def tree_depth(cls) -> int:
def item_elem_cls(cls, i: int) -> Type[View]:
return cls.B.item_elem_cls(i)

@classmethod
def default_node(cls) -> Node:
fnodes = [zero_node(0)] * cls.B.N
active_fields = Bitvector[cls.B.N]()
for (findex, ftyp, fopt) in cls._field_indices.values():
if not fopt:
fnodes[findex] = ftyp.default_node()
active_fields.set(findex, True)
return PairNode(
left=subtree_fill_to_contents(fnodes, cls.tree_depth()),
right=active_fields.get_backing(),
)

def active_fields(self) -> Bitvector:
active_fields_node = super().get_backing().get_right()
return Bitvector[self.__class__.B.N].view_from_backing(active_fields_node)
Expand All @@ -564,6 +606,24 @@ def optional_fields(self) -> Bitvector:
oindex += 1
return optional_fields

def check_backing(self):
active_fields = self.active_fields()
for fkey, (findex, _) in self.__class__.B._field_indices.items():
if fkey not in self.__class__._field_indices:
if active_fields.get(findex):
raise ValueError(f'`{self.__class__.__name__}` invalid: {fkey} unsupported')
elif active_fields.get(findex):
value = getattr(self, fkey)
if isinstance(value, BackedView):
value.check_backing()
else:
(_, _, fopt) = self.__class__._field_indices[fkey]
if not fopt:
raise ValueError(f'`{self.__class__.__name__}` invalid: {fkey} is required')
for findex in range(len(self.__class__.B._field_indices), self.__class__.B.N):
if active_fields.get(findex):
raise ValueError(f'`{self.__class__.__name__}` invalid: Unknown field {findex}')

def __getattribute__(self, item):
if item == 'B':
raise AttributeError(f'Use `.__class__.{item}` to access `{item}`')
Expand Down Expand Up @@ -649,6 +709,9 @@ def deserialize(cls: Type[BV], stream: BinaryIO, scope: int) -> BV:
f'{foffset}, next {next_offset}, implied size: {fsize}, '
f'size bounds: [{f_min_size}, {f_max_size}]')
field_values[fkey] = ftyp.deserialize(stream, fsize)
else:
if scope != fixed_size:
raise Exception(f'Incorrect object size: {scope}, expected: {fixed_size}')
return cls(**field_values) # type: ignore

def serialize(self, stream: BinaryIO) -> int:
Expand Down
Loading
Loading