From 6e2d8a75d1b5fe90550f8c9fa78a8b6ca99264d9 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Wed, 15 Jul 2020 16:45:04 -0500 Subject: [PATCH 01/23] rebase onto latest master --- src/betterproto/plugin.py | 284 ++-------- src/betterproto/plugin_dataclasses.py | 642 +++++++++++++++++++++++ src/betterproto/templates/template.py.j2 | 42 +- 3 files changed, 714 insertions(+), 254 deletions(-) create mode 100644 src/betterproto/plugin_dataclasses.py diff --git a/src/betterproto/plugin.py b/src/betterproto/plugin.py index 4f01c292..eb7261c7 100755 --- a/src/betterproto/plugin.py +++ b/src/betterproto/plugin.py @@ -42,6 +42,19 @@ ) raise SystemExit(1) +from .plugin_dataclasses import ( + OutputTemplate, + ProtoInputFile, + Message, + Field, + OneOfField, + MapField, + EnumDefinition, + Service, + ServiceMethod, + is_map, + is_oneof +) def py_type(package: str, imports: set, field: FieldDescriptorProto) -> str: if field.type in [1, 2]: @@ -151,43 +164,30 @@ def generate_code(request, response): # Initialize Template data for each package for output_package_name, output_package_content in output_package_files.items(): - template_data = { - "input_package": output_package_content["input_package"], - "files": [f.name for f in output_package_content["files"]], - "imports": set(), - "datetime_imports": set(), - "typing_imports": set(), - "messages": [], - "enums": [], - "services": [], - } + template_data = OutputTemplate(input_package=output_package_content["input_package"]) + for input_proto_file in output_package_content["files"]: + input_proto_file_data = ProtoInputFile(parent=template_data, proto_obj=input_proto_file) output_package_content["template_data"] = template_data # Read Messages and Enums - output_types = [] for output_package_name, output_package_content in output_package_files.items(): - for proto_file in output_package_content["files"]: - for item, path in traverse(proto_file): - type_data = read_protobuf_type( - item, path, proto_file, output_package_content - ) - output_types.append(type_data) + for proto_file_data in output_package_content["template_data"].input_files: + for item, path in traverse(proto_file_data.proto_obj): + read_protobuf_type(item=item, path=path, proto_file_data=proto_file_data) # Read Services for output_package_name, output_package_content in output_package_files.items(): - for proto_file in output_package_content["files"]: - for index, service in enumerate(proto_file.service): - read_protobuf_service( - service, index, proto_file, output_package_content, output_types - ) + for proto_file_data in output_package_content["template_data"].input_files: + for index, service in enumerate(proto_file_data.proto_obj.service): + read_protobuf_service(service, index, proto_file_data) # Render files output_paths = set() for output_package_name, output_package_content in output_package_files.items(): template_data = output_package_content["template_data"] - template_data["imports"] = sorted(template_data["imports"]) - template_data["datetime_imports"] = sorted(template_data["datetime_imports"]) - template_data["typing_imports"] = sorted(template_data["typing_imports"]) + template_data.imports = sorted(template_data.imports) + template_data.datetime_imports = sorted(template_data.datetime_imports) + template_data.typing_imports = sorted(template_data.typing_imports) # Fill response output_path = pathlib.Path(*output_package_name.split("."), "__init__.py") @@ -219,224 +219,42 @@ def generate_code(request, response): for output_package_name in sorted(output_paths.union(init_files)): print(f"Writing {output_package_name}", file=sys.stderr) - -def read_protobuf_type(item: DescriptorProto, path: List[int], proto_file, content): - input_package_name = content["input_package"] - template_data = content["template_data"] - data = { - "name": item.name, - "py_name": pythonize_class_name(item.name), - "descriptor": item, - "package": input_package_name, - } +def read_protobuf_type(item: DescriptorProto, path: List[int], proto_file_data: ProtoInputFile): if isinstance(item, DescriptorProto): - # print(item, file=sys.stderr) if item.options.map_entry: # Skip generated map entry messages since we just use dicts return - - data.update( - { - "type": "Message", - "comment": get_comment(proto_file, path), - "properties": [], - } + # Process Message + message_data = Message( + parent=proto_file_data, + proto_obj=item, + path=path ) - - for i, f in enumerate(item.field): - t = py_type(input_package_name, template_data["imports"], f) - zero = get_py_zero(f.type) - - repeated = False - packed = False - - field_type = f.Type.Name(f.type).lower()[5:] - - field_wraps = "" - match_wrapper = re.match(r"\.google\.protobuf\.(.+)Value", f.type_name) - if match_wrapper: - wrapped_type = "TYPE_" + match_wrapper.group(1).upper() - if hasattr(betterproto, wrapped_type): - field_wraps = f"betterproto.{wrapped_type}" - - map_types = None - if f.type == 11: - # This might be a map... - message_type = f.type_name.split(".").pop().lower() - # message_type = py_type(package) - map_entry = f"{f.name.replace('_', '').lower()}entry" - - if message_type == map_entry: - for nested in item.nested_type: - if nested.name.replace("_", "").lower() == map_entry: - if nested.options.map_entry: - # print("Found a map!", file=sys.stderr) - k = py_type( - input_package_name, - template_data["imports"], - nested.field[0], - ) - v = py_type( - input_package_name, - template_data["imports"], - nested.field[1], - ) - t = f"Dict[{k}, {v}]" - field_type = "map" - map_types = ( - f.Type.Name(nested.field[0].type), - f.Type.Name(nested.field[1].type), - ) - template_data["typing_imports"].add("Dict") - - if f.label == 3 and field_type != "map": - # Repeated field - repeated = True - t = f"List[{t}]" - zero = "[]" - template_data["typing_imports"].add("List") - - if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]: - packed = True - - one_of = "" - if f.HasField("oneof_index"): - one_of = item.oneof_decl[f.oneof_index].name - - if "Optional[" in t: - template_data["typing_imports"].add("Optional") - - if "timedelta" in t: - template_data["datetime_imports"].add("timedelta") - elif "datetime" in t: - template_data["datetime_imports"].add("datetime") - - data["properties"].append( - { - "name": f.name, - "py_name": pythonize_field_name(f.name), - "number": f.number, - "comment": get_comment(proto_file, path + [2, i]), - "proto_type": int(f.type), - "field_type": field_type, - "field_wraps": field_wraps, - "map_types": map_types, - "type": t, - "zero": zero, - "repeated": repeated, - "packed": packed, - "one_of": one_of, - } - ) - # print(f, file=sys.stderr) - - template_data["messages"].append(data) - return data + for index, field in enumerate(item.field): + if is_map(field, item): + MapField(parent=message_data, proto_obj=field, path=path+[2, index]) + elif is_oneof(field): + OneOfField(parent=message_data, proto_obj=field, path=path+[2, index]) + else: + Field(parent=message_data, proto_obj=field, path=path+[2, index]) elif isinstance(item, EnumDescriptorProto): - # print(item.name, path, file=sys.stderr) - data.update( - { - "type": "Enum", - "comment": get_comment(proto_file, path), - "entries": [ - { - "name": v.name, - "value": v.number, - "comment": get_comment(proto_file, path + [2, i]), - } - for i, v in enumerate(item.value) - ], - } - ) - - template_data["enums"].append(data) - return data + # Enum + EnumDefinition(proto_obj=item, parent=proto_file_data, path=path) -def lookup_method_input_type(method, types): - package, name = parse_source_type_name(method.input_type) - - for known_type in types: - if known_type["type"] != "Message": - continue - - # Nested types are currently flattened without dots. - # Todo: keep a fully quantified name in types, that is comparable with method.input_type - if ( - package == known_type["package"] - and name.replace(".", "") == known_type["name"] - ): - return known_type - - -def is_mutable_field_type(field_type: str) -> bool: - return field_type.startswith("List[") or field_type.startswith("Dict[") - - -def read_protobuf_service( - service: ServiceDescriptorProto, index, proto_file, content, output_types -): - input_package_name = content["input_package"] - template_data = content["template_data"] - # print(service, file=sys.stderr) - data = { - "name": service.name, - "py_name": pythonize_class_name(service.name), - "comment": get_comment(proto_file, [6, index]), - "methods": [], - } +def read_protobuf_service(service: ServiceDescriptorProto, index: int, proto_file_data: ProtoInputFile): + service_data = Service( + parent=proto_file_data, + proto_obj=service, + path=[6, index], + ) for j, method in enumerate(service.method): - method_input_message = lookup_method_input_type(method, output_types) - - # This section ensures that method arguments having a default - # value that is initialised as a List/Dict (mutable) is replaced - # with None and initialisation is deferred to the beginning of the - # method definition. This is done so to avoid any side-effects. - # Reference: https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments - mutable_default_args = [] - - if method_input_message: - for field in method_input_message["properties"]: - if ( - not method.client_streaming - and field["zero"] != "None" - and is_mutable_field_type(field["type"]) - ): - mutable_default_args.append((field["py_name"], field["zero"])) - field["zero"] = "None" - - if field["zero"] == "None": - template_data["typing_imports"].add("Optional") - - data["methods"].append( - { - "name": method.name, - "py_name": pythonize_method_name(method.name), - "comment": get_comment(proto_file, [6, index, 2, j], indent=8), - "route": f"/{input_package_name}.{service.name}/{method.name}", - "input": get_type_reference( - input_package_name, template_data["imports"], method.input_type - ).strip('"'), - "input_message": method_input_message, - "output": get_type_reference( - input_package_name, - template_data["imports"], - method.output_type, - unwrap=False, - ), - "client_streaming": method.client_streaming, - "server_streaming": method.server_streaming, - "mutable_default_args": mutable_default_args, - } - ) - if method.client_streaming: - template_data["typing_imports"].add("AsyncIterable") - template_data["typing_imports"].add("Iterable") - template_data["typing_imports"].add("Union") - if method.server_streaming: - template_data["typing_imports"].add("AsyncIterator") - template_data["services"].append(data) + ServiceMethod( + parent=service_data, + proto_obj=method, + path=[6, index, 2, j], + ) def main(): diff --git a/src/betterproto/plugin_dataclasses.py b/src/betterproto/plugin_dataclasses.py new file mode 100644 index 00000000..1360254c --- /dev/null +++ b/src/betterproto/plugin_dataclasses.py @@ -0,0 +1,642 @@ +#!/usr/bin/env python +from __future__ import annotations + +import re +from dataclasses import dataclass +from dataclasses import field +from typing import ( + Union, + Type, + List, + Set, + Text, +) +import textwrap + +import betterproto +from betterproto.compile.importing import get_type_reference, parse_source_type_name +from betterproto.compile.naming import ( + pythonize_class_name, + pythonize_field_name, + pythonize_method_name, +) + +try: + # betterproto[compiler] specific dependencies + from google.protobuf.compiler import plugin_pb2 as plugin + from google.protobuf.descriptor_pb2 import ( + DescriptorProto, + EnumDescriptorProto, + FieldDescriptorProto, + FileDescriptorProto, + ) +except ImportError as err: + missing_import = re.match(r".*(cannot import name .*$)", err.args[0]).group(1) + print( + "\033[31m" + f"Unable to import `{missing_import}` from betterproto plugin! " + "Please ensure that you've installed betterproto as " + '`pip install "betterproto[compiler]"` so that compiler dependencies ' + "are included." + "\033[0m" + ) + raise SystemExit(1) + +# Create a unique placeholder to deal with +# https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses +PLACEHOLDER = object() + +# Organize proto types into categories +PROTO_FLOAT_TYPES = ( + FieldDescriptorProto.TYPE_DOUBLE, # 1 + FieldDescriptorProto.TYPE_FLOAT, # 2 +) +PROTO_INT_TYPES = ( + FieldDescriptorProto.TYPE_INT64, # 3 + FieldDescriptorProto.TYPE_UINT64, # 4 + FieldDescriptorProto.TYPE_INT32, # 5 + FieldDescriptorProto.TYPE_FIXED64, # 6 + FieldDescriptorProto.TYPE_FIXED32, # 7 + FieldDescriptorProto.TYPE_UINT32, # 13 + FieldDescriptorProto.TYPE_SFIXED32, # 15 + FieldDescriptorProto.TYPE_SFIXED64, # 16 + FieldDescriptorProto.TYPE_SINT32, # 17 + FieldDescriptorProto.TYPE_SINT64, # 18 +) +PROTO_BOOL_TYPES = (FieldDescriptorProto.TYPE_BOOL,) # 8 +PROTO_STR_TYPES = (FieldDescriptorProto.TYPE_STRING,) # 9 +PROTO_BYTES_TYPES = (FieldDescriptorProto.TYPE_BYTES,) # 12 +PROTO_MESSAGE_TYPES = ( + FieldDescriptorProto.TYPE_MESSAGE, # 11 + FieldDescriptorProto.TYPE_ENUM, # 14 +) +PROTO_MAP_TYPES = (FieldDescriptorProto.TYPE_MESSAGE,) # 11 +PROTO_PACKED_TYPES = ( + FieldDescriptorProto.TYPE_DOUBLE, # 1 + FieldDescriptorProto.TYPE_FLOAT, # 2 + FieldDescriptorProto.TYPE_INT64, # 3 + FieldDescriptorProto.TYPE_UINT64, # 4 + FieldDescriptorProto.TYPE_INT32, # 5 + FieldDescriptorProto.TYPE_FIXED64, # 6 + FieldDescriptorProto.TYPE_FIXED32, # 7 + FieldDescriptorProto.TYPE_BOOL, # 8 + FieldDescriptorProto.TYPE_UINT32, # 13 + FieldDescriptorProto.TYPE_SFIXED32, # 15 + FieldDescriptorProto.TYPE_SFIXED64, # 16 + FieldDescriptorProto.TYPE_SINT32, # 17 + FieldDescriptorProto.TYPE_SINT64, # 18 +) + + +def get_comment(proto_file, path: List[int], indent: int = 4) -> str: + pad = " " * indent + for sci in proto_file.source_code_info.location: + # print(list(sci.path), path, file=sys.stderr) + if list(sci.path) == path and sci.leading_comments: + lines = textwrap.wrap( + sci.leading_comments.strip().replace("\n", ""), + width=79 - indent, + ) + + if path[-2] == 2 and path[-4] != 6: + # This is a field + return f"{pad}# " + f"\n{pad}# ".join(lines) + else: + # This is a message, enum, service, or method + if len(lines) == 1 and len(lines[0]) < 79 - indent - 6: + lines[0] = lines[0].strip('"') + return f'{pad}"""{lines[0]}"""' + else: + joined = f"\n{pad}".join(lines) + return f'{pad}"""\n{pad}{joined}\n{pad}"""' + + return "" + + +class ProtoContentBase: + """Methods common to Message, Service and ServiceMethod.""" + + path: List[int] + comment_indent: int = 4 + + def __post_init__(self): + """Checks that no fake default fields were left as placeholders.""" + for field_name, field_val in self.__dataclass_fields__.items(): + if field_val is PLACEHOLDER: + raise ValueError( + f"`{field_name}` is a required field with no default value." + ) + + @property + def output_file(self) -> OutputTemplate: + current = self.parent + while not isinstance(current, OutputTemplate): + current = current.parent + return current + + @property + def proto_file(self) -> FieldDescriptorProto: + current = self + while not isinstance(current, ProtoInputFile): + current = current.parent + return current.proto_obj + + @property + def comment(self) -> str: + """Crawl the proto source code and retrieve comments for this object.""" + return get_comment( + proto_file=self.proto_file, + path=self.path, + indent=self.comment_indent, + ) + + +@dataclass +class OutputTemplate: + """Representation of an output .py file. + """ + + input_package: str + input_files: List[ProtoInputFile] = field(default_factory=list) + imports: Set[str] = field(default_factory=set) + datetime_imports: Set[str] = field(default_factory=set) + typing_imports: Set[str] = field(default_factory=set) + messages: List[Message] = field(default_factory=list) + enums: List[Enum] = field(default_factory=list) + services: List[Service] = field(default_factory=list) + + @property + def input_filenames(self) -> List[str]: + return [f.name for f in self.input_files] + + +@dataclass +class ProtoInputFile: + """Representation of an input .proto file. + """ + + parent: OutputTemplate + proto_obj: FileDescriptorProto + + @property + def name(self) -> str: + return self.proto_obj.name + + def __post_init__(self): + # Add proto file to output file + self.parent.input_files.append(self) + + +@dataclass +class Message(ProtoContentBase): + """Representation of a protobuf message. + """ + + parent: Union[ProtoInputFile, Message] = PLACEHOLDER + proto_obj: DescriptorProto = PLACEHOLDER + path: List[int] = PLACEHOLDER + fields: List[Union[Field, Message]] = field(default_factory=list) + + def __post_init__(self): + # Add message to output file + if isinstance(self.parent, ProtoInputFile): + if isinstance(self, EnumDefinition): + self.output_file.enums.append(self) + else: + self.output_file.messages.append(self) + super().__post_init__() + + @property + def proto_name(self) -> str: + return self.proto_obj.name + + @property + def py_name(self) -> str: + return pythonize_class_name(self.proto_name) + + @property + def annotation(self) -> str: + if self.repeated: + return f"List[{self.py_name}]" + return self.py_name + + +def is_map( + proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto +) -> bool: + """True if proto_field_obj is a map, otherwise False. + """ + if proto_field_obj.type is FieldDescriptorProto.TYPE_MESSAGE: + # This might be a map... + message_type = proto_field_obj.type_name.split(".").pop().lower() + map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry" + if message_type == map_entry: + for nested in parent_message.nested_type: # parent message + if nested.name.replace("_", "").lower() == map_entry: + if nested.options.map_entry: + return True + return False + + +def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool: + """True if proto_field_obj is a OneOf, otherwise False. + """ + if proto_field_obj.HasField("oneof_index"): + return True + return False + + +@dataclass +class Field(Message): + parent: Message = PLACEHOLDER + proto_obj: FieldDescriptorProto = PLACEHOLDER + + def __post_init__(self): + # Add field to message + self.parent.fields.append(self) + # Check for new imports + annotation = self.annotation + if "Optional[" in annotation: + self.output_file.typing_imports.add("Optional") + if "List[" in annotation: + self.output_file.typing_imports.add("List") + if "Dict[" in annotation: + self.output_file.typing_imports.add("Dict") + if "timedelta" in annotation: + self.output_file.datetime_imports.add("timedelta") + if "datetime" in annotation: + self.output_file.datetime_imports.add("datetime") + super().__post_init__() # call Field -> Message __post_init__ + + def get_field_string(self, indent: int = 4) -> str: + """Construct string representation of this field as a field.""" + name = f"{self.py_name}" + annotations = f": {self.annotation}" + betterproto_field_type = \ + f"betterproto.{self.field_type}_field({self.proto_obj.number}" + \ + f"{self.betterproto_field_args}" + \ + ")" + return name + annotations + " = " + betterproto_field_type + + @property + def betterproto_field_args(self): + args = "" + if self.field_wraps: + args.append(f", wraps={self.field_wraps}") + return args + + @property + def field_wraps(self) -> Union[str, None]: + """Returns betterproto wrapped field type or None. + """ + match_wrapper = re.match( + r"\.google\.protobuf\.(.+)Value", self.proto_obj.type_name + ) + if match_wrapper: + wrapped_type = "TYPE_" + match_wrapper.group(1).upper() + if hasattr(betterproto, wrapped_type): + return f"betterproto.{wrapped_type}" + return None + + @property + def repeated(self) -> bool: + if ( + self.proto_obj.label == FieldDescriptorProto.LABEL_REPEATED + and not is_map(self.proto_obj, self.parent) + ): + return True + return False + + @property + def mutable(self) -> bool: + """True if the field is a mutable type, otherwise False.""" + return self.field_type.startswith("List[") or self.field_type.startswith("Dict[") + + @property + def field_type(self) -> str: + """String representation of proto field type.""" + return ( + self.proto_obj.Type.Name(self.proto_obj.type) + .lower() + .replace("type_", "") + ) + + @property + def default_value_string(self) -> Union[Text, None, float, int]: + """Python representation of the default proto value. + """ + if self.repeated: + return "[]" + if self.py_type == "int": + return "0" + if self.py_type == "float": + return "0.0" + elif self.py_type == "bool": + return "False" + elif self.py_type == "str": + return '""' + elif self.py_type == "bytes": + return 'b""' + else: + # Message type + return "None" + + @property + def packed(self) -> bool: + """True if the wire representation is a packed format.""" + if self.repeated and self.proto_obj.type in PROTO_PACKED_TYPES: + return True + return False + + @property + def py_name(self) -> str: + """Pythonized name.""" + return pythonize_field_name(self.proto_name) + + @property + def proto_name(self) -> str: + """Original protobuf name.""" + return self.proto_obj.name + + @property + def py_type(self) -> str: + """String representation of Python type.""" + if self.proto_obj.type in PROTO_FLOAT_TYPES: + return "float" + elif self.proto_obj.type in PROTO_INT_TYPES: + return "int" + elif self.proto_obj.type in PROTO_BOOL_TYPES: + return "bool" + elif self.proto_obj.type in PROTO_STR_TYPES: + return "str" + elif self.proto_obj.type in PROTO_BYTES_TYPES: + return "bytes" + elif self.proto_obj.type in PROTO_MESSAGE_TYPES: + # Type referencing another defined Message or a named enum + return get_type_reference( + package=self.output_file.input_package, + imports=self.output_file.imports, + source_type=self.proto_obj.type_name, + ) + else: + raise NotImplementedError(f"Unknown type {field.type}") + + @property + def annotation(self) -> str: + if self.repeated: + return f"List[{self.py_type}]" + return self.py_type + + +@dataclass +class OneOfField(Field): + + @property + def betterproto_field_args(self): + args = super().betterproto_field_args() + args = args + f', group="{self.parent.oneof_decl[self.proto_obj.oneof_index].name}"' + return args + + +@dataclass +class MapField(Field): + py_k_type: Type = PLACEHOLDER + py_v_type: Type = PLACEHOLDER + proto_k_type: str = PLACEHOLDER + proto_v_type: str = PLACEHOLDER + + def __post_init__(self): + """Explore nested types and set k_type and v_type if unset.""" + map_entry = f"{self.proto_obj.name.replace('_', '').lower()}entry" + for nested in self.parent.proto_obj.nested_type: + if nested.name.replace("_", "").lower() == map_entry: + if nested.options.map_entry: + # Get Python types + self.py_k_type = Field( + parent=self, proto_obj=nested.field[0], # key + ).py_type + self.py_v_type = Field( + parent=self, proto_obj=nested.field[1], # key + ).py_type + # Get proto types + self.proto_k_type = self.proto_obj.Type.Name( + nested.field[0].type + ) + self.proto_v_type = self.proto_obj.Type.Name( + nested.field[1].type + ) + super().__post_init__() # call Field -> Message __post_init__ + + def get_field_string(self, indent: int = 4) -> str: + """Construct string representation of this field.""" + name = f"{self.py_name}" + annotations = f": {self.annotation}" + betterproto_field_type = ( + f"betterproto.map_field(" + f"{self.proto_obj.number}, betterproto.{self.proto_k_type}, " + f"betterproto.{self.proto_v_type})" + ) + return name + annotations + " = " + betterproto_field_type + + @property + def annotation(self): + return f"Dict[{self.py_k_type}, {self.py_v_type}]" + + @property + def repeated(self): + return False # maps cannot be repeated + + +@dataclass +class EnumDefinition(Message): + """Representation of a proto Enum definition.""" + + proto_obj: EnumDescriptorProto = PLACEHOLDER + entries: List[EnumDefinition.EnumEntry] = PLACEHOLDER + + @dataclass(unsafe_hash=True) + class EnumEntry: + """Representation of an Enum entry.""" + + name: str + value: int + comment: str + + def __post_init__(self): + # Get entries + self.entries = [ + self.EnumEntry( + name=v.name, + value=v.number, + comment=get_comment( + proto_file=self.proto_file, path=self.path + [2, i] + ), + ) + for i, v in enumerate(self.proto_obj.value) + ] + super().__post_init__() # call Message __post_init__ + + @property + def default_value_string(self) -> int: + """Python representation of the default value for Enums. + + As per the spec, this is the first value of the Enum. + """ + return str(self.entries[0].value) # should ALWAYS be int(0)! + + +@dataclass +class Service(ProtoContentBase): + parent: ProtoInputFile = PLACEHOLDER + proto_obj: DescriptorProto = PLACEHOLDER + path: List[int] = PLACEHOLDER + methods: List[ServiceMethod] = field(default_factory=list) + + def __post_init__(self) -> None: + # Add service to output file + self.output_file.services.append(self) + super().__post_init__() # check for unset fields + + @property + def proto_name(self): + return self.proto_obj.name + + @property + def py_name(self): + return pythonize_class_name(self.proto_name) + + +@dataclass +class ServiceMethod(ProtoContentBase): + + parent: Service + proto_obj: MethodDescriptorProto + path: List[int] = PLACEHOLDER + comment_indent: int = 8 + + def __post_init__(self) -> None: + # Add method to service + self.parent.methods.append(self) + + # Check for Optional import + if self.py_input_message: + for field in self.py_input_message.fields: + if field.default_value_string == "None": + self.output_file.typing_imports.add("Optional") + + # Check for Async imports + if self.client_streaming: + self.output_file.typing_imports.add("AsyncIterable") + self.output_file.typing_imports.add("Iterable") + self.output_file.typing_imports.add("Union") + if self.server_streaming: + self.output_file.typing_imports.add("AsyncIterator") + + super().__post_init__() # check for unset fields + + @property + def mutable_default_arguments(self) -> List[Tuple[str, str]]: + """Handle mutable default arguments. + + Returns a list of tuples containing the name and default value + for arguments to this message who's default value is mutable. + The defaults are swapped out for None and replaced back inside + the method's body. + Reference: https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments + + Returns + ------- + List[Tuple[str, str]] + Each tuple contains the name and actual default value (as a string) for each + argument with mutable default values. + """ + mutable_default_arguments = [] + + if self.py_input_message: + for field in self.py_input_message.fields: + if ( + not self.client_streaming + and field.default_value_string != "None" + and field.mutable + ): + mutable_default_arguments.append( + (field.py_name, field.default_value_string) + ) + return mutable_default_arguments + + @property + def py_name(self) -> str: + """Pythonized method name.""" + return pythonize_method_name(self.proto_obj.name) + + @property + def proto_name(self) -> str: + """Original protobuf name.""" + return self.proto_obj.name + + @property + def route(self) -> str: + return f"/{self.output_file.input_package}.{self.parent.proto_name}/{self.proto_name}" + + @property + def py_input_message(self) -> Union[None, Message]: + """Find the input message object. + + Returns + ------- + Union[None, Message] + Method instance representing the input message. + If not input message could be found or there are no + input messages, None is returned. + """ + package, name = parse_source_type_name(self.proto_obj.input_type) + + # Nested types are currently flattened without dots. + # Todo: keep a fully quantified name in types, that is comparable with method.input_type + for msg in self.output_file.messages: + if msg.proto_name == self.proto_name and \ + msg.output_file.input_package == package: + return msg + return None + + @property + def py_input_message_type(self) -> str: + """String representation of the Python type correspoding to the + input message. + + Returns + ------- + str + String representation of the Python type correspoding to the + input message. + """ + return get_type_reference( + package=self.output_file.input_package, + imports=self.output_file.imports, + source_type=self.proto_obj.input_type, + ).strip('"') + + @property + def py_output_message_type(self) -> str: + """String representation of the Python type correspoding to the + output message. + + Returns + ------- + str + String representation of the Python type correspoding to the + output message. + """ + return get_type_reference( + package=self.output_file.input_package, + imports=self.output_file.imports, + source_type=self.proto_obj.output_type, + ).strip('"') + + @property + def client_streaming(self) -> bool: + return self.proto_obj.client_streaming + + @property + def server_streaming(self) -> bool: + return self.proto_obj.server_streaming diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index b7ca89c5..26d9a2a6 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -1,5 +1,5 @@ # Generated by the protocol buffer compiler. DO NOT EDIT! -# sources: {{ ', '.join(description.files) }} +# sources: {{ ', '.join(description.input_filenames) }} # plugin: python-betterproto from dataclasses import dataclass {% if description.datetime_imports %} @@ -40,13 +40,13 @@ class {{ message.py_name }}(betterproto.Message): {{ message.comment }} {% endif %} - {% for field in message.properties %} + {% for field in message.fields %} {% if field.comment %} {{ field.comment }} {% endif %} - {{ field.py_name }}: {{ field.type }} = betterproto.{{ field.field_type }}_field({{ field.number }}{% if field.field_type == 'map'%}, betterproto.{{ field.map_types[0] }}, betterproto.{{ field.map_types[1] }}{% endif %}{% if field.one_of %}, group="{{ field.one_of }}"{% endif %}{% if field.field_wraps %}, wraps={{ field.field_wraps }}{% endif %}) + {{ field.get_field_string() }} {% endfor %} - {% if not message.properties %} + {% if not message.fields %} pass {% endif %} @@ -61,21 +61,21 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {% for method in service.methods %} async def {{ method.py_name }}(self {%- if not method.client_streaming -%} - {%- if method.input_message and method.input_message.properties -%}, *, - {%- for field in method.input_message.properties -%} - {{ field.py_name }}: {% if field.zero == "None" and not field.type.startswith("Optional[") -%} - Optional[{{ field.type }}] + {%- if method.py_input_message and method.py_input_message.fields -%}, *, + {%- for field in method.py_input_message.fields -%} + {{ field.py_name }}: {% if field.default_value_string == "None" and not field.annotation.startswith("Optional[") -%} + Optional[{{ field.py_type }}] {%- else -%} - {{ field.type }} - {%- endif -%} = {{ field.zero }} + {{ field.py_type }} + {%- endif -%} = {{ field.default_value_string }} {%- if not loop.last %}, {% endif -%} {%- endfor -%} {%- endif -%} {%- else -%} {# Client streaming: need a request iterator instead #} - , request_iterator: Union[AsyncIterable["{{ method.input }}"], Iterable["{{ method.input }}"]] + , request_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]] {%- endif -%} - ) -> {% if method.server_streaming %}AsyncIterator[{{ method.output }}]{% else %}{{ method.output }}{% endif %}: + ) -> {% if method.server_streaming %}AsyncIterator[{{ method.py_output_message_type }}]{% else %}{{ method.py_output_message_type }}{% endif %}: {% if method.comment %} {{ method.comment }} @@ -85,8 +85,8 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {% endfor %} {% if not method.client_streaming %} - request = {{ method.input }}() - {% for field in method.input_message.properties %} + request = {{ method.py_input_message_type }}() + {% for field in method.py_input_message.fields %} {% if field.field_type == 'message' %} if {{ field.py_name }} is not None: request.{{ field.py_name }} = {{ field.py_name }} @@ -101,15 +101,15 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): async for response in self._stream_stream( "{{ method.route }}", request_iterator, - {{ method.input }}, - {{ method.output.strip('"') }}, + {{ method.py_input_message_type }}, + {{ method.py_output_message_type.strip('"') }}, ): yield response {% else %}{# i.e. not client streaming #} async for response in self._unary_stream( "{{ method.route }}", request, - {{ method.output.strip('"') }}, + {{ method.py_output_message_type.strip('"') }}, ): yield response @@ -119,14 +119,14 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): return await self._stream_unary( "{{ method.route }}", request_iterator, - {{ method.input }}, - {{ method.output.strip('"') }} + {{ method.py_input_message_type }}, + {{ method.py_output_message_type.strip('"') }} ) {% else %}{# i.e. not client streaming #} return await self._unary_unary( "{{ method.route }}", request, - {{ method.output.strip('"') }} + {{ method.py_output_message_type.strip('"') }} ) {% endif %}{# client streaming #} {% endif %} @@ -136,4 +136,4 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {% for i in description.imports %} {{ i }} -{% endfor %} \ No newline at end of file +{% endfor %} From e60791b676602fdbfde99046056c8ad5bfc93b64 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 16 Jul 2020 10:41:07 -0500 Subject: [PATCH 02/23] fix bugs --- src/betterproto/plugin.py | 3 --- src/betterproto/plugin_dataclasses.py | 8 +++++--- src/betterproto/templates/template.py.j2 | 6 +++--- 3 files changed, 8 insertions(+), 9 deletions(-) diff --git a/src/betterproto/plugin.py b/src/betterproto/plugin.py index eb7261c7..4e89d0c9 100755 --- a/src/betterproto/plugin.py +++ b/src/betterproto/plugin.py @@ -185,9 +185,6 @@ def generate_code(request, response): output_paths = set() for output_package_name, output_package_content in output_package_files.items(): template_data = output_package_content["template_data"] - template_data.imports = sorted(template_data.imports) - template_data.datetime_imports = sorted(template_data.datetime_imports) - template_data.typing_imports = sorted(template_data.typing_imports) # Fill response output_path = pathlib.Path(*output_package_name.split("."), "__init__.py") diff --git a/src/betterproto/plugin_dataclasses.py b/src/betterproto/plugin_dataclasses.py index 1360254c..4032ecda 100644 --- a/src/betterproto/plugin_dataclasses.py +++ b/src/betterproto/plugin_dataclasses.py @@ -282,7 +282,7 @@ def get_field_string(self, indent: int = 4) -> str: def betterproto_field_args(self): args = "" if self.field_wraps: - args.append(f", wraps={self.field_wraps}") + args = args + f", wraps={self.field_wraps}" return args @property @@ -393,8 +393,8 @@ class OneOfField(Field): @property def betterproto_field_args(self): - args = super().betterproto_field_args() - args = args + f', group="{self.parent.oneof_decl[self.proto_obj.oneof_index].name}"' + args = super().betterproto_field_args + args = args + f', group="{self.parent.proto_obj.oneof_decl[self.proto_obj.oneof_index].name}"' return args @@ -562,6 +562,7 @@ def mutable_default_arguments(self) -> List[Tuple[str, str]]: mutable_default_arguments.append( (field.py_name, field.default_value_string) ) + self.output_file.typing_imports.add("Optional") return mutable_default_arguments @property @@ -599,6 +600,7 @@ def py_input_message(self) -> Union[None, Message]: return msg return None + @property def py_input_message_type(self) -> str: """String representation of the Python type correspoding to the diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index 26d9a2a6..03069c61 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -3,11 +3,11 @@ # plugin: python-betterproto from dataclasses import dataclass {% if description.datetime_imports %} -from datetime import {% for i in description.datetime_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} +from datetime import {% for i in description.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} {% endif%} {% if description.typing_imports %} -from typing import {% for i in description.typing_imports %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} +from typing import {% for i in description.typing_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %} {% endif %} @@ -134,6 +134,6 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {% endfor %} {% endfor %} -{% for i in description.imports %} +{% for i in description.imports|sort %} {{ i }} {% endfor %} From 46c10353059d0096edb6880bc92fc9b2c70cd128 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 16 Jul 2020 10:51:37 -0500 Subject: [PATCH 03/23] add import --- Makefile | 4 ++-- src/betterproto/plugin_dataclasses.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index 5a3b7571..54bd37f6 100644 --- a/Makefile +++ b/Makefile @@ -8,8 +8,8 @@ help: ## - Show this help. generate: ## - Generate test cases (do this once before running test) poetry run python -m tests.generate -test: ## - Run tests - poetry run pytest --cov betterproto +test: ## - Run tests, ingoring collection errors (ex from missing imports) + poetry run pytest --cov betterproto --continue-on-collection-errors types: ## - Check types with mypy poetry run mypy src/betterproto --ignore-missing-imports diff --git a/src/betterproto/plugin_dataclasses.py b/src/betterproto/plugin_dataclasses.py index 4032ecda..50c9b318 100644 --- a/src/betterproto/plugin_dataclasses.py +++ b/src/betterproto/plugin_dataclasses.py @@ -523,7 +523,9 @@ def __post_init__(self) -> None: for field in self.py_input_message.fields: if field.default_value_string == "None": self.output_file.typing_imports.add("Optional") - + if "Optional" in self.py_output_message_type: + self.output_file.typing_imports.add("Optional") + # Check for Async imports if self.client_streaming: self.output_file.typing_imports.add("AsyncIterable") From faea4047d31c16ab41e40377017c1839619ce5ca Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 16 Jul 2020 11:04:43 -0500 Subject: [PATCH 04/23] fix map type detection --- src/betterproto/plugin_dataclasses.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/betterproto/plugin_dataclasses.py b/src/betterproto/plugin_dataclasses.py index 50c9b318..8dd4699f 100644 --- a/src/betterproto/plugin_dataclasses.py +++ b/src/betterproto/plugin_dataclasses.py @@ -226,7 +226,7 @@ def is_map( ) -> bool: """True if proto_field_obj is a map, otherwise False. """ - if proto_field_obj.type is FieldDescriptorProto.TYPE_MESSAGE: + if proto_field_obj.type == FieldDescriptorProto.TYPE_MESSAGE: # This might be a map... message_type = proto_field_obj.type_name.split(".").pop().lower() map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry" From 092076b8253bc8b5c8a1c7988875ac4cca949af6 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 16 Jul 2020 11:42:32 -0500 Subject: [PATCH 05/23] formatting fixes --- src/betterproto/plugin.py | 20 +++---- src/betterproto/plugin_dataclasses.py | 86 +++++++++++++++++---------- 2 files changed, 60 insertions(+), 46 deletions(-) diff --git a/src/betterproto/plugin.py b/src/betterproto/plugin.py index 4e89d0c9..ba57cde0 100755 --- a/src/betterproto/plugin.py +++ b/src/betterproto/plugin.py @@ -3,21 +3,13 @@ import itertools import os.path import pathlib -import re import sys import textwrap from typing import List, Union -from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest -import betterproto -from betterproto.compile.importing import get_type_reference, parse_source_type_name -from betterproto.compile.naming import ( - pythonize_class_name, - pythonize_field_name, - pythonize_method_name, -) from betterproto.lib.google.protobuf import ServiceDescriptorProto +from betterproto.compile.importing import get_type_reference try: # betterproto[compiler] specific dependencies @@ -28,7 +20,7 @@ EnumDescriptorProto, FieldDescriptorProto, ) - import google.protobuf.wrappers_pb2 as google_wrappers + from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest import jinja2 except ImportError as err: missing_import = err.args[0][17:-1] @@ -56,6 +48,7 @@ is_oneof ) + def py_type(package: str, imports: set, field: FieldDescriptorProto) -> str: if field.type in [1, 2]: return "float" @@ -166,7 +159,7 @@ def generate_code(request, response): for output_package_name, output_package_content in output_package_files.items(): template_data = OutputTemplate(input_package=output_package_content["input_package"]) for input_proto_file in output_package_content["files"]: - input_proto_file_data = ProtoInputFile(parent=template_data, proto_obj=input_proto_file) + ProtoInputFile(parent=template_data, proto_obj=input_proto_file) output_package_content["template_data"] = template_data # Read Messages and Enums @@ -216,6 +209,7 @@ def generate_code(request, response): for output_package_name in sorted(output_paths.union(init_files)): print(f"Writing {output_package_name}", file=sys.stderr) + def read_protobuf_type(item: DescriptorProto, path: List[int], proto_file_data: ProtoInputFile): if isinstance(item, DescriptorProto): if item.options.map_entry: @@ -235,8 +229,8 @@ def read_protobuf_type(item: DescriptorProto, path: List[int], proto_file_data: else: Field(parent=message_data, proto_obj=field, path=path+[2, index]) elif isinstance(item, EnumDescriptorProto): - # Enum - EnumDefinition(proto_obj=item, parent=proto_file_data, path=path) + # Enum + EnumDefinition(proto_obj=item, parent=proto_file_data, path=path) def read_protobuf_service(service: ServiceDescriptorProto, index: int, proto_file_data: ProtoInputFile): diff --git a/src/betterproto/plugin_dataclasses.py b/src/betterproto/plugin_dataclasses.py index 8dd4699f..5da8a6ff 100644 --- a/src/betterproto/plugin_dataclasses.py +++ b/src/betterproto/plugin_dataclasses.py @@ -8,13 +8,17 @@ Union, Type, List, + Tuple, Set, Text, ) import textwrap import betterproto -from betterproto.compile.importing import get_type_reference, parse_source_type_name +from betterproto.compile.importing import ( + get_type_reference, + parse_source_type_name, +) from betterproto.compile.naming import ( pythonize_class_name, pythonize_field_name, @@ -23,15 +27,17 @@ try: # betterproto[compiler] specific dependencies - from google.protobuf.compiler import plugin_pb2 as plugin from google.protobuf.descriptor_pb2 import ( DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FileDescriptorProto, + MethodDescriptorProto, ) except ImportError as err: - missing_import = re.match(r".*(cannot import name .*$)", err.args[0]).group(1) + missing_import = re.match( + r".*(cannot import name .*$)", err.args[0] + ).group(1) print( "\033[31m" f"Unable to import `{missing_import}` from betterproto plugin! " @@ -124,7 +130,7 @@ def __post_init__(self): for field_name, field_val in self.__dataclass_fields__.items(): if field_val is PLACEHOLDER: raise ValueError( - f"`{field_name}` is a required field with no default value." + f"`{field_name}` is a required field." ) @property @@ -143,7 +149,9 @@ def proto_file(self) -> FieldDescriptorProto: @property def comment(self) -> str: - """Crawl the proto source code and retrieve comments for this object.""" + """Crawl the proto source code and retrieve comments + for this object. + """ return get_comment( proto_file=self.proto_file, path=self.path, @@ -162,7 +170,7 @@ class OutputTemplate: datetime_imports: Set[str] = field(default_factory=set) typing_imports: Set[str] = field(default_factory=set) messages: List[Message] = field(default_factory=list) - enums: List[Enum] = field(default_factory=list) + enums: List[EnumDefinition] = field(default_factory=list) services: List[Service] = field(default_factory=list) @property @@ -272,19 +280,20 @@ def get_field_string(self, indent: int = 4) -> str: """Construct string representation of this field as a field.""" name = f"{self.py_name}" annotations = f": {self.annotation}" - betterproto_field_type = \ - f"betterproto.{self.field_type}_field({self.proto_obj.number}" + \ - f"{self.betterproto_field_args}" + \ - ")" + betterproto_field_type = ( + f"betterproto.{self.field_type}_field({self.proto_obj.number}" + + f"{self.betterproto_field_args}" + + ")" + ) return name + annotations + " = " + betterproto_field_type - + @property def betterproto_field_args(self): args = "" if self.field_wraps: args = args + f", wraps={self.field_wraps}" return args - + @property def field_wraps(self) -> Union[str, None]: """Returns betterproto wrapped field type or None. @@ -306,11 +315,13 @@ def repeated(self) -> bool: ): return True return False - + @property def mutable(self) -> bool: """True if the field is a mutable type, otherwise False.""" - return self.field_type.startswith("List[") or self.field_type.startswith("Dict[") + return self.field_type.startswith( + "List[" + ) or self.field_type.startswith("Dict[") @property def field_type(self) -> str: @@ -390,11 +401,16 @@ def annotation(self) -> str: @dataclass class OneOfField(Field): - @property def betterproto_field_args(self): args = super().betterproto_field_args - args = args + f', group="{self.parent.proto_obj.oneof_decl[self.proto_obj.oneof_index].name}"' + group = self.parent.proto_obj.oneof_decl[ + self.proto_obj.oneof_index + ].name + args = ( + args + + f', group="{group}"' + ) return args @@ -520,8 +536,8 @@ def __post_init__(self) -> None: # Check for Optional import if self.py_input_message: - for field in self.py_input_message.fields: - if field.default_value_string == "None": + for f in self.py_input_message.fields: + if f.default_value_string == "None": self.output_file.typing_imports.add("Optional") if "Optional" in self.py_output_message_type: self.output_file.typing_imports.add("Optional") @@ -535,34 +551,35 @@ def __post_init__(self) -> None: self.output_file.typing_imports.add("AsyncIterator") super().__post_init__() # check for unset fields - + @property def mutable_default_arguments(self) -> List[Tuple[str, str]]: """Handle mutable default arguments. - + Returns a list of tuples containing the name and default value for arguments to this message who's default value is mutable. The defaults are swapped out for None and replaced back inside the method's body. - Reference: https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments + Reference: + https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments Returns ------- List[Tuple[str, str]] - Each tuple contains the name and actual default value (as a string) for each - argument with mutable default values. + Each tuple contains the name and actual default value (as a string) + for each argument with mutable default values. """ mutable_default_arguments = [] - + if self.py_input_message: - for field in self.py_input_message.fields: + for f in self.py_input_message.fields: if ( not self.client_streaming - and field.default_value_string != "None" - and field.mutable + and f.default_value_string != "None" + and f.mutable ): mutable_default_arguments.append( - (field.py_name, field.default_value_string) + (f.py_name, f.default_value_string) ) self.output_file.typing_imports.add("Optional") return mutable_default_arguments @@ -579,7 +596,8 @@ def proto_name(self) -> str: @property def route(self) -> str: - return f"/{self.output_file.input_package}.{self.parent.proto_name}/{self.proto_name}" + return f"/{self.output_file.input_package}.\ + {self.parent.proto_name}/{self.proto_name}" @property def py_input_message(self) -> Union[None, Message]: @@ -595,14 +613,16 @@ def py_input_message(self) -> Union[None, Message]: package, name = parse_source_type_name(self.proto_obj.input_type) # Nested types are currently flattened without dots. - # Todo: keep a fully quantified name in types, that is comparable with method.input_type + # Todo: keep a fully quantified name in types, that is + # comparable with method.input_type for msg in self.output_file.messages: - if msg.proto_name == self.proto_name and \ - msg.output_file.input_package == package: + if ( + msg.proto_name == self.proto_name + and msg.output_file.input_package == package + ): return msg return None - @property def py_input_message_type(self) -> str: """String representation of the Python type correspoding to the From e58bbe129810a13a1e423239ce167b51e8dbb717 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 16 Jul 2020 11:54:15 -0500 Subject: [PATCH 06/23] nested msg fix --- src/betterproto/plugin.py | 1 - src/betterproto/plugin_dataclasses.py | 2 +- src/betterproto/templates/template.py.j2 | 2 +- 3 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/betterproto/plugin.py b/src/betterproto/plugin.py index ba57cde0..d090fb65 100755 --- a/src/betterproto/plugin.py +++ b/src/betterproto/plugin.py @@ -240,7 +240,6 @@ def read_protobuf_service(service: ServiceDescriptorProto, index: int, proto_fil path=[6, index], ) for j, method in enumerate(service.method): - ServiceMethod( parent=service_data, proto_obj=method, diff --git a/src/betterproto/plugin_dataclasses.py b/src/betterproto/plugin_dataclasses.py index 5da8a6ff..f82b308f 100644 --- a/src/betterproto/plugin_dataclasses.py +++ b/src/betterproto/plugin_dataclasses.py @@ -617,7 +617,7 @@ def py_input_message(self) -> Union[None, Message]: # comparable with method.input_type for msg in self.output_file.messages: if ( - msg.proto_name == self.proto_name + msg.py_name == self.py_input_message_type and msg.output_file.input_package == package ): return msg diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index 03069c61..fb234614 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -75,7 +75,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {# Client streaming: need a request iterator instead #} , request_iterator: Union[AsyncIterable["{{ method.py_input_message_type }}"], Iterable["{{ method.py_input_message_type }}"]] {%- endif -%} - ) -> {% if method.server_streaming %}AsyncIterator[{{ method.py_output_message_type }}]{% else %}{{ method.py_output_message_type }}{% endif %}: + ) -> {% if method.server_streaming %}AsyncIterator["{{ method.py_output_message_type }}"]{% else %}"{{ method.py_output_message_type }}"{% endif %}: {% if method.comment %} {{ method.comment }} From d51a5b15d6de49db152b3ac78cb47938ff59274f Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 16 Jul 2020 12:47:25 -0500 Subject: [PATCH 07/23] blacken --- src/betterproto/plugin.py | 40 +++++++++++++++++++-------------------- 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/src/betterproto/plugin.py b/src/betterproto/plugin.py index d090fb65..ca9dc996 100755 --- a/src/betterproto/plugin.py +++ b/src/betterproto/plugin.py @@ -45,7 +45,7 @@ Service, ServiceMethod, is_map, - is_oneof + is_oneof, ) @@ -157,7 +157,9 @@ def generate_code(request, response): # Initialize Template data for each package for output_package_name, output_package_content in output_package_files.items(): - template_data = OutputTemplate(input_package=output_package_content["input_package"]) + template_data = OutputTemplate( + input_package=output_package_content["input_package"] + ) for input_proto_file in output_package_content["files"]: ProtoInputFile(parent=template_data, proto_obj=input_proto_file) output_package_content["template_data"] = template_data @@ -166,7 +168,9 @@ def generate_code(request, response): for output_package_name, output_package_content in output_package_files.items(): for proto_file_data in output_package_content["template_data"].input_files: for item, path in traverse(proto_file_data.proto_obj): - read_protobuf_type(item=item, path=path, proto_file_data=proto_file_data) + read_protobuf_type( + item=item, path=path, proto_file_data=proto_file_data + ) # Read Services for output_package_name, output_package_content in output_package_files.items(): @@ -210,40 +214,34 @@ def generate_code(request, response): print(f"Writing {output_package_name}", file=sys.stderr) -def read_protobuf_type(item: DescriptorProto, path: List[int], proto_file_data: ProtoInputFile): +def read_protobuf_type( + item: DescriptorProto, path: List[int], proto_file_data: ProtoInputFile +): if isinstance(item, DescriptorProto): if item.options.map_entry: # Skip generated map entry messages since we just use dicts return # Process Message - message_data = Message( - parent=proto_file_data, - proto_obj=item, - path=path - ) + message_data = Message(parent=proto_file_data, proto_obj=item, path=path) for index, field in enumerate(item.field): if is_map(field, item): - MapField(parent=message_data, proto_obj=field, path=path+[2, index]) + MapField(parent=message_data, proto_obj=field, path=path + [2, index]) elif is_oneof(field): - OneOfField(parent=message_data, proto_obj=field, path=path+[2, index]) + OneOfField(parent=message_data, proto_obj=field, path=path + [2, index]) else: - Field(parent=message_data, proto_obj=field, path=path+[2, index]) + Field(parent=message_data, proto_obj=field, path=path + [2, index]) elif isinstance(item, EnumDescriptorProto): # Enum EnumDefinition(proto_obj=item, parent=proto_file_data, path=path) -def read_protobuf_service(service: ServiceDescriptorProto, index: int, proto_file_data: ProtoInputFile): - service_data = Service( - parent=proto_file_data, - proto_obj=service, - path=[6, index], - ) +def read_protobuf_service( + service: ServiceDescriptorProto, index: int, proto_file_data: ProtoInputFile +): + service_data = Service(parent=proto_file_data, proto_obj=service, path=[6, index],) for j, method in enumerate(service.method): ServiceMethod( - parent=service_data, - proto_obj=method, - path=[6, index, 2, j], + parent=service_data, proto_obj=method, path=[6, index, 2, j], ) From cefd73f8c8d96533f24718699b32c733952ec24b Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 16 Jul 2020 13:24:36 -0500 Subject: [PATCH 08/23] mutable defaults fix --- src/betterproto/plugin_dataclasses.py | 27 +++++++++++++----------- src/betterproto/templates/template.py.j2 | 9 ++++++-- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/src/betterproto/plugin_dataclasses.py b/src/betterproto/plugin_dataclasses.py index f82b308f..0014dee3 100644 --- a/src/betterproto/plugin_dataclasses.py +++ b/src/betterproto/plugin_dataclasses.py @@ -8,7 +8,6 @@ Union, Type, List, - Tuple, Set, Text, ) @@ -319,9 +318,8 @@ def repeated(self) -> bool: @property def mutable(self) -> bool: """True if the field is a mutable type, otherwise False.""" - return self.field_type.startswith( - "List[" - ) or self.field_type.startswith("Dict[") + annotation = self.annotation + return annotation.startswith("List[") or annotation.startswith("Dict[") @property def field_type(self) -> str: @@ -553,7 +551,7 @@ def __post_init__(self) -> None: super().__post_init__() # check for unset fields @property - def mutable_default_arguments(self) -> List[Tuple[str, str]]: + def mutable_default_args(self) -> Dict[str, str]: """Handle mutable default arguments. Returns a list of tuples containing the name and default value @@ -565,11 +563,17 @@ def mutable_default_arguments(self) -> List[Tuple[str, str]]: Returns ------- - List[Tuple[str, str]] - Each tuple contains the name and actual default value (as a string) + Dict[str, str] + Name and actual default value (as a string) for each argument with mutable default values. """ - mutable_default_arguments = [] + mutable_default_args = dict() + + # if self.py_name == "do_thing" and self.py_input_message_type == "DoThingRequest": + # import ptvsd + # ptvsd.enable_attach() + # ptvsd.wait_for_attach() # blocks execution until debugger is attached + # print("done") if self.py_input_message: for f in self.py_input_message.fields: @@ -578,11 +582,10 @@ def mutable_default_arguments(self) -> List[Tuple[str, str]]: and f.default_value_string != "None" and f.mutable ): - mutable_default_arguments.append( - (f.py_name, f.default_value_string) - ) + mutable_default_args[f.py_name] = f.default_value_string self.output_file.typing_imports.add("Optional") - return mutable_default_arguments + + return mutable_default_args @property def py_name(self) -> str: diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index fb234614..cfaa50fb 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -67,7 +67,12 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): Optional[{{ field.py_type }}] {%- else -%} {{ field.py_type }} - {%- endif -%} = {{ field.default_value_string }} + {%- endif -%} = + {%- if field.py_name not in method.mutable_default_args -%} + {{ field.default_value_string }} + {%- else -%} + None + {% endif -%} {%- if not loop.last %}, {% endif -%} {%- endfor -%} {%- endif -%} @@ -80,7 +85,7 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {{ method.comment }} {% endif %} - {%- for py_name, zero in method.mutable_default_args %} + {%- for py_name, zero in method.mutable_default_args.items() %} {{ py_name }} = {{ py_name }} or {{ zero }} {% endfor %} From f2b81bdd2ef26a0bca5fce62048b7854a37628f8 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 16 Jul 2020 13:30:20 -0500 Subject: [PATCH 09/23] more mutable default arg fixes --- src/betterproto/plugin_dataclasses.py | 6 ------ src/betterproto/templates/template.py.j2 | 6 +++--- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/src/betterproto/plugin_dataclasses.py b/src/betterproto/plugin_dataclasses.py index 0014dee3..c4128eb3 100644 --- a/src/betterproto/plugin_dataclasses.py +++ b/src/betterproto/plugin_dataclasses.py @@ -569,12 +569,6 @@ def mutable_default_args(self) -> Dict[str, str]: """ mutable_default_args = dict() - # if self.py_name == "do_thing" and self.py_input_message_type == "DoThingRequest": - # import ptvsd - # ptvsd.enable_attach() - # ptvsd.wait_for_attach() # blocks execution until debugger is attached - # print("done") - if self.py_input_message: for f in self.py_input_message.fields: if ( diff --git a/src/betterproto/templates/template.py.j2 b/src/betterproto/templates/template.py.j2 index cfaa50fb..bbd7cc59 100644 --- a/src/betterproto/templates/template.py.j2 +++ b/src/betterproto/templates/template.py.j2 @@ -63,10 +63,10 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): {%- if not method.client_streaming -%} {%- if method.py_input_message and method.py_input_message.fields -%}, *, {%- for field in method.py_input_message.fields -%} - {{ field.py_name }}: {% if field.default_value_string == "None" and not field.annotation.startswith("Optional[") -%} - Optional[{{ field.py_type }}] + {{ field.py_name }}: {% if field.py_name in method.mutable_default_args and not field.annotation.startswith("Optional[") -%} + Optional[{{ field.annotation }}] {%- else -%} - {{ field.py_type }} + {{ field.annotation }} {%- endif -%} = {%- if field.py_name not in method.mutable_default_args -%} {{ field.default_value_string }} From a77e4626364dde3e5234cc1c1782fcef16980c1a Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 16 Jul 2020 14:57:27 -0500 Subject: [PATCH 10/23] fixes --- src/betterproto/plugin.py | 1 + src/betterproto/plugin_dataclasses.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/betterproto/plugin.py b/src/betterproto/plugin.py index ca9dc996..5cc251c2 100755 --- a/src/betterproto/plugin.py +++ b/src/betterproto/plugin.py @@ -246,6 +246,7 @@ def read_protobuf_service( def main(): + """The plugin's main entry point.""" # Read request message from stdin data = sys.stdin.buffer.read() diff --git a/src/betterproto/plugin_dataclasses.py b/src/betterproto/plugin_dataclasses.py index c4128eb3..3e566c5f 100644 --- a/src/betterproto/plugin_dataclasses.py +++ b/src/betterproto/plugin_dataclasses.py @@ -539,6 +539,7 @@ def __post_init__(self) -> None: self.output_file.typing_imports.add("Optional") if "Optional" in self.py_output_message_type: self.output_file.typing_imports.add("Optional") + self.mutable_default_args # ensure this is called before rendering # Check for Async imports if self.client_streaming: @@ -578,7 +579,7 @@ def mutable_default_args(self) -> Dict[str, str]: ): mutable_default_args[f.py_name] = f.default_value_string self.output_file.typing_imports.add("Optional") - + return mutable_default_args @property From 514548d7576c63a721898c8f72715369644b237a Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 16 Jul 2020 15:03:56 -0500 Subject: [PATCH 11/23] routing fix --- src/betterproto/plugin_dataclasses.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/betterproto/plugin_dataclasses.py b/src/betterproto/plugin_dataclasses.py index 3e566c5f..887cab6f 100644 --- a/src/betterproto/plugin_dataclasses.py +++ b/src/betterproto/plugin_dataclasses.py @@ -594,8 +594,10 @@ def proto_name(self) -> str: @property def route(self) -> str: - return f"/{self.output_file.input_package}.\ - {self.parent.proto_name}/{self.proto_name}" + return ( + f"/{self.output_file.input_package}." + f"{self.parent.proto_name}/{self.proto_name}" + ) @property def py_input_message(self) -> Union[None, Message]: From e970c5a95a1fc5a5411260477859f7d8fa7dcd7d Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 16 Jul 2020 20:37:25 -0500 Subject: [PATCH 12/23] fix all tests --- src/betterproto/plugin.py | 69 +++++++++---------- src/betterproto/plugin_dataclasses.py | 97 ++++++++++++++++++--------- 2 files changed, 97 insertions(+), 69 deletions(-) diff --git a/src/betterproto/plugin.py b/src/betterproto/plugin.py index 5cc251c2..5d7af035 100755 --- a/src/betterproto/plugin.py +++ b/src/betterproto/plugin.py @@ -35,8 +35,8 @@ raise SystemExit(1) from .plugin_dataclasses import ( + Request, OutputTemplate, - ProtoInputFile, Message, Field, OneOfField, @@ -139,58 +139,53 @@ def generate_code(request, response): loader=jinja2.FileSystemLoader("%s/templates/" % os.path.dirname(__file__)), ) template = env.get_template("template.py.j2") - + request_data = Request(plugin_request_obj=request) # Gather output packages - output_package_files = collections.defaultdict() for proto_file in request.proto_file: if ( proto_file.package == "google.protobuf" and "INCLUDE_GOOGLE" not in plugin_options ): + # If not INCLUDE_GOOGLE, + # skip re-compiling Google's well-known types continue - output_package = proto_file.package - output_package_files.setdefault( - output_package, {"input_package": proto_file.package, "files": []} - ) - output_package_files[output_package]["files"].append(proto_file) - - # Initialize Template data for each package - for output_package_name, output_package_content in output_package_files.items(): - template_data = OutputTemplate( - input_package=output_package_content["input_package"] - ) - for input_proto_file in output_package_content["files"]: - ProtoInputFile(parent=template_data, proto_obj=input_proto_file) - output_package_content["template_data"] = template_data + output_package_name = proto_file.package + if output_package_name not in request_data.output_packages: + # Create a new output if there is no output for this package + request_data.output_packages[output_package_name] = OutputTemplate( + parent_request=request_data, + package_proto_obj=proto_file + ) + # Add this input file to the output corresponding to this package + request_data.output_packages[output_package_name].input_files.append(proto_file) # Read Messages and Enums - for output_package_name, output_package_content in output_package_files.items(): - for proto_file_data in output_package_content["template_data"].input_files: - for item, path in traverse(proto_file_data.proto_obj): - read_protobuf_type( - item=item, path=path, proto_file_data=proto_file_data - ) + # We need to read Messages before Services in so that we can + # get the references to input/output messages for each service + for output_package_name, output_package in request_data.output_packages.items(): + for proto_input_file in output_package.input_files: + for item, path in traverse(proto_input_file): + read_protobuf_type(item=item, path=path, output_package=output_package) # Read Services - for output_package_name, output_package_content in output_package_files.items(): - for proto_file_data in output_package_content["template_data"].input_files: - for index, service in enumerate(proto_file_data.proto_obj.service): - read_protobuf_service(service, index, proto_file_data) + for output_package_name, output_package in request_data.output_packages.items(): + for proto_input_file in output_package.input_files: + for index, service in enumerate(proto_input_file.service): + read_protobuf_service(service, index, output_package) - # Render files + # Generate output files output_paths = set() - for output_package_name, output_package_content in output_package_files.items(): - template_data = output_package_content["template_data"] + for output_package_name, template_data in request_data.output_packages.items(): - # Fill response + # Add files to the response object output_path = pathlib.Path(*output_package_name.split("."), "__init__.py") output_paths.add(output_path) f = response.file.add() f.name = str(output_path) - # Render and then format the output file. + # Render and then format the output file f.content = black.format_str( template.render(description=template_data), mode=black.FileMode(target_versions={black.TargetVersion.PY37}), @@ -215,14 +210,14 @@ def generate_code(request, response): def read_protobuf_type( - item: DescriptorProto, path: List[int], proto_file_data: ProtoInputFile + item: DescriptorProto, path: List[int], output_package: OutputTemplate ): if isinstance(item, DescriptorProto): if item.options.map_entry: # Skip generated map entry messages since we just use dicts return # Process Message - message_data = Message(parent=proto_file_data, proto_obj=item, path=path) + message_data = Message(parent=output_package, proto_obj=item, path=path) for index, field in enumerate(item.field): if is_map(field, item): MapField(parent=message_data, proto_obj=field, path=path + [2, index]) @@ -232,13 +227,13 @@ def read_protobuf_type( Field(parent=message_data, proto_obj=field, path=path + [2, index]) elif isinstance(item, EnumDescriptorProto): # Enum - EnumDefinition(proto_obj=item, parent=proto_file_data, path=path) + EnumDefinition(parent=output_package, proto_obj=item, path=path) def read_protobuf_service( - service: ServiceDescriptorProto, index: int, proto_file_data: ProtoInputFile + service: ServiceDescriptorProto, index: int, output_package: OutputTemplate ): - service_data = Service(parent=proto_file_data, proto_obj=service, path=[6, index],) + service_data = Service(parent=output_package, proto_obj=service, path=[6, index],) for j, method in enumerate(service.method): ServiceMethod( parent=service_data, proto_obj=method, path=[6, index, 2, j], diff --git a/src/betterproto/plugin_dataclasses.py b/src/betterproto/plugin_dataclasses.py index 887cab6f..941948d9 100644 --- a/src/betterproto/plugin_dataclasses.py +++ b/src/betterproto/plugin_dataclasses.py @@ -134,7 +134,7 @@ def __post_init__(self): @property def output_file(self) -> OutputTemplate: - current = self.parent + current = self while not isinstance(current, OutputTemplate): current = current.parent return current @@ -142,9 +142,16 @@ def output_file(self) -> OutputTemplate: @property def proto_file(self) -> FieldDescriptorProto: current = self - while not isinstance(current, ProtoInputFile): + while not isinstance(current, OutputTemplate): current = current.parent - return current.proto_obj + return current.package_proto_obj + + @property + def request(self) -> Request: + current = self + while not isinstance(current, OutputTemplate): + current = current.parent + return current.parent_request @property def comment(self) -> str: @@ -157,14 +164,39 @@ def comment(self) -> str: indent=self.comment_indent, ) +@dataclass +class Request: + from typing import Any + plugin_request_obj: Any + output_packages: Dict[str, OutputTemplate] = field(default_factory=dict) + + @property + def all_messages(self) -> List[Message]: + """All of the messages in this request. + + Returns + ------- + List[Message] + List of all of the messages in this request. + """ + return [ + msg + for output in self.output_packages.values() + for msg in output.messages + ] + @dataclass class OutputTemplate: """Representation of an output .py file. - """ - input_package: str - input_files: List[ProtoInputFile] = field(default_factory=list) + Each output file corresponds to a .proto input file, + but may need references to other .proto files to be + built. + """ + parent_request: Request + package_proto_obj: FileDescriptorProto + input_files: List[str] = field(default_factory=list) imports: Set[str] = field(default_factory=set) datetime_imports: Set[str] = field(default_factory=set) typing_imports: Set[str] = field(default_factory=set) @@ -173,40 +205,40 @@ class OutputTemplate: services: List[Service] = field(default_factory=list) @property - def input_filenames(self) -> List[str]: - return [f.name for f in self.input_files] - - -@dataclass -class ProtoInputFile: - """Representation of an input .proto file. - """ + def package(self) -> str: + """Name of input package. - parent: OutputTemplate - proto_obj: FileDescriptorProto + Returns + ------- + str + Name of input package. + """ + return self.package_proto_obj.package @property - def name(self) -> str: - return self.proto_obj.name + def input_filenames(self) -> List[str]: + """Names of the input files used to build this output. - def __post_init__(self): - # Add proto file to output file - self.parent.input_files.append(self) + Returns + ------- + List[str] + Names of the input files used to build this output. + """ + return [f.name for f in self.input_files] @dataclass class Message(ProtoContentBase): """Representation of a protobuf message. """ - - parent: Union[ProtoInputFile, Message] = PLACEHOLDER + parent: Union[Message, OutputTemplate] = PLACEHOLDER proto_obj: DescriptorProto = PLACEHOLDER path: List[int] = PLACEHOLDER fields: List[Union[Field, Message]] = field(default_factory=list) def __post_init__(self): # Add message to output file - if isinstance(self.parent, ProtoInputFile): + if isinstance(self.parent, OutputTemplate): if isinstance(self, EnumDefinition): self.output_file.enums.append(self) else: @@ -383,7 +415,7 @@ def py_type(self) -> str: elif self.proto_obj.type in PROTO_MESSAGE_TYPES: # Type referencing another defined Message or a named enum return get_type_reference( - package=self.output_file.input_package, + package=self.output_file.package, imports=self.output_file.imports, source_type=self.proto_obj.type_name, ) @@ -501,7 +533,7 @@ def default_value_string(self) -> int: @dataclass class Service(ProtoContentBase): - parent: ProtoInputFile = PLACEHOLDER + parent: OutputTemplate = PLACEHOLDER proto_obj: DescriptorProto = PLACEHOLDER path: List[int] = PLACEHOLDER methods: List[ServiceMethod] = field(default_factory=list) @@ -595,7 +627,7 @@ def proto_name(self) -> str: @property def route(self) -> str: return ( - f"/{self.output_file.input_package}." + f"/{self.output_file.package}." f"{self.parent.proto_name}/{self.proto_name}" ) @@ -615,10 +647,10 @@ def py_input_message(self) -> Union[None, Message]: # Nested types are currently flattened without dots. # Todo: keep a fully quantified name in types, that is # comparable with method.input_type - for msg in self.output_file.messages: + for msg in self.request.all_messages: if ( - msg.py_name == self.py_input_message_type - and msg.output_file.input_package == package + msg.py_name == name.replace(".", "") + and msg.output_file.package == package ): return msg return None @@ -635,7 +667,7 @@ def py_input_message_type(self) -> str: input message. """ return get_type_reference( - package=self.output_file.input_package, + package=self.output_file.package, imports=self.output_file.imports, source_type=self.proto_obj.input_type, ).strip('"') @@ -652,9 +684,10 @@ def py_output_message_type(self) -> str: output message. """ return get_type_reference( - package=self.output_file.input_package, + package=self.output_file.package, imports=self.output_file.imports, source_type=self.proto_obj.output_type, + unwrap=False, ).strip('"') @property From 01345c59f91bd04279c1081b7e22d884884df302 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Thu, 16 Jul 2020 21:20:10 -0500 Subject: [PATCH 13/23] black --- src/betterproto/plugin.py | 3 +- src/betterproto/plugin_dataclasses.py | 53 +++++++++------------------ 2 files changed, 19 insertions(+), 37 deletions(-) diff --git a/src/betterproto/plugin.py b/src/betterproto/plugin.py index 5d7af035..9cb65a67 100755 --- a/src/betterproto/plugin.py +++ b/src/betterproto/plugin.py @@ -154,8 +154,7 @@ def generate_code(request, response): if output_package_name not in request_data.output_packages: # Create a new output if there is no output for this package request_data.output_packages[output_package_name] = OutputTemplate( - parent_request=request_data, - package_proto_obj=proto_file + parent_request=request_data, package_proto_obj=proto_file ) # Add this input file to the output corresponding to this package request_data.output_packages[output_package_name].input_files.append(proto_file) diff --git a/src/betterproto/plugin_dataclasses.py b/src/betterproto/plugin_dataclasses.py index 941948d9..bf80ea01 100644 --- a/src/betterproto/plugin_dataclasses.py +++ b/src/betterproto/plugin_dataclasses.py @@ -34,9 +34,7 @@ MethodDescriptorProto, ) except ImportError as err: - missing_import = re.match( - r".*(cannot import name .*$)", err.args[0] - ).group(1) + missing_import = re.match(r".*(cannot import name .*$)", err.args[0]).group(1) print( "\033[31m" f"Unable to import `{missing_import}` from betterproto plugin! " @@ -99,8 +97,7 @@ def get_comment(proto_file, path: List[int], indent: int = 4) -> str: # print(list(sci.path), path, file=sys.stderr) if list(sci.path) == path and sci.leading_comments: lines = textwrap.wrap( - sci.leading_comments.strip().replace("\n", ""), - width=79 - indent, + sci.leading_comments.strip().replace("\n", ""), width=79 - indent, ) if path[-2] == 2 and path[-4] != 6: @@ -128,9 +125,7 @@ def __post_init__(self): """Checks that no fake default fields were left as placeholders.""" for field_name, field_val in self.__dataclass_fields__.items(): if field_val is PLACEHOLDER: - raise ValueError( - f"`{field_name}` is a required field." - ) + raise ValueError(f"`{field_name}` is a required field.") @property def output_file(self) -> OutputTemplate: @@ -145,7 +140,7 @@ def proto_file(self) -> FieldDescriptorProto: while not isinstance(current, OutputTemplate): current = current.parent return current.package_proto_obj - + @property def request(self) -> Request: current = self @@ -159,17 +154,17 @@ def comment(self) -> str: for this object. """ return get_comment( - proto_file=self.proto_file, - path=self.path, - indent=self.comment_indent, + proto_file=self.proto_file, path=self.path, indent=self.comment_indent, ) + @dataclass class Request: from typing import Any + plugin_request_obj: Any output_packages: Dict[str, OutputTemplate] = field(default_factory=dict) - + @property def all_messages(self) -> List[Message]: """All of the messages in this request. @@ -180,9 +175,7 @@ def all_messages(self) -> List[Message]: List of all of the messages in this request. """ return [ - msg - for output in self.output_packages.values() - for msg in output.messages + msg for output in self.output_packages.values() for msg in output.messages ] @@ -194,6 +187,7 @@ class OutputTemplate: but may need references to other .proto files to be built. """ + parent_request: Request package_proto_obj: FileDescriptorProto input_files: List[str] = field(default_factory=list) @@ -231,6 +225,7 @@ def input_filenames(self) -> List[str]: class Message(ProtoContentBase): """Representation of a protobuf message. """ + parent: Union[Message, OutputTemplate] = PLACEHOLDER proto_obj: DescriptorProto = PLACEHOLDER path: List[int] = PLACEHOLDER @@ -340,9 +335,8 @@ def field_wraps(self) -> Union[str, None]: @property def repeated(self) -> bool: - if ( - self.proto_obj.label == FieldDescriptorProto.LABEL_REPEATED - and not is_map(self.proto_obj, self.parent) + if self.proto_obj.label == FieldDescriptorProto.LABEL_REPEATED and not is_map( + self.proto_obj, self.parent ): return True return False @@ -357,9 +351,7 @@ def mutable(self) -> bool: def field_type(self) -> str: """String representation of proto field type.""" return ( - self.proto_obj.Type.Name(self.proto_obj.type) - .lower() - .replace("type_", "") + self.proto_obj.Type.Name(self.proto_obj.type).lower().replace("type_", "") ) @property @@ -434,13 +426,8 @@ class OneOfField(Field): @property def betterproto_field_args(self): args = super().betterproto_field_args - group = self.parent.proto_obj.oneof_decl[ - self.proto_obj.oneof_index - ].name - args = ( - args - + f', group="{group}"' - ) + group = self.parent.proto_obj.oneof_decl[self.proto_obj.oneof_index].name + args = args + f', group="{group}"' return args @@ -465,12 +452,8 @@ def __post_init__(self): parent=self, proto_obj=nested.field[1], # key ).py_type # Get proto types - self.proto_k_type = self.proto_obj.Type.Name( - nested.field[0].type - ) - self.proto_v_type = self.proto_obj.Type.Name( - nested.field[1].type - ) + self.proto_k_type = self.proto_obj.Type.Name(nested.field[0].type) + self.proto_v_type = self.proto_obj.Type.Name(nested.field[1].type) super().__post_init__() # call Field -> Message __post_init__ def get_field_string(self, indent: int = 4) -> str: From 8858a34fdbf3d42298fc6fa0fc3f515bbbbdeb07 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 17 Jul 2020 09:04:53 -0500 Subject: [PATCH 14/23] revert changes to Makefile --- Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 54bd37f6..5a3b7571 100644 --- a/Makefile +++ b/Makefile @@ -8,8 +8,8 @@ help: ## - Show this help. generate: ## - Generate test cases (do this once before running test) poetry run python -m tests.generate -test: ## - Run tests, ingoring collection errors (ex from missing imports) - poetry run pytest --cov betterproto --continue-on-collection-errors +test: ## - Run tests + poetry run pytest --cov betterproto types: ## - Check types with mypy poetry run mypy src/betterproto --ignore-missing-imports From 6747a23f99eb21b57a99d558b9c544f6a5f12def Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 17 Jul 2020 09:19:51 -0500 Subject: [PATCH 15/23] python3.6 support --- src/betterproto/plugin_dataclasses.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/src/betterproto/plugin_dataclasses.py b/src/betterproto/plugin_dataclasses.py index bf80ea01..0655af70 100644 --- a/src/betterproto/plugin_dataclasses.py +++ b/src/betterproto/plugin_dataclasses.py @@ -1,6 +1,4 @@ #!/usr/bin/env python -from __future__ import annotations - import re from dataclasses import dataclass from dataclasses import field @@ -8,6 +6,7 @@ Union, Type, List, + Dict, Set, Text, ) @@ -128,7 +127,7 @@ def __post_init__(self): raise ValueError(f"`{field_name}` is a required field.") @property - def output_file(self) -> OutputTemplate: + def output_file(self) -> "OutputTemplate": current = self while not isinstance(current, OutputTemplate): current = current.parent @@ -142,7 +141,7 @@ def proto_file(self) -> FieldDescriptorProto: return current.package_proto_obj @property - def request(self) -> Request: + def request(self) -> "Request": current = self while not isinstance(current, OutputTemplate): current = current.parent @@ -163,10 +162,10 @@ class Request: from typing import Any plugin_request_obj: Any - output_packages: Dict[str, OutputTemplate] = field(default_factory=dict) + output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict) @property - def all_messages(self) -> List[Message]: + def all_messages(self) -> List["Message"]: """All of the messages in this request. Returns @@ -194,9 +193,9 @@ class OutputTemplate: imports: Set[str] = field(default_factory=set) datetime_imports: Set[str] = field(default_factory=set) typing_imports: Set[str] = field(default_factory=set) - messages: List[Message] = field(default_factory=list) - enums: List[EnumDefinition] = field(default_factory=list) - services: List[Service] = field(default_factory=list) + messages: List["Message"] = field(default_factory=list) + enums: List["EnumDefinition"] = field(default_factory=list) + services: List["Service"] = field(default_factory=list) @property def package(self) -> str: @@ -226,10 +225,10 @@ class Message(ProtoContentBase): """Representation of a protobuf message. """ - parent: Union[Message, OutputTemplate] = PLACEHOLDER + parent: Union["Message", OutputTemplate] = PLACEHOLDER proto_obj: DescriptorProto = PLACEHOLDER path: List[int] = PLACEHOLDER - fields: List[Union[Field, Message]] = field(default_factory=list) + fields: List[Union["Field", "Message"]] = field(default_factory=list) def __post_init__(self): # Add message to output file @@ -481,7 +480,7 @@ class EnumDefinition(Message): """Representation of a proto Enum definition.""" proto_obj: EnumDescriptorProto = PLACEHOLDER - entries: List[EnumDefinition.EnumEntry] = PLACEHOLDER + entries: List["EnumDefinition.EnumEntry"] = PLACEHOLDER @dataclass(unsafe_hash=True) class EnumEntry: @@ -519,7 +518,7 @@ class Service(ProtoContentBase): parent: OutputTemplate = PLACEHOLDER proto_obj: DescriptorProto = PLACEHOLDER path: List[int] = PLACEHOLDER - methods: List[ServiceMethod] = field(default_factory=list) + methods: List["ServiceMethod"] = field(default_factory=list) def __post_init__(self) -> None: # Add service to output file From 586bb010042402bf2ef3a1bd583aaca5a68885da Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 17 Jul 2020 14:05:42 -0500 Subject: [PATCH 16/23] restructure files --- src/betterproto/plugin/__init__.py | 49 ++ src/betterproto/plugin/models.py | 680 ++++++++++++++++++ .../{plugin.py => plugin/parser.py} | 135 +--- 3 files changed, 746 insertions(+), 118 deletions(-) create mode 100644 src/betterproto/plugin/__init__.py create mode 100644 src/betterproto/plugin/models.py rename src/betterproto/{plugin.py => plugin/parser.py} (61%) mode change 100755 => 100644 diff --git a/src/betterproto/plugin/__init__.py b/src/betterproto/plugin/__init__.py new file mode 100644 index 00000000..7d54c825 --- /dev/null +++ b/src/betterproto/plugin/__init__.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python +import sys +import os + +from google.protobuf.compiler import plugin_pb2 as plugin + +from betterproto.plugin.parser import generate_code + + +def main(): + + """The plugin's main entry point.""" + # Read request message from stdin + data = sys.stdin.buffer.read() + + # Parse request + request = plugin.CodeGeneratorRequest() + request.ParseFromString(data) + + dump_file = os.getenv("BETTERPROTO_DUMP") + if dump_file: + dump_request(dump_file, request) + + # Create response + response = plugin.CodeGeneratorResponse() + + # Generate code + generate_code(request, response) + + # Serialise response message + output = response.SerializeToString() + + # Write to stdout + sys.stdout.buffer.write(output) + + +def dump_request(dump_file: str, request: plugin.CodeGeneratorRequest): + """ + For developers: Supports running plugin.py standalone so its possible to debug it. + Run protoc (or generate.py) with BETTERPROTO_DUMP="yourfile.bin" to write the request to a file. + Then run plugin.py from your IDE in debugging mode, and redirect stdin to the file. + """ + with open(str(dump_file), "wb") as fh: + sys.stderr.write(f"\033[31mWriting input from protoc to: {dump_file}\033[0m\n") + fh.write(request.SerializeToString()) + + +if __name__ == "__main__": + main() diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py new file mode 100644 index 00000000..93e83867 --- /dev/null +++ b/src/betterproto/plugin/models.py @@ -0,0 +1,680 @@ +import re +from dataclasses import dataclass +from dataclasses import field +from typing import ( + Union, + Type, + List, + Dict, + Set, + Text, +) +import textwrap + +import betterproto +from betterproto.compile.importing import ( + get_type_reference, + parse_source_type_name, +) +from betterproto.compile.naming import ( + pythonize_class_name, + pythonize_field_name, + pythonize_method_name, +) + +try: + # betterproto[compiler] specific dependencies + from google.protobuf.descriptor_pb2 import ( + DescriptorProto, + EnumDescriptorProto, + FieldDescriptorProto, + FileDescriptorProto, + MethodDescriptorProto, + ) +except ImportError as err: + missing_import = re.match(r".*(cannot import name .*$)", err.args[0]).group(1) + print( + "\033[31m" + f"Unable to import `{missing_import}` from betterproto plugin! " + "Please ensure that you've installed betterproto as " + '`pip install "betterproto[compiler]"` so that compiler dependencies ' + "are included." + "\033[0m" + ) + raise SystemExit(1) + +# Create a unique placeholder to deal with +# https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses +PLACEHOLDER = object() + +# Organize proto types into categories +PROTO_FLOAT_TYPES = ( + FieldDescriptorProto.TYPE_DOUBLE, # 1 + FieldDescriptorProto.TYPE_FLOAT, # 2 +) +PROTO_INT_TYPES = ( + FieldDescriptorProto.TYPE_INT64, # 3 + FieldDescriptorProto.TYPE_UINT64, # 4 + FieldDescriptorProto.TYPE_INT32, # 5 + FieldDescriptorProto.TYPE_FIXED64, # 6 + FieldDescriptorProto.TYPE_FIXED32, # 7 + FieldDescriptorProto.TYPE_UINT32, # 13 + FieldDescriptorProto.TYPE_SFIXED32, # 15 + FieldDescriptorProto.TYPE_SFIXED64, # 16 + FieldDescriptorProto.TYPE_SINT32, # 17 + FieldDescriptorProto.TYPE_SINT64, # 18 +) +PROTO_BOOL_TYPES = (FieldDescriptorProto.TYPE_BOOL,) # 8 +PROTO_STR_TYPES = (FieldDescriptorProto.TYPE_STRING,) # 9 +PROTO_BYTES_TYPES = (FieldDescriptorProto.TYPE_BYTES,) # 12 +PROTO_MESSAGE_TYPES = ( + FieldDescriptorProto.TYPE_MESSAGE, # 11 + FieldDescriptorProto.TYPE_ENUM, # 14 +) +PROTO_MAP_TYPES = (FieldDescriptorProto.TYPE_MESSAGE,) # 11 +PROTO_PACKED_TYPES = ( + FieldDescriptorProto.TYPE_DOUBLE, # 1 + FieldDescriptorProto.TYPE_FLOAT, # 2 + FieldDescriptorProto.TYPE_INT64, # 3 + FieldDescriptorProto.TYPE_UINT64, # 4 + FieldDescriptorProto.TYPE_INT32, # 5 + FieldDescriptorProto.TYPE_FIXED64, # 6 + FieldDescriptorProto.TYPE_FIXED32, # 7 + FieldDescriptorProto.TYPE_BOOL, # 8 + FieldDescriptorProto.TYPE_UINT32, # 13 + FieldDescriptorProto.TYPE_SFIXED32, # 15 + FieldDescriptorProto.TYPE_SFIXED64, # 16 + FieldDescriptorProto.TYPE_SINT32, # 17 + FieldDescriptorProto.TYPE_SINT64, # 18 +) + + +def get_comment(proto_file, path: List[int], indent: int = 4) -> str: + pad = " " * indent + for sci in proto_file.source_code_info.location: + # print(list(sci.path), path, file=sys.stderr) + if list(sci.path) == path and sci.leading_comments: + lines = textwrap.wrap( + sci.leading_comments.strip().replace("\n", ""), width=79 - indent, + ) + + if path[-2] == 2 and path[-4] != 6: + # This is a field + return f"{pad}# " + f"\n{pad}# ".join(lines) + else: + # This is a message, enum, service, or method + if len(lines) == 1 and len(lines[0]) < 79 - indent - 6: + lines[0] = lines[0].strip('"') + return f'{pad}"""{lines[0]}"""' + else: + joined = f"\n{pad}".join(lines) + return f'{pad}"""\n{pad}{joined}\n{pad}"""' + + return "" + + +class ProtoContentBase: + """Methods common to Message, Service and ServiceMethod.""" + + path: List[int] + comment_indent: int = 4 + + def __post_init__(self): + """Checks that no fake default fields were left as placeholders.""" + for field_name, field_val in self.__dataclass_fields__.items(): + if field_val is PLACEHOLDER: + raise ValueError(f"`{field_name}` is a required field.") + + @property + def output_file(self) -> "OutputTemplate": + current = self + while not isinstance(current, OutputTemplate): + current = current.parent + return current + + @property + def proto_file(self) -> FieldDescriptorProto: + current = self + while not isinstance(current, OutputTemplate): + current = current.parent + return current.package_proto_obj + + @property + def request(self) -> "Request": + current = self + while not isinstance(current, OutputTemplate): + current = current.parent + return current.parent_request + + @property + def comment(self) -> str: + """Crawl the proto source code and retrieve comments + for this object. + """ + return get_comment( + proto_file=self.proto_file, path=self.path, indent=self.comment_indent, + ) + + +@dataclass +class Request: + from typing import Any + + plugin_request_obj: Any + output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict) + + @property + def all_messages(self) -> List["Message"]: + """All of the messages in this request. + + Returns + ------- + List[Message] + List of all of the messages in this request. + """ + return [ + msg for output in self.output_packages.values() for msg in output.messages + ] + + +@dataclass +class OutputTemplate: + """Representation of an output .py file. + + Each output file corresponds to a .proto input file, + but may need references to other .proto files to be + built. + """ + + parent_request: Request + package_proto_obj: FileDescriptorProto + input_files: List[str] = field(default_factory=list) + imports: Set[str] = field(default_factory=set) + datetime_imports: Set[str] = field(default_factory=set) + typing_imports: Set[str] = field(default_factory=set) + messages: List["Message"] = field(default_factory=list) + enums: List["EnumDefinition"] = field(default_factory=list) + services: List["Service"] = field(default_factory=list) + + @property + def package(self) -> str: + """Name of input package. + + Returns + ------- + str + Name of input package. + """ + return self.package_proto_obj.package + + @property + def input_filenames(self) -> List[str]: + """Names of the input files used to build this output. + + Returns + ------- + List[str] + Names of the input files used to build this output. + """ + return [f.name for f in self.input_files] + + +@dataclass +class Message(ProtoContentBase): + """Representation of a protobuf message. + """ + + parent: Union["Message", OutputTemplate] = PLACEHOLDER + proto_obj: DescriptorProto = PLACEHOLDER + path: List[int] = PLACEHOLDER + fields: List[Union["Field", "Message"]] = field(default_factory=list) + + def __post_init__(self): + # Add message to output file + if isinstance(self.parent, OutputTemplate): + if isinstance(self, EnumDefinition): + self.output_file.enums.append(self) + else: + self.output_file.messages.append(self) + super().__post_init__() + + @property + def proto_name(self) -> str: + return self.proto_obj.name + + @property + def py_name(self) -> str: + return pythonize_class_name(self.proto_name) + + @property + def annotation(self) -> str: + if self.repeated: + return f"List[{self.py_name}]" + return self.py_name + + +def is_map( + proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto +) -> bool: + """True if proto_field_obj is a map, otherwise False. + """ + if proto_field_obj.type == FieldDescriptorProto.TYPE_MESSAGE: + # This might be a map... + message_type = proto_field_obj.type_name.split(".").pop().lower() + map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry" + if message_type == map_entry: + for nested in parent_message.nested_type: # parent message + if nested.name.replace("_", "").lower() == map_entry: + if nested.options.map_entry: + return True + return False + + +def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool: + """True if proto_field_obj is a OneOf, otherwise False. + """ + if proto_field_obj.HasField("oneof_index"): + return True + return False + + +@dataclass +class Field(Message): + parent: Message = PLACEHOLDER + proto_obj: FieldDescriptorProto = PLACEHOLDER + + def __post_init__(self): + # Add field to message + self.parent.fields.append(self) + # Check for new imports + annotation = self.annotation + if "Optional[" in annotation: + self.output_file.typing_imports.add("Optional") + if "List[" in annotation: + self.output_file.typing_imports.add("List") + if "Dict[" in annotation: + self.output_file.typing_imports.add("Dict") + if "timedelta" in annotation: + self.output_file.datetime_imports.add("timedelta") + if "datetime" in annotation: + self.output_file.datetime_imports.add("datetime") + super().__post_init__() # call Field -> Message __post_init__ + + def get_field_string(self, indent: int = 4) -> str: + """Construct string representation of this field as a field.""" + name = f"{self.py_name}" + annotations = f": {self.annotation}" + betterproto_field_type = ( + f"betterproto.{self.field_type}_field({self.proto_obj.number}" + + f"{self.betterproto_field_args}" + + ")" + ) + return name + annotations + " = " + betterproto_field_type + + @property + def betterproto_field_args(self): + args = "" + if self.field_wraps: + args = args + f", wraps={self.field_wraps}" + return args + + @property + def field_wraps(self) -> Union[str, None]: + """Returns betterproto wrapped field type or None. + """ + match_wrapper = re.match( + r"\.google\.protobuf\.(.+)Value", self.proto_obj.type_name + ) + if match_wrapper: + wrapped_type = "TYPE_" + match_wrapper.group(1).upper() + if hasattr(betterproto, wrapped_type): + return f"betterproto.{wrapped_type}" + return None + + @property + def repeated(self) -> bool: + if self.proto_obj.label == FieldDescriptorProto.LABEL_REPEATED and not is_map( + self.proto_obj, self.parent + ): + return True + return False + + @property + def mutable(self) -> bool: + """True if the field is a mutable type, otherwise False.""" + annotation = self.annotation + return annotation.startswith("List[") or annotation.startswith("Dict[") + + @property + def field_type(self) -> str: + """String representation of proto field type.""" + return ( + self.proto_obj.Type.Name(self.proto_obj.type).lower().replace("type_", "") + ) + + @property + def default_value_string(self) -> Union[Text, None, float, int]: + """Python representation of the default proto value. + """ + if self.repeated: + return "[]" + if self.py_type == "int": + return "0" + if self.py_type == "float": + return "0.0" + elif self.py_type == "bool": + return "False" + elif self.py_type == "str": + return '""' + elif self.py_type == "bytes": + return 'b""' + else: + # Message type + return "None" + + @property + def packed(self) -> bool: + """True if the wire representation is a packed format.""" + if self.repeated and self.proto_obj.type in PROTO_PACKED_TYPES: + return True + return False + + @property + def py_name(self) -> str: + """Pythonized name.""" + return pythonize_field_name(self.proto_name) + + @property + def proto_name(self) -> str: + """Original protobuf name.""" + return self.proto_obj.name + + @property + def py_type(self) -> str: + """String representation of Python type.""" + if self.proto_obj.type in PROTO_FLOAT_TYPES: + return "float" + elif self.proto_obj.type in PROTO_INT_TYPES: + return "int" + elif self.proto_obj.type in PROTO_BOOL_TYPES: + return "bool" + elif self.proto_obj.type in PROTO_STR_TYPES: + return "str" + elif self.proto_obj.type in PROTO_BYTES_TYPES: + return "bytes" + elif self.proto_obj.type in PROTO_MESSAGE_TYPES: + # Type referencing another defined Message or a named enum + return get_type_reference( + package=self.output_file.package, + imports=self.output_file.imports, + source_type=self.proto_obj.type_name, + ) + else: + raise NotImplementedError(f"Unknown type {field.type}") + + @property + def annotation(self) -> str: + if self.repeated: + return f"List[{self.py_type}]" + return self.py_type + + +@dataclass +class OneOfField(Field): + @property + def betterproto_field_args(self): + args = super().betterproto_field_args + group = self.parent.proto_obj.oneof_decl[self.proto_obj.oneof_index].name + args = args + f', group="{group}"' + return args + + +@dataclass +class MapField(Field): + py_k_type: Type = PLACEHOLDER + py_v_type: Type = PLACEHOLDER + proto_k_type: str = PLACEHOLDER + proto_v_type: str = PLACEHOLDER + + def __post_init__(self): + """Explore nested types and set k_type and v_type if unset.""" + map_entry = f"{self.proto_obj.name.replace('_', '').lower()}entry" + for nested in self.parent.proto_obj.nested_type: + if nested.name.replace("_", "").lower() == map_entry: + if nested.options.map_entry: + # Get Python types + self.py_k_type = Field( + parent=self, proto_obj=nested.field[0], # key + ).py_type + self.py_v_type = Field( + parent=self, proto_obj=nested.field[1], # key + ).py_type + # Get proto types + self.proto_k_type = self.proto_obj.Type.Name(nested.field[0].type) + self.proto_v_type = self.proto_obj.Type.Name(nested.field[1].type) + super().__post_init__() # call Field -> Message __post_init__ + + def get_field_string(self, indent: int = 4) -> str: + """Construct string representation of this field.""" + name = f"{self.py_name}" + annotations = f": {self.annotation}" + betterproto_field_type = ( + f"betterproto.map_field(" + f"{self.proto_obj.number}, betterproto.{self.proto_k_type}, " + f"betterproto.{self.proto_v_type})" + ) + return name + annotations + " = " + betterproto_field_type + + @property + def annotation(self): + return f"Dict[{self.py_k_type}, {self.py_v_type}]" + + @property + def repeated(self): + return False # maps cannot be repeated + + +@dataclass +class EnumDefinition(Message): + """Representation of a proto Enum definition.""" + + proto_obj: EnumDescriptorProto = PLACEHOLDER + entries: List["EnumDefinition.EnumEntry"] = PLACEHOLDER + + @dataclass(unsafe_hash=True) + class EnumEntry: + """Representation of an Enum entry.""" + + name: str + value: int + comment: str + + def __post_init__(self): + # Get entries + self.entries = [ + self.EnumEntry( + name=v.name, + value=v.number, + comment=get_comment( + proto_file=self.proto_file, path=self.path + [2, i] + ), + ) + for i, v in enumerate(self.proto_obj.value) + ] + super().__post_init__() # call Message __post_init__ + + @property + def default_value_string(self) -> int: + """Python representation of the default value for Enums. + + As per the spec, this is the first value of the Enum. + """ + return str(self.entries[0].value) # should ALWAYS be int(0)! + + +@dataclass +class Service(ProtoContentBase): + parent: OutputTemplate = PLACEHOLDER + proto_obj: DescriptorProto = PLACEHOLDER + path: List[int] = PLACEHOLDER + methods: List["ServiceMethod"] = field(default_factory=list) + + def __post_init__(self) -> None: + # Add service to output file + self.output_file.services.append(self) + super().__post_init__() # check for unset fields + + @property + def proto_name(self): + return self.proto_obj.name + + @property + def py_name(self): + return pythonize_class_name(self.proto_name) + + +@dataclass +class ServiceMethod(ProtoContentBase): + + parent: Service + proto_obj: MethodDescriptorProto + path: List[int] = PLACEHOLDER + comment_indent: int = 8 + + def __post_init__(self) -> None: + # Add method to service + self.parent.methods.append(self) + + # Check for Optional import + if self.py_input_message: + for f in self.py_input_message.fields: + if f.default_value_string == "None": + self.output_file.typing_imports.add("Optional") + if "Optional" in self.py_output_message_type: + self.output_file.typing_imports.add("Optional") + self.mutable_default_args # ensure this is called before rendering + + # Check for Async imports + if self.client_streaming: + self.output_file.typing_imports.add("AsyncIterable") + self.output_file.typing_imports.add("Iterable") + self.output_file.typing_imports.add("Union") + if self.server_streaming: + self.output_file.typing_imports.add("AsyncIterator") + + super().__post_init__() # check for unset fields + + @property + def mutable_default_args(self) -> Dict[str, str]: + """Handle mutable default arguments. + + Returns a list of tuples containing the name and default value + for arguments to this message who's default value is mutable. + The defaults are swapped out for None and replaced back inside + the method's body. + Reference: + https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments + + Returns + ------- + Dict[str, str] + Name and actual default value (as a string) + for each argument with mutable default values. + """ + mutable_default_args = dict() + + if self.py_input_message: + for f in self.py_input_message.fields: + if ( + not self.client_streaming + and f.default_value_string != "None" + and f.mutable + ): + mutable_default_args[f.py_name] = f.default_value_string + self.output_file.typing_imports.add("Optional") + + return mutable_default_args + + @property + def py_name(self) -> str: + """Pythonized method name.""" + return pythonize_method_name(self.proto_obj.name) + + @property + def proto_name(self) -> str: + """Original protobuf name.""" + return self.proto_obj.name + + @property + def route(self) -> str: + return ( + f"/{self.output_file.package}." + f"{self.parent.proto_name}/{self.proto_name}" + ) + + @property + def py_input_message(self) -> Union[None, Message]: + """Find the input message object. + + Returns + ------- + Union[None, Message] + Method instance representing the input message. + If not input message could be found or there are no + input messages, None is returned. + """ + package, name = parse_source_type_name(self.proto_obj.input_type) + + # Nested types are currently flattened without dots. + # Todo: keep a fully quantified name in types, that is + # comparable with method.input_type + for msg in self.request.all_messages: + if ( + msg.py_name == name.replace(".", "") + and msg.output_file.package == package + ): + return msg + return None + + @property + def py_input_message_type(self) -> str: + """String representation of the Python type correspoding to the + input message. + + Returns + ------- + str + String representation of the Python type correspoding to the + input message. + """ + return get_type_reference( + package=self.output_file.package, + imports=self.output_file.imports, + source_type=self.proto_obj.input_type, + ).strip('"') + + @property + def py_output_message_type(self) -> str: + """String representation of the Python type correspoding to the + output message. + + Returns + ------- + str + String representation of the Python type correspoding to the + output message. + """ + return get_type_reference( + package=self.output_file.package, + imports=self.output_file.imports, + source_type=self.proto_obj.output_type, + unwrap=False, + ).strip('"') + + @property + def client_streaming(self) -> bool: + return self.proto_obj.client_streaming + + @property + def server_streaming(self) -> bool: + return self.proto_obj.server_streaming diff --git a/src/betterproto/plugin.py b/src/betterproto/plugin/parser.py old mode 100755 new mode 100644 similarity index 61% rename from src/betterproto/plugin.py rename to src/betterproto/plugin/parser.py index 9cb65a67..be6b5feb --- a/src/betterproto/plugin.py +++ b/src/betterproto/plugin/parser.py @@ -1,15 +1,9 @@ -#!/usr/bin/env python -import collections import itertools import os.path import pathlib import sys import textwrap -from typing import List, Union - - -from betterproto.lib.google.protobuf import ServiceDescriptorProto -from betterproto.compile.importing import get_type_reference +from typing import List, Union, Iterator try: # betterproto[compiler] specific dependencies @@ -19,8 +13,8 @@ DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, + ServiceDescriptorProto ) - from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest import jinja2 except ImportError as err: missing_import = err.args[0][17:-1] @@ -34,7 +28,7 @@ ) raise SystemExit(1) -from .plugin_dataclasses import ( +from betterproto.plugin.models import ( Request, OutputTemplate, Message, @@ -49,41 +43,7 @@ ) -def py_type(package: str, imports: set, field: FieldDescriptorProto) -> str: - if field.type in [1, 2]: - return "float" - elif field.type in [3, 4, 5, 6, 7, 13, 15, 16, 17, 18]: - return "int" - elif field.type == 8: - return "bool" - elif field.type == 9: - return "str" - elif field.type in [11, 14]: - # Type referencing another defined Message or a named enum - return get_type_reference(package, imports, field.type_name) - elif field.type == 12: - return "bytes" - else: - raise NotImplementedError(f"Unknown type {field.type}") - - -def get_py_zero(type_num: int) -> Union[str, float]: - zero: Union[str, float] = 0 - if type_num in []: - zero = 0.0 - elif type_num == 8: - zero = "False" - elif type_num == 9: - zero = '""' - elif type_num == 11: - zero = "None" - elif type_num == 12: - zero = 'b""' - - return zero - - -def traverse(proto_file): +def traverse(proto_file: FieldDescriptorProto) -> Iterator: # Todo: Keep information about nested hierarchy def _traverse(path, items, prefix=""): for i, item in enumerate(items): @@ -106,37 +66,18 @@ def _traverse(path, items, prefix=""): ) -def get_comment(proto_file, path: List[int], indent: int = 4) -> str: - pad = " " * indent - for sci in proto_file.source_code_info.location: - # print(list(sci.path), path, file=sys.stderr) - if list(sci.path) == path and sci.leading_comments: - lines = textwrap.wrap( - sci.leading_comments.strip().replace("\n", ""), width=79 - indent - ) - - if path[-2] == 2 and path[-4] != 6: - # This is a field - return f"{pad}# " + f"\n{pad}# ".join(lines) - else: - # This is a message, enum, service, or method - if len(lines) == 1 and len(lines[0]) < 79 - indent - 6: - lines[0] = lines[0].strip('"') - return f'{pad}"""{lines[0]}"""' - else: - joined = f"\n{pad}".join(lines) - return f'{pad}"""\n{pad}{joined}\n{pad}"""' - - return "" - - -def generate_code(request, response): +def generate_code( + request: plugin.CodeGeneratorRequest, + response: plugin.CodeGeneratorResponse +) -> None: plugin_options = request.parameter.split(",") if request.parameter else [] + templates_folder = os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..', 'templates')) + env = jinja2.Environment( trim_blocks=True, lstrip_blocks=True, - loader=jinja2.FileSystemLoader("%s/templates/" % os.path.dirname(__file__)), + loader=jinja2.FileSystemLoader(templates_folder), ) template = env.get_template("template.py.j2") request_data = Request(plugin_request_obj=request) @@ -174,18 +115,18 @@ def generate_code(request, response): read_protobuf_service(service, index, output_package) # Generate output files - output_paths = set() + output_paths: pathlib.Path = set() for output_package_name, template_data in request_data.output_packages.items(): # Add files to the response object output_path = pathlib.Path(*output_package_name.split("."), "__init__.py") output_paths.add(output_path) - f = response.file.add() - f.name = str(output_path) + f: response.File = response.file.add() + f.name: str = str(output_path) # Render and then format the output file - f.content = black.format_str( + f.content: str = black.format_str( template.render(description=template_data), mode=black.FileMode(target_versions={black.TargetVersion.PY37}), ) @@ -210,7 +151,7 @@ def generate_code(request, response): def read_protobuf_type( item: DescriptorProto, path: List[int], output_package: OutputTemplate -): +) -> None: if isinstance(item, DescriptorProto): if item.options.map_entry: # Skip generated map entry messages since we just use dicts @@ -231,51 +172,9 @@ def read_protobuf_type( def read_protobuf_service( service: ServiceDescriptorProto, index: int, output_package: OutputTemplate -): +) -> None: service_data = Service(parent=output_package, proto_obj=service, path=[6, index],) for j, method in enumerate(service.method): ServiceMethod( parent=service_data, proto_obj=method, path=[6, index, 2, j], ) - - -def main(): - - """The plugin's main entry point.""" - # Read request message from stdin - data = sys.stdin.buffer.read() - - # Parse request - request = plugin.CodeGeneratorRequest() - request.ParseFromString(data) - - dump_file = os.getenv("BETTERPROTO_DUMP") - if dump_file: - dump_request(dump_file, request) - - # Create response - response = plugin.CodeGeneratorResponse() - - # Generate code - generate_code(request, response) - - # Serialise response message - output = response.SerializeToString() - - # Write to stdout - sys.stdout.buffer.write(output) - - -def dump_request(dump_file: str, request: CodeGeneratorRequest): - """ - For developers: Supports running plugin.py standalone so its possible to debug it. - Run protoc (or generate.py) with BETTERPROTO_DUMP="yourfile.bin" to write the request to a file. - Then run plugin.py from your IDE in debugging mode, and redirect stdin to the file. - """ - with open(str(dump_file), "wb") as fh: - sys.stderr.write(f"\033[31mWriting input from protoc to: {dump_file}\033[0m\n") - fh.write(request.SerializeToString()) - - -if __name__ == "__main__": - main() From f9dd7ebd0631dcc2d56c4906417f702221b635e7 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 17 Jul 2020 14:20:13 -0500 Subject: [PATCH 17/23] docs --- src/betterproto/plugin/models.py | 31 +++++++++++++++++++++++++++++++ src/betterproto/plugin/parser.py | 12 ++++++------ 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index 93e83867..c9609c46 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -1,3 +1,34 @@ +"""Plugin model dataclasses. + +These classes are meant to be an intermediate representation +of protbuf objects. They are used to organize the data collected during parsing. + +The general intention is to create a doubly-linked tree-like structure +with the following types of references: +- Downwards references: from message -> fields, from output package -> messages +or from service -> service methods +- Upwards references: from field -> message, message -> package. +- Input/ouput message references: from a service method to it's corresponding +input/output messages, which may even be in another package. + +There are convenience methods to allow climbing up and down this tree, for +example to retrieve the list of all messages that are in the same package as +the current message. + +Most of these classes take as inputs: +- proto_obj: A reference to it's corresponding protobuf object as +presented by the protoc plugin. +- parent: a reference to the parent object in the tree. + +With this information, the class is able to expose attributes, +such as a pythonized name, that will be calculated from proto_obj. + +The instantiation should also attach a reference to the new object +into the corresponding place within it's parent object. For example, +instantiating field `A` with parent message `B` should add a +reference to `A` to `B`'s `fields` attirbute. +""" + import re from dataclasses import dataclass from dataclasses import field diff --git a/src/betterproto/plugin/parser.py b/src/betterproto/plugin/parser.py index be6b5feb..0bc17acc 100644 --- a/src/betterproto/plugin/parser.py +++ b/src/betterproto/plugin/parser.py @@ -2,8 +2,7 @@ import os.path import pathlib import sys -import textwrap -from typing import List, Union, Iterator +from typing import List, Iterator try: # betterproto[compiler] specific dependencies @@ -13,7 +12,7 @@ DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, - ServiceDescriptorProto + ServiceDescriptorProto, ) import jinja2 except ImportError as err: @@ -67,12 +66,13 @@ def _traverse(path, items, prefix=""): def generate_code( - request: plugin.CodeGeneratorRequest, - response: plugin.CodeGeneratorResponse + request: plugin.CodeGeneratorRequest, response: plugin.CodeGeneratorResponse ) -> None: plugin_options = request.parameter.split(",") if request.parameter else [] - templates_folder = os.path.abspath(os.path.join(os.path.dirname( __file__ ), '..', 'templates')) + templates_folder = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "templates") + ) env = jinja2.Environment( trim_blocks=True, From 62eea354439a31b0a4d28d4934855e078b73a1d0 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 17 Jul 2020 14:26:09 -0500 Subject: [PATCH 18/23] file cleanup --- src/betterproto/plugin.bat | 2 - src/betterproto/plugin/plugin.bat | 2 + src/betterproto/plugin_dataclasses.py | 681 -------------------------- 3 files changed, 2 insertions(+), 683 deletions(-) delete mode 100644 src/betterproto/plugin.bat create mode 100644 src/betterproto/plugin/plugin.bat delete mode 100644 src/betterproto/plugin_dataclasses.py diff --git a/src/betterproto/plugin.bat b/src/betterproto/plugin.bat deleted file mode 100644 index 9b837d7d..00000000 --- a/src/betterproto/plugin.bat +++ /dev/null @@ -1,2 +0,0 @@ -@SET plugin_dir=%~dp0 -@python %plugin_dir%/plugin.py %* \ No newline at end of file diff --git a/src/betterproto/plugin/plugin.bat b/src/betterproto/plugin/plugin.bat new file mode 100644 index 00000000..2a4444db --- /dev/null +++ b/src/betterproto/plugin/plugin.bat @@ -0,0 +1,2 @@ +@SET plugin_dir=%~dp0 +@python -m %plugin_dir% %* \ No newline at end of file diff --git a/src/betterproto/plugin_dataclasses.py b/src/betterproto/plugin_dataclasses.py deleted file mode 100644 index 0655af70..00000000 --- a/src/betterproto/plugin_dataclasses.py +++ /dev/null @@ -1,681 +0,0 @@ -#!/usr/bin/env python -import re -from dataclasses import dataclass -from dataclasses import field -from typing import ( - Union, - Type, - List, - Dict, - Set, - Text, -) -import textwrap - -import betterproto -from betterproto.compile.importing import ( - get_type_reference, - parse_source_type_name, -) -from betterproto.compile.naming import ( - pythonize_class_name, - pythonize_field_name, - pythonize_method_name, -) - -try: - # betterproto[compiler] specific dependencies - from google.protobuf.descriptor_pb2 import ( - DescriptorProto, - EnumDescriptorProto, - FieldDescriptorProto, - FileDescriptorProto, - MethodDescriptorProto, - ) -except ImportError as err: - missing_import = re.match(r".*(cannot import name .*$)", err.args[0]).group(1) - print( - "\033[31m" - f"Unable to import `{missing_import}` from betterproto plugin! " - "Please ensure that you've installed betterproto as " - '`pip install "betterproto[compiler]"` so that compiler dependencies ' - "are included." - "\033[0m" - ) - raise SystemExit(1) - -# Create a unique placeholder to deal with -# https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses -PLACEHOLDER = object() - -# Organize proto types into categories -PROTO_FLOAT_TYPES = ( - FieldDescriptorProto.TYPE_DOUBLE, # 1 - FieldDescriptorProto.TYPE_FLOAT, # 2 -) -PROTO_INT_TYPES = ( - FieldDescriptorProto.TYPE_INT64, # 3 - FieldDescriptorProto.TYPE_UINT64, # 4 - FieldDescriptorProto.TYPE_INT32, # 5 - FieldDescriptorProto.TYPE_FIXED64, # 6 - FieldDescriptorProto.TYPE_FIXED32, # 7 - FieldDescriptorProto.TYPE_UINT32, # 13 - FieldDescriptorProto.TYPE_SFIXED32, # 15 - FieldDescriptorProto.TYPE_SFIXED64, # 16 - FieldDescriptorProto.TYPE_SINT32, # 17 - FieldDescriptorProto.TYPE_SINT64, # 18 -) -PROTO_BOOL_TYPES = (FieldDescriptorProto.TYPE_BOOL,) # 8 -PROTO_STR_TYPES = (FieldDescriptorProto.TYPE_STRING,) # 9 -PROTO_BYTES_TYPES = (FieldDescriptorProto.TYPE_BYTES,) # 12 -PROTO_MESSAGE_TYPES = ( - FieldDescriptorProto.TYPE_MESSAGE, # 11 - FieldDescriptorProto.TYPE_ENUM, # 14 -) -PROTO_MAP_TYPES = (FieldDescriptorProto.TYPE_MESSAGE,) # 11 -PROTO_PACKED_TYPES = ( - FieldDescriptorProto.TYPE_DOUBLE, # 1 - FieldDescriptorProto.TYPE_FLOAT, # 2 - FieldDescriptorProto.TYPE_INT64, # 3 - FieldDescriptorProto.TYPE_UINT64, # 4 - FieldDescriptorProto.TYPE_INT32, # 5 - FieldDescriptorProto.TYPE_FIXED64, # 6 - FieldDescriptorProto.TYPE_FIXED32, # 7 - FieldDescriptorProto.TYPE_BOOL, # 8 - FieldDescriptorProto.TYPE_UINT32, # 13 - FieldDescriptorProto.TYPE_SFIXED32, # 15 - FieldDescriptorProto.TYPE_SFIXED64, # 16 - FieldDescriptorProto.TYPE_SINT32, # 17 - FieldDescriptorProto.TYPE_SINT64, # 18 -) - - -def get_comment(proto_file, path: List[int], indent: int = 4) -> str: - pad = " " * indent - for sci in proto_file.source_code_info.location: - # print(list(sci.path), path, file=sys.stderr) - if list(sci.path) == path and sci.leading_comments: - lines = textwrap.wrap( - sci.leading_comments.strip().replace("\n", ""), width=79 - indent, - ) - - if path[-2] == 2 and path[-4] != 6: - # This is a field - return f"{pad}# " + f"\n{pad}# ".join(lines) - else: - # This is a message, enum, service, or method - if len(lines) == 1 and len(lines[0]) < 79 - indent - 6: - lines[0] = lines[0].strip('"') - return f'{pad}"""{lines[0]}"""' - else: - joined = f"\n{pad}".join(lines) - return f'{pad}"""\n{pad}{joined}\n{pad}"""' - - return "" - - -class ProtoContentBase: - """Methods common to Message, Service and ServiceMethod.""" - - path: List[int] - comment_indent: int = 4 - - def __post_init__(self): - """Checks that no fake default fields were left as placeholders.""" - for field_name, field_val in self.__dataclass_fields__.items(): - if field_val is PLACEHOLDER: - raise ValueError(f"`{field_name}` is a required field.") - - @property - def output_file(self) -> "OutputTemplate": - current = self - while not isinstance(current, OutputTemplate): - current = current.parent - return current - - @property - def proto_file(self) -> FieldDescriptorProto: - current = self - while not isinstance(current, OutputTemplate): - current = current.parent - return current.package_proto_obj - - @property - def request(self) -> "Request": - current = self - while not isinstance(current, OutputTemplate): - current = current.parent - return current.parent_request - - @property - def comment(self) -> str: - """Crawl the proto source code and retrieve comments - for this object. - """ - return get_comment( - proto_file=self.proto_file, path=self.path, indent=self.comment_indent, - ) - - -@dataclass -class Request: - from typing import Any - - plugin_request_obj: Any - output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict) - - @property - def all_messages(self) -> List["Message"]: - """All of the messages in this request. - - Returns - ------- - List[Message] - List of all of the messages in this request. - """ - return [ - msg for output in self.output_packages.values() for msg in output.messages - ] - - -@dataclass -class OutputTemplate: - """Representation of an output .py file. - - Each output file corresponds to a .proto input file, - but may need references to other .proto files to be - built. - """ - - parent_request: Request - package_proto_obj: FileDescriptorProto - input_files: List[str] = field(default_factory=list) - imports: Set[str] = field(default_factory=set) - datetime_imports: Set[str] = field(default_factory=set) - typing_imports: Set[str] = field(default_factory=set) - messages: List["Message"] = field(default_factory=list) - enums: List["EnumDefinition"] = field(default_factory=list) - services: List["Service"] = field(default_factory=list) - - @property - def package(self) -> str: - """Name of input package. - - Returns - ------- - str - Name of input package. - """ - return self.package_proto_obj.package - - @property - def input_filenames(self) -> List[str]: - """Names of the input files used to build this output. - - Returns - ------- - List[str] - Names of the input files used to build this output. - """ - return [f.name for f in self.input_files] - - -@dataclass -class Message(ProtoContentBase): - """Representation of a protobuf message. - """ - - parent: Union["Message", OutputTemplate] = PLACEHOLDER - proto_obj: DescriptorProto = PLACEHOLDER - path: List[int] = PLACEHOLDER - fields: List[Union["Field", "Message"]] = field(default_factory=list) - - def __post_init__(self): - # Add message to output file - if isinstance(self.parent, OutputTemplate): - if isinstance(self, EnumDefinition): - self.output_file.enums.append(self) - else: - self.output_file.messages.append(self) - super().__post_init__() - - @property - def proto_name(self) -> str: - return self.proto_obj.name - - @property - def py_name(self) -> str: - return pythonize_class_name(self.proto_name) - - @property - def annotation(self) -> str: - if self.repeated: - return f"List[{self.py_name}]" - return self.py_name - - -def is_map( - proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto -) -> bool: - """True if proto_field_obj is a map, otherwise False. - """ - if proto_field_obj.type == FieldDescriptorProto.TYPE_MESSAGE: - # This might be a map... - message_type = proto_field_obj.type_name.split(".").pop().lower() - map_entry = f"{proto_field_obj.name.replace('_', '').lower()}entry" - if message_type == map_entry: - for nested in parent_message.nested_type: # parent message - if nested.name.replace("_", "").lower() == map_entry: - if nested.options.map_entry: - return True - return False - - -def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool: - """True if proto_field_obj is a OneOf, otherwise False. - """ - if proto_field_obj.HasField("oneof_index"): - return True - return False - - -@dataclass -class Field(Message): - parent: Message = PLACEHOLDER - proto_obj: FieldDescriptorProto = PLACEHOLDER - - def __post_init__(self): - # Add field to message - self.parent.fields.append(self) - # Check for new imports - annotation = self.annotation - if "Optional[" in annotation: - self.output_file.typing_imports.add("Optional") - if "List[" in annotation: - self.output_file.typing_imports.add("List") - if "Dict[" in annotation: - self.output_file.typing_imports.add("Dict") - if "timedelta" in annotation: - self.output_file.datetime_imports.add("timedelta") - if "datetime" in annotation: - self.output_file.datetime_imports.add("datetime") - super().__post_init__() # call Field -> Message __post_init__ - - def get_field_string(self, indent: int = 4) -> str: - """Construct string representation of this field as a field.""" - name = f"{self.py_name}" - annotations = f": {self.annotation}" - betterproto_field_type = ( - f"betterproto.{self.field_type}_field({self.proto_obj.number}" - + f"{self.betterproto_field_args}" - + ")" - ) - return name + annotations + " = " + betterproto_field_type - - @property - def betterproto_field_args(self): - args = "" - if self.field_wraps: - args = args + f", wraps={self.field_wraps}" - return args - - @property - def field_wraps(self) -> Union[str, None]: - """Returns betterproto wrapped field type or None. - """ - match_wrapper = re.match( - r"\.google\.protobuf\.(.+)Value", self.proto_obj.type_name - ) - if match_wrapper: - wrapped_type = "TYPE_" + match_wrapper.group(1).upper() - if hasattr(betterproto, wrapped_type): - return f"betterproto.{wrapped_type}" - return None - - @property - def repeated(self) -> bool: - if self.proto_obj.label == FieldDescriptorProto.LABEL_REPEATED and not is_map( - self.proto_obj, self.parent - ): - return True - return False - - @property - def mutable(self) -> bool: - """True if the field is a mutable type, otherwise False.""" - annotation = self.annotation - return annotation.startswith("List[") or annotation.startswith("Dict[") - - @property - def field_type(self) -> str: - """String representation of proto field type.""" - return ( - self.proto_obj.Type.Name(self.proto_obj.type).lower().replace("type_", "") - ) - - @property - def default_value_string(self) -> Union[Text, None, float, int]: - """Python representation of the default proto value. - """ - if self.repeated: - return "[]" - if self.py_type == "int": - return "0" - if self.py_type == "float": - return "0.0" - elif self.py_type == "bool": - return "False" - elif self.py_type == "str": - return '""' - elif self.py_type == "bytes": - return 'b""' - else: - # Message type - return "None" - - @property - def packed(self) -> bool: - """True if the wire representation is a packed format.""" - if self.repeated and self.proto_obj.type in PROTO_PACKED_TYPES: - return True - return False - - @property - def py_name(self) -> str: - """Pythonized name.""" - return pythonize_field_name(self.proto_name) - - @property - def proto_name(self) -> str: - """Original protobuf name.""" - return self.proto_obj.name - - @property - def py_type(self) -> str: - """String representation of Python type.""" - if self.proto_obj.type in PROTO_FLOAT_TYPES: - return "float" - elif self.proto_obj.type in PROTO_INT_TYPES: - return "int" - elif self.proto_obj.type in PROTO_BOOL_TYPES: - return "bool" - elif self.proto_obj.type in PROTO_STR_TYPES: - return "str" - elif self.proto_obj.type in PROTO_BYTES_TYPES: - return "bytes" - elif self.proto_obj.type in PROTO_MESSAGE_TYPES: - # Type referencing another defined Message or a named enum - return get_type_reference( - package=self.output_file.package, - imports=self.output_file.imports, - source_type=self.proto_obj.type_name, - ) - else: - raise NotImplementedError(f"Unknown type {field.type}") - - @property - def annotation(self) -> str: - if self.repeated: - return f"List[{self.py_type}]" - return self.py_type - - -@dataclass -class OneOfField(Field): - @property - def betterproto_field_args(self): - args = super().betterproto_field_args - group = self.parent.proto_obj.oneof_decl[self.proto_obj.oneof_index].name - args = args + f', group="{group}"' - return args - - -@dataclass -class MapField(Field): - py_k_type: Type = PLACEHOLDER - py_v_type: Type = PLACEHOLDER - proto_k_type: str = PLACEHOLDER - proto_v_type: str = PLACEHOLDER - - def __post_init__(self): - """Explore nested types and set k_type and v_type if unset.""" - map_entry = f"{self.proto_obj.name.replace('_', '').lower()}entry" - for nested in self.parent.proto_obj.nested_type: - if nested.name.replace("_", "").lower() == map_entry: - if nested.options.map_entry: - # Get Python types - self.py_k_type = Field( - parent=self, proto_obj=nested.field[0], # key - ).py_type - self.py_v_type = Field( - parent=self, proto_obj=nested.field[1], # key - ).py_type - # Get proto types - self.proto_k_type = self.proto_obj.Type.Name(nested.field[0].type) - self.proto_v_type = self.proto_obj.Type.Name(nested.field[1].type) - super().__post_init__() # call Field -> Message __post_init__ - - def get_field_string(self, indent: int = 4) -> str: - """Construct string representation of this field.""" - name = f"{self.py_name}" - annotations = f": {self.annotation}" - betterproto_field_type = ( - f"betterproto.map_field(" - f"{self.proto_obj.number}, betterproto.{self.proto_k_type}, " - f"betterproto.{self.proto_v_type})" - ) - return name + annotations + " = " + betterproto_field_type - - @property - def annotation(self): - return f"Dict[{self.py_k_type}, {self.py_v_type}]" - - @property - def repeated(self): - return False # maps cannot be repeated - - -@dataclass -class EnumDefinition(Message): - """Representation of a proto Enum definition.""" - - proto_obj: EnumDescriptorProto = PLACEHOLDER - entries: List["EnumDefinition.EnumEntry"] = PLACEHOLDER - - @dataclass(unsafe_hash=True) - class EnumEntry: - """Representation of an Enum entry.""" - - name: str - value: int - comment: str - - def __post_init__(self): - # Get entries - self.entries = [ - self.EnumEntry( - name=v.name, - value=v.number, - comment=get_comment( - proto_file=self.proto_file, path=self.path + [2, i] - ), - ) - for i, v in enumerate(self.proto_obj.value) - ] - super().__post_init__() # call Message __post_init__ - - @property - def default_value_string(self) -> int: - """Python representation of the default value for Enums. - - As per the spec, this is the first value of the Enum. - """ - return str(self.entries[0].value) # should ALWAYS be int(0)! - - -@dataclass -class Service(ProtoContentBase): - parent: OutputTemplate = PLACEHOLDER - proto_obj: DescriptorProto = PLACEHOLDER - path: List[int] = PLACEHOLDER - methods: List["ServiceMethod"] = field(default_factory=list) - - def __post_init__(self) -> None: - # Add service to output file - self.output_file.services.append(self) - super().__post_init__() # check for unset fields - - @property - def proto_name(self): - return self.proto_obj.name - - @property - def py_name(self): - return pythonize_class_name(self.proto_name) - - -@dataclass -class ServiceMethod(ProtoContentBase): - - parent: Service - proto_obj: MethodDescriptorProto - path: List[int] = PLACEHOLDER - comment_indent: int = 8 - - def __post_init__(self) -> None: - # Add method to service - self.parent.methods.append(self) - - # Check for Optional import - if self.py_input_message: - for f in self.py_input_message.fields: - if f.default_value_string == "None": - self.output_file.typing_imports.add("Optional") - if "Optional" in self.py_output_message_type: - self.output_file.typing_imports.add("Optional") - self.mutable_default_args # ensure this is called before rendering - - # Check for Async imports - if self.client_streaming: - self.output_file.typing_imports.add("AsyncIterable") - self.output_file.typing_imports.add("Iterable") - self.output_file.typing_imports.add("Union") - if self.server_streaming: - self.output_file.typing_imports.add("AsyncIterator") - - super().__post_init__() # check for unset fields - - @property - def mutable_default_args(self) -> Dict[str, str]: - """Handle mutable default arguments. - - Returns a list of tuples containing the name and default value - for arguments to this message who's default value is mutable. - The defaults are swapped out for None and replaced back inside - the method's body. - Reference: - https://docs.python-guide.org/writing/gotchas/#mutable-default-arguments - - Returns - ------- - Dict[str, str] - Name and actual default value (as a string) - for each argument with mutable default values. - """ - mutable_default_args = dict() - - if self.py_input_message: - for f in self.py_input_message.fields: - if ( - not self.client_streaming - and f.default_value_string != "None" - and f.mutable - ): - mutable_default_args[f.py_name] = f.default_value_string - self.output_file.typing_imports.add("Optional") - - return mutable_default_args - - @property - def py_name(self) -> str: - """Pythonized method name.""" - return pythonize_method_name(self.proto_obj.name) - - @property - def proto_name(self) -> str: - """Original protobuf name.""" - return self.proto_obj.name - - @property - def route(self) -> str: - return ( - f"/{self.output_file.package}." - f"{self.parent.proto_name}/{self.proto_name}" - ) - - @property - def py_input_message(self) -> Union[None, Message]: - """Find the input message object. - - Returns - ------- - Union[None, Message] - Method instance representing the input message. - If not input message could be found or there are no - input messages, None is returned. - """ - package, name = parse_source_type_name(self.proto_obj.input_type) - - # Nested types are currently flattened without dots. - # Todo: keep a fully quantified name in types, that is - # comparable with method.input_type - for msg in self.request.all_messages: - if ( - msg.py_name == name.replace(".", "") - and msg.output_file.package == package - ): - return msg - return None - - @property - def py_input_message_type(self) -> str: - """String representation of the Python type correspoding to the - input message. - - Returns - ------- - str - String representation of the Python type correspoding to the - input message. - """ - return get_type_reference( - package=self.output_file.package, - imports=self.output_file.imports, - source_type=self.proto_obj.input_type, - ).strip('"') - - @property - def py_output_message_type(self) -> str: - """String representation of the Python type correspoding to the - output message. - - Returns - ------- - str - String representation of the Python type correspoding to the - output message. - """ - return get_type_reference( - package=self.output_file.package, - imports=self.output_file.imports, - source_type=self.proto_obj.output_type, - unwrap=False, - ).strip('"') - - @property - def client_streaming(self) -> bool: - return self.proto_obj.client_streaming - - @property - def server_streaming(self) -> bool: - return self.proto_obj.server_streaming From ecd4b5a46af0717e43458afa861a0091d3fa03a0 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Fri, 17 Jul 2020 16:31:28 -0500 Subject: [PATCH 19/23] structure changes --- src/betterproto/plugin/__init__.py | 46 +----------------------------- src/betterproto/plugin/main.py | 43 ++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 45 deletions(-) create mode 100644 src/betterproto/plugin/main.py diff --git a/src/betterproto/plugin/__init__.py b/src/betterproto/plugin/__init__.py index 7d54c825..6b898caa 100644 --- a/src/betterproto/plugin/__init__.py +++ b/src/betterproto/plugin/__init__.py @@ -1,49 +1,5 @@ #!/usr/bin/env python -import sys -import os - -from google.protobuf.compiler import plugin_pb2 as plugin - -from betterproto.plugin.parser import generate_code - - -def main(): - - """The plugin's main entry point.""" - # Read request message from stdin - data = sys.stdin.buffer.read() - - # Parse request - request = plugin.CodeGeneratorRequest() - request.ParseFromString(data) - - dump_file = os.getenv("BETTERPROTO_DUMP") - if dump_file: - dump_request(dump_file, request) - - # Create response - response = plugin.CodeGeneratorResponse() - - # Generate code - generate_code(request, response) - - # Serialise response message - output = response.SerializeToString() - - # Write to stdout - sys.stdout.buffer.write(output) - - -def dump_request(dump_file: str, request: plugin.CodeGeneratorRequest): - """ - For developers: Supports running plugin.py standalone so its possible to debug it. - Run protoc (or generate.py) with BETTERPROTO_DUMP="yourfile.bin" to write the request to a file. - Then run plugin.py from your IDE in debugging mode, and redirect stdin to the file. - """ - with open(str(dump_file), "wb") as fh: - sys.stderr.write(f"\033[31mWriting input from protoc to: {dump_file}\033[0m\n") - fh.write(request.SerializeToString()) - +from .main import main if __name__ == "__main__": main() diff --git a/src/betterproto/plugin/main.py b/src/betterproto/plugin/main.py new file mode 100644 index 00000000..070d08fb --- /dev/null +++ b/src/betterproto/plugin/main.py @@ -0,0 +1,43 @@ +import sys +import os + +from google.protobuf.compiler import plugin_pb2 as plugin + +from betterproto.plugin.parser import generate_code + + +def main(): + """The plugin's main entry point.""" + # Read request message from stdin + data = sys.stdin.buffer.read() + + # Parse request + request = plugin.CodeGeneratorRequest() + request.ParseFromString(data) + + dump_file = os.getenv("BETTERPROTO_DUMP") + if dump_file: + dump_request(dump_file, request) + + # Create response + response = plugin.CodeGeneratorResponse() + + # Generate code + generate_code(request, response) + + # Serialise response message + output = response.SerializeToString() + + # Write to stdout + sys.stdout.buffer.write(output) + + +def dump_request(dump_file: str, request: plugin.CodeGeneratorRequest): + """ + For developers: Supports running plugin.py standalone so its possible to debug it. + Run protoc (or generate.py) with BETTERPROTO_DUMP="yourfile.bin" to write the request to a file. + Then run plugin.py from your IDE in debugging mode, and redirect stdin to the file. + """ + with open(str(dump_file), "wb") as fh: + sys.stderr.write(f"\033[31mWriting input from protoc to: {dump_file}\033[0m\n") + fh.write(request.SerializeToString()) From 0fdb181990dae4aa3fb0dcb1be640722323b175e Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sun, 19 Jul 2020 13:28:57 -0500 Subject: [PATCH 20/23] restructure, rename --- src/betterproto/plugin/__init__.py | 6 +-- src/betterproto/plugin/__main__.py | 4 ++ src/betterproto/plugin/main.py | 5 +++ src/betterproto/plugin/models.py | 68 +++++++++++++++--------------- src/betterproto/plugin/parser.py | 40 +++++++++++------- 5 files changed, 69 insertions(+), 54 deletions(-) create mode 100644 src/betterproto/plugin/__main__.py diff --git a/src/betterproto/plugin/__init__.py b/src/betterproto/plugin/__init__.py index 6b898caa..b668da91 100644 --- a/src/betterproto/plugin/__init__.py +++ b/src/betterproto/plugin/__init__.py @@ -1,5 +1 @@ -#!/usr/bin/env python -from .main import main - -if __name__ == "__main__": - main() +from .main import main \ No newline at end of file diff --git a/src/betterproto/plugin/__main__.py b/src/betterproto/plugin/__main__.py new file mode 100644 index 00000000..a47523db --- /dev/null +++ b/src/betterproto/plugin/__main__.py @@ -0,0 +1,4 @@ +from .main import main + + +main() \ No newline at end of file diff --git a/src/betterproto/plugin/main.py b/src/betterproto/plugin/main.py index 070d08fb..2604af2b 100644 --- a/src/betterproto/plugin/main.py +++ b/src/betterproto/plugin/main.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python import sys import os @@ -41,3 +42,7 @@ def dump_request(dump_file: str, request: plugin.CodeGeneratorRequest): with open(str(dump_file), "wb") as fh: sys.stderr.write(f"\033[31mWriting input from protoc to: {dump_file}\033[0m\n") fh.write(request.SerializeToString()) + + +if __name__ == "__main__": + main() diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index c9609c46..c4b417bb 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -55,6 +55,7 @@ try: # betterproto[compiler] specific dependencies + from google.protobuf.compiler import plugin_pb2 as plugin from google.protobuf.descriptor_pb2 import ( DescriptorProto, EnumDescriptorProto, @@ -145,7 +146,7 @@ def get_comment(proto_file, path: List[int], indent: int = 4) -> str: class ProtoContentBase: - """Methods common to Message, Service and ServiceMethod.""" + """Methods common to MessageCompiler, ServiceCompiler and ServiceMethodCompiler.""" path: List[int] comment_indent: int = 4 @@ -171,7 +172,7 @@ def proto_file(self) -> FieldDescriptorProto: return current.package_proto_obj @property - def request(self) -> "Request": + def request(self) -> "PluginRequestCompiler": current = self while not isinstance(current, OutputTemplate): current = current.parent @@ -188,19 +189,18 @@ def comment(self) -> str: @dataclass -class Request: - from typing import Any +class PluginRequestCompiler: - plugin_request_obj: Any + plugin_request_obj: plugin.CodeGeneratorRequest output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict) @property - def all_messages(self) -> List["Message"]: + def all_messages(self) -> List["MessageCompiler"]: """All of the messages in this request. Returns ------- - List[Message] + List[MessageCompiler] List of all of the messages in this request. """ return [ @@ -217,15 +217,15 @@ class OutputTemplate: built. """ - parent_request: Request + parent_request: PluginRequestCompiler package_proto_obj: FileDescriptorProto input_files: List[str] = field(default_factory=list) imports: Set[str] = field(default_factory=set) datetime_imports: Set[str] = field(default_factory=set) typing_imports: Set[str] = field(default_factory=set) - messages: List["Message"] = field(default_factory=list) - enums: List["EnumDefinition"] = field(default_factory=list) - services: List["Service"] = field(default_factory=list) + messages: List["MessageCompiler"] = field(default_factory=list) + enums: List["EnumDefinitionCompiler"] = field(default_factory=list) + services: List["ServiceCompiler"] = field(default_factory=list) @property def package(self) -> str: @@ -251,19 +251,21 @@ def input_filenames(self) -> List[str]: @dataclass -class Message(ProtoContentBase): +class MessageCompiler(ProtoContentBase): """Representation of a protobuf message. """ - parent: Union["Message", OutputTemplate] = PLACEHOLDER + parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER proto_obj: DescriptorProto = PLACEHOLDER path: List[int] = PLACEHOLDER - fields: List[Union["Field", "Message"]] = field(default_factory=list) + fields: List[Union["FieldCompiler", "MessageCompiler"]] = field( + default_factory=list + ) def __post_init__(self): # Add message to output file if isinstance(self.parent, OutputTemplate): - if isinstance(self, EnumDefinition): + if isinstance(self, EnumDefinitionCompiler): self.output_file.enums.append(self) else: self.output_file.messages.append(self) @@ -310,8 +312,8 @@ def is_oneof(proto_field_obj: FieldDescriptorProto) -> bool: @dataclass -class Field(Message): - parent: Message = PLACEHOLDER +class FieldCompiler(MessageCompiler): + parent: MessageCompiler = PLACEHOLDER proto_obj: FieldDescriptorProto = PLACEHOLDER def __post_init__(self): @@ -329,7 +331,7 @@ def __post_init__(self): self.output_file.datetime_imports.add("timedelta") if "datetime" in annotation: self.output_file.datetime_imports.add("datetime") - super().__post_init__() # call Field -> Message __post_init__ + super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__ def get_field_string(self, indent: int = 4) -> str: """Construct string representation of this field as a field.""" @@ -451,9 +453,9 @@ def annotation(self) -> str: @dataclass -class OneOfField(Field): +class OneOfFieldCompiler(FieldCompiler): @property - def betterproto_field_args(self): + def betterproto_field_args(self) -> "str": args = super().betterproto_field_args group = self.parent.proto_obj.oneof_decl[self.proto_obj.oneof_index].name args = args + f', group="{group}"' @@ -461,7 +463,7 @@ def betterproto_field_args(self): @dataclass -class MapField(Field): +class MapEntryCompiler(FieldCompiler): py_k_type: Type = PLACEHOLDER py_v_type: Type = PLACEHOLDER proto_k_type: str = PLACEHOLDER @@ -474,16 +476,16 @@ def __post_init__(self): if nested.name.replace("_", "").lower() == map_entry: if nested.options.map_entry: # Get Python types - self.py_k_type = Field( + self.py_k_type = FieldCompiler( parent=self, proto_obj=nested.field[0], # key ).py_type - self.py_v_type = Field( + self.py_v_type = FieldCompiler( parent=self, proto_obj=nested.field[1], # key ).py_type # Get proto types self.proto_k_type = self.proto_obj.Type.Name(nested.field[0].type) self.proto_v_type = self.proto_obj.Type.Name(nested.field[1].type) - super().__post_init__() # call Field -> Message __post_init__ + super().__post_init__() # call FieldCompiler-> MessageCompiler __post_init__ def get_field_string(self, indent: int = 4) -> str: """Construct string representation of this field.""" @@ -506,11 +508,11 @@ def repeated(self): @dataclass -class EnumDefinition(Message): +class EnumDefinitionCompiler(MessageCompiler): """Representation of a proto Enum definition.""" proto_obj: EnumDescriptorProto = PLACEHOLDER - entries: List["EnumDefinition.EnumEntry"] = PLACEHOLDER + entries: List["EnumDefinitionCompiler.EnumEntry"] = PLACEHOLDER @dataclass(unsafe_hash=True) class EnumEntry: @@ -532,7 +534,7 @@ def __post_init__(self): ) for i, v in enumerate(self.proto_obj.value) ] - super().__post_init__() # call Message __post_init__ + super().__post_init__() # call MessageCompiler __post_init__ @property def default_value_string(self) -> int: @@ -544,11 +546,11 @@ def default_value_string(self) -> int: @dataclass -class Service(ProtoContentBase): +class ServiceCompiler(ProtoContentBase): parent: OutputTemplate = PLACEHOLDER proto_obj: DescriptorProto = PLACEHOLDER path: List[int] = PLACEHOLDER - methods: List["ServiceMethod"] = field(default_factory=list) + methods: List["ServiceMethodCompiler"] = field(default_factory=list) def __post_init__(self) -> None: # Add service to output file @@ -565,9 +567,9 @@ def py_name(self): @dataclass -class ServiceMethod(ProtoContentBase): +class ServiceMethodCompiler(ProtoContentBase): - parent: Service + parent: ServiceCompiler proto_obj: MethodDescriptorProto path: List[int] = PLACEHOLDER comment_indent: int = 8 @@ -644,12 +646,12 @@ def route(self) -> str: ) @property - def py_input_message(self) -> Union[None, Message]: + def py_input_message(self) -> Union[None, MessageCompiler]: """Find the input message object. Returns ------- - Union[None, Message] + Union[None, MessageCompiler] Method instance representing the input message. If not input message could be found or there are no input messages, None is returned. diff --git a/src/betterproto/plugin/parser.py b/src/betterproto/plugin/parser.py index 0bc17acc..33991ec7 100644 --- a/src/betterproto/plugin/parser.py +++ b/src/betterproto/plugin/parser.py @@ -28,15 +28,15 @@ raise SystemExit(1) from betterproto.plugin.models import ( - Request, + PluginRequestCompiler, OutputTemplate, - Message, - Field, - OneOfField, - MapField, - EnumDefinition, - Service, - ServiceMethod, + MessageCompiler, + FieldCompiler, + OneOfFieldCompiler, + MapEntryCompiler, + EnumDefinitionCompiler, + ServiceCompiler, + ServiceMethodCompiler, is_map, is_oneof, ) @@ -80,7 +80,7 @@ def generate_code( loader=jinja2.FileSystemLoader(templates_folder), ) template = env.get_template("template.py.j2") - request_data = Request(plugin_request_obj=request) + request_data = PluginRequestCompiler(plugin_request_obj=request) # Gather output packages for proto_file in request.proto_file: if ( @@ -157,24 +157,32 @@ def read_protobuf_type( # Skip generated map entry messages since we just use dicts return # Process Message - message_data = Message(parent=output_package, proto_obj=item, path=path) + message_data = MessageCompiler(parent=output_package, proto_obj=item, path=path) for index, field in enumerate(item.field): if is_map(field, item): - MapField(parent=message_data, proto_obj=field, path=path + [2, index]) + MapEntryCompiler( + parent=message_data, proto_obj=field, path=path + [2, index] + ) elif is_oneof(field): - OneOfField(parent=message_data, proto_obj=field, path=path + [2, index]) + OneOfFieldCompiler( + parent=message_data, proto_obj=field, path=path + [2, index] + ) else: - Field(parent=message_data, proto_obj=field, path=path + [2, index]) + FieldCompiler( + parent=message_data, proto_obj=field, path=path + [2, index] + ) elif isinstance(item, EnumDescriptorProto): # Enum - EnumDefinition(parent=output_package, proto_obj=item, path=path) + EnumDefinitionCompiler(parent=output_package, proto_obj=item, path=path) def read_protobuf_service( service: ServiceDescriptorProto, index: int, output_package: OutputTemplate ) -> None: - service_data = Service(parent=output_package, proto_obj=service, path=[6, index],) + service_data = ServiceCompiler( + parent=output_package, proto_obj=service, path=[6, index], + ) for j, method in enumerate(service.method): - ServiceMethod( + ServiceMethodCompiler( parent=service_data, proto_obj=method, path=[6, index, 2, j], ) From 8a99cbd36c63363acd80fb419685b05b350c2731 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Sun, 19 Jul 2020 13:30:36 -0500 Subject: [PATCH 21/23] blacken --- src/betterproto/plugin/__init__.py | 2 +- src/betterproto/plugin/__main__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/betterproto/plugin/__init__.py b/src/betterproto/plugin/__init__.py index b668da91..c28a133f 100644 --- a/src/betterproto/plugin/__init__.py +++ b/src/betterproto/plugin/__init__.py @@ -1 +1 @@ -from .main import main \ No newline at end of file +from .main import main diff --git a/src/betterproto/plugin/__main__.py b/src/betterproto/plugin/__main__.py index a47523db..bd95daea 100644 --- a/src/betterproto/plugin/__main__.py +++ b/src/betterproto/plugin/__main__.py @@ -1,4 +1,4 @@ from .main import main -main() \ No newline at end of file +main() From 66610b0fe8095c52171d14fdf1d6314604cc165b Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Mon, 20 Jul 2020 09:12:04 -0500 Subject: [PATCH 22/23] fix comment --- src/betterproto/plugin/models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index c4b417bb..c0d42a6b 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -480,7 +480,7 @@ def __post_init__(self): parent=self, proto_obj=nested.field[0], # key ).py_type self.py_v_type = FieldCompiler( - parent=self, proto_obj=nested.field[1], # key + parent=self, proto_obj=nested.field[1], # value ).py_type # Get proto types self.proto_k_type = self.proto_obj.Type.Name(nested.field[0].type) From 6f523430dd40472308b61970f21aebe49d8eabcb Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Mon, 20 Jul 2020 09:17:29 -0500 Subject: [PATCH 23/23] clarify variable names --- src/betterproto/plugin/models.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index c0d42a6b..8e19961c 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -523,16 +523,16 @@ class EnumEntry: comment: str def __post_init__(self): - # Get entries + # Get entries/allowed values for this Enum self.entries = [ self.EnumEntry( - name=v.name, - value=v.number, + name=entry_proto_value.name, + value=entry_proto_value.number, comment=get_comment( - proto_file=self.proto_file, path=self.path + [2, i] + proto_file=self.proto_file, path=self.path + [2, entry_number] ), ) - for i, v in enumerate(self.proto_obj.value) + for entry_number, entry_proto_value in enumerate(self.proto_obj.value) ] super().__post_init__() # call MessageCompiler __post_init__ @@ -542,7 +542,7 @@ def default_value_string(self) -> int: As per the spec, this is the first value of the Enum. """ - return str(self.entries[0].value) # should ALWAYS be int(0)! + return str(self.entries[0].value) # ideally, should ALWAYS be int(0)! @dataclass