Skip to content

Commit

Permalink
Add support for default_node in StableContainer / Profile
Browse files Browse the repository at this point in the history
Implement `default_node` so that incomplete initialization becomes
possible, similar to how it works for `Container` when not all fields
are explicitly defined during construction.
  • Loading branch information
etan-status committed Jul 9, 2024
1 parent 1099d0a commit c2855b3
Showing 1 changed file with 22 additions and 2 deletions.
24 changes: 22 additions & 2 deletions remerkleable/stable_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,13 @@ def tree_depth(cls) -> int:
def item_elem_cls(cls, i: int) -> Type[View]:
return list(cls._field_indices.values())[i]

@classmethod
def default_node(cls) -> Node:
return PairNode(
left=subtree_fill_to_contents([], cls.tree_depth()),
right=Bitvector[cls.N].default_node(),
)

def active_fields(self) -> Bitvector:
active_fields_node = super().get_backing().get_right()
return Bitvector[self.__class__.N].view_from_backing(active_fields_node)
Expand Down Expand Up @@ -319,11 +326,11 @@ def __new__(cls, backing: Optional[Node] = None, hook: Optional[ViewHook] = None
return super().__new__(cls, backing=backing, hook=hook, **kwargs)

extra_kw = kwargs.copy()
for fkey, (_, _, fopt) in cls._field_indices.items():
for fkey, (_, ftyp, fopt) in cls._field_indices.items():
if fkey in extra_kw:
extra_kw.pop(fkey)
elif not fopt:
raise AttributeError(f'Field `{fkey}` is required in {cls.__name__}')
kwargs[fkey] = ftyp.view_from_backing(ftyp.default_node())
else:
pass
if len(extra_kw) > 0:
Expand Down Expand Up @@ -548,6 +555,19 @@ def tree_depth(cls) -> int:
def item_elem_cls(cls, i: int) -> Type[View]:
return cls.B.item_elem_cls(i)

@classmethod
def default_node(cls) -> Node:
fnodes = [zero_node(0)] * cls.B.N
active_fields = Bitvector[cls.B.N]()
for (findex, ftyp, fopt) in cls._field_indices.values():
if not fopt:
fnodes[findex] = ftyp.default_node()
active_fields.set(findex, True)
return PairNode(
left=subtree_fill_to_contents(fnodes, cls.tree_depth()),
right=active_fields.get_backing(),
)

def active_fields(self) -> Bitvector:
active_fields_node = super().get_backing().get_right()
return Bitvector[self.__class__.B.N].view_from_backing(active_fields_node)
Expand Down

0 comments on commit c2855b3

Please sign in to comment.