From 27429d7176534595de71b8199c53a33fe2c2c5bc Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Tue, 31 Jan 2023 00:31:11 +0530 Subject: [PATCH 1/3] Add fix and testcase for run_raw_sql --- .../src/astro/sql/operators/base_decorator.py | 20 +++++----- .../sql/operators/test_raw_sql.py | 40 +++++++++++++++++-- 2 files changed, 47 insertions(+), 13 deletions(-) diff --git a/python-sdk/src/astro/sql/operators/base_decorator.py b/python-sdk/src/astro/sql/operators/base_decorator.py index b12c46e5c..3d3ed2940 100644 --- a/python-sdk/src/astro/sql/operators/base_decorator.py +++ b/python-sdk/src/astro/sql/operators/base_decorator.py @@ -88,14 +88,12 @@ def execute(self, context: Context) -> None: # Find and load dataframes from op_arg and op_kwarg into Table self.create_output_table_if_needed() self.op_args = load_op_arg_dataframes_into_sql( # type: ignore - conn_id=self.conn_id, - op_args=self.op_args, # type: ignore - target_table=self.output_table.create_similar_table(), + conn_id=self.conn_id, op_args=self.op_args, output_table=self.output_table # type: ignore ) self.op_kwargs = load_op_kwarg_dataframes_into_sql( conn_id=self.conn_id, op_kwargs=self.op_kwargs, - target_table=self.output_table.create_similar_table(), + output_table=self.output_table, ) # The transform decorator doesn't explicitly pass output_table as a @@ -290,19 +288,20 @@ def get_source_code(self, py_callable: Callable) -> str | None: return None -def load_op_arg_dataframes_into_sql(conn_id: str, op_args: tuple, target_table: BaseTable) -> tuple: +def load_op_arg_dataframes_into_sql(conn_id: str, op_args: tuple, output_table: BaseTable) -> tuple: """ Identify dataframes in op_args and load them to the table. :param conn_id: Connection identifier to be used to load content to the target_table :param op_args: user-defined decorator's kwargs - :param target_table: Table where the dataframe content will be written to + :param output_table: Similar table where the dataframe content will be written to :return: New op_args, in which dataframes are replaced by tables """ - final_args = [] + final_args: list[Table | BaseTable] = [] database = create_database(conn_id=conn_id) for arg in op_args: if isinstance(arg, pd.DataFrame): + target_table = output_table.create_similar_table() database.load_pandas_dataframe_to_table(source_dataframe=arg, target_table=target_table) final_args.append(target_table) elif isinstance(arg, BaseTable): @@ -313,19 +312,20 @@ def load_op_arg_dataframes_into_sql(conn_id: str, op_args: tuple, target_table: return tuple(final_args) -def load_op_kwarg_dataframes_into_sql(conn_id: str, op_kwargs: dict, target_table: BaseTable) -> dict: +def load_op_kwarg_dataframes_into_sql(conn_id: str, op_kwargs: dict, output_table: BaseTable) -> dict: """ Identify dataframes in op_kwargs and load them to a table. :param conn_id: Connection identifier to be used to load content to the target_table :param op_kwargs: user-defined decorator's kwargs - :param target_table: Table where the dataframe content will be written to + :param output_table: Similar table where the dataframe content will be written to :return: New op_kwargs, in which dataframes are replaced by tables """ final_kwargs = {} - database = create_database(conn_id=conn_id, table=target_table) + database = create_database(conn_id=conn_id, table=output_table) for key, value in op_kwargs.items(): if isinstance(value, pd.DataFrame): + target_table = output_table.create_similar_table() df_table = cast(BaseTable, target_table.create_similar_table()) database.load_pandas_dataframe_to_table(source_dataframe=value, target_table=df_table) final_kwargs[key] = df_table diff --git a/python-sdk/tests_integration/sql/operators/test_raw_sql.py b/python-sdk/tests_integration/sql/operators/test_raw_sql.py index b415b03b3..8b32b5091 100644 --- a/python-sdk/tests_integration/sql/operators/test_raw_sql.py +++ b/python-sdk/tests_integration/sql/operators/test_raw_sql.py @@ -1,13 +1,14 @@ import logging import pathlib -import pandas +import pandas as pd import pytest from airflow.decorators import task from astro import sql as aql from astro.constants import Database from astro.files import File +from astro.table import BaseTable from ..operators import utils as test_utils @@ -166,8 +167,8 @@ def raw_sql_query(input_table): @task def assert_num_rows(result): - assert isinstance(result, pandas.DataFrame) - assert result.equals(pandas.read_csv(DATA_FILEPATH)) + assert isinstance(result, pd.DataFrame) + assert result.equals(pd.read_csv(DATA_FILEPATH)) assert result.shape == (3, 2) with sample_dag: @@ -208,5 +209,38 @@ def assert_num_rows(result): with sample_dag: results = raw_sql_query(input_table=test_table) assert_num_rows(results) + test_utils.run_dag(sample_dag) + +@pytest.mark.integration +@pytest.mark.parametrize( + "database_table_fixture", + [{"database": Database.SQLITE, "file": File(path=str(DATA_FILEPATH))}], + indirect=True, + ids=["sqlite"], +) +def test_run_raw_sql_handle_multiple_tables(sample_dag, database_table_fixture): + """ + Handle the case when we are passing multiple dataframe to run_raw_sql() operator + and all the dataframes are converted to different tables. + """ + _, test_table = database_table_fixture + + @aql.run_raw_sql(handler=lambda x: pd.DataFrame(x.fetchall(), columns=x.keys())) + def raw_sql_query_1(input_table: BaseTable): + return "SELECT * from {{input_table}}" + + @aql.run_raw_sql(handler=lambda x: pd.DataFrame(x.fetchall(), columns=x.keys())) + def raw_sql_query_2(input_table: BaseTable): + return "SELECT * from {{input_table}}" + + @aql.run_raw_sql(handler=lambda x: pd.DataFrame(x.fetchall(), columns=x.keys()), conn_id="sqlite_default") + def raw_sql_query_3(table_1: BaseTable, table_2: BaseTable): + assert table_1.name != table_2.name + return "SELECT 1 + 1" + + with sample_dag: + results_1 = raw_sql_query_1(input_table=test_table) + results_2 = raw_sql_query_2(input_table=test_table) + _ = raw_sql_query_3(table_1=results_1, table_2=results_2) test_utils.run_dag(sample_dag) From 90f898ad7334697c3de8254d85dd0cd313d36ae0 Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Thu, 2 Feb 2023 19:46:20 +0530 Subject: [PATCH 2/3] Add unit testcases --- .../sql/operators/test_base_decorator.py | 40 ++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) diff --git a/python-sdk/tests/sql/operators/test_base_decorator.py b/python-sdk/tests/sql/operators/test_base_decorator.py index 1d7e56ba7..8c876a100 100644 --- a/python-sdk/tests/sql/operators/test_base_decorator.py +++ b/python-sdk/tests/sql/operators/test_base_decorator.py @@ -1,9 +1,15 @@ from unittest import mock +import pandas as pd import pytest from astro.sql import RawSQLOperator -from astro.sql.operators.base_decorator import BaseSQLDecoratedOperator +from astro.sql.operators.base_decorator import ( + BaseSQLDecoratedOperator, + load_op_arg_dataframes_into_sql, + load_op_kwarg_dataframes_into_sql, +) +from astro.table import BaseTable, Table def test_base_sql_decorated_operator_template_fields_with_parameters(): @@ -22,3 +28,35 @@ def test_get_source_code_handle_exception(mock_getsource, exception): RawSQLOperator(task_id="test", sql="select * from 1", python_callable=lambda: 1).get_source_code( py_callable=None ) + + +def test_load_op_arg_dataframes_into_sql(): + df_1 = pd.DataFrame({"col1": [1, 2], "col2": [3, 4]}) + df_2 = pd.DataFrame({"col1": [1, 2], "col2": [3, 4]}) + op_args = (df_1, df_2, Table(conn_id="sqlite_default"), "str") + results = load_op_arg_dataframes_into_sql( + conn_id="sqlite_default", op_args=op_args, output_table=Table(conn_id="sqlite_default") + ) + + assert isinstance(results[0], BaseTable) + assert isinstance(results[1], BaseTable) + assert results[0].name != results[1].name + + assert isinstance(results[2], BaseTable) + assert isinstance(results[3], str) + + +def test_load_op_kwarg_dataframes_into_sql(): + df_1 = pd.DataFrame({"col1": [1, 2], "col2": [3, 4]}) + df_2 = pd.DataFrame({"col1": [1, 2], "col2": [3, 4]}) + op_kwargs = {"df_1": df_1, "df_2": df_2, "table": Table(conn_id="sqlite_default"), "some_str": "str"} + results = load_op_kwarg_dataframes_into_sql( + conn_id="sqlite_default", op_kwargs=op_kwargs, output_table=Table(conn_id="sqlite_default") + ) + + assert isinstance(results["df_1"], BaseTable) + assert isinstance(results["df_2"], BaseTable) + assert results["df_1"].name != results["df_2"].name + + assert isinstance(results["table"], BaseTable) + assert isinstance(results["some_str"], str) From 709e02418dac68b8f5790995da4619139aae1cd0 Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Fri, 3 Feb 2023 01:08:36 +0530 Subject: [PATCH 3/3] Used PandasDataframe instead of native pandas --- .../tests_integration/sql/operators/test_raw_sql.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/python-sdk/tests_integration/sql/operators/test_raw_sql.py b/python-sdk/tests_integration/sql/operators/test_raw_sql.py index 8b32b5091..42bd8331a 100644 --- a/python-sdk/tests_integration/sql/operators/test_raw_sql.py +++ b/python-sdk/tests_integration/sql/operators/test_raw_sql.py @@ -7,6 +7,7 @@ from astro import sql as aql from astro.constants import Database +from astro.dataframes.pandas import PandasDataframe from astro.files import File from astro.table import BaseTable @@ -226,15 +227,17 @@ def test_run_raw_sql_handle_multiple_tables(sample_dag, database_table_fixture): """ _, test_table = database_table_fixture - @aql.run_raw_sql(handler=lambda x: pd.DataFrame(x.fetchall(), columns=x.keys())) + @aql.run_raw_sql(handler=lambda x: PandasDataframe(x.fetchall(), columns=x.keys())) def raw_sql_query_1(input_table: BaseTable): return "SELECT * from {{input_table}}" - @aql.run_raw_sql(handler=lambda x: pd.DataFrame(x.fetchall(), columns=x.keys())) + @aql.run_raw_sql(handler=lambda x: PandasDataframe(x.fetchall(), columns=x.keys())) def raw_sql_query_2(input_table: BaseTable): return "SELECT * from {{input_table}}" - @aql.run_raw_sql(handler=lambda x: pd.DataFrame(x.fetchall(), columns=x.keys()), conn_id="sqlite_default") + @aql.run_raw_sql( + handler=lambda x: PandasDataframe(x.fetchall(), columns=x.keys()), conn_id="sqlite_default" + ) def raw_sql_query_3(table_1: BaseTable, table_2: BaseTable): assert table_1.name != table_2.name return "SELECT 1 + 1"