From 27fffbf9ed3d012f7969f37096296e3d03a9a491 Mon Sep 17 00:00:00 2001 From: Adam Johnson Date: Mon, 15 Aug 2022 10:14:13 +0100 Subject: [PATCH] Use Mypy's strict mode --- .pre-commit-config.yaml | 3 ++ pyproject.toml | 11 ++-- src/django_mysql/compat.py | 2 + src/django_mysql/models/__init__.py | 59 ++++++++++++++++------ src/django_mysql/models/fields/__init__.py | 4 +- src/django_mysql/models/fields/dynamic.py | 10 ++-- src/django_mysql/operations.py | 34 ++++++------- src/django_mysql/status.py | 6 +-- tests/testapp/models.py | 3 +- tests/testapp/test_functions.py | 2 +- tests/testapp/test_locks.py | 4 +- 11 files changed, 85 insertions(+), 53 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e07dd48fe..5a462b29c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -45,3 +45,6 @@ repos: rev: v0.971 hooks: - id: mypy + additional_dependencies: + - django-stubs==1.12.0 + - mysqlclient diff --git a/pyproject.toml b/pyproject.toml index 808a4afbf..3c5c17dca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,20 +5,19 @@ build-backend = "setuptools.build_meta" [tool.black] target-version = ['py37'] +[tool.django-stubs] +django_settings_module = "tests.settings" + [tool.isort] profile = "black" add_imports = "from __future__ import annotations" [tool.mypy] -check_untyped_defs = true -disallow_any_generics = true -disallow_incomplete_defs = true -disallow_untyped_defs = true mypy_path = "src/" -no_implicit_optional = true +plugins = ["mypy_django_plugin.main"] show_error_codes = true +strict = true warn_unreachable = true -warn_unused_ignores = true [[tool.mypy.overrides]] module = "tests.*" diff --git a/src/django_mysql/compat.py b/src/django_mysql/compat.py index c7b280167..b7c307ca3 100644 --- a/src/django_mysql/compat.py +++ b/src/django_mysql/compat.py @@ -3,6 +3,8 @@ import sys from typing import Any, Callable, TypeVar, cast +__all__ = ("cache",) + if sys.version_info >= (3, 9): from functools import cache else: diff --git a/src/django_mysql/models/__init__.py b/src/django_mysql/models/__init__.py index 83ca64e4d..416500759 100644 --- a/src/django_mysql/models/__init__.py +++ b/src/django_mysql/models/__init__.py @@ -1,19 +1,9 @@ -""" -isort:skip_file -""" -from django_mysql.models.base import Model # noqa -from django_mysql.models.aggregates import BitAnd, BitOr, BitXor, GroupConcat # noqa -from django_mysql.models.expressions import ListF, SetF # noqa -from django_mysql.models.query import ( # noqa - add_QuerySetMixin, - ApproximateInt, - SmartChunkedIterator, - SmartIterator, - pt_visual_explain, - QuerySet, - QuerySetMixin, -) -from django_mysql.models.fields import ( # noqa +from __future__ import annotations + +from django_mysql.models.aggregates import BitAnd, BitOr, BitXor, GroupConcat +from django_mysql.models.base import Model +from django_mysql.models.expressions import ListF, SetF +from django_mysql.models.fields import ( Bit1BooleanField, DynamicField, EnumField, @@ -26,3 +16,40 @@ SizedBinaryField, SizedTextField, ) +from django_mysql.models.query import ( + ApproximateInt, + QuerySet, + QuerySetMixin, + SmartChunkedIterator, + SmartIterator, + add_QuerySetMixin, + pt_visual_explain, +) + +__all__ = ( + "add_QuerySetMixin", + "ApproximateInt", + "Bit1BooleanField", + "BitAnd", + "BitOr", + "BitXor", + "DynamicField", + "EnumField", + "FixedCharField", + "GroupConcat", + "ListCharField", + "ListF", + "ListTextField", + "Model", + "NullBit1BooleanField", + "pt_visual_explain", + "QuerySet", + "QuerySetMixin", + "SetCharField", + "SetF", + "SetTextField", + "SizedBinaryField", + "SizedTextField", + "SmartChunkedIterator", + "SmartIterator", +) diff --git a/src/django_mysql/models/fields/__init__.py b/src/django_mysql/models/fields/__init__.py index 0284b3b18..d575652a1 100644 --- a/src/django_mysql/models/fields/__init__.py +++ b/src/django_mysql/models/fields/__init__.py @@ -8,7 +8,7 @@ from django_mysql.models.fields.sets import SetCharField, SetTextField from django_mysql.models.fields.sizes import SizedBinaryField, SizedTextField -__all__ = [ +__all__ = ( "Bit1BooleanField", "DynamicField", "EnumField", @@ -20,4 +20,4 @@ "SetTextField", "SizedBinaryField", "SizedTextField", -] +) diff --git a/src/django_mysql/models/fields/dynamic.py b/src/django_mysql/models/fields/dynamic.py index 287b7ae90..09bcb57c7 100644 --- a/src/django_mysql/models/fields/dynamic.py +++ b/src/django_mysql/models/fields/dynamic.py @@ -21,9 +21,9 @@ from django.forms import Field as FormField from django.utils.translation import gettext_lazy as _ -from django_mysql.checks import mysql_connections from django_mysql.models.lookups import DynColHasKey from django_mysql.typing import DeconstructResult +from django_mysql.utils import mysql_connections try: import mariadb_dyncol @@ -85,7 +85,7 @@ def check(self, **kwargs: Any) -> list[checks.CheckMessage]: return errors def _check_mariadb_dyncol(self) -> list[checks.CheckMessage]: - errors = [] + errors: list[checks.CheckMessage] = [] if mariadb_dyncol is None: errors.append( checks.Error( @@ -98,7 +98,7 @@ def _check_mariadb_dyncol(self) -> list[checks.CheckMessage]: return errors def _check_mariadb_version(self) -> list[checks.CheckMessage]: - errors = [] + errors: list[checks.CheckMessage] = [] any_conn_works = any( (conn.vendor == "mysql" and conn.mysql_is_mariadb) @@ -117,7 +117,7 @@ def _check_mariadb_version(self) -> list[checks.CheckMessage]: return errors def _check_character_set(self) -> list[checks.CheckMessage]: - errors = [] + errors: list[checks.CheckMessage] = [] conn = None for _alias, check_conn in mysql_connections(): @@ -149,7 +149,7 @@ def _check_character_set(self) -> list[checks.CheckMessage]: def _check_spec_recursively( self, spec: Any, path: str = "" ) -> list[checks.CheckMessage]: - errors = [] + errors: list[checks.CheckMessage] = [] if not isinstance(spec, dict): errors.append( diff --git a/src/django_mysql/operations.py b/src/django_mysql/operations.py index ea1ff5b4c..6b853dcc3 100644 --- a/src/django_mysql/operations.py +++ b/src/django_mysql/operations.py @@ -2,7 +2,7 @@ from django.db.backends.base.schema import BaseDatabaseSchemaEditor from django.db.migrations.operations.base import Operation -from django.db.migrations.state import ModelState +from django.db.migrations.state import ProjectState from django.utils.functional import cached_property @@ -15,15 +15,15 @@ def __init__(self, name: str, soname: str) -> None: self.name = name self.soname = soname - def state_forwards(self, app_label: str, state: ModelState) -> None: + def state_forwards(self, app_label: str, state: ProjectState) -> None: pass # pragma: no cover def database_forwards( self, app_label: str, schema_editor: BaseDatabaseSchemaEditor, - from_st: ModelState, - to_st: ModelState, + from_st: ProjectState, + to_st: ProjectState, ) -> None: if not self.plugin_installed(schema_editor): schema_editor.execute( @@ -34,8 +34,8 @@ def database_backwards( self, app_label: str, schema_editor: BaseDatabaseSchemaEditor, - from_st: ModelState, - to_st: ModelState, + from_st: ProjectState, + to_st: ProjectState, ) -> None: if self.plugin_installed(schema_editor): schema_editor.execute("UNINSTALL PLUGIN %s" % self.name) @@ -63,15 +63,15 @@ class InstallSOName(Operation): def __init__(self, soname: str) -> None: self.soname = soname - def state_forwards(self, app_label: str, state: ModelState) -> None: + def state_forwards(self, app_label: str, state: ProjectState) -> None: pass # pragma: no cover def database_forwards( self, app_label: str, schema_editor: BaseDatabaseSchemaEditor, - from_st: ModelState, - to_st: ModelState, + from_st: ProjectState, + to_st: ProjectState, ) -> None: schema_editor.execute("INSTALL SONAME %s", (self.soname,)) @@ -79,8 +79,8 @@ def database_backwards( self, app_label: str, schema_editor: BaseDatabaseSchemaEditor, - from_st: ModelState, - to_st: ModelState, + from_st: ProjectState, + to_st: ProjectState, ) -> None: schema_editor.execute("UNINSTALL SONAME %s", (self.soname,)) @@ -100,15 +100,15 @@ def __init__( def reversible(self) -> bool: return self.from_engine is not None - def state_forwards(self, app_label: str, state: ModelState) -> None: + def state_forwards(self, app_label: str, state: ProjectState) -> None: pass def database_forwards( self, app_label: str, schema_editor: BaseDatabaseSchemaEditor, - from_state: ModelState, - to_state: ModelState, + from_state: ProjectState, + to_state: ProjectState, ) -> None: self._change_engine(app_label, schema_editor, to_state, engine=self.engine) @@ -116,8 +116,8 @@ def database_backwards( self, app_label: str, schema_editor: BaseDatabaseSchemaEditor, - from_state: ModelState, - to_state: ModelState, + from_state: ProjectState, + to_state: ProjectState, ) -> None: if self.from_engine is None: raise NotImplementedError("You cannot reverse this operation") @@ -128,7 +128,7 @@ def _change_engine( self, app_label: str, schema_editor: BaseDatabaseSchemaEditor, - to_state: ModelState, + to_state: ProjectState, engine: str, ) -> None: new_model = to_state.apps.get_model(app_label, self.name) diff --git a/src/django_mysql/status.py b/src/django_mysql/status.py index e16346770..a8451b86e 100644 --- a/src/django_mysql/status.py +++ b/src/django_mysql/status.py @@ -6,7 +6,6 @@ from django.db import connections from django.db.backends.utils import CursorWrapper from django.db.utils import DEFAULT_DB_ALIAS -from django.utils.functional import SimpleLazyObject from django_mysql.exceptions import TimeoutError @@ -16,6 +15,7 @@ class BaseStatus: Base class for the status classes """ + __slots__ = ("db",) query = "" def __init__(self, using: str | None = None) -> None: @@ -127,5 +127,5 @@ class SessionStatus(BaseStatus): query = "SHOW SESSION STATUS" -global_status = SimpleLazyObject(GlobalStatus) -session_status = SimpleLazyObject(SessionStatus) +global_status = GlobalStatus() +session_status = SessionStatus() diff --git a/tests/testapp/models.py b/tests/testapp/models.py index 1e655bba2..2abac8adb 100644 --- a/tests/testapp/models.py +++ b/tests/testapp/models.py @@ -5,6 +5,7 @@ from typing import Any import django +from django.core import checks from django.db import connection from django.db.models import ( CASCADE, @@ -119,7 +120,7 @@ class DynamicModel(Model): ) @classmethod - def check(cls, **kwargs): + def check(cls, **kwargs: Any) -> list[checks.CheckMessage]: # Disable the checks on MySQL so that checks tests don't fail if not connection.mysql_is_mariadb: return [] diff --git a/tests/testapp/test_functions.py b/tests/testapp/test_functions.py index c3e134fa7..314c6da21 100644 --- a/tests/testapp/test_functions.py +++ b/tests/testapp/test_functions.py @@ -354,7 +354,7 @@ def test_sha2_bad_hash_len(self): class InformationFunctionTests(TestCase): - databases = ["default", "other"] + databases = {"default", "other"} def test_last_insert_id(self): Alphabet.objects.create(a=7891) diff --git a/tests/testapp/test_locks.py b/tests/testapp/test_locks.py index 83c30df43..24e9a6d30 100644 --- a/tests/testapp/test_locks.py +++ b/tests/testapp/test_locks.py @@ -23,7 +23,7 @@ class LockTests(TestCase): - databases = ["default", "other"] + databases = {"default", "other"} @classmethod def setUpClass(cls): @@ -216,7 +216,7 @@ def test_acquire_release(self): class TableLockTests(TransactionTestCase): - databases = ["default", "other"] + databases = {"default", "other"} def tearDown(self): Alphabet.objects.all().delete()