Skip to content

Commit

Permalink
stubgen: generate valid dataclass stubs (#15625)
Browse files Browse the repository at this point in the history
Fixes #12441
Fixes #9986
Fixes #15966
  • Loading branch information
hamdanal committed Sep 15, 2023
1 parent 402c8ff commit 2bbc42f
Show file tree
Hide file tree
Showing 3 changed files with 244 additions and 6 deletions.
57 changes: 51 additions & 6 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,7 @@ def __init__(
self.defined_names: set[str] = set()
# Short names of methods defined in the body of the current class
self.method_names: set[str] = set()
self.processing_dataclass = False

def visit_mypy_file(self, o: MypyFile) -> None:
self.module = o.fullname # Current module being processed
Expand Down Expand Up @@ -706,6 +707,12 @@ def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> None:
self.clear_decorators()

def visit_func_def(self, o: FuncDef) -> None:
is_dataclass_generated = (
self.analyzed and self.processing_dataclass and o.info.names[o.name].plugin_generated
)
if is_dataclass_generated and o.name != "__init__":
# Skip methods generated by the @dataclass decorator (except for __init__)
return
if (
self.is_private_name(o.name, o.fullname)
or self.is_not_in_all(o.name)
Expand Down Expand Up @@ -771,6 +778,12 @@ def visit_func_def(self, o: FuncDef) -> None:
else:
arg = name + annotation
args.append(arg)
if o.name == "__init__" and is_dataclass_generated and "**" in args:
# The dataclass plugin generates invalid nameless "*" and "**" arguments
new_name = "".join(a.split(":", 1)[0] for a in args).replace("*", "")
args[args.index("*")] = f"*{new_name}_" # this name is guaranteed to be unique
args[args.index("**")] = f"**{new_name}__" # same here

retname = None
if o.name != "__init__" and isinstance(o.unanalyzed_type, CallableType):
if isinstance(get_proper_type(o.unanalyzed_type.ret_type), AnyType):
Expand Down Expand Up @@ -899,6 +912,9 @@ def visit_class_def(self, o: ClassDef) -> None:
if not self._indent and self._state != EMPTY:
sep = len(self._output)
self.add("\n")
decorators = self.get_class_decorators(o)
for d in decorators:
self.add(f"{self._indent}@{d}\n")
self.add(f"{self._indent}class {o.name}")
self.record_name(o.name)
base_types = self.get_base_types(o)
Expand Down Expand Up @@ -934,6 +950,7 @@ def visit_class_def(self, o: ClassDef) -> None:
else:
self._state = CLASS
self.method_names = set()
self.processing_dataclass = False
self._current_class = None

def get_base_types(self, cdef: ClassDef) -> list[str]:
Expand Down Expand Up @@ -979,6 +996,21 @@ def get_base_types(self, cdef: ClassDef) -> list[str]:
base_types.append(f"{name}={value.accept(p)}")
return base_types

def get_class_decorators(self, cdef: ClassDef) -> list[str]:
decorators: list[str] = []
p = AliasPrinter(self)
for d in cdef.decorators:
if self.is_dataclass(d):
decorators.append(d.accept(p))
self.import_tracker.require_name(get_qualified_name(d))
self.processing_dataclass = True
return decorators

def is_dataclass(self, expr: Expression) -> bool:
if isinstance(expr, CallExpr):
expr = expr.callee
return self.get_fullname(expr) == "dataclasses.dataclass"

def visit_block(self, o: Block) -> None:
# Unreachable statements may be partially uninitialized and that may
# cause trouble.
Expand Down Expand Up @@ -1336,19 +1368,30 @@ def get_init(
# Final without type argument is invalid in stubs.
final_arg = self.get_str_type_of_node(rvalue)
typename += f"[{final_arg}]"
elif self.processing_dataclass:
# attribute without annotation is not a dataclass field, don't add annotation.
return f"{self._indent}{lvalue} = ...\n"
else:
typename = self.get_str_type_of_node(rvalue)
initializer = self.get_assign_initializer(rvalue)
return f"{self._indent}{lvalue}: {typename}{initializer}\n"

def get_assign_initializer(self, rvalue: Expression) -> str:
"""Does this rvalue need some special initializer value?"""
if self._current_class and self._current_class.info:
# Current rules
# 1. Return `...` if we are dealing with `NamedTuple` and it has an existing default value
if self._current_class.info.is_named_tuple and not isinstance(rvalue, TempNode):
return " = ..."
# TODO: support other possible cases, where initializer is important
if not self._current_class:
return ""
# Current rules
# 1. Return `...` if we are dealing with `NamedTuple` or `dataclass` field and
# it has an existing default value
if (
self._current_class.info
and self._current_class.info.is_named_tuple
and not isinstance(rvalue, TempNode)
):
return " = ..."
if self.processing_dataclass and not (isinstance(rvalue, TempNode) and rvalue.no_rhs):
return " = ..."
# TODO: support other possible cases, where initializer is important

# By default, no initializer is required:
return ""
Expand Down Expand Up @@ -1410,6 +1453,8 @@ def is_private_name(self, name: str, fullname: str | None = None) -> bool:
return False
if fullname in EXTRA_EXPORTED:
return False
if name == "_":
return False
return name.startswith("_") and (not name.endswith("__") or name in IGNORED_DUNDERS)

def is_private_member(self, fullname: str) -> bool:
Expand Down
11 changes: 11 additions & 0 deletions mypy/test/teststubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,11 +724,22 @@ def run_case_inner(self, testcase: DataDrivenTestCase) -> None:

def parse_flags(self, program_text: str, extra: list[str]) -> Options:
flags = re.search("# flags: (.*)$", program_text, flags=re.MULTILINE)
pyversion = None
if flags:
flag_list = flags.group(1).split()
for i, flag in enumerate(flag_list):
if flag.startswith("--python-version="):
pyversion = flag.split("=", 1)[1]
del flag_list[i]
break
else:
flag_list = []
options = parse_options(flag_list + extra)
if pyversion:
# A hack to allow testing old python versions with new language constructs
# This should be rarely used in general as stubgen output should not be version-specific
major, minor = pyversion.split(".", 1)
options.pyversion = (int(major), int(minor))
if "--verbose" not in flag_list:
options.quiet = True
else:
Expand Down
182 changes: 182 additions & 0 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -3512,3 +3512,185 @@ def gen2() -> _Generator[_Incomplete, _Incomplete, _Incomplete]: ...

class X(_Incomplete): ...
class Y(_Incomplete): ...

[case testDataclass]
import dataclasses
import dataclasses as dcs
from dataclasses import dataclass, InitVar, KW_ONLY
from dataclasses import dataclass as dc
from typing import ClassVar

@dataclasses.dataclass
class X:
a: int
b: str = "hello"
c: ClassVar
d: ClassVar = 200
f: list[int] = field(init=False, default_factory=list)
g: int = field(default=2, kw_only=True)
_: KW_ONLY
h: int = 1
i: InitVar[str]
j: InitVar = 100
non_field = None

@dcs.dataclass
class Y: ...

@dataclass
class Z: ...

@dc
class W: ...

@dataclass(init=False, repr=False)
class V: ...

[out]
import dataclasses
import dataclasses as dcs
from dataclasses import InitVar, KW_ONLY, dataclass, dataclass as dc
from typing import ClassVar

@dataclasses.dataclass
class X:
a: int
b: str = ...
c: ClassVar
d: ClassVar = ...
f: list[int] = ...
g: int = ...
_: KW_ONLY
h: int = ...
i: InitVar[str]
j: InitVar = ...
non_field = ...

@dcs.dataclass
class Y: ...
@dataclass
class Z: ...
@dc
class W: ...
@dataclass(init=False, repr=False)
class V: ...

[case testDataclass_semanal]
from dataclasses import dataclass, InitVar
from typing import ClassVar

@dataclass
class X:
a: int
b: str = "hello"
c: ClassVar
d: ClassVar = 200
f: list[int] = field(init=False, default_factory=list)
g: int = field(default=2, kw_only=True)
h: int = 1
i: InitVar[str]
j: InitVar = 100
non_field = None

@dataclass(init=False, repr=False, frozen=True)
class Y: ...

[out]
from dataclasses import InitVar, dataclass
from typing import ClassVar

@dataclass
class X:
a: int
b: str = ...
c: ClassVar
d: ClassVar = ...
f: list[int] = ...
g: int = ...
h: int = ...
i: InitVar[str]
j: InitVar = ...
non_field = ...
def __init__(self, a, b, f, g, h, i, j) -> None: ...

@dataclass(init=False, repr=False, frozen=True)
class Y: ...

[case testDataclassWithKwOnlyField_semanal]
# flags: --python-version=3.10
from dataclasses import dataclass, InitVar, KW_ONLY
from typing import ClassVar

@dataclass
class X:
a: int
b: str = "hello"
c: ClassVar
d: ClassVar = 200
f: list[int] = field(init=False, default_factory=list)
g: int = field(default=2, kw_only=True)
_: KW_ONLY
h: int = 1
i: InitVar[str]
j: InitVar = 100
non_field = None

@dataclass(init=False, repr=False, frozen=True)
class Y: ...

[out]
from dataclasses import InitVar, KW_ONLY, dataclass
from typing import ClassVar

@dataclass
class X:
a: int
b: str = ...
c: ClassVar
d: ClassVar = ...
f: list[int] = ...
g: int = ...
_: KW_ONLY
h: int = ...
i: InitVar[str]
j: InitVar = ...
non_field = ...
def __init__(self, a, b, f, g, *, h, i, j) -> None: ...

@dataclass(init=False, repr=False, frozen=True)
class Y: ...

[case testDataclassWithExplicitGeneratedMethodsOverrides_semanal]
from dataclasses import dataclass

@dataclass
class X:
a: int
def __init__(self, a: int, b: str = ...) -> None: ...
def __post_init__(self) -> None: ...

[out]
from dataclasses import dataclass

@dataclass
class X:
a: int
def __init__(self, a: int, b: str = ...) -> None: ...
def __post_init__(self) -> None: ...

[case testDataclassInheritsFromAny_semanal]
from dataclasses import dataclass
import missing

@dataclass
class X(missing.Base):
a: int

[out]
import missing
from dataclasses import dataclass

@dataclass
class X(missing.Base):
a: int
def __init__(self, *selfa_, a, **selfa__) -> None: ...

0 comments on commit 2bbc42f

Please sign in to comment.