diff --git a/docker-compose.yml b/docker-compose.yml index eff8558..13badc8 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,7 +1,7 @@ version: "3.3" services: ydb: - image: cr.yandex/yc/yandex-docker-local-ydb:latest + image: cr.yandex/yc/yandex-docker-local-ydb:trunk restart: always ports: - "2136:2136" diff --git a/test/test_core.py b/test/test_core.py index a1cbfb5..52c61e1 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -18,7 +18,13 @@ def test_sa_text(self, connection): assert rs.fetchone() == (1,) rs = connection.execute( - sa.text("SELECT x, y FROM AS_TABLE(:data)"), [{"data": [{"x": 2, "y": 1}, {"x": 3, "y": 2}]}] + sa.text( + """ + DECLARE :data AS List>; + SELECT x, y FROM AS_TABLE(:data) + """ + ), + [{"data": [{"x": 2, "y": 1}, {"x": 3, "y": 2}]}], ) assert set(rs.fetchall()) == {(2, 1), (3, 2)} @@ -184,7 +190,7 @@ def test_select_types(self, connection): id=1, # bin=b"abc", str="Hello World!", - num=3.1415, + num=3.5, bl=True, ts=now, date=today, @@ -193,4 +199,4 @@ def test_select_types(self, connection): connection.execute(stm) row = connection.execute(sa.select(tb)).fetchone() - assert row == (1, "Hello World!", 3.1415, True, now, today) + assert row == (1, "Hello World!", 3.5, True, now, today) diff --git a/test/test_suite.py b/test/test_suite.py index 42ce126..dc61109 100644 --- a/test/test_suite.py +++ b/test/test_suite.py @@ -5,7 +5,7 @@ from sqlalchemy.testing.suite import * # noqa: F401, F403 from sqlalchemy.testing import is_true, is_false -from sqlalchemy.testing.suite import eq_, testing, inspect, provide_metadata, config, requirements +from sqlalchemy.testing.suite import eq_, testing, inspect, provide_metadata, config, requirements, fixtures from sqlalchemy.testing.suite import func, column, literal_column, select, exists from sqlalchemy.testing.suite import MetaData, Column, Table, Integer, String @@ -32,6 +32,10 @@ NativeUUIDTest as _NativeUUIDTest, TimeMicrosecondsTest as _TimeMicrosecondsTest, DateTimeCoercedToDateTimeTest as _DateTimeCoercedToDateTimeTest, + DateTest as _DateTest, + DateTimeMicrosecondsTest as _DateTimeMicrosecondsTest, + DateTimeTest as _DateTimeTest, + TimestampMicrosecondsTest as _TimestampMicrosecondsTest, ) from sqlalchemy.testing.suite.test_dialect import ( EscapingTest as _EscapingTest, @@ -47,6 +51,7 @@ from sqlalchemy.testing.suite.test_results import RowFetchTest as _RowFetchTest from sqlalchemy.testing.suite.test_deprecations import DeprecatedCompoundSelectTest as _DeprecatedCompoundSelectTest +from ydb_sqlalchemy.sqlalchemy import types as ydb_sa_types test_types_suite = sqlalchemy.testing.suite.test_types col_creator = test_types_suite.Column @@ -259,6 +264,13 @@ def test_truediv_numeric(self): # SqlAlchemy maybe eat Decimal and throw Double pass + @testing.combinations(("6.25", "2.5", 2.5), argnames="left, right, expected") + def test_truediv_float(self, connection, left, right, expected): + eq_( + connection.scalar(select(literal_column(left, type_=sa.Float()) / literal_column(right, type_=sa.Float()))), + expected, + ) + class ExistsTest(_ExistsTest): """ @@ -402,6 +414,26 @@ def test_insert_from_select_autoinc(self, connection): def test_insert_from_select_autoinc_no_rows(self, connection): pass + @pytest.mark.skip("implicit PK values unsupported") + def test_no_results_for_non_returning_insert(self, connection): + pass + + +class DateTest(_DateTest): + run_dispose_bind = "once" + + +class DateTimeMicrosecondsTest(_DateTimeMicrosecondsTest): + run_dispose_bind = "once" + + +class DateTimeTest(_DateTimeTest): + run_dispose_bind = "once" + + +class TimestampMicrosecondsTest(_TimestampMicrosecondsTest): + run_dispose_bind = "once" + @pytest.mark.skip("unsupported Time data type") class TimeTest(_TimeTest): @@ -418,6 +450,71 @@ def test_nolength_string(self): foo.drop(config.db) +class ContainerTypesTest(fixtures.TablesTest): + @classmethod + def define_tables(cls, metadata): + Table( + "container_types_test", + metadata, + Column("id", Integer), + sa.PrimaryKeyConstraint("id"), + schema=None, + test_needs_fk=True, + ) + + def test_ARRAY_bind_variable(self, connection): + table = self.tables.container_types_test + + connection.execute(sa.insert(table).values([{"id": 1}, {"id": 2}, {"id": 3}])) + + stmt = select(table.c.id).where(table.c.id.in_(sa.bindparam("id", type_=sa.ARRAY(sa.Integer)))) + + eq_(connection.execute(stmt, {"id": [1, 2]}).fetchall(), [(1,), (2,)]) + + def test_list_type_bind_variable(self, connection): + table = self.tables.container_types_test + + connection.execute(sa.insert(table).values([{"id": 1}, {"id": 2}, {"id": 3}])) + + stmt = select(table.c.id).where(table.c.id.in_(sa.bindparam("id", type_=ydb_sa_types.ListType(sa.Integer)))) + + eq_(connection.execute(stmt, {"id": [1, 2]}).fetchall(), [(1,), (2,)]) + + def test_struct_type_bind_variable(self, connection): + table = self.tables.container_types_test + + connection.execute(sa.insert(table).values([{"id": 1}, {"id": 2}, {"id": 3}])) + + stmt = select(table.c.id).where( + table.c.id + == sa.text(":struct.id").bindparams( + sa.bindparam("struct", type_=ydb_sa_types.StructType({"id": sa.Integer})), + ) + ) + + eq_(connection.scalar(stmt, {"struct": {"id": 1}}), 1) + + def test_from_as_table(self, connection): + table = self.tables.container_types_test + + connection.execute( + sa.insert(table).from_select( + ["id"], + sa.select(sa.column("id")).select_from( + sa.func.as_table( + sa.bindparam( + "data", + value=[{"id": 1}, {"id": 2}, {"id": 3}], + type_=ydb_sa_types.ListType(ydb_sa_types.StructType({"id": sa.Integer})), + ) + ) + ), + ) + ) + + eq_(connection.execute(sa.select(table)).fetchall(), [(1,), (2,), (3,)]) + + @pytest.mark.skip("uuid unsupported for columns") class NativeUUIDTest(_NativeUUIDTest): pass diff --git a/test_dbapi/test_dbapi.py b/test_dbapi/test_dbapi.py index ea70fe8..ad354d5 100644 --- a/test_dbapi/test_dbapi.py +++ b/test_dbapi/test_dbapi.py @@ -12,13 +12,13 @@ def test_connection(connection): cur = connection.cursor() with suppress(dbapi.DatabaseError): - cur.execute("DROP TABLE foo", context={"isddl": True}) + cur.execute(dbapi.YdbQuery("DROP TABLE foo", is_ddl=True)) assert not connection.check_exists("/local/foo") with pytest.raises(dbapi.ProgrammingError): connection.describe("/local/foo") - cur.execute("CREATE TABLE foo(id Int64 NOT NULL, PRIMARY KEY (id))", context={"isddl": True}) + cur.execute(dbapi.YdbQuery("CREATE TABLE foo(id Int64 NOT NULL, PRIMARY KEY (id))", is_ddl=True)) assert connection.check_exists("/local/foo") @@ -26,87 +26,45 @@ def test_connection(connection): assert col.name == "id" assert col.type == ydb.PrimitiveType.Int64 - cur.execute("DROP TABLE foo", context={"isddl": True}) + cur.execute(dbapi.YdbQuery("DROP TABLE foo", is_ddl=True)) cur.close() -def test_cursor(connection): +def test_cursor_raw_query(connection): cur = connection.cursor() assert cur with suppress(dbapi.DatabaseError): - cur.execute("DROP TABLE test", context={"isddl": True}) + cur.execute(dbapi.YdbQuery("DROP TABLE test", is_ddl=True)) - cur.execute( - "CREATE TABLE test(id Int64 NOT NULL, text Utf8, PRIMARY KEY (id))", - context={"isddl": True}, - ) - - cur.execute('INSERT INTO test(id, text) VALUES (1, "foo")') - - cur.execute("SELECT id, text FROM test") - assert cur.fetchone() == (1, "foo"), "fetchone is ok" - - cur.execute("SELECT id, text FROM test WHERE id = %(id)s", {"id": 1}) - assert cur.fetchone() == (1, "foo"), "parametrized query is ok" + cur.execute(dbapi.YdbQuery("CREATE TABLE test(id Int64 NOT NULL, text Utf8, PRIMARY KEY (id))", is_ddl=True)) cur.execute( - "INSERT INTO test(id, text) VALUES (%(id1)s, %(text1)s), (%(id2)s, %(text2)s)", - {"id1": 2, "text1": "", "id2": 3, "text2": "bar"}, - ) - - cur.execute("UPDATE test SET text = %(t)s WHERE id = %(id)s", {"id": 2, "t": "foo2"}) - - cur.execute("SELECT id FROM test") - assert set(cur.fetchall()) == {(1,), (2,), (3,)}, "fetchall is ok" - - cur.execute("SELECT id FROM test ORDER BY id DESC") - assert cur.fetchmany(2) == [(3,), (2,)], "fetchmany is ok" - assert cur.fetchmany(1) == [(1,)] - - cur.execute("SELECT id FROM test ORDER BY id LIMIT 2") - assert cur.fetchall() == [(1,), (2,)], "limit clause without params is ok" - - # TODO: Failed to convert type: Int64 to Uint64 - # cur.execute("SELECT id FROM test ORDER BY id LIMIT %(limit)s", {"limit": 2}) - # assert cur.fetchall() == [(1,), (2,)], "limit clause with params is ok" - - cur2 = connection.cursor() - cur2.execute("INSERT INTO test(id) VALUES (%(id1)s), (%(id2)s)", {"id1": 5, "id2": 6}) - - cur.execute("SELECT id FROM test ORDER BY id") - assert cur.fetchall() == [(1,), (2,), (3,), (5,), (6,)], "cursor2 commit changes" - - cur.execute("SELECT text FROM test WHERE id > %(min_id)s", {"min_id": 3}) - assert cur.fetchall() == [(None,), (None,)], "NULL returns as None" - - cur.execute("SELECT id, text FROM test WHERE text LIKE %(p)s", {"p": "foo%"}) - assert set(cur.fetchall()) == {(1, "foo"), (2, "foo2")}, "like clause works" - - cur.execute( - # DECLARE statement (DECLARE $data AS List>) - # will generate automatically - """INSERT INTO test SELECT id, text FROM AS_TABLE($data);""", + dbapi.YdbQuery( + """ + DECLARE $data AS List>; + + INSERT INTO test SELECT id, text FROM AS_TABLE($data); + """, + parameters_types={ + "$data": ydb.ListType( + ydb.StructType() + .add_member("id", ydb.PrimitiveType.Int64) + .add_member("text", ydb.PrimitiveType.Utf8) + ) + }, + ), { - "data": [ + "$data": [ {"id": 17, "text": "seventeen"}, {"id": 21, "text": "twenty one"}, ] }, ) - cur.execute("SELECT id FROM test ORDER BY id") - assert cur.rowcount == 7, "rowcount ok" - assert cur.fetchall() == [(1,), (2,), (3,), (5,), (6,), (17,), (21,)], "ok" - - cur.execute("INSERT INTO test(id) VALUES (37)") - cur.execute("SELECT id FROM test WHERE text IS %(p)s", {"p": None}) - assert not cur.fetchone() == (37,) - - cur.execute("DROP TABLE test", context={"isddl": True}) + cur.execute(dbapi.YdbQuery("DROP TABLE test", is_ddl=True)) cur.close() - cur2.close() def test_errors(connection): @@ -116,25 +74,25 @@ def test_errors(connection): cur = connection.cursor() with suppress(dbapi.DatabaseError): - cur.execute("DROP TABLE test", context={"isddl": True}) + cur.execute(dbapi.YdbQuery("DROP TABLE test", is_ddl=True)) with pytest.raises(dbapi.DataError): - cur.execute("SELECT 18446744073709551616") + cur.execute(dbapi.YdbQuery("SELECT 18446744073709551616")) with pytest.raises(dbapi.DataError): - cur.execute("SELECT * FROM 拉屎") + cur.execute(dbapi.YdbQuery("SELECT * FROM 拉屎")) with pytest.raises(dbapi.DataError): - cur.execute("SELECT floor(5 / 2)") + cur.execute(dbapi.YdbQuery("SELECT floor(5 / 2)")) with pytest.raises(dbapi.ProgrammingError): - cur.execute("SELECT * FROM test") + cur.execute(dbapi.YdbQuery("SELECT * FROM test")) - cur.execute("CREATE TABLE test(id Int64, PRIMARY KEY (id))", context={"isddl": True}) + cur.execute(dbapi.YdbQuery("CREATE TABLE test(id Int64, PRIMARY KEY (id))", is_ddl=True)) - cur.execute("INSERT INTO test(id) VALUES(1)") + cur.execute(dbapi.YdbQuery("INSERT INTO test(id) VALUES(1)")) with pytest.raises(dbapi.IntegrityError): - cur.execute("INSERT INTO test(id) VALUES(1)") + cur.execute(dbapi.YdbQuery("INSERT INTO test(id) VALUES(1)")) - cur.execute("DROP TABLE test", context={"isddl": True}) + cur.execute(dbapi.YdbQuery("DROP TABLE test", is_ddl=True)) cur.close() diff --git a/ydb_sqlalchemy/dbapi/__init__.py b/ydb_sqlalchemy/dbapi/__init__.py index 8756b0f..075925c 100644 --- a/ydb_sqlalchemy/dbapi/__init__.py +++ b/ydb_sqlalchemy/dbapi/__init__.py @@ -1,4 +1,5 @@ from .connection import Connection +from .cursor import Cursor, YdbQuery # noqa: F401 from .errors import ( Warning, Error, diff --git a/ydb_sqlalchemy/dbapi/cursor.py b/ydb_sqlalchemy/dbapi/cursor.py index fc35fc3..7573e1f 100644 --- a/ydb_sqlalchemy/dbapi/cursor.py +++ b/ydb_sqlalchemy/dbapi/cursor.py @@ -1,10 +1,8 @@ -import datetime +import dataclasses import itertools import logging -import uuid -import decimal -import string -from typing import Optional, Dict, Any + +from typing import Any, Mapping, Optional, Sequence, Union, Dict import ydb from .errors import ( @@ -21,75 +19,17 @@ logger = logging.getLogger(__name__) -identifier_starts = {x for x in itertools.chain(string.ascii_letters, "_")} -valid_identifier_chars = {x for x in itertools.chain(identifier_starts, string.digits)} - - -def check_identifier_valid(idt: str) -> bool: - valid = idt and idt[0] in identifier_starts and all(c in valid_identifier_chars for c in idt) - if not valid: - raise ProgrammingError(f"Invalid identifier {idt}") - return valid - - def get_column_type(type_obj: Any) -> str: return str(ydb.convert.type_to_native(type_obj)) -def _generate_type_str(value: Any) -> str: - tvalue = type(value) - - stype = { - bool: "Bool", - bytes: "String", - str: "Utf8", - int: "Int64", - float: "Double", - decimal.Decimal: "Decimal(22, 9)", - datetime.date: "Date", - datetime.datetime: "Timestamp", - datetime.timedelta: "Interval", - uuid.UUID: "Uuid", - }.get(tvalue) - - if tvalue == dict: - types_lst = ", ".join(f"{k}: {_generate_type_str(v)}" for k, v in value.items()) - stype = f"Struct<{types_lst}>" - - elif tvalue == tuple: - types_lst = ", ".join(_generate_type_str(x) for x in value) - stype = f"Tuple<{types_lst}>" - - elif tvalue == list: - nested_type = _generate_type_str(value[0]) - stype = f"List<{nested_type}>" - - elif tvalue == set: - nested_type = _generate_type_str(next(iter(value))) - stype = f"Set<{nested_type}>" - - if stype is None: - raise ProgrammingError(f"Cannot translate value {value} (type {tvalue}) to ydb type.") - - return stype - - -def _generate_declare_stms(params: Dict[str, Any]) -> str: - return "".join(f"DECLARE {k} AS {_generate_type_str(t)}; " for k, t in params.items()) - - -def _generate_full_stm(sql: str, params: Optional[Dict[str, Any]] = None) -> (str, Optional[Dict[str, Any]]): - sql_params = None - - if params: - for name in params.keys(): - check_identifier_valid(name) - sql = sql % {k: f"${k}" if v is not None else "NULL" for k, v in params.items()} - sql_params = {f"${k}": v for k, v in params.items() if v is not None} - declare_stms = _generate_declare_stms(sql_params) - sql = f"{declare_stms}{sql}" - - return sql.replace("%%", "%"), sql_params +@dataclasses.dataclass +class YdbQuery: + yql_text: str + parameters_types: Dict[str, Union[ydb.PrimitiveType, ydb.AbstractTypeBuilder]] = dataclasses.field( + default_factory=dict + ) + is_ddl: bool = False class Cursor(object): @@ -100,19 +40,28 @@ def __init__(self, connection): self.rows = None self._rows_prefetched = None - def execute(self, sql, parameters=None, context=None): + def execute(self, operation: YdbQuery, parameters: Optional[Mapping[str, Any]] = None): self.description = None - sql, sql_params = _generate_full_stm(sql, parameters) - logger.info("execute sql: %s, params: %s", sql, sql_params) + if operation.is_ddl or not operation.parameters_types: + query = operation.yql_text + is_ddl = operation.is_ddl + else: + query = ydb.DataQuery(operation.yql_text, operation.parameters_types) + is_ddl = operation.is_ddl + + logger.info("execute sql: %s, params: %s", query, parameters) - def _execute_in_pool(cli): + def _execute_in_pool(cli: ydb.Session): try: - if context and context.get("isddl"): - return cli.execute_scheme(sql) - else: - prepared_query = cli.prepare(sql) - return cli.transaction().execute(prepared_query, sql_params, commit_tx=True) + if is_ddl: + return cli.execute_scheme(query) + + prepared_query = query + if isinstance(query, str) and parameters: + prepared_query = cli.prepare(query) + + return cli.transaction().execute(prepared_query, parameters, commit_tx=True) except (ydb.issues.AlreadyExists, ydb.issues.PreconditionFailed) as e: raise IntegrityError(e.message, e.issues, e.status) from e except (ydb.issues.Unsupported, ydb.issues.Unimplemented) as e: @@ -182,9 +131,9 @@ def _ensure_prefetched(self): self.rows = iter(self._rows_prefetched) return self._rows_prefetched - def executemany(self, sql, seq_of_parameters): + def executemany(self, operation: YdbQuery, seq_of_parameters: Optional[Sequence[Mapping[str, Any]]]): for parameters in seq_of_parameters: - self.execute(sql, parameters) + self.execute(operation, parameters) def executescript(self, script): return self.execute(script) diff --git a/ydb_sqlalchemy/dbapi/test_cursor.py b/ydb_sqlalchemy/dbapi/test_cursor.py deleted file mode 100644 index 32f28fd..0000000 --- a/ydb_sqlalchemy/dbapi/test_cursor.py +++ /dev/null @@ -1,84 +0,0 @@ -import pytest -import uuid -import decimal -from datetime import date, datetime, timedelta - -from .cursor import ( - _generate_type_str, - _generate_declare_stms, - _generate_full_stm, - check_identifier_valid, - ProgrammingError, -) - - -def test_check_identifier_valid(): - assert check_identifier_valid("id") - assert check_identifier_valid("_id") - assert check_identifier_valid("id0") - assert check_identifier_valid("foo_bar") - assert check_identifier_valid("foo_bar_1") - - with pytest.raises(ProgrammingError): - check_identifier_valid("") - - with pytest.raises(ProgrammingError): - check_identifier_valid("01") - - with pytest.raises(ProgrammingError): - check_identifier_valid("(a)") - - with pytest.raises(ProgrammingError): - check_identifier_valid("drop table") - - -def test_generate_type_str(): - assert _generate_type_str(True) == "Bool" - assert _generate_type_str(1) == "Int64" - assert _generate_type_str("foo") == "Utf8" - assert _generate_type_str(b"foo") == "String" - assert _generate_type_str(3.1415) == "Double" - assert _generate_type_str(uuid.uuid4()) == "Uuid" - assert _generate_type_str(decimal.Decimal("3.1415926535")) == "Decimal(22, 9)" - - assert _generate_type_str([1, 2, 3]) == "List" - assert _generate_type_str((1, "2", False)) == "Tuple" - assert _generate_type_str({1, 2, 3}) == "Set" - assert _generate_type_str({"foo": 1, "bar": 2, "kek": 3.14}) == "Struct" - - assert _generate_type_str([[1], [2], [3]]) == "List>" - assert _generate_type_str([{"a": 1, "b": 2}, {"a": 11, "b": 22}]) == "List>" - assert _generate_type_str(("foo", [1], 3.14)) == "Tuple, Double>" - - assert _generate_type_str(datetime.now()) == "Timestamp" - assert _generate_type_str(date.today()) == "Date" - assert _generate_type_str(timedelta(days=2)) == "Interval" - - with pytest.raises(ProgrammingError): - assert _generate_type_str(None) - - with pytest.raises(ProgrammingError): - assert _generate_type_str(object()) - - -def test_generate_declare_stm(): - assert _generate_declare_stms({}) == "" - assert _generate_declare_stms({"$p1": 123}).strip() == "DECLARE $p1 AS Int64;" - assert _generate_declare_stms({"$p1": 123, "$p2": "foo"}).strip() == "DECLARE $p1 AS Int64; DECLARE $p2 AS Utf8;" - - assert _generate_declare_stms({"$foo": decimal.Decimal("3.14")}).strip() == "DECLARE $foo AS Decimal(22, 9);" - assert _generate_declare_stms({"$foo": [1, 2, 3]}).strip() == "DECLARE $foo AS List;" - - -def test_generate_full_stm(): - assert _generate_full_stm("select 1") == ("select 1", None) - assert _generate_full_stm("select %(p1)s as value", {"p1": 1}) == ( - "DECLARE $p1 AS Int64; select $p1 as value", - {"$p1": 1}, - ) - assert _generate_full_stm("select %(p1)s as value1, %(P2)s as value2", {"p1": 1, "P2": "123"}) == ( - "DECLARE $p1 AS Int64; DECLARE $P2 AS Utf8; select $p1 as value1, $P2 as value2", - {"$p1": 1, "$P2": "123"}, - ) - - assert _generate_full_stm("select %(p1)s as value", {"p1": None}) == ("select NULL as value", {}) diff --git a/ydb_sqlalchemy/sqlalchemy/__init__.py b/ydb_sqlalchemy/sqlalchemy/__init__.py index adb286c..742e356 100644 --- a/ydb_sqlalchemy/sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/sqlalchemy/__init__.py @@ -2,6 +2,7 @@ Experimental Work in progress, breaking changes are possible. """ +import collections import ydb import ydb_sqlalchemy.dbapi as dbapi from ydb_sqlalchemy.dbapi.constants import YDB_KEYWORDS @@ -18,12 +19,12 @@ ) from sqlalchemy.sql.elements import ClauseList from sqlalchemy.engine import reflection -from sqlalchemy.engine.default import StrCompileDialect +from sqlalchemy.engine.default import StrCompileDialect, DefaultExecutionContext from sqlalchemy.util.compat import inspect_getfullargspec -from typing import Any +from typing import Any, Union, Mapping, Sequence, Optional, Tuple, List, Dict -from .types import UInt32, UInt64 +from . import types STR_QUOTE_MAP = { "'": "\\'", @@ -60,55 +61,109 @@ def __init__(self, dialect): class YqlTypeCompiler(StrSQLTypeCompiler): - def visit_CHAR(self, type_, **kw): + def visit_CHAR(self, type_: sa.CHAR, **kw): return "UTF8" - def visit_VARCHAR(self, type_, **kw): + def visit_VARCHAR(self, type_: sa.VARCHAR, **kw): return "UTF8" - def visit_unicode(self, type_, **kw): + def visit_unicode(self, type_: sa.Unicode, **kw): return "UTF8" - def visit_uuid(self, type_, **kw): + def visit_uuid(self, type_: sa.Uuid, **kw): return "UTF8" - def visit_NVARCHAR(self, type_, **kw): + def visit_NVARCHAR(self, type_: sa.NVARCHAR, **kw): return "UTF8" - def visit_TEXT(self, type_, **kw): + def visit_TEXT(self, type_: sa.TEXT, **kw): return "UTF8" - def visit_FLOAT(self, type_, **kw): - return "DOUBLE" + def visit_FLOAT(self, type_: sa.FLOAT, **kw): + return "FLOAT" - def visit_BOOLEAN(self, type_, **kw): + def visit_BOOLEAN(self, type_: sa.BOOLEAN, **kw): return "BOOL" - def visit_uint32(self, type_, **kw): + def visit_uint32(self, type_: types.UInt32, **kw): return "UInt32" - def visit_uint64(self, type_, **kw): + def visit_uint64(self, type_: types.UInt64, **kw): return "UInt64" - def visit_uint8(self, type_, **kw): + def visit_uint8(self, type_: types.UInt8, **kw): return "UInt8" - def visit_INTEGER(self, type_, **kw): + def visit_INTEGER(self, type_: sa.INTEGER, **kw): return "Int64" - def visit_NUMERIC(self, type_, **kw): + def visit_NUMERIC(self, type_: sa.Numeric, **kw): """Only Decimal(22,9) is supported for table columns""" return f"Decimal({type_.precision}, {type_.scale})" - def visit_BINARY(self, type_, **kw): + def visit_BINARY(self, type_: sa.BINARY, **kw): return "String" - def visit_BLOB(self, type_, **kw): + def visit_BLOB(self, type_: sa.BLOB, **kw): return "String" - def visit_DATETIME(self, type_, **kw): + def visit_DATETIME(self, type_: sa.TIMESTAMP, **kw): return "Timestamp" + def visit_list_type(self, type_: types.ListType, **kw): + inner = self.process(type_.item_type, **kw) + return f"List<{inner}>" + + def visit_ARRAY(self, type_: sa.ARRAY, **kw): + inner = self.process(type_.item_type, **kw) + return f"List<{inner}>" + + def visit_struct_type(self, type_: types.StructType, **kw): + text = "Struct<" + for field, field_type in type_.fields_types: + text += f"{field}:{self.process(field_type, **kw)}" + return text + ">" + + def get_ydb_type( + self, type_: sa.types.TypeEngine, is_optional: bool + ) -> Union[ydb.PrimitiveType, ydb.AbstractTypeBuilder]: + if isinstance(type_, sa.TypeDecorator): + type_ = type_.impl + + if isinstance(type_, (sa.Text, sa.String, sa.Uuid)): + ydb_type = ydb.PrimitiveType.Utf8 + elif isinstance(type_, sa.Integer): + ydb_type = ydb.PrimitiveType.Int64 + elif isinstance(type_, sa.JSON): + ydb_type = ydb.PrimitiveType.Json + elif isinstance(type_, sa.DateTime): + ydb_type = ydb.PrimitiveType.Timestamp + elif isinstance(type_, sa.Date): + ydb_type = ydb.PrimitiveType.Date + elif isinstance(type_, sa.BINARY): + ydb_type = ydb.PrimitiveType.String + elif isinstance(type_, sa.Float): + ydb_type = ydb.PrimitiveType.Float + elif isinstance(type_, sa.Double): + ydb_type = ydb.PrimitiveType.Double + elif isinstance(type_, sa.Boolean): + ydb_type = ydb.PrimitiveType.Bool + elif isinstance(type_, sa.Numeric): + ydb_type = ydb.DecimalType(type_.precision, type_.scale) + elif isinstance(type_, (types.ListType, sa.ARRAY)): + ydb_type = ydb.ListType(self.get_ydb_type(type_.item_type, is_optional=False)) + elif isinstance(type_, types.StructType): + ydb_type = ydb.StructType() + for field, field_type in type_.fields_types.items(): + ydb_type.add_member(field, self.get_ydb_type(field_type(), is_optional=False)) + else: + raise dbapi.NotSupportedError(f"{type_} bind variables not supported") + + if is_optional: + return ydb.OptionalType(ydb_type) + + return ydb_type + class ParametrizedFunction(functions.Function): __visit_name__ = "parametrized_function" @@ -210,6 +265,82 @@ def visit_regexp_match_op_binary(self, binary, operator, **kw): def visit_not_regexp_match_op_binary(self, binary, operator, **kw): return self._generate_generic_binary(binary, " NOT REGEXP ", **kw) + def _is_bound_to_nullable_column(self, bind_name: str) -> bool: + if bind_name in self.column_keys and hasattr(self.compile_state, "dml_table"): + if bind_name in self.compile_state.dml_table.c: + column = self.compile_state.dml_table.c[bind_name] + return not column.primary_key + return False + + def _guess_bound_variable_type_by_parameters( + self, bind: sa.BindParameter, post_compile_bind_values: list + ) -> Optional[sa.types.TypeEngine]: + if not bind.expanding: + if isinstance(bind.type, sa.types.NullType): + return None + bind_type = bind.type + else: + not_null_values = [v for v in post_compile_bind_values if v is not None] + if not_null_values: + bind_type = sa.BindParameter("", not_null_values[0]).type + else: + return None + + return bind_type + + def _get_expanding_bind_names(self, bind_name: str, parameters_values: Mapping[str, List[Any]]) -> List[Any]: + expanding_bind_names = [] + for parameter_name in parameters_values: + parameter_bind_name = "_".join(parameter_name.split("_")[:-1]) + if parameter_bind_name == bind_name: + expanding_bind_names.append(parameter_name) + return expanding_bind_names + + def get_bind_types( + self, post_compile_parameters: Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]] + ) -> Dict[str, Union[ydb.PrimitiveType, ydb.AbstractTypeBuilder]]: + """ + This method extracts information about bound variables from the table definition and parameters. + """ + if isinstance(post_compile_parameters, collections.Mapping): + post_compile_parameters = [post_compile_parameters] + + parameters_values = collections.defaultdict(list) + for parameters_entry in post_compile_parameters: + for parameter_name, parameter_value in parameters_entry.items(): + parameters_values[parameter_name].append(parameter_value) + + parameter_types = {} + for bind_name in self.bind_names.values(): + bind = self.binds[bind_name] + + if bind.literal_execute: + continue + + if not bind.expanding: + post_compile_bind_names = [bind_name] + post_compile_bind_values = parameters_values[bind_name] + else: + post_compile_bind_names = self._get_expanding_bind_names(bind_name, parameters_values) + post_compile_bind_values = [] + for parameter_name, parameter_values in parameters_values.items(): + if parameter_name in post_compile_bind_names: + post_compile_bind_values.extend(parameter_values) + + is_optional = self._is_bound_to_nullable_column(bind_name) + if not post_compile_bind_values or None in post_compile_bind_values: + is_optional = True + + bind_type = self._guess_bound_variable_type_by_parameters(bind, post_compile_bind_values) + + if bind_type: + for post_compile_bind_name in post_compile_bind_names: + parameter_types[post_compile_bind_name] = YqlTypeCompiler(self.dialect).get_ydb_type( + bind_type, is_optional + ) + + return parameter_types + class YqlDDLCompiler(DDLCompiler): pass @@ -226,8 +357,8 @@ def upsert(table): ydb.PrimitiveType.Int64: sa.INTEGER, ydb.PrimitiveType.Uint8: sa.INTEGER, ydb.PrimitiveType.Uint16: sa.INTEGER, - ydb.PrimitiveType.Uint32: UInt32, - ydb.PrimitiveType.Uint64: UInt64, + ydb.PrimitiveType.Uint32: types.UInt32, + ydb.PrimitiveType.Uint64: types.UInt64, ydb.PrimitiveType.Float: sa.FLOAT, ydb.PrimitiveType.Double: sa.FLOAT, ydb.PrimitiveType.String: sa.BINARY, @@ -364,8 +495,76 @@ def do_commit(self, dbapi_connection) -> None: # TODO: needs to implement? pass - def do_execute(self, cursor, statement, parameters, context=None) -> None: - c = None - if context is not None and context.isddl: - c = {"isddl": True} - cursor.execute(statement, parameters, c) + def _format_variables( + self, + statement: str, + parameters: Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]], + execute_many: bool, + ) -> Tuple[str, Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]]]: + formatted_statement = statement + formatted_parameters = None + + if parameters: + if execute_many: + parameters_sequence: Sequence[Mapping[str, Any]] = parameters + variable_names = set() + formatted_parameters = [] + for i in range(len(parameters_sequence)): + variable_names.update(set(parameters_sequence[i].keys())) + formatted_parameters.append({f"${k}": v for k, v in parameters_sequence[i].items()}) + else: + variable_names = set(parameters.keys()) + formatted_parameters = {f"${k}": v for k, v in parameters.items()} + + formatted_variable_names = {variable_name: f"${variable_name}" for variable_name in variable_names} + formatted_statement = formatted_statement % formatted_variable_names + + formatted_statement = formatted_statement.replace("%%", "%") + return formatted_statement, formatted_parameters + + def _make_ydb_operation( + self, + statement: str, + context: Optional[DefaultExecutionContext] = None, + parameters: Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]] = None, + execute_many: bool = False, + ) -> Tuple[dbapi.YdbQuery, Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]]]: + is_ddl = context.isddl if context is not None else False + + if not is_ddl and parameters: + parameters_types = context.compiled.get_bind_types(parameters) + parameters_types = {f"${k}": v for k, v in parameters_types.items()} + statement, parameters = self._format_variables(statement, parameters, execute_many) + return dbapi.YdbQuery(yql_text=statement, parameters_types=parameters_types, is_ddl=is_ddl), parameters + + statement, parameters = self._format_variables(statement, parameters, execute_many) + return dbapi.YdbQuery(yql_text=statement, is_ddl=is_ddl), parameters + + def do_ping(self, dbapi_connection: dbapi.Connection) -> bool: + cursor = dbapi_connection.cursor() + statement, _ = self._make_ydb_operation(self._dialect_specific_select_one) + try: + cursor.execute(statement) + finally: + cursor.close() + return True + + def do_executemany( + self, + cursor: dbapi.Cursor, + statement: str, + parameters: Optional[Sequence[Mapping[str, Any]]], + context: Optional[DefaultExecutionContext] = None, + ) -> None: + operation, parameters = self._make_ydb_operation(statement, context, parameters, execute_many=True) + cursor.executemany(operation, parameters) + + def do_execute( + self, + cursor: dbapi.Cursor, + statement: str, + parameters: Optional[Mapping[str, Any]] = None, + context: Optional[DefaultExecutionContext] = None, + ) -> None: + operation, parameters = self._make_ydb_operation(statement, context, parameters, execute_many=False) + cursor.execute(operation, parameters) diff --git a/ydb_sqlalchemy/sqlalchemy/types.py b/ydb_sqlalchemy/sqlalchemy/types.py index 5515be5..61fa5ca 100644 --- a/ydb_sqlalchemy/sqlalchemy/types.py +++ b/ydb_sqlalchemy/sqlalchemy/types.py @@ -1,19 +1,38 @@ -from sqlalchemy import exc, Integer, ColumnElement +from sqlalchemy import exc, ColumnElement, ARRAY, types from sqlalchemy.sql import type_api +from typing import Mapping, Any, Union, Type -class UInt32(Integer): +class UInt32(types.Integer): __visit_name__ = "uint32" -class UInt64(Integer): +class UInt64(types.Integer): __visit_name__ = "uint64" -class UInt8(Integer): +class UInt8(types.Integer): __visit_name__ = "uint8" +class ListType(ARRAY): + __visit_name__ = "list_type" + + +class StructType(types.TypeEngine[Mapping[str, Any]]): + __visit_name__ = "struct_type" + + def __init__(self, fields_types: Mapping[str, Union[Type[types.TypeEngine], Type[types.TypeDecorator]]]): + self.fields_types = fields_types + + @property + def python_type(self): + return dict + + def compare_values(self, x, y): + return x == y + + class Lambda(ColumnElement): __visit_name__ = "lambda"