Skip to content

Commit

Permalink
infra: add min version testing to pr test flow (langchain-ai#24358)
Browse files Browse the repository at this point in the history
xfailing some sql tests that do not currently work on sqlalchemy v1

langchain-ai#22207 was very much not sqlalchemy v1 compatible. 

Moving forward, implementations should be compatible with both to pass
CI
  • Loading branch information
efriis authored and olgamurraft committed Aug 16, 2024
1 parent f31d79a commit 068c5f4
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 19 deletions.
7 changes: 6 additions & 1 deletion .github/scripts/get_min_versions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import sys

import tomllib
if sys.version_info >= (3, 11):
import tomllib
else:
# for python 3.10 and below, which doesnt have stdlib tomllib
import tomli as tomllib

from packaging.version import parse as parse_version
import re

Expand Down
18 changes: 18 additions & 0 deletions .github/workflows/_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,21 @@ jobs:
# grep will exit non-zero if the target message isn't found,
# and `set -e` above will cause the step to fail.
echo "$STATUS" | grep 'nothing to commit, working tree clean'
- name: Get minimum versions
working-directory: ${{ inputs.working-directory }}
id: min-version
run: |
poetry run pip install packaging tomli
min_versions="$(poetry run python $GITHUB_WORKSPACE/.github/scripts/get_min_versions.py pyproject.toml)"
echo "min-versions=$min_versions" >> "$GITHUB_OUTPUT"
echo "min-versions=$min_versions"
- name: Run unit tests with minimum dependency versions
if: ${{ steps.min-version.outputs.min-versions != '' }}
env:
MIN_VERSIONS: ${{ steps.min-version.outputs.min-versions }}
run: |
poetry run pip install --force-reinstall $MIN_VERSIONS --editable .
make tests
working-directory: ${{ inputs.working-directory }}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.orm import (
Expand All @@ -44,6 +43,12 @@
sessionmaker,
)

try:
from sqlalchemy.ext.asyncio import async_sessionmaker
except ImportError:
# dummy for sqlalchemy < 2
async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore

logger = logging.getLogger(__name__)


Expand Down
61 changes: 45 additions & 16 deletions libs/community/langchain_community/storage/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,46 +17,72 @@

from langchain_core.stores import BaseStore
from sqlalchemy import (
Engine,
LargeBinary,
Text,
and_,
create_engine,
delete,
select,
)
from sqlalchemy.engine.base import Engine
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.orm import (
Mapped,
Session,
declarative_base,
mapped_column,
sessionmaker,
)

try:
from sqlalchemy.ext.asyncio import async_sessionmaker
except ImportError:
# dummy for sqlalchemy < 2
async_sessionmaker = type("async_sessionmaker", (type,), {}) # type: ignore

Base = declarative_base()

try:
from sqlalchemy.orm import mapped_column

def items_equal(x: Any, y: Any) -> bool:
return x == y
class LangchainKeyValueStores(Base): # type: ignore[valid-type,misc]
"""Table used to save values."""

# ATTENTION:
# Prior to modifying this table, please determine whether
# we should create migrations for this table to make sure
# users do not experience data loss.
__tablename__ = "langchain_key_value_stores"

namespace: Mapped[str] = mapped_column(
primary_key=True, index=True, nullable=False
)
key: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False)
value = mapped_column(LargeBinary, index=False, nullable=False)

except ImportError:
# dummy for sqlalchemy < 2
from sqlalchemy import Column

class LangchainKeyValueStores(Base): # type: ignore[valid-type,misc,no-redef]
"""Table used to save values."""

# ATTENTION:
# Prior to modifying this table, please determine whether
# we should create migrations for this table to make sure
# users do not experience data loss.
__tablename__ = "langchain_key_value_stores"

class LangchainKeyValueStores(Base): # type: ignore[valid-type,misc]
"""Table used to save values."""
namespace = Column(Text(), primary_key=True, index=True, nullable=False)
key = Column(Text(), primary_key=True, index=True, nullable=False)
value = Column(LargeBinary, index=False, nullable=False)

# ATTENTION:
# Prior to modifying this table, please determine whether
# we should create migrations for this table to make sure
# users do not experience data loss.
__tablename__ = "langchain_key_value_stores"

namespace: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False)
key: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False)
value = mapped_column(LargeBinary, index=False, nullable=False)
def items_equal(x: Any, y: Any) -> bool:
return x == y


# This is a fix of original SQLStore.
Expand Down Expand Up @@ -135,7 +161,10 @@ def __init__(
self.namespace = namespace

def create_schema(self) -> None:
Base.metadata.create_all(self.engine)
Base.metadata.create_all(self.engine) # problem in sqlalchemy v1
# sqlalchemy.exc.CompileError: (in table 'langchain_key_value_stores',
# column 'namespace'): Can't generate DDL for NullType(); did you forget
# to specify a type on this Column?

async def acreate_schema(self) -> None:
assert isinstance(self.engine, AsyncEngine)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
continue

# Ignore JSON datatyped columns
for k, v in table.columns.items():
for k, v in table.columns.items(): # AttributeError: items in sqlalchemy v1
if type(v.type) is NullType:
table._columns.remove(v)

Expand Down
7 changes: 7 additions & 0 deletions libs/community/tests/unit_tests/storage/test_sql.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from typing import AsyncGenerator, Generator, cast

import pytest
import sqlalchemy as sa
from langchain.storage._lc_store import create_kv_docstore, create_lc_store
from langchain_core.documents import Document
from langchain_core.stores import BaseStore
from packaging import version

from langchain_community.storage.sql import SQLStore

is_sqlalchemy_v1 = version.parse(sa.__version__).major == 1


@pytest.fixture
def sql_store() -> Generator[SQLStore, None, None]:
Expand All @@ -22,6 +26,7 @@ async def async_sql_store() -> AsyncGenerator[SQLStore, None]:
yield store


@pytest.mark.xfail(is_sqlalchemy_v1, reason="SQLAlchemy 1.x issues")
def test_create_lc_store(sql_store: SQLStore) -> None:
"""Test that a docstore is created from a base store."""
docstore: BaseStore[str, Document] = cast(
Expand All @@ -34,6 +39,7 @@ def test_create_lc_store(sql_store: SQLStore) -> None:
assert fetched_doc.metadata == {"key": "value"}


@pytest.mark.xfail(is_sqlalchemy_v1, reason="SQLAlchemy 1.x issues")
def test_create_kv_store(sql_store: SQLStore) -> None:
"""Test that a docstore is created from a base store."""
docstore = create_kv_docstore(sql_store)
Expand All @@ -57,6 +63,7 @@ async def test_async_create_kv_store(async_sql_store: SQLStore) -> None:
assert fetched_doc.metadata == {"key": "value"}


@pytest.mark.xfail(is_sqlalchemy_v1, reason="SQLAlchemy 1.x issues")
def test_sample_sql_docstore(sql_store: SQLStore) -> None:
# Set values for keys
sql_store.mset([("key1", b"value1"), ("key2", b"value2")])
Expand Down
3 changes: 3 additions & 0 deletions libs/community/tests/unit_tests/test_sql_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def db_lazy_reflection(engine: Engine) -> SQLDatabase:
return SQLDatabase(engine, lazy_table_reflection=True)


@pytest.mark.xfail(is_sqlalchemy_v1, reason="SQLAlchemy 1.x issues")
def test_table_info(db: SQLDatabase) -> None:
"""Test that table info is constructed properly."""
output = db.table_info
Expand Down Expand Up @@ -85,6 +86,7 @@ def test_table_info(db: SQLDatabase) -> None:
assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split()))


@pytest.mark.xfail(is_sqlalchemy_v1, reason="SQLAlchemy 1.x issues")
def test_table_info_lazy_reflection(db_lazy_reflection: SQLDatabase) -> None:
"""Test that table info with lazy reflection"""
assert len(db_lazy_reflection._metadata.sorted_tables) == 0
Expand All @@ -111,6 +113,7 @@ def test_table_info_lazy_reflection(db_lazy_reflection: SQLDatabase) -> None:
assert db_lazy_reflection._metadata.sorted_tables[1].name == "user"


@pytest.mark.xfail(is_sqlalchemy_v1, reason="SQLAlchemy 1.x issues")
def test_table_info_w_sample_rows(db: SQLDatabase) -> None:
"""Test that table info is constructed properly."""

Expand Down
9 changes: 9 additions & 0 deletions libs/community/tests/unit_tests/test_sql_database_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
insert,
schema,
)
import sqlalchemy as sa

from packaging import version

from langchain_community.utilities.sql_database import SQLDatabase

Expand All @@ -43,6 +46,9 @@
)


@pytest.mark.xfail(
version.parse(sa.__version__).major == 1, reason="SQLAlchemy 1.x issues"
)
def test_table_info() -> None:
"""Test that table info is constructed properly."""
engine = create_engine("duckdb:///:memory:")
Expand All @@ -65,6 +71,9 @@ def test_table_info() -> None:
assert sorted(" ".join(output.split())) == sorted(" ".join(expected_output.split()))


@pytest.mark.xfail(
version.parse(sa.__version__).major == 1, reason="SQLAlchemy 1.x issues"
)
def test_sql_database_run() -> None:
"""Test that commands can be run successfully and returned in correct format."""
engine = create_engine("duckdb:///:memory:")
Expand Down

0 comments on commit 068c5f4

Please sign in to comment.