Skip to content

Commit

Permalink
refactor(python/protobuf): allow field types imported in the same module
Browse files Browse the repository at this point in the history
  • Loading branch information
matejcik committed Jun 25, 2024
1 parent 27fef37 commit cd55d32
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 132 deletions.
199 changes: 95 additions & 104 deletions python/src/trezorlib/protobuf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@
from __future__ import annotations

import logging
import sys
import typing as t
import warnings
from dataclasses import dataclass
from enum import IntEnum
from io import BytesIO
from itertools import zip_longest
import typing as t

import typing_extensions as tx

Expand Down Expand Up @@ -62,10 +63,6 @@ def write(self, __buf: bytes) -> int:
LOG = logging.getLogger(__name__)


def safe_issubclass(value: t.Any, cls: T | tuple[T, ...]) -> tx.TypeGuard[T]:
return isinstance(value, type) and issubclass(value, cls)


def load_uvarint(reader: Reader) -> int:
buffer = _UVARINT_BUFFER
result = 0
Expand Down Expand Up @@ -135,14 +132,14 @@ def uint_to_sint(uint: int) -> int:
WIRE_TYPE_INT = 0
WIRE_TYPE_LENGTH = 2

WIRE_TYPES = {
"uint32": WIRE_TYPE_INT,
"uint64": WIRE_TYPE_INT,
"sint32": WIRE_TYPE_INT,
"sint64": WIRE_TYPE_INT,
"bool": WIRE_TYPE_INT,
"bytes": WIRE_TYPE_LENGTH,
"string": WIRE_TYPE_LENGTH,
PROTO_TYPES = {
"uint32": int,
"uint64": int,
"sint32": int,
"sint64": int,
"bool": bool,
"bytes": bytes,
"string": str,
}

REQUIRED_FIELD_PLACEHOLDER = object()
Expand All @@ -151,50 +148,76 @@ def uint_to_sint(uint: int) -> int:
@dataclass
class Field:
name: str
type: str
proto_type: str
repeated: bool = False
required: bool = False
default: object = None

_py_type: type | None = None
_owner: type[MessageType] | None = None

@property
def wire_type(self) -> int:
if self.type in WIRE_TYPES:
return WIRE_TYPES[self.type]
def py_type(self) -> type:
if self._py_type is None:
self._py_type = self._resolve_type()
# pyright issue https://github.com/microsoft/pyright/issues/8136
return self._py_type # type: ignore [Type ["Unknown | None"]]

def _resolve_type(self) -> type:
# look for a type in the builtins
py_type = PROTO_TYPES.get(self.proto_type)
if py_type is not None:
return py_type

# look for a type in the class locals
assert self._owner is not None, "Field is not owned by a MessageType"
py_type = self._owner.__dict__.get(self.proto_type)
if py_type is not None:
return py_type

# look for a type in the class globals
cls_module = sys.modules.get(self._owner.__module__, None)
cls_globals = getattr(cls_module, "__dict__", {})
py_type = cls_globals.get(self.proto_type)
if py_type is not None:
return py_type

raise TypeError(f"Could not resolve field type {self.proto_type}")

field_type_object = get_field_type_object(self)
if safe_issubclass(field_type_object, MessageType):
@property
def wire_type(self) -> int:
if issubclass(self.py_type, (MessageType, bytes, str)):
return WIRE_TYPE_LENGTH

if safe_issubclass(field_type_object, IntEnum):
if issubclass(self.py_type, int):
return WIRE_TYPE_INT

raise ValueError(f"Unrecognized type for field {self.name}")

def value_fits(self, value: int) -> bool:
if self.type == "uint32":
if self.proto_type == "uint32":
return 0 <= value < 2**32
if self.type == "uint64":
if self.proto_type == "uint64":
return 0 <= value < 2**64
if self.type == "sint32":
if self.proto_type == "sint32":
return -(2**31) <= value < 2**31
if self.type == "sint64":
if self.proto_type == "sint64":
return -(2**63) <= value < 2**63

raise ValueError(f"Cannot check range bounds for {self.type}")

raise ValueError(f"Cannot check range bounds for {self.proto_type}")

class _MessageTypeMeta(type):
def __init__(cls, name: str, bases: tuple, d: dict) -> None:
super().__init__(name, bases, d)
if name != "MessageType":
cls.__init__ = MessageType.__init__ # type: ignore [Parameter]


class MessageType(metaclass=_MessageTypeMeta):
class MessageType:
MESSAGE_WIRE_TYPE: t.ClassVar[int | None] = None

FIELDS: t.ClassVar[dict[int, Field]] = {}

def __init_subclass__(cls) -> None:
super().__init_subclass__()
# override the generated __init__ methods by the parent method
cls.__init__ = MessageType.__init__
for field in cls.FIELDS.values():
field._owner = cls

@classmethod
def get_field(cls, name: str) -> Field | None:
return next((f for f in cls.FIELDS.values() if f.name == name), None)
Expand Down Expand Up @@ -278,15 +301,6 @@ def write(self, buf: bytes) -> int:
return nwritten


def get_field_type_object(field: Field) -> type[MessageType] | type[IntEnum] | None:
from . import messages

field_type_object = getattr(messages, field.type, None)
if not safe_issubclass(field_type_object, (IntEnum, MessageType)):
return None
return field_type_object


def decode_packed_array_field(field: Field, reader: Reader) -> list[t.Any]:
assert field.repeated, "Not decoding packed array into non-repeated field"
length = load_uvarint(reader)
Expand All @@ -304,33 +318,26 @@ def decode_varint_field(field: Field, reader: Reader) -> int | bool | IntEnum:
assert field.wire_type == WIRE_TYPE_INT, f"Field {field.name} is not varint-encoded"
value = load_uvarint(reader)

field_type_object = get_field_type_object(field)
if safe_issubclass(field_type_object, IntEnum):
if issubclass(field.py_type, IntEnum):
try:
return field_type_object(value)
return field.py_type(value)
except ValueError as e:
# treat enum errors as warnings
LOG.info(f"On field {field.name}: {e}")
return value

if field.type.startswith("uint"):
if not field.value_fits(value):
LOG.info(
f"On field {field.name}: value {value} out of range for {field.type}"
)
return value
if issubclass(field.py_type, bool):
return bool(value)

if field.type.startswith("sint"):
value = uint_to_sint(value)
if issubclass(field.py_type, int):
if field.proto_type.startswith("sint"):
value = uint_to_sint(value)
if not field.value_fits(value):
LOG.info(
f"On field {field.name}: value {value} out of range for {field.type}"
f"On field {field.name}: value {value} out of range for {field.proto_type}"
)
return value

if field.type == "bool":
return bool(value)

raise TypeError # not a varint field or unknown type


Expand All @@ -341,19 +348,18 @@ def decode_length_delimited_field(
if value > MAX_FIELD_SIZE:
raise ValueError(f"Field {field.name} contents too large ({value} bytes)")

if field.type == "bytes":
if issubclass(field.py_type, bytes):
buf = bytearray(value)
reader.readinto(buf)
return bytes(buf)

if field.type == "string":
if issubclass(field.py_type, str):
buf = bytearray(value)
reader.readinto(buf)
return buf.decode()

field_type_object = get_field_type_object(field)
if safe_issubclass(field_type_object, MessageType):
return load_message(LimitedReader(reader, value), field_type_object)
if issubclass(field.py_type, MessageType):
return load_message(LimitedReader(reader, value), field.py_type)

raise TypeError # field type is unknown

Expand Down Expand Up @@ -446,47 +452,41 @@ def dump_message(writer: Writer, msg: "MessageType") -> None:
for svalue in fvalue:
dump_uvarint(writer, fkey)

field_type_object = get_field_type_object(field)
if safe_issubclass(field_type_object, MessageType):
if not isinstance(svalue, field_type_object):
if issubclass(field.py_type, MessageType):
if not isinstance(svalue, field.py_type):
raise ValueError(
f"Value {svalue} in field {field.name} is not {field_type_object.__name__}"
f"Value {svalue} in field {field.name} is not {field.py_type.__name__}"
)
counter = CountingWriter()
dump_message(counter, svalue)
dump_uvarint(writer, counter.size)
dump_message(writer, svalue)

elif safe_issubclass(field_type_object, IntEnum):
if svalue not in field_type_object.__members__.values():
elif issubclass(field.py_type, IntEnum):
if svalue not in field.py_type.__members__.values():
raise ValueError(
f"Value {svalue} in field {field.name} unknown for {field.type}"
f"Value {svalue} in field {field.name} unknown for {field.proto_type}"
)
dump_uvarint(writer, svalue)

elif field.type.startswith("uint"):
if not field.value_fits(svalue):
raise ValueError(
f"Value {svalue} in field {field.name} does not fit into {field.type}"
)
dump_uvarint(writer, svalue)
elif issubclass(field.py_type, bool):
dump_uvarint(writer, int(svalue))

elif field.type.startswith("sint"):
elif issubclass(field.py_type, int):
if not field.value_fits(svalue):
raise ValueError(
f"Value {svalue} in field {field.name} does not fit into {field.type}"
f"Value {svalue} in field {field.name} does not fit into {field.proto_type}"
)
dump_uvarint(writer, sint_to_uint(svalue))

elif field.type == "bool":
dump_uvarint(writer, int(svalue))
if field.proto_type.startswith("sint"):
svalue = sint_to_uint(svalue)
dump_uvarint(writer, svalue)

elif field.type == "bytes":
elif issubclass(field.py_type, bytes):
assert isinstance(svalue, (bytes, bytearray))
dump_uvarint(writer, len(svalue))
writer.write(svalue)

elif field.type == "string":
elif issubclass(field.py_type, str):
assert isinstance(svalue, str)
svalue_bytes = svalue.encode()
dump_uvarint(writer, len(svalue_bytes))
Expand Down Expand Up @@ -549,9 +549,9 @@ def pformat(name: str, value: t.Any, indent: int) -> str:

field = pb.get_field(name)
if field is not None:
if isinstance(value, int) and safe_issubclass(field.type, IntEnum):
if isinstance(value, int) and issubclass(field.py_type, IntEnum):
try:
return f"{field.type(value).name} ({value})"
return f"{field.py_type(value).name} ({value})"
except ValueError:
return str(value)

Expand All @@ -569,37 +569,29 @@ def pformat(name: str, value: t.Any, indent: int) -> str:


def value_to_proto(field: Field, value: t.Any) -> t.Any:
field_type_object = get_field_type_object(field)
if safe_issubclass(field_type_object, MessageType):
if issubclass(field.py_type, MessageType):
raise TypeError("value_to_proto only converts simple values")

if safe_issubclass(field_type_object, IntEnum):
if issubclass(field.py_type, IntEnum):
if isinstance(value, str):
return field_type_object.__members__[value]
return field.py_type.__members__[value]
else:
try:
return field_type_object(value)
return field.py_type(value)
except ValueError as e:
LOG.info(f"On field {field.name}: {e}")
return int(value)

if "int" in field.type:
return int(value)

if field.type == "bool":
return bool(value)

if field.type == "string":
return str(value)

if field.type == "bytes":
if issubclass(field.py_type, bytes):
if isinstance(value, str):
return bytes.fromhex(value)
elif isinstance(value, bytes):
return value
else:
raise TypeError(f"can't convert {type(value)} value to bytes")

return field.py_type(value)


def dict_to_proto(message_type: type[MT], d: dict[str, t.Any]) -> MT:
params = {}
Expand All @@ -611,9 +603,8 @@ def dict_to_proto(message_type: type[MT], d: dict[str, t.Any]) -> MT:
if not field.repeated:
value = [value]

field_type_object = get_field_type_object(field)
if safe_issubclass(field_type_object, MessageType):
newvalue = [dict_to_proto(field_type_object, v) for v in value]
if issubclass(field.py_type, MessageType):
newvalue = [dict_to_proto(field.py_type, v) for v in value]
else:
newvalue = [value_to_proto(field, v) for v in value]

Expand Down
15 changes: 1 addition & 14 deletions python/tests/test_protobuf_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import pytest

from trezorlib import messages, protobuf
from trezorlib import protobuf


class SomeEnum(IntEnum):
Expand Down Expand Up @@ -94,19 +94,6 @@ class RecursiveMessage(protobuf.MessageType):
}


# message types are read from the messages module so we need to "include" these messages there for now
messages.SomeEnum = SomeEnum
messages.WiderEnum = WiderEnum
messages.NarrowerEnum = NarrowerEnum
messages.PrimitiveMessage = PrimitiveMessage
messages.EnumMessageMoreValues = EnumMessageMoreValues
messages.EnumMessageLessValues = EnumMessageLessValues
messages.RepeatedFields = RepeatedFields
messages.RequiredFields = RequiredFields
messages.DefaultFields = DefaultFields
messages.RecursiveMessage = RecursiveMessage


def load_uvarint(buffer):
reader = BytesIO(buffer)
return protobuf.load_uvarint(reader)
Expand Down
Loading

0 comments on commit cd55d32

Please sign in to comment.