Skip to content

Commit

Permalink
pyln.proto.message: expose array types, add set_field for Message class.
Browse files Browse the repository at this point in the history
Exposing the array types is required for our dummyrunner in the lnprototest suite, since
it wants to be able to generate fake fields.

The set_field is similarly useful.

Signed-off-by: Rusty Russell <rusty@rustcorp.com.au>
  • Loading branch information
rustyrussell committed Jun 11, 2020
1 parent 2ead207 commit dd1aab8
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
4 changes: 4 additions & 0 deletions contrib/pyln-proto/pyln/proto/message/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .array_types import SizedArrayType, DynamicArrayType, EllipsisArrayType
from .message import MessageNamespace, MessageType, Message, SubtypeType
from .fundamental_types import split_field, FieldType

Expand All @@ -10,6 +11,9 @@
"SubtypeType",
"FieldType",
"split_field",
"SizedArrayType",
"DynamicArrayType",
"EllipsisArrayType",

# fundamental_types
'byte',
Expand Down
22 changes: 12 additions & 10 deletions contrib/pyln-proto/pyln/proto/message/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,22 +545,24 @@ def __init__(self, messagetype: MessageType, **kwargs):

# Convert arguments from strings to values if necessary.
for field in kwargs:
f = self.messagetype.find_field(field)
if f is None:
raise ValueError("Unknown field {}".format(field))

v = kwargs[field]
if isinstance(v, str):
v, remainder = f.fieldtype.val_from_str(v)
if remainder != '':
raise ValueError('Unexpected {} at end of initializer for {}'.format(remainder, field))
self.fields[field] = v
self.set_field(field, kwargs[field])

bad_lens = self.messagetype.len_fields_bad(self.messagetype.name,
self.fields)
if bad_lens:
raise ValueError("Inconsistent length fields: {}".format(bad_lens))

def set_field(self, field: str, val: Any) -> None:
f = self.messagetype.find_field(field)
if f is None:
raise ValueError("Unknown field {}".format(field))

if isinstance(val, str):
val, remainder = f.fieldtype.val_from_str(val)
if remainder != '':
raise ValueError('Unexpected {} at end of initializer for {}'.format(remainder, field))
self.fields[field] = val

def missing_fields(self) -> List[str]:
"""Are any required fields missing?"""
missing: List[str] = []
Expand Down

0 comments on commit dd1aab8

Please sign in to comment.