diff --git a/test/test_core.py b/test/test_core.py index c2a1dc9..9175b7c 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -216,6 +216,21 @@ def test_select_types(self, connection): row = connection.execute(sa.select(tb)).fetchone() assert row == (1, "Hello World!", 3.5, True, now, today) + def test_integer_types(self, connection): + stmt = sa.Select( + sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_uint8", 8, types.UInt8))), + sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_uint16", 16, types.UInt16))), + sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_uint32", 32, types.UInt32))), + sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_uint64", 64, types.UInt64))), + sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_int8", -8, types.Int8))), + sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_int16", -16, types.Int16))), + sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_int32", -32, types.Int32))), + sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_int64", -64, types.Int64))), + ) + + result = connection.execute(stmt).fetchone() + assert result == (b"Uint8", b"Uint16", b"Uint32", b"Uint64", b"Int8", b"Int16", b"Int32", b"Int64") + class TestWithClause(TablesTest): __backend__ = True diff --git a/test/test_suite.py b/test/test_suite.py index d853aaf..e504058 100644 --- a/test/test_suite.py +++ b/test/test_suite.py @@ -21,13 +21,11 @@ requirements, select, testing, + union, ) from sqlalchemy.testing.suite.test_ddl import ( LongNameBlowoutTest as _LongNameBlowoutTest, ) -from sqlalchemy.testing.suite.test_deprecations import ( - DeprecatedCompoundSelectTest as _DeprecatedCompoundSelectTest, -) from sqlalchemy.testing.suite.test_dialect import ( DifficultParametersTest as _DifficultParametersTest, ) @@ -50,9 +48,6 @@ QuotedNameArgumentTest as _QuotedNameArgumentTest, ) from sqlalchemy.testing.suite.test_results import RowFetchTest as _RowFetchTest -from sqlalchemy.testing.suite.test_select import ( - CompoundSelectTest as _CompoundSelectTest, -) from sqlalchemy.testing.suite.test_select import ExistsTest as _ExistsTest from sqlalchemy.testing.suite.test_select import ( FetchLimitOffsetTest as _FetchLimitOffsetTest, @@ -325,20 +320,6 @@ def test_not_regexp_match(self): self._test(~col.regexp_match("a.cde"), {2, 3, 4, 7, 8, 10, 11}) -class CompoundSelectTest(_CompoundSelectTest): - @pytest.mark.skip("limit don't work") - def test_distinct_selectable_in_unions(self): - pass - - @pytest.mark.skip("limit don't work") - def test_limit_offset_in_unions_from_alias(self): - pass - - @pytest.mark.skip("limit don't work") - def test_limit_offset_aliased_selectable_in_unions(self): - pass - - class EscapingTest(_EscapingTest): @provide_metadata def test_percent_sign_round_trip(self): @@ -395,45 +376,23 @@ def test_group_by_composed(self): class FetchLimitOffsetTest(_FetchLimitOffsetTest): - @pytest.mark.skip("Failed to convert type: Int64 to Uint64") - def test_bound_limit(self, connection): - pass - - @pytest.mark.skip("Failed to convert type: Int64 to Uint64") - def test_bound_limit_offset(self, connection): - pass - - @pytest.mark.skip("Failed to convert type: Int64 to Uint64") - def test_bound_offset(self, connection): - pass - - @pytest.mark.skip("Failed to convert type: Int64 to Uint64") - def test_expr_limit_simple_offset(self, connection): - pass - - @pytest.mark.skip("Failed to convert type: Int64 to Uint64") def test_limit_render_multiple_times(self, connection): - pass - - @pytest.mark.skip("Failed to convert type: Int64 to Uint64") - def test_simple_limit(self, connection): - pass - - @pytest.mark.skip("Failed to convert type: Int64 to Uint64") - def test_simple_limit_offset(self, connection): - pass - - @pytest.mark.skip("Failed to convert type: Int64 to Uint64") - def test_simple_offset(self, connection): - pass + """ + YQL does not support scalar subquery, so test was refiled with simple subquery + """ + table = self.tables.some_table + stmt = select(table.c.id).limit(1).subquery() - @pytest.mark.skip("Failed to convert type: Int64 to Uint64") - def test_simple_offset_zero(self, connection): - pass + u = union(select(stmt), select(stmt)).subquery().select() - @pytest.mark.skip("Failed to convert type: Int64 to Uint64") - def test_simple_limit_expr_offset(self, connection): - pass + self._assert_result( + connection, + u, + [ + (1,), + (1,), + ], + ) class InsertBehaviorTest(_InsertBehaviorTest): @@ -570,8 +529,3 @@ class RowFetchTest(_RowFetchTest): @pytest.mark.skip("scalar subquery unsupported") def test_row_w_scalar_select(self, connection): pass - - -@pytest.mark.skip("TODO: try it after limit/offset tests would fixed") -class DeprecatedCompoundSelectTest(_DeprecatedCompoundSelectTest): - pass diff --git a/ydb_sqlalchemy/dbapi/errors.py b/ydb_sqlalchemy/dbapi/errors.py index 70b55eb..79faba8 100644 --- a/ydb_sqlalchemy/dbapi/errors.py +++ b/ydb_sqlalchemy/dbapi/errors.py @@ -1,4 +1,4 @@ -from typing import Optional, List +from typing import List, Optional import ydb from google.protobuf.message import Message diff --git a/ydb_sqlalchemy/sqlalchemy/__init__.py b/ydb_sqlalchemy/sqlalchemy/__init__.py index a3cb0e4..78bb7de 100644 --- a/ydb_sqlalchemy/sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/sqlalchemy/__init__.py @@ -4,7 +4,7 @@ """ import collections import collections.abc -from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union import sqlalchemy as sa import ydb @@ -87,15 +87,30 @@ def visit_FLOAT(self, type_: sa.FLOAT, **kw): def visit_BOOLEAN(self, type_: sa.BOOLEAN, **kw): return "BOOL" + def visit_uint64(self, type_: types.UInt64, **kw): + return "UInt64" + def visit_uint32(self, type_: types.UInt32, **kw): return "UInt32" - def visit_uint64(self, type_: types.UInt64, **kw): - return "UInt64" + def visit_uint16(self, type_: types.UInt16, **kw): + return "UInt16" def visit_uint8(self, type_: types.UInt8, **kw): return "UInt8" + def visit_int64(self, type_: types.Int64, **kw): + return "Int64" + + def visit_int32(self, type_: types.Int32, **kw): + return "Int32" + + def visit_int16(self, type_: types.Int16, **kw): + return "Int16" + + def visit_int8(self, type_: types.Int8, **kw): + return "Int8" + def visit_INTEGER(self, type_: sa.INTEGER, **kw): return "Int64" @@ -134,8 +149,28 @@ def get_ydb_type( if isinstance(type_, (sa.Text, sa.String, sa.Uuid)): ydb_type = ydb.PrimitiveType.Utf8 + + # Integers + elif isinstance(type_, types.UInt64): + ydb_type = ydb.PrimitiveType.Uint64 + elif isinstance(type_, types.UInt32): + ydb_type = ydb.PrimitiveType.Uint32 + elif isinstance(type_, types.UInt16): + ydb_type = ydb.PrimitiveType.Uint16 + elif isinstance(type_, types.UInt8): + ydb_type = ydb.PrimitiveType.Uint8 + elif isinstance(type_, types.Int64): + ydb_type = ydb.PrimitiveType.Int64 + elif isinstance(type_, types.Int32): + ydb_type = ydb.PrimitiveType.Int32 + elif isinstance(type_, types.Int16): + ydb_type = ydb.PrimitiveType.Int16 + elif isinstance(type_, types.Int8): + ydb_type = ydb.PrimitiveType.Int8 elif isinstance(type_, sa.Integer): ydb_type = ydb.PrimitiveType.Int64 + # Integers + elif isinstance(type_, sa.JSON): ydb_type = ydb.PrimitiveType.Json elif isinstance(type_, sa.DateTime): @@ -188,6 +223,36 @@ def group_by_clause(self, select, **kw): kw.update(within_columns_clause=True) return super(YqlCompiler, self).group_by_clause(select, **kw) + def limit_clause(self, select, **kw): + text = "" + if select._limit_clause is not None: + limit_clause = self._maybe_cast( + select._limit_clause, types.UInt64, skip_types=(types.UInt64, types.UInt32, types.UInt16, types.UInt8) + ) + text += "\n LIMIT " + self.process(limit_clause, **kw) + if select._offset_clause is not None: + offset_clause = self._maybe_cast( + select._offset_clause, types.UInt64, skip_types=(types.UInt64, types.UInt32, types.UInt16, types.UInt8) + ) + if select._limit_clause is None: + text += "\n LIMIT NULL" + text += " OFFSET " + self.process(offset_clause, **kw) + return text + + def _maybe_cast( + self, + element: Any, + cast_to: Type[sa.types.TypeEngine], + skip_types: Optional[Tuple[Type[sa.types.TypeEngine], ...]] = None, + ) -> Any: + if not skip_types: + skip_types = (cast_to,) + if cast_to not in skip_types: + skip_types = (*skip_types, cast_to) + if not hasattr(element, "type") or not isinstance(element.type, skip_types): + return sa.Cast(element, cast_to) + return element + def render_literal_value(self, value, type_): if isinstance(value, str): value = "".join(STR_QUOTE_MAP.get(x, x) for x in value) @@ -277,16 +342,14 @@ def _is_bound_to_nullable_column(self, bind_name: str) -> bool: def _guess_bound_variable_type_by_parameters( self, bind: sa.BindParameter, post_compile_bind_values: list ) -> Optional[sa.types.TypeEngine]: - if not bind.expanding: - if isinstance(bind.type, sa.types.NullType): - return None - bind_type = bind.type - else: + bind_type = bind.type + if bind.expanding or (isinstance(bind.type, sa.types.NullType) and post_compile_bind_values): not_null_values = [v for v in post_compile_bind_values if v is not None] if not_null_values: bind_type = sa.BindParameter("", not_null_values[0]).type - else: - return None + + if isinstance(bind_type, sa.types.NullType): + return None return bind_type diff --git a/ydb_sqlalchemy/sqlalchemy/types.py b/ydb_sqlalchemy/sqlalchemy/types.py index 8570e9a..94f957b 100644 --- a/ydb_sqlalchemy/sqlalchemy/types.py +++ b/ydb_sqlalchemy/sqlalchemy/types.py @@ -4,18 +4,38 @@ from sqlalchemy.sql import type_api +class UInt64(types.Integer): + __visit_name__ = "uint64" + + class UInt32(types.Integer): __visit_name__ = "uint32" -class UInt64(types.Integer): - __visit_name__ = "uint64" +class UInt16(types.Integer): + __visit_name__ = "uint16" class UInt8(types.Integer): __visit_name__ = "uint8" +class Int64(types.Integer): + __visit_name__ = "int64" + + +class Int32(types.Integer): + __visit_name__ = "int32" + + +class Int16(types.Integer): + __visit_name__ = "int32" + + +class Int8(types.Integer): + __visit_name__ = "int8" + + class ListType(ARRAY): __visit_name__ = "list_type"