Skip to content

Commit

Permalink
Fix compilation of fields with name identical to their type (#294)
Browse files Browse the repository at this point in the history
* Revert "Fix compilation of fields named 'bytes' or 'str' (#226)"

This reverts commit deb623e.

* Fix compilation of fileds with name identical to their type

* Added test for field-name identical to python type

Co-authored-by: Guy Szweigman <guysz@nvidia.com>
  • Loading branch information
guysz and guysz authored Dec 1, 2021
1 parent a4d2d39 commit b0a36d1
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 13 deletions.
12 changes: 1 addition & 11 deletions src/betterproto/casing.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,16 +133,6 @@ def lowercase_first(value: str) -> str:
return value[0:1].lower() + value[1:]


def is_reserved_name(value: str) -> bool:
if keyword.iskeyword(value):
return True

if value in ("bytes", "str"):
return True

return False


def sanitize_name(value: str) -> str:
# https://www.python.org/dev/peps/pep-0008/#descriptive-naming-styles
return f"{value}_" if is_reserved_name(value) else value
return f"{value}_" if keyword.iskeyword(value) else value
21 changes: 19 additions & 2 deletions src/betterproto/plugin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"""


import builtins
import betterproto
from betterproto import which_one_of
from betterproto.casing import sanitize_name
Expand Down Expand Up @@ -237,6 +238,7 @@ class OutputTemplate:
imports: Set[str] = field(default_factory=set)
datetime_imports: Set[str] = field(default_factory=set)
typing_imports: Set[str] = field(default_factory=set)
builtins_import: bool = False
messages: List["MessageCompiler"] = field(default_factory=list)
enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
services: List["ServiceCompiler"] = field(default_factory=list)
Expand Down Expand Up @@ -268,6 +270,8 @@ def python_module_imports(self) -> Set[str]:
imports = set()
if any(x for x in self.messages if any(x.deprecated_fields)):
imports.add("warnings")
if self.builtins_import:
imports.add("builtins")
return imports


Expand All @@ -283,6 +287,7 @@ class MessageCompiler(ProtoContentBase):
default_factory=list
)
deprecated: bool = field(default=False, init=False)
builtins_types: Set[str] = field(default_factory=set)

def __post_init__(self) -> None:
# Add message to output file
Expand Down Expand Up @@ -376,6 +381,8 @@ def get_field_string(self, indent: int = 4) -> str:
betterproto_field_type = (
f"betterproto.{self.field_type}_field({self.proto_obj.number}{field_args})"
)
if self.py_name in dir(builtins):
self.parent.builtins_types.add(self.py_name)
return f"{name}{annotations} = {betterproto_field_type}"

@property
Expand Down Expand Up @@ -408,9 +415,16 @@ def typing_imports(self) -> Set[str]:
imports.add("Dict")
return imports

@property
def use_builtins(self) -> bool:
return self.py_type in self.parent.builtins_types or (
self.py_type == self.py_name and self.py_name in dir(builtins)
)

def add_imports_to(self, output_file: OutputTemplate) -> None:
output_file.datetime_imports.update(self.datetime_imports)
output_file.typing_imports.update(self.typing_imports)
output_file.builtins_import = output_file.builtins_import or self.use_builtins

@property
def field_wraps(self) -> Optional[str]:
Expand Down Expand Up @@ -504,9 +518,12 @@ def py_type(self) -> str:

@property
def annotation(self) -> str:
py_type = self.py_type
if self.use_builtins:
py_type = f"builtins.{py_type}"
if self.repeated:
return f"List[{self.py_type}]"
return self.py_type
return f"List[{py_type}]"
return py_type


@dataclass
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"int": 26,
"float": 26.0,
"str": "value-for-str",
"bytes": "001a",
"bool": true
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
syntax = "proto3";

// Tests that messages may contain fields with names that are identical to their python types (PR #294)

message Test {
int32 int = 1;
float float = 2;
string str = 3;
bytes bytes = 4;
bool bool = 5;
}

0 comments on commit b0a36d1

Please sign in to comment.