Skip to content

Commit

Permalink
Use Mypy's strict mode
Browse files Browse the repository at this point in the history
  • Loading branch information
adamchainz committed Aug 15, 2022
1 parent 8da03cd commit 27fffbf
Show file tree
Hide file tree
Showing 11 changed files with 85 additions and 53 deletions.
3 changes: 3 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,6 @@ repos:
rev: v0.971
hooks:
- id: mypy
additional_dependencies:
- django-stubs==1.12.0
- mysqlclient
11 changes: 5 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,19 @@ build-backend = "setuptools.build_meta"
[tool.black]
target-version = ['py37']

[tool.django-stubs]
django_settings_module = "tests.settings"

[tool.isort]
profile = "black"
add_imports = "from __future__ import annotations"

[tool.mypy]
check_untyped_defs = true
disallow_any_generics = true
disallow_incomplete_defs = true
disallow_untyped_defs = true
mypy_path = "src/"
no_implicit_optional = true
plugins = ["mypy_django_plugin.main"]
show_error_codes = true
strict = true
warn_unreachable = true
warn_unused_ignores = true

[[tool.mypy.overrides]]
module = "tests.*"
Expand Down
2 changes: 2 additions & 0 deletions src/django_mysql/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import sys
from typing import Any, Callable, TypeVar, cast

__all__ = ("cache",)

if sys.version_info >= (3, 9):
from functools import cache
else:
Expand Down
59 changes: 43 additions & 16 deletions src/django_mysql/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,9 @@
"""
isort:skip_file
"""
from django_mysql.models.base import Model # noqa
from django_mysql.models.aggregates import BitAnd, BitOr, BitXor, GroupConcat # noqa
from django_mysql.models.expressions import ListF, SetF # noqa
from django_mysql.models.query import ( # noqa
add_QuerySetMixin,
ApproximateInt,
SmartChunkedIterator,
SmartIterator,
pt_visual_explain,
QuerySet,
QuerySetMixin,
)
from django_mysql.models.fields import ( # noqa
from __future__ import annotations

from django_mysql.models.aggregates import BitAnd, BitOr, BitXor, GroupConcat
from django_mysql.models.base import Model
from django_mysql.models.expressions import ListF, SetF
from django_mysql.models.fields import (
Bit1BooleanField,
DynamicField,
EnumField,
Expand All @@ -26,3 +16,40 @@
SizedBinaryField,
SizedTextField,
)
from django_mysql.models.query import (
ApproximateInt,
QuerySet,
QuerySetMixin,
SmartChunkedIterator,
SmartIterator,
add_QuerySetMixin,
pt_visual_explain,
)

__all__ = (
"add_QuerySetMixin",
"ApproximateInt",
"Bit1BooleanField",
"BitAnd",
"BitOr",
"BitXor",
"DynamicField",
"EnumField",
"FixedCharField",
"GroupConcat",
"ListCharField",
"ListF",
"ListTextField",
"Model",
"NullBit1BooleanField",
"pt_visual_explain",
"QuerySet",
"QuerySetMixin",
"SetCharField",
"SetF",
"SetTextField",
"SizedBinaryField",
"SizedTextField",
"SmartChunkedIterator",
"SmartIterator",
)
4 changes: 2 additions & 2 deletions src/django_mysql/models/fields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from django_mysql.models.fields.sets import SetCharField, SetTextField
from django_mysql.models.fields.sizes import SizedBinaryField, SizedTextField

__all__ = [
__all__ = (
"Bit1BooleanField",
"DynamicField",
"EnumField",
Expand All @@ -20,4 +20,4 @@
"SetTextField",
"SizedBinaryField",
"SizedTextField",
]
)
10 changes: 5 additions & 5 deletions src/django_mysql/models/fields/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
from django.forms import Field as FormField
from django.utils.translation import gettext_lazy as _

from django_mysql.checks import mysql_connections
from django_mysql.models.lookups import DynColHasKey
from django_mysql.typing import DeconstructResult
from django_mysql.utils import mysql_connections

try:
import mariadb_dyncol
Expand Down Expand Up @@ -85,7 +85,7 @@ def check(self, **kwargs: Any) -> list[checks.CheckMessage]:
return errors

def _check_mariadb_dyncol(self) -> list[checks.CheckMessage]:
errors = []
errors: list[checks.CheckMessage] = []
if mariadb_dyncol is None:
errors.append(
checks.Error(
Expand All @@ -98,7 +98,7 @@ def _check_mariadb_dyncol(self) -> list[checks.CheckMessage]:
return errors

def _check_mariadb_version(self) -> list[checks.CheckMessage]:
errors = []
errors: list[checks.CheckMessage] = []

any_conn_works = any(
(conn.vendor == "mysql" and conn.mysql_is_mariadb)
Expand All @@ -117,7 +117,7 @@ def _check_mariadb_version(self) -> list[checks.CheckMessage]:
return errors

def _check_character_set(self) -> list[checks.CheckMessage]:
errors = []
errors: list[checks.CheckMessage] = []

conn = None
for _alias, check_conn in mysql_connections():
Expand Down Expand Up @@ -149,7 +149,7 @@ def _check_character_set(self) -> list[checks.CheckMessage]:
def _check_spec_recursively(
self, spec: Any, path: str = ""
) -> list[checks.CheckMessage]:
errors = []
errors: list[checks.CheckMessage] = []

if not isinstance(spec, dict):
errors.append(
Expand Down
34 changes: 17 additions & 17 deletions src/django_mysql/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db.migrations.operations.base import Operation
from django.db.migrations.state import ModelState
from django.db.migrations.state import ProjectState
from django.utils.functional import cached_property


Expand All @@ -15,15 +15,15 @@ def __init__(self, name: str, soname: str) -> None:
self.name = name
self.soname = soname

def state_forwards(self, app_label: str, state: ModelState) -> None:
def state_forwards(self, app_label: str, state: ProjectState) -> None:
pass # pragma: no cover

def database_forwards(
self,
app_label: str,
schema_editor: BaseDatabaseSchemaEditor,
from_st: ModelState,
to_st: ModelState,
from_st: ProjectState,
to_st: ProjectState,
) -> None:
if not self.plugin_installed(schema_editor):
schema_editor.execute(
Expand All @@ -34,8 +34,8 @@ def database_backwards(
self,
app_label: str,
schema_editor: BaseDatabaseSchemaEditor,
from_st: ModelState,
to_st: ModelState,
from_st: ProjectState,
to_st: ProjectState,
) -> None:
if self.plugin_installed(schema_editor):
schema_editor.execute("UNINSTALL PLUGIN %s" % self.name)
Expand Down Expand Up @@ -63,24 +63,24 @@ class InstallSOName(Operation):
def __init__(self, soname: str) -> None:
self.soname = soname

def state_forwards(self, app_label: str, state: ModelState) -> None:
def state_forwards(self, app_label: str, state: ProjectState) -> None:
pass # pragma: no cover

def database_forwards(
self,
app_label: str,
schema_editor: BaseDatabaseSchemaEditor,
from_st: ModelState,
to_st: ModelState,
from_st: ProjectState,
to_st: ProjectState,
) -> None:
schema_editor.execute("INSTALL SONAME %s", (self.soname,))

def database_backwards(
self,
app_label: str,
schema_editor: BaseDatabaseSchemaEditor,
from_st: ModelState,
to_st: ModelState,
from_st: ProjectState,
to_st: ProjectState,
) -> None:
schema_editor.execute("UNINSTALL SONAME %s", (self.soname,))

Expand All @@ -100,24 +100,24 @@ def __init__(
def reversible(self) -> bool:
return self.from_engine is not None

def state_forwards(self, app_label: str, state: ModelState) -> None:
def state_forwards(self, app_label: str, state: ProjectState) -> None:
pass

def database_forwards(
self,
app_label: str,
schema_editor: BaseDatabaseSchemaEditor,
from_state: ModelState,
to_state: ModelState,
from_state: ProjectState,
to_state: ProjectState,
) -> None:
self._change_engine(app_label, schema_editor, to_state, engine=self.engine)

def database_backwards(
self,
app_label: str,
schema_editor: BaseDatabaseSchemaEditor,
from_state: ModelState,
to_state: ModelState,
from_state: ProjectState,
to_state: ProjectState,
) -> None:
if self.from_engine is None:
raise NotImplementedError("You cannot reverse this operation")
Expand All @@ -128,7 +128,7 @@ def _change_engine(
self,
app_label: str,
schema_editor: BaseDatabaseSchemaEditor,
to_state: ModelState,
to_state: ProjectState,
engine: str,
) -> None:
new_model = to_state.apps.get_model(app_label, self.name)
Expand Down
6 changes: 3 additions & 3 deletions src/django_mysql/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from django.db import connections
from django.db.backends.utils import CursorWrapper
from django.db.utils import DEFAULT_DB_ALIAS
from django.utils.functional import SimpleLazyObject

from django_mysql.exceptions import TimeoutError

Expand All @@ -16,6 +15,7 @@ class BaseStatus:
Base class for the status classes
"""

__slots__ = ("db",)
query = ""

def __init__(self, using: str | None = None) -> None:
Expand Down Expand Up @@ -127,5 +127,5 @@ class SessionStatus(BaseStatus):
query = "SHOW SESSION STATUS"


global_status = SimpleLazyObject(GlobalStatus)
session_status = SimpleLazyObject(SessionStatus)
global_status = GlobalStatus()
session_status = SessionStatus()
3 changes: 2 additions & 1 deletion tests/testapp/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any

import django
from django.core import checks
from django.db import connection
from django.db.models import (
CASCADE,
Expand Down Expand Up @@ -119,7 +120,7 @@ class DynamicModel(Model):
)

@classmethod
def check(cls, **kwargs):
def check(cls, **kwargs: Any) -> list[checks.CheckMessage]:
# Disable the checks on MySQL so that checks tests don't fail
if not connection.mysql_is_mariadb:
return []
Expand Down
2 changes: 1 addition & 1 deletion tests/testapp/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def test_sha2_bad_hash_len(self):

class InformationFunctionTests(TestCase):

databases = ["default", "other"]
databases = {"default", "other"}

def test_last_insert_id(self):
Alphabet.objects.create(a=7891)
Expand Down
4 changes: 2 additions & 2 deletions tests/testapp/test_locks.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

class LockTests(TestCase):

databases = ["default", "other"]
databases = {"default", "other"}

@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -216,7 +216,7 @@ def test_acquire_release(self):

class TableLockTests(TransactionTestCase):

databases = ["default", "other"]
databases = {"default", "other"}

def tearDown(self):
Alphabet.objects.all().delete()
Expand Down

0 comments on commit 27fffbf

Please sign in to comment.