diff --git a/test/test_core.py b/test/test_core.py index 8c662ba..fa4e326 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -1,14 +1,15 @@ -from decimal import Decimal from datetime import date, datetime +from decimal import Decimal +from typing import NamedTuple import pytest import sqlalchemy as sa -from sqlalchemy import Table, Column, Integer, Unicode -from sqlalchemy.testing.fixtures import TestBase, TablesTest - import ydb +from sqlalchemy import Table, Column, Integer, Unicode +from sqlalchemy.testing.fixtures import TestBase, TablesTest, config from ydb._grpc.v4.protos import ydb_common_pb2 +from ydb_sqlalchemy import dbapi, IsolationLevel from ydb_sqlalchemy.sqlalchemy import types @@ -219,9 +220,9 @@ def _create_table_and_get_desc(connection, metadata, **kwargs): ) table.create(connection) - session: ydb.Session = connection.connection.driver_connection.pool.acquire() + session: ydb.Session = connection.connection.driver_connection.session_pool.acquire() table_description = session.describe_table("/local/" + table.name) - session.delete() + connection.connection.driver_connection.session_pool.release(session) return table_description @pytest.mark.parametrize( @@ -367,3 +368,174 @@ def test_several_keys(self, connection, metadata): assert desc.partitioning_settings.partitioning_by_load == 1 assert desc.partitioning_settings.min_partitions_count == 3 assert desc.partitioning_settings.max_partitions_count == 5 + + +class TestTransaction(TablesTest): + @classmethod + def define_tables(cls, metadata: sa.MetaData): + Table( + "test", + metadata, + Column("id", Integer, primary_key=True), + ) + + def test_rollback(self, connection_no_trans: sa.Connection, connection: sa.Connection): + table = self.tables.test + + connection_no_trans.execution_options(isolation_level=IsolationLevel.SERIALIZABLE) + with connection_no_trans.begin(): + stm1 = table.insert().values(id=1) + connection_no_trans.execute(stm1) + stm2 = table.insert().values(id=2) + connection_no_trans.execute(stm2) + connection_no_trans.rollback() + + cursor = connection.execute(sa.select(table)) + result = cursor.fetchall() + assert result == [] + + def test_commit(self, connection_no_trans: sa.Connection, connection: sa.Connection): + table = self.tables.test + + connection_no_trans.execution_options(isolation_level=IsolationLevel.SERIALIZABLE) + with connection_no_trans.begin(): + stm1 = table.insert().values(id=3) + connection_no_trans.execute(stm1) + stm2 = table.insert().values(id=4) + connection_no_trans.execute(stm2) + + cursor = connection.execute(sa.select(table)) + result = cursor.fetchall() + assert set(result) == {(3,), (4,)} + + @pytest.mark.parametrize("isolation_level", (IsolationLevel.SERIALIZABLE, IsolationLevel.SNAPSHOT_READONLY)) + def test_interactive_transaction( + self, connection_no_trans: sa.Connection, connection: sa.Connection, isolation_level + ): + table = self.tables.test + dbapi_connection: dbapi.Connection = connection_no_trans.connection.dbapi_connection + + stm1 = table.insert().values([{"id": 5}, {"id": 6}]) + connection.execute(stm1) + + connection_no_trans.execution_options(isolation_level=isolation_level) + with connection_no_trans.begin(): + tx_id = dbapi_connection.tx_context.tx_id + assert tx_id is not None + cursor1 = connection_no_trans.execute(sa.select(table)) + cursor2 = connection_no_trans.execute(sa.select(table)) + assert dbapi_connection.tx_context.tx_id == tx_id + + assert set(cursor1.fetchall()) == {(5,), (6,)} + assert set(cursor2.fetchall()) == {(5,), (6,)} + + @pytest.mark.parametrize( + "isolation_level", + ( + IsolationLevel.ONLINE_READONLY, + IsolationLevel.ONLINE_READONLY_INCONSISTENT, + IsolationLevel.STALE_READONLY, + IsolationLevel.AUTOCOMMIT, + ), + ) + def test_not_interactive_transaction( + self, connection_no_trans: sa.Connection, connection: sa.Connection, isolation_level + ): + table = self.tables.test + dbapi_connection: dbapi.Connection = connection_no_trans.connection.dbapi_connection + + stm1 = table.insert().values([{"id": 7}, {"id": 8}]) + connection.execute(stm1) + + connection_no_trans.execution_options(isolation_level=isolation_level) + with connection_no_trans.begin(): + assert dbapi_connection.tx_context is None + cursor1 = connection_no_trans.execute(sa.select(table)) + cursor2 = connection_no_trans.execute(sa.select(table)) + assert dbapi_connection.tx_context is None + + assert set(cursor1.fetchall()) == {(7,), (8,)} + assert set(cursor2.fetchall()) == {(7,), (8,)} + + +class TestTransactionIsolationLevel(TestBase): + class IsolationSettings(NamedTuple): + ydb_mode: ydb.AbstractTransactionModeBuilder + interactive: bool + + YDB_ISOLATION_SETTINGS_MAP = { + IsolationLevel.AUTOCOMMIT: IsolationSettings(ydb.SerializableReadWrite().name, False), + IsolationLevel.SERIALIZABLE: IsolationSettings(ydb.SerializableReadWrite().name, True), + IsolationLevel.ONLINE_READONLY: IsolationSettings(ydb.OnlineReadOnly().name, False), + IsolationLevel.ONLINE_READONLY_INCONSISTENT: IsolationSettings( + ydb.OnlineReadOnly().with_allow_inconsistent_reads().name, False + ), + IsolationLevel.STALE_READONLY: IsolationSettings(ydb.StaleReadOnly().name, False), + IsolationLevel.SNAPSHOT_READONLY: IsolationSettings(ydb.SnapshotReadOnly().name, True), + } + + def test_connection_set(self, connection_no_trans: sa.Connection): + dbapi_connection: dbapi.Connection = connection_no_trans.connection.dbapi_connection + + for sa_isolation_level, ydb_isolation_settings in self.YDB_ISOLATION_SETTINGS_MAP.items(): + connection_no_trans.execution_options(isolation_level=sa_isolation_level) + with connection_no_trans.begin(): + assert dbapi_connection.tx_mode.name == ydb_isolation_settings[0] + assert dbapi_connection.interactive_transaction is ydb_isolation_settings[1] + if dbapi_connection.interactive_transaction: + assert dbapi_connection.tx_context is not None + assert dbapi_connection.tx_context.tx_id is not None + else: + assert dbapi_connection.tx_context is None + + +class TestEngine(TestBase): + @pytest.fixture(scope="module") + def ydb_driver(self): + url = config.db_url + driver = ydb.Driver(endpoint=f"grpc://{url.host}:{url.port}", database=url.database) + try: + driver.wait(timeout=5, fail_fast=True) + yield driver + finally: + driver.stop() + + driver.stop() + + @pytest.fixture(scope="module") + def ydb_pool(self, ydb_driver): + session_pool = ydb.SessionPool(ydb_driver, size=5, workers_threads_count=1) + + yield session_pool + + session_pool.stop() + + def test_sa_queue_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool): + engine1 = sa.create_engine(config.db_url, poolclass=sa.QueuePool, connect_args={"ydb_session_pool": ydb_pool}) + engine2 = sa.create_engine(config.db_url, poolclass=sa.QueuePool, connect_args={"ydb_session_pool": ydb_pool}) + + with engine1.connect() as conn1, engine2.connect() as conn2: + dbapi_conn1: dbapi.Connection = conn1.connection.dbapi_connection + dbapi_conn2: dbapi.Connection = conn2.connection.dbapi_connection + + assert dbapi_conn1.session_pool is dbapi_conn2.session_pool + assert dbapi_conn1.driver is dbapi_conn2.driver + + engine1.dispose() + engine2.dispose() + assert not ydb_driver._stopped + + def test_sa_null_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool): + engine1 = sa.create_engine(config.db_url, poolclass=sa.NullPool, connect_args={"ydb_session_pool": ydb_pool}) + engine2 = sa.create_engine(config.db_url, poolclass=sa.NullPool, connect_args={"ydb_session_pool": ydb_pool}) + + with engine1.connect() as conn1, engine2.connect() as conn2: + dbapi_conn1: dbapi.Connection = conn1.connection.dbapi_connection + dbapi_conn2: dbapi.Connection = conn2.connection.dbapi_connection + + assert dbapi_conn1.session_pool is dbapi_conn2.session_pool + assert dbapi_conn1.driver is dbapi_conn2.driver + + engine1.dispose() + engine2.dispose() + assert not ydb_driver._stopped diff --git a/test_dbapi/conftest.py b/test_dbapi/conftest.py index 92f6610..7a9f5a3 100644 --- a/test_dbapi/conftest.py +++ b/test_dbapi/conftest.py @@ -1,9 +1,10 @@ import pytest + import ydb_sqlalchemy.dbapi as dbapi @pytest.fixture(scope="module") def connection(): - conn = dbapi.connect("localhost:2136", database="/local") + conn = dbapi.connect(host="localhost", port="2136", database="/local") yield conn conn.close() diff --git a/ydb_sqlalchemy/__init__.py b/ydb_sqlalchemy/__init__.py index e69de29..2e5fbab 100644 --- a/ydb_sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/__init__.py @@ -0,0 +1 @@ +from .dbapi import IsolationLevel # noqa: F401 diff --git a/ydb_sqlalchemy/dbapi/__init__.py b/ydb_sqlalchemy/dbapi/__init__.py index 075925c..f06e15f 100644 --- a/ydb_sqlalchemy/dbapi/__init__.py +++ b/ydb_sqlalchemy/dbapi/__init__.py @@ -1,4 +1,4 @@ -from .connection import Connection +from .connection import Connection, IsolationLevel # noqa: F401 from .cursor import Cursor, YdbQuery # noqa: F401 from .errors import ( Warning, diff --git a/ydb_sqlalchemy/dbapi/connection.py b/ydb_sqlalchemy/dbapi/connection.py index e8fbfc7..43e6273 100644 --- a/ydb_sqlalchemy/dbapi/connection.py +++ b/ydb_sqlalchemy/dbapi/connection.py @@ -1,24 +1,54 @@ import posixpath +from typing import Optional, NamedTuple, Any import ydb + from .cursor import Cursor -from .errors import InterfaceError, ProgrammingError, DatabaseError +from .errors import InterfaceError, ProgrammingError, DatabaseError, InternalError, NotSupportedError + + +class IsolationLevel: + SERIALIZABLE = "SERIALIZABLE" + ONLINE_READONLY = "ONLINE READONLY" + ONLINE_READONLY_INCONSISTENT = "ONLINE READONLY INCONSISTENT" + STALE_READONLY = "STALE READONLY" + SNAPSHOT_READONLY = "SNAPSHOT READONLY" + AUTOCOMMIT = "AUTOCOMMIT" class Connection: - def __init__(self, endpoint=None, host=None, port=None, database=None, **conn_kwargs): - self.endpoint = endpoint or f"grpc://{host}:{port}" + def __init__( + self, + host: str = "", + port: str = "", + database: str = "", + **conn_kwargs: Any, + ): + self.endpoint = f"grpc://{host}:{port}" self.database = database - self.driver = self._create_driver(self.endpoint, self.database, **conn_kwargs) - self.pool = ydb.SessionPool(self.driver) + self.conn_kwargs = conn_kwargs + + if "ydb_session_pool" in self.conn_kwargs: # Use session pool managed manually + self._shared_session_pool = True + self.session_pool: ydb.SessionPool = self.conn_kwargs.pop("ydb_session_pool") + self.driver = self.session_pool._pool_impl._driver + self.driver.table_client = ydb.TableClient(self.driver, self._get_table_client_settings()) + else: + self._shared_session_pool = False + self.driver = self._create_driver() + self.session_pool = ydb.SessionPool(self.driver, size=5, workers_threads_count=1) + + self.interactive_transaction: bool = False # AUTOCOMMIT + self.tx_mode: ydb.AbstractTransactionModeBuilder = ydb.SerializableReadWrite() + self.tx_context: Optional[ydb.TxContext] = None def cursor(self): - return Cursor(self) + return Cursor(self.session_pool, self.tx_context) def describe(self, table_path): full_path = posixpath.join(self.database, table_path) try: - return self.pool.retry_operation_sync(lambda cli: cli.describe_table(full_path)) + return self.session_pool.retry_operation_sync(lambda session: session.describe_table(full_path)) except ydb.issues.SchemeError as e: raise ProgrammingError(e.message, e.issues, e.status) from e except ydb.Error as e: @@ -33,31 +63,85 @@ def check_exists(self, table_path): except ydb.SchemeError: return False + def set_isolation_level(self, isolation_level: str): + class IsolationSettings(NamedTuple): + ydb_mode: ydb.AbstractTransactionModeBuilder + interactive: bool + + ydb_isolation_settings_map = { + IsolationLevel.AUTOCOMMIT: IsolationSettings(ydb.SerializableReadWrite(), interactive=False), + IsolationLevel.SERIALIZABLE: IsolationSettings(ydb.SerializableReadWrite(), interactive=True), + IsolationLevel.ONLINE_READONLY: IsolationSettings(ydb.OnlineReadOnly(), interactive=False), + IsolationLevel.ONLINE_READONLY_INCONSISTENT: IsolationSettings( + ydb.OnlineReadOnly().with_allow_inconsistent_reads(), interactive=False + ), + IsolationLevel.STALE_READONLY: IsolationSettings(ydb.StaleReadOnly(), interactive=False), + IsolationLevel.SNAPSHOT_READONLY: IsolationSettings(ydb.SnapshotReadOnly(), interactive=True), + } + ydb_isolation_settings = ydb_isolation_settings_map[isolation_level] + if self.tx_context and self.tx_context.tx_id: + raise InternalError("Failed to set transaction mode: transaction is already began") + self.tx_mode = ydb_isolation_settings.ydb_mode + self.interactive_transaction = ydb_isolation_settings.interactive + + def get_isolation_level(self) -> str: + if self.tx_mode.name == ydb.SerializableReadWrite().name: + if self.interactive_transaction: + return IsolationLevel.SERIALIZABLE + else: + return IsolationLevel.AUTOCOMMIT + elif self.tx_mode.name == ydb.OnlineReadOnly().name: + if self.tx_mode.settings.allow_inconsistent_reads: + return IsolationLevel.ONLINE_READONLY_INCONSISTENT + else: + return IsolationLevel.ONLINE_READONLY + elif self.tx_mode.name == ydb.StaleReadOnly().name: + return IsolationLevel.STALE_READONLY + elif self.tx_mode.name == ydb.SnapshotReadOnly().name: + return IsolationLevel.SNAPSHOT_READONLY + else: + raise NotSupportedError(f"{self.tx_mode.name} is not supported") + + def begin(self): + self.tx_context = None + if self.interactive_transaction: + session = self.session_pool.acquire(blocking=True) + self.tx_context = session.transaction(self.tx_mode) + self.tx_context.begin() + def commit(self): - pass + if self.tx_context and self.tx_context.tx_id: + self.tx_context.commit() + self.session_pool.release(self.tx_context.session) + self.tx_context = None def rollback(self): - pass + if self.tx_context and self.tx_context.tx_id: + self.tx_context.rollback() + self.session_pool.release(self.tx_context.session) + self.tx_context = None def close(self): - if self.pool: - self.pool.stop() - if self.driver: - self.driver.stop() - - @staticmethod - def _create_driver(endpoint, database, **conn_kwargs): - # TODO: add cache for initialized drivers/pools? - driver_config = ydb.DriverConfig( - endpoint, - database=database, - table_client_settings=ydb.TableClientSettings() + self.rollback() + if not self._shared_session_pool: + self.session_pool.stop() + self._stop_driver() + + def _get_table_client_settings(self) -> ydb.TableClientSettings: + return ( + ydb.TableClientSettings() .with_native_date_in_result_sets(True) .with_native_datetime_in_result_sets(True) .with_native_timestamp_in_result_sets(True) .with_native_interval_in_result_sets(True) - .with_native_json_in_result_sets(True), - **conn_kwargs, + .with_native_json_in_result_sets(True) + ) + + def _create_driver(self): + driver_config = ydb.DriverConfig( + endpoint=self.endpoint, + database=self.database, + table_client_settings=self._get_table_client_settings(), ) driver = ydb.Driver(driver_config) try: @@ -68,3 +152,6 @@ def _create_driver(endpoint, database, **conn_kwargs): driver.stop() raise InterfaceError(f"Failed to connect to YDB, details {driver.discovery_debug_details()}") from e return driver + + def _stop_driver(self): + self.driver.stop() diff --git a/ydb_sqlalchemy/dbapi/cursor.py b/ydb_sqlalchemy/dbapi/cursor.py index 7573e1f..4ae9565 100644 --- a/ydb_sqlalchemy/dbapi/cursor.py +++ b/ydb_sqlalchemy/dbapi/cursor.py @@ -1,10 +1,10 @@ import dataclasses import itertools import logging - -from typing import Any, Mapping, Optional, Sequence, Union, Dict +from typing import Any, Mapping, Optional, Sequence, Union, Dict, Callable import ydb + from .errors import ( InternalError, IntegrityError, @@ -15,7 +15,6 @@ NotSupportedError, ) - logger = logging.getLogger(__name__) @@ -33,16 +32,19 @@ class YdbQuery: class Cursor(object): - def __init__(self, connection): - self.connection = connection + def __init__( + self, + session_pool: ydb.SessionPool, + tx_context: Optional[ydb.BaseTxContext] = None, + ): + self.session_pool = session_pool + self.tx_context = tx_context self.description = None self.arraysize = 1 self.rows = None self._rows_prefetched = None def execute(self, operation: YdbQuery, parameters: Optional[Mapping[str, Any]] = None): - self.description = None - if operation.is_ddl or not operation.parameters_types: query = operation.yql_text is_ddl = operation.is_ddl @@ -51,45 +53,14 @@ def execute(self, operation: YdbQuery, parameters: Optional[Mapping[str, Any]] = is_ddl = operation.is_ddl logger.info("execute sql: %s, params: %s", query, parameters) + if is_ddl: + chunks = self.session_pool.retry_operation_sync(self._execute_ddl, None, query) + else: + if self.tx_context: + chunks = self._execute_dml(self.tx_context.session, query, parameters, self.tx_context) + else: + chunks = self.session_pool.retry_operation_sync(self._execute_dml, None, query, parameters) - def _execute_in_pool(cli: ydb.Session): - try: - if is_ddl: - return cli.execute_scheme(query) - - prepared_query = query - if isinstance(query, str) and parameters: - prepared_query = cli.prepare(query) - - return cli.transaction().execute(prepared_query, parameters, commit_tx=True) - except (ydb.issues.AlreadyExists, ydb.issues.PreconditionFailed) as e: - raise IntegrityError(e.message, e.issues, e.status) from e - except (ydb.issues.Unsupported, ydb.issues.Unimplemented) as e: - raise NotSupportedError(e.message, e.issues, e.status) from e - except (ydb.issues.BadRequest, ydb.issues.SchemeError) as e: - raise ProgrammingError(e.message, e.issues, e.status) from e - except ( - ydb.issues.TruncatedResponseError, - ydb.issues.ConnectionError, - ydb.issues.Aborted, - ydb.issues.Unavailable, - ydb.issues.Overloaded, - ydb.issues.Undetermined, - ydb.issues.Timeout, - ydb.issues.Cancelled, - ydb.issues.SessionBusy, - ydb.issues.SessionExpired, - ydb.issues.SessionPoolEmpty, - ) as e: - raise OperationalError(e.message, e.issues, e.status) from e - except ydb.issues.GenericError as e: - raise DataError(e.message, e.issues, e.status) from e - except ydb.issues.InternalError as e: - raise InternalError(e.message, e.issues, e.status) from e - except ydb.Error as e: - raise DatabaseError(e.message, e.issues, e.status) from e - - chunks = self.connection.pool.retry_operation_sync(_execute_in_pool) rows = self._rows_iterable(chunks) # Prefetch the description: try: @@ -103,7 +74,59 @@ def _execute_in_pool(cli: ydb.Session): self.rows = rows - def _rows_iterable(self, chunks_iterable): + @classmethod + def _execute_dml( + cls, + session: ydb.Session, + query: ydb.DataQuery, + parameters: Optional[Mapping[str, Any]] = None, + tx_context: Optional[ydb.BaseTxContext] = None, + ) -> ydb.convert.ResultSets: + prepared_query = query + if isinstance(query, str) and parameters: + prepared_query = session.prepare(query) + + if tx_context: + return cls._handle_ydb_errors(tx_context.execute, prepared_query, parameters) + + return cls._handle_ydb_errors(session.transaction().execute, prepared_query, parameters, commit_tx=True) + + @classmethod + def _execute_ddl(cls, session: ydb.Session, query: str) -> ydb.convert.ResultSets: + return cls._handle_ydb_errors(session.execute_scheme, query) + + @staticmethod + def _handle_ydb_errors(callee: Callable, *args, **kwargs) -> Any: + try: + return callee(*args, **kwargs) + except (ydb.issues.AlreadyExists, ydb.issues.PreconditionFailed) as e: + raise IntegrityError(e.message, e.issues, e.status) from e + except (ydb.issues.Unsupported, ydb.issues.Unimplemented) as e: + raise NotSupportedError(e.message, e.issues, e.status) from e + except (ydb.issues.BadRequest, ydb.issues.SchemeError) as e: + raise ProgrammingError(e.message, e.issues, e.status) from e + except ( + ydb.issues.TruncatedResponseError, + ydb.issues.ConnectionError, + ydb.issues.Aborted, + ydb.issues.Unavailable, + ydb.issues.Overloaded, + ydb.issues.Undetermined, + ydb.issues.Timeout, + ydb.issues.Cancelled, + ydb.issues.SessionBusy, + ydb.issues.SessionExpired, + ydb.issues.SessionPoolEmpty, + ) as e: + raise OperationalError(e.message, e.issues, e.status) from e + except ydb.issues.GenericError as e: + raise DataError(e.message, e.issues, e.status) from e + except ydb.issues.InternalError as e: + raise InternalError(e.message, e.issues, e.status) from e + except ydb.Error as e: + raise DatabaseError(e.message, e.issues, e.status) from e + + def _rows_iterable(self, chunks_iterable: ydb.convert.ResultSets): try: for chunk in chunks_iterable: self.description = [ diff --git a/ydb_sqlalchemy/sqlalchemy/__init__.py b/ydb_sqlalchemy/sqlalchemy/__init__.py index 9b6b6ef..52d6991 100644 --- a/ydb_sqlalchemy/sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/sqlalchemy/__init__.py @@ -538,9 +538,23 @@ def get_indexes(self, connection, table_name, schema=None, **kwargs): # TODO: implement me return [] - def do_commit(self, dbapi_connection) -> None: - # TODO: needs to implement? - pass + def set_isolation_level(self, dbapi_connection: dbapi.Connection, level: str) -> None: + dbapi_connection.set_isolation_level(level) + + def get_default_isolation_level(self, dbapi_conn: dbapi.Connection) -> str: + return dbapi.IsolationLevel.AUTOCOMMIT + + def get_isolation_level(self, dbapi_connection: dbapi.Connection) -> str: + return dbapi_connection.get_isolation_level() + + def do_begin(self, dbapi_connection: dbapi.Connection) -> None: + dbapi_connection.begin() + + def do_rollback(self, dbapi_connection: dbapi.Connection) -> None: + dbapi_connection.rollback() + + def do_commit(self, dbapi_connection: dbapi.Connection) -> None: + dbapi_connection.commit() def _format_variables( self,