diff --git a/moonstreamdb-v3/moonstreamdbv3/db.py b/moonstreamdb-v3/moonstreamdbv3/db.py index 9322b15d..92563115 100644 --- a/moonstreamdb-v3/moonstreamdbv3/db.py +++ b/moonstreamdb-v3/moonstreamdbv3/db.py @@ -163,6 +163,39 @@ def yield_db_read_only_session(self) -> Generator[Session, None, None]: session.close() +class MoonstreamCustomDBEngine(DBEngine): + def __init__(self, url: str, schema: Optional[str] = None) -> None: + super().__init__(url=url, schema=schema) + + logger.warning("Initialized custom database engine with specified URI") + + self._session_local = sessionmaker(bind=self.engine) + + self._yield_db_session_ctx = contextmanager(self.yield_db_session) + + @property + def session_local(self): + return self._session_local + + @property + def yield_db_session_ctx(self): + return self._yield_db_session_ctx + + def yield_db_session( + self, + ) -> Generator[Session, None, None]: + """ + Yields a database connection (created using environment variables). + As per FastAPI docs: + https://fastapi.tiangolo.com/tutorial/sql-databases/#create-a-dependency + """ + session = self._session_local() + try: + yield session + finally: + session.close() + + class MoonstreamDBIndexesEngine(DBEngine): def __init__(self, schema: Optional[str] = None) -> None: super().__init__(url=MOONSTREAM_DB_V3_INDEXES_URI, schema=schema) diff --git a/moonstreamdb-v3/moonstreamdbv3/version.txt b/moonstreamdb-v3/moonstreamdbv3/version.txt index c5d54ec3..7c1886bb 100644 --- a/moonstreamdb-v3/moonstreamdbv3/version.txt +++ b/moonstreamdb-v3/moonstreamdbv3/version.txt @@ -1 +1 @@ -0.0.9 +0.0.10