From 10099a3bd3d4d4e58ff7bde0ca17da17711bd295 Mon Sep 17 00:00:00 2001 From: Petter Friberg Date: Fri, 7 Jun 2024 23:02:49 +0200 Subject: [PATCH 1/3] Add type hints to builtin models' fields --- django-stubs/contrib/admin/models.pyi | 22 +- django-stubs/contrib/auth/base_user.pyi | 8 +- django-stubs/contrib/auth/models.pyi | 59 +++-- django-stubs/contrib/contenttypes/models.pyi | 15 +- django-stubs/contrib/flatpages/models.pyi | 33 ++- django-stubs/contrib/redirects/models.pyi | 11 +- .../contrib/sessions/base_session.pyi | 11 +- django-stubs/contrib/sessions/models.pyi | 7 +- django-stubs/contrib/sites/models.pyi | 12 +- .../db/models/fields/related_descriptors.pyi | 8 +- mypy_django_plugin/lib/fullnames.py | 5 +- mypy_django_plugin/transformers/fields.py | 11 +- mypy_django_plugin/transformers/manytomany.py | 8 +- mypy_django_plugin/transformers/models.py | 230 +++++++++++------- .../contrib/admin/test_admin_models.py | 20 ++ .../contrib/auth/test_auth_models.py | 52 ++++ .../contenttypes/test_contenttypes_models.py | 12 + .../flatpages/test_flatpages_models.py | 26 ++ .../contrib/sessions/test_sessions_models.py | 11 + .../contrib/sites/test_sites_models.py | 18 ++ .../typecheck/models/test_contrib_models.yml | 4 +- 21 files changed, 426 insertions(+), 157 deletions(-) create mode 100644 tests/assert_type/contrib/admin/test_admin_models.py create mode 100644 tests/assert_type/contrib/auth/test_auth_models.py create mode 100644 tests/assert_type/contrib/contenttypes/test_contenttypes_models.py create mode 100644 tests/assert_type/contrib/flatpages/test_flatpages_models.py create mode 100644 tests/assert_type/contrib/sessions/test_sessions_models.py create mode 100644 tests/assert_type/contrib/sites/test_sites_models.py diff --git a/django-stubs/contrib/admin/models.pyi b/django-stubs/contrib/admin/models.pyi index a1ac77b49..25ab047f1 100644 --- a/django-stubs/contrib/admin/models.pyi +++ b/django-stubs/contrib/admin/models.pyi @@ -1,8 +1,12 @@ +from datetime import date, datetime from typing import Any, ClassVar from uuid import UUID +from django.contrib.auth.models import AbstractUser +from django.contrib.contenttypes.models import ContentType from django.db import models from django.db.models.base import Model +from django.db.models.expressions import Combinable ADDITION: int CHANGE: int @@ -21,13 +25,17 @@ class LogEntryManager(models.Manager[LogEntry]): ) -> LogEntry: ... class LogEntry(models.Model): - action_time: models.DateTimeField - user: models.ForeignKey - content_type: models.ForeignKey - object_id: models.TextField - object_repr: models.CharField - action_flag: models.PositiveSmallIntegerField - change_message: models.TextField + id: models.AutoField[str | int | Combinable | None, int] + pk: models.AutoField[str | int | Combinable | None, int] + action_time: models.DateTimeField[str | datetime | date | Combinable, datetime] + user: models.ForeignKey[AbstractUser | Combinable, AbstractUser] + user_id: Any + content_type: models.ForeignKey[ContentType | Combinable | None, ContentType | None] + content_type_id: int | None + object_id: models.TextField[str | int | Combinable | None, str | None] + object_repr: models.CharField[str | int | Combinable, str] + action_flag: models.PositiveSmallIntegerField[float | int | str | Combinable, int] + change_message: models.TextField[str | int | Combinable, str] objects: ClassVar[LogEntryManager] def is_addition(self) -> bool: ... def is_change(self) -> bool: ... diff --git a/django-stubs/contrib/auth/base_user.pyi b/django-stubs/contrib/auth/base_user.pyi index 4740dd504..43f948f6c 100644 --- a/django-stubs/contrib/auth/base_user.pyi +++ b/django-stubs/contrib/auth/base_user.pyi @@ -1,9 +1,9 @@ +from datetime import date, datetime from typing import Any, ClassVar, Literal, TypeVar, overload from django.db import models from django.db.models.base import Model from django.db.models.expressions import Combinable -from django.db.models.fields import BooleanField _T = TypeVar("_T", bound=Model) @@ -16,9 +16,9 @@ class BaseUserManager(models.Manager[_T]): class AbstractBaseUser(models.Model): REQUIRED_FIELDS: ClassVar[list[str]] - password = models.CharField(max_length=128) - last_login = models.DateTimeField(blank=True, null=True) - is_active: bool | BooleanField[bool | Combinable, bool] + password: models.CharField[str | int | Combinable, str] + last_login: models.DateTimeField[str | datetime | date | Combinable, datetime | None] + is_active: bool | models.BooleanField[bool | Combinable, bool] def get_username(self) -> str: ... def natural_key(self) -> tuple[str]: ... diff --git a/django-stubs/contrib/auth/models.pyi b/django-stubs/contrib/auth/models.pyi index 78702f1e7..03e50aa5f 100644 --- a/django-stubs/contrib/auth/models.pyi +++ b/django-stubs/contrib/auth/models.pyi @@ -1,5 +1,6 @@ from collections.abc import Iterable -from typing import Any, ClassVar, Literal, TypeVar +from datetime import date, datetime +from typing import Any, ClassVar, Literal, TypeVar, type_check_only from django.contrib.auth.base_user import AbstractBaseUser as AbstractBaseUser from django.contrib.auth.base_user import BaseUserManager as BaseUserManager @@ -8,6 +9,8 @@ from django.contrib.contenttypes.models import ContentType from django.db import models from django.db.models import QuerySet from django.db.models.base import Model +from django.db.models.expressions import Combinable +from django.db.models.fields.related_descriptors import ManyToManyDescriptor from django.db.models.manager import EmptyManager from django.utils.functional import _StrOrPromise from typing_extensions import Self, TypeAlias @@ -20,22 +23,40 @@ class PermissionManager(models.Manager[Permission]): def get_by_natural_key(self, codename: str, app_label: str, model: str) -> Permission: ... class Permission(models.Model): - content_type_id: int objects: ClassVar[PermissionManager] - name = models.CharField(max_length=255) - content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE) - codename = models.CharField(max_length=100) + id: models.AutoField[str | int | Combinable | None, int] + pk: models.AutoField[str | int | Combinable | None, int] + name: models.CharField[str | int | Combinable, str] + content_type: models.ForeignKey[ContentType | Combinable, ContentType] + content_type_id: int + codename: models.CharField[str | int | Combinable, str] + group_set: ManyToManyDescriptor[Group, Group_permissions] def natural_key(self) -> tuple[str, str, str]: ... class GroupManager(models.Manager[Group]): def get_by_natural_key(self, name: str) -> Group: ... +# This is a model that only exists in Django's model registry and doesn't have any +# class statement form. It's the through model between 'Group' and 'Permission'. +@type_check_only +class Group_permissions(models.Model): + objects: ClassVar[models.Manager[Self]] + + id: models.AutoField[str | int | Combinable | None, int] + pk: models.AutoField[str | int | Combinable | None, int] + group: models.ForeignKey[Group | Combinable, Group] + group_id: int + permission: models.ForeignKey[Permission | Combinable, Permission] + permission_id: int + class Group(models.Model): objects: ClassVar[GroupManager] - name = models.CharField(max_length=150) - permissions = models.ManyToManyField(Permission) + id: models.AutoField[str | int | Combinable | None, int] + pk: models.AutoField[str | int | Combinable | None, int] + name: models.CharField[str | int | Combinable, str] + permissions: models.ManyToManyField[Permission, Group_permissions] def natural_key(self) -> tuple[str]: ... _T = TypeVar("_T", bound=Model) @@ -57,9 +78,9 @@ class UserManager(BaseUserManager[_T]): ) -> QuerySet[_T]: ... class PermissionsMixin(models.Model): - is_superuser = models.BooleanField() - groups = models.ManyToManyField(Group) - user_permissions = models.ManyToManyField(Permission) + is_superuser: models.BooleanField[bool | Combinable, bool] + groups: models.ManyToManyField[Group, Any] + user_permissions: models.ManyToManyField[Permission, Any] def get_user_permissions(self, obj: _AnyUser | None = ...) -> set[str]: ... def get_group_permissions(self, obj: _AnyUser | None = ...) -> set[str]: ... @@ -71,13 +92,13 @@ class PermissionsMixin(models.Model): class AbstractUser(AbstractBaseUser, PermissionsMixin): username_validator: UnicodeUsernameValidator - username = models.CharField(max_length=150) - first_name = models.CharField(max_length=30, blank=True) - last_name = models.CharField(max_length=150, blank=True) - email = models.EmailField(blank=True) - is_staff = models.BooleanField() - is_active = models.BooleanField() - date_joined = models.DateTimeField() + username: models.CharField[str | int | Combinable, str] + first_name: models.CharField[str | int | Combinable, str] + last_name: models.CharField[str | int | Combinable, str] + email: models.EmailField[str | Combinable, str] + is_staff: models.BooleanField[bool | Combinable, bool] + is_active: models.BooleanField[bool | Combinable, bool] + date_joined: models.DateTimeField[str | datetime | date | Combinable, datetime] objects: ClassVar[UserManager[Self]] @@ -90,7 +111,9 @@ class AbstractUser(AbstractBaseUser, PermissionsMixin): self, subject: _StrOrPromise, message: _StrOrPromise, from_email: str = ..., **kwargs: Any ) -> None: ... -class User(AbstractUser): ... +class User(AbstractUser): + id: models.AutoField[str | int | Combinable | None, int] + pk: models.AutoField[str | int | Combinable | None, int] class AnonymousUser: id: Any diff --git a/django-stubs/contrib/contenttypes/models.pyi b/django-stubs/contrib/contenttypes/models.pyi index e7206aed0..8becc1b45 100644 --- a/django-stubs/contrib/contenttypes/models.pyi +++ b/django-stubs/contrib/contenttypes/models.pyi @@ -1,7 +1,11 @@ from typing import Any, ClassVar +from django.contrib.admin.models import LogEntry +from django.contrib.auth.models import Permission from django.db import models from django.db.models.base import Model +from django.db.models.expressions import Combinable +from django.db.models.fields.related_descriptors import ReverseManyToOneDescriptor from django.db.models.query import QuerySet class ContentTypeManager(models.Manager[ContentType]): @@ -12,13 +16,16 @@ class ContentTypeManager(models.Manager[ContentType]): def clear_cache(self) -> None: ... class ContentType(models.Model): - id: int - app_label: models.CharField - model: models.CharField + id: models.AutoField[str | int | Combinable | None, int] + pk: models.AutoField[str | int | Combinable | None, int] + app_label: models.CharField[str | int | Combinable, str] + model: models.CharField[str | int | Combinable, str] + logentry_set: ReverseManyToOneDescriptor[LogEntry] + permission_set: ReverseManyToOneDescriptor[Permission] objects: ClassVar[ContentTypeManager] @property def name(self) -> str: ... def model_class(self) -> type[Model] | None: ... def get_object_for_this_type(self, **kwargs: Any) -> Model: ... - def get_all_objects_for_this_type(self, **kwargs: Any) -> QuerySet: ... + def get_all_objects_for_this_type(self, **kwargs: Any) -> QuerySet[Model]: ... def natural_key(self) -> tuple[str, str]: ... diff --git a/django-stubs/contrib/flatpages/models.pyi b/django-stubs/contrib/flatpages/models.pyi index 4b578fe03..44eb29e82 100644 --- a/django-stubs/contrib/flatpages/models.pyi +++ b/django-stubs/contrib/flatpages/models.pyi @@ -1,12 +1,31 @@ +from typing import ClassVar, type_check_only + from django.contrib.sites.models import Site from django.db import models +from django.db.models.expressions import Combinable +from typing_extensions import Self + +# This is a model that only exists in Django's model registry and doesn't have any +# class statement form. It's the through model between 'FlatPage' and 'Site'. +@type_check_only +class FlatPage_sites(models.Model): + objects: ClassVar[models.Manager[Self]] + + id: models.AutoField[str | int | Combinable | None, int] + pk: models.AutoField[str | int | Combinable | None, int] + site: models.ForeignKey[Site | Combinable, Site] + site_id: int + flatpage: models.ForeignKey[FlatPage | Combinable, FlatPage] + flatpage_id: int class FlatPage(models.Model): - url: models.CharField - title: models.CharField - content: models.TextField - enable_comments: models.BooleanField - template_name: models.CharField - registration_required: models.BooleanField - sites: models.ManyToManyField[Site, Site] + id: models.AutoField[str | int | Combinable | None, int] + pk: models.AutoField[str | int | Combinable | None, int] + url: models.CharField[str | int | Combinable, str] + title: models.CharField[str | int | Combinable, str] + content: models.TextField[str | int | Combinable, str] + enable_comments: models.BooleanField[bool | Combinable, bool] + template_name: models.CharField[str | int | Combinable, str] + registration_required: models.BooleanField[bool | Combinable, bool] + sites: models.ManyToManyField[Site, FlatPage_sites] def get_absolute_url(self) -> str: ... diff --git a/django-stubs/contrib/redirects/models.pyi b/django-stubs/contrib/redirects/models.pyi index 444960881..d2b4c283c 100644 --- a/django-stubs/contrib/redirects/models.pyi +++ b/django-stubs/contrib/redirects/models.pyi @@ -1,6 +1,11 @@ +from django.contrib.sites.models import Site from django.db import models +from django.db.models.expressions import Combinable class Redirect(models.Model): - site: models.ForeignKey - old_path: models.CharField - new_path: models.CharField + id: models.AutoField[str | int | Combinable | None, int] + pk: models.AutoField[str | int | Combinable | None, int] + site: models.ForeignKey[Site | Combinable, Site] + site_id: int + old_path: models.CharField[str | int | Combinable, str] + new_path: models.CharField[str | int | Combinable, str] diff --git a/django-stubs/contrib/sessions/base_session.pyi b/django-stubs/contrib/sessions/base_session.pyi index 4da1bdb26..2298dda59 100644 --- a/django-stubs/contrib/sessions/base_session.pyi +++ b/django-stubs/contrib/sessions/base_session.pyi @@ -1,8 +1,9 @@ -from datetime import datetime +from datetime import date, datetime from typing import Any, ClassVar, TypeVar from django.contrib.sessions.backends.base import SessionBase from django.db import models +from django.db.models.expressions import Combinable from typing_extensions import Self _T = TypeVar("_T", bound=AbstractBaseSession) @@ -12,9 +13,11 @@ class BaseSessionManager(models.Manager[_T]): def save(self, session_key: str, session_dict: dict[str, Any], expire_date: datetime) -> _T: ... class AbstractBaseSession(models.Model): - session_key = models.CharField(primary_key=True) - session_data = models.TextField() - expire_date = models.DateTimeField() + session_key: models.CharField[str | int | Combinable | None, str] + # 'session_key' is declared as primary key + pk: models.CharField[str | int | Combinable | None, str] + session_data: models.TextField[str | int | Combinable, str] + expire_date: models.DateTimeField[str | datetime | date | Combinable, datetime] objects: ClassVar[BaseSessionManager[Self]] @classmethod diff --git a/django-stubs/contrib/sessions/models.pyi b/django-stubs/contrib/sessions/models.pyi index 492c9c34d..266e32d34 100644 --- a/django-stubs/contrib/sessions/models.pyi +++ b/django-stubs/contrib/sessions/models.pyi @@ -1,8 +1,11 @@ -from typing import TypeVar +from typing import ClassVar, TypeVar from django.contrib.sessions.base_session import AbstractBaseSession, BaseSessionManager +from typing_extensions import Self _T = TypeVar("_T", bound=Session) class SessionManager(BaseSessionManager[_T]): ... -class Session(AbstractBaseSession): ... + +class Session(AbstractBaseSession): + objects: ClassVar[SessionManager[Self]] diff --git a/django-stubs/contrib/sites/models.pyi b/django-stubs/contrib/sites/models.pyi index e4181d66f..07650ac4f 100644 --- a/django-stubs/contrib/sites/models.pyi +++ b/django-stubs/contrib/sites/models.pyi @@ -1,6 +1,10 @@ from typing import Any, ClassVar +from django.contrib.flatpages.models import FlatPage, FlatPage_sites +from django.contrib.redirects.models import Redirect from django.db import models +from django.db.models.expressions import Combinable +from django.db.models.fields.related_descriptors import ManyToManyDescriptor, ReverseManyToOneDescriptor from django.http.request import HttpRequest SITE_CACHE: Any @@ -13,8 +17,12 @@ class SiteManager(models.Manager[Site]): class Site(models.Model): objects: ClassVar[SiteManager] - domain = models.CharField(max_length=100) - name = models.CharField(max_length=50) + id: models.AutoField[str | int | Combinable | None, int] + pk: models.AutoField[str | int | Combinable | None, int] + domain: models.CharField[str | int | Combinable, str] + name: models.CharField[str | int | Combinable, str] + flatpage_set: ManyToManyDescriptor[FlatPage, FlatPage_sites] + redirect_set: ReverseManyToOneDescriptor[Redirect] def natural_key(self) -> tuple[str]: ... def clear_site_cache(sender: type[Site], **kwargs: Any) -> None: ... diff --git a/django-stubs/db/models/fields/related_descriptors.pyi b/django-stubs/db/models/fields/related_descriptors.pyi index a6efc796a..9a1a67c9b 100644 --- a/django-stubs/db/models/fields/related_descriptors.pyi +++ b/django-stubs/db/models/fields/related_descriptors.pyi @@ -73,7 +73,7 @@ class ReverseOneToOneDescriptor(Generic[_From, _To]): def __set__(self, instance: _From, value: _To | None) -> None: ... def __reduce__(self) -> tuple[Callable[..., Any], tuple[type[_To], str]]: ... -class ReverseManyToOneDescriptor: +class ReverseManyToOneDescriptor(Generic[_To]): """ In the example:: @@ -84,14 +84,14 @@ class ReverseManyToOneDescriptor: """ rel: ManyToOneRel - field: ForeignKey + field: ForeignKey[_To, _To] def __init__(self, rel: ManyToOneRel) -> None: ... @cached_property - def related_manager_cls(self) -> type[RelatedManager[Any]]: ... + def related_manager_cls(self) -> type[RelatedManager[_To]]: ... @overload def __get__(self, instance: None, cls: Any = ...) -> Self: ... @overload - def __get__(self, instance: Model, cls: Any = ...) -> RelatedManager[Any]: ... + def __get__(self, instance: Model, cls: Any = ...) -> RelatedManager[_To]: ... def __set__(self, instance: Any, value: Any) -> NoReturn: ... # Fake class, Django defines 'RelatedManager' inside a function body diff --git a/mypy_django_plugin/lib/fullnames.py b/mypy_django_plugin/lib/fullnames.py index 29497ed10..9300134a6 100644 --- a/mypy_django_plugin/lib/fullnames.py +++ b/mypy_django_plugin/lib/fullnames.py @@ -1,4 +1,7 @@ ABSTRACT_USER_MODEL_FULLNAME = "django.contrib.auth.models.AbstractUser" +USER_MODEL_FULLNAME = "django.contrib.auth.models.User" +PERMISSION_MODEL_FULLNAME = "django.contrib.auth.models.Permission" +GROUP_MODEL_FULLNAME = "django.contrib.auth.models.Group" PERMISSION_MIXIN_CLASS_FULLNAME = "django.contrib.auth.models.PermissionsMixin" MODEL_METACLASS_FULLNAME = "django.db.models.base.ModelBase" MODEL_CLASS_FULLNAME = "django.db.models.base.Model" @@ -12,7 +15,7 @@ ONETOONE_FIELD_FULLNAME = "django.db.models.fields.related.OneToOneField" MANYTOMANY_FIELD_FULLNAME = "django.db.models.fields.related.ManyToManyField" DUMMY_SETTINGS_BASE_CLASS = "django.conf._DjangoConfLazyObject" -AUTH_USER_MODEL_FULLNAME = "django.conf.settings.AUTH_USER_MODEL" +AUTH_USER_MODEL_SETTING_FULLNAME = "django.conf.settings.AUTH_USER_MODEL" QUERYSET_CLASS_FULLNAME = "django.db.models.query.QuerySet" BASE_MANAGER_CLASS_FULLNAME = "django.db.models.manager.BaseManager" diff --git a/mypy_django_plugin/transformers/fields.py b/mypy_django_plugin/transformers/fields.py index 731a102b5..cfd280aa4 100644 --- a/mypy_django_plugin/transformers/fields.py +++ b/mypy_django_plugin/transformers/fields.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union, cast from django.core.exceptions import FieldDoesNotExist from django.db.models.fields import AutoField, Field @@ -114,12 +114,17 @@ def fill_descriptor_types_for_related_field(ctx: FunctionContext, django_context ) +class FieldDescriptorTypes(NamedTuple): + set: MypyType + get: MypyType + + def get_field_descriptor_types( field_info: TypeInfo, *, is_set_nullable: bool, is_get_nullable: bool -) -> Tuple[MypyType, MypyType]: +) -> FieldDescriptorTypes: set_type = helpers.get_private_descriptor_type(field_info, "_pyi_private_set_type", is_nullable=is_set_nullable) get_type = helpers.get_private_descriptor_type(field_info, "_pyi_private_get_type", is_nullable=is_get_nullable) - return set_type, get_type + return FieldDescriptorTypes(set=set_type, get=get_type) def set_descriptor_types_for_field_callback(ctx: FunctionContext, django_context: DjangoContext) -> MypyType: diff --git a/mypy_django_plugin/transformers/manytomany.py b/mypy_django_plugin/transformers/manytomany.py index 9d0f7d1a0..848bc2438 100644 --- a/mypy_django_plugin/transformers/manytomany.py +++ b/mypy_django_plugin/transformers/manytomany.py @@ -1,7 +1,7 @@ from typing import NamedTuple, Optional, Tuple, Union from mypy.checker import TypeChecker -from mypy.nodes import AssignmentStmt, Expression, MemberExpr, NameExpr, RefExpr, StrExpr, TypeInfo +from mypy.nodes import AssignmentStmt, Expression, MemberExpr, NameExpr, Node, RefExpr, StrExpr, TypeInfo from mypy.plugin import FunctionContext, MethodContext from mypy.semanal import SemanticAnalyzer from mypy.types import Instance, ProperType, TypeVarType, UninhabitedType @@ -12,12 +12,12 @@ class M2MThrough(NamedTuple): - arg: Optional[Expression] + arg: Optional[Node] model: ProperType class M2MTo(NamedTuple): - arg: Expression + arg: Node model: ProperType self: bool # ManyToManyField('self', ...) @@ -139,7 +139,7 @@ def get_model_from_expression( elif ( isinstance(expr, MemberExpr) and isinstance(expr.expr, NameExpr) - and f"{expr.expr.fullname}.{expr.name}" == fullnames.AUTH_USER_MODEL_FULLNAME + and f"{expr.expr.fullname}.{expr.name}" == fullnames.AUTH_USER_MODEL_SETTING_FULLNAME ): lazy_reference = django_context.settings.AUTH_USER_MODEL diff --git a/mypy_django_plugin/transformers/models.py b/mypy_django_plugin/transformers/models.py index 24d1ed197..a28c62b45 100644 --- a/mypy_django_plugin/transformers/models.py +++ b/mypy_django_plugin/transformers/models.py @@ -36,7 +36,7 @@ from mypy_django_plugin.exceptions import UnregisteredModelError from mypy_django_plugin.lib import fullnames, helpers from mypy_django_plugin.lib.fullnames import ANNOTATIONS_FULLNAME, ANY_ATTR_ALLOWED_CLASS_FULLNAME, MODEL_CLASS_FULLNAME -from mypy_django_plugin.transformers.fields import get_field_descriptor_types +from mypy_django_plugin.transformers.fields import FieldDescriptorTypes, get_field_descriptor_types from mypy_django_plugin.transformers.managers import ( MANAGER_METHODS_RETURNING_QUERYSET, create_manager_info_from_from_queryset_call, @@ -658,16 +658,27 @@ def run(self) -> None: # TODO: Create abstract through models? return - # Start out by prefetching a couple of dependencies needed to be able to declare any - # new, implicit, through model class. - model_base = self.lookup_typeinfo(fullnames.MODEL_CLASS_FULLNAME) - fk_field = self.lookup_typeinfo(fullnames.FOREIGN_KEY_FULLNAME) - manager_info = self.lookup_typeinfo(fullnames.MANAGER_CLASS_FULLNAME) - if model_base is None or fk_field is None or manager_info is None: - raise helpers.IncompleteDefnException() - - from_pk = self.get_pk_instance(self.model_classdef.info) - fk_set_type, fk_get_type = get_field_descriptor_types(fk_field, is_set_nullable=False, is_get_nullable=False) + if self.model_classdef.fullname == fullnames.USER_MODEL_FULLNAME: + permission_typeinfo = self.lookup_typeinfo_or_incomplete_defn_error(fullnames.PERMISSION_MODEL_FULLNAME) + self.create_through_table_class( + field_name="user_permissions", + model_name="User_user_permissions", + model_fullname="django.contrib.auth.models.User_user_permissions", + m2m_args=M2MArguments( + to=M2MTo(arg=self.model_classdef, model=Instance(permission_typeinfo, []), self=False), through=None + ), + ) + group_typeinfo = self.lookup_typeinfo_or_incomplete_defn_error(fullnames.GROUP_MODEL_FULLNAME) + self.create_through_table_class( + field_name="groups", + model_name="User_groups", + model_fullname="django.contrib.auth.models.User_groups", + m2m_args=M2MArguments( + to=M2MTo(arg=self.model_classdef, model=Instance(group_typeinfo, []), self=False), + through=None, + ), + ) + return for statement in self.statements(): # Check if this part of the class body is an assignment from a 'ManyToManyField' call @@ -689,90 +700,16 @@ def run(self) -> None: continue # Resolve argument information of the 'ManyToManyField(...)' call args = self.resolve_many_to_many_arguments(statement.rvalue, context=statement) - if ( - # Ignore calls without required 'to' argument, mypy will complain - args is None - or not isinstance(args.to.model, Instance) - # Call has explicit 'through=', no need to create any implicit through table - or args.through is not None - ): + # Ignore calls without required 'to' argument, mypy will complain + if args is None: continue - # Get the names of the implicit through model that will be generated through_model_name = f"{self.model_classdef.name}_{m2m_field_name}" - through_model_fullname = f"{self.model_classdef.info.module_name}.{through_model_name}" - # If implicit through model is already declared there's nothing more we should do - through_model = self.lookup_typeinfo(through_model_fullname) - if through_model is not None: - continue - # Declare a new, empty, implicitly generated through model class named: '_' - through_model = self.add_new_class_for_current_module( - through_model_name, bases=[Instance(model_base, [])] - ) - # We attempt to be a bit clever here and store the generated through model's fullname in - # the metadata of the class containing the 'ManyToManyField' call expression, where its - # identifier is the field name of the 'ManyToManyField'. This would allow the containing - # model to always find the implicit through model, so that it doesn't get lost. - model_metadata = helpers.get_django_metadata(self.model_classdef.info) - model_metadata.setdefault("m2m_throughs", {}) - model_metadata["m2m_throughs"][m2m_field_name] = through_model.fullname - # Add a 'pk' symbol to the model class - helpers.add_new_sym_for_info( - through_model, name="pk", sym_type=self.default_pk_instance.copy_modified() - ) - # Add an 'id' symbol to the model class - helpers.add_new_sym_for_info( - through_model, name="id", sym_type=self.default_pk_instance.copy_modified() - ) - # Add the foreign key to the model containing the 'ManyToManyField' call: - # or from_ - from_name = ( - f"from_{self.model_classdef.name.lower()}" if args.to.self else self.model_classdef.name.lower() - ) - helpers.add_new_sym_for_info( - through_model, - name=from_name, - sym_type=Instance( - fk_field, - [ - helpers.convert_any_to_type(fk_set_type, Instance(self.model_classdef.info, [])), - helpers.convert_any_to_type(fk_get_type, Instance(self.model_classdef.info, [])), - ], - ), - ) - # Add the foreign key's '_id' field: _id or from__id - helpers.add_new_sym_for_info(through_model, name=f"{from_name}_id", sym_type=from_pk.copy_modified()) - # Add the foreign key to the model on the opposite side of the relation - # i.e. the model given as 'to' argument to the 'ManyToManyField' call: - # or to_ - to_name = f"to_{args.to.model.type.name.lower()}" if args.to.self else args.to.model.type.name.lower() - helpers.add_new_sym_for_info( - through_model, - name=to_name, - sym_type=Instance( - fk_field, - [ - helpers.convert_any_to_type(fk_set_type, args.to.model), - helpers.convert_any_to_type(fk_get_type, args.to.model), - ], - ), - ) - # Add the foreign key's '_id' field: _id or to__id - other_pk = self.get_pk_instance(args.to.model.type) - helpers.add_new_sym_for_info(through_model, name=f"{to_name}_id", sym_type=other_pk.copy_modified()) - # Add a manager named 'objects' - helpers.add_new_sym_for_info( - through_model, - name="objects", - sym_type=Instance(manager_info, [Instance(through_model, [])]), - is_classvar=True, - ) - # Also add manager as '_default_manager' attribute - helpers.add_new_sym_for_info( - through_model, - name="_default_manager", - sym_type=Instance(manager_info, [Instance(through_model, [])]), - is_classvar=True, + self.create_through_table_class( + field_name=m2m_field_name, + model_name=through_model_name, + model_fullname=f"{self.model_classdef.info.module_name}.{through_model_name}", + m2m_args=args, ) @cached_property @@ -785,6 +722,35 @@ def default_pk_instance(self) -> Instance: list(get_field_descriptor_types(default_pk_field, is_set_nullable=True, is_get_nullable=False)), ) + @cached_property + def model_pk_instance(self) -> Instance: + return self.get_pk_instance(self.model_classdef.info) + + @cached_property + def model_base(self) -> TypeInfo: + info = self.lookup_typeinfo(fullnames.MODEL_CLASS_FULLNAME) + if info is None: + raise helpers.IncompleteDefnException() + return info + + @cached_property + def fk_field(self) -> TypeInfo: + info = self.lookup_typeinfo(fullnames.FOREIGN_KEY_FULLNAME) + if info is None: + raise helpers.IncompleteDefnException() + return info + + @cached_property + def manager_info(self) -> TypeInfo: + info = self.lookup_typeinfo(fullnames.MANAGER_CLASS_FULLNAME) + if info is None: + raise helpers.IncompleteDefnException() + return info + + @cached_property + def fk_field_types(self) -> FieldDescriptorTypes: + return get_field_descriptor_types(self.fk_field, is_set_nullable=False, is_get_nullable=False) + def get_pk_instance(self, model: TypeInfo, /) -> Instance: """ Get a primary key instance of provided model's type info. If primary key can't be resolved, @@ -797,6 +763,86 @@ def get_pk_instance(self, model: TypeInfo, /) -> Instance: return pk.type return self.default_pk_instance + def create_through_table_class( + self, field_name: str, model_name: str, model_fullname: str, m2m_args: M2MArguments + ) -> None: + if ( + not isinstance(m2m_args.to.model, Instance) + # Call has explicit 'through=', no need to create any implicit through table + or m2m_args.through is not None + ): + return + + # If through model is already declared there's nothing more we should do + through_model = self.lookup_typeinfo(model_fullname) + if through_model is not None: + return + # Declare a new, empty, implicitly generated through model class named: '_' + through_model = self.add_new_class_for_current_module(model_name, bases=[Instance(self.model_base, [])]) + # We attempt to be a bit clever here and store the generated through model's fullname in + # the metadata of the class containing the 'ManyToManyField' call expression, where its + # identifier is the field name of the 'ManyToManyField'. This would allow the containing + # model to always find the implicit through model, so that it doesn't get lost. + model_metadata = helpers.get_django_metadata(self.model_classdef.info) + model_metadata.setdefault("m2m_throughs", {}) + model_metadata["m2m_throughs"][field_name] = through_model.fullname + # Add a 'pk' symbol to the model class + helpers.add_new_sym_for_info(through_model, name="pk", sym_type=self.default_pk_instance.copy_modified()) + # Add an 'id' symbol to the model class + helpers.add_new_sym_for_info(through_model, name="id", sym_type=self.default_pk_instance.copy_modified()) + # Add the foreign key to the model containing the 'ManyToManyField' call: + # or from_ + from_name = f"from_{self.model_classdef.name.lower()}" if m2m_args.to.self else self.model_classdef.name.lower() + helpers.add_new_sym_for_info( + through_model, + name=from_name, + sym_type=Instance( + self.fk_field, + [ + helpers.convert_any_to_type(self.fk_field_types.set, Instance(self.model_classdef.info, [])), + helpers.convert_any_to_type(self.fk_field_types.get, Instance(self.model_classdef.info, [])), + ], + ), + ) + # Add the foreign key's '_id' field: _id or from__id + helpers.add_new_sym_for_info( + through_model, name=f"{from_name}_id", sym_type=self.model_pk_instance.copy_modified() + ) + # Add the foreign key to the model on the opposite side of the relation + # i.e. the model given as 'to' argument to the 'ManyToManyField' call: + # or to_ + to_name = ( + f"to_{m2m_args.to.model.type.name.lower()}" if m2m_args.to.self else m2m_args.to.model.type.name.lower() + ) + helpers.add_new_sym_for_info( + through_model, + name=to_name, + sym_type=Instance( + self.fk_field, + [ + helpers.convert_any_to_type(self.fk_field_types.set, m2m_args.to.model), + helpers.convert_any_to_type(self.fk_field_types.get, m2m_args.to.model), + ], + ), + ) + # Add the foreign key's '_id' field: _id or to__id + other_pk = self.get_pk_instance(m2m_args.to.model.type) + helpers.add_new_sym_for_info(through_model, name=f"{to_name}_id", sym_type=other_pk.copy_modified()) + # Add a manager named 'objects' + helpers.add_new_sym_for_info( + through_model, + name="objects", + sym_type=Instance(self.manager_info, [Instance(through_model, [])]), + is_classvar=True, + ) + # Also add manager as '_default_manager' attribute + helpers.add_new_sym_for_info( + through_model, + name="_default_manager", + sym_type=Instance(self.manager_info, [Instance(through_model, [])]), + is_classvar=True, + ) + def resolve_many_to_many_arguments(self, call: CallExpr, /, context: Context) -> Optional[M2MArguments]: """ Inspect a 'ManyToManyField(...)' call to collect argument data on any 'to' and diff --git a/tests/assert_type/contrib/admin/test_admin_models.py b/tests/assert_type/contrib/admin/test_admin_models.py new file mode 100644 index 000000000..8b728e9d3 --- /dev/null +++ b/tests/assert_type/contrib/admin/test_admin_models.py @@ -0,0 +1,20 @@ +from datetime import datetime +from typing import Optional + +from django.contrib.admin.models import LogEntry, LogEntryManager +from django.contrib.auth.models import AbstractUser +from django.contrib.contenttypes.models import ContentType +from typing_extensions import assert_type + +log_entry = LogEntry() +assert_type(log_entry.id, int) +assert_type(log_entry.pk, int) +assert_type(log_entry.action_time, datetime) +assert_type(log_entry.user, AbstractUser) +assert_type(log_entry.content_type, Optional[ContentType]) +assert_type(log_entry.content_type_id, Optional[int]) +assert_type(log_entry.object_id, Optional[str]) +assert_type(log_entry.object_repr, str) +assert_type(log_entry.action_flag, int) +assert_type(log_entry.change_message, str) +assert_type(LogEntry.objects, LogEntryManager) diff --git a/tests/assert_type/contrib/auth/test_auth_models.py b/tests/assert_type/contrib/auth/test_auth_models.py new file mode 100644 index 000000000..97d3cd3d2 --- /dev/null +++ b/tests/assert_type/contrib/auth/test_auth_models.py @@ -0,0 +1,52 @@ +from datetime import datetime +from typing import Any, Optional, Type + +from django.contrib.auth.models import Group, Group_permissions, Permission, User +from django.contrib.contenttypes.models import ContentType +from django.db.models import Manager +from typing_extensions import assert_type + +user = User() +assert_type(user.id, int) +assert_type(user.pk, int) +assert_type(user.password, str) +assert_type(user.last_login, Optional[datetime]) +assert_type(user.is_active, bool) +assert_type(user.username, str) +assert_type(user.first_name, str) +assert_type(user.last_name, str) +assert_type(user.email, str) +assert_type(user.is_staff, bool) +assert_type(user.is_active, bool) +assert_type(user.date_joined, datetime) +assert_type(user.groups.get(), Group) +# '.through' should really by 'Type[Any]' but pyright doesn't follow along +assert_type(user.groups.through, Type[Any]) # pyright: ignore[reportAssertTypeFailure] +assert_type(user.user_permissions.get(), Permission) +# '.through' should really by 'Type[Any]' but pyright doesn't follow along +assert_type(user.user_permissions.through, Type[Any]) # pyright: ignore[reportAssertTypeFailure] + +group = Group() +assert_type(group.permissions.get(), Permission) +# Pyright doesn't allow "runtime" usage of @type_check_only 'Group_permissions' but +# we're only type checking these files so it should be fine. +assert_type(group.permissions.through, Type[Group_permissions]) # pyright: ignore[reportGeneralTypeIssues] +assert_type(Group.permissions.through, Type[Group_permissions]) # pyright: ignore[reportGeneralTypeIssues] +assert_type(Group.permissions.through.objects, Manager[Group_permissions]) # pyright: ignore[reportGeneralTypeIssues] + +group_permissions = Group.permissions.through.objects.get() +assert_type(group_permissions.id, int) +assert_type(group_permissions.pk, int) +assert_type(group_permissions.group, Group) +assert_type(group_permissions.group_id, int) +assert_type(group_permissions.permission, Permission) +assert_type(group_permissions.permission_id, int) + +permission = Permission() +assert_type(permission.id, int) +assert_type(permission.pk, int) +assert_type(permission.name, str) +assert_type(permission.content_type, ContentType) +assert_type(permission.content_type_id, int) +assert_type(permission.group_set.get(), Group) +assert_type(permission.group_set.through.objects.get(), Group_permissions) # pyright: ignore[reportGeneralTypeIssues] diff --git a/tests/assert_type/contrib/contenttypes/test_contenttypes_models.py b/tests/assert_type/contrib/contenttypes/test_contenttypes_models.py new file mode 100644 index 000000000..b7459a23e --- /dev/null +++ b/tests/assert_type/contrib/contenttypes/test_contenttypes_models.py @@ -0,0 +1,12 @@ +from django.contrib.admin.models import LogEntry +from django.contrib.auth.models import Permission +from django.contrib.contenttypes.models import ContentType +from typing_extensions import assert_type + +content_type = ContentType() +assert_type(content_type.id, int) +assert_type(content_type.pk, int) +assert_type(content_type.app_label, str) +assert_type(content_type.model, str) +assert_type(content_type.logentry_set.get(), LogEntry) +assert_type(content_type.permission_set.get(), Permission) diff --git a/tests/assert_type/contrib/flatpages/test_flatpages_models.py b/tests/assert_type/contrib/flatpages/test_flatpages_models.py new file mode 100644 index 000000000..e8a0689a2 --- /dev/null +++ b/tests/assert_type/contrib/flatpages/test_flatpages_models.py @@ -0,0 +1,26 @@ +from django.contrib.flatpages.models import FlatPage, FlatPage_sites +from django.contrib.sites.models import Site +from django.db.models import Manager +from typing_extensions import assert_type + +flat_page = FlatPage() +assert_type(flat_page.id, int) +assert_type(flat_page.pk, int) +assert_type(flat_page.url, str) +assert_type(flat_page.title, str) +assert_type(flat_page.content, str) +assert_type(flat_page.enable_comments, bool) +assert_type(flat_page.template_name, str) +assert_type(flat_page.registration_required, bool) +assert_type(flat_page.sites.get(), Site) + +# Pyright doesn't allow "runtime" usage of @type_check_only 'FlatPage_sites' but +# we're only type checking these files so it should be fine. +assert_type(FlatPage.sites.through.objects, Manager[FlatPage_sites]) # pyright: ignore[reportGeneralTypeIssues] +flat_page_sites = FlatPage.sites.through.objects.get() +assert_type(flat_page_sites.id, int) +assert_type(flat_page_sites.pk, int) +assert_type(flat_page_sites.site, Site) +assert_type(flat_page_sites.site_id, int) +assert_type(flat_page_sites.flatpage, FlatPage) +assert_type(flat_page_sites.flatpage_id, int) diff --git a/tests/assert_type/contrib/sessions/test_sessions_models.py b/tests/assert_type/contrib/sessions/test_sessions_models.py new file mode 100644 index 000000000..d7e719076 --- /dev/null +++ b/tests/assert_type/contrib/sessions/test_sessions_models.py @@ -0,0 +1,11 @@ +from datetime import datetime + +from django.contrib.sessions.models import Session, SessionManager +from typing_extensions import assert_type + +session = Session() +assert_type(session.session_key, str) +assert_type(session.pk, str) +assert_type(session.session_data, str) +assert_type(session.expire_date, datetime) +assert_type(session.objects, SessionManager[Session]) diff --git a/tests/assert_type/contrib/sites/test_sites_models.py b/tests/assert_type/contrib/sites/test_sites_models.py new file mode 100644 index 000000000..d47b6168c --- /dev/null +++ b/tests/assert_type/contrib/sites/test_sites_models.py @@ -0,0 +1,18 @@ +from typing import Type + +from django.contrib.flatpages.models import FlatPage, FlatPage_sites +from django.contrib.redirects.models import Redirect +from django.contrib.sites.models import Site +from typing_extensions import assert_type + +site = Site() +assert_type(site.id, int) +assert_type(site.pk, int) +assert_type(site.domain, str) +assert_type(site.name, str) +assert_type(site.flatpage_set.get(), FlatPage) +assert_type(site.redirect_set.get(), Redirect) + +# Pyright doesn't allow "runtime" usage of @type_check_only 'FlatPage_sites' but +# we're only type checking these files so it should be fine. +assert_type(site.flatpage_set.through, Type[FlatPage_sites]) # pyright: ignore[reportGeneralTypeIssues] diff --git a/tests/typecheck/models/test_contrib_models.yml b/tests/typecheck/models/test_contrib_models.yml index e055640f2..eef6c077d 100644 --- a/tests/typecheck/models/test_contrib_models.yml +++ b/tests/typecheck/models/test_contrib_models.yml @@ -15,8 +15,8 @@ reveal_type(User().is_anonymous) # N: Revealed type is "Literal[False]" reveal_type(User().groups.get()) # N: Revealed type is "django.contrib.auth.models.Group" reveal_type(User().user_permissions.get()) # N: Revealed type is "django.contrib.auth.models.Permission" - reveal_type(User.groups) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[django.contrib.auth.models.Group, django.db.models.base.Model]" - reveal_type(User.user_permissions) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[django.contrib.auth.models.Permission, django.db.models.base.Model]" + reveal_type(User.groups) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[django.contrib.auth.models.Group, Any]" + reveal_type(User.user_permissions) # N: Revealed type is "django.db.models.fields.related_descriptors.ManyToManyDescriptor[django.contrib.auth.models.Permission, Any]" from django.contrib.auth.models import AnonymousUser reveal_type(AnonymousUser().is_authenticated) # N: Revealed type is "Literal[False]" From 9c192e23fdefa9945d943c18eca0bb1a2bc0ae5f Mon Sep 17 00:00:00 2001 From: Petter Friberg Date: Sun, 9 Jun 2024 10:45:15 +0200 Subject: [PATCH 2/3] fixup! Add type hints to builtin models' fields --- django-stubs/contrib/sessions/models.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/django-stubs/contrib/sessions/models.pyi b/django-stubs/contrib/sessions/models.pyi index 266e32d34..013e4915e 100644 --- a/django-stubs/contrib/sessions/models.pyi +++ b/django-stubs/contrib/sessions/models.pyi @@ -8,4 +8,4 @@ _T = TypeVar("_T", bound=Session) class SessionManager(BaseSessionManager[_T]): ... class Session(AbstractBaseSession): - objects: ClassVar[SessionManager[Self]] + objects: ClassVar[SessionManager[Self]] # type: ignore[assignment] From 2715917cb33290e8496a0a4c1a3a4f94f0361985 Mon Sep 17 00:00:00 2001 From: Petter Friberg Date: Sun, 9 Jun 2024 11:17:50 +0200 Subject: [PATCH 3/3] fixup! fixup! Add type hints to builtin models' fields --- scripts/stubtest/allowlist_todo.txt | 4 ---- 1 file changed, 4 deletions(-) diff --git a/scripts/stubtest/allowlist_todo.txt b/scripts/stubtest/allowlist_todo.txt index 63dade198..363800220 100644 --- a/scripts/stubtest/allowlist_todo.txt +++ b/scripts/stubtest/allowlist_todo.txt @@ -169,9 +169,7 @@ django.contrib.contenttypes.management.inject_rename_contenttypes_operations django.contrib.contenttypes.models.ContentType.app_label django.contrib.contenttypes.models.ContentType.app_labeled_name django.contrib.contenttypes.models.ContentType.id -django.contrib.contenttypes.models.ContentType.logentry_set django.contrib.contenttypes.models.ContentType.model -django.contrib.contenttypes.models.ContentType.permission_set django.contrib.contenttypes.models.ContentTypeManager.__init__ django.contrib.contenttypes.models.ContentTypeManager.__slotnames__ django.contrib.flatpages.admin.FlatPageAdmin @@ -513,10 +511,8 @@ django.contrib.sessions.models.SessionManager.__slotnames__ django.contrib.sitemaps.views.SitemapIndexItem django.contrib.sites.admin.SiteAdmin django.contrib.sites.models.Site.domain -django.contrib.sites.models.Site.flatpage_set django.contrib.sites.models.Site.id django.contrib.sites.models.Site.name -django.contrib.sites.models.Site.redirect_set django.contrib.sites.models.SiteManager.__slotnames__ django.contrib.staticfiles.finders.BaseStorageFinder.storage django.contrib.staticfiles.finders.DefaultStorageFinder.storage