Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Specific integer types and limit-offset support #27

Merged
merged 14 commits into from
Feb 6, 2024
15 changes: 15 additions & 0 deletions test/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,21 @@ def test_select_types(self, connection):
row = connection.execute(sa.select(tb)).fetchone()
assert row == (1, "Hello World!", 3.5, True, now, today)

def test_integer_types(self, connection):
stmt = sa.Select(
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_uint8", 8, types.UInt8))),
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_uint16", 16, types.UInt16))),
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_uint32", 32, types.UInt32))),
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_uint64", 64, types.UInt64))),
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_int8", -8, types.Int8))),
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_int16", -16, types.Int16))),
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_int32", -32, types.Int32))),
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_int64", -64, types.Int64))),
)

result = connection.execute(stmt).fetchone()
assert result == (b"Uint8", b"Uint16", b"Uint32", b"Uint64", b"Int8", b"Int16", b"Int32", b"Int64")


class TestWithClause(TablesTest):
__backend__ = True
Expand Down
76 changes: 15 additions & 61 deletions test/test_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,11 @@
requirements,
select,
testing,
union,
)
from sqlalchemy.testing.suite.test_ddl import (
LongNameBlowoutTest as _LongNameBlowoutTest,
)
from sqlalchemy.testing.suite.test_deprecations import (
DeprecatedCompoundSelectTest as _DeprecatedCompoundSelectTest,
)
from sqlalchemy.testing.suite.test_dialect import (
DifficultParametersTest as _DifficultParametersTest,
)
Expand All @@ -50,9 +48,6 @@
QuotedNameArgumentTest as _QuotedNameArgumentTest,
)
from sqlalchemy.testing.suite.test_results import RowFetchTest as _RowFetchTest
from sqlalchemy.testing.suite.test_select import (
CompoundSelectTest as _CompoundSelectTest,
)
from sqlalchemy.testing.suite.test_select import ExistsTest as _ExistsTest
from sqlalchemy.testing.suite.test_select import (
FetchLimitOffsetTest as _FetchLimitOffsetTest,
Expand Down Expand Up @@ -325,20 +320,6 @@ def test_not_regexp_match(self):
self._test(~col.regexp_match("a.cde"), {2, 3, 4, 7, 8, 10, 11})


class CompoundSelectTest(_CompoundSelectTest):
@pytest.mark.skip("limit don't work")
def test_distinct_selectable_in_unions(self):
pass

@pytest.mark.skip("limit don't work")
def test_limit_offset_in_unions_from_alias(self):
pass

@pytest.mark.skip("limit don't work")
def test_limit_offset_aliased_selectable_in_unions(self):
pass


class EscapingTest(_EscapingTest):
@provide_metadata
def test_percent_sign_round_trip(self):
Expand Down Expand Up @@ -395,45 +376,23 @@ def test_group_by_composed(self):


class FetchLimitOffsetTest(_FetchLimitOffsetTest):
@pytest.mark.skip("Failed to convert type: Int64 to Uint64")
def test_bound_limit(self, connection):
pass

@pytest.mark.skip("Failed to convert type: Int64 to Uint64")
def test_bound_limit_offset(self, connection):
pass

@pytest.mark.skip("Failed to convert type: Int64 to Uint64")
def test_bound_offset(self, connection):
pass

@pytest.mark.skip("Failed to convert type: Int64 to Uint64")
def test_expr_limit_simple_offset(self, connection):
pass

@pytest.mark.skip("Failed to convert type: Int64 to Uint64")
def test_limit_render_multiple_times(self, connection):
pass

@pytest.mark.skip("Failed to convert type: Int64 to Uint64")
def test_simple_limit(self, connection):
pass

@pytest.mark.skip("Failed to convert type: Int64 to Uint64")
def test_simple_limit_offset(self, connection):
pass

@pytest.mark.skip("Failed to convert type: Int64 to Uint64")
def test_simple_offset(self, connection):
pass
"""
YQL does not support scalar subquery, so test was refiled with simple subquery
"""
table = self.tables.some_table
stmt = select(table.c.id).limit(1).subquery()

@pytest.mark.skip("Failed to convert type: Int64 to Uint64")
def test_simple_offset_zero(self, connection):
pass
u = union(select(stmt), select(stmt)).subquery().select()

@pytest.mark.skip("Failed to convert type: Int64 to Uint64")
def test_simple_limit_expr_offset(self, connection):
pass
self._assert_result(
connection,
u,
[
(1,),
(1,),
],
)


class InsertBehaviorTest(_InsertBehaviorTest):
Expand Down Expand Up @@ -570,8 +529,3 @@ class RowFetchTest(_RowFetchTest):
@pytest.mark.skip("scalar subquery unsupported")
def test_row_w_scalar_select(self, connection):
pass


@pytest.mark.skip("TODO: try it after limit/offset tests would fixed")
class DeprecatedCompoundSelectTest(_DeprecatedCompoundSelectTest):
pass
2 changes: 1 addition & 1 deletion ydb_sqlalchemy/dbapi/errors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, List
from typing import List, Optional

import ydb
from google.protobuf.message import Message
Expand Down
83 changes: 73 additions & 10 deletions ydb_sqlalchemy/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""
import collections
import collections.abc
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union

import sqlalchemy as sa
import ydb
Expand Down Expand Up @@ -87,15 +87,30 @@ def visit_FLOAT(self, type_: sa.FLOAT, **kw):
def visit_BOOLEAN(self, type_: sa.BOOLEAN, **kw):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a doubt about visit_uuid and implicit convert uuid to UTF8: YDB has own uuid type. Now it can't stored to a table, but it can be (and will sometime) in future. Then convert from uuid to UTF8 will see strange.

I suggest to not convert uuid to UTF-8 into SDK and suggest users to the conversion in customers code. It is simple as str(var)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed in offline: will solved in a separate PR

return "BOOL"

def visit_uint64(self, type_: types.UInt64, **kw):
return "UInt64"

def visit_uint32(self, type_: types.UInt32, **kw):
return "UInt32"

def visit_uint64(self, type_: types.UInt64, **kw):
return "UInt64"
def visit_uint16(self, type_: types.UInt16, **kw):
return "UInt16"

def visit_uint8(self, type_: types.UInt8, **kw):
return "UInt8"

def visit_int64(self, type_: types.Int64, **kw):
return "Int64"

def visit_int32(self, type_: types.Int32, **kw):
return "Int32"

def visit_int16(self, type_: types.Int16, **kw):
return "Int16"

def visit_int8(self, type_: types.Int8, **kw):
return "Int8"

def visit_INTEGER(self, type_: sa.INTEGER, **kw):
return "Int64"

Expand Down Expand Up @@ -134,8 +149,28 @@ def get_ydb_type(

if isinstance(type_, (sa.Text, sa.String, sa.Uuid)):
ydb_type = ydb.PrimitiveType.Utf8

# Integers
elif isinstance(type_, types.UInt64):
ydb_type = ydb.PrimitiveType.Uint64
elif isinstance(type_, types.UInt32):
ydb_type = ydb.PrimitiveType.Uint32
elif isinstance(type_, types.UInt16):
ydb_type = ydb.PrimitiveType.Uint16
elif isinstance(type_, types.UInt8):
ydb_type = ydb.PrimitiveType.Uint8
elif isinstance(type_, types.Int64):
ydb_type = ydb.PrimitiveType.Int64
elif isinstance(type_, types.Int32):
ydb_type = ydb.PrimitiveType.Int32
elif isinstance(type_, types.Int16):
ydb_type = ydb.PrimitiveType.Int16
elif isinstance(type_, types.Int8):
ydb_type = ydb.PrimitiveType.Int8
elif isinstance(type_, sa.Integer):
ydb_type = ydb.PrimitiveType.Int64
# Integers

elif isinstance(type_, sa.JSON):
ydb_type = ydb.PrimitiveType.Json
elif isinstance(type_, sa.DateTime):
Expand Down Expand Up @@ -188,6 +223,36 @@ def group_by_clause(self, select, **kw):
kw.update(within_columns_clause=True)
return super(YqlCompiler, self).group_by_clause(select, **kw)

def limit_clause(self, select, **kw):
text = ""
if select._limit_clause is not None:
limit_clause = self._maybe_cast(
select._limit_clause, types.UInt64, skip_types=(types.UInt64, types.UInt32, types.UInt16, types.UInt8)
)
text += "\n LIMIT " + self.process(limit_clause, **kw)
if select._offset_clause is not None:
offset_clause = self._maybe_cast(
select._offset_clause, types.UInt64, skip_types=(types.UInt64, types.UInt32, types.UInt16, types.UInt8)
)
if select._limit_clause is None:
text += "\n LIMIT NULL"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need explicit LIMIT NULL?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed in offline: it is needed to use OFFSET without LIMIT

text += " OFFSET " + self.process(offset_clause, **kw)
return text

def _maybe_cast(
self,
element: Any,
cast_to: Type[sa.types.TypeEngine],
skip_types: Optional[Tuple[Type[sa.types.TypeEngine], ...]] = None,
) -> Any:
if not skip_types:
skip_types = (cast_to,)
if cast_to not in skip_types:
skip_types = (*skip_types, cast_to)
if not hasattr(element, "type") or not isinstance(element.type, skip_types):
return sa.Cast(element, cast_to)
return element

def render_literal_value(self, value, type_):
if isinstance(value, str):
value = "".join(STR_QUOTE_MAP.get(x, x) for x in value)
Expand Down Expand Up @@ -277,16 +342,14 @@ def _is_bound_to_nullable_column(self, bind_name: str) -> bool:
def _guess_bound_variable_type_by_parameters(
self, bind: sa.BindParameter, post_compile_bind_values: list
) -> Optional[sa.types.TypeEngine]:
if not bind.expanding:
if isinstance(bind.type, sa.types.NullType):
return None
bind_type = bind.type
else:
bind_type = bind.type
if bind.expanding or (isinstance(bind.type, sa.types.NullType) and post_compile_bind_values):
not_null_values = [v for v in post_compile_bind_values if v is not None]
if not_null_values:
bind_type = sa.BindParameter("", not_null_values[0]).type
else:
return None

if isinstance(bind_type, sa.types.NullType):
return None

return bind_type

Expand Down
24 changes: 22 additions & 2 deletions ydb_sqlalchemy/sqlalchemy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,38 @@
from sqlalchemy.sql import type_api


class UInt64(types.Integer):
__visit_name__ = "uint64"


class UInt32(types.Integer):
__visit_name__ = "uint32"


class UInt64(types.Integer):
__visit_name__ = "uint64"
class UInt16(types.Integer):
__visit_name__ = "uint16"


class UInt8(types.Integer):
__visit_name__ = "uint8"


class Int64(types.Integer):
__visit_name__ = "int64"


class Int32(types.Integer):
__visit_name__ = "int32"


class Int16(types.Integer):
__visit_name__ = "int32"


class Int8(types.Integer):
__visit_name__ = "int8"


class ListType(ARRAY):
__visit_name__ = "list_type"

Expand Down
Loading