Skip to content

Commit

Permalink
Handle typing collisions and add validation to a files module for ove…
Browse files Browse the repository at this point in the history
…rlaping declarations (#582)

* Fix 'typing' import collisions.

* Fix formatting.

* Fix self-test issues.

* Validation for modules, different typing configurations

* add readme

* make warning

* fix format

---------

Co-authored-by: Scott Hendricks <scott.hendricks@confluent.io>
  • Loading branch information
imcdo and scott-hendricks authored Jul 19, 2024
1 parent 7c6c627 commit 8b59234
Show file tree
Hide file tree
Showing 13 changed files with 887 additions and 157 deletions.
43 changes: 43 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,50 @@ swap the dataclass implementation from the builtin python dataclass to the
pydantic dataclass. You must have pydantic as a dependency in your project for
this to work.

## Configuration typing imports

By default typing types will be imported directly from typing. This sometimes can lead to issues in generation if types that are being generated conflict with the name. In this case you can configure the way types are imported from 3 different options:

### Direct
```
protoc -I . --python_betterproto_opt=typing.direct --python_betterproto_out=lib example.proto
```
this configuration is the default, and will import types as follows:
```
from typing import (
List,
Optional,
Union
)
...
value: List[str] = []
value2: Optional[str] = None
value3: Union[str, int] = 1
```
### Root
```
protoc -I . --python_betterproto_opt=typing.root --python_betterproto_out=lib example.proto
```
this configuration loads the root typing module, and then access the types off of it directly:
```
import typing
...
value: typing.List[str] = []
value2: typing.Optional[str] = None
value3: typing.Union[str, int] = 1
```

### 310
```
protoc -I . --python_betterproto_opt=typing.310 --python_betterproto_out=lib example.proto
```
this configuration avoid loading typing all together if possible and uses the python 3.10 pattern:
```
...
value: list[str] = []
value2: str | None = None
value3: str | int = 1
```

## Development

Expand Down
3 changes: 2 additions & 1 deletion src/betterproto/compile/importing.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def get_type_reference(
package: str,
imports: set,
source_type: str,
typing_compiler: "TypingCompiler",
unwrap: bool = True,
pydantic: bool = False,
) -> str:
Expand All @@ -57,7 +58,7 @@ def get_type_reference(
if unwrap:
if source_type in WRAPPER_TYPES:
wrapped_type = type(WRAPPER_TYPES[source_type]().value)
return f"Optional[{wrapped_type.__name__}]"
return typing_compiler.optional(wrapped_type.__name__)

if source_type == ".google.protobuf.Duration":
return "timedelta"
Expand Down
23 changes: 20 additions & 3 deletions src/betterproto/plugin/compiler.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import os.path
import sys

from .module_validation import ModuleValidator


try:
Expand Down Expand Up @@ -30,9 +33,12 @@ def outputfile_compiler(output_file: OutputTemplate) -> str:
lstrip_blocks=True,
loader=jinja2.FileSystemLoader(templates_folder),
)
template = env.get_template("template.py.j2")
# Load the body first so we have a compleate list of imports needed.
body_template = env.get_template("template.py.j2")
header_template = env.get_template("header.py.j2")

code = template.render(output_file=output_file)
code = body_template.render(output_file=output_file)
code = header_template.render(output_file=output_file) + code
code = isort.api.sort_code_string(
code=code,
show_diff=False,
Expand All @@ -44,7 +50,18 @@ def outputfile_compiler(output_file: OutputTemplate) -> str:
force_grid_wrap=2,
known_third_party=["grpclib", "betterproto"],
)
return black.format_str(
code = black.format_str(
src_contents=code,
mode=black.Mode(),
)

# Validate the generated code.
validator = ModuleValidator(iter(code.splitlines()))
if not validator.validate():
message_builder = ["[WARNING]: Generated code has collisions in the module:"]
for collision, lines in validator.collisions.items():
message_builder.append(f' "{collision}" on lines:')
for num, line in lines:
message_builder.append(f" {num}:{line}")
print("\n".join(message_builder), file=sys.stderr)
return code
64 changes: 20 additions & 44 deletions src/betterproto/plugin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,8 @@
reference to `A` to `B`'s `fields` attribute.
"""


import builtins
import re
import textwrap
from dataclasses import (
dataclass,
field,
Expand All @@ -49,12 +47,6 @@
)

import betterproto
from betterproto import which_one_of
from betterproto.casing import sanitize_name
from betterproto.compile.importing import (
get_type_reference,
parse_source_type_name,
)
from betterproto.compile.naming import (
pythonize_class_name,
pythonize_field_name,
Expand All @@ -72,6 +64,7 @@
)
from betterproto.lib.google.protobuf.compiler import CodeGeneratorRequest

from .. import which_one_of
from ..compile.importing import (
get_type_reference,
parse_source_type_name,
Expand All @@ -82,6 +75,10 @@
pythonize_field_name,
pythonize_method_name,
)
from .typing_compiler import (
DirectImportTypingCompiler,
TypingCompiler,
)


# Create a unique placeholder to deal with
Expand Down Expand Up @@ -173,6 +170,7 @@ class ProtoContentBase:
"""Methods common to MessageCompiler, ServiceCompiler and ServiceMethodCompiler."""

source_file: FileDescriptorProto
typing_compiler: TypingCompiler
path: List[int]
comment_indent: int = 4
parent: Union["betterproto.Message", "OutputTemplate"]
Expand Down Expand Up @@ -242,7 +240,6 @@ class OutputTemplate:
input_files: List[str] = field(default_factory=list)
imports: Set[str] = field(default_factory=set)
datetime_imports: Set[str] = field(default_factory=set)
typing_imports: Set[str] = field(default_factory=set)
pydantic_imports: Set[str] = field(default_factory=set)
builtins_import: bool = False
messages: List["MessageCompiler"] = field(default_factory=list)
Expand All @@ -251,6 +248,7 @@ class OutputTemplate:
imports_type_checking_only: Set[str] = field(default_factory=set)
pydantic_dataclasses: bool = False
output: bool = True
typing_compiler: TypingCompiler = field(default_factory=DirectImportTypingCompiler)

@property
def package(self) -> str:
Expand Down Expand Up @@ -289,6 +287,7 @@ class MessageCompiler(ProtoContentBase):
"""Representation of a protobuf message."""

source_file: FileDescriptorProto
typing_compiler: TypingCompiler
parent: Union["MessageCompiler", OutputTemplate] = PLACEHOLDER
proto_obj: DescriptorProto = PLACEHOLDER
path: List[int] = PLACEHOLDER
Expand Down Expand Up @@ -319,7 +318,7 @@ def py_name(self) -> str:
@property
def annotation(self) -> str:
if self.repeated:
return f"List[{self.py_name}]"
return self.typing_compiler.list(self.py_name)
return self.py_name

@property
Expand Down Expand Up @@ -434,18 +433,6 @@ def datetime_imports(self) -> Set[str]:
imports.add("datetime")
return imports

@property
def typing_imports(self) -> Set[str]:
imports = set()
annotation = self.annotation
if "Optional[" in annotation:
imports.add("Optional")
if "List[" in annotation:
imports.add("List")
if "Dict[" in annotation:
imports.add("Dict")
return imports

@property
def pydantic_imports(self) -> Set[str]:
return set()
Expand All @@ -458,7 +445,6 @@ def use_builtins(self) -> bool:

def add_imports_to(self, output_file: OutputTemplate) -> None:
output_file.datetime_imports.update(self.datetime_imports)
output_file.typing_imports.update(self.typing_imports)
output_file.pydantic_imports.update(self.pydantic_imports)
output_file.builtins_import = output_file.builtins_import or self.use_builtins

Expand Down Expand Up @@ -488,7 +474,9 @@ def optional(self) -> bool:
@property
def mutable(self) -> bool:
"""True if the field is a mutable type, otherwise False."""
return self.annotation.startswith(("List[", "Dict["))
return self.annotation.startswith(
("typing.List[", "typing.Dict[", "dict[", "list[", "Dict[", "List[")
)

@property
def field_type(self) -> str:
Expand Down Expand Up @@ -562,6 +550,7 @@ def py_type(self) -> str:
package=self.output_file.package,
imports=self.output_file.imports,
source_type=self.proto_obj.type_name,
typing_compiler=self.typing_compiler,
pydantic=self.output_file.pydantic_dataclasses,
)
else:
Expand All @@ -573,9 +562,9 @@ def annotation(self) -> str:
if self.use_builtins:
py_type = f"builtins.{py_type}"
if self.repeated:
return f"List[{py_type}]"
return self.typing_compiler.list(py_type)
if self.optional:
return f"Optional[{py_type}]"
return self.typing_compiler.optional(py_type)
return py_type


Expand Down Expand Up @@ -623,11 +612,13 @@ def __post_init__(self) -> None:
source_file=self.source_file,
parent=self,
proto_obj=nested.field[0], # key
typing_compiler=self.typing_compiler,
).py_type
self.py_v_type = FieldCompiler(
source_file=self.source_file,
parent=self,
proto_obj=nested.field[1], # value
typing_compiler=self.typing_compiler,
).py_type

# Get proto types
Expand All @@ -645,7 +636,7 @@ def field_type(self) -> str:

@property
def annotation(self) -> str:
return f"Dict[{self.py_k_type}, {self.py_v_type}]"
return self.typing_compiler.dict(self.py_k_type, self.py_v_type)

@property
def repeated(self) -> bool:
Expand Down Expand Up @@ -702,7 +693,6 @@ class ServiceCompiler(ProtoContentBase):
def __post_init__(self) -> None:
# Add service to output file
self.output_file.services.append(self)
self.output_file.typing_imports.add("Dict")
super().__post_init__() # check for unset fields

@property
Expand All @@ -725,22 +715,6 @@ def __post_init__(self) -> None:
# Add method to service
self.parent.methods.append(self)

# Check for imports
if "Optional" in self.py_output_message_type:
self.output_file.typing_imports.add("Optional")

# Check for Async imports
if self.client_streaming:
self.output_file.typing_imports.add("AsyncIterable")
self.output_file.typing_imports.add("Iterable")
self.output_file.typing_imports.add("Union")

# Required by both client and server
if self.client_streaming or self.server_streaming:
self.output_file.typing_imports.add("AsyncIterator")

# add imports required for request arguments timeout, deadline and metadata
self.output_file.typing_imports.add("Optional")
self.output_file.imports_type_checking_only.add("import grpclib.server")
self.output_file.imports_type_checking_only.add(
"from betterproto.grpc.grpclib_client import MetadataLike"
Expand Down Expand Up @@ -806,6 +780,7 @@ def py_input_message_type(self) -> str:
package=self.output_file.package,
imports=self.output_file.imports,
source_type=self.proto_obj.input_type,
typing_compiler=self.output_file.typing_compiler,
unwrap=False,
pydantic=self.output_file.pydantic_dataclasses,
).strip('"')
Expand Down Expand Up @@ -835,6 +810,7 @@ def py_output_message_type(self) -> str:
package=self.output_file.package,
imports=self.output_file.imports,
source_type=self.proto_obj.output_type,
typing_compiler=self.output_file.typing_compiler,
unwrap=False,
pydantic=self.output_file.pydantic_dataclasses,
).strip('"')
Expand Down
Loading

0 comments on commit 8b59234

Please sign in to comment.