Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

stubgen: generate valid dataclass stubs #15625

Merged
merged 4 commits into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 51 additions & 6 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,7 @@ def __init__(
self.defined_names: set[str] = set()
# Short names of methods defined in the body of the current class
self.method_names: set[str] = set()
self.processing_dataclass = False

def visit_mypy_file(self, o: MypyFile) -> None:
self.module = o.fullname # Current module being processed
Expand Down Expand Up @@ -706,6 +707,12 @@ def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> None:
self.clear_decorators()

def visit_func_def(self, o: FuncDef) -> None:
is_dataclass_generated = (
self.analyzed and self.processing_dataclass and o.info.names[o.name].plugin_generated
)
if is_dataclass_generated and o.name != "__init__":
# Skip methods generated by the @dataclass decorator (except for __init__)
return
if (
self.is_private_name(o.name, o.fullname)
or self.is_not_in_all(o.name)
Expand Down Expand Up @@ -771,6 +778,12 @@ def visit_func_def(self, o: FuncDef) -> None:
else:
arg = name + annotation
args.append(arg)
if o.name == "__init__" and is_dataclass_generated and "**" in args:
# The dataclass plugin generates invalid nameless "*" and "**" arguments
new_name = "".join(a.split(":", 1)[0] for a in args).replace("*", "")
args[args.index("*")] = f"*{new_name}_" # this name is guaranteed to be unique
args[args.index("**")] = f"**{new_name}__" # same here

retname = None
if o.name != "__init__" and isinstance(o.unanalyzed_type, CallableType):
if isinstance(get_proper_type(o.unanalyzed_type.ret_type), AnyType):
Expand Down Expand Up @@ -899,6 +912,9 @@ def visit_class_def(self, o: ClassDef) -> None:
if not self._indent and self._state != EMPTY:
sep = len(self._output)
self.add("\n")
decorators = self.get_class_decorators(o)
for d in decorators:
self.add(f"{self._indent}@{d}\n")
self.add(f"{self._indent}class {o.name}")
self.record_name(o.name)
base_types = self.get_base_types(o)
Expand Down Expand Up @@ -934,6 +950,7 @@ def visit_class_def(self, o: ClassDef) -> None:
else:
self._state = CLASS
self.method_names = set()
self.processing_dataclass = False
self._current_class = None

def get_base_types(self, cdef: ClassDef) -> list[str]:
Expand Down Expand Up @@ -979,6 +996,21 @@ def get_base_types(self, cdef: ClassDef) -> list[str]:
base_types.append(f"{name}={value.accept(p)}")
return base_types

def get_class_decorators(self, cdef: ClassDef) -> list[str]:
decorators: list[str] = []
p = AliasPrinter(self)
for d in cdef.decorators:
if self.is_dataclass(d):
decorators.append(d.accept(p))
self.import_tracker.require_name(get_qualified_name(d))
self.processing_dataclass = True
return decorators

def is_dataclass(self, expr: Expression) -> bool:
if isinstance(expr, CallExpr):
expr = expr.callee
return self.get_fullname(expr) == "dataclasses.dataclass"

def visit_block(self, o: Block) -> None:
# Unreachable statements may be partially uninitialized and that may
# cause trouble.
Expand Down Expand Up @@ -1336,19 +1368,30 @@ def get_init(
# Final without type argument is invalid in stubs.
final_arg = self.get_str_type_of_node(rvalue)
typename += f"[{final_arg}]"
elif self.processing_dataclass:
# attribute without annotation is not a dataclass field, don't add annotation.
return f"{self._indent}{lvalue} = ...\n"
else:
typename = self.get_str_type_of_node(rvalue)
initializer = self.get_assign_initializer(rvalue)
return f"{self._indent}{lvalue}: {typename}{initializer}\n"

def get_assign_initializer(self, rvalue: Expression) -> str:
"""Does this rvalue need some special initializer value?"""
if self._current_class and self._current_class.info:
# Current rules
# 1. Return `...` if we are dealing with `NamedTuple` and it has an existing default value
if self._current_class.info.is_named_tuple and not isinstance(rvalue, TempNode):
return " = ..."
# TODO: support other possible cases, where initializer is important
if not self._current_class:
return ""
# Current rules
# 1. Return `...` if we are dealing with `NamedTuple` or `dataclass` field and
# it has an existing default value
if (
self._current_class.info
and self._current_class.info.is_named_tuple
and not isinstance(rvalue, TempNode)
):
return " = ..."
if self.processing_dataclass and not (isinstance(rvalue, TempNode) and rvalue.no_rhs):
return " = ..."
# TODO: support other possible cases, where initializer is important

# By default, no initializer is required:
return ""
Expand Down Expand Up @@ -1410,6 +1453,8 @@ def is_private_name(self, name: str, fullname: str | None = None) -> bool:
return False
if fullname in EXTRA_EXPORTED:
return False
if name == "_":
return False
return name.startswith("_") and (not name.endswith("__") or name in IGNORED_DUNDERS)

def is_private_member(self, fullname: str) -> bool:
Expand Down
11 changes: 11 additions & 0 deletions mypy/test/teststubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,11 +724,22 @@ def run_case_inner(self, testcase: DataDrivenTestCase) -> None:

def parse_flags(self, program_text: str, extra: list[str]) -> Options:
flags = re.search("# flags: (.*)$", program_text, flags=re.MULTILINE)
pyversion = None
if flags:
flag_list = flags.group(1).split()
for i, flag in enumerate(flag_list):
if flag.startswith("--python-version="):
pyversion = flag.split("=", 1)[1]
del flag_list[i]
break
else:
flag_list = []
options = parse_options(flag_list + extra)
if pyversion:
# A hack to allow testing old python versions with new language constructs
# This should be rarely used in general as stubgen output should not be version-specific
major, minor = pyversion.split(".", 1)
options.pyversion = (int(major), int(minor))
if "--verbose" not in flag_list:
options.quiet = True
else:
Expand Down
182 changes: 182 additions & 0 deletions test-data/unit/stubgen.test
Original file line number Diff line number Diff line change
Expand Up @@ -3512,3 +3512,185 @@ def gen2() -> _Generator[_Incomplete, _Incomplete, _Incomplete]: ...

class X(_Incomplete): ...
class Y(_Incomplete): ...

[case testDataclass]
import dataclasses
import dataclasses as dcs
from dataclasses import dataclass, InitVar, KW_ONLY
from dataclasses import dataclass as dc
from typing import ClassVar

@dataclasses.dataclass
class X:
a: int
b: str = "hello"
c: ClassVar
d: ClassVar = 200
f: list[int] = field(init=False, default_factory=list)
g: int = field(default=2, kw_only=True)
_: KW_ONLY
h: int = 1
i: InitVar[str]
j: InitVar = 100
non_field = None

@dcs.dataclass
class Y: ...

@dataclass
class Z: ...

@dc
class W: ...

@dataclass(init=False, repr=False)
class V: ...

[out]
import dataclasses
import dataclasses as dcs
from dataclasses import InitVar, KW_ONLY, dataclass, dataclass as dc
from typing import ClassVar

@dataclasses.dataclass
class X:
a: int
b: str = ...
c: ClassVar
d: ClassVar = ...
f: list[int] = ...
g: int = ...
_: KW_ONLY
h: int = ...
i: InitVar[str]
j: InitVar = ...
non_field = ...

@dcs.dataclass
class Y: ...
@dataclass
class Z: ...
@dc
class W: ...
@dataclass(init=False, repr=False)
class V: ...

[case testDataclass_semanal]
from dataclasses import dataclass, InitVar
from typing import ClassVar

@dataclass
class X:
a: int
b: str = "hello"
c: ClassVar
d: ClassVar = 200
f: list[int] = field(init=False, default_factory=list)
g: int = field(default=2, kw_only=True)
h: int = 1
i: InitVar[str]
j: InitVar = 100
non_field = None

@dataclass(init=False, repr=False, frozen=True)
class Y: ...

[out]
from dataclasses import InitVar, dataclass
from typing import ClassVar

@dataclass
class X:
a: int
b: str = ...
c: ClassVar
d: ClassVar = ...
f: list[int] = ...
g: int = ...
h: int = ...
i: InitVar[str]
j: InitVar = ...
non_field = ...
def __init__(self, a, b, f, g, h, i, j) -> None: ...
Copy link

@thomasgilgenast thomasgilgenast Oct 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry to comment on an old, merged PR, but I wanted to ask if it's expected that the arguments to __init__() are missing type annotations in the generated stub.

If I install this stub and then try to typecheck

X(a="foo", ...)

I would expect mypy to error with

error: Argument "a" to "X" has incompatible type "str"; expected "int"  [arg-type]

but instead it passes (I think the type annotation for a is implicitly Any when it's not annotated in the signature, which is incorrect I think - I think the type of a should be int, not Any).

Let me know if this is actually better positioned as a new issue (with reproducible example etc.), but I just wanted to see if I was missing something obvious before opening a new one.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stubgen doesn't use inferred types of parameters in generated stubs unless it sees a parameter with a primitive default value. You'll always need to tweak the generated stubs by hand if you want better experience with the type checker.

Also note that the __init__ method can be safely deleted from the stub most of the time. It is included because a field declaration like some_field: int = field(init=False) that affects the signature of __init__ cannot be currently represented in the stub.

With that being said, I agree that perhaps we can do better but I am not a mypy maintainer. So yes please open a new issue asking for this feature.


@dataclass(init=False, repr=False, frozen=True)
class Y: ...

[case testDataclassWithKwOnlyField_semanal]
# flags: --python-version=3.10
from dataclasses import dataclass, InitVar, KW_ONLY
from typing import ClassVar

@dataclass
class X:
a: int
b: str = "hello"
c: ClassVar
d: ClassVar = 200
f: list[int] = field(init=False, default_factory=list)
g: int = field(default=2, kw_only=True)
_: KW_ONLY
h: int = 1
i: InitVar[str]
j: InitVar = 100
non_field = None

@dataclass(init=False, repr=False, frozen=True)
class Y: ...

[out]
from dataclasses import InitVar, KW_ONLY, dataclass
from typing import ClassVar

@dataclass
class X:
a: int
b: str = ...
c: ClassVar
d: ClassVar = ...
f: list[int] = ...
g: int = ...
_: KW_ONLY
h: int = ...
i: InitVar[str]
j: InitVar = ...
non_field = ...
def __init__(self, a, b, f, g, *, h, i, j) -> None: ...

@dataclass(init=False, repr=False, frozen=True)
class Y: ...

[case testDataclassWithExplicitGeneratedMethodsOverrides_semanal]
from dataclasses import dataclass

@dataclass
class X:
a: int
def __init__(self, a: int, b: str = ...) -> None: ...
def __post_init__(self) -> None: ...

[out]
from dataclasses import dataclass

@dataclass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that init=False is assumed here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the dataclass documentation,

init: If true (the default), a __init__() method will be generated.
If the class already defines __init__(), this parameter is ignored.

So it is ignored in this case. This was not the goal of the test anyway, it was to test that user defined methods are always included in the stub.

class X:
a: int
def __init__(self, a: int, b: str = ...) -> None: ...
def __post_init__(self) -> None: ...

[case testDataclassInheritsFromAny_semanal]
from dataclasses import dataclass
import missing

@dataclass
class X(missing.Base):
a: int

[out]
import missing
from dataclasses import dataclass

@dataclass
class X(missing.Base):
a: int
def __init__(self, *selfa_, a, **selfa__) -> None: ...