Skip to content

Commit

Permalink
Support default_factory for structs
Browse files Browse the repository at this point in the history
  • Loading branch information
unmade committed Dec 17, 2024
1 parent 8afcd96 commit 6b98004
Show file tree
Hide file tree
Showing 16 changed files with 127 additions and 28 deletions.
4 changes: 4 additions & 0 deletions example/app/interfaces/dates.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@ class DateTime:

@dataclass
class Date: ...

EPOCH: DateTime = DateTime(
year=1970, month=1, day=1, hour=0, minute=0, second=0, microsecond=0
)
8 changes: 7 additions & 1 deletion example/app/interfaces/shared.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ from dataclasses import dataclass
from typing import *

class NotFound(Exception):
def __init__(self, message: Optional[str] = "Not Found"): ...
message: Optional[str] = "Not Found"

def __init__(self, message: Optional[str] = "Not Found") -> None: ...

class EmptyException(Exception): ...

Expand All @@ -11,5 +13,9 @@ class LimitOffset:
limit: Optional[int] = None
offset: Optional[int] = None

INT_CONST_1: int = 1234
MAP_CONST: Dict[str, str] = {"hello": "world", "goodnight": "moon"}
INT_CONST_2: int = 1234

class Service:
def ping(self) -> str: ...
17 changes: 12 additions & 5 deletions example/app/interfaces/todo.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import IntEnum
from typing import *
from . import dates
Expand All @@ -13,19 +13,26 @@ class TodoType(IntEnum):
class TodoItem:
id: int
text: str
type: int
type: TodoType
created: dates.DateTime
is_deleted: bool
picture: Optional[str] = None
picture: Optional[bytes] = None
is_favorite: bool = False

@dataclass
class TodoCounter:
todos: Dict[int, TodoItem] = field(default_factory=dict)
plain_ids: Set[int] = field(default_factory=lambda: {1, 2, 3})
note_ids: List[int] = field(default_factory=list)
checkboxes_ids: Set[int] = field(default_factory=set)

class Todo:
def create(self, text: str, type: int) -> int: ...
def create(self, text: str, type: TodoType) -> int: ...
def update(self, item: TodoItem) -> None: ...
def get(self, id: int) -> TodoItem: ...
def all(self, pager: shared.LimitOffset) -> List[TodoItem]: ...
def filter(self, ids: List[int]) -> List[TodoItem]: ...
def stats(self) -> Dict[int, float]: ...
def types(self) -> Set[int]: ...
def groupby(self) -> Dict[int, List[TodoItem]]: ...
def groupby(self) -> Dict[TodoType, List[TodoItem]]: ...
def ping(self) -> str: ...
4 changes: 2 additions & 2 deletions example/app/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from thriftpy2.rpc import make_server

if TYPE_CHECKING:
from example.app.interfaces.todo import TodoItem
from example.app.interfaces.todo import TodoItem, TodoType

todos: Dict[int, "TodoItem"] = {}


class Dispatcher:
def create(self, text: str, type: int) -> int:
def create(self, text: str, type: TodoType) -> int:
todo_id = max(todos.keys() or [0]) + 1
created = datetime.datetime.now()
todos[todo_id] = interfaces.todo.TodoItem(
Expand Down
9 changes: 9 additions & 0 deletions example/interfaces/todo.thrift
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,18 @@ struct TodoItem {
7: required bool is_favorite = false
}


typedef list<TodoItem> TodoItemList


struct TodoCounter {
1: required map<i32, TodoItem> todos = {}
2: required set<i32> plain_ids = [1, 2, 3]
3: required list<i32> note_ids = []
4: required set<i32> checkboxes_ids = []
}


service Todo extends shared.Service {
i32 create(
1: string text,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "thrift-pyi"
version = "1.0.0"
version = "1.1.0"
description = "This is simple `.pyi` stubs generator from thrift interfaces"
readme = "README.rst"
repository = "https://github.com/unmade/thrift-pyi"
Expand Down
41 changes: 38 additions & 3 deletions src/thriftpyi/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@

import ast
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Sequence, Type, Union
from typing import (
TYPE_CHECKING,
MutableMapping,
MutableSequence,
MutableSet,
Sequence,
Type,
Union,
)

if TYPE_CHECKING:
AnyFunctionDef = Union[ast.AsyncFunctionDef, ast.FunctionDef]
Expand Down Expand Up @@ -161,7 +169,7 @@ def as_ast(self) -> ast.AnnAssign | ast.Assign:
if self.type is None:
return ast.Assign(
targets=[ast.Name(id=self.name, ctx=ast.Store())],
value=ast.Constant(value=self.value, kind=None),
value=self._make_ast_value(),
lineno=0,
)

Expand All @@ -173,11 +181,38 @@ def as_ast(self) -> ast.AnnAssign | ast.Assign:
if self.required and self.value is None:
value = None
else:
value = ast.Constant(value=self.value, kind=None)
value = self._make_ast_value()

return ast.AnnAssign(
target=ast.Name(id=self.name, ctx=ast.Store()),
annotation=annotation,
value=value,
simple=1,
)

def _make_ast_value(self) -> ast.expr:
return ast.Constant(value=self.value, kind=None)


@dataclass
class StructField(Field):
def _make_ast_value(self) -> ast.expr:
if isinstance(self.value, (MutableSequence, MutableSet, MutableMapping)):
if self.value:
value = ast.Lambda(
args=[], body=ast.Constant(value=self.value, kind=None)
)
else:
value = ast.Name(id=self.value.__class__.__name__, ctx=ast.Load())

return ast.Call(
func=ast.Name(id="field", ctx=ast.Load()),
args=[],
keywords=[
ast.keyword(
arg="default_factory",
value=value,
)
],
)
return super()._make_ast_value()
17 changes: 15 additions & 2 deletions src/thriftpyi/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import cast

from thriftpyi.entities import Field, FieldValue, Method, ModuleItem
from thriftpyi.entities import Field, FieldValue, Method, ModuleItem, StructField
from thriftpyi.utils import get_python_type, guess_type


Expand Down Expand Up @@ -127,7 +127,7 @@ def _make_service(tservice) -> ModuleItem:

@staticmethod
def _make_struct(tclass) -> ModuleItem:
spec = TSpecProxy(
spec = TStructSpecProxy(
module_name=tclass.__module__,
thrift_spec=tclass.thrift_spec,
default_spec=dict(tclass.default_spec),
Expand Down Expand Up @@ -188,3 +188,16 @@ def _get_python_type(self, item: TSpecItemProxy) -> str:
def _get_default_value(self, item: TSpecItemProxy) -> FieldValue:
default_value = self.default_spec.get(item.name)
return cast(FieldValue, default_value)


class TStructSpecProxy(TSpecProxy):
def get_fields(self, *, ignore_type: bool = False) -> list[Field]:
return [
StructField(
name=item.name,
type=self._get_python_type(item) if not ignore_type else None,
value=self._get_default_value(item),
required=item.required,
)
for item in self.thrift_spec
]
2 changes: 1 addition & 1 deletion src/thriftpyi/stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def build_init(imports: Iterable[str]) -> ast.Module:
def _make_imports(proxy: TModuleProxy) -> list[ast.ImportFrom]:
imports = []
if proxy.has_structs():
imports.append(_make_absolute_import("dataclasses", "dataclass"))
imports.append(_make_absolute_import("dataclasses", "dataclass, field"))
if proxy.has_enums():
imports.append(_make_absolute_import("enum", "IntEnum"))

Expand Down
4 changes: 4 additions & 0 deletions tests/stubs/expected/async/dates.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@ class DateTime:

@dataclass
class Date: ...

EPOCH: DateTime = DateTime(
year=1970, month=1, day=1, hour=0, minute=0, second=0, microsecond=0
)
8 changes: 4 additions & 4 deletions tests/stubs/expected/async/shared.pyi
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
from dataclasses import dataclass
from typing import *

INT_CONST_1: int = 1234
MAP_CONST: Dict[str, str] = {"hello": "world", "goodnight": "moon"}
INT_CONST_2: int = 1234

class NotFound(Exception):
message: Optional[str] = "Not Found"

Expand All @@ -17,5 +13,9 @@ class LimitOffset:
limit: Optional[int] = None
offset: Optional[int] = None

INT_CONST_1: int = 1234
MAP_CONST: Dict[str, str] = {"hello": "world", "goodnight": "moon"}
INT_CONST_2: int = 1234

class Service:
async def ping(self) -> str: ...
9 changes: 8 additions & 1 deletion tests/stubs/expected/async/todo.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import IntEnum
from typing import *
from . import dates
Expand All @@ -19,6 +19,13 @@ class TodoItem:
picture: Optional[bytes] = None
is_favorite: bool = False

@dataclass
class TodoCounter:
todos: Dict[int, TodoItem] = field(default_factory=dict)
plain_ids: Set[int] = field(default_factory=lambda: {1, 2, 3})
note_ids: List[int] = field(default_factory=list)
checkboxes_ids: Set[int] = field(default_factory=set)

class Todo:
async def create(self, text: str, type: TodoType) -> int: ...
async def update(self, item: TodoItem) -> None: ...
Expand Down
8 changes: 4 additions & 4 deletions tests/stubs/expected/optional/shared.pyi
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
from dataclasses import dataclass
from typing import *

INT_CONST_1: int = 1234
MAP_CONST: Dict[str, str] = {"hello": "world", "goodnight": "moon"}
INT_CONST_2: int = 1234

class NotFound(Exception):
message: Optional[str] = "Not Found"

Expand All @@ -17,5 +13,9 @@ class LimitOffset:
limit: Optional[int] = None
offset: Optional[int] = None

INT_CONST_1: int = 1234
MAP_CONST: Dict[str, str] = {"hello": "world", "goodnight": "moon"}
INT_CONST_2: int = 1234

class Service:
def ping(self) -> str: ...
9 changes: 8 additions & 1 deletion tests/stubs/expected/optional/todo.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import IntEnum
from typing import *
from . import dates
Expand All @@ -19,6 +19,13 @@ class TodoItem:
picture: Optional[bytes] = None
is_favorite: Optional[bool] = False

@dataclass
class TodoCounter:
todos: Optional[Dict[int, TodoItem]] = field(default_factory=dict)
plain_ids: Optional[Set[int]] = field(default_factory=lambda: {1, 2, 3})
note_ids: Optional[List[int]] = field(default_factory=list)
checkboxes_ids: Optional[Set[int]] = field(default_factory=set)

class Todo:
def create(self, text: str, type: TodoType) -> int: ...
def update(self, item: TodoItem) -> None: ...
Expand Down
9 changes: 8 additions & 1 deletion tests/stubs/expected/sync/todo.pyi
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from enum import IntEnum
from typing import *
from . import dates
Expand All @@ -19,6 +19,13 @@ class TodoItem:
picture: Optional[bytes] = None
is_favorite: bool = False

@dataclass
class TodoCounter:
todos: Dict[int, TodoItem] = field(default_factory=dict)
plain_ids: Set[int] = field(default_factory=lambda: {1, 2, 3})
note_ids: List[int] = field(default_factory=list)
checkboxes_ids: Set[int] = field(default_factory=set)

class Todo:
def create(self, text: str, type: TodoType) -> int: ...
def update(self, item: TodoItem) -> None: ...
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
"expected_dir,args",
[
("tests/stubs/expected/sync", ["--strict-optional"]),
# ("tests/stubs/expected/async", ["--async", "--strict-optional"]),
# ("tests/stubs/expected/optional", []),
("tests/stubs/expected/async", ["--async", "--strict-optional"]),
("tests/stubs/expected/optional", []),
],
)
def test_main(capsys, expected_dir, args):
Expand Down

0 comments on commit 6b98004

Please sign in to comment.