Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add SQL adapter #779

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ jobs:
coverage report
env:
# Provide test suite with a PostgreSQL database to use.
TILED_TEST_POSTGRESQL_URI: postgresql+asyncpg://postgres:secret@localhost:5432
TILED_TEST_POSTGRESQL_URI: postgresql://postgres:secret@localhost:5432
# Opt in to LDAPAuthenticator tests.
TILED_TEST_LDAP: 1

Expand Down
6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ tiled = "tiled.commandline.main:main"

# This is the union of all optional dependencies.
all = [
"adbc_driver_manager",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section is used when tiled is installed like pip install "tiled[all]". These three should also be added to the section server, below, so that they are included when tiled is installed like pip install "tiled[server]" # server only.

"adbc_driver_sqlite",
"adbc_driver_postgresql",
"aiofiles",
"aiosqlite",
"alembic",
Expand Down Expand Up @@ -225,6 +228,9 @@ minimal-server = [
]
# This is the "kichen sink" fully-featured server dependency set.
server = [
"adbc_driver_manager",
"adbc_driver_sqlite",
"adbc_driver_postgresql",
"aiofiles",
"aiosqlite",
"alembic",
Expand Down
126 changes: 126 additions & 0 deletions tiled/_tests/adapters/test_sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import os
import tempfile

import adbc_driver_sqlite
import pyarrow as pa
import pytest

from tiled.adapters.sql import SQLAdapter
from tiled.structures.table import TableStructure

names = ["f0", "f1", "f2"]
data0 = [
pa.array([1, 2, 3, 4, 5]),
pa.array(["foo0", "bar0", "baz0", None, "goo0"]),
pa.array([True, None, False, True, None]),
]
data1 = [
pa.array([6, 7, 8, 9, 10, 11, 12]),
pa.array(["foo1", "bar1", None, "baz1", "biz", None, "goo"]),
pa.array([None, True, True, False, False, None, True]),
]
data2 = [pa.array([13, 14]), pa.array(["foo2", "baz2"]), pa.array([False, None])]

batch0 = pa.record_batch(data0, names=names)
batch1 = pa.record_batch(data1, names=names)
batch2 = pa.record_batch(data2, names=names)


def test_invalid_uri() -> None:
data_uri = "/some_random_uri/test.db"
table = pa.Table.from_arrays(data0, names)
structure = TableStructure.from_arrow_table(table, npartitions=1)
asset = SQLAdapter.init_storage(data_uri, structure=structure)
with pytest.raises(
ValueError,
match="The database uri must start with either `sqlite://` or `postgresql://` ",
):
SQLAdapter(asset.data_uri, structure=structure)


def test_invalid_structure() -> None:
data_uri = "/some_random_uri/test.db"
table = pa.Table.from_arrays(data0, names)
structure = TableStructure.from_arrow_table(table, npartitions=3)
with pytest.raises(ValueError, match="The SQL adapter must have only 1 partition"):
SQLAdapter.init_storage(data_uri, structure=structure)


@pytest.fixture
def adapter_sql() -> SQLAdapter:
data_uri = "sqlite://file://localhost" + tempfile.gettempdir() + "/test.db"
table = pa.Table.from_arrays(data0, names)
structure = TableStructure.from_arrow_table(table, npartitions=1)
asset = SQLAdapter.init_storage(data_uri, structure=structure)
return SQLAdapter(asset.data_uri, structure=structure)


def test_attributes(adapter_sql: SQLAdapter) -> None:
assert adapter_sql.structure().columns == names
assert adapter_sql.structure().npartitions == 1
assert isinstance(adapter_sql.conn, adbc_driver_sqlite.dbapi.AdbcSqliteConnection)


def test_write_read(adapter_sql: SQLAdapter) -> None:
# test writing and reading it
adapter_sql.write(batch0)
result = adapter_sql.read()
# the pandas dataframe gives the last column of the data as 0 and 1 since SQL does not save boolean
# so we explicitely convert the last column to boolean for testing purposes
result["f2"] = result["f2"].astype("boolean")

assert pa.Table.from_arrays(data0, names) == pa.Table.from_pandas(result)

adapter_sql.write([batch0, batch1])
result = adapter_sql.read()
# the pandas dataframe gives the last column of the data as 0 and 1 since SQL does not save boolean
# so we explicitely convert the last column to boolean for testing purposes
result["f2"] = result["f2"].astype("boolean")
assert pa.Table.from_batches([batch0, batch1]) == pa.Table.from_pandas(result)

adapter_sql.write([batch0, batch1, batch2])
result = adapter_sql.read()
# the pandas dataframe gives the last column of the data as 0 and 1 since SQL does not save boolean
# so we explicitely convert the last column to boolean for testing purposes
result["f2"] = result["f2"].astype("boolean")
assert pa.Table.from_batches([batch0, batch1, batch2]) == pa.Table.from_pandas(
result
)

# test write , append and read all
adapter_sql.write([batch0, batch1, batch2])
adapter_sql.append([batch2, batch0, batch1])
adapter_sql.append([batch1, batch2, batch0])
result = adapter_sql.read()
# the pandas dataframe gives the last column of the data as 0 and 1 since SQL does not save boolean
# so we explicitely convert the last column to boolean for testing purposes
result["f2"] = result["f2"].astype("boolean")

assert pa.Table.from_batches(
[batch0, batch1, batch2, batch2, batch0, batch1, batch1, batch2, batch0]
) == pa.Table.from_pandas(result)


@pytest.fixture
def postgres_uri() -> str:
uri = os.getenv("TILED_TEST_POSTGRESQL_URI")
if uri is not None:
return uri
pytest.skip("TILED_TEST_POSTGRESQL_URI is not set")
return ""


@pytest.fixture
def adapter_psql(postgres_uri: str) -> SQLAdapter:
table = pa.Table.from_arrays(data0, names)
structure = TableStructure.from_arrow_table(table, npartitions=1)
asset = SQLAdapter.init_storage(postgres_uri, structure=structure)
return SQLAdapter(asset.data_uri, structure=structure)


def test_psql(postgres_uri: str, adapter_psql: SQLAdapter) -> None:
assert adapter_psql.structure().columns == names
assert adapter_psql.structure().npartitions == 1
# assert isinstance(
# adapter_psql.conn, adbc_driver_postgresql.dbapi.AdbcSqliteConnection
# )
3 changes: 2 additions & 1 deletion tiled/_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ..catalog import from_uri, in_memory
from ..client.base import BaseClient
from ..server.settings import get_settings
from ..utils import ensure_specified_sql_driver
from .utils import enter_password as utils_enter_password
from .utils import temp_postgres

Expand Down Expand Up @@ -152,7 +153,7 @@ async def postgresql_with_example_data_adapter(request, tmpdir):
if uri.endswith("/"):
uri = uri[:-1]
uri_with_database_name = f"{uri}/{DATABASE_NAME}"
engine = create_async_engine(uri_with_database_name)
engine = create_async_engine(ensure_specified_sql_driver(uri_with_database_name))
try:
async with engine.connect():
pass
Expand Down
61 changes: 61 additions & 0 deletions tiled/_tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from ..utils import ensure_specified_sql_driver


def test_ensure_specified_sql_driver():
# Postgres
# Default driver is added if missing.
assert (
ensure_specified_sql_driver(
"postgresql://user:password@localhost:5432/database"
)
== "postgresql+asyncpg://user:password@localhost:5432/database"
)
# Default driver passes through if specified.
assert (
ensure_specified_sql_driver(
"postgresql+asyncpg://user:password@localhost:5432/database"
)
== "postgresql+asyncpg://user:password@localhost:5432/database"
)
# Do not override user-provided.
assert (
ensure_specified_sql_driver(
"postgresql+custom://user:password@localhost:5432/database"
)
== "postgresql+custom://user:password@localhost:5432/database"
)

# SQLite
# Default driver is added if missing.
assert (
ensure_specified_sql_driver("sqlite:////test.db")
== "sqlite+aiosqlite:////test.db"
)
# Default driver passes through if specified.
assert (
ensure_specified_sql_driver("sqlite+aiosqlite:////test.db")
== "sqlite+aiosqlite:////test.db"
)
# Do not override user-provided.
assert (
ensure_specified_sql_driver("sqlite+custom:////test.db")
== "sqlite+custom:////test.db"
)
# Handle SQLite :memory: URIs
assert (
ensure_specified_sql_driver("sqlite+aiosqlite://:memory:")
== "sqlite+aiosqlite://:memory:"
)
assert (
ensure_specified_sql_driver("sqlite://:memory:")
== "sqlite+aiosqlite://:memory:"
)
# Handle SQLite relative URIs
assert (
ensure_specified_sql_driver("sqlite+aiosqlite:///test.db")
== "sqlite+aiosqlite:///test.db"
)
assert (
ensure_specified_sql_driver("sqlite:///test.db")
== "sqlite+aiosqlite:///test.db"
)
3 changes: 2 additions & 1 deletion tiled/_tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from ..client import context
from ..client.base import BaseClient
from ..utils import ensure_specified_sql_driver

if sys.version_info < (3, 9):
import importlib_resources as resources
Expand All @@ -33,7 +34,7 @@ async def temp_postgres(uri):
if uri.endswith("/"):
uri = uri[:-1]
# Create a fresh database.
engine = create_async_engine(uri)
engine = create_async_engine(ensure_specified_sql_driver(uri))
database_name = f"tiled_test_disposable_{uuid.uuid4().hex}"
async with engine.connect() as connection:
await connection.execute(
Expand Down
Loading
Loading