Skip to content

Commit

Permalink
Merge pull request #24 from LuckySting/refactor-dbapi-connection Add …
Browse files Browse the repository at this point in the history
…transactions support
  • Loading branch information
rekby authored Jan 16, 2024
2 parents a3144d1 + a85232c commit dc7381a
Show file tree
Hide file tree
Showing 7 changed files with 378 additions and 80 deletions.
184 changes: 178 additions & 6 deletions test/test_core.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from decimal import Decimal
from datetime import date, datetime
from decimal import Decimal
from typing import NamedTuple

import pytest
import sqlalchemy as sa
from sqlalchemy import Table, Column, Integer, Unicode
from sqlalchemy.testing.fixtures import TestBase, TablesTest

import ydb
from sqlalchemy import Table, Column, Integer, Unicode
from sqlalchemy.testing.fixtures import TestBase, TablesTest, config
from ydb._grpc.v4.protos import ydb_common_pb2

from ydb_sqlalchemy import dbapi, IsolationLevel
from ydb_sqlalchemy.sqlalchemy import types


Expand Down Expand Up @@ -219,9 +220,9 @@ def _create_table_and_get_desc(connection, metadata, **kwargs):
)
table.create(connection)

session: ydb.Session = connection.connection.driver_connection.pool.acquire()
session: ydb.Session = connection.connection.driver_connection.session_pool.acquire()
table_description = session.describe_table("/local/" + table.name)
session.delete()
connection.connection.driver_connection.session_pool.release(session)
return table_description

@pytest.mark.parametrize(
Expand Down Expand Up @@ -367,3 +368,174 @@ def test_several_keys(self, connection, metadata):
assert desc.partitioning_settings.partitioning_by_load == 1
assert desc.partitioning_settings.min_partitions_count == 3
assert desc.partitioning_settings.max_partitions_count == 5


class TestTransaction(TablesTest):
@classmethod
def define_tables(cls, metadata: sa.MetaData):
Table(
"test",
metadata,
Column("id", Integer, primary_key=True),
)

def test_rollback(self, connection_no_trans: sa.Connection, connection: sa.Connection):
table = self.tables.test

connection_no_trans.execution_options(isolation_level=IsolationLevel.SERIALIZABLE)
with connection_no_trans.begin():
stm1 = table.insert().values(id=1)
connection_no_trans.execute(stm1)
stm2 = table.insert().values(id=2)
connection_no_trans.execute(stm2)
connection_no_trans.rollback()

cursor = connection.execute(sa.select(table))
result = cursor.fetchall()
assert result == []

def test_commit(self, connection_no_trans: sa.Connection, connection: sa.Connection):
table = self.tables.test

connection_no_trans.execution_options(isolation_level=IsolationLevel.SERIALIZABLE)
with connection_no_trans.begin():
stm1 = table.insert().values(id=3)
connection_no_trans.execute(stm1)
stm2 = table.insert().values(id=4)
connection_no_trans.execute(stm2)

cursor = connection.execute(sa.select(table))
result = cursor.fetchall()
assert set(result) == {(3,), (4,)}

@pytest.mark.parametrize("isolation_level", (IsolationLevel.SERIALIZABLE, IsolationLevel.SNAPSHOT_READONLY))
def test_interactive_transaction(
self, connection_no_trans: sa.Connection, connection: sa.Connection, isolation_level
):
table = self.tables.test
dbapi_connection: dbapi.Connection = connection_no_trans.connection.dbapi_connection

stm1 = table.insert().values([{"id": 5}, {"id": 6}])
connection.execute(stm1)

connection_no_trans.execution_options(isolation_level=isolation_level)
with connection_no_trans.begin():
tx_id = dbapi_connection.tx_context.tx_id
assert tx_id is not None
cursor1 = connection_no_trans.execute(sa.select(table))
cursor2 = connection_no_trans.execute(sa.select(table))
assert dbapi_connection.tx_context.tx_id == tx_id

assert set(cursor1.fetchall()) == {(5,), (6,)}
assert set(cursor2.fetchall()) == {(5,), (6,)}

@pytest.mark.parametrize(
"isolation_level",
(
IsolationLevel.ONLINE_READONLY,
IsolationLevel.ONLINE_READONLY_INCONSISTENT,
IsolationLevel.STALE_READONLY,
IsolationLevel.AUTOCOMMIT,
),
)
def test_not_interactive_transaction(
self, connection_no_trans: sa.Connection, connection: sa.Connection, isolation_level
):
table = self.tables.test
dbapi_connection: dbapi.Connection = connection_no_trans.connection.dbapi_connection

stm1 = table.insert().values([{"id": 7}, {"id": 8}])
connection.execute(stm1)

connection_no_trans.execution_options(isolation_level=isolation_level)
with connection_no_trans.begin():
assert dbapi_connection.tx_context is None
cursor1 = connection_no_trans.execute(sa.select(table))
cursor2 = connection_no_trans.execute(sa.select(table))
assert dbapi_connection.tx_context is None

assert set(cursor1.fetchall()) == {(7,), (8,)}
assert set(cursor2.fetchall()) == {(7,), (8,)}


class TestTransactionIsolationLevel(TestBase):
class IsolationSettings(NamedTuple):
ydb_mode: ydb.AbstractTransactionModeBuilder
interactive: bool

YDB_ISOLATION_SETTINGS_MAP = {
IsolationLevel.AUTOCOMMIT: IsolationSettings(ydb.SerializableReadWrite().name, False),
IsolationLevel.SERIALIZABLE: IsolationSettings(ydb.SerializableReadWrite().name, True),
IsolationLevel.ONLINE_READONLY: IsolationSettings(ydb.OnlineReadOnly().name, False),
IsolationLevel.ONLINE_READONLY_INCONSISTENT: IsolationSettings(
ydb.OnlineReadOnly().with_allow_inconsistent_reads().name, False
),
IsolationLevel.STALE_READONLY: IsolationSettings(ydb.StaleReadOnly().name, False),
IsolationLevel.SNAPSHOT_READONLY: IsolationSettings(ydb.SnapshotReadOnly().name, True),
}

def test_connection_set(self, connection_no_trans: sa.Connection):
dbapi_connection: dbapi.Connection = connection_no_trans.connection.dbapi_connection

for sa_isolation_level, ydb_isolation_settings in self.YDB_ISOLATION_SETTINGS_MAP.items():
connection_no_trans.execution_options(isolation_level=sa_isolation_level)
with connection_no_trans.begin():
assert dbapi_connection.tx_mode.name == ydb_isolation_settings[0]
assert dbapi_connection.interactive_transaction is ydb_isolation_settings[1]
if dbapi_connection.interactive_transaction:
assert dbapi_connection.tx_context is not None
assert dbapi_connection.tx_context.tx_id is not None
else:
assert dbapi_connection.tx_context is None


class TestEngine(TestBase):
@pytest.fixture(scope="module")
def ydb_driver(self):
url = config.db_url
driver = ydb.Driver(endpoint=f"grpc://{url.host}:{url.port}", database=url.database)
try:
driver.wait(timeout=5, fail_fast=True)
yield driver
finally:
driver.stop()

driver.stop()

@pytest.fixture(scope="module")
def ydb_pool(self, ydb_driver):
session_pool = ydb.SessionPool(ydb_driver, size=5, workers_threads_count=1)

yield session_pool

session_pool.stop()

def test_sa_queue_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool):
engine1 = sa.create_engine(config.db_url, poolclass=sa.QueuePool, connect_args={"ydb_session_pool": ydb_pool})
engine2 = sa.create_engine(config.db_url, poolclass=sa.QueuePool, connect_args={"ydb_session_pool": ydb_pool})

with engine1.connect() as conn1, engine2.connect() as conn2:
dbapi_conn1: dbapi.Connection = conn1.connection.dbapi_connection
dbapi_conn2: dbapi.Connection = conn2.connection.dbapi_connection

assert dbapi_conn1.session_pool is dbapi_conn2.session_pool
assert dbapi_conn1.driver is dbapi_conn2.driver

engine1.dispose()
engine2.dispose()
assert not ydb_driver._stopped

def test_sa_null_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool):
engine1 = sa.create_engine(config.db_url, poolclass=sa.NullPool, connect_args={"ydb_session_pool": ydb_pool})
engine2 = sa.create_engine(config.db_url, poolclass=sa.NullPool, connect_args={"ydb_session_pool": ydb_pool})

with engine1.connect() as conn1, engine2.connect() as conn2:
dbapi_conn1: dbapi.Connection = conn1.connection.dbapi_connection
dbapi_conn2: dbapi.Connection = conn2.connection.dbapi_connection

assert dbapi_conn1.session_pool is dbapi_conn2.session_pool
assert dbapi_conn1.driver is dbapi_conn2.driver

engine1.dispose()
engine2.dispose()
assert not ydb_driver._stopped
3 changes: 2 additions & 1 deletion test_dbapi/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import pytest

import ydb_sqlalchemy.dbapi as dbapi


@pytest.fixture(scope="module")
def connection():
conn = dbapi.connect("localhost:2136", database="/local")
conn = dbapi.connect(host="localhost", port="2136", database="/local")
yield conn
conn.close()
1 change: 1 addition & 0 deletions ydb_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .dbapi import IsolationLevel # noqa: F401
2 changes: 1 addition & 1 deletion ydb_sqlalchemy/dbapi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .connection import Connection
from .connection import Connection, IsolationLevel # noqa: F401
from .cursor import Cursor, YdbQuery # noqa: F401
from .errors import (
Warning,
Expand Down
133 changes: 110 additions & 23 deletions ydb_sqlalchemy/dbapi/connection.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,54 @@
import posixpath
from typing import Optional, NamedTuple, Any

import ydb

from .cursor import Cursor
from .errors import InterfaceError, ProgrammingError, DatabaseError
from .errors import InterfaceError, ProgrammingError, DatabaseError, InternalError, NotSupportedError


class IsolationLevel:
SERIALIZABLE = "SERIALIZABLE"
ONLINE_READONLY = "ONLINE READONLY"
ONLINE_READONLY_INCONSISTENT = "ONLINE READONLY INCONSISTENT"
STALE_READONLY = "STALE READONLY"
SNAPSHOT_READONLY = "SNAPSHOT READONLY"
AUTOCOMMIT = "AUTOCOMMIT"


class Connection:
def __init__(self, endpoint=None, host=None, port=None, database=None, **conn_kwargs):
self.endpoint = endpoint or f"grpc://{host}:{port}"
def __init__(
self,
host: str = "",
port: str = "",
database: str = "",
**conn_kwargs: Any,
):
self.endpoint = f"grpc://{host}:{port}"
self.database = database
self.driver = self._create_driver(self.endpoint, self.database, **conn_kwargs)
self.pool = ydb.SessionPool(self.driver)
self.conn_kwargs = conn_kwargs

if "ydb_session_pool" in self.conn_kwargs: # Use session pool managed manually
self._shared_session_pool = True
self.session_pool: ydb.SessionPool = self.conn_kwargs.pop("ydb_session_pool")
self.driver = self.session_pool._pool_impl._driver
self.driver.table_client = ydb.TableClient(self.driver, self._get_table_client_settings())
else:
self._shared_session_pool = False
self.driver = self._create_driver()
self.session_pool = ydb.SessionPool(self.driver, size=5, workers_threads_count=1)

self.interactive_transaction: bool = False # AUTOCOMMIT
self.tx_mode: ydb.AbstractTransactionModeBuilder = ydb.SerializableReadWrite()
self.tx_context: Optional[ydb.TxContext] = None

def cursor(self):
return Cursor(self)
return Cursor(self.session_pool, self.tx_context)

def describe(self, table_path):
full_path = posixpath.join(self.database, table_path)
try:
return self.pool.retry_operation_sync(lambda cli: cli.describe_table(full_path))
return self.session_pool.retry_operation_sync(lambda session: session.describe_table(full_path))
except ydb.issues.SchemeError as e:
raise ProgrammingError(e.message, e.issues, e.status) from e
except ydb.Error as e:
Expand All @@ -33,31 +63,85 @@ def check_exists(self, table_path):
except ydb.SchemeError:
return False

def set_isolation_level(self, isolation_level: str):
class IsolationSettings(NamedTuple):
ydb_mode: ydb.AbstractTransactionModeBuilder
interactive: bool

ydb_isolation_settings_map = {
IsolationLevel.AUTOCOMMIT: IsolationSettings(ydb.SerializableReadWrite(), interactive=False),
IsolationLevel.SERIALIZABLE: IsolationSettings(ydb.SerializableReadWrite(), interactive=True),
IsolationLevel.ONLINE_READONLY: IsolationSettings(ydb.OnlineReadOnly(), interactive=False),
IsolationLevel.ONLINE_READONLY_INCONSISTENT: IsolationSettings(
ydb.OnlineReadOnly().with_allow_inconsistent_reads(), interactive=False
),
IsolationLevel.STALE_READONLY: IsolationSettings(ydb.StaleReadOnly(), interactive=False),
IsolationLevel.SNAPSHOT_READONLY: IsolationSettings(ydb.SnapshotReadOnly(), interactive=True),
}
ydb_isolation_settings = ydb_isolation_settings_map[isolation_level]
if self.tx_context and self.tx_context.tx_id:
raise InternalError("Failed to set transaction mode: transaction is already began")
self.tx_mode = ydb_isolation_settings.ydb_mode
self.interactive_transaction = ydb_isolation_settings.interactive

def get_isolation_level(self) -> str:
if self.tx_mode.name == ydb.SerializableReadWrite().name:
if self.interactive_transaction:
return IsolationLevel.SERIALIZABLE
else:
return IsolationLevel.AUTOCOMMIT
elif self.tx_mode.name == ydb.OnlineReadOnly().name:
if self.tx_mode.settings.allow_inconsistent_reads:
return IsolationLevel.ONLINE_READONLY_INCONSISTENT
else:
return IsolationLevel.ONLINE_READONLY
elif self.tx_mode.name == ydb.StaleReadOnly().name:
return IsolationLevel.STALE_READONLY
elif self.tx_mode.name == ydb.SnapshotReadOnly().name:
return IsolationLevel.SNAPSHOT_READONLY
else:
raise NotSupportedError(f"{self.tx_mode.name} is not supported")

def begin(self):
self.tx_context = None
if self.interactive_transaction:
session = self.session_pool.acquire(blocking=True)
self.tx_context = session.transaction(self.tx_mode)
self.tx_context.begin()

def commit(self):
pass
if self.tx_context and self.tx_context.tx_id:
self.tx_context.commit()
self.session_pool.release(self.tx_context.session)
self.tx_context = None

def rollback(self):
pass
if self.tx_context and self.tx_context.tx_id:
self.tx_context.rollback()
self.session_pool.release(self.tx_context.session)
self.tx_context = None

def close(self):
if self.pool:
self.pool.stop()
if self.driver:
self.driver.stop()

@staticmethod
def _create_driver(endpoint, database, **conn_kwargs):
# TODO: add cache for initialized drivers/pools?
driver_config = ydb.DriverConfig(
endpoint,
database=database,
table_client_settings=ydb.TableClientSettings()
self.rollback()
if not self._shared_session_pool:
self.session_pool.stop()
self._stop_driver()

def _get_table_client_settings(self) -> ydb.TableClientSettings:
return (
ydb.TableClientSettings()
.with_native_date_in_result_sets(True)
.with_native_datetime_in_result_sets(True)
.with_native_timestamp_in_result_sets(True)
.with_native_interval_in_result_sets(True)
.with_native_json_in_result_sets(True),
**conn_kwargs,
.with_native_json_in_result_sets(True)
)

def _create_driver(self):
driver_config = ydb.DriverConfig(
endpoint=self.endpoint,
database=self.database,
table_client_settings=self._get_table_client_settings(),
)
driver = ydb.Driver(driver_config)
try:
Expand All @@ -68,3 +152,6 @@ def _create_driver(endpoint, database, **conn_kwargs):
driver.stop()
raise InterfaceError(f"Failed to connect to YDB, details {driver.discovery_debug_details()}") from e
return driver

def _stop_driver(self):
self.driver.stop()
Loading

0 comments on commit dc7381a

Please sign in to comment.