Skip to content

Commit

Permalink
Correct types of source expression functions
Browse files Browse the repository at this point in the history
  • Loading branch information
adamchainz committed Sep 9, 2024
1 parent 07f44d0 commit de983de
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions src/django_mysql/models/expressions.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

from typing import Any
from typing import Iterable
from typing import Sequence

from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models import F
from django.db.models import Value
from django.db.models.expressions import BaseExpression
from django.db.models.expressions import Combinable
from django.db.models.expressions import Expression
from django.db.models.sql.compiler import SQLCompiler

from django_mysql.utils import collapse_spaces
Expand All @@ -18,10 +20,10 @@ def __init__(self, lhs: BaseExpression, rhs: BaseExpression) -> None:
self.lhs = lhs
self.rhs = rhs

def get_source_expressions(self) -> list[BaseExpression]:
def get_source_expressions(self) -> list[Expression]:
return [self.lhs, self.rhs]

def set_source_expressions(self, exprs: Iterable[BaseExpression]) -> None:
def set_source_expressions(self, exprs: Sequence[Combinable | Expression]) -> None:
self.lhs, self.rhs = exprs


Expand Down Expand Up @@ -138,10 +140,10 @@ def __init__(self, lhs: BaseExpression) -> None:
super().__init__()
self.lhs = lhs

def get_source_expressions(self) -> list[BaseExpression]:
def get_source_expressions(self) -> list[Expression]:
return [self.lhs]

def set_source_expressions(self, exprs: Iterable[BaseExpression]) -> None:
def set_source_expressions(self, exprs: Sequence[Combinable | Expression]) -> None:
(self.lhs,) = exprs

def as_sql(
Expand Down Expand Up @@ -170,10 +172,10 @@ def __init__(self, lhs: BaseExpression) -> None:
super().__init__()
self.lhs = lhs

def get_source_expressions(self) -> list[BaseExpression]:
def get_source_expressions(self) -> list[Expression]:
return [self.lhs]

def set_source_expressions(self, exprs: Iterable[BaseExpression]) -> None:
def set_source_expressions(self, exprs: Sequence[Combinable | Expression]) -> None:
(self.lhs,) = exprs

def as_sql(
Expand Down

0 comments on commit de983de

Please sign in to comment.