Skip to content

Commit

Permalink
postgres
Browse files Browse the repository at this point in the history
  • Loading branch information
phenobarbital committed Feb 19, 2025
1 parent dc87f43 commit 3a10b65
Showing 1 changed file with 77 additions and 60 deletions.
137 changes: 77 additions & 60 deletions asyncdb/drivers/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
QueryCanceledError
)
from asyncpg.pgproto import pgproto
from pgvector.asyncpg import register_vector
from ..exceptions import (
ConnectionTimeout,
DriverError,
Expand All @@ -51,7 +52,6 @@
from ..models import Model
from ..utils.encoders import DefaultEncoder
from ..utils.types import Entity

from .base import BasePool
from .sql import SQLCursor, SQLDriver

Expand Down Expand Up @@ -115,6 +115,7 @@ def __init__(
self._custom_record: bool = custom_record
self._record_class_ = kwargs.get("record_class", pgRecord)
self._cache_size: int = kwargs.get("cache_size", 36000)
self._enable_vector: bool = kwargs.get("enable_vector", False)
# max_inactive_connection_lifetime
self._max_inactive_timeout = kwargs.pop("max_inactive_timeout", 360000)
if "server_settings" in kwargs:
Expand Down Expand Up @@ -195,12 +196,27 @@ def _encoder(value):
def _decoder(value):
return self._encoder.loads(value) # pylint: disable=E1120

await connection.set_type_codec("json", encoder=_encoder, decoder=_decoder, schema="pg_catalog")
await connection.set_type_codec("jsonb", encoder=_encoder, decoder=_decoder, schema="pg_catalog")
try:
await connection.set_builtin_type_codec("hstore", codec_name="pg_contrib.hstore")
except Exception:
pass
await connection.set_type_codec(
"json",
encoder=_encoder,
decoder=_decoder,
schema="pg_catalog"
)
await connection.set_type_codec(
"jsonb",
encoder=_encoder,
decoder=_decoder,
schema="pg_catalog"
)

if self._enable_vector:
await register_vector(connection)

with contextlib.suppress(Exception):
await connection.set_builtin_type_codec(
"hstore",
codec_name="pg_contrib.hstore"
)

def _uuid_encoder(value):
if isinstance(value, uuid.UUID):
Expand All @@ -211,16 +227,14 @@ def _uuid_encoder(value):
val = b""
return val

try:
with contextlib.suppress(Exception):
await connection.set_type_codec(
"uuid",
encoder=_uuid_encoder,
decoder=lambda u: pgproto.UUID(u), # pylint: disable=I1101,W0108
schema="pg_catalog",
format="binary",
)
except Exception:
pass
if self._connection_config and isinstance(self._connection_config, dict):
for key, value in self._connection_config.items():
config = f"SELECT set_config('{key}', '{value}', false);"
Expand Down Expand Up @@ -467,6 +481,7 @@ def __init__(self, dsn: str = "", loop: asyncio.AbstractEventLoop = None, params
self._custom_record: bool = kwargs.get("custom_record", False)
self._record_class_ = kwargs.get("record_class", pgRecord)
self._cache_size: int = kwargs.get("cache_size", 36000)
self._enable_vector: bool = kwargs.get("enable_vector", False)
# max_inactive_connection_lifetime
self._max_inactive_timeout = kwargs.pop("max_inactive_timeout", 360000)
DBCursorBackend.__init__(self)
Expand Down Expand Up @@ -555,6 +570,55 @@ def is_connected(self):
return not (self._connection.is_closed())
return self._connected

async def _register_codecs(self, conn, schema: str = 'pg_catalog'):
"""
Register codecs for the connection.
"""
# Setup jsonb encoder/decoder
def _encoder(value):
return self._encoder.dumps(value) # pylint: disable=E1120

def _decoder(value):
return self._encoder.loads(value) # pylint: disable=E1120

await conn.set_type_codec(
"json", encoder=_encoder, decoder=_decoder, schema=schema
)
await conn.set_type_codec(
"jsonb", encoder=_encoder, decoder=_decoder, schema=schema
)

with contextlib.suppress(Exception):
await conn.set_builtin_type_codec(
"hstore",
codec_name="pg_contrib.hstore"
)

def _uuid_encoder(value):
if isinstance(value, uuid.UUID):
val = value.bytes
elif value is not None:
val = uuid.UUID(bytes=value)
else:
val = b""
return val

def _uuid_decoder(value):
return b"" if value is None else uuid.UUID(bytes=value)

with contextlib.suppress(Exception):
await conn.set_type_codec(
"uuid",
encoder=_uuid_encoder,
decoder=_uuid_decoder,
schema=schema,
format="binary",
)

if self._enable_vector:
await register_vector(conn)


async def connection(self):
"""connection.
Expand All @@ -565,35 +629,21 @@ async def connection(self):
return self
self._connection = None
self._connected = False
# Setup jsonb encoder/decoder

def _encoder(value):
return self._encoder.dumps(value) # pylint: disable=E1120

def _decoder(value):
return self._encoder.loads(value) # pylint: disable=E1120

server_settings = {
"application_name": self.application_name,
"idle_session_timeout": "120min",
# "tcp_keepalives_idle": "36000",
# "max_parallel_workers": "512",
"max_parallel_workers": "512",
}
server_settings = {**server_settings, **self._server_settings}
_ssl = {}
if self.ssl:
_ssl = {"ssl": self.sslctx}
_ssl = {"ssl": self.sslctx} if self.ssl else {}
custom_class = {}
if self._custom_record:
custom_class = {"record_class": self._record_class_}
try:
if self._pool and not self._connection:
self._connection = await self._pool.pool().acquire()
else:
# try:
# loop = self._loop or asyncio.get_running_loop()
# except RuntimeError:
# loop = asyncio.get_event_loop()
self._connection = await asyncpg.connect(
dsn=self._dsn,
timeout=self._timeout,
Expand All @@ -602,45 +652,12 @@ def _decoder(value):
# connection_class=NAVConnection,
max_cached_statement_lifetime=600,
max_cacheable_statement_size=1024 * 30,
# loop=loop,
**custom_class,
**_ssl,
)
await self._connection.set_type_codec(
"json", encoder=_encoder, decoder=_decoder, schema="pg_catalog"
)
await self._connection.set_type_codec(
"jsonb", encoder=_encoder, decoder=_decoder, schema="pg_catalog"
)
with contextlib.suppress(Exception):
await self._connection.set_builtin_type_codec(
"hstore",
codec_name="pg_contrib.hstore"
)

def _uuid_encoder(value):
if isinstance(value, uuid.UUID):
val = value.bytes
elif value is not None:
val = uuid.UUID(bytes=value)
else:
val = b""
return val

def _uuid_decoder(value):
return b"" if value is None else uuid.UUID(bytes=value)

try:
await self._connection.set_type_codec(
"uuid",
encoder=_uuid_encoder,
decoder=_uuid_decoder,
schema="pg_catalog",
format="binary",
)
except Exception:
pass
if self._connection:
await self._register_codecs(self._connection)
self._connected = True
if self._connection_config and isinstance(self._connection_config, dict):
for key, value in self._connection_config.items():
Expand Down

0 comments on commit 3a10b65

Please sign in to comment.