Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mypy errors in backends/base #2894

Merged
merged 4 commits into from
Aug 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ibis/backends/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
import warnings
from typing import Any, Callable, List
from typing import Any, Callable, List, Type

import ibis.expr.operations as ops
import ibis.expr.schema as sch
Expand All @@ -21,7 +21,7 @@ class BaseBackend(abc.ABC):
"""

database_class = Database
table_class = ops.DatabaseTable
table_class: Type[ops.DatabaseTable] = ops.DatabaseTable

@property
@abc.abstractmethod
Expand Down
8 changes: 6 additions & 2 deletions ibis/backends/base/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from __future__ import annotations

from typing import List

import ibis.expr.types as ir


Expand All @@ -21,7 +25,7 @@ def __repr__(self) -> str:
"""Return type name and the name of the database."""
return '{}({!r})'.format(type(self).__name__, self.name)

def __dir__(self) -> set:
def __dir__(self) -> List[str]:
"""Return a set of attributes and tables available for the database.
Returns
Expand Down Expand Up @@ -133,5 +137,5 @@ def list_tables(self, like: str = None) -> list:
like=self._qualify_like(like), database=self.name
)

def _qualify_like(self, like: str) -> str:
def _qualify_like(self, like: str | None) -> str | None:
return like
18 changes: 9 additions & 9 deletions ibis/backends/base/sql/alchemy/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import contextlib
import warnings
from typing import List, Optional, Union
from typing import Dict, List, Optional, Union

import pandas as pd
import sqlalchemy as sa
Expand Down Expand Up @@ -49,7 +49,7 @@ def __init__(self, con: sa.engine.Engine) -> None:
self.con = con
self.meta = sa.MetaData(bind=con)
self._inspector = sa.inspect(con)
self._schemas = {}
self._schemas: Dict[str, sch.Schema] = {}

@property
def inspector(self):
Expand Down Expand Up @@ -95,7 +95,7 @@ def begin(self):
yield bind

def create_table(self, name, expr=None, schema=None, database=None):
if database == self.database_name:
if database == self.current_database:
# avoid fully qualified name
database = None

Expand Down Expand Up @@ -149,7 +149,7 @@ def drop_table(
database: Optional[str] = None,
force: bool = False,
) -> None:
if database == self.database_name:
if database == self.current_database:
# avoid fully qualified name
database = None

Expand Down Expand Up @@ -199,7 +199,7 @@ def load_data(
Loading data to a table from a different database is not
yet implemented
"""
if database == self.database_name:
if database == self.current_database:
# avoid fully qualified name
database = None

Expand All @@ -213,7 +213,7 @@ def load_data(
if self.has_attachment:
# for database with attachment
# see: https://github.com/ibis-project/ibis/issues/1930
params['schema'] = self.database_name
params['schema'] = self.current_database

data.to_sql(
table_name,
Expand Down Expand Up @@ -296,7 +296,7 @@ def list_tables(
names = [x for x in names if like in x]
return sorted(names)

def raw_sql(self, query: str):
def raw_sql(self, query: str, results=False):
return _AutoCloseCursor(super().raw_sql(query))

def _log(self, sql):
Expand Down Expand Up @@ -351,7 +351,7 @@ def insert(
"""

if database == self.database_name:
if database == self.current_database:
# avoid fully qualified name
database = None

Expand All @@ -365,7 +365,7 @@ def insert(
if self.has_attachment:
# for database with attachment
# see: https://github.com/ibis-project/ibis/issues/1930
params['schema'] = self.database_name
params['schema'] = self.current_database

if isinstance(obj, pd.DataFrame):
obj.to_sql(
Expand Down
3 changes: 2 additions & 1 deletion ibis/backends/base/sql/alchemy/registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import operator
from typing import Any, Dict

import sqlalchemy as sa
import sqlalchemy.sql as sql
Expand Down Expand Up @@ -405,7 +406,7 @@ def _sort_key(t, expr):
return sort_direction(t.translate(by))


sqlalchemy_operation_registry = {
sqlalchemy_operation_registry: Dict[Any, Any] = {
ops.And: fixed_arity(sql.and_, 2),
ops.Or: fixed_arity(sql.or_, 2),
ops.Not: unary(sa.not_),
Expand Down
5 changes: 4 additions & 1 deletion ibis/backends/base/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,10 @@ def raw_sql(self, query: str, results=False):
"""
# TODO results is unused, it can be removed
# (requires updating Impala tests)
cursor = self.con.execute(query)
# TODO `self.con` is assumed to be defined in subclasses, but there
# is nothing that enforces it. We should find a way to make sure
# `self.con` is always a DBAPI2 connection, or raise an error
cursor = self.con.execute(query) # type: ignore
if cursor:
return cursor
cursor.release()
Expand Down
9 changes: 8 additions & 1 deletion ibis/backends/base/sql/compiler/query_builder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from io import StringIO

import toolz
Expand Down Expand Up @@ -513,8 +515,13 @@ def flatten_union(table: ir.TableExpr):
"""
op = table.op()
if isinstance(op, ops.Union):
# For some reason mypy considers `op.left` and `op.right`
# of `Argument` type, and fails the validation. While in
# `flatten` types are the same, and it works
return toolz.concatv(
flatten_union(op.left), [op.distinct], flatten_union(op.right)
flatten_union(op.left), # type: ignore
[op.distinct],
flatten_union(op.right), # type: ignore
)
return [table]

Expand Down
3 changes: 2 additions & 1 deletion ibis/backends/base/sql/compiler/translator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import operator
from typing import Callable, Dict

import ibis
import ibis.common.exceptions as com
Expand Down Expand Up @@ -201,7 +202,7 @@ class ExprTranslator:
"""

_registry = operation_registry
_rewrites = {}
_rewrites: Dict[ops.Node, Callable] = {}

def __init__(self, expr, context, named=False, permit_subquery=False):
self.expr = expr
Expand Down
13 changes: 7 additions & 6 deletions ibis/backends/base/sql/registry/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.types as ir
from ibis.expr.signature import Argument

_map_interval_to_microseconds = {
'W': 604800000000,
Expand Down Expand Up @@ -43,7 +44,7 @@

def _replace_interval_with_scalar(
expr: Union[ir.Expr, dt.Interval, float]
) -> Union[ir.FloatingScalar, float]:
) -> Union[ir.Expr, float, Argument]:
"""
Good old Depth-First Search to identify the Interval and IntervalValue
components of the expression and return a comparable scalar expression.
Expand All @@ -57,9 +58,9 @@ def _replace_interval_with_scalar(
-------
preceding : float or ir.FloatingScalar, depending upon the expr
"""
try:
if isinstance(expr, ir.Expr):
expr_op = expr.op()
except AttributeError:
else:
expr_op = None

if not isinstance(expr, (dt.Interval, ir.IntervalValue)):
Expand All @@ -80,13 +81,13 @@ def _replace_interval_with_scalar(
)
elif expr_op.args and isinstance(expr, ir.IntervalValue):
if len(expr_op.args) > 2:
raise com.NotImplementedError(
"'preceding' argument cannot be parsed."
)
raise NotImplementedError("'preceding' argument cannot be parsed.")
left_arg = _replace_interval_with_scalar(expr_op.args[0])
right_arg = _replace_interval_with_scalar(expr_op.args[1])
method = _map_interval_op_to_op[type(expr_op)]
return method(left_arg, right_arg)
else:
raise TypeError(f'expr has unknown type {type(expr).__name__}')


def cumulative_to_window(translator, expr, window):
Expand Down