From f0091b1af12e17003e2238cc48014b0639b5dee8 Mon Sep 17 00:00:00 2001 From: Maarten ter Huurne Date: Mon, 25 Mar 2024 17:30:48 +0100 Subject: [PATCH] Add minimal type annotations to make mypy pass --- tests/models.py | 23 ++++++++++++++--------- tests/test_fields/test_field_tracker.py | 15 +++++++++------ 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/tests/models.py b/tests/models.py index 6a5849e6..90d50e8c 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1,7 +1,12 @@ +from __future__ import annotations + +from typing import ClassVar + from django.db import models from django.db.models import Manager from django.db.models.query_utils import DeferredAttribute from django.utils.translation import gettext_lazy as _ +from typing_extensions import Self from model_utils import Choices from model_utils.fields import MonitorField, SplitField, StatusField, UUIDField @@ -32,7 +37,7 @@ class InheritanceManagerTestParent(models.Model): related_self = models.OneToOneField( "self", related_name="imtests_self", null=True, on_delete=models.CASCADE) - objects = InheritanceManager() + objects: ClassVar[Manager[Self]] = InheritanceManager() def __str__(self): return "{}({})".format( @@ -44,7 +49,7 @@ def __str__(self): class InheritanceManagerTestChild1(InheritanceManagerTestParent): non_related_field_using_descriptor_2 = models.FileField(upload_to="test") normal_field_2 = models.TextField() - objects = InheritanceManager() + objects: ClassVar[Manager[Self]] = InheritanceManager() class InheritanceManagerTestGrandChild1(InheritanceManagerTestChild1): @@ -171,8 +176,8 @@ class Post(models.Model): order = models.IntegerField() objects = models.Manager() - public = QueryManager(published=True) - public_confirmed = QueryManager( + public: ClassVar[QueryManager[Self]] = QueryManager(published=True) + public_confirmed: ClassVar[QueryManager[Self]] = QueryManager( models.Q(published=True) & models.Q(confirmed=True)) public_reversed = QueryManager(published=True).order_by("-order") @@ -193,7 +198,7 @@ class Meta: class AbstractTracked(models.Model): - number = 1 + number: models.IntegerField class Meta: abstract = True @@ -339,13 +344,13 @@ class SoftDeletable(SoftDeletableModel): """ name = models.CharField(max_length=20) - all_objects = models.Manager() + all_objects: ClassVar[Manager[SoftDeletable]] = models.Manager() class CustomSoftDelete(SoftDeletableModel): is_read = models.BooleanField(default=False) - objects = CustomSoftDeleteManager() + objects: ClassVar[CustomSoftDeleteManager[Self]] = CustomSoftDeleteManager() class StringyDescriptor: @@ -389,7 +394,7 @@ class ModelWithCustomDescriptor(models.Model): class BoxJoinModel(models.Model): name = models.CharField(max_length=32) - objects = JoinManager() + objects: ClassVar[JoinManager[Self]] = JoinManager() class JoinItemForeignKey(models.Model): @@ -399,7 +404,7 @@ class JoinItemForeignKey(models.Model): null=True, on_delete=models.CASCADE ) - objects = JoinManager() + objects: ClassVar[JoinManager[Self]] = JoinManager() class CustomUUIDModel(UUIDModel): diff --git a/tests/test_fields/test_field_tracker.py b/tests/test_fields/test_field_tracker.py index a2e12c3d..038aaaa7 100644 --- a/tests/test_fields/test_field_tracker.py +++ b/tests/test_fields/test_field_tracker.py @@ -1,7 +1,10 @@ +from __future__ import annotations + from unittest import skip from django.core.cache import cache from django.core.exceptions import FieldError +from django.db import models from django.db.models.fields.files import FieldFile from django.test import TestCase @@ -69,7 +72,7 @@ def test_pre_save_previous(self): class FieldTrackerTests(FieldTrackerTestCase, FieldTrackerCommonTests): - tracked_class = Tracked + tracked_class: type[models.Model] = Tracked def setUp(self): self.instance = self.tracked_class() @@ -276,7 +279,7 @@ def test_with_deferred_fields_access_multiple(self): class FieldTrackedModelCustomTests(FieldTrackerTestCase, FieldTrackerCommonTests): - tracked_class = TrackedNotDefault + tracked_class: type[models.Model] = TrackedNotDefault def setUp(self): self.instance = self.tracked_class() @@ -407,7 +410,7 @@ def test_current(self): class FieldTrackedModelMultiTests(FieldTrackerTestCase, FieldTrackerCommonTests): - tracked_class = TrackedMultiple + tracked_class: type[models.Model] = TrackedMultiple def setUp(self): self.instance = self.tracked_class() @@ -498,8 +501,8 @@ def test_current(self): class FieldTrackerForeignKeyTests(FieldTrackerTestCase): - fk_class = Tracked - tracked_class = TrackedFK + fk_class: type[models.Model] = Tracked + tracked_class: type[models.Model] = TrackedFK def setUp(self): self.old_fk = self.fk_class.objects.create(number=8) @@ -725,7 +728,7 @@ def test_current(self): class ModelTrackerTests(FieldTrackerTests): - tracked_class = ModelTracked + tracked_class: type[models.Model] = ModelTracked def test_cache_compatible(self): cache.set('key', self.instance)