diff --git a/python-sdk/src/astro/databases/base.py b/python-sdk/src/astro/databases/base.py index a9d9cf8f6..dfb200a53 100644 --- a/python-sdk/src/astro/databases/base.py +++ b/python-sdk/src/astro/databases/base.py @@ -3,19 +3,13 @@ import logging import warnings from abc import ABC -from typing import TYPE_CHECKING, Any, Callable, Mapping +from typing import Any, Callable, Mapping import pandas as pd import sqlalchemy from airflow.hooks.dbapi import DbApiHook from pandas.io.sql import SQLDatabase from sqlalchemy import column, insert, select - -from astro.dataframes.pandas import PandasDataframe - -if TYPE_CHECKING: # pragma: no cover - from sqlalchemy.engine.cursor import CursorResult - from sqlalchemy.sql import ClauseElement from sqlalchemy.sql.elements import ColumnClause from sqlalchemy.sql.schema import Table as SqlaTable @@ -29,6 +23,7 @@ LoadExistStrategy, MergeConflictStrategy, ) +from astro.dataframes.pandas import PandasDataframe from astro.exceptions import DatabaseCustomError, NonExistentTableException from astro.files import File, resolve_file_path_pattern from astro.files.types import create_file_type @@ -63,8 +58,6 @@ class BaseDatabase(ABC): # illegal_column_name_chars[0] will be replaced by value in illegal_column_name_chars_replacement[0] illegal_column_name_chars: list[str] = [] illegal_column_name_chars_replacement: list[str] = [] - # In run_raw_sql operator decides if we want to return results directly or process them by handler provided - IGNORE_HANDLER_IN_RUN_RAW_SQL: bool = False NATIVE_PATHS: dict[Any, Any] = {} DEFAULT_SCHEMA = SCHEMA NATIVE_LOAD_EXCEPTIONS: Any = DatabaseCustomError @@ -107,8 +100,9 @@ def run_sql( self, sql: str | ClauseElement = "", parameters: dict | None = None, + handler: Callable | None = None, **kwargs, - ) -> CursorResult: + ) -> Any: """ Return the results to running a SQL statement. @@ -118,6 +112,7 @@ def run_sql( :param sql: Contains SQL query to be run against database :param parameters: Optional parameters to be used to render the query :param autocommit: Optional autocommit flag + :param handler: function that takes in a cursor as an argument. """ if parameters is None: parameters = {} @@ -139,7 +134,9 @@ def run_sql( ) else: result = self.connection.execute(sql, parameters) - return result + if handler: + return handler(result) + return None def columns_exist(self, table: BaseTable, columns: list[str]) -> bool: """ @@ -407,7 +404,7 @@ def create_schema_and_table_if_needed( use_native_support=use_native_support, ) - def fetch_all_rows(self, table: BaseTable, row_limit: int = -1) -> list: + def fetch_all_rows(self, table: BaseTable, row_limit: int = -1) -> Any: """ Fetches all rows for a table and returns as a list. This is needed because some databases have different cursors that require different methods to fetch rows @@ -419,8 +416,8 @@ def fetch_all_rows(self, table: BaseTable, row_limit: int = -1) -> list: statement = f"SELECT * FROM {self.get_table_qualified_name(table)}" if row_limit > -1: statement = statement + f" LIMIT {row_limit}" - response = self.run_sql(statement) - return response.fetchall() # type: ignore + response: list = self.run_sql(statement, handler=lambda x: x.fetchall()) + return response def load_file_to_table( self, @@ -777,8 +774,9 @@ def row_count(self, table: BaseTable): :return: The number of rows in the table """ result = self.run_sql( - f"select count(*) from {self.get_table_qualified_name(table)}" # skipcq: BAN-B608 - ).scalar() + f"select count(*) from {self.get_table_qualified_name(table)}", # skipcq: BAN-B608 + handler=lambda x: x.scalar(), + ) return result def parameterize_variable(self, variable: str): diff --git a/python-sdk/src/astro/databases/databricks/delta.py b/python-sdk/src/astro/databases/databricks/delta.py index 58aad79d9..c2bfa173d 100644 --- a/python-sdk/src/astro/databases/databricks/delta.py +++ b/python-sdk/src/astro/databases/databricks/delta.py @@ -4,6 +4,7 @@ import uuid import warnings from textwrap import dedent +from typing import Any, Callable import pandas as pd from airflow.providers.databricks.hooks.databricks import DatabricksHook @@ -25,9 +26,6 @@ class DeltaDatabase(BaseDatabase): LOAD_OPTIONS_CLASS_NAME = "DeltaLoadOptions" - # In run_raw_sql operator decides if we want to return results directly or process them by handler provided - # For delta tables we ignore the handler - IGNORE_HANDLER_IN_RUN_RAW_SQL: bool = True _create_table_statement: str = "CREATE TABLE IF NOT EXISTS {} USING DELTA AS {} " def __init__(self, conn_id: str, table: BaseTable | None = None, load_options: LoadOptions | None = None): @@ -197,9 +195,9 @@ def run_sql( self, sql: str | ClauseElement = "", parameters: dict | None = None, - handler=None, + handler: Callable | None = None, **kwargs, - ): + ) -> Any: """ Run SQL against a delta table using spark SQL. diff --git a/python-sdk/src/astro/databases/mssql.py b/python-sdk/src/astro/databases/mssql.py index 4180c5d01..cc8bd73b4 100644 --- a/python-sdk/src/astro/databases/mssql.py +++ b/python-sdk/src/astro/databases/mssql.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING +from typing import Any, Callable import pandas as pd import sqlalchemy @@ -17,8 +17,6 @@ from astro.utils.compat.functools import cached_property DEFAULT_CONN_ID = MsSqlHook.default_conn_name -if TYPE_CHECKING: # pragma: no cover - from sqlalchemy.engine.cursor import CursorResult class MssqlDatabase(BaseDatabase): @@ -145,8 +143,9 @@ def run_sql( self, sql: str | ClauseElement = "", parameters: dict | None = None, + handler: Callable | None = None, **kwargs, - ) -> CursorResult: + ) -> Any: """ Return the results to running a SQL statement. Whenever possible, this method should be implemented using Airflow Hooks, @@ -154,6 +153,7 @@ def run_sql( :param sql: Contains SQL query to be run against database :param parameters: Optional parameters to be used to render the query + :param handler: function that takes in a cursor as an argument. """ if parameters is None: parameters = {} @@ -177,11 +177,12 @@ def run_sql( result = self.connection.execute( sqlalchemy.text(sql).execution_options(autocommit=autocommit), parameters ) - return result else: # this is used for append result = self.connection.execute(sql, parameters) - return result + if handler: + return handler(result) + return None def create_schema_if_needed(self, schema: str | None) -> None: """ @@ -226,7 +227,7 @@ def drop_table(self, table: BaseTable) -> None: statement = self._drop_table_statement.format(self.get_table_qualified_name(table)) self.run_sql(statement, autocommit=True) - def fetch_all_rows(self, table: BaseTable, row_limit: int = -1) -> list: + def fetch_all_rows(self, table: BaseTable, row_limit: int = -1) -> Any: """ Fetches all rows for a table and returns as a list. This is needed because some databases have different cursors that require different methods to fetch rows @@ -238,8 +239,8 @@ def fetch_all_rows(self, table: BaseTable, row_limit: int = -1) -> list: statement = f"SELECT * FROM {self.get_table_qualified_name(table)}" # skipcq: BAN-B608 if row_limit > -1: statement = f"SELECT TOP {row_limit} * FROM {self.get_table_qualified_name(table)}" - response = self.run_sql(statement) - return response.fetchall() # type: ignore + response: list = self.run_sql(statement, handler=lambda x: x.fetchall()) + return response def load_pandas_dataframe_to_table( self, diff --git a/python-sdk/src/astro/sql/operators/raw_sql.py b/python-sdk/src/astro/sql/operators/raw_sql.py index 1ee535a8a..c8ff0b884 100644 --- a/python-sdk/src/astro/sql/operators/raw_sql.py +++ b/python-sdk/src/astro/sql/operators/raw_sql.py @@ -51,6 +51,7 @@ def __init__( def execute(self, context: Context) -> Any: super().execute(context) + self.handler = self.get_handler() result = self.database_impl.run_sql(sql=self.sql, parameters=self.parameters, handler=self.handler) if self.response_size == -1 and not settings.IS_CUSTOM_XCOM_BACKEND: logging.warning( @@ -60,22 +61,8 @@ def execute(self, context: Context) -> Any: "backend." ) - # ToDo: Currently, the handler param in run_sql() method is only used in databricks all other databases are - # not using it. Which leads to different response types since handler is processed within `run_sql()` for - # databricks and not for other databases. Also the signature of `run_sql()` in databricks deviates from base. - # We need to standardise and when we do, we can remove below check as well. - if self.database_impl.IGNORE_HANDLER_IN_RUN_RAW_SQL: - return result - - self.handler = self.get_handler() - if self.handler: - self.handler = self.get_wrapped_handler( - fail_on_empty=self.fail_on_empty, conversion_func=self.handler - ) - # otherwise, call the handler and convert the result to a list - response = self.handler(result) - response = self.make_row_serializable(response) + response = self.make_row_serializable(result) if 0 <= self.response_limit < len(response): raise IllegalLoadToDatabaseException() # pragma: no cover if self.response_size >= 0: diff --git a/python-sdk/tests/sql/operators/test_run_raw_sql.py b/python-sdk/tests/sql/operators/test_run_raw_sql.py index cecc44722..aff18f3b6 100644 --- a/python-sdk/tests/sql/operators/test_run_raw_sql.py +++ b/python-sdk/tests/sql/operators/test_run_raw_sql.py @@ -23,10 +23,10 @@ def test_make_row_serializable(rows): @mock.patch("astro.sql.operators.raw_sql.RawSQLOperator.results_as_list") -@mock.patch("astro.databases.base.BaseDatabase.run_sql") +@mock.patch("astro.databases.base.BaseDatabase.connection") def test_run_sql_calls_list_handler(run_sql, results_as_list, sample_dag): results_as_list.return_value = [] - run_sql.return_value = [] + run_sql.execute.return_value = [] with sample_dag: @aql.run_raw_sql(results_format="list", conn_id="sqlite_default") @@ -40,10 +40,10 @@ def dummy_method(): @mock.patch("astro.sql.operators.raw_sql.RawSQLOperator.results_as_pandas_dataframe") -@mock.patch("astro.databases.base.BaseDatabase.run_sql") +@mock.patch("astro.databases.base.BaseDatabase.connection") def test_run_sql_calls_pandas_dataframe_handler(run_sql, results_as_pandas_dataframe, sample_dag): results_as_pandas_dataframe.return_value = [] - run_sql.return_value = [] + run_sql.execute.return_value = [] with sample_dag: @aql.run_raw_sql(results_format="pandas_dataframe", conn_id="sqlite_default") @@ -57,13 +57,13 @@ def dummy_method(): @mock.patch("astro.sql.operators.raw_sql.RawSQLOperator.results_as_pandas_dataframe") -@mock.patch("astro.databases.base.BaseDatabase.run_sql") +@mock.patch("astro.databases.base.BaseDatabase.connection") def test_run_sql_gives_priority_to_pandas_dataframe_handler(run_sql, results_as_pandas_dataframe, sample_dag): """ Test that run_sql calls `results_format` specified handler over handler passed in decorator. """ results_as_pandas_dataframe.return_value = [] - run_sql.return_value = [] + run_sql.execute.return_value = [] with sample_dag: @aql.run_raw_sql( @@ -103,14 +103,11 @@ def dummy_method(): @mock.patch("astro.sql.operators.raw_sql.RawSQLOperator.results_as_pandas_dataframe") -@mock.patch("astro.databases.base.BaseDatabase.run_sql") -def test_run_sql_should_raise_exception(run_sql, results_as_pandas_dataframe, sample_dag): +def test_run_sql_should_raise_exception(results_as_pandas_dataframe, sample_dag): """ Test that run_sql should raise an exception when fail_on_empty=False """ results_as_pandas_dataframe.return_value = [] - return_value = [1, 2, 3] - run_sql.return_value = return_value def raise_exception(result): raise ValueError("dummy exception") @@ -156,7 +153,7 @@ def test_handlers(): class Val: def __init__(self, val): - self.value = [val] + self.value: list = [val] def values(self) -> list: return self.value diff --git a/python-sdk/tests_integration/databases/test_bigquery.py b/python-sdk/tests_integration/databases/test_bigquery.py index 645102023..98e64b064 100644 --- a/python-sdk/tests_integration/databases/test_bigquery.py +++ b/python-sdk/tests_integration/databases/test_bigquery.py @@ -38,8 +38,8 @@ def test_bigquery_run_sql(): """Test run_sql against bigquery database""" statement = "SELECT 1 + 1;" database = BigqueryDatabase(conn_id=DEFAULT_CONN_ID) - response = database.run_sql(statement) - assert response.first()[0] == 2 + response = database.run_sql(statement, handler=lambda x: x.first()) + assert response[0] == 2 @pytest.mark.integration @@ -77,12 +77,12 @@ def test_bigquery_create_table_with_columns(database_table_fixture): f"SELECT TABLE_CATALOG, TABLE_SCHEMA, TABLE_NAME, COLUMN_NAME, DATA_TYPE " f"FROM {table.metadata.schema}.INFORMATION_SCHEMA.COLUMNS WHERE table_name='{table.name}'" ) - response = database.run_sql(statement) - assert response.first() is None + response = database.run_sql(statement, handler=lambda x: x.first()) + assert response is None database.create_table(table) - response = database.run_sql(statement) - rows = response.fetchall() + response = database.run_sql(statement, handler=lambda x: x.fetchall()) + rows = response assert len(rows) == 2 assert rows[0] == ( "astronomer-dag-authoring", @@ -121,9 +121,9 @@ def test_load_pandas_dataframe_to_table(database_table_fixture): database.load_pandas_dataframe_to_table(pandas_dataframe, table) statement = f"SELECT * FROM {database.get_table_qualified_name(table)};" - response = database.run_sql(statement) + response = database.run_sql(statement, handler=lambda x: x.fetchall()) - rows = response.fetchall() + rows = response assert len(rows) == 2 assert rows[0] == (1,) assert rows[1] == (2,) diff --git a/python-sdk/tests_integration/databases/test_mssql.py b/python-sdk/tests_integration/databases/test_mssql.py index e0ecd6cac..85a0a6bcb 100644 --- a/python-sdk/tests_integration/databases/test_mssql.py +++ b/python-sdk/tests_integration/databases/test_mssql.py @@ -53,8 +53,8 @@ def test_mssql_run_sql(): """Test run_sql against mssql database""" statement = "SELECT 1 + 1;" database = MssqlDatabase(conn_id=CUSTOM_CONN_ID) - response = database.run_sql(statement) - assert response.first()[0] == 2 + response = database.run_sql(statement, handler=lambda x: x.first()) + assert response[0] == 2 @pytest.mark.integration @@ -88,12 +88,12 @@ def test_mssql_create_table_with_columns(database_table_fixture): database, table = database_table_fixture statement = f"SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name='{table.name}'" - response = database.run_sql(statement) - assert response.first() is None + response = database.run_sql(statement, handler=lambda x: x.first()) + assert response is None database.create_table(table) - response = database.run_sql(statement) - rows = response.fetchall() + response = database.run_sql(statement, handler=lambda x: x.fetchall()) + rows = response assert len(rows) == 2 assert rows[0][0:4] == ( "astrodb", @@ -126,9 +126,9 @@ def test_load_pandas_dataframe_to_table(database_table_fixture): database.load_pandas_dataframe_to_table(pandas_dataframe, table) statement = f"SELECT * FROM {database.get_table_qualified_name(table)};" - response = database.run_sql(statement) + response = database.run_sql(statement, handler=lambda x: x.fetchall()) - rows = response.fetchall() + rows = response assert len(rows) == 2 assert rows[0] == (1,) assert rows[1] == (2,) diff --git a/python-sdk/tests_integration/databases/test_postgres.py b/python-sdk/tests_integration/databases/test_postgres.py index 2b5746654..37bae614f 100644 --- a/python-sdk/tests_integration/databases/test_postgres.py +++ b/python-sdk/tests_integration/databases/test_postgres.py @@ -55,8 +55,8 @@ def test_postgres_run_sql(): """Test run_sql against postgres database""" statement = "SELECT 1 + 1;" database = PostgresDatabase(conn_id=CUSTOM_CONN_ID) - response = database.run_sql(statement) - assert response.first()[0] == 2 + response = database.run_sql(statement, handler=lambda x: x.first()) + assert response[0] == 2 @pytest.mark.integration @@ -90,12 +90,12 @@ def test_postgres_create_table_with_columns(database_table_fixture): database, table = database_table_fixture statement = f"SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name='{table.name}'" - response = database.run_sql(statement) - assert response.first() is None + response = database.run_sql(statement, handler=lambda x: x.first()) + assert response is None database.create_table(table) - response = database.run_sql(statement) - rows = response.fetchall() + response = database.run_sql(statement, handler=lambda x: x.fetchall()) + rows = response assert len(rows) == 2 assert rows[0][0:4] == ( "postgres", @@ -128,9 +128,9 @@ def test_load_pandas_dataframe_to_table(database_table_fixture): database.load_pandas_dataframe_to_table(pandas_dataframe, table) statement = f"SELECT * FROM {database.get_table_qualified_name(table)};" - response = database.run_sql(statement) + response = database.run_sql(statement, handler=lambda x: x.fetchall()) - rows = response.fetchall() + rows = response assert len(rows) == 2 assert rows[0] == (1,) assert rows[1] == (2,) diff --git a/python-sdk/tests_integration/databases/test_redshift.py b/python-sdk/tests_integration/databases/test_redshift.py index f463e4613..9242f7f3c 100644 --- a/python-sdk/tests_integration/databases/test_redshift.py +++ b/python-sdk/tests_integration/databases/test_redshift.py @@ -38,8 +38,8 @@ def test_redshift_run_sql(): """Test run_sql against redshift database""" statement = "SELECT 1 + 1;" database = RedshiftDatabase(conn_id=CUSTOM_CONN_ID) - response = database.run_sql(statement) - assert response.first()[0] == 2 + response = database.run_sql(statement, handler=lambda x: x.first()) + assert response[0] == 2 @pytest.mark.integration @@ -73,12 +73,12 @@ def test_redshift_create_table_with_columns(database_table_fixture): database, table = database_table_fixture statement = f"SELECT * FROM INFORMATION_SCHEMA.COLUMNS WHERE table_name='{table.name}'" - response = database.run_sql(statement) - assert response.first() is None + response = database.run_sql(statement, handler=lambda x: x.first()) + assert response is None database.create_table(table) - response = database.run_sql(statement) - rows = response.fetchall() + response = database.run_sql(statement, handler=lambda x: x.fetchall()) + rows = response assert len(rows) == 2 assert rows[0][0:4] == ( "dev", @@ -111,9 +111,9 @@ def test_load_pandas_dataframe_to_table(database_table_fixture): database.load_pandas_dataframe_to_table(pandas_dataframe, table) statement = f"SELECT * FROM {database.get_table_qualified_name(table)};" - response = database.run_sql(statement) + response = database.run_sql(statement, handler=lambda x: x.fetchall()) - rows = response.fetchall() + rows = response assert len(rows) == 2 assert rows[0] == (1,) assert rows[1] == (2,) diff --git a/python-sdk/tests_integration/databases/test_snowflake.py b/python-sdk/tests_integration/databases/test_snowflake.py index 0ca5af3fb..8a1321125 100644 --- a/python-sdk/tests_integration/databases/test_snowflake.py +++ b/python-sdk/tests_integration/databases/test_snowflake.py @@ -39,8 +39,8 @@ def test_snowflake_run_sql(): """Test run_sql against snowflake database""" statement = "SELECT 1 + 1;" database = SnowflakeDatabase(conn_id=CUSTOM_CONN_ID) - response = database.run_sql(statement) - assert response.first()[0] == 2 + response = database.run_sql(statement, handler=lambda x: x.first()) + assert response[0] == 2 @pytest.mark.integration @@ -79,8 +79,8 @@ def test_snowflake_create_table_with_columns(database_table_fixture): assert e.match("does not exist or not authorized") database.create_table(table) - response = database.run_sql(statement) - rows = response.fetchall() + response = database.run_sql(statement, handler=lambda x: x.fetchall()) + rows = response assert len(rows) == 2 assert rows[0] == ( "ID", @@ -137,11 +137,11 @@ def test_snowflake_create_table_using_native_schema_autodetection( file = File("s3://astro-sdk/sample.parquet", conn_id="aws_conn") database.create_table(table, file) - response = database.run_sql(statement) - rows = response.fetchall() + response = database.run_sql(statement, handler=lambda x: x.fetchall()) + rows = response assert len(rows) == 2 statement = f"SELECT COUNT(*) FROM {database.get_table_qualified_name(table)}" - count = database.run_sql(statement).scalar() + count = database.run_sql(statement, handler=lambda x: x.scalar()) assert count == 0 @@ -165,9 +165,9 @@ def test_load_pandas_dataframe_to_table(database_table_fixture): database.load_pandas_dataframe_to_table(pandas_dataframe, table) statement = f"SELECT * FROM {database.get_table_qualified_name(table)}" - response = database.run_sql(statement) + response = database.run_sql(statement, handler=lambda x: x.fetchall()) - rows = response.fetchall() + rows = response assert len(rows) == 2 assert rows[0] == (1,) assert rows[1] == (2,) diff --git a/python-sdk/tests_integration/databases/test_sqlite.py b/python-sdk/tests_integration/databases/test_sqlite.py index 04782a0d6..1f6c005ea 100644 --- a/python-sdk/tests_integration/databases/test_sqlite.py +++ b/python-sdk/tests_integration/databases/test_sqlite.py @@ -56,8 +56,8 @@ def test_sqlite_run_sql_with_sqlalchemy_text(): """Run a SQL statement using SQLAlchemy text""" statement = sqlalchemy.text("SELECT 1 + 1;") database = SqliteDatabase() - response = database.run_sql(statement) - assert response.first()[0] == 2 + response = database.run_sql(statement, handler=lambda x: x.first()) + assert response[0] == 2 @pytest.mark.integration @@ -65,8 +65,8 @@ def test_sqlite_run_sql(): """Run a SQL statement using plain string.""" statement = "SELECT 1 + 1;" database = SqliteDatabase() - response = database.run_sql(statement) - assert response.first()[0] == 2 + response = database.run_sql(statement, handler=lambda x: x.first()) + assert response[0] == 2 @pytest.mark.integration @@ -74,8 +74,8 @@ def test_sqlite_run_sql_with_parameters(): """Test running a SQL query using SQLAlchemy templating engine""" statement = "SELECT 1 + :value;" database = SqliteDatabase() - response = database.run_sql(statement, parameters={"value": 1}) - assert response.first()[0] == 2 + response = database.run_sql(statement, parameters={"value": 1}, handler=lambda x: x.first()) + assert response[0] == 2 @pytest.mark.integration @@ -107,12 +107,12 @@ def test_sqlite_create_table_with_columns(database_table_fixture): database, table = database_table_fixture statement = f"PRAGMA table_info({table.name});" - response = database.run_sql(statement) - assert response.first() is None + response = database.run_sql(statement, handler=lambda x: x.first()) + assert response is None database.create_table(table) - response = database.run_sql(statement) - rows = response.fetchall() + response = database.run_sql(statement, handler=lambda x: x.fetchall()) + rows = response assert len(rows) == 2 assert rows[0] == (0, "id", "INTEGER", 1, None, 1) assert rows[1] == (1, "name", "VARCHAR(60)", 1, None, 0) @@ -130,13 +130,13 @@ def test_sqlite_create_table_autodetection_with_file(database_table_fixture): database, table = database_table_fixture statement = f"PRAGMA table_info({table.name});" - response = database.run_sql(statement) - assert response.first() is None + response = database.run_sql(statement, handler=lambda x: x.first()) + assert response is None filepath = str(pathlib.Path(CWD.parent, "data/sample.csv")) database.create_table(table, File(filepath)) - response = database.run_sql(statement) - rows = response.fetchall() + response = database.run_sql(statement, handler=lambda x: x.fetchall()) + rows = response assert len(rows) == 2 assert rows[0] == (0, "id", "BIGINT", 0, None, 0) assert rows[1] == (1, "name", "TEXT", 0, None, 0) @@ -154,8 +154,8 @@ def test_sqlite_create_table_autodetection_without_file(database_table_fixture): database, table = database_table_fixture statement = f"PRAGMA table_info({table.name});" - response = database.run_sql(statement) - assert response.first() is None + response = database.run_sql(statement, handler=lambda x: x.first()) + assert response is None with pytest.raises(ValueError) as exc_info: database.create_table(table) @@ -179,9 +179,9 @@ def test_load_pandas_dataframe_to_table(database_table_fixture): database.load_pandas_dataframe_to_table(pandas_dataframe, table) statement = f"SELECT * FROM {table.name};" - response = database.run_sql(statement) + response = database.run_sql(statement, handler=lambda x: x.fetchall()) - rows = response.fetchall() + rows = response assert len(rows) == 2 assert rows[0] == (1,) assert rows[1] == (2,)