Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transaction support #16

Merged
merged 8 commits into from
Jul 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions src/mayim/executor/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,3 @@ async def _run_sql(
return None
raw = await getattr(cursor, method_name)()
return raw

def _get_method(self, as_list: bool):
return "fetchall" if as_list else "fetchone"
3 changes: 0 additions & 3 deletions src/mayim/executor/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,3 @@ async def _run_sql(
cursor.row_factory = dict_row
raw = await getattr(cursor, method_name)()
return raw

def _get_method(self, as_list: bool):
return "fetchall" if as_list else "fetchone"
31 changes: 29 additions & 2 deletions src/mayim/executor/sql.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from contextlib import asynccontextmanager
from functools import wraps
from inspect import getmembers, isawaitable, isfunction, signature
from pathlib import Path
Expand Down Expand Up @@ -111,8 +112,34 @@ async def _run_sql(
):
...

def _get_method(self, as_list: bool):
...
async def rollback(self) -> None:
existing = self.pool.existing_connection()
transaction = self.pool.in_transaction()
if not existing or not transaction:
raise MayimError("Cannot rollback non-existing transaction")
await self._rollback(existing)

async def _rollback(self, existing) -> None:
self.pool._commit.set(False)
await existing.rollback()

def _get_method(self, as_list: bool) -> str:
return "fetchall" if as_list else "fetchone"

@asynccontextmanager
async def transaction(self):
self.pool._transaction.set(True)
async with self.pool.connection() as conn:
self.pool._connection.set(conn)
try:
yield
except Exception:
await self.rollback()
raise
finally:
self.pool._connection.set(None)
self.pool._transaction.set(False)
self.pool._commit.set(True)

@classmethod
def _load(cls) -> None:
Expand Down
19 changes: 18 additions & 1 deletion src/mayim/interface/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from abc import ABC, abstractmethod
from collections import namedtuple
from typing import Optional, Set, Type
from contextvars import ContextVar
from typing import Any, Optional, Set, Type
from urllib.parse import urlparse

from mayim.exception import MayimError
Expand Down Expand Up @@ -92,6 +93,13 @@ def __init__(
self._password = password
self._db = db
self._full_dsn: Optional[str] = None
self._connection: ContextVar[Any] = ContextVar(
"connection", default=None
)
self._transaction: ContextVar[bool] = ContextVar(
"transaction", default=False
)
self._commit: ContextVar[bool] = ContextVar("commit", default=True)

self._populate_connection_args()
self._populate_dsn()
Expand Down Expand Up @@ -156,3 +164,12 @@ def db(self):
@property
def full_dsn(self):
return self._full_dsn

def existing_connection(self):
return self._connection.get()

def in_transaction(self) -> bool:
return self._transaction.get()

def do_commit(self) -> bool:
return self._commit.get()
21 changes: 17 additions & 4 deletions src/mayim/interface/mysql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import AsyncContextManager, Optional
from contextlib import asynccontextmanager
from typing import AsyncIterator, Optional

import asyncmy

Expand All @@ -24,7 +25,19 @@ async def close(self):
self._pool.close()
await self._pool.wait_closed()

def connection(
@asynccontextmanager
async def connection(
self, timeout: Optional[float] = None
) -> AsyncContextManager[asyncmy.contexts._PoolAcquireContextManager]:
return self._pool.acquire()
) -> AsyncIterator[asyncmy.Connection]:
existing = self.existing_connection()
if existing:
yield existing
else:
transaction = self.in_transaction()
async with self._pool.acquire() as conn:
if transaction:
await conn.begin()
yield conn
if transaction:
if self.do_commit():
await conn.commit()
15 changes: 11 additions & 4 deletions src/mayim/interface/postgres.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import AsyncContextManager, Optional
from contextlib import asynccontextmanager
from typing import AsyncIterator, Optional

from psycopg import AsyncConnection
from psycopg_pool import AsyncConnectionPool
Expand All @@ -18,7 +19,13 @@ async def open(self):
async def close(self):
await self._pool.close()

def connection(
@asynccontextmanager
async def connection(
self, timeout: Optional[float] = None
) -> AsyncContextManager[AsyncConnection]:
return self._pool.connection(timeout=timeout)
) -> AsyncIterator[AsyncConnection]:
existing = self._connection.get(None)
if existing:
yield existing
else:
async with self._pool.connection(timeout=timeout) as conn:
yield conn