diff --git a/projects/pgai/tests/vectorizer/extensions/test_inheritance.py b/projects/pgai/tests/vectorizer/extensions/test_inheritance.py index 1b0bb6b77..916732583 100644 --- a/projects/pgai/tests/vectorizer/extensions/test_inheritance.py +++ b/projects/pgai/tests/vectorizer/extensions/test_inheritance.py @@ -8,7 +8,7 @@ from testcontainers.postgres import PostgresContainer # type: ignore from pgai.sqlalchemy import vectorizer_relationship -from tests.vectorizer.extensions.utils import run_vectorizer_worker +from tests.vectorizer.utils import run_vectorizer_worker class BaseModel(DeclarativeBase): diff --git a/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy.py b/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy.py index 0fe91dd19..ed2cca9e6 100644 --- a/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy.py +++ b/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy.py @@ -6,7 +6,7 @@ from testcontainers.postgres import PostgresContainer # type: ignore from pgai.sqlalchemy import vectorizer_relationship -from tests.vectorizer.extensions.utils import run_vectorizer_worker +from tests.vectorizer.utils import run_vectorizer_worker def test_sqlalchemy( diff --git a/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_composite_primary.py b/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_composite_primary.py index 8fe343a9f..232664a4e 100644 --- a/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_composite_primary.py +++ b/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_composite_primary.py @@ -7,7 +7,7 @@ from testcontainers.postgres import PostgresContainer # type: ignore from pgai.sqlalchemy import vectorizer_relationship -from tests.vectorizer.extensions.utils import run_vectorizer_worker +from tests.vectorizer.utils import run_vectorizer_worker class Base(DeclarativeBase): diff --git a/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_large_embeddings.py b/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_large_embeddings.py index 41d31a907..765e7dfb6 100644 --- a/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_large_embeddings.py +++ b/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_large_embeddings.py @@ -7,7 +7,7 @@ from testcontainers.postgres import PostgresContainer # type: ignore from pgai.sqlalchemy import vectorizer_relationship -from tests.vectorizer.extensions.utils import run_vectorizer_worker +from tests.vectorizer.utils import run_vectorizer_worker class Base(DeclarativeBase): diff --git a/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_lazy_strategies.py b/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_lazy_strategies.py index ad65478b8..bee2a85d8 100644 --- a/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_lazy_strategies.py +++ b/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_lazy_strategies.py @@ -7,7 +7,7 @@ from testcontainers.postgres import PostgresContainer # type: ignore from pgai.sqlalchemy import vectorizer_relationship -from tests.vectorizer.extensions.utils import run_vectorizer_worker +from tests.vectorizer.utils import run_vectorizer_worker class Base(DeclarativeBase): diff --git a/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_relationship.py b/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_relationship.py index 476354871..f58eaef60 100644 --- a/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_relationship.py +++ b/projects/pgai/tests/vectorizer/extensions/test_sqlalchemy_relationship.py @@ -7,9 +7,7 @@ from testcontainers.postgres import PostgresContainer # type: ignore from pgai.sqlalchemy import vectorizer_relationship -from tests.vectorizer.extensions.utils import ( - run_vectorizer_worker, -) +from tests.vectorizer.utils import run_vectorizer_worker class Base(DeclarativeBase): diff --git a/projects/pgai/tests/vectorizer/extensions/utils.py b/projects/pgai/tests/vectorizer/extensions/utils.py index 0596d831f..1f2ec72d3 100644 --- a/projects/pgai/tests/vectorizer/extensions/utils.py +++ b/projects/pgai/tests/vectorizer/extensions/utils.py @@ -1,29 +1,11 @@ from pathlib import Path from typing import Any -from click.testing import CliRunner from sqlalchemy import Column -from pgai.cli import vectorizer_worker from tests.vectorizer.extensions.conftest import load_template -def run_vectorizer_worker(db_url: str, vectorizer_id: int) -> None: - CliRunner().invoke( - vectorizer_worker, - [ - "--db-url", - db_url, - "--once", - "--vectorizer-id", - str(vectorizer_id), - "--concurrency", - "1", - ], - catch_exceptions=False, - ) - - def create_vectorizer_migration( migrations_dir: Path, table_name: str, diff --git a/projects/pgai/tests/vectorizer/test_vectorizer_cli.py b/projects/pgai/tests/vectorizer/test_vectorizer_cli.py index d5536bdda..b8159a794 100644 --- a/projects/pgai/tests/vectorizer/test_vectorizer_cli.py +++ b/projects/pgai/tests/vectorizer/test_vectorizer_cli.py @@ -11,14 +11,13 @@ import openai import psycopg import pytest -from click.testing import CliRunner from psycopg import Connection, sql from psycopg.rows import dict_row from testcontainers.ollama import OllamaContainer # type: ignore from testcontainers.postgres import PostgresContainer # type: ignore -from pgai.cli import vectorizer_worker from tests.vectorizer import expected +from tests.vectorizer.utils import run_vectorizer_worker count = 10000 @@ -88,31 +87,29 @@ def cli_db_url(cli_db: tuple[TestDatabase, Connection]) -> str: def test_worker_no_tasks(cli_db_url: str): """Test that worker handles no tasks gracefully""" - result = CliRunner().invoke(vectorizer_worker, ["--db-url", cli_db_url, "--once"]) + result = run_vectorizer_worker(cli_db_url) # It exits successfully assert result.exit_code == 0 assert "no vectorizers found" in result.output.lower() -@pytest.fixture -def source_table( - cli_db: tuple[TestDatabase, Connection], test_params: tuple[int, int, int, str, str] -) -> str: - _, conn = cli_db - num_items = test_params[0] +def setup_source_table( + connection: Connection, + number_of_rows: int, +): table_name = "blog" - with conn.cursor(row_factory=dict_row) as cur: + with connection.cursor(row_factory=dict_row) as cur: # Create source table cur.execute(f""" - CREATE TABLE {table_name} ( - id INT NOT NULL PRIMARY KEY, - id2 INT NOT NULL, - content TEXT NOT NULL - ) - """) + CREATE TABLE {table_name} ( + id INT NOT NULL PRIMARY KEY, + id2 INT NOT NULL, + content TEXT NOT NULL + ) + """) # Insert test data - values = [(i, i, f"post_{i}") for i in range(1, num_items + 1)] + values = [(i, i, f"post_{i}") for i in range(1, number_of_rows + 1)] cur.executemany( "INSERT INTO blog (id, id2, content) VALUES (%s, %s, %s)", values ) @@ -120,7 +117,7 @@ def source_table( @pytest.fixture(scope="session") -def ollama_connection_url(): +def ollama_url(): # If the OLLAMA_HOST environment variable is set, we assume that the user # has an Ollama container running and we don't need to start a new one. if "OLLAMA_HOST" in os.environ: @@ -137,14 +134,14 @@ def ollama_connection_url(): def configure_vectorizer( source_table: str, - cli_db: tuple[TestDatabase, Connection], - test_params: tuple[int, int, int, str, str], - embedding: str, + connection: Connection, + concurrency: int = 1, + batch_size: int = 1, + chunking: str = "chunking_character_text_splitter('content')", + formatting: str = "formatting_python_template('$chunk')", + embedding: str = "embedding_openai('text-embedding-ada-002', 1536)", ): - _, concurrency, batch_size, chunking, formatting = test_params - _, conn = cli_db - - with conn.cursor(row_factory=dict_row) as cur: + with connection.cursor(row_factory=dict_row) as cur: # Create vectorizer cur.execute(f""" SELECT ai.create_vectorizer( @@ -161,99 +158,86 @@ def configure_vectorizer( return vectorizer_id -@pytest.fixture -def configured_openai_vectorizer_id( - source_table: str, - cli_db: tuple[TestDatabase, Connection], - test_params: tuple[int, int, int, str, str], - openai_proxy_url: str | None, +def configure_openai_vectorizer( + connection: Connection, + openai_proxy_url: str | None = None, + number_of_rows: int = 1, + concurrency: int = 1, + batch_size: int = 1, + chunking: str = "chunking_character_text_splitter('content')", + formatting: str = "formatting_python_template('$chunk')", ) -> int: """Creates and configures a vectorizer for testing""" - + table_name = setup_source_table(connection, number_of_rows) base_url = ( f", base_url => '{openai_proxy_url}'" if openai_proxy_url is not None else "" ) embedding = f"embedding_openai('text-embedding-ada-002', 1536{base_url})" - return configure_vectorizer( - source_table, - cli_db, - test_params, - embedding, + table_name, + connection, + concurrency=concurrency, + batch_size=batch_size, + chunking=chunking, + formatting=formatting, + embedding=embedding, ) -@pytest.fixture -def configured_ollama_vectorizer_id( - source_table: str, - cli_db: tuple[TestDatabase, Connection], - test_params: tuple[int, int, int, str, str], - ollama_connection_url: str, +def configure_ollama_vectorizer( + connection: Connection, + ollama_url: str, + number_of_rows: int = 1, + concurrency: int = 1, + batch_size: int = 1, + chunking: str = "chunking_character_text_splitter('content')", + formatting: str = "formatting_python_template('$chunk')", ) -> int: """Creates and configures an ollama vectorizer for testing""" + + table_name = setup_source_table(connection, number_of_rows) return configure_vectorizer( - source_table, - cli_db, - test_params, - f"embedding_ollama('nomic-embed-text', 768, base_url => '{ollama_connection_url}')", # noqa: E501 Line too long + table_name, + connection, + concurrency=concurrency, + batch_size=batch_size, + chunking=chunking, + formatting=formatting, + embedding=f"embedding_ollama('nomic-embed-text'," + f"768, base_url => '{ollama_url}')", ) -@pytest.fixture -def configured_voyageai_vectorizer_id( - source_table: str, - cli_db: tuple[TestDatabase, Connection], - test_params: tuple[int, int, int, str, str], +def configure_voyageai_vectorizer_id( + connection: Connection, + number_of_rows: int = 1, + concurrency: int = 1, + batch_size: int = 1, + chunking: str = "chunking_character_text_splitter('content')", + formatting: str = "formatting_python_template('$chunk')", ) -> int: """Creates and configures a VoyageAI vectorizer for testing""" + + table_name = setup_source_table(connection, number_of_rows) return configure_vectorizer( - source_table, cli_db, test_params, "embedding_voyageai('voyage-3-lite', 512)" + table_name, + connection, + concurrency=concurrency, + batch_size=batch_size, + chunking=chunking, + formatting=formatting, + embedding="embedding_voyageai('voyage-3-lite', 512)", ) -@pytest.fixture -def test_params(request: pytest.FixtureRequest) -> tuple[int, int, int, str, str]: - """Parameters for test variations: - (num_items, concurrency, batch_size, chunking, formatting)""" - return request.param - - class TestWithOpenAiVectorizer: @pytest.mark.parametrize( - "test_params,openai_proxy_url", + "num_items,concurrency,batch_size,openai_proxy_url", [ - ( - ( - 1, - 1, - 1, - "chunking_character_text_splitter('content')", - "formatting_python_template('$chunk')", - ), - None, # No base_url is set. Use default (https://api.openai.com/v1) - ), - ( - ( - 1, - 1, - 1, - "chunking_character_text_splitter('content')", - "formatting_python_template('$chunk')", - ), - # Same test as before but with a custom base_url - 8000, - ), - ( - ( - 4, - 2, - 2, - "chunking_character_text_splitter('content')", - "formatting_python_template('$chunk')", - ), - None, # No base_url is set. Use default (https://api.openai.com/v1) - ), + (1, 1, 1, None), + (1, 1, 1, "8000"), + (4, 2, 2, None), ], indirect=["openai_proxy_url"], ) @@ -261,14 +245,21 @@ def test_process_vectorizer( self, cli_db: tuple[TestDatabase, Connection], cli_db_url: str, - configured_openai_vectorizer_id: int, vcr_: Any, - test_params: tuple[int, int, int, str, str], + num_items: int, + concurrency: int, + batch_size: int, openai_proxy_url: str | None, ): """Test successful processing of vectorizer tasks""" - num_items, concurrency, batch_size, _, _ = test_params _, conn = cli_db + vectorizer_id = configure_openai_vectorizer( + cli_db[1], + openai_proxy_url, + number_of_rows=num_items, + concurrency=concurrency, + batch_size=batch_size, + ) # Insert pre-existing embedding for first item with conn.cursor() as cur: cur.execute(""" @@ -291,19 +282,7 @@ def test_process_vectorizer( logging.getLogger("vcr").setLevel(logging.DEBUG) with vcr_.use_cassette(cassette): - result = CliRunner().invoke( - vectorizer_worker, - [ - "--db-url", - cli_db_url, - "--once", - "--vectorizer-id", - str(configured_openai_vectorizer_id), - "--concurrency", - str(concurrency), - ], - catch_exceptions=False, - ) + result = run_vectorizer_worker(cli_db_url, vectorizer_id, concurrency) assert not result.exception assert result.exit_code == 0 @@ -312,29 +291,15 @@ def test_process_vectorizer( cur.execute("SELECT count(*) as count FROM blog_embedding_store;") assert cur.fetchone()["count"] == num_items # type: ignore - @pytest.mark.parametrize( - "test_params", - [ - ( - 1, - 1, - 1, - "chunking_character_text_splitter('content')", - "formatting_python_template('$chunk')", - ), - ], - ) @pytest.mark.postgres_params(load_openai_key=False) def test_vectorizer_without_secrets_fails( self, cli_db: tuple[TestDatabase, Connection], cli_db_url: str, - configured_openai_vectorizer_id: int, vcr_: Any, - test_params: tuple[int, int, int, str, str], ): - num_items, concurrency, batch_size, _, _ = test_params _, conn = cli_db + vectorizer_id = configure_openai_vectorizer(cli_db[1]) # Insert pre-existing embedding for first item with conn.cursor() as cur: cur.execute(""" @@ -347,51 +312,31 @@ def test_vectorizer_without_secrets_fails( del os.environ["OPENAI_API_KEY"] cassette = ( - f"openai-character_text_splitter-chunk_value-" - f"items={num_items}-batch_size={batch_size}.yaml" + "openai-character_text_splitter-chunk_value-" "items=1-batch_size=1.yaml" ) logging.getLogger("vcr").setLevel(logging.DEBUG) with vcr_.use_cassette(cassette): - result = CliRunner().invoke( - vectorizer_worker, - [ - "--db-url", - cli_db_url, - "--once", - "--vectorizer-id", - str(configured_openai_vectorizer_id), - "--concurrency", - str(concurrency), - ], - catch_exceptions=False, - ) + result = run_vectorizer_worker(cli_db_url, vectorizer_id) assert result.exit_code == 1 assert "ApiKeyNotFoundError" in result.output - @pytest.mark.parametrize( - "test_params", - [ - ( - 2, - 1, - 2, - "chunking_recursive_character_text_splitter('content', 128, 10," - " separators => array[E'\n\n'])", - "formatting_python_template('$chunk')", - ) - ], - ) def test_document_exceeds_model_context_length( self, cli_db: tuple[TestDatabase, Connection], cli_db_url: str, - configured_openai_vectorizer_id: int, vcr_: Any, ): """Test handling of documents that exceed the model's token limit""" _, conn = cli_db # Given a vectorizer configuration + vectorizer_id = configure_openai_vectorizer( + cli_db[1], + number_of_rows=2, + batch_size=2, + chunking="chunking_recursive_character_text_splitter('content', 128, 10," + " separators => array[E'\n\n'])", + ) with conn.cursor(row_factory=dict_row) as cur: long_content = "AGI" * 5000 cur.execute( @@ -400,17 +345,7 @@ def test_document_exceeds_model_context_length( # When running the worker with vcr_.use_cassette("test_document_in_batch_too_long.yaml"): - result = CliRunner().invoke( - vectorizer_worker, - [ - "--db-url", - cli_db_url, - "--once", - "--vectorizer-id", - str(configured_openai_vectorizer_id), - ], - catch_exceptions=False, - ) + result = run_vectorizer_worker(cli_db_url, vectorizer_id) assert result.exit_code == 0 @@ -442,44 +377,28 @@ def test_document_exceeds_model_context_length( " model context length of 8192 tokens" ) - @pytest.mark.parametrize( - "test_params", - [ - ( - 2, - 1, - 2, - "chunking_recursive_character_text_splitter('content')", - "formatting_python_template('$chunk')", - ) - ], - ) def test_invalid_api_key_error( self, cli_db: tuple[TestDatabase, Connection], cli_db_url: str, - configured_openai_vectorizer_id: int, - test_params: Any, # noqa vcr_: Any, ): """Test that worker handles invalid API key appropriately""" _, conn = cli_db + vectorizer_id = configure_openai_vectorizer( + cli_db[1], + number_of_rows=2, + batch_size=2, + chunking="chunking_recursive_character_text_splitter('content')", + ) + os.environ["OPENAI_API_KEY"] = "invalid" # When running the worker and getting an invalid api key response with vcr_.use_cassette("test_invalid_api_key_error.yaml"): try: - CliRunner().invoke( - vectorizer_worker, - [ - "--db-url", - cli_db_url, - "--once", - "--vectorizer-id", - str(configured_openai_vectorizer_id), - ], - ) + run_vectorizer_worker(cli_db_url, vectorizer_id) except openai.AuthenticationError as e: assert e.code == 401 @@ -489,7 +408,7 @@ def test_invalid_api_key_error( records = cur.fetchall() assert len(records) == 1 error = records[0] - assert error["id"] == configured_openai_vectorizer_id + assert error["id"] == vectorizer_id assert error["message"] == "embedding provider failed" assert error["details"] == { "provider": "openai", @@ -501,27 +420,13 @@ def test_invalid_api_key_error( " 'code': 'invalid_api_key'}}", } - @pytest.mark.parametrize( - "test_params", - [ - ( - 1, - 1, - 1, - "chunking_character_text_splitter('content')", - "formatting_python_template('$chunk')", - ) - ], - ) def test_invalid_function_arguments( - self, - cli_db: tuple[TestDatabase, Connection], - cli_db_url: str, - configured_openai_vectorizer_id: int, + self, cli_db: tuple[TestDatabase, Connection], cli_db_url: str ): """Test that worker handles invalid embedding model arguments appropriately""" _, conn = cli_db + vectorizer_id = configure_openai_vectorizer(cli_db[1]) # And a vectorizer with invalid embedding dimensions with conn.cursor() as cur: cur.execute( @@ -534,21 +439,12 @@ def test_invalid_function_arguments( ) WHERE id = %s """, - (configured_openai_vectorizer_id,), + (vectorizer_id,), ) # When running the worker try: - CliRunner().invoke( - vectorizer_worker, - [ - "--db-url", - cli_db_url, - "--once", - "--vectorizer-id", - str(configured_openai_vectorizer_id), - ], - ) + run_vectorizer_worker(cli_db_url, vectorizer_id) except ValueError as e: assert str(e) == "dimensions must be 1536 for text-embedding-ada-002" @@ -558,7 +454,7 @@ def test_invalid_function_arguments( records = cur.fetchall() assert len(records) == 1 error = records[0] - assert error["id"] == configured_openai_vectorizer_id + assert error["id"] == vectorizer_id assert error["message"] == "embedding provider failed" assert error["details"] == { "provider": "openai", @@ -567,33 +463,30 @@ def test_invalid_function_arguments( @pytest.mark.parametrize( - "test_params", + "num_items,concurrency,batch_size", [ - ( - 1, - 1, - 1, - "chunking_character_text_splitter('content')", - "formatting_python_template('$chunk')", - ), - ( - 4, - 2, - 2, - "chunking_character_text_splitter('content')", - "formatting_python_template('$chunk')", - ), + (1, 1, 1), + (4, 2, 2), ], ) def test_ollama_vectorizer( cli_db: tuple[TestDatabase, Connection], cli_db_url: str, - configured_ollama_vectorizer_id: int, - test_params: tuple[int, int, int, str, str], + ollama_url: str, + num_items: int, + concurrency: int, + batch_size: int, ): """Test successful processing of vectorizer tasks""" - num_items, concurrency, _, _, _ = test_params _, conn = cli_db + + vectorizer_id = configure_ollama_vectorizer( + cli_db[1], + ollama_url, + number_of_rows=num_items, + concurrency=concurrency, + batch_size=batch_size, + ) # Insert pre-existing embedding for first item with conn.cursor() as cur: cur.execute(""" @@ -602,20 +495,7 @@ def test_ollama_vectorizer( VALUES (gen_random_uuid(), 1, 1, 'post_1', array_fill(0, ARRAY[768])::vector) """) - - result = CliRunner().invoke( - vectorizer_worker, - [ - "--db-url", - cli_db_url, - "--once", - "--vectorizer-id", - str(configured_ollama_vectorizer_id), - "--concurrency", - str(concurrency), - ], - catch_exceptions=False, - ) + result = run_vectorizer_worker(cli_db_url, vectorizer_id, concurrency) assert not result.exception assert result.exit_code == 0 @@ -626,36 +506,30 @@ def test_ollama_vectorizer( @pytest.mark.parametrize( - "test_params", + "num_items,concurrency,batch_size", [ - ( - 1, - 1, - 1, - "chunking_character_text_splitter('content')", - "formatting_python_template('$chunk')", - ), - ( - 4, - 2, - 2, - "chunking_character_text_splitter('content')", - "formatting_python_template('$chunk')", - ), + (1, 1, 1), + (4, 2, 2), ], ) def test_voyageai_vectorizer( cli_db: tuple[TestDatabase, Connection], cli_db_url: str, - configured_voyageai_vectorizer_id: int, vcr_: Any, - test_params: tuple[int, int, int, str, str], + num_items: int, + concurrency: int, + batch_size: int, ): """Test successful processing of vectorizer tasks""" if "VOYAGE_API_KEY" not in os.environ: os.environ["VOYAGE_API_KEY"] = "A FAKE KEY" - num_items, concurrency, batch_size, _, _ = test_params _, conn = cli_db + vectorizer_id = configure_voyageai_vectorizer_id( + cli_db[1], + number_of_rows=num_items, + concurrency=concurrency, + batch_size=batch_size, + ) # Insert pre-existing embedding for first item with conn.cursor() as cur: cur.execute(""" @@ -672,19 +546,7 @@ def test_voyageai_vectorizer( ) logging.getLogger("vcr").setLevel(logging.DEBUG) with vcr_.use_cassette(cassette): - result = CliRunner().invoke( - vectorizer_worker, - [ - "--db-url", - cli_db_url, - "--once", - "--vectorizer-id", - str(configured_voyageai_vectorizer_id), - "--concurrency", - str(concurrency), - ], - catch_exceptions=False, - ) + result = run_vectorizer_worker(cli_db_url, vectorizer_id, concurrency) assert not result.exception assert result.exit_code == 0 @@ -717,15 +579,8 @@ def test_voyageai_vectorizer_fails_when_api_key_is_not_set( chunking => ai.chunking_character_text_splitter('content') )""") # noqa cur.execute("INSERT INTO blog (id, content) VALUES(1, repeat('1', 100000))") - result = CliRunner().invoke( - vectorizer_worker, - [ - "--db-url", - cli_db_url, - "--once", - ], - catch_exceptions=False, - ) + + result = run_vectorizer_worker(cli_db_url) assert result.exit_code == 1 assert "ApiKeyNotFoundError" in result.output @@ -736,10 +591,7 @@ def test_vectorizer_exits_with_error_when_no_ai_extension( self, postgres_container: PostgresContainer, ): - result = CliRunner().invoke( - vectorizer_worker, - ["--db-url", postgres_container.get_connection_url(), "--once"], - ) + result = run_vectorizer_worker(postgres_container.get_connection_url()) assert result.exit_code == 1 assert "the pgai extension is not installed" in result.output.lower() @@ -747,18 +599,7 @@ def test_vectorizer_exits_with_error_when_no_ai_extension( def test_vectorizer_exits_with_error_when_vectorizers_specified_but_missing( self, cli_db_url: str ): - result = CliRunner().invoke( - vectorizer_worker, - [ - "--db-url", - cli_db_url, - "--poll-interval", - "0.1s", - "--vectorizer-id", - "0", - "--once", - ], - ) + result = run_vectorizer_worker(cli_db_url, vectorizer_id=0) assert result.exit_code != 0 assert "invalid vectorizers, wanted: [0], got: []" in result.output @@ -768,14 +609,9 @@ def test_vectorizer_does_not_exit_with_error_when_no_ai_extension( self, postgres_container: PostgresContainer, ): - result = CliRunner().invoke( - vectorizer_worker, - [ - "--db-url", - postgres_container.get_connection_url(), - "--once", - "--exit-on-error=false", - ], + result = run_vectorizer_worker( + postgres_container.get_connection_url(), + extra_params=["--exit-on-error=false"], ) assert result.exit_code == 0 @@ -784,18 +620,8 @@ def test_vectorizer_does_not_exit_with_error_when_no_ai_extension( def test_vectorizer_does_not_exit_with_error_when_vectorizers_specified_but_missing( self, cli_db_url: str ): - result = CliRunner().invoke( - vectorizer_worker, - [ - "--db-url", - cli_db_url, - "--poll-interval", - "0.1s", - "--vectorizer-id", - "0", - "--once", - "--exit-on-error=false", - ], + result = run_vectorizer_worker( + cli_db_url, vectorizer_id=0, extra_params=["--exit-on-error=false"] ) assert result.exit_code == 0 assert "invalid vectorizers, wanted: [0], got: []" in result.output @@ -867,28 +693,21 @@ def test_vectorizer_picks_up_new_vectorizer( process.terminate() -@pytest.mark.parametrize( - "test_params", - [ - ( - 2, - 1, - 2, - "chunking_recursive_character_text_splitter('content', 100, 20," - " separators => array[E'\\n\\n', E'\\n', ' '])", - "formatting_python_template('$chunk')", - ) - ], -) def test_recursive_character_splitting( cli_db: tuple[PostgresContainer, Connection], cli_db_url: str, - configured_openai_vectorizer_id: int, vcr_: Any, ): """Test that recursive character splitting correctly chunks content based on natural boundaries""" _, conn = cli_db + vectorizer_id = configure_openai_vectorizer( + cli_db[1], + number_of_rows=2, + batch_size=2, + chunking="chunking_recursive_character_text_splitter('content', 100, 20," + " separators => array[E'\\n\\n', E'\\n', ' '])", + ) # Given content with natural splitting points sample_content = """Introduction to Machine Learning @@ -913,17 +732,7 @@ def test_recursive_character_splitting( # When running the worker with vcr_.use_cassette("test_recursive_character_splitting.yaml"): - result = CliRunner().invoke( - vectorizer_worker, - [ - "--db-url", - cli_db_url, - "--once", - "--vectorizer-id", - str(configured_openai_vectorizer_id), - ], - catch_exceptions=False, - ) + result = run_vectorizer_worker(cli_db_url, vectorizer_id) assert result.exit_code == 0 @@ -969,47 +778,25 @@ def test_recursive_character_splitting( @pytest.mark.parametrize( - "test_params", + "chunking", [ - ( - 1, - 1, - 1, - "chunking_character_text_splitter('content')", - "formatting_python_template('$chunk')", - ), - ( - 1, - 1, - 1, - "chunking_recursive_character_text_splitter('content')", - "formatting_python_template('$chunk')", - ), + "chunking_character_text_splitter('content')", + "chunking_recursive_character_text_splitter('content')", ], ) def test_vectorization_successful_with_null_contents( cli_db: tuple[PostgresContainer, Connection], cli_db_url: str, - configured_ollama_vectorizer_id: int, - test_params: tuple[int, int, int, str, str], # noqa: ARG001 + ollama_url: str, + chunking: str, ): _, conn = cli_db - + vectorizer_id = configure_ollama_vectorizer(conn, ollama_url, chunking=chunking) with conn.cursor(row_factory=dict_row) as cur: cur.execute("ALTER TABLE blog ALTER COLUMN content DROP NOT NULL;") cur.execute("UPDATE blog SET content = null;") - result = CliRunner().invoke( - vectorizer_worker, - [ - "--db-url", - cli_db_url, - "--once", - "--vectorizer-id", - str(configured_ollama_vectorizer_id), - ], - catch_exceptions=False, - ) + result = run_vectorizer_worker(cli_db_url, vectorizer_id) assert not result.exception assert result.exit_code == 0 @@ -1022,22 +809,10 @@ def test_vectorization_successful_with_null_contents( @pytest.mark.parametrize( - "test_params", + "num_items, concurrency, batch_size", [ - ( - 1, - 1, - 1, - "chunking_character_text_splitter('content')", - "formatting_python_template('$chunk')", - ), - ( - 4, - 2, - 2, - "chunking_character_text_splitter('content')", - "formatting_python_template('$chunk')", - ), + (1, 1, 1), + (4, 2, 2), ], ) @pytest.mark.parametrize( @@ -1065,12 +840,13 @@ def test_vectorization_successful_with_null_contents( ], ) def test_litellm_vectorizer( - source_table: str, cli_db: tuple[TestDatabase, Connection], cli_db_url: str, embedding: tuple[str, int, dict[str, Any], str], + num_items: int, + concurrency: int, + batch_size: int, vcr_: Any, - test_params: tuple[int, int, int, str, str], ): model, dimensions, extra_options, api_key_name = embedding function = "embedding_litellm" @@ -1083,11 +859,15 @@ def test_litellm_vectorizer( embedding_str = f"{function}('{model}', {dimensions}, extra_options => '{json.dumps(extra_options)}'::jsonb)" # noqa: E501 Line too long + source_table = setup_source_table(cli_db[1], num_items) vectorizer_id = configure_vectorizer( - source_table, cli_db, test_params, embedding_str + source_table, + connection=cli_db[1], + concurrency=concurrency, + batch_size=batch_size, + embedding=embedding_str, ) - num_items, concurrency, batch_size, _, _ = test_params _, conn = cli_db # Insert pre-existing embedding for first item with conn.cursor() as cur: @@ -1106,24 +886,7 @@ def test_litellm_vectorizer( cassette = f"{function}_{stripped_model}_{dimensions}_items_{num_items}_batch_size_{batch_size}.yaml" # noqa: E501 Line too long logging.getLogger("vcr").setLevel(logging.DEBUG) with vcr_.use_cassette(cassette): - result = CliRunner().invoke( - vectorizer_worker, - [ - "--db-url", - cli_db_url, - "--once", - "--vectorizer-id", - str(vectorizer_id), - "--concurrency", - str(concurrency), - "--log-level", - "debug", - ], - catch_exceptions=False, - ) - - if result.exception: - print(result.output) + result = run_vectorizer_worker(cli_db_url, vectorizer_id, concurrency) assert not result.exception assert result.exit_code == 0 diff --git a/projects/pgai/tests/vectorizer/utils.py b/projects/pgai/tests/vectorizer/utils.py new file mode 100644 index 000000000..e6ae8e1c1 --- /dev/null +++ b/projects/pgai/tests/vectorizer/utils.py @@ -0,0 +1,28 @@ +from click.testing import CliRunner, Result + +from pgai.cli import vectorizer_worker + + +def run_vectorizer_worker( + db_url: str, + vectorizer_id: int | None = None, + concurrency: int = 1, + extra_params: list[str] | None = None, +) -> Result: + args = [ + "--db-url", + db_url, + "--once", + "--concurrency", + str(concurrency), + ] + if vectorizer_id is not None: + args.extend(["--vectorizer-id", str(vectorizer_id)]) + if extra_params: + args.extend(extra_params) + + return CliRunner().invoke( + vectorizer_worker, + args, + catch_exceptions=False, + ) diff --git a/projects/pgai/uv.lock b/projects/pgai/uv.lock index 8fbb29064..36e8aecda 100644 --- a/projects/pgai/uv.lock +++ b/projects/pgai/uv.lock @@ -2053,7 +2053,6 @@ wheels = [ [[package]] name = "pgai" -version = "0.5.0" source = { editable = "." } dependencies = [ { name = "boto3" },