Skip to content

Commit

Permalink
Fix mypy errors in backends/base (#2894)
Browse files Browse the repository at this point in the history
  • Loading branch information
datapythonista authored Aug 8, 2021
1 parent 75c9ca9 commit 0f596c6
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 23 deletions.
4 changes: 2 additions & 2 deletions ibis/backends/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
import warnings
from typing import Any, Callable, List
from typing import Any, Callable, List, Type

import ibis.expr.operations as ops
import ibis.expr.schema as sch
Expand All @@ -21,7 +21,7 @@ class BaseBackend(abc.ABC):
"""

database_class = Database
table_class = ops.DatabaseTable
table_class: Type[ops.DatabaseTable] = ops.DatabaseTable

@property
@abc.abstractmethod
Expand Down
8 changes: 6 additions & 2 deletions ibis/backends/base/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from __future__ import annotations

from typing import List

import ibis.expr.types as ir


Expand All @@ -21,7 +25,7 @@ def __repr__(self) -> str:
"""Return type name and the name of the database."""
return '{}({!r})'.format(type(self).__name__, self.name)

def __dir__(self) -> set:
def __dir__(self) -> List[str]:
"""Return a set of attributes and tables available for the database.
Returns
Expand Down Expand Up @@ -133,5 +137,5 @@ def list_tables(self, like: str = None) -> list:
like=self._qualify_like(like), database=self.name
)

def _qualify_like(self, like: str) -> str:
def _qualify_like(self, like: str | None) -> str | None:
return like
18 changes: 9 additions & 9 deletions ibis/backends/base/sql/alchemy/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import contextlib
import warnings
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union

import pandas as pd
import sqlalchemy as sa
Expand Down Expand Up @@ -49,7 +49,7 @@ def __init__(self, con: sa.engine.Engine) -> None:
self.con = con
self.meta = sa.MetaData(bind=con)
self._inspector = sa.inspect(con)
self._schemas = {}
self._schemas: Dict[str, sch.Schema] = {}

@property
def inspector(self):
Expand Down Expand Up @@ -95,7 +95,7 @@ def begin(self):
yield bind

def create_table(self, name, expr=None, schema=None, database=None):
if database == self.database_name:
if database == self.current_database:
# avoid fully qualified name
database = None

Expand Down Expand Up @@ -149,7 +149,7 @@ def drop_table(
database: Optional[str] = None,
force: bool = False,
) -> None:
if database == self.database_name:
if database == self.current_database:
# avoid fully qualified name
database = None

Expand Down Expand Up @@ -199,7 +199,7 @@ def load_data(
Loading data to a table from a different database is not
yet implemented
"""
if database == self.database_name:
if database == self.current_database:
# avoid fully qualified name
database = None

Expand All @@ -213,7 +213,7 @@ def load_data(
if self.has_attachment:
# for database with attachment
# see: https://github.com/ibis-project/ibis/issues/1930
params['schema'] = self.database_name
params['schema'] = self.current_database

data.to_sql(
table_name,
Expand Down Expand Up @@ -296,7 +296,7 @@ def list_tables(
names = [x for x in names if like in x]
return sorted(names)

def raw_sql(self, query: str):
def raw_sql(self, query: str, results=False):
return _AutoCloseCursor(super().raw_sql(query))

def _log(self, sql):
Expand Down Expand Up @@ -351,7 +351,7 @@ def insert(
"""

if database == self.database_name:
if database == self.current_database:
# avoid fully qualified name
database = None

Expand All @@ -365,7 +365,7 @@ def insert(
if self.has_attachment:
# for database with attachment
# see: https://github.com/ibis-project/ibis/issues/1930
params['schema'] = self.database_name
params['schema'] = self.current_database

if isinstance(obj, pd.DataFrame):
obj.to_sql(
Expand Down
3 changes: 2 additions & 1 deletion ibis/backends/base/sql/alchemy/registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import operator
from typing import Any, Dict

import sqlalchemy as sa
import sqlalchemy.sql as sql
Expand Down Expand Up @@ -405,7 +406,7 @@ def _sort_key(t, expr):
return sort_direction(t.translate(by))


sqlalchemy_operation_registry = {
sqlalchemy_operation_registry: Dict[Any, Any] = {
ops.And: fixed_arity(sql.and_, 2),
ops.Or: fixed_arity(sql.or_, 2),
ops.Not: unary(sa.not_),
Expand Down
5 changes: 4 additions & 1 deletion ibis/backends/base/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,10 @@ def raw_sql(self, query: str, results=False):
"""
# TODO results is unused, it can be removed
# (requires updating Impala tests)
cursor = self.con.execute(query)
# TODO `self.con` is assumed to be defined in subclasses, but there
# is nothing that enforces it. We should find a way to make sure
# `self.con` is always a DBAPI2 connection, or raise an error
cursor = self.con.execute(query) # type: ignore
if cursor:
return cursor
cursor.release()
Expand Down
9 changes: 8 additions & 1 deletion ibis/backends/base/sql/compiler/query_builder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from io import StringIO

import toolz
Expand Down Expand Up @@ -513,8 +515,13 @@ def flatten_union(table: ir.TableExpr):
"""
op = table.op()
if isinstance(op, ops.Union):
# For some reason mypy considers `op.left` and `op.right`
# of `Argument` type, and fails the validation. While in
# `flatten` types are the same, and it works
return toolz.concatv(
flatten_union(op.left), [op.distinct], flatten_union(op.right)
flatten_union(op.left), # type: ignore
[op.distinct],
flatten_union(op.right), # type: ignore
)
return [table]

Expand Down
3 changes: 2 additions & 1 deletion ibis/backends/base/sql/compiler/translator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import operator
from typing import Callable, Dict

import ibis
import ibis.common.exceptions as com
Expand Down Expand Up @@ -201,7 +202,7 @@ class ExprTranslator:
"""

_registry = operation_registry
_rewrites = {}
_rewrites: Dict[ops.Node, Callable] = {}

def __init__(self, expr, context, named=False, permit_subquery=False):
self.expr = expr
Expand Down
13 changes: 7 additions & 6 deletions ibis/backends/base/sql/registry/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis.expr.signature import Argument

_map_interval_to_microseconds = {
'W': 604800000000,
Expand Down Expand Up @@ -43,7 +44,7 @@

def _replace_interval_with_scalar(
expr: Union[ir.Expr, dt.Interval, float]
) -> Union[ir.FloatingScalar, float]:
) -> Union[ir.Expr, float, Argument]:
"""
Good old Depth-First Search to identify the Interval and IntervalValue
components of the expression and return a comparable scalar expression.
Expand All @@ -57,9 +58,9 @@ def _replace_interval_with_scalar(
-------
preceding : float or ir.FloatingScalar, depending upon the expr
"""
try:
if isinstance(expr, ir.Expr):
expr_op = expr.op()
except AttributeError:
else:
expr_op = None

if not isinstance(expr, (dt.Interval, ir.IntervalValue)):
Expand All @@ -80,13 +81,13 @@ def _replace_interval_with_scalar(
)
elif expr_op.args and isinstance(expr, ir.IntervalValue):
if len(expr_op.args) > 2:
raise com.NotImplementedError(
"'preceding' argument cannot be parsed."
)
raise NotImplementedError("'preceding' argument cannot be parsed.")
left_arg = _replace_interval_with_scalar(expr_op.args[0])
right_arg = _replace_interval_with_scalar(expr_op.args[1])
method = _map_interval_op_to_op[type(expr_op)]
return method(left_arg, right_arg)
else:
raise TypeError(f'expr has unknown type {type(expr).__name__}')


def cumulative_to_window(translator, expr, window):
Expand Down

0 comments on commit 0f596c6

Please sign in to comment.