From 0af0cf4bfbd5369a56db7f9da54963ac09f7a277 Mon Sep 17 00:00:00 2001 From: boukeversteegh Date: Sat, 4 Jul 2020 15:35:42 +0200 Subject: [PATCH 1/7] Fixes circular import problem when a non-circular dependency triangle is flattened into two python packages --- betterproto/compile/importing.py | 12 ++-- betterproto/plugin.py | 2 +- betterproto/templates/template.py.j2 | 16 +++--- betterproto/tests/inputs/config.py | 1 - betterproto/tests/test_get_ref_type.py | 76 +++++++++++++++----------- 5 files changed, 59 insertions(+), 48 deletions(-) diff --git a/betterproto/compile/importing.py b/betterproto/compile/importing.py index 40441f8c..57ef376f 100644 --- a/betterproto/compile/importing.py +++ b/betterproto/compile/importing.py @@ -86,7 +86,7 @@ def reference_absolute(imports, py_package, py_type): string_import = ".".join(py_package) string_alias = safe_snake_case(string_import) imports.add(f"import {string_import} as {string_alias}") - return f"{string_alias}.{py_type}" + return f'"{string_alias}.{py_type}"' def reference_sibling(py_type: str) -> str: @@ -109,10 +109,10 @@ def reference_descendent( if string_from: string_alias = "_".join(importing_descendent) imports.add(f"from .{string_from} import {string_import} as {string_alias}") - return f"{string_alias}.{py_type}" + return f'"{string_alias}.{py_type}"' else: imports.add(f"from . import {string_import}") - return f"{string_import}.{py_type}" + return f'"{string_import}.{py_type}"' def reference_ancestor( @@ -130,11 +130,11 @@ def reference_ancestor( string_alias = f"_{'_' * distance_up}{string_import}__" string_from = f"..{'.' * distance_up}" imports.add(f"from {string_from} import {string_import} as {string_alias}") - return f"{string_alias}.{py_type}" + return f'"{string_alias}.{py_type}"' else: string_alias = f"{'_' * distance_up}{py_type}__" imports.add(f"from .{'.' * distance_up} import {py_type} as {string_alias}") - return string_alias + return f'"{string_alias}"' def reference_cousin( @@ -157,4 +157,4 @@ def reference_cousin( + "__" ) imports.add(f"from {string_from} import {string_import} as {string_alias}") - return f"{string_alias}.{py_type}" + return f'"{string_alias}.{py_type}"' diff --git a/betterproto/plugin.py b/betterproto/plugin.py index e835fab7..9f4df64e 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -329,7 +329,7 @@ def generate_code(request, response): output["imports"], method.output_type, unwrap=False, - ).strip('"'), + ), "client_streaming": method.client_streaming, "server_streaming": method.server_streaming, } diff --git a/betterproto/templates/template.py.j2 b/betterproto/templates/template.py.j2 index 3894619a..b2d9112a 100644 --- a/betterproto/templates/template.py.j2 +++ b/betterproto/templates/template.py.j2 @@ -16,10 +16,6 @@ import betterproto import grpclib {% endif %} -{% for i in description.imports %} -{{ i }} -{% endfor %} - {% if description.enums %}{% for enum in description.enums %} class {{ enum.py_name }}(betterproto.Enum): @@ -102,14 +98,14 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): "{{ method.route }}", request_iterator, {{ method.input }}, - {{ method.output }}, + {{ method.output.strip('"') }}, ): yield response {% else %}{# i.e. not client streaming #} async for response in self._unary_stream( "{{ method.route }}", request, - {{ method.output }}, + {{ method.output.strip('"') }}, ): yield response @@ -120,16 +116,20 @@ class {{ service.py_name }}Stub(betterproto.ServiceStub): "{{ method.route }}", request_iterator, {{ method.input }}, - {{ method.output }} + {{ method.output.strip('"') }} ) {% else %}{# i.e. not client streaming #} return await self._unary_unary( "{{ method.route }}", request, - {{ method.output }} + {{ method.output.strip('"') }} ) {% endif %}{# client streaming #} {% endif %} {% endfor %} {% endfor %} + +{% for i in description.imports %} +{{ i }} +{% endfor %} \ No newline at end of file diff --git a/betterproto/tests/inputs/config.py b/betterproto/tests/inputs/config.py index eab5ea4c..38e9603f 100644 --- a/betterproto/tests/inputs/config.py +++ b/betterproto/tests/inputs/config.py @@ -1,7 +1,6 @@ # Test cases that are expected to fail, e.g. unimplemented features or bug-fixes. # Remove from list when fixed. xfail = { - "import_circular_dependency", "oneof_enum", # 63 "namespace_keywords", # 70 "namespace_builtin_types", # 53 diff --git a/betterproto/tests/test_get_ref_type.py b/betterproto/tests/test_get_ref_type.py index 2bedf76c..5a1722b1 100644 --- a/betterproto/tests/test_get_ref_type.py +++ b/betterproto/tests/test_get_ref_type.py @@ -8,22 +8,22 @@ [ ( ".google.protobuf.Empty", - "betterproto_lib_google_protobuf.Empty", + '"betterproto_lib_google_protobuf.Empty"', "import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf", ), ( ".google.protobuf.Struct", - "betterproto_lib_google_protobuf.Struct", + '"betterproto_lib_google_protobuf.Struct"', "import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf", ), ( ".google.protobuf.ListValue", - "betterproto_lib_google_protobuf.ListValue", + '"betterproto_lib_google_protobuf.ListValue"', "import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf", ), ( ".google.protobuf.Value", - "betterproto_lib_google_protobuf.Value", + '"betterproto_lib_google_protobuf.Value"', "import betterproto.lib.google.protobuf as betterproto_lib_google_protobuf", ), ], @@ -67,15 +67,27 @@ def test_referenceing_google_wrappers_unwraps_them( @pytest.mark.parametrize( ["google_type", "expected_name"], [ - (".google.protobuf.DoubleValue", "betterproto_lib_google_protobuf.DoubleValue"), - (".google.protobuf.FloatValue", "betterproto_lib_google_protobuf.FloatValue"), - (".google.protobuf.Int32Value", "betterproto_lib_google_protobuf.Int32Value"), - (".google.protobuf.Int64Value", "betterproto_lib_google_protobuf.Int64Value"), - (".google.protobuf.UInt32Value", "betterproto_lib_google_protobuf.UInt32Value"), - (".google.protobuf.UInt64Value", "betterproto_lib_google_protobuf.UInt64Value"), - (".google.protobuf.BoolValue", "betterproto_lib_google_protobuf.BoolValue"), - (".google.protobuf.StringValue", "betterproto_lib_google_protobuf.StringValue"), - (".google.protobuf.BytesValue", "betterproto_lib_google_protobuf.BytesValue"), + ( + ".google.protobuf.DoubleValue", + '"betterproto_lib_google_protobuf.DoubleValue"', + ), + (".google.protobuf.FloatValue", '"betterproto_lib_google_protobuf.FloatValue"'), + (".google.protobuf.Int32Value", '"betterproto_lib_google_protobuf.Int32Value"'), + (".google.protobuf.Int64Value", '"betterproto_lib_google_protobuf.Int64Value"'), + ( + ".google.protobuf.UInt32Value", + '"betterproto_lib_google_protobuf.UInt32Value"', + ), + ( + ".google.protobuf.UInt64Value", + '"betterproto_lib_google_protobuf.UInt64Value"', + ), + (".google.protobuf.BoolValue", '"betterproto_lib_google_protobuf.BoolValue"'), + ( + ".google.protobuf.StringValue", + '"betterproto_lib_google_protobuf.StringValue"', + ), + (".google.protobuf.BytesValue", '"betterproto_lib_google_protobuf.BytesValue"'), ], ) def test_referenceing_google_wrappers_without_unwrapping( @@ -95,7 +107,7 @@ def test_reference_child_package_from_package(): ) assert imports == {"from . import child"} - assert name == "child.Message" + assert name == '"child.Message"' def test_reference_child_package_from_root(): @@ -103,7 +115,7 @@ def test_reference_child_package_from_root(): name = get_type_reference(package="", imports=imports, source_type="child.Message") assert imports == {"from . import child"} - assert name == "child.Message" + assert name == '"child.Message"' def test_reference_camel_cased(): @@ -113,7 +125,7 @@ def test_reference_camel_cased(): ) assert imports == {"from . import child_package"} - assert name == "child_package.ExampleMessage" + assert name == '"child_package.ExampleMessage"' def test_reference_nested_child_from_root(): @@ -123,7 +135,7 @@ def test_reference_nested_child_from_root(): ) assert imports == {"from .nested import child as nested_child"} - assert name == "nested_child.Message" + assert name == '"nested_child.Message"' def test_reference_deeply_nested_child_from_root(): @@ -133,7 +145,7 @@ def test_reference_deeply_nested_child_from_root(): ) assert imports == {"from .deeply.nested import child as deeply_nested_child"} - assert name == "deeply_nested_child.Message" + assert name == '"deeply_nested_child.Message"' def test_reference_deeply_nested_child_from_package(): @@ -145,7 +157,7 @@ def test_reference_deeply_nested_child_from_package(): ) assert imports == {"from .deeply.nested import child as deeply_nested_child"} - assert name == "deeply_nested_child.Message" + assert name == '"deeply_nested_child.Message"' def test_reference_root_sibling(): @@ -181,7 +193,7 @@ def test_reference_parent_package_from_child(): ) assert imports == {"from ... import package as __package__"} - assert name == "__package__.Message" + assert name == '"__package__.Message"' def test_reference_parent_package_from_deeply_nested_child(): @@ -193,7 +205,7 @@ def test_reference_parent_package_from_deeply_nested_child(): ) assert imports == {"from ... import nested as __nested__"} - assert name == "__nested__.Message" + assert name == '"__nested__.Message"' def test_reference_ancestor_package_from_nested_child(): @@ -205,7 +217,7 @@ def test_reference_ancestor_package_from_nested_child(): ) assert imports == {"from .... import ancestor as ___ancestor__"} - assert name == "___ancestor__.Message" + assert name == '"___ancestor__.Message"' def test_reference_root_package_from_child(): @@ -215,7 +227,7 @@ def test_reference_root_package_from_child(): ) assert imports == {"from ... import Message as __Message__"} - assert name == "__Message__" + assert name == '"__Message__"' def test_reference_root_package_from_deeply_nested_child(): @@ -225,7 +237,7 @@ def test_reference_root_package_from_deeply_nested_child(): ) assert imports == {"from ..... import Message as ____Message__"} - assert name == "____Message__" + assert name == '"____Message__"' def test_reference_unrelated_package(): @@ -233,7 +245,7 @@ def test_reference_unrelated_package(): name = get_type_reference(package="a", imports=imports, source_type="p.Message") assert imports == {"from .. import p as _p__"} - assert name == "_p__.Message" + assert name == '"_p__.Message"' def test_reference_unrelated_nested_package(): @@ -241,7 +253,7 @@ def test_reference_unrelated_nested_package(): name = get_type_reference(package="a.b", imports=imports, source_type="p.q.Message") assert imports == {"from ...p import q as __p_q__"} - assert name == "__p_q__.Message" + assert name == '"__p_q__.Message"' def test_reference_unrelated_deeply_nested_package(): @@ -251,7 +263,7 @@ def test_reference_unrelated_deeply_nested_package(): ) assert imports == {"from .....p.q.r import s as ____p_q_r_s__"} - assert name == "____p_q_r_s__.Message" + assert name == '"____p_q_r_s__.Message"' def test_reference_cousin_package(): @@ -259,7 +271,7 @@ def test_reference_cousin_package(): name = get_type_reference(package="a.x", imports=imports, source_type="a.y.Message") assert imports == {"from .. import y as _y__"} - assert name == "_y__.Message" + assert name == '"_y__.Message"' def test_reference_cousin_package_different_name(): @@ -269,7 +281,7 @@ def test_reference_cousin_package_different_name(): ) assert imports == {"from ...cousin import package2 as __cousin_package2__"} - assert name == "__cousin_package2__.Message" + assert name == '"__cousin_package2__.Message"' def test_reference_cousin_package_same_name(): @@ -279,7 +291,7 @@ def test_reference_cousin_package_same_name(): ) assert imports == {"from ...cousin import package as __cousin_package__"} - assert name == "__cousin_package__.Message" + assert name == '"__cousin_package__.Message"' def test_reference_far_cousin_package(): @@ -289,7 +301,7 @@ def test_reference_far_cousin_package(): ) assert imports == {"from ...b import c as __b_c__"} - assert name == "__b_c__.Message" + assert name == '"__b_c__.Message"' def test_reference_far_far_cousin_package(): @@ -299,7 +311,7 @@ def test_reference_far_far_cousin_package(): ) assert imports == {"from ....b.c import d as ___b_c_d__"} - assert name == "___b_c_d__.Message" + assert name == '"___b_c_d__.Message"' @pytest.mark.parametrize( From 98d00f0d217afa9c241e9a43bd7a715c6b7d8fdd Mon Sep 17 00:00:00 2001 From: boukeversteegh Date: Sun, 5 Jul 2020 12:13:59 +0200 Subject: [PATCH 2/7] Supports running plugin.py standalone by reading from a dump-file, so its possible to debug it. --- betterproto/plugin.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/betterproto/plugin.py b/betterproto/plugin.py index 9f4df64e..27788afd 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -386,6 +386,10 @@ def main(): request = plugin.CodeGeneratorRequest() request.ParseFromString(data) + dump_file = os.getenv("DUMP_FILE") + if dump_file: + dump_request(dump_file, request) + # Create response response = plugin.CodeGeneratorResponse() @@ -399,5 +403,16 @@ def main(): sys.stdout.buffer.write(output) +def dump_request(dump_file: str, request: CodeGeneratorRequest): + """ + For developers: Supports running plugin.py standalone so its possible to debug it. + Run protoc (or generate.py) with DUMP_FILE="yourfile.bin" to write the request to a file. + Then run plugin.py from your IDE in debugging mode, and redirect stdin to the file. + """ + with open(str(dump_file), "wb") as fh: + sys.stderr.write(f"\033[31mWriting: {dump_file}\033[0m\n") + fh.write(request.SerializeToString()) + + if __name__ == "__main__": main() From f2e87192b0c0602634baed331ffc17311201098f Mon Sep 17 00:00:00 2001 From: boukeversteegh Date: Sun, 5 Jul 2020 12:24:21 +0200 Subject: [PATCH 3/7] Clarify variable names --- betterproto/plugin.py | 102 +++++++++++++++++++++--------------------- 1 file changed, 52 insertions(+), 50 deletions(-) diff --git a/betterproto/plugin.py b/betterproto/plugin.py index 27788afd..8791c242 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -1,5 +1,5 @@ #!/usr/bin/env python - +import collections import itertools import os.path import pathlib @@ -8,6 +8,8 @@ import textwrap from typing import List, Union +from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest + import betterproto from betterproto.compile.importing import get_type_reference from betterproto.compile.naming import ( @@ -129,7 +131,8 @@ def generate_code(request, response): ) template = env.get_template("template.py.j2") - output_map = {} + # Gather output packages + output_package_files = collections.defaultdict() for proto_file in request.proto_file: if ( proto_file.package == "google.protobuf" @@ -137,21 +140,18 @@ def generate_code(request, response): ): continue - output_file = str(pathlib.Path(*proto_file.package.split("."), "__init__.py")) - - if output_file not in output_map: - output_map[output_file] = {"package": proto_file.package, "files": []} - output_map[output_file]["files"].append(proto_file) - - # TODO: Figure out how to handle gRPC request/response messages and add - # processing below for Service. - - for filename, options in output_map.items(): - package = options["package"] - # print(package, filename, file=sys.stderr) - output = { - "package": package, - "files": [f.name for f in options["files"]], + output_package = proto_file.package + output_package_files.setdefault( + output_package, {"input_package": proto_file.package, "files": []} + ) + output_package_files[output_package]["files"].append(proto_file) + + output_paths = set() + for output_package_name, output_package_content in output_package_files.items(): + input_package_name = output_package_content["input_package"] + template_data = { + "input_package": input_package_name, + "files": [f.name for f in output_package_content["files"]], "imports": set(), "datetime_imports": set(), "typing_imports": set(), @@ -160,7 +160,7 @@ def generate_code(request, response): "services": [], } - for proto_file in options["files"]: + for proto_file in output_package_content["files"]: item: DescriptorProto for item, path in traverse(proto_file): data = {"name": item.name, "py_name": pythonize_class_name(item.name)} @@ -180,7 +180,7 @@ def generate_code(request, response): ) for i, f in enumerate(item.field): - t = py_type(package, output["imports"], f) + t = py_type(input_package_name, template_data["imports"], f) zero = get_py_zero(f.type) repeated = False @@ -213,13 +213,13 @@ def generate_code(request, response): if nested.options.map_entry: # print("Found a map!", file=sys.stderr) k = py_type( - package, - output["imports"], + input_package_name, + template_data["imports"], nested.field[0], ) v = py_type( - package, - output["imports"], + input_package_name, + template_data["imports"], nested.field[1], ) t = f"Dict[{k}, {v}]" @@ -228,14 +228,14 @@ def generate_code(request, response): f.Type.Name(nested.field[0].type), f.Type.Name(nested.field[1].type), ) - output["typing_imports"].add("Dict") + template_data["typing_imports"].add("Dict") if f.label == 3 and field_type != "map": # Repeated field repeated = True t = f"List[{t}]" zero = "[]" - output["typing_imports"].add("List") + template_data["typing_imports"].add("List") if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]: packed = True @@ -245,12 +245,12 @@ def generate_code(request, response): one_of = item.oneof_decl[f.oneof_index].name if "Optional[" in t: - output["typing_imports"].add("Optional") + template_data["typing_imports"].add("Optional") if "timedelta" in t: - output["datetime_imports"].add("timedelta") + template_data["datetime_imports"].add("timedelta") elif "datetime" in t: - output["datetime_imports"].add("datetime") + template_data["datetime_imports"].add("datetime") data["properties"].append( { @@ -271,7 +271,7 @@ def generate_code(request, response): ) # print(f, file=sys.stderr) - output["messages"].append(data) + template_data["messages"].append(data) elif isinstance(item, EnumDescriptorProto): # print(item.name, path, file=sys.stderr) data.update( @@ -289,7 +289,7 @@ def generate_code(request, response): } ) - output["enums"].append(data) + template_data["enums"].append(data) for i, service in enumerate(proto_file.service): # print(service, file=sys.stderr) @@ -304,14 +304,14 @@ def generate_code(request, response): for j, method in enumerate(service.method): input_message = None input_type = get_type_reference( - package, output["imports"], method.input_type + input_package_name, template_data["imports"], method.input_type ).strip('"') - for msg in output["messages"]: + for msg in template_data["messages"]: if msg["name"] == input_type: input_message = msg for field in msg["properties"]: if field["zero"] == "None": - output["typing_imports"].add("Optional") + template_data["typing_imports"].add("Optional") break data["methods"].append( @@ -319,14 +319,14 @@ def generate_code(request, response): "name": method.name, "py_name": pythonize_method_name(method.name), "comment": get_comment(proto_file, [6, i, 2, j], indent=8), - "route": f"/{package}.{service.name}/{method.name}", + "route": f"/{input_package_name}.{service.name}/{method.name}", "input": get_type_reference( - package, output["imports"], method.input_type + input_package_name, template_data["imports"], method.input_type ).strip('"'), "input_message": input_message, "output": get_type_reference( - package, - output["imports"], + input_package_name, + template_data["imports"], method.output_type, unwrap=False, ), @@ -336,30 +336,32 @@ def generate_code(request, response): ) if method.client_streaming: - output["typing_imports"].add("AsyncIterable") - output["typing_imports"].add("Iterable") - output["typing_imports"].add("Union") + template_data["typing_imports"].add("AsyncIterable") + template_data["typing_imports"].add("Iterable") + template_data["typing_imports"].add("Union") if method.server_streaming: - output["typing_imports"].add("AsyncIterator") + template_data["typing_imports"].add("AsyncIterator") - output["services"].append(data) + template_data["services"].append(data) - output["imports"] = sorted(output["imports"]) - output["datetime_imports"] = sorted(output["datetime_imports"]) - output["typing_imports"] = sorted(output["typing_imports"]) + template_data["imports"] = sorted(template_data["imports"]) + template_data["datetime_imports"] = sorted(template_data["datetime_imports"]) + template_data["typing_imports"] = sorted(template_data["typing_imports"]) # Fill response + output_path = pathlib.Path(*output_package_name.split("."), "__init__.py") + output_paths.add(output_path) + f = response.file.add() - f.name = filename + f.name = str(output_path) # Render and then format the output file. f.content = black.format_str( - template.render(description=output), + template.render(description=template_data), mode=black.FileMode(target_versions=set([black.TargetVersion.PY37])), ) # Make each output directory a package with __init__ file - output_paths = set(pathlib.Path(path) for path in output_map.keys()) init_files = ( set( directory.joinpath("__init__.py") @@ -373,8 +375,8 @@ def generate_code(request, response): init = response.file.add() init.name = str(init_file) - for filename in sorted(output_paths.union(init_files)): - print(f"Writing {filename}", file=sys.stderr) + for output_package_name in sorted(output_paths.union(init_files)): + print(f"Writing {output_package_name}", file=sys.stderr) def main(): From 87b3a4b86dcfe433cbfd4a02694d4cc88fe02853 Mon Sep 17 00:00:00 2001 From: boukeversteegh Date: Sun, 5 Jul 2020 12:27:06 +0200 Subject: [PATCH 4/7] Move parsing of protobuf data types and services into separate methods --- betterproto/plugin.py | 360 +++++++++++++++++++++--------------------- 1 file changed, 182 insertions(+), 178 deletions(-) diff --git a/betterproto/plugin.py b/betterproto/plugin.py index 8791c242..513503c5 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -163,186 +163,10 @@ def generate_code(request, response): for proto_file in output_package_content["files"]: item: DescriptorProto for item, path in traverse(proto_file): - data = {"name": item.name, "py_name": pythonize_class_name(item.name)} - - if isinstance(item, DescriptorProto): - # print(item, file=sys.stderr) - if item.options.map_entry: - # Skip generated map entry messages since we just use dicts - continue - - data.update( - { - "type": "Message", - "comment": get_comment(proto_file, path), - "properties": [], - } - ) - - for i, f in enumerate(item.field): - t = py_type(input_package_name, template_data["imports"], f) - zero = get_py_zero(f.type) - - repeated = False - packed = False - - field_type = f.Type.Name(f.type).lower()[5:] - - field_wraps = "" - match_wrapper = re.match( - r"\.google\.protobuf\.(.+)Value", f.type_name - ) - if match_wrapper: - wrapped_type = "TYPE_" + match_wrapper.group(1).upper() - if hasattr(betterproto, wrapped_type): - field_wraps = f"betterproto.{wrapped_type}" - - map_types = None - if f.type == 11: - # This might be a map... - message_type = f.type_name.split(".").pop().lower() - # message_type = py_type(package) - map_entry = f"{f.name.replace('_', '').lower()}entry" - - if message_type == map_entry: - for nested in item.nested_type: - if ( - nested.name.replace("_", "").lower() - == map_entry - ): - if nested.options.map_entry: - # print("Found a map!", file=sys.stderr) - k = py_type( - input_package_name, - template_data["imports"], - nested.field[0], - ) - v = py_type( - input_package_name, - template_data["imports"], - nested.field[1], - ) - t = f"Dict[{k}, {v}]" - field_type = "map" - map_types = ( - f.Type.Name(nested.field[0].type), - f.Type.Name(nested.field[1].type), - ) - template_data["typing_imports"].add("Dict") - - if f.label == 3 and field_type != "map": - # Repeated field - repeated = True - t = f"List[{t}]" - zero = "[]" - template_data["typing_imports"].add("List") - - if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]: - packed = True - - one_of = "" - if f.HasField("oneof_index"): - one_of = item.oneof_decl[f.oneof_index].name - - if "Optional[" in t: - template_data["typing_imports"].add("Optional") - - if "timedelta" in t: - template_data["datetime_imports"].add("timedelta") - elif "datetime" in t: - template_data["datetime_imports"].add("datetime") - - data["properties"].append( - { - "name": f.name, - "py_name": pythonize_field_name(f.name), - "number": f.number, - "comment": get_comment(proto_file, path + [2, i]), - "proto_type": int(f.type), - "field_type": field_type, - "field_wraps": field_wraps, - "map_types": map_types, - "type": t, - "zero": zero, - "repeated": repeated, - "packed": packed, - "one_of": one_of, - } - ) - # print(f, file=sys.stderr) - - template_data["messages"].append(data) - elif isinstance(item, EnumDescriptorProto): - # print(item.name, path, file=sys.stderr) - data.update( - { - "type": "Enum", - "comment": get_comment(proto_file, path), - "entries": [ - { - "name": v.name, - "value": v.number, - "comment": get_comment(proto_file, path + [2, i]), - } - for i, v in enumerate(item.value) - ], - } - ) - - template_data["enums"].append(data) + read_protobuf_type(input_package_name, item, path, proto_file, template_data) for i, service in enumerate(proto_file.service): - # print(service, file=sys.stderr) - - data = { - "name": service.name, - "py_name": pythonize_class_name(service.name), - "comment": get_comment(proto_file, [6, i]), - "methods": [], - } - - for j, method in enumerate(service.method): - input_message = None - input_type = get_type_reference( - input_package_name, template_data["imports"], method.input_type - ).strip('"') - for msg in template_data["messages"]: - if msg["name"] == input_type: - input_message = msg - for field in msg["properties"]: - if field["zero"] == "None": - template_data["typing_imports"].add("Optional") - break - - data["methods"].append( - { - "name": method.name, - "py_name": pythonize_method_name(method.name), - "comment": get_comment(proto_file, [6, i, 2, j], indent=8), - "route": f"/{input_package_name}.{service.name}/{method.name}", - "input": get_type_reference( - input_package_name, template_data["imports"], method.input_type - ).strip('"'), - "input_message": input_message, - "output": get_type_reference( - input_package_name, - template_data["imports"], - method.output_type, - unwrap=False, - ), - "client_streaming": method.client_streaming, - "server_streaming": method.server_streaming, - } - ) - - if method.client_streaming: - template_data["typing_imports"].add("AsyncIterable") - template_data["typing_imports"].add("Iterable") - template_data["typing_imports"].add("Union") - if method.server_streaming: - template_data["typing_imports"].add("AsyncIterator") - - template_data["services"].append(data) + read_protobuf_service(i, input_package_name, proto_file, service, template_data) template_data["imports"] = sorted(template_data["imports"]) template_data["datetime_imports"] = sorted(template_data["datetime_imports"]) @@ -379,6 +203,186 @@ def generate_code(request, response): print(f"Writing {output_package_name}", file=sys.stderr) +def read_protobuf_service(i, input_package_name, proto_file, service, template_data): + # print(service, file=sys.stderr) + data = { + "name": service.name, + "py_name": pythonize_class_name(service.name), + "comment": get_comment(proto_file, [6, i]), + "methods": [], + } + for j, method in enumerate(service.method): + input_message = None + input_type = get_type_reference( + input_package_name, template_data["imports"], method.input_type + ).strip('"') + for msg in template_data["messages"]: + if msg["name"] == input_type: + input_message = msg + for field in msg["properties"]: + if field["zero"] == "None": + template_data["typing_imports"].add("Optional") + break + + data["methods"].append( + { + "name": method.name, + "py_name": pythonize_method_name(method.name), + "comment": get_comment(proto_file, [6, i, 2, j], indent=8), + "route": f"/{input_package_name}.{service.name}/{method.name}", + "input": get_type_reference( + input_package_name, template_data["imports"], method.input_type + ).strip('"'), + "input_message": input_message, + "output": get_type_reference( + input_package_name, + template_data["imports"], + method.output_type, + unwrap=False, + ), + "client_streaming": method.client_streaming, + "server_streaming": method.server_streaming, + } + ) + + if method.client_streaming: + template_data["typing_imports"].add("AsyncIterable") + template_data["typing_imports"].add("Iterable") + template_data["typing_imports"].add("Union") + if method.server_streaming: + template_data["typing_imports"].add("AsyncIterator") + template_data["services"].append(data) + + +def read_protobuf_type(input_package_name, item, path, proto_file, template_data): + data = {"name": item.name, "py_name": pythonize_class_name(item.name)} + if isinstance(item, DescriptorProto): + # print(item, file=sys.stderr) + if item.options.map_entry: + # Skip generated map entry messages since we just use dicts + return + + data.update( + { + "type": "Message", + "comment": get_comment(proto_file, path), + "properties": [], + } + ) + + for i, f in enumerate(item.field): + t = py_type(input_package_name, template_data["imports"], f) + zero = get_py_zero(f.type) + + repeated = False + packed = False + + field_type = f.Type.Name(f.type).lower()[5:] + + field_wraps = "" + match_wrapper = re.match( + r"\.google\.protobuf\.(.+)Value", f.type_name + ) + if match_wrapper: + wrapped_type = "TYPE_" + match_wrapper.group(1).upper() + if hasattr(betterproto, wrapped_type): + field_wraps = f"betterproto.{wrapped_type}" + + map_types = None + if f.type == 11: + # This might be a map... + message_type = f.type_name.split(".").pop().lower() + # message_type = py_type(package) + map_entry = f"{f.name.replace('_', '').lower()}entry" + + if message_type == map_entry: + for nested in item.nested_type: + if ( + nested.name.replace("_", "").lower() + == map_entry + ): + if nested.options.map_entry: + # print("Found a map!", file=sys.stderr) + k = py_type( + input_package_name, + template_data["imports"], + nested.field[0], + ) + v = py_type( + input_package_name, + template_data["imports"], + nested.field[1], + ) + t = f"Dict[{k}, {v}]" + field_type = "map" + map_types = ( + f.Type.Name(nested.field[0].type), + f.Type.Name(nested.field[1].type), + ) + template_data["typing_imports"].add("Dict") + + if f.label == 3 and field_type != "map": + # Repeated field + repeated = True + t = f"List[{t}]" + zero = "[]" + template_data["typing_imports"].add("List") + + if f.type in [1, 2, 3, 4, 5, 6, 7, 8, 13, 15, 16, 17, 18]: + packed = True + + one_of = "" + if f.HasField("oneof_index"): + one_of = item.oneof_decl[f.oneof_index].name + + if "Optional[" in t: + template_data["typing_imports"].add("Optional") + + if "timedelta" in t: + template_data["datetime_imports"].add("timedelta") + elif "datetime" in t: + template_data["datetime_imports"].add("datetime") + + data["properties"].append( + { + "name": f.name, + "py_name": pythonize_field_name(f.name), + "number": f.number, + "comment": get_comment(proto_file, path + [2, i]), + "proto_type": int(f.type), + "field_type": field_type, + "field_wraps": field_wraps, + "map_types": map_types, + "type": t, + "zero": zero, + "repeated": repeated, + "packed": packed, + "one_of": one_of, + } + ) + # print(f, file=sys.stderr) + + template_data["messages"].append(data) + elif isinstance(item, EnumDescriptorProto): + # print(item.name, path, file=sys.stderr) + data.update( + { + "type": "Enum", + "comment": get_comment(proto_file, path), + "entries": [ + { + "name": v.name, + "value": v.number, + "comment": get_comment(proto_file, path + [2, i]), + } + for i, v in enumerate(item.value) + ], + } + ) + + template_data["enums"].append(data) + + def main(): """The plugin's main entry point.""" # Read request message from stdin From dedead048f1d3e4ea1e14ac40cfad16b400a4c98 Mon Sep 17 00:00:00 2001 From: boukeversteegh Date: Sun, 5 Jul 2020 13:10:25 +0200 Subject: [PATCH 5/7] Read proto objects before services --- betterproto/plugin.py | 46 ++++++++++++++++++++++++++----------------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/betterproto/plugin.py b/betterproto/plugin.py index 513503c5..a795efae 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -76,6 +76,7 @@ def get_py_zero(type_num: int) -> Union[str, float]: return zero +# Todo: Keep information about nested hierarchy def traverse(proto_file): def _traverse(path, items, prefix=""): for i, item in enumerate(items): @@ -146,11 +147,10 @@ def generate_code(request, response): ) output_package_files[output_package]["files"].append(proto_file) - output_paths = set() + # Initialize Template data for each package for output_package_name, output_package_content in output_package_files.items(): - input_package_name = output_package_content["input_package"] template_data = { - "input_package": input_package_name, + "input_package": output_package_content["input_package"], "files": [f.name for f in output_package_content["files"]], "imports": set(), "datetime_imports": set(), @@ -159,15 +159,26 @@ def generate_code(request, response): "enums": [], "services": [], } + output_package_content["template_data"] = template_data + # Read Messages and Enums + for output_package_name, output_package_content in output_package_files.items(): for proto_file in output_package_content["files"]: - item: DescriptorProto for item, path in traverse(proto_file): - read_protobuf_type(input_package_name, item, path, proto_file, template_data) + read_protobuf_object(item, path, proto_file, output_package_content) - for i, service in enumerate(proto_file.service): - read_protobuf_service(i, input_package_name, proto_file, service, template_data) + # Read Services + for output_package_name, output_package_content in output_package_files.items(): + for proto_file in output_package_content["files"]: + for index, service in enumerate(proto_file.service): + read_protobuf_service( + service, index, proto_file, output_package_content + ) + # Render files + output_paths = set() + for output_package_name, output_package_content in output_package_files.items(): + template_data = output_package_content["template_data"] template_data["imports"] = sorted(template_data["imports"]) template_data["datetime_imports"] = sorted(template_data["datetime_imports"]) template_data["typing_imports"] = sorted(template_data["typing_imports"]) @@ -203,12 +214,14 @@ def generate_code(request, response): print(f"Writing {output_package_name}", file=sys.stderr) -def read_protobuf_service(i, input_package_name, proto_file, service, template_data): +def read_protobuf_service(service: DescriptorProto, index, proto_file, content): + input_package_name = content["input_package"] + template_data = content["template_data"] # print(service, file=sys.stderr) data = { "name": service.name, "py_name": pythonize_class_name(service.name), - "comment": get_comment(proto_file, [6, i]), + "comment": get_comment(proto_file, [6, index]), "methods": [], } for j, method in enumerate(service.method): @@ -228,7 +241,7 @@ def read_protobuf_service(i, input_package_name, proto_file, service, template_d { "name": method.name, "py_name": pythonize_method_name(method.name), - "comment": get_comment(proto_file, [6, i, 2, j], indent=8), + "comment": get_comment(proto_file, [6, index, 2, j], indent=8), "route": f"/{input_package_name}.{service.name}/{method.name}", "input": get_type_reference( input_package_name, template_data["imports"], method.input_type @@ -254,7 +267,9 @@ def read_protobuf_service(i, input_package_name, proto_file, service, template_d template_data["services"].append(data) -def read_protobuf_type(input_package_name, item, path, proto_file, template_data): +def read_protobuf_object(item: DescriptorProto, path: List[int], proto_file, content): + input_package_name = content["input_package"] + template_data = content["template_data"] data = {"name": item.name, "py_name": pythonize_class_name(item.name)} if isinstance(item, DescriptorProto): # print(item, file=sys.stderr) @@ -280,9 +295,7 @@ def read_protobuf_type(input_package_name, item, path, proto_file, template_data field_type = f.Type.Name(f.type).lower()[5:] field_wraps = "" - match_wrapper = re.match( - r"\.google\.protobuf\.(.+)Value", f.type_name - ) + match_wrapper = re.match(r"\.google\.protobuf\.(.+)Value", f.type_name) if match_wrapper: wrapped_type = "TYPE_" + match_wrapper.group(1).upper() if hasattr(betterproto, wrapped_type): @@ -297,10 +310,7 @@ def read_protobuf_type(input_package_name, item, path, proto_file, template_data if message_type == map_entry: for nested in item.nested_type: - if ( - nested.name.replace("_", "").lower() - == map_entry - ): + if nested.name.replace("_", "").lower() == map_entry: if nested.options.map_entry: # print("Found a map!", file=sys.stderr) k = py_type( From 3f519d4fb1216a127ebbc6b228e21b46cc239cfd Mon Sep 17 00:00:00 2001 From: boukeversteegh Date: Sun, 5 Jul 2020 17:14:53 +0200 Subject: [PATCH 6/7] Fixes #23 again, a broken test made it seem the issue was fixed before. --- CHANGELOG.md | 3 +- betterproto/plugin.py | 138 +++++++++++------- .../child_package_request_message.proto | 7 + .../import_service_input_message.proto | 8 + .../test_import_service.py | 16 -- .../test_import_service_input_message.py | 31 ++++ 6 files changed, 129 insertions(+), 74 deletions(-) create mode 100644 betterproto/tests/inputs/import_service_input_message/child_package_request_message.proto delete mode 100644 betterproto/tests/inputs/import_service_input_message/test_import_service.py create mode 100644 betterproto/tests/inputs/import_service_input_message/test_import_service_input_message.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 383d3f78..c5c65b4c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,8 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 > `2.0.0` will be released once the interface is stable. - Add support for gRPC and **stream-stream** [#83](https://github.com/danielgtaylor/python-betterproto/pull/83) -- Switch from to `poetry` for development [#75](https://github.com/danielgtaylor/python-betterproto/pull/75) -- Fix No arguments are generated for stub methods when using import with proto definition +- Switch from `pipenv` to `poetry` for development [#75](https://github.com/danielgtaylor/python-betterproto/pull/75) - Fix two packages with the same name suffix should not cause naming conflict [#25](https://github.com/danielgtaylor/python-betterproto/issues/25) - Fix Import child package from root [#57](https://github.com/danielgtaylor/python-betterproto/issues/57) diff --git a/betterproto/plugin.py b/betterproto/plugin.py index a795efae..4ab1b93f 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -11,12 +11,13 @@ from google.protobuf.compiler.plugin_pb2 import CodeGeneratorRequest import betterproto -from betterproto.compile.importing import get_type_reference +from betterproto.compile.importing import get_type_reference, parse_source_type_name from betterproto.compile.naming import ( pythonize_class_name, pythonize_field_name, pythonize_method_name, ) +from betterproto.lib.google.protobuf import ServiceDescriptorProto try: # betterproto[compiler] specific dependencies @@ -76,11 +77,12 @@ def get_py_zero(type_num: int) -> Union[str, float]: return zero -# Todo: Keep information about nested hierarchy def traverse(proto_file): + # Todo: Keep information about nested hierarchy def _traverse(path, items, prefix=""): for i, item in enumerate(items): - # Adjust the name since we flatten the heirarchy. + # Adjust the name since we flatten the hierarchy. + # Todo: don't change the name, but include full name in returned tuple item.name = next_prefix = prefix + item.name yield item, path + [i] @@ -162,17 +164,21 @@ def generate_code(request, response): output_package_content["template_data"] = template_data # Read Messages and Enums + output_types = [] for output_package_name, output_package_content in output_package_files.items(): for proto_file in output_package_content["files"]: for item, path in traverse(proto_file): - read_protobuf_object(item, path, proto_file, output_package_content) + type_data = read_protobuf_type( + item, path, proto_file, output_package_content + ) + output_types.append(type_data) # Read Services for output_package_name, output_package_content in output_package_files.items(): for proto_file in output_package_content["files"]: for index, service in enumerate(proto_file.service): read_protobuf_service( - service, index, proto_file, output_package_content + service, index, proto_file, output_package_content, output_types ) # Render files @@ -214,63 +220,31 @@ def generate_code(request, response): print(f"Writing {output_package_name}", file=sys.stderr) -def read_protobuf_service(service: DescriptorProto, index, proto_file, content): - input_package_name = content["input_package"] - template_data = content["template_data"] - # print(service, file=sys.stderr) - data = { - "name": service.name, - "py_name": pythonize_class_name(service.name), - "comment": get_comment(proto_file, [6, index]), - "methods": [], - } - for j, method in enumerate(service.method): - input_message = None - input_type = get_type_reference( - input_package_name, template_data["imports"], method.input_type - ).strip('"') - for msg in template_data["messages"]: - if msg["name"] == input_type: - input_message = msg - for field in msg["properties"]: - if field["zero"] == "None": - template_data["typing_imports"].add("Optional") - break +def lookup_method_input_type(method, types): + package, name = parse_source_type_name(method.input_type) - data["methods"].append( - { - "name": method.name, - "py_name": pythonize_method_name(method.name), - "comment": get_comment(proto_file, [6, index, 2, j], indent=8), - "route": f"/{input_package_name}.{service.name}/{method.name}", - "input": get_type_reference( - input_package_name, template_data["imports"], method.input_type - ).strip('"'), - "input_message": input_message, - "output": get_type_reference( - input_package_name, - template_data["imports"], - method.output_type, - unwrap=False, - ), - "client_streaming": method.client_streaming, - "server_streaming": method.server_streaming, - } - ) + for known_type in types: + if known_type["type"] != "Message": + continue - if method.client_streaming: - template_data["typing_imports"].add("AsyncIterable") - template_data["typing_imports"].add("Iterable") - template_data["typing_imports"].add("Union") - if method.server_streaming: - template_data["typing_imports"].add("AsyncIterator") - template_data["services"].append(data) + # Nested types are currently flattened without dots. + # Todo: keep a fully quantified name in types, that is comparable with method.input_type + if ( + package == known_type["package"] + and name.replace(".", "") == known_type["name"] + ): + return known_type -def read_protobuf_object(item: DescriptorProto, path: List[int], proto_file, content): +def read_protobuf_type(item: DescriptorProto, path: List[int], proto_file, content): input_package_name = content["input_package"] template_data = content["template_data"] - data = {"name": item.name, "py_name": pythonize_class_name(item.name)} + data = { + "name": item.name, + "py_name": pythonize_class_name(item.name), + "descriptor": item, + "package": input_package_name, + } if isinstance(item, DescriptorProto): # print(item, file=sys.stderr) if item.options.map_entry: @@ -373,6 +347,7 @@ def read_protobuf_object(item: DescriptorProto, path: List[int], proto_file, con # print(f, file=sys.stderr) template_data["messages"].append(data) + return data elif isinstance(item, EnumDescriptorProto): # print(item.name, path, file=sys.stderr) data.update( @@ -391,6 +366,57 @@ def read_protobuf_object(item: DescriptorProto, path: List[int], proto_file, con ) template_data["enums"].append(data) + return data + + +def read_protobuf_service( + service: ServiceDescriptorProto, index, proto_file, content, output_types +): + input_package_name = content["input_package"] + template_data = content["template_data"] + # print(service, file=sys.stderr) + data = { + "name": service.name, + "py_name": pythonize_class_name(service.name), + "comment": get_comment(proto_file, [6, index]), + "methods": [], + } + for j, method in enumerate(service.method): + method_input_message = lookup_method_input_type(method, output_types) + + if method_input_message: + for field in method_input_message["properties"]: + if field["zero"] == "None": + template_data["typing_imports"].add("Optional") + + data["methods"].append( + { + "name": method.name, + "py_name": pythonize_method_name(method.name), + "comment": get_comment(proto_file, [6, index, 2, j], indent=8), + "route": f"/{input_package_name}.{service.name}/{method.name}", + "input": get_type_reference( + input_package_name, template_data["imports"], method.input_type + ).strip('"'), + "input_message": method_input_message, + "output": get_type_reference( + input_package_name, + template_data["imports"], + method.output_type, + unwrap=False, + ), + "client_streaming": method.client_streaming, + "server_streaming": method.server_streaming, + } + ) + + if method.client_streaming: + template_data["typing_imports"].add("AsyncIterable") + template_data["typing_imports"].add("Iterable") + template_data["typing_imports"].add("Union") + if method.server_streaming: + template_data["typing_imports"].add("AsyncIterator") + template_data["services"].append(data) def main(): diff --git a/betterproto/tests/inputs/import_service_input_message/child_package_request_message.proto b/betterproto/tests/inputs/import_service_input_message/child_package_request_message.proto new file mode 100644 index 00000000..6380db24 --- /dev/null +++ b/betterproto/tests/inputs/import_service_input_message/child_package_request_message.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +package child; + +message ChildRequestMessage { + int32 child_argument = 1; +} \ No newline at end of file diff --git a/betterproto/tests/inputs/import_service_input_message/import_service_input_message.proto b/betterproto/tests/inputs/import_service_input_message/import_service_input_message.proto index a5073db7..7ca9c46f 100644 --- a/betterproto/tests/inputs/import_service_input_message/import_service_input_message.proto +++ b/betterproto/tests/inputs/import_service_input_message/import_service_input_message.proto @@ -1,11 +1,14 @@ syntax = "proto3"; import "request_message.proto"; +import "child_package_request_message.proto"; // Tests generated service correctly imports the RequestMessage service Test { rpc DoThing (RequestMessage) returns (RequestResponse); + rpc DoThing2 (child.ChildRequestMessage) returns (RequestResponse); + rpc DoThing3 (Nested.RequestMessage) returns (RequestResponse); } @@ -13,3 +16,8 @@ message RequestResponse { int32 value = 1; } +message Nested { + message RequestMessage { + int32 nestedArgument = 1; + } +} \ No newline at end of file diff --git a/betterproto/tests/inputs/import_service_input_message/test_import_service.py b/betterproto/tests/inputs/import_service_input_message/test_import_service.py deleted file mode 100644 index 891b77ab..00000000 --- a/betterproto/tests/inputs/import_service_input_message/test_import_service.py +++ /dev/null @@ -1,16 +0,0 @@ -import pytest - -from betterproto.tests.mocks import MockChannel -from betterproto.tests.output_betterproto.import_service_input_message import ( - RequestResponse, - TestStub, -) - - -@pytest.mark.xfail(reason="#68 Request Input Messages are not imported for service") -@pytest.mark.asyncio -async def test_service_correctly_imports_reference_message(): - mock_response = RequestResponse(value=10) - service = TestStub(MockChannel([mock_response])) - response = await service.do_thing() - assert mock_response == response diff --git a/betterproto/tests/inputs/import_service_input_message/test_import_service_input_message.py b/betterproto/tests/inputs/import_service_input_message/test_import_service_input_message.py new file mode 100644 index 00000000..e53fc485 --- /dev/null +++ b/betterproto/tests/inputs/import_service_input_message/test_import_service_input_message.py @@ -0,0 +1,31 @@ +import pytest + +from betterproto.tests.mocks import MockChannel +from betterproto.tests.output_betterproto.import_service_input_message import ( + RequestResponse, + TestStub, +) + + +@pytest.mark.asyncio +async def test_service_correctly_imports_reference_message(): + mock_response = RequestResponse(value=10) + service = TestStub(MockChannel([mock_response])) + response = await service.do_thing(argument=1) + assert mock_response == response + + +@pytest.mark.asyncio +async def test_service_correctly_imports_reference_message_from_child_package(): + mock_response = RequestResponse(value=10) + service = TestStub(MockChannel([mock_response])) + response = await service.do_thing2(child_argument=1) + assert mock_response == response + + +@pytest.mark.asyncio +async def test_service_correctly_imports_nested_reference(): + mock_response = RequestResponse(value=10) + service = TestStub(MockChannel([mock_response])) + response = await service.do_thing3(nested_argument=1) + assert mock_response == response From 1d7ba850e91265ab3e4fcfb9accad4bebff24305 Mon Sep 17 00:00:00 2001 From: boukeversteegh Date: Thu, 9 Jul 2020 23:09:34 +0200 Subject: [PATCH 7/7] Reorder methods, use BETTERPROTO_DUMP for dump env var, docs. --- betterproto/plugin.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/betterproto/plugin.py b/betterproto/plugin.py index 8297a160..0d88d477 100755 --- a/betterproto/plugin.py +++ b/betterproto/plugin.py @@ -220,22 +220,6 @@ def generate_code(request, response): print(f"Writing {output_package_name}", file=sys.stderr) -def lookup_method_input_type(method, types): - package, name = parse_source_type_name(method.input_type) - - for known_type in types: - if known_type["type"] != "Message": - continue - - # Nested types are currently flattened without dots. - # Todo: keep a fully quantified name in types, that is comparable with method.input_type - if ( - package == known_type["package"] - and name.replace(".", "") == known_type["name"] - ): - return known_type - - def read_protobuf_type(item: DescriptorProto, path: List[int], proto_file, content): input_package_name = content["input_package"] template_data = content["template_data"] @@ -369,6 +353,22 @@ def read_protobuf_type(item: DescriptorProto, path: List[int], proto_file, conte return data +def lookup_method_input_type(method, types): + package, name = parse_source_type_name(method.input_type) + + for known_type in types: + if known_type["type"] != "Message": + continue + + # Nested types are currently flattened without dots. + # Todo: keep a fully quantified name in types, that is comparable with method.input_type + if ( + package == known_type["package"] + and name.replace(".", "") == known_type["name"] + ): + return known_type + + def read_protobuf_service( service: ServiceDescriptorProto, index, proto_file, content, output_types ): @@ -428,7 +428,7 @@ def main(): request = plugin.CodeGeneratorRequest() request.ParseFromString(data) - dump_file = os.getenv("DUMP_FILE") + dump_file = os.getenv("BETTERPROTO_DUMP") if dump_file: dump_request(dump_file, request) @@ -448,11 +448,11 @@ def main(): def dump_request(dump_file: str, request: CodeGeneratorRequest): """ For developers: Supports running plugin.py standalone so its possible to debug it. - Run protoc (or generate.py) with DUMP_FILE="yourfile.bin" to write the request to a file. + Run protoc (or generate.py) with BETTERPROTO_DUMP="yourfile.bin" to write the request to a file. Then run plugin.py from your IDE in debugging mode, and redirect stdin to the file. """ with open(str(dump_file), "wb") as fh: - sys.stderr.write(f"\033[31mWriting: {dump_file}\033[0m\n") + sys.stderr.write(f"\033[31mWriting input from protoc to: {dump_file}\033[0m\n") fh.write(request.SerializeToString())