Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for pydantic dataclasses #406

Merged
merged 30 commits into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
9a35f98
Pull down the `include_default_values` argument to `to_json`
SamuelYvon Aug 3, 2022
8127732
Also add the `casing` arg.
SamuelYvon Aug 3, 2022
f58ba68
Add tests
SamuelYvon Aug 4, 2022
a2fc143
Merge branch 'master' of github.com:danielgtaylor/python-betterproto
SamuelYvon Aug 8, 2022
f21a792
Add support for pydantic dataclasses
SamuelYvon Aug 4, 2022
ff9bba1
Use the TYPE_CHECKING trick so mypy is happy.
SamuelYvon Aug 5, 2022
f751354
Add forward refs update after generating pydantic dataclasses
SamuelYvon Aug 5, 2022
6e643b7
Adding support for oneofs w/pydantic
SamuelYvon Aug 8, 2022
fb58f1b
These will never properly resolve
SamuelYvon Aug 8, 2022
d1cde0f
Only updating forward-refs on fields w/forward refs
SamuelYvon Aug 8, 2022
5ceae21
Generate code through the pydantic generator
SamuelYvon Aug 8, 2022
26dbed7
Use the plugin way?
SamuelYvon Aug 8, 2022
d0b807d
Adding tests
SamuelYvon Aug 8, 2022
74a0104
Those are my tests; not important
SamuelYvon Aug 8, 2022
b6b1d93
Fix sort
SamuelYvon Aug 8, 2022
a38f9b1
Required!
SamuelYvon Aug 8, 2022
56cfff0
Older version
SamuelYvon Aug 8, 2022
ad841d3
If there's a single field, it will be optional, so it's OK
SamuelYvon Aug 10, 2022
c36b65d
Windows unit tests
SamuelYvon Oct 18, 2022
53998aa
Smaller windows section
SamuelYvon Oct 18, 2022
d70b0c7
Some type fixing
SamuelYvon Oct 18, 2022
6007e56
Applying requested changes
SamuelYvon Feb 13, 2023
02cdc86
Merge remote-tracking branch 'upstream/master' into opt_pydantic_dc
SamuelYvon Feb 13, 2023
3527b98
Add missing parameter
SamuelYvon Feb 13, 2023
49c3f92
Adding instructions
SamuelYvon Feb 13, 2023
87f7b6d
Update README.md
SamuelYvon Feb 13, 2023
1a6f1e1
Update src/betterproto/plugin/models.py
SamuelYvon Feb 13, 2023
14bb373
Update src/betterproto/templates/template.py.j2
SamuelYvon Feb 13, 2023
30e8c38
Fix models
SamuelYvon Feb 13, 2023
f65ab53
Remove TYPE_CHECKING verification for dataclass
SamuelYvon Feb 13, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ This project aims to provide an improved experience when using Protobuf / gRPC i
- Timezone-aware `datetime` and `timedelta` objects
- Relative imports
- Mypy type checking
- [Pydantic Models](https://docs.pydantic.dev/) generation (see #generating-pydantic-models)

This project is heavily inspired by, and borrows functionality from:

Expand Down Expand Up @@ -364,6 +365,25 @@ datetime.datetime(2019, 1, 1, 11, 59, 58, 800000, tzinfo=datetime.timezone.utc)
{'ts': '2019-01-01T12:00:00Z', 'duration': '1.200s'}
```

## Generating Pydantic Models

You can use python-betterproto to generate pydantic based models, using
pydantic dataclasses. This means the results of the protobuf unmarshalling will
be typed checked. The usage is the same, but you need to add a custom option
when calling the protobuf compiler:


```
protoc -I . --custom_opt=pydantic_dataclasses --python_betterproto_out=lib example.proto
```

With the important change being `--custom_opt=pydantic_dataclasses`. This will
swap the dataclass implementation from the builtin python dataclass to the
pydantic dataclass. You must have pydantic as a dependency in your project for
this to work.



## Development

- _Join us on [Slack](https://join.slack.com/t/betterproto/shared_invite/zt-f0n0uolx-iN8gBNrkPxtKHTLpG3o1OQ)!_
Expand Down
62 changes: 61 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ sphinx-rtd-theme = "0.5.0"
tomlkit = "^0.7.0"
tox = "^3.15.1"
pre-commit = "^2.17.0"
pydantic = ">=1.8.0"


[tool.poetry.scripts]
Expand Down
19 changes: 18 additions & 1 deletion src/betterproto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,6 @@ def __post_init__(self) -> None:
# Set current field of each group after `__init__` has already been run.
group_current: Dict[str, Optional[str]] = {}
for field_name, meta in self._betterproto.meta_by_field_name.items():

if meta.group:
group_current.setdefault(meta.group)

Expand Down Expand Up @@ -1470,6 +1469,24 @@ def is_set(self, name: str) -> bool:
)
return self.__raw_get(name) is not default

@classmethod
def _validate_field_groups(cls, values):
meta = cls._betterproto_meta.oneof_field_by_group # type: ignore

for group, field_set in meta.items():
set_fields = [
field.name for field in field_set if values[field.name] is not None
]
if not set_fields:
raise ValueError(f"Group {group} has no value; all fields are None")
elif len(set_fields) > 1:
set_fields_str = ", ".join(set_fields)
raise ValueError(
f"Group {group} has more than one value; fields {set_fields_str} are not None"
)

return values


def serialized_on_wire(message: Message) -> bool:
"""
Expand Down
37 changes: 35 additions & 2 deletions src/betterproto/plugin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ def comment(self) -> str:

@dataclass
class PluginRequestCompiler:

plugin_request_obj: CodeGeneratorRequest
output_packages: Dict[str, "OutputTemplate"] = field(default_factory=dict)

Expand Down Expand Up @@ -247,11 +246,13 @@ class OutputTemplate:
imports: Set[str] = field(default_factory=set)
datetime_imports: Set[str] = field(default_factory=set)
typing_imports: Set[str] = field(default_factory=set)
pydantic_imports: Set[str] = field(default_factory=set)
builtins_import: bool = False
messages: List["MessageCompiler"] = field(default_factory=list)
enums: List["EnumDefinitionCompiler"] = field(default_factory=list)
services: List["ServiceCompiler"] = field(default_factory=list)
imports_type_checking_only: Set[str] = field(default_factory=set)
pydantic_dataclasses: bool = False
output: bool = True

@property
Expand Down Expand Up @@ -334,6 +335,20 @@ def deprecated_fields(self) -> Iterator[str]:
def has_deprecated_fields(self) -> bool:
return any(self.deprecated_fields)

@property
def has_oneof_fields(self) -> bool:
return any(isinstance(field, OneOfFieldCompiler) for field in self.fields)

@property
def has_message_field(self) -> bool:
return any(
(
field.proto_obj.type in PROTO_MESSAGE_TYPES
for field in self.fields
if isinstance(field.proto_obj, FieldDescriptorProto)
)
)


def is_map(
proto_field_obj: FieldDescriptorProto, parent_message: DescriptorProto
Expand Down Expand Up @@ -431,6 +446,10 @@ def typing_imports(self) -> Set[str]:
imports.add("Dict")
return imports

@property
def pydantic_imports(self) -> Set[str]:
return set()

@property
def use_builtins(self) -> bool:
return self.py_type in self.parent.builtins_types or (
Expand All @@ -440,6 +459,7 @@ def use_builtins(self) -> bool:
def add_imports_to(self, output_file: OutputTemplate) -> None:
output_file.datetime_imports.update(self.datetime_imports)
output_file.typing_imports.update(self.typing_imports)
output_file.pydantic_imports.update(self.pydantic_imports)
output_file.builtins_import = output_file.builtins_import or self.use_builtins

@property
Expand Down Expand Up @@ -568,6 +588,20 @@ def betterproto_field_args(self) -> List[str]:
return args


@dataclass
class PydanticOneOfFieldCompiler(OneOfFieldCompiler):
@property
def optional(self) -> bool:
# Force the optional to be True. This will allow the pydantic dataclass
# to validate the object correctly by allowing the field to be let empty.
# We add a pydantic validator later to ensure exactly one field is defined.
return True

@property
def pydantic_imports(self) -> Set[str]:
return {"root_validator"}


@dataclass
class MapEntryCompiler(FieldCompiler):
py_k_type: Type = PLACEHOLDER
Expand Down Expand Up @@ -679,7 +713,6 @@ def py_name(self) -> str:

@dataclass
class ServiceMethodCompiler(ProtoContentBase):

parent: ServiceCompiler
proto_obj: MethodDescriptorProto
path: List[int] = PLACEHOLDER
Expand Down
32 changes: 27 additions & 5 deletions src/betterproto/plugin/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from betterproto.lib.google.protobuf import (
DescriptorProto,
EnumDescriptorProto,
FieldDescriptorProto,
FileDescriptorProto,
ServiceDescriptorProto,
)
Expand All @@ -30,6 +31,7 @@
OneOfFieldCompiler,
OutputTemplate,
PluginRequestCompiler,
PydanticOneOfFieldCompiler,
ServiceCompiler,
ServiceMethodCompiler,
is_map,
Expand Down Expand Up @@ -91,6 +93,11 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
# skip outputting Google's well-known types
request_data.output_packages[output_package_name].output = False

if "pydantic_dataclasses" in plugin_options:
request_data.output_packages[
output_package_name
].pydantic_dataclasses = True

# Read Messages and Enums
# We need to read Messages before Services in so that we can
# get the references to input/output messages for each service
Expand Down Expand Up @@ -145,6 +152,24 @@ def generate_code(request: CodeGeneratorRequest) -> CodeGeneratorResponse:
return response


def _make_one_of_field_compiler(
output_package: OutputTemplate,
source_file: "FileDescriptorProto",
parent: MessageCompiler,
proto_obj: "FieldDescriptorProto",
path: List[int],
) -> FieldCompiler:
SamuelYvon marked this conversation as resolved.
Show resolved Hide resolved

pydantic = output_package.pydantic_dataclasses
Cls = PydanticOneOfFieldCompiler if pydantic else OneOfFieldCompiler
return Cls(
source_file=source_file,
parent=parent,
proto_obj=proto_obj,
path=path,
)


def read_protobuf_type(
item: DescriptorProto,
path: List[int],
Expand All @@ -168,11 +193,8 @@ def read_protobuf_type(
path=path + [2, index],
)
elif is_oneof(field):
OneOfFieldCompiler(
source_file=source_file,
parent=message_data,
proto_obj=field,
path=path + [2, index],
_make_one_of_field_compiler(
output_package, source_file, message_data, field, path + [2, index]
)
else:
FieldCompiler(
Expand Down
24 changes: 24 additions & 0 deletions src/betterproto/templates/template.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,13 @@
{% for i in output_file.python_module_imports|sort %}
import {{ i }}
{% endfor %}

{% if output_file.pydantic_dataclasses %}
from pydantic.dataclasses import dataclass
{%- else -%}
from dataclasses import dataclass
{% endif %}

{% if output_file.datetime_imports %}
from datetime import {% for i in output_file.datetime_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}

Expand All @@ -15,6 +21,11 @@ from typing import {% for i in output_file.typing_imports|sort %}{{ i }}{% if no

{% endif %}

{% if output_file.pydantic_imports %}
from pydantic import {% for i in output_file.pydantic_imports|sort %}{{ i }}{% if not loop.last %}, {% endif %}{% endfor %}

{% endif %}

import betterproto
{% if output_file.services %}
from betterproto.grpc.grpclib_server import ServiceBase
Expand Down Expand Up @@ -80,6 +91,11 @@ class {{ message.py_name }}(betterproto.Message):
{% endfor %}
{% endif %}

{% if output_file.pydantic_dataclasses and message.has_oneof_fields %}
@root_validator()
def check_oneof(cls, values):
return cls._validate_field_groups(values)
{% endif %}

{% endfor %}
{% for service in output_file.services %}
Expand Down Expand Up @@ -226,3 +242,11 @@ class {{ service.py_name }}Base(ServiceBase):
}

{% endfor %}

{% if output_file.pydantic_dataclasses %}
{% for message in output_file.messages %}
{% if message.has_message_field %}
{{ message.py_name }}.__pydantic_model__.update_forward_refs() # type: ignore
{% endif %}
{% endfor %}
{% endif %}
Loading