diff --git a/pyproject.toml b/pyproject.toml index e402727150..9561cd2761 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,6 +50,9 @@ fastapi = "^0.68.1" requests = "^2.26.0" autoflake = "^1.4" isort = "^5.9.3" +testcontainers = "^3.7.1" +psycopg2-binary = "^2.9.7" +asyncpg = "^0.28.0" [build-system] requires = ["poetry-core"] diff --git a/sqlmodel/ext/asyncio/__init__.py b/sqlmodel/ext/asyncio/__init__.py index e69de29bb2..0af81880e6 100644 --- a/sqlmodel/ext/asyncio/__init__.py +++ b/sqlmodel/ext/asyncio/__init__.py @@ -0,0 +1,2 @@ +from .engine import create_async_engine as create_async_engine +from .session import AsyncSession as AsyncSession diff --git a/sqlmodel/ext/asyncio/engine.py b/sqlmodel/ext/asyncio/engine.py new file mode 100644 index 0000000000..92c0dff377 --- /dev/null +++ b/sqlmodel/ext/asyncio/engine.py @@ -0,0 +1,10 @@ +from typing import Any + +from sqlalchemy.ext.asyncio import AsyncEngine +from sqlalchemy.ext.asyncio import create_async_engine as _create_async_engine + + +# create_async_engine by default already has future set to be true. +# Porting this over to sqlmodel to make it easier to use. +def create_async_engine(*args: Any, **kwargs: Any) -> AsyncEngine: + return _create_async_engine(*args, **kwargs) diff --git a/sqlmodel/ext/asyncio/session.py b/sqlmodel/ext/asyncio/session.py index 80267b25e5..79dae568a6 100644 --- a/sqlmodel/ext/asyncio/session.py +++ b/sqlmodel/ext/asyncio/session.py @@ -9,7 +9,7 @@ from ...engine.result import ScalarResult from ...orm.session import Session -from ...sql.expression import Select +from ...sql.expression import Select, SelectOfScalar _T = TypeVar("_T") @@ -42,7 +42,7 @@ def __init__( async def exec( self, - statement: Union[Select[_T], Executable[_T]], + statement: Union[Select[_T], SelectOfScalar[_T], Executable[_T]], params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, execution_options: Mapping[Any, Any] = util.EMPTY_DICT, bind_arguments: Optional[Mapping[str, Any]] = None, diff --git a/tests/test_async.py b/tests/test_async.py new file mode 100644 index 0000000000..64b7a10357 --- /dev/null +++ b/tests/test_async.py @@ -0,0 +1,52 @@ +import asyncio +from typing import Generator, Optional + +import pytest +from sqlmodel import Field, SQLModel, select +from sqlmodel.ext.asyncio import AsyncSession, create_async_engine +from testcontainers.postgres import PostgresContainer + + +# The first time this test is run, it will download the postgres image which can take +# a while. Subsequent runs will be much faster. +@pytest.fixture(scope="module") +def postgres_container_url() -> Generator[str, None, None]: + with PostgresContainer("postgres:13") as postgres: + postgres.driver = "asyncpg" + yield postgres.get_connection_url() + + +async def _test_async_create(postgres_container_url: str) -> None: + class Hero(SQLModel, table=True): + # SQLModel.metadata is a singleton and the Hero Class has already been defined. + # If I flush the metadata during this test, it will cause test_enum to fail + # because in that file, the model isn't defined within a function. For now, the + # workaround is to set extend_existing to True. In the future, test setup and + # teardown should be refactored to avoid this issue. + __table_args__ = {"extend_existing": True} + + id: Optional[int] = Field(default=None, primary_key=True) + name: str + secret_name: str + age: Optional[int] = None + + hero_create = Hero(name="Deadpond", secret_name="Dive Wilson") + + engine = create_async_engine(postgres_container_url) + async with engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.create_all) + + async with AsyncSession(engine) as session: + session.add(hero_create) + await session.commit() + await session.refresh(hero_create) + + async with AsyncSession(engine) as session: + statement = select(Hero).where(Hero.name == "Deadpond") + results = await session.exec(statement) + hero_query = results.one() + assert hero_create == hero_query + + +def test_async_create(postgres_container_url: str) -> None: + asyncio.run(_test_async_create(postgres_container_url))