diff --git a/asyncdb/drivers/pg.py b/asyncdb/drivers/pg.py index 1a3fa8d1..de4ede42 100644 --- a/asyncdb/drivers/pg.py +++ b/asyncdb/drivers/pg.py @@ -37,6 +37,7 @@ QueryCanceledError ) from asyncpg.pgproto import pgproto +from pgvector.asyncpg import register_vector from ..exceptions import ( ConnectionTimeout, DriverError, @@ -51,7 +52,6 @@ from ..models import Model from ..utils.encoders import DefaultEncoder from ..utils.types import Entity - from .base import BasePool from .sql import SQLCursor, SQLDriver @@ -115,6 +115,7 @@ def __init__( self._custom_record: bool = custom_record self._record_class_ = kwargs.get("record_class", pgRecord) self._cache_size: int = kwargs.get("cache_size", 36000) + self._enable_vector: bool = kwargs.get("enable_vector", False) # max_inactive_connection_lifetime self._max_inactive_timeout = kwargs.pop("max_inactive_timeout", 360000) if "server_settings" in kwargs: @@ -195,12 +196,27 @@ def _encoder(value): def _decoder(value): return self._encoder.loads(value) # pylint: disable=E1120 - await connection.set_type_codec("json", encoder=_encoder, decoder=_decoder, schema="pg_catalog") - await connection.set_type_codec("jsonb", encoder=_encoder, decoder=_decoder, schema="pg_catalog") - try: - await connection.set_builtin_type_codec("hstore", codec_name="pg_contrib.hstore") - except Exception: - pass + await connection.set_type_codec( + "json", + encoder=_encoder, + decoder=_decoder, + schema="pg_catalog" + ) + await connection.set_type_codec( + "jsonb", + encoder=_encoder, + decoder=_decoder, + schema="pg_catalog" + ) + + if self._enable_vector: + await register_vector(connection) + + with contextlib.suppress(Exception): + await connection.set_builtin_type_codec( + "hstore", + codec_name="pg_contrib.hstore" + ) def _uuid_encoder(value): if isinstance(value, uuid.UUID): @@ -211,7 +227,7 @@ def _uuid_encoder(value): val = b"" return val - try: + with contextlib.suppress(Exception): await connection.set_type_codec( "uuid", encoder=_uuid_encoder, @@ -219,8 +235,6 @@ def _uuid_encoder(value): schema="pg_catalog", format="binary", ) - except Exception: - pass if self._connection_config and isinstance(self._connection_config, dict): for key, value in self._connection_config.items(): config = f"SELECT set_config('{key}', '{value}', false);" @@ -467,6 +481,7 @@ def __init__(self, dsn: str = "", loop: asyncio.AbstractEventLoop = None, params self._custom_record: bool = kwargs.get("custom_record", False) self._record_class_ = kwargs.get("record_class", pgRecord) self._cache_size: int = kwargs.get("cache_size", 36000) + self._enable_vector: bool = kwargs.get("enable_vector", False) # max_inactive_connection_lifetime self._max_inactive_timeout = kwargs.pop("max_inactive_timeout", 360000) DBCursorBackend.__init__(self) @@ -555,6 +570,55 @@ def is_connected(self): return not (self._connection.is_closed()) return self._connected + async def _register_codecs(self, conn, schema: str = 'pg_catalog'): + """ + Register codecs for the connection. + """ + # Setup jsonb encoder/decoder + def _encoder(value): + return self._encoder.dumps(value) # pylint: disable=E1120 + + def _decoder(value): + return self._encoder.loads(value) # pylint: disable=E1120 + + await conn.set_type_codec( + "json", encoder=_encoder, decoder=_decoder, schema=schema + ) + await conn.set_type_codec( + "jsonb", encoder=_encoder, decoder=_decoder, schema=schema + ) + + with contextlib.suppress(Exception): + await conn.set_builtin_type_codec( + "hstore", + codec_name="pg_contrib.hstore" + ) + + def _uuid_encoder(value): + if isinstance(value, uuid.UUID): + val = value.bytes + elif value is not None: + val = uuid.UUID(bytes=value) + else: + val = b"" + return val + + def _uuid_decoder(value): + return b"" if value is None else uuid.UUID(bytes=value) + + with contextlib.suppress(Exception): + await conn.set_type_codec( + "uuid", + encoder=_uuid_encoder, + decoder=_uuid_decoder, + schema=schema, + format="binary", + ) + + if self._enable_vector: + await register_vector(conn) + + async def connection(self): """connection. @@ -565,24 +629,14 @@ async def connection(self): return self self._connection = None self._connected = False - # Setup jsonb encoder/decoder - - def _encoder(value): - return self._encoder.dumps(value) # pylint: disable=E1120 - - def _decoder(value): - return self._encoder.loads(value) # pylint: disable=E1120 server_settings = { "application_name": self.application_name, "idle_session_timeout": "120min", - # "tcp_keepalives_idle": "36000", - # "max_parallel_workers": "512", + "max_parallel_workers": "512", } server_settings = {**server_settings, **self._server_settings} - _ssl = {} - if self.ssl: - _ssl = {"ssl": self.sslctx} + _ssl = {"ssl": self.sslctx} if self.ssl else {} custom_class = {} if self._custom_record: custom_class = {"record_class": self._record_class_} @@ -590,10 +644,6 @@ def _decoder(value): if self._pool and not self._connection: self._connection = await self._pool.pool().acquire() else: - # try: - # loop = self._loop or asyncio.get_running_loop() - # except RuntimeError: - # loop = asyncio.get_event_loop() self._connection = await asyncpg.connect( dsn=self._dsn, timeout=self._timeout, @@ -602,45 +652,12 @@ def _decoder(value): # connection_class=NAVConnection, max_cached_statement_lifetime=600, max_cacheable_statement_size=1024 * 30, - # loop=loop, **custom_class, **_ssl, ) - await self._connection.set_type_codec( - "json", encoder=_encoder, decoder=_decoder, schema="pg_catalog" - ) - await self._connection.set_type_codec( - "jsonb", encoder=_encoder, decoder=_decoder, schema="pg_catalog" - ) - with contextlib.suppress(Exception): - await self._connection.set_builtin_type_codec( - "hstore", - codec_name="pg_contrib.hstore" - ) - - def _uuid_encoder(value): - if isinstance(value, uuid.UUID): - val = value.bytes - elif value is not None: - val = uuid.UUID(bytes=value) - else: - val = b"" - return val - - def _uuid_decoder(value): - return b"" if value is None else uuid.UUID(bytes=value) - try: - await self._connection.set_type_codec( - "uuid", - encoder=_uuid_encoder, - decoder=_uuid_decoder, - schema="pg_catalog", - format="binary", - ) - except Exception: - pass if self._connection: + await self._register_codecs(self._connection) self._connected = True if self._connection_config and isinstance(self._connection_config, dict): for key, value in self._connection_config.items():