Skip to content

Commit

Permalink
Merge pull request #21 from LuckySting/declare-clause-compilation
Browse files Browse the repository at this point in the history
Thanks for the work :)
  • Loading branch information
rekby authored Dec 15, 2023
2 parents cf9a436 + 00c618c commit d413ef3
Show file tree
Hide file tree
Showing 9 changed files with 419 additions and 274 deletions.
2 changes: 1 addition & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
version: "3.3"
services:
ydb:
image: cr.yandex/yc/yandex-docker-local-ydb:latest
image: cr.yandex/yc/yandex-docker-local-ydb:trunk
restart: always
ports:
- "2136:2136"
Expand Down
12 changes: 9 additions & 3 deletions test/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@ def test_sa_text(self, connection):
assert rs.fetchone() == (1,)

rs = connection.execute(
sa.text("SELECT x, y FROM AS_TABLE(:data)"), [{"data": [{"x": 2, "y": 1}, {"x": 3, "y": 2}]}]
sa.text(
"""
DECLARE :data AS List<Struct<x:Int64, y:Int64>>;
SELECT x, y FROM AS_TABLE(:data)
"""
),
[{"data": [{"x": 2, "y": 1}, {"x": 3, "y": 2}]}],
)
assert set(rs.fetchall()) == {(2, 1), (3, 2)}

Expand Down Expand Up @@ -184,7 +190,7 @@ def test_select_types(self, connection):
id=1,
# bin=b"abc",
str="Hello World!",
num=3.1415,
num=3.5,
bl=True,
ts=now,
date=today,
Expand All @@ -193,4 +199,4 @@ def test_select_types(self, connection):
connection.execute(stm)

row = connection.execute(sa.select(tb)).fetchone()
assert row == (1, "Hello World!", 3.1415, True, now, today)
assert row == (1, "Hello World!", 3.5, True, now, today)
99 changes: 98 additions & 1 deletion test/test_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sqlalchemy.testing.suite import * # noqa: F401, F403

from sqlalchemy.testing import is_true, is_false
from sqlalchemy.testing.suite import eq_, testing, inspect, provide_metadata, config, requirements
from sqlalchemy.testing.suite import eq_, testing, inspect, provide_metadata, config, requirements, fixtures
from sqlalchemy.testing.suite import func, column, literal_column, select, exists
from sqlalchemy.testing.suite import MetaData, Column, Table, Integer, String

Expand All @@ -32,6 +32,10 @@
NativeUUIDTest as _NativeUUIDTest,
TimeMicrosecondsTest as _TimeMicrosecondsTest,
DateTimeCoercedToDateTimeTest as _DateTimeCoercedToDateTimeTest,
DateTest as _DateTest,
DateTimeMicrosecondsTest as _DateTimeMicrosecondsTest,
DateTimeTest as _DateTimeTest,
TimestampMicrosecondsTest as _TimestampMicrosecondsTest,
)
from sqlalchemy.testing.suite.test_dialect import (
EscapingTest as _EscapingTest,
Expand All @@ -47,6 +51,7 @@
from sqlalchemy.testing.suite.test_results import RowFetchTest as _RowFetchTest
from sqlalchemy.testing.suite.test_deprecations import DeprecatedCompoundSelectTest as _DeprecatedCompoundSelectTest

from ydb_sqlalchemy.sqlalchemy import types as ydb_sa_types

test_types_suite = sqlalchemy.testing.suite.test_types
col_creator = test_types_suite.Column
Expand Down Expand Up @@ -259,6 +264,13 @@ def test_truediv_numeric(self):
# SqlAlchemy maybe eat Decimal and throw Double
pass

@testing.combinations(("6.25", "2.5", 2.5), argnames="left, right, expected")
def test_truediv_float(self, connection, left, right, expected):
eq_(
connection.scalar(select(literal_column(left, type_=sa.Float()) / literal_column(right, type_=sa.Float()))),
expected,
)


class ExistsTest(_ExistsTest):
"""
Expand Down Expand Up @@ -402,6 +414,26 @@ def test_insert_from_select_autoinc(self, connection):
def test_insert_from_select_autoinc_no_rows(self, connection):
pass

@pytest.mark.skip("implicit PK values unsupported")
def test_no_results_for_non_returning_insert(self, connection):
pass


class DateTest(_DateTest):
run_dispose_bind = "once"


class DateTimeMicrosecondsTest(_DateTimeMicrosecondsTest):
run_dispose_bind = "once"


class DateTimeTest(_DateTimeTest):
run_dispose_bind = "once"


class TimestampMicrosecondsTest(_TimestampMicrosecondsTest):
run_dispose_bind = "once"


@pytest.mark.skip("unsupported Time data type")
class TimeTest(_TimeTest):
Expand All @@ -418,6 +450,71 @@ def test_nolength_string(self):
foo.drop(config.db)


class ContainerTypesTest(fixtures.TablesTest):
@classmethod
def define_tables(cls, metadata):
Table(
"container_types_test",
metadata,
Column("id", Integer),
sa.PrimaryKeyConstraint("id"),
schema=None,
test_needs_fk=True,
)

def test_ARRAY_bind_variable(self, connection):
table = self.tables.container_types_test

connection.execute(sa.insert(table).values([{"id": 1}, {"id": 2}, {"id": 3}]))

stmt = select(table.c.id).where(table.c.id.in_(sa.bindparam("id", type_=sa.ARRAY(sa.Integer))))

eq_(connection.execute(stmt, {"id": [1, 2]}).fetchall(), [(1,), (2,)])

def test_list_type_bind_variable(self, connection):
table = self.tables.container_types_test

connection.execute(sa.insert(table).values([{"id": 1}, {"id": 2}, {"id": 3}]))

stmt = select(table.c.id).where(table.c.id.in_(sa.bindparam("id", type_=ydb_sa_types.ListType(sa.Integer))))

eq_(connection.execute(stmt, {"id": [1, 2]}).fetchall(), [(1,), (2,)])

def test_struct_type_bind_variable(self, connection):
table = self.tables.container_types_test

connection.execute(sa.insert(table).values([{"id": 1}, {"id": 2}, {"id": 3}]))

stmt = select(table.c.id).where(
table.c.id
== sa.text(":struct.id").bindparams(
sa.bindparam("struct", type_=ydb_sa_types.StructType({"id": sa.Integer})),
)
)

eq_(connection.scalar(stmt, {"struct": {"id": 1}}), 1)

def test_from_as_table(self, connection):
table = self.tables.container_types_test

connection.execute(
sa.insert(table).from_select(
["id"],
sa.select(sa.column("id")).select_from(
sa.func.as_table(
sa.bindparam(
"data",
value=[{"id": 1}, {"id": 2}, {"id": 3}],
type_=ydb_sa_types.ListType(ydb_sa_types.StructType({"id": sa.Integer})),
)
)
),
)
)

eq_(connection.execute(sa.select(table)).fetchall(), [(1,), (2,), (3,)])


@pytest.mark.skip("uuid unsupported for columns")
class NativeUUIDTest(_NativeUUIDTest):
pass
Expand Down
104 changes: 31 additions & 73 deletions test_dbapi/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,101 +12,59 @@ def test_connection(connection):

cur = connection.cursor()
with suppress(dbapi.DatabaseError):
cur.execute("DROP TABLE foo", context={"isddl": True})
cur.execute(dbapi.YdbQuery("DROP TABLE foo", is_ddl=True))

assert not connection.check_exists("/local/foo")
with pytest.raises(dbapi.ProgrammingError):
connection.describe("/local/foo")

cur.execute("CREATE TABLE foo(id Int64 NOT NULL, PRIMARY KEY (id))", context={"isddl": True})
cur.execute(dbapi.YdbQuery("CREATE TABLE foo(id Int64 NOT NULL, PRIMARY KEY (id))", is_ddl=True))

assert connection.check_exists("/local/foo")

col = connection.describe("/local/foo").columns[0]
assert col.name == "id"
assert col.type == ydb.PrimitiveType.Int64

cur.execute("DROP TABLE foo", context={"isddl": True})
cur.execute(dbapi.YdbQuery("DROP TABLE foo", is_ddl=True))
cur.close()


def test_cursor(connection):
def test_cursor_raw_query(connection):
cur = connection.cursor()
assert cur

with suppress(dbapi.DatabaseError):
cur.execute("DROP TABLE test", context={"isddl": True})
cur.execute(dbapi.YdbQuery("DROP TABLE test", is_ddl=True))

cur.execute(
"CREATE TABLE test(id Int64 NOT NULL, text Utf8, PRIMARY KEY (id))",
context={"isddl": True},
)

cur.execute('INSERT INTO test(id, text) VALUES (1, "foo")')

cur.execute("SELECT id, text FROM test")
assert cur.fetchone() == (1, "foo"), "fetchone is ok"

cur.execute("SELECT id, text FROM test WHERE id = %(id)s", {"id": 1})
assert cur.fetchone() == (1, "foo"), "parametrized query is ok"
cur.execute(dbapi.YdbQuery("CREATE TABLE test(id Int64 NOT NULL, text Utf8, PRIMARY KEY (id))", is_ddl=True))

cur.execute(
"INSERT INTO test(id, text) VALUES (%(id1)s, %(text1)s), (%(id2)s, %(text2)s)",
{"id1": 2, "text1": "", "id2": 3, "text2": "bar"},
)

cur.execute("UPDATE test SET text = %(t)s WHERE id = %(id)s", {"id": 2, "t": "foo2"})

cur.execute("SELECT id FROM test")
assert set(cur.fetchall()) == {(1,), (2,), (3,)}, "fetchall is ok"

cur.execute("SELECT id FROM test ORDER BY id DESC")
assert cur.fetchmany(2) == [(3,), (2,)], "fetchmany is ok"
assert cur.fetchmany(1) == [(1,)]

cur.execute("SELECT id FROM test ORDER BY id LIMIT 2")
assert cur.fetchall() == [(1,), (2,)], "limit clause without params is ok"

# TODO: Failed to convert type: Int64 to Uint64
# cur.execute("SELECT id FROM test ORDER BY id LIMIT %(limit)s", {"limit": 2})
# assert cur.fetchall() == [(1,), (2,)], "limit clause with params is ok"

cur2 = connection.cursor()
cur2.execute("INSERT INTO test(id) VALUES (%(id1)s), (%(id2)s)", {"id1": 5, "id2": 6})

cur.execute("SELECT id FROM test ORDER BY id")
assert cur.fetchall() == [(1,), (2,), (3,), (5,), (6,)], "cursor2 commit changes"

cur.execute("SELECT text FROM test WHERE id > %(min_id)s", {"min_id": 3})
assert cur.fetchall() == [(None,), (None,)], "NULL returns as None"

cur.execute("SELECT id, text FROM test WHERE text LIKE %(p)s", {"p": "foo%"})
assert set(cur.fetchall()) == {(1, "foo"), (2, "foo2")}, "like clause works"

cur.execute(
# DECLARE statement (DECLARE $data AS List<Struct<id:Int64,text:Utf8>>)
# will generate automatically
"""INSERT INTO test SELECT id, text FROM AS_TABLE($data);""",
dbapi.YdbQuery(
"""
DECLARE $data AS List<Struct<id:Int64, text: Utf8>>;
INSERT INTO test SELECT id, text FROM AS_TABLE($data);
""",
parameters_types={
"$data": ydb.ListType(
ydb.StructType()
.add_member("id", ydb.PrimitiveType.Int64)
.add_member("text", ydb.PrimitiveType.Utf8)
)
},
),
{
"data": [
"$data": [
{"id": 17, "text": "seventeen"},
{"id": 21, "text": "twenty one"},
]
},
)

cur.execute("SELECT id FROM test ORDER BY id")
assert cur.rowcount == 7, "rowcount ok"
assert cur.fetchall() == [(1,), (2,), (3,), (5,), (6,), (17,), (21,)], "ok"

cur.execute("INSERT INTO test(id) VALUES (37)")
cur.execute("SELECT id FROM test WHERE text IS %(p)s", {"p": None})
assert not cur.fetchone() == (37,)

cur.execute("DROP TABLE test", context={"isddl": True})
cur.execute(dbapi.YdbQuery("DROP TABLE test", is_ddl=True))

cur.close()
cur2.close()


def test_errors(connection):
Expand All @@ -116,25 +74,25 @@ def test_errors(connection):
cur = connection.cursor()

with suppress(dbapi.DatabaseError):
cur.execute("DROP TABLE test", context={"isddl": True})
cur.execute(dbapi.YdbQuery("DROP TABLE test", is_ddl=True))

with pytest.raises(dbapi.DataError):
cur.execute("SELECT 18446744073709551616")
cur.execute(dbapi.YdbQuery("SELECT 18446744073709551616"))

with pytest.raises(dbapi.DataError):
cur.execute("SELECT * FROM 拉屎")
cur.execute(dbapi.YdbQuery("SELECT * FROM 拉屎"))

with pytest.raises(dbapi.DataError):
cur.execute("SELECT floor(5 / 2)")
cur.execute(dbapi.YdbQuery("SELECT floor(5 / 2)"))

with pytest.raises(dbapi.ProgrammingError):
cur.execute("SELECT * FROM test")
cur.execute(dbapi.YdbQuery("SELECT * FROM test"))

cur.execute("CREATE TABLE test(id Int64, PRIMARY KEY (id))", context={"isddl": True})
cur.execute(dbapi.YdbQuery("CREATE TABLE test(id Int64, PRIMARY KEY (id))", is_ddl=True))

cur.execute("INSERT INTO test(id) VALUES(1)")
cur.execute(dbapi.YdbQuery("INSERT INTO test(id) VALUES(1)"))
with pytest.raises(dbapi.IntegrityError):
cur.execute("INSERT INTO test(id) VALUES(1)")
cur.execute(dbapi.YdbQuery("INSERT INTO test(id) VALUES(1)"))

cur.execute("DROP TABLE test", context={"isddl": True})
cur.execute(dbapi.YdbQuery("DROP TABLE test", is_ddl=True))
cur.close()
1 change: 1 addition & 0 deletions ydb_sqlalchemy/dbapi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .connection import Connection
from .cursor import Cursor, YdbQuery # noqa: F401
from .errors import (
Warning,
Error,
Expand Down
Loading

0 comments on commit d413ef3

Please sign in to comment.