From ddb8ea434a01e8a66219e7155d387774f69fb696 Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Mon, 14 Nov 2022 02:48:39 +0530 Subject: [PATCH 01/34] Saving work --- .../data_validations/ColumnCheckOperator.py | 146 ++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py diff --git a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py new file mode 100644 index 000000000..d5cc1122e --- /dev/null +++ b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py @@ -0,0 +1,146 @@ +from typing import Any, Dict, Optional, Union + +import pandas +from airflow.providers.common.sql.operators.sql import SQLColumnCheckOperator + +from astro.databases import create_database +from astro.files import File +from astro.table import BaseTable +from astro.utils.typing_compat import Context + + +class ColumnCheckOperator(SQLColumnCheckOperator): + """ + Performs one or more of the templated checks in the column_checks dictionary. + Checks are performed on a per-column basis specified by the column_mapping. + Each check can take one or more of the following options: + - equal_to: an exact value to equal, cannot be used with other comparison options + - greater_than: value that result should be strictly greater than + - less_than: value that results should be strictly less than + - geq_to: value that results should be greater than or equal to + - leq_to: value that results should be less than or equal to + - tolerance: the percentage that the result may be off from the expected value + + :param table: the table to run checks on + :param column_mapping: the dictionary of columns and their associated checks, e.g. + + .. code-block:: python + + { + "col_name": { + "null_check": { + "equal_to": 0, + }, + "min": { + "greater_than": 5, + "leq_to": 10, + "tolerance": 0.2, + }, + "max": {"less_than": 1000, "geq_to": 10, "tolerance": 0.01}, + } + } + """ + + def __init__( + self, + dataset: Union[BaseTable, pandas.DataFrame, File], + column_mapping: Dict[str, Dict[str, Any]], + partition_clause: Optional[str] = None, + **kwargs, + ): + for checks in column_mapping.values(): + for check, check_values in checks.items(): + self._column_mapping_validation(check, check_values) + + self.dataset = dataset + self.column_mapping = column_mapping + self.partition_clause = partition_clause + self.kwargs = kwargs + self.df = None + if type(dataset) == BaseTable: + db = create_database(conn_id=self.dataset.conn_id) + super().__init__( + table=db.get_table_qualified_name(table=self.dataset), + column_mapping=self.column_mapping, + partition_clause=self.partition_clause, + conn_id=dataset.conn_id, + database=dataset.metadata.database, + ) + + def execute(self, context: "Context"): + if type(self.dataset) == BaseTable: + return super().execute(context=context) + elif type(self.dataset) == File: + self.df = self.dataset.export_to_dataframe() + elif type(self.dataset) == pandas.DataFrame: + self.df = self.dataset + else: + raise ValueError("dataset can only be of type File | pandas.dataframe | Table object") + + self.process_checks() + + def process_checks(self): + column_checks = { + "null_check": self.col_null_check, + "distinct_check": self.col_distinct_check, + "unique_check": self.col_unique_check, + "min": self.col_max, + "max": self.col_min, + } + failed_tests = [] + for column in self.column_mapping: + checks = self.column_mapping[column] + for check in checks: + tolerance = self.column_mapping[column][check].get("tolerance") + result = column_checks[check](column_name=column) + self.column_mapping[column][check]["result"] = result + self.column_mapping[column][check]["success"] = self._get_match( + self.column_mapping[column][check], result, tolerance + ) + failed_tests.extend(_get_failed_checks(self.column_mapping[column], column)) + if failed_tests: + pass + # raise AirflowException( + # f"Test failed.\nResults:\n{records!s}\n" + # "The following tests have failed:" + # f"\n{''.join(failed_tests)}" + # ) + + def col_null_check(self, column_name: str) -> list: + if self.df and self.df[column_name]: + return self.df[column_name].isnull().values.any() + return [] + + def col_distinct_check(self, column_name: str) -> list: + if self.df and self.df[column_name]: + return self.df[column_name].unique() + return [] + + def col_unique_check(self, column_name: str) -> Optional[bool]: + if self.df and self.df[column_name]: + return len(self.df[column_name].unique()) == 1 + return None + + def col_max(self, column_name: str) -> Optional[float]: + if self.df and self.df[column_name]: + return self.df[column_name].max() + return None + + def col_min(self, column_name: str) -> Optional[float]: + if self.df and self.df[column_name]: + return self.df[column_name].min() + return None + + +def _get_failed_checks(checks, col=None): + if col: + return [ + f"Column: {col}\nCheck: {check},\nCheck Values: {check_values}\n" + for check, check_values in checks.items() + if not check_values["success"] + ] + return [ + f"\tCheck: {check},\n\tCheck Values: {check_values}\n" + for check, check_values in checks.items() + if not check_values["success"] + ] From 8996d25c5630d6ae2ebc88b12626edf27f09512a Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Mon, 5 Dec 2022 15:35:35 +0530 Subject: [PATCH 02/34] Added test case --- python-sdk/src/astro/sql/__init__.py | 1 + .../data_validations/ColumnCheckOperator.py | 61 ++++++++++++++++--- .../sql/operators/test_ColumnCheckOperator.py | 19 ++++++ 3 files changed, 73 insertions(+), 8 deletions(-) create mode 100644 python-sdk/tests/sql/operators/test_ColumnCheckOperator.py diff --git a/python-sdk/src/astro/sql/__init__.py b/python-sdk/src/astro/sql/__init__.py index eb832334c..df8173221 100644 --- a/python-sdk/src/astro/sql/__init__.py +++ b/python-sdk/src/astro/sql/__init__.py @@ -4,6 +4,7 @@ from astro.sql.operators.append import AppendOperator, append from astro.sql.operators.cleanup import CleanupOperator, cleanup +from astro.sql.operators.data_validations.ColumnCheckOperator import ColumnCheckOperator, column_check from astro.sql.operators.dataframe import DataframeOperator, dataframe from astro.sql.operators.drop import DropTableOperator, drop_table from astro.sql.operators.export_file import ExportFileOperator, export_file diff --git a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py index d5cc1122e..620372f97 100644 --- a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py +++ b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py @@ -1,6 +1,7 @@ from typing import Any, Dict, Optional, Union import pandas +from airflow.models.xcom_arg import XComArg from airflow.providers.common.sql.operators.sql import SQLColumnCheckOperator from astro.databases import create_database @@ -58,7 +59,7 @@ def __init__( self.kwargs = kwargs self.df = None if type(dataset) == BaseTable: - db = create_database(conn_id=self.dataset.conn_id) + db = create_database(conn_id=self.dataset.conn_id) # type: ignore super().__init__( table=db.get_table_qualified_name(table=self.dataset), column_mapping=self.column_mapping, @@ -70,8 +71,8 @@ def __init__( def execute(self, context: "Context"): if type(self.dataset) == BaseTable: return super().execute(context=context) - elif type(self.dataset) == File: - self.df = self.dataset.export_to_dataframe() + # elif type(self.dataset) == File: + # self.df = self.dataset.export_to_dataframe() elif type(self.dataset) == pandas.DataFrame: self.df = self.dataset else: @@ -107,27 +108,27 @@ def process_checks(self): # ) def col_null_check(self, column_name: str) -> list: - if self.df and self.df[column_name]: + if self.df is not None and column_name in self.df.columns: return self.df[column_name].isnull().values.any() return [] def col_distinct_check(self, column_name: str) -> list: - if self.df and self.df[column_name]: + if self.df is not None and column_name in self.df.columns: return self.df[column_name].unique() return [] def col_unique_check(self, column_name: str) -> Optional[bool]: - if self.df and self.df[column_name]: + if self.df is not None and column_name in self.df.columns: return len(self.df[column_name].unique()) == 1 return None def col_max(self, column_name: str) -> Optional[float]: - if self.df and self.df[column_name]: + if self.df is not None and column_name in self.df.columns: return self.df[column_name].max() return None def col_min(self, column_name: str) -> Optional[float]: - if self.df and self.df[column_name]: + if self.df is not None and column_name in self.df.columns: return self.df[column_name].min() return None @@ -144,3 +145,47 @@ def _get_failed_checks(checks, col=None): for check, check_values in checks.items() if not check_values["success"] ] + + +def column_check( + dataset: Union[BaseTable, pandas.DataFrame, File], + column_mapping: Dict[str, Dict[str, Any]], + partition_clause: Optional[str] = None, + **kwargs, +) -> XComArg: + """ + Performs one or more of the templated checks in the column_checks dictionary. + Checks are performed on a per-column basis specified by the column_mapping. + Each check can take one or more of the following options: + - equal_to: an exact value to equal, cannot be used with other comparison options + - greater_than: value that result should be strictly greater than + - less_than: value that results should be strictly less than + - geq_to: value that results should be greater than or equal to + - leq_to: value that results should be less than or equal to + - tolerance: the percentage that the result may be off from the expected value + + :param table: the table to run checks on + :param column_mapping: the dictionary of columns and their associated checks, e.g. + + .. code-block:: python + + { + "col_name": { + "null_check": { + "equal_to": 0, + }, + "min": { + "greater_than": 5, + "leq_to": 10, + "tolerance": 0.2, + }, + "max": {"less_than": 1000, "geq_to": 10, "tolerance": 0.01}, + } + } + """ + return ColumnCheckOperator( + dataset=dataset, + column_mapping=column_mapping, + partition_clause=partition_clause, + kwargs=kwargs, + ).output diff --git a/python-sdk/tests/sql/operators/test_ColumnCheckOperator.py b/python-sdk/tests/sql/operators/test_ColumnCheckOperator.py new file mode 100644 index 000000000..e8e81ba80 --- /dev/null +++ b/python-sdk/tests/sql/operators/test_ColumnCheckOperator.py @@ -0,0 +1,19 @@ +import pathlib + +from astro import sql as aql +from astro.files import File + +CWD = pathlib.Path(__file__).parent + + +def test_column_check_operator(sample_dag): + aql.ColumnCheckOperator( + dataset=File(path=str(CWD) + "/../../data/homes2.csv"), + column_mapping={ + "sell": { + "null_check": { + "equal_to": 1, + } + } + }, + ).execute({}) From 595bd02c219fbd104e4bf58589c16e44bd7c2915 Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Wed, 7 Dec 2022 15:53:25 +0530 Subject: [PATCH 03/34] Add testcases for ColumnCheckOperator --- .../data_validations/ColumnCheckOperator.py | 74 ++++-- .../sql/operators/test_ColumnCheckOperator.py | 237 +++++++++++++++++- 2 files changed, 279 insertions(+), 32 deletions(-) diff --git a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py index 620372f97..85d04ac2d 100644 --- a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py +++ b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py @@ -1,11 +1,11 @@ from typing import Any, Dict, Optional, Union import pandas +from airflow import AirflowException from airflow.models.xcom_arg import XComArg from airflow.providers.common.sql.operators.sql import SQLColumnCheckOperator from astro.databases import create_database -from astro.files import File from astro.table import BaseTable from astro.utils.typing_compat import Context @@ -22,7 +22,7 @@ class ColumnCheckOperator(SQLColumnCheckOperator): - leq_to: value that results should be less than or equal to - tolerance: the percentage that the result may be off from the expected value - :param table: the table to run checks on + :param dataset: the table or dataframe to run checks on :param column_mapping: the dictionary of columns and their associated checks, e.g. .. code-block:: python @@ -44,7 +44,7 @@ class ColumnCheckOperator(SQLColumnCheckOperator): def __init__( self, - dataset: Union[BaseTable, pandas.DataFrame, File], + dataset: Union[BaseTable, pandas.DataFrame], column_mapping: Dict[str, Dict[str, Any]], partition_clause: Optional[str] = None, **kwargs, @@ -71,8 +71,6 @@ def __init__( def execute(self, context: "Context"): if type(self.dataset) == BaseTable: return super().execute(context=context) - # elif type(self.dataset) == File: - # self.df = self.dataset.export_to_dataframe() elif type(self.dataset) == pandas.DataFrame: self.df = self.dataset else: @@ -80,46 +78,54 @@ def execute(self, context: "Context"): self.process_checks() - def process_checks(self): + def get_check_method(self, check_name: str, column_name: str): column_checks = { "null_check": self.col_null_check, "distinct_check": self.col_distinct_check, "unique_check": self.col_unique_check, - "min": self.col_max, - "max": self.col_min, + "min": self.col_min, + "max": self.col_max, } + return column_checks[check_name](column_name=column_name) + + def process_checks(self): + failed_tests = [] + passed_tests = [] + + # Iterating over columns for column in self.column_mapping: checks = self.column_mapping[column] + + # Iterating over checks for check in checks: tolerance = self.column_mapping[column][check].get("tolerance") - result = column_checks[check](column_name=column) + result = self.get_check_method(check, column_name=column) self.column_mapping[column][check]["result"] = result self.column_mapping[column][check]["success"] = self._get_match( self.column_mapping[column][check], result, tolerance ) failed_tests.extend(_get_failed_checks(self.column_mapping[column], column)) - if failed_tests: - pass - # raise AirflowException( - # f"Test failed.\nResults:\n{records!s}\n" - # "The following tests have failed:" - # f"\n{''.join(failed_tests)}" - # ) - - def col_null_check(self, column_name: str) -> list: + passed_tests.extend(_get_success_checks(self.column_mapping[column], column)) + + if len(failed_tests) > 0: + raise AirflowException(f"The following tests have failed:" f"\n{''.join(failed_tests)}") + if len(passed_tests) > 0: + print(f"The following tests have passed:" f"\n{''.join(passed_tests)}") + + def col_null_check(self, column_name: str) -> Optional[int]: if self.df is not None and column_name in self.df.columns: - return self.df[column_name].isnull().values.any() - return [] + return list(self.df[column_name].isnull().values).count(True) + return None - def col_distinct_check(self, column_name: str) -> list: + def col_distinct_check(self, column_name: str) -> Optional[int]: if self.df is not None and column_name in self.df.columns: - return self.df[column_name].unique() - return [] + return len(self.df[column_name].unique()) + return None - def col_unique_check(self, column_name: str) -> Optional[bool]: + def col_unique_check(self, column_name: str) -> Optional[int]: if self.df is not None and column_name in self.df.columns: - return len(self.df[column_name].unique()) == 1 + return len(self.df[column_name]) - self.col_distinct_check(column_name=column_name) return None def col_max(self, column_name: str) -> Optional[float]: @@ -147,8 +153,22 @@ def _get_failed_checks(checks, col=None): ] +def _get_success_checks(checks, col=None): + if col: + return [ + f"Column: {col}\nCheck: {check},\nCheck Values: {check_values}\n" + for check, check_values in checks.items() + if check_values["success"] + ] + return [ + f"\tCheck: {check},\n\tCheck Values: {check_values}\n" + for check, check_values in checks.items() + if check_values["success"] + ] + + def column_check( - dataset: Union[BaseTable, pandas.DataFrame, File], + dataset: Union[BaseTable, pandas.DataFrame], column_mapping: Dict[str, Dict[str, Any]], partition_clause: Optional[str] = None, **kwargs, @@ -164,7 +184,7 @@ def column_check( - leq_to: value that results should be less than or equal to - tolerance: the percentage that the result may be off from the expected value - :param table: the table to run checks on + :param dataset: dataframe or BaseTable that has to be validated :param column_mapping: the dictionary of columns and their associated checks, e.g. .. code-block:: python diff --git a/python-sdk/tests/sql/operators/test_ColumnCheckOperator.py b/python-sdk/tests/sql/operators/test_ColumnCheckOperator.py index e8e81ba80..3efd838f7 100644 --- a/python-sdk/tests/sql/operators/test_ColumnCheckOperator.py +++ b/python-sdk/tests/sql/operators/test_ColumnCheckOperator.py @@ -1,19 +1,246 @@ import pathlib +import pandas as pd +import pytest +from airflow import AirflowException + from astro import sql as aql -from astro.files import File CWD = pathlib.Path(__file__).parent +df = pd.DataFrame( + data={ + "name": ["Dwight Schrute", "Michael Scott", "Jim Halpert"], + "age": [30, None, None], + "city": [None, "LA", "California City"], + "emp_id": [10, 1, 35], + } +) + -def test_column_check_operator(sample_dag): +def test_column_check_operator_with_null_checks(sample_dag): + """ + Test column_check_operator for null_check case + """ aql.ColumnCheckOperator( - dataset=File(path=str(CWD) + "/../../data/homes2.csv"), + dataset=df, column_mapping={ - "sell": { + "name": {"null_check": {"geq_to": 0, "leq_to": 1}}, + "city": { + "null_check": { + "equal_to": 1, + }, + }, + "age": { "null_check": { "equal_to": 1, + "tolerance": 1, # Tolerance is + and - the value provided. Acceptable values is 0 to 2. + }, + }, + }, + ).execute({}) + + +def test_failure_of_column_check_operator_with_null_checks__equal_to(sample_dag): + """ + Test that failure column_check_operator for null_check + """ + with pytest.raises(AirflowException) as e: + aql.ColumnCheckOperator( + dataset=df, + column_mapping={ + "city": { + "null_check": { + "equal_to": 0, + }, + }, + }, + ).execute({}) + assert "Check Values: {'equal_to': 0, 'result': 1, 'success': False}" in str(e.value) + + +def test_failure_of_column_check_operator_with_null_checks__geq_to_and_leq_to(sample_dag): + """ + Test that failure column_check_operator for null_check with geq_to and leq_to + """ + with pytest.raises(AirflowException) as e: + aql.ColumnCheckOperator( + dataset=df, + column_mapping={"name": {"null_check": {"geq_to": 1, "leq_to": 2}}}, + ).execute({}) + assert "Check Values: {'geq_to': 1, 'leq_to': 2, 'result': 0, 'success': False}" in str(e.value) + + +def test_failure_of_column_check_operator_with_null_checks__equal_to_with_tolerance(sample_dag): + """ + Test that failure column_check_operator for null_check with equal_to and tolerance + """ + with pytest.raises(AirflowException) as e: + aql.ColumnCheckOperator( + dataset=df, + column_mapping={ + "age": { + "null_check": { + "equal_to": 0, + "tolerance": 1, # Tolerance is + and - the value provided. Acceptable values is 0 to 0. + }, + } + }, + ).execute({}) + assert "Check Values: {'equal_to': 0, 'tolerance': 1, 'result': 2, 'success': False}" in str(e.value) + + +def test_column_check_operator_with_distinct_checks(sample_dag): + """ + Test column_check_operator for distinct_check case + """ + aql.ColumnCheckOperator( + dataset=df, + column_mapping={ + "name": { + "distinct_check": { + "equal_to": 3, + } + }, + "city": { + "distinct_check": {"geq_to": 2, "leq_to": 3}, # Nulls are treated as values + }, + "age": { + "distinct_check": { + "equal_to": 1, + "tolerance": 1, # Tolerance is + and - the value provided. Acceptable values is 0 to 2. + }, + }, + }, + ).execute({}) + + +def test_failure_of_column_check_operator_with_distinct_checks__equal_to(sample_dag): + """ + Test that failure column_check_operator for distinct_check + """ + with pytest.raises(AirflowException) as e: + aql.ColumnCheckOperator( + dataset=df, + column_mapping={ + "city": { + "distinct_check": { + "equal_to": 0, + }, + }, + }, + ).execute({}) + assert "Check Values: {'equal_to': 0, 'result': 3, 'success': False}" in str(e.value) + + +def test_failure_of_column_check_operator_with_distinct_checks__geq_to_and_leq_to(sample_dag): + """ + Test that failure column_check_operator for distinct_check with geq_to and leq_to + """ + with pytest.raises(AirflowException) as e: + aql.ColumnCheckOperator( + dataset=df, + column_mapping={"name": {"distinct_check": {"geq_to": 1, "leq_to": 2}}}, + ).execute({}) + assert "Check Values: {'geq_to': 1, 'leq_to': 2, 'result': 3, 'success': False}" in str(e.value) + + +def test_failure_of_column_check_operator_with_distinct_check__equal_to_with_tolerance(sample_dag): + """ + Test that failure column_check_operator for distinct_check with equal_to and tolerance + """ + with pytest.raises(AirflowException) as e: + aql.ColumnCheckOperator( + dataset=df, + column_mapping={ + "age": { + "distinct_check": { + "equal_to": 0, + "tolerance": 1, # Tolerance is + and - the value provided. Acceptable values is 0 to 0. + }, + } + }, + ).execute({}) + assert "Check Values: {'equal_to': 0, 'tolerance': 1, 'result': 2, 'success': False}" in str(e.value) + + +def test_column_check_operator_with_unique_check(sample_dag): + """ + Test column_check_operator for unique_check case + """ + aql.ColumnCheckOperator( + dataset=df, + column_mapping={ + "name": { + "unique_check": { + "equal_to": 0, + } + }, + "city": { + "unique_check": {"geq_to": 0, "leq_to": 1}, # Nulls are treated as values + }, + "age": { + "unique_check": { + "equal_to": 1, + "tolerance": 1, # Tolerance is + and - the value provided. Acceptable values is 0 to 2. + }, + }, + }, + ).execute({}) + + +def test_column_check_operator_with_max_min_check(sample_dag): + """ + Test column_check_operator for max_min_check + """ + aql.ColumnCheckOperator( + dataset=df, + column_mapping={ + "emp_id": { + "min": { + "geq_to": 1, } - } + }, + "age": { + "max": { + "leq_to": 100, + }, + }, }, ).execute({}) + + +def test_failure_of_column_check_operator_with_max_check(sample_dag): + """ + Test that failure column_check_operator for max_check + """ + with pytest.raises(AirflowException) as e: + aql.ColumnCheckOperator( + dataset=df, + column_mapping={ + "age": { + "max": { + "leq_to": 20, + }, + } + }, + ).execute({}) + assert "Check Values: {'leq_to': 20, 'result': 30.0, 'success': False}" in str(e.value) + + +def test_failure_of_column_check_operator_with_min_check(sample_dag): + """ + Test that failure column_check_operator for min_check + """ + with pytest.raises(AirflowException) as e: + aql.ColumnCheckOperator( + dataset=df, + column_mapping={ + "age": { + "min": { + "geq_to": 50, + }, + } + }, + ).execute({}) + assert "Check Values: {'geq_to': 50, 'result': 30.0, 'success': False}" in str(e.value) From 8640e5e1f917d46b326995ec8e30db73e47443ed Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Wed, 7 Dec 2022 16:01:39 +0530 Subject: [PATCH 04/34] Update data types supported by ColumnCheckOperator --- .../astro/sql/operators/data_validations/ColumnCheckOperator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py index 85d04ac2d..47a4903bf 100644 --- a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py +++ b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py @@ -74,7 +74,7 @@ def execute(self, context: "Context"): elif type(self.dataset) == pandas.DataFrame: self.df = self.dataset else: - raise ValueError("dataset can only be of type File | pandas.dataframe | Table object") + raise ValueError("dataset can only be of type pandas.dataframe | Table object") self.process_checks() From d335c3caf44d56e6050d560db458f9ee56e05cfc Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Wed, 7 Dec 2022 17:55:12 +0530 Subject: [PATCH 05/34] Add task_id to operator --- .../operators/data_validations/ColumnCheckOperator.py | 11 +++++++++-- python-sdk/tests/data/data_validation.csv | 0 2 files changed, 9 insertions(+), 2 deletions(-) create mode 100644 python-sdk/tests/data/data_validation.csv diff --git a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py index 47a4903bf..69f17a766 100644 --- a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py +++ b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py @@ -2,6 +2,7 @@ import pandas from airflow import AirflowException +from airflow.decorators.base import get_unique_task_id from airflow.models.xcom_arg import XComArg from airflow.providers.common.sql.operators.sql import SQLColumnCheckOperator @@ -47,6 +48,7 @@ def __init__( dataset: Union[BaseTable, pandas.DataFrame], column_mapping: Dict[str, Dict[str, Any]], partition_clause: Optional[str] = None, + task_id: Optional[str] = None, **kwargs, ): for checks in column_mapping.values(): @@ -58,18 +60,21 @@ def __init__( self.partition_clause = partition_clause self.kwargs = kwargs self.df = None - if type(dataset) == BaseTable: + + if isinstance(dataset, BaseTable): db = create_database(conn_id=self.dataset.conn_id) # type: ignore + self.conn_id = self.dataset.conn_id super().__init__( table=db.get_table_qualified_name(table=self.dataset), column_mapping=self.column_mapping, partition_clause=self.partition_clause, conn_id=dataset.conn_id, database=dataset.metadata.database, + task_id=task_id if task_id is not None else get_unique_task_id("column_check"), ) def execute(self, context: "Context"): - if type(self.dataset) == BaseTable: + if isinstance(self.dataset, BaseTable): return super().execute(context=context) elif type(self.dataset) == pandas.DataFrame: self.df = self.dataset @@ -171,6 +176,7 @@ def column_check( dataset: Union[BaseTable, pandas.DataFrame], column_mapping: Dict[str, Dict[str, Any]], partition_clause: Optional[str] = None, + task_id: Optional[str] = None, **kwargs, ) -> XComArg: """ @@ -208,4 +214,5 @@ def column_check( column_mapping=column_mapping, partition_clause=partition_clause, kwargs=kwargs, + task_id=task_id, ).output diff --git a/python-sdk/tests/data/data_validation.csv b/python-sdk/tests/data/data_validation.csv new file mode 100644 index 000000000..e69de29bb From 43d67bd1fcc5c137c524717e51615591ea9f0212 Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Wed, 7 Dec 2022 17:56:58 +0530 Subject: [PATCH 06/34] Add testcase for table dataset --- python-sdk/tests/data/data_validation.csv | 4 ++ .../sql/operators/test_ColumnCheckOperator.py | 62 +++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/python-sdk/tests/data/data_validation.csv b/python-sdk/tests/data/data_validation.csv index e69de29bb..d47e006be 100644 --- a/python-sdk/tests/data/data_validation.csv +++ b/python-sdk/tests/data/data_validation.csv @@ -0,0 +1,4 @@ +name,age,city,emp_id +Dwight Schrute,30,,10 +Michael Scott,,LA,1 +Jim Halpert,,California City,35 diff --git a/python-sdk/tests/sql/operators/test_ColumnCheckOperator.py b/python-sdk/tests/sql/operators/test_ColumnCheckOperator.py index 3efd838f7..ba42aee01 100644 --- a/python-sdk/tests/sql/operators/test_ColumnCheckOperator.py +++ b/python-sdk/tests/sql/operators/test_ColumnCheckOperator.py @@ -5,6 +5,8 @@ from airflow import AirflowException from astro import sql as aql +from astro.constants import Database +from astro.files import File CWD = pathlib.Path(__file__).parent @@ -244,3 +246,63 @@ def test_failure_of_column_check_operator_with_min_check(sample_dag): }, ).execute({}) assert "Check Values: {'geq_to': 50, 'result': 30.0, 'success': False}" in str(e.value) + + +@pytest.mark.parametrize( + "database_table_fixture", + [ + { + "database": Database.SNOWFLAKE, + "file": File(path=str(CWD) + "/../../data/data_validation.csv"), + }, + { + "database": Database.BIGQUERY, + "file": File(path=str(CWD) + "/../../data/data_validation.csv"), + }, + { + "database": Database.POSTGRES, + "file": File(path=str(CWD) + "/../../data/data_validation.csv"), + }, + { + "database": Database.SQLITE, + "file": File(path=str(CWD) + "/../../data/data_validation.csv"), + }, + { + "database": Database.REDSHIFT, + "file": File(path=str(CWD) + "/../../data/data_validation.csv"), + }, + ], + indirect=True, + ids=["snowflake", "bigquery", "postgresql", "sqlite", "redshift"], +) +def test_column_check_operator_with_table_dataset(sample_dag, database_table_fixture): + """ + Test column_check_operator with table dataset for all checks types and make sure the generated sql is working for + all the database we support. + """ + db, test_table = database_table_fixture + + aql.ColumnCheckOperator( + dataset=test_table, + column_mapping={ + "name": { + "null_check": {"geq_to": 0, "leq_to": 1}, + "unique_check": { + "equal_to": 0, + }, + }, + "city": { + "distinct_check": {"geq_to": 2, "leq_to": 3}, # Nulls are treated as values + }, + "age": { + "max": { + "leq_to": 100, + }, + }, + "emp_id": { + "min": { + "geq_to": 1, + } + }, + }, + ).execute({}) From cce456725e13ef347bc7961d8715488d988cba84 Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Wed, 7 Dec 2022 18:22:37 +0530 Subject: [PATCH 07/34] Add doc string to functions --- .../data_validations/ColumnCheckOperator.py | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py index 69f17a766..500ef0cbe 100644 --- a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py +++ b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py @@ -84,6 +84,9 @@ def execute(self, context: "Context"): self.process_checks() def get_check_method(self, check_name: str, column_name: str): + """ + Get the method ref that will validate the dataframe + """ column_checks = { "null_check": self.col_null_check, "distinct_check": self.col_distinct_check, @@ -94,7 +97,9 @@ def get_check_method(self, check_name: str, column_name: str): return column_checks[check_name](column_name=column_name) def process_checks(self): - + """ + Process all the checks and print the result or raise an exception in the event of failed checks + """ failed_tests = [] passed_tests = [] @@ -119,26 +124,41 @@ def process_checks(self): print(f"The following tests have passed:" f"\n{''.join(passed_tests)}") def col_null_check(self, column_name: str) -> Optional[int]: + """ + Count the total null values in a dataframe column + """ if self.df is not None and column_name in self.df.columns: return list(self.df[column_name].isnull().values).count(True) return None def col_distinct_check(self, column_name: str) -> Optional[int]: + """ + Count the distinct value in a dataframe column + """ if self.df is not None and column_name in self.df.columns: return len(self.df[column_name].unique()) return None def col_unique_check(self, column_name: str) -> Optional[int]: + """ + Count the unique value in a dataframe column + """ if self.df is not None and column_name in self.df.columns: return len(self.df[column_name]) - self.col_distinct_check(column_name=column_name) return None def col_max(self, column_name: str) -> Optional[float]: + """ + Get the max value in dataframe column + """ if self.df is not None and column_name in self.df.columns: return self.df[column_name].max() return None def col_min(self, column_name: str) -> Optional[float]: + """ + Get the min value in dataframe column + """ if self.df is not None and column_name in self.df.columns: return self.df[column_name].min() return None From 488a6688062034cb095aef8eec52c1653f446596 Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Fri, 9 Dec 2022 16:36:45 +0530 Subject: [PATCH 08/34] Moved the test_ColumnCheckOperator.py to data_validation.py/test_ColumnCheckOperator.py --- .../{ => data_validation}/test_ColumnCheckOperator.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) rename python-sdk/tests/sql/operators/{ => data_validation}/test_ColumnCheckOperator.py (95%) diff --git a/python-sdk/tests/sql/operators/test_ColumnCheckOperator.py b/python-sdk/tests/sql/operators/data_validation/test_ColumnCheckOperator.py similarity index 95% rename from python-sdk/tests/sql/operators/test_ColumnCheckOperator.py rename to python-sdk/tests/sql/operators/data_validation/test_ColumnCheckOperator.py index ba42aee01..a814a3f69 100644 --- a/python-sdk/tests/sql/operators/test_ColumnCheckOperator.py +++ b/python-sdk/tests/sql/operators/data_validation/test_ColumnCheckOperator.py @@ -253,23 +253,23 @@ def test_failure_of_column_check_operator_with_min_check(sample_dag): [ { "database": Database.SNOWFLAKE, - "file": File(path=str(CWD) + "/../../data/data_validation.csv"), + "file": File(path=str(CWD) + "/../../../data/data_validation.csv"), }, { "database": Database.BIGQUERY, - "file": File(path=str(CWD) + "/../../data/data_validation.csv"), + "file": File(path=str(CWD) + "/../../../data/data_validation.csv"), }, { "database": Database.POSTGRES, - "file": File(path=str(CWD) + "/../../data/data_validation.csv"), + "file": File(path=str(CWD) + "/../../../data/data_validation.csv"), }, { "database": Database.SQLITE, - "file": File(path=str(CWD) + "/../../data/data_validation.csv"), + "file": File(path=str(CWD) + "/../../../data/data_validation.csv"), }, { "database": Database.REDSHIFT, - "file": File(path=str(CWD) + "/../../data/data_validation.csv"), + "file": File(path=str(CWD) + "/../../../data/data_validation.csv"), }, ], indirect=True, From 518cad36eed513f012c88f460111500b1d023416 Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Fri, 9 Dec 2022 16:38:37 +0530 Subject: [PATCH 09/34] Add SQLCheckOperator to validate tables via sql --- python-sdk/src/astro/sql/__init__.py | 1 + .../data_validations/SQLCheckOperator.py | 96 +++++++++++++++++++ .../sql/operators/data_validation/__init__.py | 0 .../data_validation/test_SQLCheckOperator.py | 51 ++++++++++ 4 files changed, 148 insertions(+) create mode 100644 python-sdk/src/astro/sql/operators/data_validations/SQLCheckOperator.py create mode 100644 python-sdk/tests/sql/operators/data_validation/__init__.py create mode 100644 python-sdk/tests/sql/operators/data_validation/test_SQLCheckOperator.py diff --git a/python-sdk/src/astro/sql/__init__.py b/python-sdk/src/astro/sql/__init__.py index df8173221..9b1a2ec68 100644 --- a/python-sdk/src/astro/sql/__init__.py +++ b/python-sdk/src/astro/sql/__init__.py @@ -5,6 +5,7 @@ from astro.sql.operators.append import AppendOperator, append from astro.sql.operators.cleanup import CleanupOperator, cleanup from astro.sql.operators.data_validations.ColumnCheckOperator import ColumnCheckOperator, column_check +from astro.sql.operators.data_validations.SQLCheckOperator import SQLCheckOperator, sql_check from astro.sql.operators.dataframe import DataframeOperator, dataframe from astro.sql.operators.drop import DropTableOperator, drop_table from astro.sql.operators.export_file import ExportFileOperator, export_file diff --git a/python-sdk/src/astro/sql/operators/data_validations/SQLCheckOperator.py b/python-sdk/src/astro/sql/operators/data_validations/SQLCheckOperator.py new file mode 100644 index 000000000..ae0c3907a --- /dev/null +++ b/python-sdk/src/astro/sql/operators/data_validations/SQLCheckOperator.py @@ -0,0 +1,96 @@ +from typing import Any, Dict, Optional + +from airflow.decorators.base import get_unique_task_id +from airflow.models.xcom_arg import XComArg +from airflow.providers.common.sql.operators.sql import SQLTableCheckOperator + +from astro.databases import create_database +from astro.table import BaseTable +from astro.utils.typing_compat import Context + + +class SQLCheckOperator(SQLTableCheckOperator): + """ + Performs one or more of the checks provided in the checks dictionary. + Checks should be written to return a boolean result. + + :param dataset: the table to run checks on + :param checks: the dictionary of checks, e.g.: + + .. code-block:: python + + { + "row_count_check": {"check_statement": "COUNT(*) = 1000"}, + "column_sum_check": {"check_statement": "col_a + col_b < col_c"}, + } + + + :param partition_clause: a partial SQL statement that is added to a WHERE clause in the query built by + the operator that creates partition_clauses for the checks to run on, e.g. + + .. code-block:: python + + "date = '1970-01-01'" + """ + + template_fields = ("partition_clause",) + + def __init__( + self, + *, + dataset: BaseTable, + checks: Dict[str, Dict[str, Any]], + partition_clause: Optional[str] = None, + task_id: Optional[str] = None, + **kwargs, + ): + + db = create_database(dataset.conn_id) + super().__init__( + table=db.get_table_qualified_name(dataset), + checks=checks, + partition_clause=partition_clause, + conn_id=dataset.conn_id, + task_id=task_id if task_id is not None else get_unique_task_id("sql_check"), + ) + + def execute(self, context: "Context"): + return super().execute(context=context) + + +def sql_check( + dataset: BaseTable, + checks: Dict[str, Dict[str, Any]], + partition_clause: Optional[str] = None, + task_id: Optional[str] = None, + **kwargs, +) -> XComArg: + """ + Performs one or more of the checks provided in the checks dictionary. + Checks should be written to return a boolean result. + + :param dataset: the table to run checks on + :param checks: the dictionary of checks, e.g.: + + .. code-block:: python + + { + "row_count_check": {"check_statement": "COUNT(*) = 1000"}, + "column_sum_check": {"check_statement": "col_a + col_b < col_c"}, + } + + + :param partition_clause: a partial SQL statement that is added to a WHERE clause in the query built by + the operator that creates partition_clauses for the checks to run on, e.g. + + .. code-block:: python + + "date = '1970-01-01'" + """ + return SQLCheckOperator( + dataset=dataset, + checks=checks, + partition_clause=partition_clause, + kwargs=kwargs, + task_id=task_id, + ).output diff --git a/python-sdk/tests/sql/operators/data_validation/__init__.py b/python-sdk/tests/sql/operators/data_validation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python-sdk/tests/sql/operators/data_validation/test_SQLCheckOperator.py b/python-sdk/tests/sql/operators/data_validation/test_SQLCheckOperator.py new file mode 100644 index 000000000..519749782 --- /dev/null +++ b/python-sdk/tests/sql/operators/data_validation/test_SQLCheckOperator.py @@ -0,0 +1,51 @@ +import pathlib + +import pytest + +from astro import sql as aql +from astro.constants import Database +from astro.files import File + +CWD = pathlib.Path(__file__).parent + + +@pytest.mark.parametrize( + "database_table_fixture", + [ + { + "database": Database.SNOWFLAKE, + "file": File(path=str(CWD) + "/../../../data/homes_main.csv"), + }, + { + "database": Database.BIGQUERY, + "file": File(path=str(CWD) + "/../../../data/homes_main.csv"), + }, + { + "database": Database.POSTGRES, + "file": File(path=str(CWD) + "/../../../data/homes_main.csv"), + }, + { + "database": Database.SQLITE, + "file": File(path=str(CWD) + "/../../../data/homes_main.csv"), + }, + { + "database": Database.REDSHIFT, + "file": File(path=str(CWD) + "/../../../data/homes_main.csv"), + }, + ], + indirect=True, + ids=["snowflake", "bigquery", "postgresql", "sqlite", "redshift"], +) +def test_column_check_operator_with_table_dataset(sample_dag, database_table_fixture): + """ + Test column_check_operator with table dataset for all checks types and make sure the generated sql is working for + all the database we support. + """ + db, test_table = database_table_fixture + + aql.SQLCheckOperator( + dataset=test_table, + checks={ + "sell_list": {"check_statement": "sell <= list"}, + }, + ).execute({}) From 57a5a00993ac721db1a052545a7b2acf7ad68ba7 Mon Sep 17 00:00:00 2001 From: Utkarsh Sharma Date: Wed, 14 Dec 2022 19:21:20 +0530 Subject: [PATCH 10/34] Update python-sdk/src/astro/sql/operators/data_validations/SQLCheckOperator.py Co-authored-by: Felix Uellendall --- .../astro/sql/operators/data_validations/SQLCheckOperator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-sdk/src/astro/sql/operators/data_validations/SQLCheckOperator.py b/python-sdk/src/astro/sql/operators/data_validations/SQLCheckOperator.py index ae0c3907a..d158db838 100644 --- a/python-sdk/src/astro/sql/operators/data_validations/SQLCheckOperator.py +++ b/python-sdk/src/astro/sql/operators/data_validations/SQLCheckOperator.py @@ -51,7 +51,7 @@ def __init__( checks=checks, partition_clause=partition_clause, conn_id=dataset.conn_id, - task_id=task_id if task_id is not None else get_unique_task_id("sql_check"), + task_id=task_id or get_unique_task_id("sql_check"), ) def execute(self, context: "Context"): From d2f4dcf24554c17a5745562bdf2b678b646307f2 Mon Sep 17 00:00:00 2001 From: Utkarsh Sharma Date: Wed, 14 Dec 2022 19:22:59 +0530 Subject: [PATCH 11/34] Update python-sdk/src/astro/sql/operators/data_validations/SQLCheckOperator.py Co-authored-by: Felix Uellendall --- .../astro/sql/operators/data_validations/SQLCheckOperator.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python-sdk/src/astro/sql/operators/data_validations/SQLCheckOperator.py b/python-sdk/src/astro/sql/operators/data_validations/SQLCheckOperator.py index d158db838..aa489681a 100644 --- a/python-sdk/src/astro/sql/operators/data_validations/SQLCheckOperator.py +++ b/python-sdk/src/astro/sql/operators/data_validations/SQLCheckOperator.py @@ -54,8 +54,6 @@ def __init__( task_id=task_id or get_unique_task_id("sql_check"), ) - def execute(self, context: "Context"): - return super().execute(context=context) def sql_check( From f0140c0eb86939d47fd70cdc6b81aac8544a1637 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 14 Dec 2022 13:53:37 +0000 Subject: [PATCH 12/34] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../src/astro/sql/operators/data_validations/SQLCheckOperator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python-sdk/src/astro/sql/operators/data_validations/SQLCheckOperator.py b/python-sdk/src/astro/sql/operators/data_validations/SQLCheckOperator.py index aa489681a..a6aa2599b 100644 --- a/python-sdk/src/astro/sql/operators/data_validations/SQLCheckOperator.py +++ b/python-sdk/src/astro/sql/operators/data_validations/SQLCheckOperator.py @@ -55,7 +55,6 @@ def __init__( ) - def sql_check( dataset: BaseTable, checks: Dict[str, Dict[str, Any]], From a2be757f1a0726cb0b4a2c47a85bd9e1c9ccb411 Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Wed, 14 Dec 2022 19:44:27 +0530 Subject: [PATCH 13/34] Update the dataframe check method --- .../data_validations/ColumnCheckOperator.py | 62 +++++++++---------- 1 file changed, 31 insertions(+), 31 deletions(-) diff --git a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py index 500ef0cbe..0cca5928e 100644 --- a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py +++ b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py @@ -83,18 +83,23 @@ def execute(self, context: "Context"): self.process_checks() - def get_check_method(self, check_name: str, column_name: str): + def get_check_result(self, check_name: str, column_name: str, df: pandas.DataFrame): """ - Get the method ref that will validate the dataframe + Get the check method results post validating the dataframe """ - column_checks = { - "null_check": self.col_null_check, - "distinct_check": self.col_distinct_check, - "unique_check": self.col_unique_check, - "min": self.col_min, - "max": self.col_max, - } - return column_checks[check_name](column_name=column_name) + if df is not None and column_name in df.columns: + column_checks = { + "null_check": self.col_null_check, + "distinct_check": self.col_distinct_check, + "unique_check": self.col_unique_check, + "min": self.col_min, + "max": self.col_max, + } + return column_checks[check_name](column_name=column_name, df=df) + elif df is None: + raise ValueError("Dataframe is None") + else: + raise ValueError(f"Dataframe is don't have column {column_name}") def process_checks(self): """ @@ -110,7 +115,7 @@ def process_checks(self): # Iterating over checks for check in checks: tolerance = self.column_mapping[column][check].get("tolerance") - result = self.get_check_method(check, column_name=column) + result = self.get_check_result(check, column_name=column, df=self.df) self.column_mapping[column][check]["result"] = result self.column_mapping[column][check]["success"] = self._get_match( self.column_mapping[column][check], result, tolerance @@ -123,45 +128,40 @@ def process_checks(self): if len(passed_tests) > 0: print(f"The following tests have passed:" f"\n{''.join(passed_tests)}") - def col_null_check(self, column_name: str) -> Optional[int]: + @staticmethod + def col_null_check(column_name: str, df: pandas.DataFrame) -> Optional[int]: """ Count the total null values in a dataframe column """ - if self.df is not None and column_name in self.df.columns: - return list(self.df[column_name].isnull().values).count(True) - return None + return list(df[column_name].isnull().values).count(True) - def col_distinct_check(self, column_name: str) -> Optional[int]: + @staticmethod + def col_distinct_check(column_name: str, df: pandas.DataFrame) -> Optional[int]: """ Count the distinct value in a dataframe column """ - if self.df is not None and column_name in self.df.columns: - return len(self.df[column_name].unique()) - return None + return len(df[column_name].unique()) - def col_unique_check(self, column_name: str) -> Optional[int]: + @staticmethod + def col_unique_check(column_name: str, df: pandas.DataFrame) -> Optional[int]: """ Count the unique value in a dataframe column """ - if self.df is not None and column_name in self.df.columns: - return len(self.df[column_name]) - self.col_distinct_check(column_name=column_name) - return None + return len(df[column_name]) - len(df[column_name].unique()) - def col_max(self, column_name: str) -> Optional[float]: + @staticmethod + def col_max(column_name: str, df: pandas.DataFrame) -> Optional[float]: """ Get the max value in dataframe column """ - if self.df is not None and column_name in self.df.columns: - return self.df[column_name].max() - return None + return df[column_name].max() - def col_min(self, column_name: str) -> Optional[float]: + @staticmethod + def col_min(column_name: str, df: pandas.DataFrame) -> Optional[float]: """ Get the min value in dataframe column """ - if self.df is not None and column_name in self.df.columns: - return self.df[column_name].min() - return None + return df[column_name].min() def _get_failed_checks(checks, col=None): From 016461bcad265f493e829b9fb9921c0308170dcc Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Wed, 14 Dec 2022 20:30:58 +0530 Subject: [PATCH 14/34] Change test execution method to run_dag() from operator.execute() --- .../data_validations/ColumnCheckOperator.py | 26 +- python-sdk/tests/data/data_validation.csv | 2 +- .../test_ColumnCheckOperator.py | 235 +++++++----------- .../sql/operators/data_validation/__init__.py | 0 .../test_ColumnCheckOperator.py | 71 ++++++ .../data_validation/test_SQLCheckOperator.py | 16 +- 6 files changed, 191 insertions(+), 159 deletions(-) create mode 100644 python-sdk/tests_integration/sql/operators/data_validation/__init__.py create mode 100644 python-sdk/tests_integration/sql/operators/data_validation/test_ColumnCheckOperator.py rename python-sdk/{tests => tests_integration}/sql/operators/data_validation/test_SQLCheckOperator.py (82%) diff --git a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py index 0cca5928e..ff5ae9d8e 100644 --- a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py +++ b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py @@ -3,7 +3,6 @@ import pandas from airflow import AirflowException from airflow.decorators.base import get_unique_task_id -from airflow.models.xcom_arg import XComArg from airflow.providers.common.sql.operators.sql import SQLColumnCheckOperator from astro.databases import create_database @@ -61,17 +60,22 @@ def __init__( self.kwargs = kwargs self.df = None + dataset_qualified_name = "" + dataset_conn_id = "" + if isinstance(dataset, BaseTable): db = create_database(conn_id=self.dataset.conn_id) # type: ignore self.conn_id = self.dataset.conn_id - super().__init__( - table=db.get_table_qualified_name(table=self.dataset), - column_mapping=self.column_mapping, - partition_clause=self.partition_clause, - conn_id=dataset.conn_id, - database=dataset.metadata.database, - task_id=task_id if task_id is not None else get_unique_task_id("column_check"), - ) + dataset_qualified_name = db.get_table_qualified_name(table=self.dataset) + dataset_conn_id = dataset.conn_id + + super().__init__( + table=dataset_qualified_name, + column_mapping=self.column_mapping, + partition_clause=self.partition_clause, + conn_id=dataset_conn_id, + task_id=task_id if task_id is not None else get_unique_task_id("column_check"), + ) def execute(self, context: "Context"): if isinstance(self.dataset, BaseTable): @@ -198,7 +202,7 @@ def column_check( partition_clause: Optional[str] = None, task_id: Optional[str] = None, **kwargs, -) -> XComArg: +) -> ColumnCheckOperator: """ Performs one or more of the templated checks in the column_checks dictionary. Checks are performed on a per-column basis specified by the column_mapping. @@ -235,4 +239,4 @@ def column_check( partition_clause=partition_clause, kwargs=kwargs, task_id=task_id, - ).output + ) diff --git a/python-sdk/tests/data/data_validation.csv b/python-sdk/tests/data/data_validation.csv index d47e006be..e71650807 100644 --- a/python-sdk/tests/data/data_validation.csv +++ b/python-sdk/tests/data/data_validation.csv @@ -1,4 +1,4 @@ name,age,city,emp_id -Dwight Schrute,30,,10 +Dwight Schrute,30.0,,10 Michael Scott,,LA,1 Jim Halpert,,California City,35 diff --git a/python-sdk/tests/sql/operators/data_validation/test_ColumnCheckOperator.py b/python-sdk/tests/sql/operators/data_validation/test_ColumnCheckOperator.py index a814a3f69..57d194119 100644 --- a/python-sdk/tests/sql/operators/data_validation/test_ColumnCheckOperator.py +++ b/python-sdk/tests/sql/operators/data_validation/test_ColumnCheckOperator.py @@ -5,8 +5,7 @@ from airflow import AirflowException from astro import sql as aql -from astro.constants import Database -from astro.files import File +from tests.sql.operators import utils as test_utils CWD = pathlib.Path(__file__).parent @@ -24,31 +23,33 @@ def test_column_check_operator_with_null_checks(sample_dag): """ Test column_check_operator for null_check case """ - aql.ColumnCheckOperator( - dataset=df, - column_mapping={ - "name": {"null_check": {"geq_to": 0, "leq_to": 1}}, - "city": { - "null_check": { - "equal_to": 1, + with sample_dag: + aql.column_check( + dataset=df, + column_mapping={ + "name": {"null_check": {"geq_to": 0, "leq_to": 1}}, + "city": { + "null_check": { + "equal_to": 1, + }, }, - }, - "age": { - "null_check": { - "equal_to": 1, - "tolerance": 1, # Tolerance is + and - the value provided. Acceptable values is 0 to 2. + "age": { + "null_check": { + "equal_to": 1, + "tolerance": 1, # Tolerance is + and - the value provided. Acceptable values is 0 to 2. + }, }, }, - }, - ).execute({}) + ) + test_utils.run_dag(sample_dag) def test_failure_of_column_check_operator_with_null_checks__equal_to(sample_dag): """ Test that failure column_check_operator for null_check """ - with pytest.raises(AirflowException) as e: - aql.ColumnCheckOperator( + with sample_dag, pytest.raises(AirflowException) as e: + aql.column_check( dataset=df, column_mapping={ "city": { @@ -57,7 +58,8 @@ def test_failure_of_column_check_operator_with_null_checks__equal_to(sample_dag) }, }, }, - ).execute({}) + ) + test_utils.run_dag(sample_dag) assert "Check Values: {'equal_to': 0, 'result': 1, 'success': False}" in str(e.value) @@ -65,11 +67,12 @@ def test_failure_of_column_check_operator_with_null_checks__geq_to_and_leq_to(sa """ Test that failure column_check_operator for null_check with geq_to and leq_to """ - with pytest.raises(AirflowException) as e: + with sample_dag, pytest.raises(AirflowException) as e: aql.ColumnCheckOperator( dataset=df, column_mapping={"name": {"null_check": {"geq_to": 1, "leq_to": 2}}}, - ).execute({}) + ) + test_utils.run_dag(sample_dag) assert "Check Values: {'geq_to': 1, 'leq_to': 2, 'result': 0, 'success': False}" in str(e.value) @@ -77,7 +80,7 @@ def test_failure_of_column_check_operator_with_null_checks__equal_to_with_tolera """ Test that failure column_check_operator for null_check with equal_to and tolerance """ - with pytest.raises(AirflowException) as e: + with sample_dag, pytest.raises(AirflowException) as e: aql.ColumnCheckOperator( dataset=df, column_mapping={ @@ -88,7 +91,8 @@ def test_failure_of_column_check_operator_with_null_checks__equal_to_with_tolera }, } }, - ).execute({}) + ) + test_utils.run_dag(sample_dag) assert "Check Values: {'equal_to': 0, 'tolerance': 1, 'result': 2, 'success': False}" in str(e.value) @@ -96,32 +100,34 @@ def test_column_check_operator_with_distinct_checks(sample_dag): """ Test column_check_operator for distinct_check case """ - aql.ColumnCheckOperator( - dataset=df, - column_mapping={ - "name": { - "distinct_check": { - "equal_to": 3, - } - }, - "city": { - "distinct_check": {"geq_to": 2, "leq_to": 3}, # Nulls are treated as values - }, - "age": { - "distinct_check": { - "equal_to": 1, - "tolerance": 1, # Tolerance is + and - the value provided. Acceptable values is 0 to 2. + with sample_dag: + aql.ColumnCheckOperator( + dataset=df, + column_mapping={ + "name": { + "distinct_check": { + "equal_to": 3, + } + }, + "city": { + "distinct_check": {"geq_to": 2, "leq_to": 3}, # Nulls are treated as values + }, + "age": { + "distinct_check": { + "equal_to": 1, + "tolerance": 1, # Tolerance is + and - the value provided. Acceptable values is 0 to 2. + }, }, }, - }, - ).execute({}) + ) + test_utils.run_dag(sample_dag) def test_failure_of_column_check_operator_with_distinct_checks__equal_to(sample_dag): """ Test that failure column_check_operator for distinct_check """ - with pytest.raises(AirflowException) as e: + with sample_dag, pytest.raises(AirflowException) as e: aql.ColumnCheckOperator( dataset=df, column_mapping={ @@ -131,7 +137,8 @@ def test_failure_of_column_check_operator_with_distinct_checks__equal_to(sample_ }, }, }, - ).execute({}) + ) + test_utils.run_dag(sample_dag) assert "Check Values: {'equal_to': 0, 'result': 3, 'success': False}" in str(e.value) @@ -139,11 +146,12 @@ def test_failure_of_column_check_operator_with_distinct_checks__geq_to_and_leq_t """ Test that failure column_check_operator for distinct_check with geq_to and leq_to """ - with pytest.raises(AirflowException) as e: + with sample_dag, pytest.raises(AirflowException) as e: aql.ColumnCheckOperator( dataset=df, column_mapping={"name": {"distinct_check": {"geq_to": 1, "leq_to": 2}}}, - ).execute({}) + ) + test_utils.run_dag(sample_dag) assert "Check Values: {'geq_to': 1, 'leq_to': 2, 'result': 3, 'success': False}" in str(e.value) @@ -151,7 +159,7 @@ def test_failure_of_column_check_operator_with_distinct_check__equal_to_with_tol """ Test that failure column_check_operator for distinct_check with equal_to and tolerance """ - with pytest.raises(AirflowException) as e: + with sample_dag, pytest.raises(AirflowException) as e: aql.ColumnCheckOperator( dataset=df, column_mapping={ @@ -162,7 +170,8 @@ def test_failure_of_column_check_operator_with_distinct_check__equal_to_with_tol }, } }, - ).execute({}) + ) + test_utils.run_dag(sample_dag) assert "Check Values: {'equal_to': 0, 'tolerance': 1, 'result': 2, 'success': False}" in str(e.value) @@ -170,53 +179,57 @@ def test_column_check_operator_with_unique_check(sample_dag): """ Test column_check_operator for unique_check case """ - aql.ColumnCheckOperator( - dataset=df, - column_mapping={ - "name": { - "unique_check": { - "equal_to": 0, - } - }, - "city": { - "unique_check": {"geq_to": 0, "leq_to": 1}, # Nulls are treated as values - }, - "age": { - "unique_check": { - "equal_to": 1, - "tolerance": 1, # Tolerance is + and - the value provided. Acceptable values is 0 to 2. + with sample_dag: + aql.ColumnCheckOperator( + dataset=df, + column_mapping={ + "name": { + "unique_check": { + "equal_to": 0, + } + }, + "city": { + "unique_check": {"geq_to": 0, "leq_to": 1}, # Nulls are treated as values + }, + "age": { + "unique_check": { + "equal_to": 1, + "tolerance": 1, # Tolerance is + and - the value provided. Acceptable values is 0 to 2. + }, }, }, - }, - ).execute({}) + ) + test_utils.run_dag(sample_dag) def test_column_check_operator_with_max_min_check(sample_dag): """ Test column_check_operator for max_min_check """ - aql.ColumnCheckOperator( - dataset=df, - column_mapping={ - "emp_id": { - "min": { - "geq_to": 1, - } - }, - "age": { - "max": { - "leq_to": 100, + with sample_dag: + aql.ColumnCheckOperator( + dataset=df, + column_mapping={ + "emp_id": { + "min": { + "geq_to": 1, + } + }, + "age": { + "max": { + "leq_to": 100, + }, }, }, - }, - ).execute({}) + ) + test_utils.run_dag(sample_dag) def test_failure_of_column_check_operator_with_max_check(sample_dag): """ Test that failure column_check_operator for max_check """ - with pytest.raises(AirflowException) as e: + with sample_dag, pytest.raises(AirflowException) as e: aql.ColumnCheckOperator( dataset=df, column_mapping={ @@ -226,7 +239,8 @@ def test_failure_of_column_check_operator_with_max_check(sample_dag): }, } }, - ).execute({}) + ) + test_utils.run_dag(sample_dag) assert "Check Values: {'leq_to': 20, 'result': 30.0, 'success': False}" in str(e.value) @@ -234,7 +248,7 @@ def test_failure_of_column_check_operator_with_min_check(sample_dag): """ Test that failure column_check_operator for min_check """ - with pytest.raises(AirflowException) as e: + with sample_dag, pytest.raises(AirflowException) as e: aql.ColumnCheckOperator( dataset=df, column_mapping={ @@ -244,65 +258,6 @@ def test_failure_of_column_check_operator_with_min_check(sample_dag): }, } }, - ).execute({}) + ) + test_utils.run_dag(sample_dag) assert "Check Values: {'geq_to': 50, 'result': 30.0, 'success': False}" in str(e.value) - - -@pytest.mark.parametrize( - "database_table_fixture", - [ - { - "database": Database.SNOWFLAKE, - "file": File(path=str(CWD) + "/../../../data/data_validation.csv"), - }, - { - "database": Database.BIGQUERY, - "file": File(path=str(CWD) + "/../../../data/data_validation.csv"), - }, - { - "database": Database.POSTGRES, - "file": File(path=str(CWD) + "/../../../data/data_validation.csv"), - }, - { - "database": Database.SQLITE, - "file": File(path=str(CWD) + "/../../../data/data_validation.csv"), - }, - { - "database": Database.REDSHIFT, - "file": File(path=str(CWD) + "/../../../data/data_validation.csv"), - }, - ], - indirect=True, - ids=["snowflake", "bigquery", "postgresql", "sqlite", "redshift"], -) -def test_column_check_operator_with_table_dataset(sample_dag, database_table_fixture): - """ - Test column_check_operator with table dataset for all checks types and make sure the generated sql is working for - all the database we support. - """ - db, test_table = database_table_fixture - - aql.ColumnCheckOperator( - dataset=test_table, - column_mapping={ - "name": { - "null_check": {"geq_to": 0, "leq_to": 1}, - "unique_check": { - "equal_to": 0, - }, - }, - "city": { - "distinct_check": {"geq_to": 2, "leq_to": 3}, # Nulls are treated as values - }, - "age": { - "max": { - "leq_to": 100, - }, - }, - "emp_id": { - "min": { - "geq_to": 1, - } - }, - }, - ).execute({}) diff --git a/python-sdk/tests_integration/sql/operators/data_validation/__init__.py b/python-sdk/tests_integration/sql/operators/data_validation/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python-sdk/tests_integration/sql/operators/data_validation/test_ColumnCheckOperator.py b/python-sdk/tests_integration/sql/operators/data_validation/test_ColumnCheckOperator.py new file mode 100644 index 000000000..676e795f3 --- /dev/null +++ b/python-sdk/tests_integration/sql/operators/data_validation/test_ColumnCheckOperator.py @@ -0,0 +1,71 @@ +import pathlib + +import pytest + +from astro import sql as aql +from astro.constants import Database +from astro.files import File +from tests.sql.operators import utils as test_utils + +CWD = pathlib.Path(__file__).parent + + +@pytest.mark.parametrize( + "database_table_fixture", + [ + { + "database": Database.SNOWFLAKE, + "file": File(path=str(CWD) + "/../../../data/data_validation.csv"), + }, + { + "database": Database.BIGQUERY, + "file": File(path=str(CWD) + "/../../../data/data_validation.csv"), + }, + { + "database": Database.POSTGRES, + "file": File(path=str(CWD) + "/../../../data/data_validation.csv"), + }, + { + "database": Database.SQLITE, + "file": File(path=str(CWD) + "/../../../data/data_validation.csv"), + }, + { + "database": Database.REDSHIFT, + "file": File(path=str(CWD) + "/../../../data/data_validation.csv"), + }, + ], + indirect=True, + ids=["snowflake", "bigquery", "postgresql", "sqlite", "redshift"], +) +def test_column_check_operator_with_table_dataset(sample_dag, database_table_fixture): + """ + Test column_check_operator with table dataset for all checks types and make sure the generated sql is working for + all the database we support. + """ + db, test_table = database_table_fixture + with sample_dag: + aql.ColumnCheckOperator( + dataset=test_table, + column_mapping={ + "name": { + "null_check": {"geq_to": 0, "leq_to": 1}, + "unique_check": { + "equal_to": 0, + }, + }, + "city": { + "distinct_check": {"geq_to": 2, "leq_to": 3}, # Nulls are treated as values + }, + "age": { + "max": { + "leq_to": 100, + }, + }, + "emp_id": { + "min": { + "geq_to": 1, + } + }, + }, + ) + test_utils.run_dag(sample_dag) diff --git a/python-sdk/tests/sql/operators/data_validation/test_SQLCheckOperator.py b/python-sdk/tests_integration/sql/operators/data_validation/test_SQLCheckOperator.py similarity index 82% rename from python-sdk/tests/sql/operators/data_validation/test_SQLCheckOperator.py rename to python-sdk/tests_integration/sql/operators/data_validation/test_SQLCheckOperator.py index 519749782..b373efdfb 100644 --- a/python-sdk/tests/sql/operators/data_validation/test_SQLCheckOperator.py +++ b/python-sdk/tests_integration/sql/operators/data_validation/test_SQLCheckOperator.py @@ -5,6 +5,7 @@ from astro import sql as aql from astro.constants import Database from astro.files import File +from tests.sql.operators import utils as test_utils CWD = pathlib.Path(__file__).parent @@ -42,10 +43,11 @@ def test_column_check_operator_with_table_dataset(sample_dag, database_table_fix all the database we support. """ db, test_table = database_table_fixture - - aql.SQLCheckOperator( - dataset=test_table, - checks={ - "sell_list": {"check_statement": "sell <= list"}, - }, - ).execute({}) + with sample_dag: + aql.SQLCheckOperator( + dataset=test_table, + checks={ + "sell_list": {"check_statement": "sell <= list"}, + }, + ) + test_utils.run_dag(sample_dag) From 122088376cf92c2920463c39862254664f81868e Mon Sep 17 00:00:00 2001 From: Utkarsh Sharma Date: Tue, 20 Dec 2022 08:23:07 +0530 Subject: [PATCH 15/34] Override GoogleBaseHook with BigqueryHook (#1442) # Description ## What is the current behavior? Because of below issue: ``` airflow.exceptions.AirflowException: You are trying to use `common-sql` with GoogleBaseHook, but its provider does not support it. Please upgrade the provider to a version that supports `common-sql`. The hook class should be a subclass of `airflow.providers.common.sql.hooks.sql.DbApiHook`. Got GoogleBaseHook Hook with class hierarchy: [, , , ] ``` We are using a work around and using Bigquey Hook --- python-sdk/conftest.py | 4 ++++ python-sdk/src/astro/databases/__init__.py | 7 ++----- python-sdk/src/astro/databases/base.py | 2 +- .../src/astro/databases/google/bigquery.py | 7 +++++-- python-sdk/src/astro/databases/postgres.py | 4 +++- python-sdk/src/astro/databases/snowflake.py | 4 +++- python-sdk/src/astro/databases/sqlite.py | 4 +++- .../data_validations/ColumnCheckOperator.py | 21 +++++++++++++++++++ .../test_ColumnCheckOperator.py | 3 +++ .../data_validation/test_SQLCheckOperator.py | 2 ++ 10 files changed, 47 insertions(+), 11 deletions(-) diff --git a/python-sdk/conftest.py b/python-sdk/conftest.py index 3da328b26..1da7fd02d 100644 --- a/python-sdk/conftest.py +++ b/python-sdk/conftest.py @@ -66,7 +66,11 @@ def database_table_fixture(request): params = deepcopy(request.param) database_name = params["database"] + user_table = params.get("table", None) conn_id = DATABASE_NAME_TO_CONN_ID[database_name] + if user_table and user_table.conn_id: + conn_id = user_table.conn_id + database = create_database(conn_id) table = params.get("table", Table(conn_id=database.conn_id, metadata=database.default_metadata)) if not isinstance(table, TempTable): diff --git a/python-sdk/src/astro/databases/__init__.py b/python-sdk/src/astro/databases/__init__.py index 49107be73..aeed965ab 100644 --- a/python-sdk/src/astro/databases/__init__.py +++ b/python-sdk/src/astro/databases/__init__.py @@ -24,10 +24,7 @@ SUPPORTED_DATABASES = set(DEFAULT_CONN_TYPE_TO_MODULE_PATH.keys()) -def create_database( - conn_id: str, - table: BaseTable | None = None, -) -> BaseDatabase: +def create_database(conn_id: str, table: BaseTable | None = None, region: str | None = None) -> BaseDatabase: """ Given a conn_id, return the associated Database class. @@ -40,5 +37,5 @@ def create_database( module_path = CONN_TYPE_TO_MODULE_PATH[conn_type] module = importlib.import_module(module_path) class_name = get_class_name(module_ref=module, suffix="Database") - database: BaseDatabase = getattr(module, class_name)(conn_id, table) + database: BaseDatabase = getattr(module, class_name)(conn_id, table, region) return database diff --git a/python-sdk/src/astro/databases/base.py b/python-sdk/src/astro/databases/base.py index 106c2fb49..418d770de 100644 --- a/python-sdk/src/astro/databases/base.py +++ b/python-sdk/src/astro/databases/base.py @@ -68,7 +68,7 @@ class BaseDatabase(ABC): NATIVE_AUTODETECT_SCHEMA_CONFIG: Mapping[FileLocation, Mapping[str, list[FileType] | Callable]] = {} FILE_PATTERN_BASED_AUTODETECT_SCHEMA_SUPPORTED: set[FileLocation] = set() - def __init__(self, conn_id: str): + def __init__(self, conn_id: str, table: BaseTable | None = None, region: str | None = None): self.conn_id = conn_id self.sql: str | ClauseElement = "" diff --git a/python-sdk/src/astro/databases/google/bigquery.py b/python-sdk/src/astro/databases/google/bigquery.py index a96b3122f..e869c4d77 100644 --- a/python-sdk/src/astro/databases/google/bigquery.py +++ b/python-sdk/src/astro/databases/google/bigquery.py @@ -104,9 +104,12 @@ class BigqueryDatabase(BaseDatabase): _create_schema_statement: str = "CREATE SCHEMA IF NOT EXISTS {} OPTIONS (location='{}')" - def __init__(self, conn_id: str = DEFAULT_CONN_ID, table: BaseTable | None = None): + def __init__( + self, conn_id: str = DEFAULT_CONN_ID, table: BaseTable | None = None, region: str | None = None + ): super().__init__(conn_id) self.table = table + self.region = region @property def sql_type(self) -> str: @@ -115,7 +118,7 @@ def sql_type(self) -> str: @property def hook(self) -> BigQueryHook: """Retrieve Airflow hook to interface with the BigQuery database.""" - return BigQueryHook(gcp_conn_id=self.conn_id, use_legacy_sql=False) + return BigQueryHook(gcp_conn_id=self.conn_id, use_legacy_sql=False, location=self.region) @property def sqlalchemy_engine(self) -> Engine: diff --git a/python-sdk/src/astro/databases/postgres.py b/python-sdk/src/astro/databases/postgres.py index 07e47e5e3..51c58b486 100644 --- a/python-sdk/src/astro/databases/postgres.py +++ b/python-sdk/src/astro/databases/postgres.py @@ -28,7 +28,9 @@ class PostgresDatabase(BaseDatabase): illegal_column_name_chars: list[str] = ["."] illegal_column_name_chars_replacement: list[str] = ["_"] - def __init__(self, conn_id: str = DEFAULT_CONN_ID, table: BaseTable | None = None): + def __init__( + self, conn_id: str = DEFAULT_CONN_ID, table: BaseTable | None = None, region: str | None = None + ): super().__init__(conn_id) self.table = table diff --git a/python-sdk/src/astro/databases/snowflake.py b/python-sdk/src/astro/databases/snowflake.py index dd6bc23bc..c3d69cd04 100644 --- a/python-sdk/src/astro/databases/snowflake.py +++ b/python-sdk/src/astro/databases/snowflake.py @@ -252,7 +252,9 @@ class SnowflakeDatabase(BaseDatabase): ) DEFAULT_SCHEMA = SNOWFLAKE_SCHEMA - def __init__(self, conn_id: str = DEFAULT_CONN_ID, table: BaseTable | None = None): + def __init__( + self, conn_id: str = DEFAULT_CONN_ID, table: BaseTable | None = None, region: str | None = None + ): super().__init__(conn_id) self.table = table diff --git a/python-sdk/src/astro/databases/sqlite.py b/python-sdk/src/astro/databases/sqlite.py index 36f9ae43e..42f913b08 100644 --- a/python-sdk/src/astro/databases/sqlite.py +++ b/python-sdk/src/astro/databases/sqlite.py @@ -20,7 +20,9 @@ class SqliteDatabase(BaseDatabase): logic in other parts of our code-base. """ - def __init__(self, conn_id: str = DEFAULT_CONN_ID, table: BaseTable | None = None): + def __init__( + self, conn_id: str = DEFAULT_CONN_ID, table: BaseTable | None = None, region: str | None = None + ): super().__init__(conn_id) self.table = table diff --git a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py index ff5ae9d8e..8c2032d20 100644 --- a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py +++ b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py @@ -3,9 +3,11 @@ import pandas from airflow import AirflowException from airflow.decorators.base import get_unique_task_id +from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.common.sql.operators.sql import SQLColumnCheckOperator from astro.databases import create_database +from astro.settings import BIGQUERY_SCHEMA_LOCATION from astro.table import BaseTable from astro.utils.typing_compat import Context @@ -77,8 +79,27 @@ def __init__( task_id=task_id if task_id is not None else get_unique_task_id("column_check"), ) + def get_db_hook(self) -> DbApiHook: + """ + Get the database hook for the connection. + + :return: the database hook object. + """ + db = create_database( + conn_id=self.conn_id, region=self.dataset.metadata.region or BIGQUERY_SCHEMA_LOCATION + ) + if db.sql_type == "bigquery": + return db.hook + return super().get_db_hook() + def execute(self, context: "Context"): if isinstance(self.dataset, BaseTable): + # Work around for GoogleBaseHook not inheriting from DBApi + # db = create_database( + # conn_id=self.conn_id, region=self.dataset.metadata.region or BIGQUERY_SCHEMA_LOCATION + # ) + # if db.sql_type == "bigquery": + # self._hook = db.hook return super().execute(context=context) elif type(self.dataset) == pandas.DataFrame: self.df = self.dataset diff --git a/python-sdk/tests_integration/sql/operators/data_validation/test_ColumnCheckOperator.py b/python-sdk/tests_integration/sql/operators/data_validation/test_ColumnCheckOperator.py index 676e795f3..047cc858e 100644 --- a/python-sdk/tests_integration/sql/operators/data_validation/test_ColumnCheckOperator.py +++ b/python-sdk/tests_integration/sql/operators/data_validation/test_ColumnCheckOperator.py @@ -5,6 +5,7 @@ from astro import sql as aql from astro.constants import Database from astro.files import File +from astro.table import Table from tests.sql.operators import utils as test_utils CWD = pathlib.Path(__file__).parent @@ -20,6 +21,7 @@ { "database": Database.BIGQUERY, "file": File(path=str(CWD) + "/../../../data/data_validation.csv"), + "table": Table(conn_id="bigquery"), }, { "database": Database.POSTGRES, @@ -43,6 +45,7 @@ def test_column_check_operator_with_table_dataset(sample_dag, database_table_fix all the database we support. """ db, test_table = database_table_fixture + test_table.conn_id = "gcp_conn_project" with sample_dag: aql.ColumnCheckOperator( dataset=test_table, diff --git a/python-sdk/tests_integration/sql/operators/data_validation/test_SQLCheckOperator.py b/python-sdk/tests_integration/sql/operators/data_validation/test_SQLCheckOperator.py index b373efdfb..8ec3d598e 100644 --- a/python-sdk/tests_integration/sql/operators/data_validation/test_SQLCheckOperator.py +++ b/python-sdk/tests_integration/sql/operators/data_validation/test_SQLCheckOperator.py @@ -5,6 +5,7 @@ from astro import sql as aql from astro.constants import Database from astro.files import File +from astro.table import Table from tests.sql.operators import utils as test_utils CWD = pathlib.Path(__file__).parent @@ -20,6 +21,7 @@ { "database": Database.BIGQUERY, "file": File(path=str(CWD) + "/../../../data/homes_main.csv"), + "table": Table(conn_id="bigquery"), }, { "database": Database.POSTGRES, From 3db2eab6f8d61922f4a194b400f55f9793ab6241 Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Tue, 20 Dec 2022 14:50:54 +0530 Subject: [PATCH 16/34] Revoke region changes --- python-sdk/src/astro/databases/__init__.py | 4 ++-- python-sdk/src/astro/databases/base.py | 2 +- python-sdk/src/astro/databases/google/bigquery.py | 7 ++----- python-sdk/src/astro/databases/postgres.py | 4 +--- python-sdk/src/astro/databases/snowflake.py | 4 +--- python-sdk/src/astro/databases/sqlite.py | 4 +--- 6 files changed, 8 insertions(+), 17 deletions(-) diff --git a/python-sdk/src/astro/databases/__init__.py b/python-sdk/src/astro/databases/__init__.py index aeed965ab..0b501d8ea 100644 --- a/python-sdk/src/astro/databases/__init__.py +++ b/python-sdk/src/astro/databases/__init__.py @@ -24,7 +24,7 @@ SUPPORTED_DATABASES = set(DEFAULT_CONN_TYPE_TO_MODULE_PATH.keys()) -def create_database(conn_id: str, table: BaseTable | None = None, region: str | None = None) -> BaseDatabase: +def create_database(conn_id: str, table: BaseTable | None = None) -> BaseDatabase: """ Given a conn_id, return the associated Database class. @@ -37,5 +37,5 @@ def create_database(conn_id: str, table: BaseTable | None = None, region: str | module_path = CONN_TYPE_TO_MODULE_PATH[conn_type] module = importlib.import_module(module_path) class_name = get_class_name(module_ref=module, suffix="Database") - database: BaseDatabase = getattr(module, class_name)(conn_id, table, region) + database: BaseDatabase = getattr(module, class_name)(conn_id, table) return database diff --git a/python-sdk/src/astro/databases/base.py b/python-sdk/src/astro/databases/base.py index 418d770de..7f139beeb 100644 --- a/python-sdk/src/astro/databases/base.py +++ b/python-sdk/src/astro/databases/base.py @@ -68,7 +68,7 @@ class BaseDatabase(ABC): NATIVE_AUTODETECT_SCHEMA_CONFIG: Mapping[FileLocation, Mapping[str, list[FileType] | Callable]] = {} FILE_PATTERN_BASED_AUTODETECT_SCHEMA_SUPPORTED: set[FileLocation] = set() - def __init__(self, conn_id: str, table: BaseTable | None = None, region: str | None = None): + def __init__(self, conn_id: str, table: BaseTable | None = None): self.conn_id = conn_id self.sql: str | ClauseElement = "" diff --git a/python-sdk/src/astro/databases/google/bigquery.py b/python-sdk/src/astro/databases/google/bigquery.py index e869c4d77..a96b3122f 100644 --- a/python-sdk/src/astro/databases/google/bigquery.py +++ b/python-sdk/src/astro/databases/google/bigquery.py @@ -104,12 +104,9 @@ class BigqueryDatabase(BaseDatabase): _create_schema_statement: str = "CREATE SCHEMA IF NOT EXISTS {} OPTIONS (location='{}')" - def __init__( - self, conn_id: str = DEFAULT_CONN_ID, table: BaseTable | None = None, region: str | None = None - ): + def __init__(self, conn_id: str = DEFAULT_CONN_ID, table: BaseTable | None = None): super().__init__(conn_id) self.table = table - self.region = region @property def sql_type(self) -> str: @@ -118,7 +115,7 @@ def sql_type(self) -> str: @property def hook(self) -> BigQueryHook: """Retrieve Airflow hook to interface with the BigQuery database.""" - return BigQueryHook(gcp_conn_id=self.conn_id, use_legacy_sql=False, location=self.region) + return BigQueryHook(gcp_conn_id=self.conn_id, use_legacy_sql=False) @property def sqlalchemy_engine(self) -> Engine: diff --git a/python-sdk/src/astro/databases/postgres.py b/python-sdk/src/astro/databases/postgres.py index 51c58b486..07e47e5e3 100644 --- a/python-sdk/src/astro/databases/postgres.py +++ b/python-sdk/src/astro/databases/postgres.py @@ -28,9 +28,7 @@ class PostgresDatabase(BaseDatabase): illegal_column_name_chars: list[str] = ["."] illegal_column_name_chars_replacement: list[str] = ["_"] - def __init__( - self, conn_id: str = DEFAULT_CONN_ID, table: BaseTable | None = None, region: str | None = None - ): + def __init__(self, conn_id: str = DEFAULT_CONN_ID, table: BaseTable | None = None): super().__init__(conn_id) self.table = table diff --git a/python-sdk/src/astro/databases/snowflake.py b/python-sdk/src/astro/databases/snowflake.py index c3d69cd04..dd6bc23bc 100644 --- a/python-sdk/src/astro/databases/snowflake.py +++ b/python-sdk/src/astro/databases/snowflake.py @@ -252,9 +252,7 @@ class SnowflakeDatabase(BaseDatabase): ) DEFAULT_SCHEMA = SNOWFLAKE_SCHEMA - def __init__( - self, conn_id: str = DEFAULT_CONN_ID, table: BaseTable | None = None, region: str | None = None - ): + def __init__(self, conn_id: str = DEFAULT_CONN_ID, table: BaseTable | None = None): super().__init__(conn_id) self.table = table diff --git a/python-sdk/src/astro/databases/sqlite.py b/python-sdk/src/astro/databases/sqlite.py index 42f913b08..36f9ae43e 100644 --- a/python-sdk/src/astro/databases/sqlite.py +++ b/python-sdk/src/astro/databases/sqlite.py @@ -20,9 +20,7 @@ class SqliteDatabase(BaseDatabase): logic in other parts of our code-base. """ - def __init__( - self, conn_id: str = DEFAULT_CONN_ID, table: BaseTable | None = None, region: str | None = None - ): + def __init__(self, conn_id: str = DEFAULT_CONN_ID, table: BaseTable | None = None): super().__init__(conn_id) self.table = table From 12a9334f08506f682ace062080c33dc254321d52 Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Tue, 20 Dec 2022 15:12:29 +0530 Subject: [PATCH 17/34] Add apache-airflow-providers-common-sql as dependency --- python-sdk/pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python-sdk/pyproject.toml b/python-sdk/pyproject.toml index 6ea67b356..5e2e338aa 100644 --- a/python-sdk/pyproject.toml +++ b/python-sdk/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "python-frontmatter", "smart-open", "SQLAlchemy>=1.3.18", + "apache-airflow-providers-common-sql" ] keywords = ["airflow", "provider", "astronomer", "sql", "decorator", "task flow", "elt", "etl", "dag"] @@ -95,7 +96,7 @@ all = [ "protobuf<=3.20", # Google bigquery client require protobuf <= 3.20.0. We can remove the limitation when this limitation is removed "openlineage-airflow>=0.17.0", "apache-airflow-providers-microsoft-azure", - "azure-storage-blob", + "azure-storage-blob" ] doc = [ "myst-parser>=0.17", From bc037cb0debc285f4dbf7a5988a04487a31143a2 Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Tue, 20 Dec 2022 15:13:04 +0530 Subject: [PATCH 18/34] Remove unwanted code --- .../sql/operators/data_validation/test_ColumnCheckOperator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python-sdk/tests_integration/sql/operators/data_validation/test_ColumnCheckOperator.py b/python-sdk/tests_integration/sql/operators/data_validation/test_ColumnCheckOperator.py index 047cc858e..ad2b599e4 100644 --- a/python-sdk/tests_integration/sql/operators/data_validation/test_ColumnCheckOperator.py +++ b/python-sdk/tests_integration/sql/operators/data_validation/test_ColumnCheckOperator.py @@ -45,7 +45,6 @@ def test_column_check_operator_with_table_dataset(sample_dag, database_table_fix all the database we support. """ db, test_table = database_table_fixture - test_table.conn_id = "gcp_conn_project" with sample_dag: aql.ColumnCheckOperator( dataset=test_table, From 04d527e94bc92fa09af2c9d1cc66cc9462315ef7 Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Tue, 20 Dec 2022 15:29:47 +0530 Subject: [PATCH 19/34] Remove unwanted codes --- .../sql/operators/data_validations/ColumnCheckOperator.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py index 8c2032d20..c2222e974 100644 --- a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py +++ b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py @@ -7,7 +7,6 @@ from airflow.providers.common.sql.operators.sql import SQLColumnCheckOperator from astro.databases import create_database -from astro.settings import BIGQUERY_SCHEMA_LOCATION from astro.table import BaseTable from astro.utils.typing_compat import Context @@ -85,9 +84,7 @@ def get_db_hook(self) -> DbApiHook: :return: the database hook object. """ - db = create_database( - conn_id=self.conn_id, region=self.dataset.metadata.region or BIGQUERY_SCHEMA_LOCATION - ) + db = create_database(conn_id=self.conn_id) if db.sql_type == "bigquery": return db.hook return super().get_db_hook() From 52e382172e2e0cd58ac61a9005806f808ce65a56 Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Tue, 20 Dec 2022 16:09:37 +0530 Subject: [PATCH 20/34] Add location --- python-sdk/src/astro/databases/google/bigquery.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-sdk/src/astro/databases/google/bigquery.py b/python-sdk/src/astro/databases/google/bigquery.py index a96b3122f..4e0f2f63e 100644 --- a/python-sdk/src/astro/databases/google/bigquery.py +++ b/python-sdk/src/astro/databases/google/bigquery.py @@ -115,7 +115,7 @@ def sql_type(self) -> str: @property def hook(self) -> BigQueryHook: """Retrieve Airflow hook to interface with the BigQuery database.""" - return BigQueryHook(gcp_conn_id=self.conn_id, use_legacy_sql=False) + return BigQueryHook(gcp_conn_id=self.conn_id, use_legacy_sql=False, location=BIGQUERY_SCHEMA_LOCATION) @property def sqlalchemy_engine(self) -> Engine: From a18e8d85746f1c291e16693c5b37b319dd46e1de Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Tue, 20 Dec 2022 16:12:03 +0530 Subject: [PATCH 21/34] Add google_cloud_platform connection --- .github/ci-test-connections.yaml | 6 ++++++ .../operators/data_validation/test_ColumnCheckOperator.py | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/.github/ci-test-connections.yaml b/.github/ci-test-connections.yaml index b2386bf8f..1e1369192 100644 --- a/.github/ci-test-connections.yaml +++ b/.github/ci-test-connections.yaml @@ -99,3 +99,9 @@ connections: description: null extra: connection_string: $AZURE_WASB_CONN_STRING + - conn_id: gcp_conn_project + conn_type: google_cloud_platform + description: null + extra: + project: "astronomer-dag-authoring" + project_id: "astronomer-dag-authoring" diff --git a/python-sdk/tests_integration/sql/operators/data_validation/test_ColumnCheckOperator.py b/python-sdk/tests_integration/sql/operators/data_validation/test_ColumnCheckOperator.py index ad2b599e4..54bf1b91e 100644 --- a/python-sdk/tests_integration/sql/operators/data_validation/test_ColumnCheckOperator.py +++ b/python-sdk/tests_integration/sql/operators/data_validation/test_ColumnCheckOperator.py @@ -21,7 +21,7 @@ { "database": Database.BIGQUERY, "file": File(path=str(CWD) + "/../../../data/data_validation.csv"), - "table": Table(conn_id="bigquery"), + "table": Table(conn_id="gcp_conn_project"), }, { "database": Database.POSTGRES, From 39051534b62d02a0200779957d8400a9e81e9ad7 Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Tue, 20 Dec 2022 16:31:08 +0530 Subject: [PATCH 22/34] Update conn_id --- .../sql/operators/data_validation/test_SQLCheckOperator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-sdk/tests_integration/sql/operators/data_validation/test_SQLCheckOperator.py b/python-sdk/tests_integration/sql/operators/data_validation/test_SQLCheckOperator.py index 8ec3d598e..fb8491bc8 100644 --- a/python-sdk/tests_integration/sql/operators/data_validation/test_SQLCheckOperator.py +++ b/python-sdk/tests_integration/sql/operators/data_validation/test_SQLCheckOperator.py @@ -21,7 +21,7 @@ { "database": Database.BIGQUERY, "file": File(path=str(CWD) + "/../../../data/homes_main.csv"), - "table": Table(conn_id="bigquery"), + "table": Table(conn_id="gcp_conn_project"), }, { "database": Database.POSTGRES, From a0e677312f6773f6fc8b91e5c4502d3d71bbe73a Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Tue, 20 Dec 2022 16:33:18 +0530 Subject: [PATCH 23/34] Change return type --- .../sql/operators/data_validations/ColumnCheckOperator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py index c2222e974..5b6816574 100644 --- a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py +++ b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py @@ -176,14 +176,14 @@ def col_max(column_name: str, df: pandas.DataFrame) -> Optional[float]: """ Get the max value in dataframe column """ - return df[column_name].max() + return float(df[column_name].max()) @staticmethod def col_min(column_name: str, df: pandas.DataFrame) -> Optional[float]: """ Get the min value in dataframe column """ - return df[column_name].min() + return float(df[column_name].min()) def _get_failed_checks(checks, col=None): From 686a09d9eca314951351a60fd5fa79b65390d04c Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Tue, 20 Dec 2022 16:46:31 +0530 Subject: [PATCH 24/34] Updated hook --- .../operators/data_validations/SQLCheckOperator.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/python-sdk/src/astro/sql/operators/data_validations/SQLCheckOperator.py b/python-sdk/src/astro/sql/operators/data_validations/SQLCheckOperator.py index a6aa2599b..5fa70dd26 100644 --- a/python-sdk/src/astro/sql/operators/data_validations/SQLCheckOperator.py +++ b/python-sdk/src/astro/sql/operators/data_validations/SQLCheckOperator.py @@ -2,11 +2,11 @@ from airflow.decorators.base import get_unique_task_id from airflow.models.xcom_arg import XComArg +from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.providers.common.sql.operators.sql import SQLTableCheckOperator from astro.databases import create_database from astro.table import BaseTable -from astro.utils.typing_compat import Context class SQLCheckOperator(SQLTableCheckOperator): @@ -54,6 +54,17 @@ def __init__( task_id=task_id or get_unique_task_id("sql_check"), ) + def get_db_hook(self) -> DbApiHook: + """ + Get the database hook for the connection. + + :return: the database hook object. + """ + db = create_database(conn_id=self.conn_id) + if db.sql_type == "bigquery": + return db.hook + return super().get_db_hook() + def sql_check( dataset: BaseTable, From 2b91128ba42c52e2d46492e1ae5024dcb97e06c1 Mon Sep 17 00:00:00 2001 From: Utkarsh Sharma Date: Tue, 20 Dec 2022 16:47:17 +0530 Subject: [PATCH 25/34] Update python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py Co-authored-by: Felix Uellendall --- .../astro/sql/operators/data_validations/ColumnCheckOperator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py index 5b6816574..b18e14401 100644 --- a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py +++ b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py @@ -155,7 +155,7 @@ def col_null_check(column_name: str, df: pandas.DataFrame) -> Optional[int]: """ Count the total null values in a dataframe column """ - return list(df[column_name].isnull().values).count(True) + return df[column_name].isna().sum() @staticmethod def col_distinct_check(column_name: str, df: pandas.DataFrame) -> Optional[int]: From ec3b4fd4e8c77cefeae789c9351e83361aba6389 Mon Sep 17 00:00:00 2001 From: Utkarsh Sharma Date: Tue, 20 Dec 2022 16:47:23 +0530 Subject: [PATCH 26/34] Update python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py Co-authored-by: Felix Uellendall --- .../astro/sql/operators/data_validations/ColumnCheckOperator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py index b18e14401..f83532b35 100644 --- a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py +++ b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py @@ -162,7 +162,7 @@ def col_distinct_check(column_name: str, df: pandas.DataFrame) -> Optional[int]: """ Count the distinct value in a dataframe column """ - return len(df[column_name].unique()) + return df[column_name].nunique() @staticmethod def col_unique_check(column_name: str, df: pandas.DataFrame) -> Optional[int]: From cd828d7f5595902ac427b1036b34a7d1a8985ada Mon Sep 17 00:00:00 2001 From: Utkarsh Sharma Date: Tue, 20 Dec 2022 16:48:01 +0530 Subject: [PATCH 27/34] Update python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py Co-authored-by: Felix Uellendall --- .../astro/sql/operators/data_validations/ColumnCheckOperator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py index f83532b35..ef7d2aeb8 100644 --- a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py +++ b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py @@ -151,7 +151,7 @@ def process_checks(self): print(f"The following tests have passed:" f"\n{''.join(passed_tests)}") @staticmethod - def col_null_check(column_name: str, df: pandas.DataFrame) -> Optional[int]: + def col_null_check(column_name: str, df: pandas.DataFrame) -> int: """ Count the total null values in a dataframe column """ From d1d5c28b649a49439f9890440daae364e7981e71 Mon Sep 17 00:00:00 2001 From: Utkarsh Sharma Date: Tue, 20 Dec 2022 16:48:28 +0530 Subject: [PATCH 28/34] Update python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py Co-authored-by: Felix Uellendall --- .../sql/operators/data_validations/ColumnCheckOperator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py index ef7d2aeb8..dab8e9303 100644 --- a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py +++ b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py @@ -118,10 +118,10 @@ def get_check_result(self, check_name: str, column_name: str, df: pandas.DataFra "max": self.col_max, } return column_checks[check_name](column_name=column_name, df=df) - elif df is None: + if df is None: raise ValueError("Dataframe is None") - else: - raise ValueError(f"Dataframe is don't have column {column_name}") + if column_name not in df.columns: + raise ValueError(f"Dataframe doesn't have column {column_name}") def process_checks(self): """ From a0543ad3e777cf567ba770fd6e1fe37d9ac9445e Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Tue, 20 Dec 2022 17:07:16 +0530 Subject: [PATCH 29/34] Revert changes to check since nunique() don't count None --- .../astro/sql/operators/data_validations/ColumnCheckOperator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py index dab8e9303..88a6097a3 100644 --- a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py +++ b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py @@ -162,7 +162,7 @@ def col_distinct_check(column_name: str, df: pandas.DataFrame) -> Optional[int]: """ Count the distinct value in a dataframe column """ - return df[column_name].nunique() + return len(df[column_name].unique()) @staticmethod def col_unique_check(column_name: str, df: pandas.DataFrame) -> Optional[int]: From 764e08a93198ddaf883ee92744109ea532784bfa Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Tue, 20 Dec 2022 17:34:21 +0530 Subject: [PATCH 30/34] Fix Deep Source --- python-sdk/src/astro/databases/base.py | 2 +- python-sdk/src/astro/sql/__init__.py | 10 ++++++++-- .../operators/data_validations/ColumnCheckOperator.py | 10 ++-------- .../data_validation/test_ColumnCheckOperator.py | 2 +- .../operators/data_validation/test_SQLCheckOperator.py | 2 +- 5 files changed, 13 insertions(+), 13 deletions(-) diff --git a/python-sdk/src/astro/databases/base.py b/python-sdk/src/astro/databases/base.py index 7f139beeb..4c12d345a 100644 --- a/python-sdk/src/astro/databases/base.py +++ b/python-sdk/src/astro/databases/base.py @@ -68,7 +68,7 @@ class BaseDatabase(ABC): NATIVE_AUTODETECT_SCHEMA_CONFIG: Mapping[FileLocation, Mapping[str, list[FileType] | Callable]] = {} FILE_PATTERN_BASED_AUTODETECT_SCHEMA_SUPPORTED: set[FileLocation] = set() - def __init__(self, conn_id: str, table: BaseTable | None = None): + def __init__(self, conn_id: str, table: BaseTable | None = None): # skipcq: PYL-W0613 self.conn_id = conn_id self.sql: str | ClauseElement = "" diff --git a/python-sdk/src/astro/sql/__init__.py b/python-sdk/src/astro/sql/__init__.py index 9b1a2ec68..f3f3b1cc2 100644 --- a/python-sdk/src/astro/sql/__init__.py +++ b/python-sdk/src/astro/sql/__init__.py @@ -4,8 +4,14 @@ from astro.sql.operators.append import AppendOperator, append from astro.sql.operators.cleanup import CleanupOperator, cleanup -from astro.sql.operators.data_validations.ColumnCheckOperator import ColumnCheckOperator, column_check -from astro.sql.operators.data_validations.SQLCheckOperator import SQLCheckOperator, sql_check +from astro.sql.operators.data_validations.ColumnCheckOperator import ( # skipcq: PY-W2000 + ColumnCheckOperator, + column_check, +) +from astro.sql.operators.data_validations.SQLCheckOperator import ( # skipcq: PY-W2000 + SQLCheckOperator, + sql_check, +) from astro.sql.operators.dataframe import DataframeOperator, dataframe from astro.sql.operators.drop import DropTableOperator, drop_table from astro.sql.operators.export_file import ExportFileOperator, export_file diff --git a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py index 88a6097a3..6ba16351d 100644 --- a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py +++ b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py @@ -91,14 +91,8 @@ def get_db_hook(self) -> DbApiHook: def execute(self, context: "Context"): if isinstance(self.dataset, BaseTable): - # Work around for GoogleBaseHook not inheriting from DBApi - # db = create_database( - # conn_id=self.conn_id, region=self.dataset.metadata.region or BIGQUERY_SCHEMA_LOCATION - # ) - # if db.sql_type == "bigquery": - # self._hook = db.hook return super().execute(context=context) - elif type(self.dataset) == pandas.DataFrame: + elif type(self.dataset) is pandas.DataFrame: self.df = self.dataset else: raise ValueError("dataset can only be of type pandas.dataframe | Table object") @@ -155,7 +149,7 @@ def col_null_check(column_name: str, df: pandas.DataFrame) -> int: """ Count the total null values in a dataframe column """ - return df[column_name].isna().sum() + return int(df[column_name].isna().sum()) @staticmethod def col_distinct_check(column_name: str, df: pandas.DataFrame) -> Optional[int]: diff --git a/python-sdk/tests_integration/sql/operators/data_validation/test_ColumnCheckOperator.py b/python-sdk/tests_integration/sql/operators/data_validation/test_ColumnCheckOperator.py index 54bf1b91e..6f0609cd1 100644 --- a/python-sdk/tests_integration/sql/operators/data_validation/test_ColumnCheckOperator.py +++ b/python-sdk/tests_integration/sql/operators/data_validation/test_ColumnCheckOperator.py @@ -44,7 +44,7 @@ def test_column_check_operator_with_table_dataset(sample_dag, database_table_fix Test column_check_operator with table dataset for all checks types and make sure the generated sql is working for all the database we support. """ - db, test_table = database_table_fixture + _, test_table = database_table_fixture with sample_dag: aql.ColumnCheckOperator( dataset=test_table, diff --git a/python-sdk/tests_integration/sql/operators/data_validation/test_SQLCheckOperator.py b/python-sdk/tests_integration/sql/operators/data_validation/test_SQLCheckOperator.py index fb8491bc8..d3401b6e1 100644 --- a/python-sdk/tests_integration/sql/operators/data_validation/test_SQLCheckOperator.py +++ b/python-sdk/tests_integration/sql/operators/data_validation/test_SQLCheckOperator.py @@ -44,7 +44,7 @@ def test_column_check_operator_with_table_dataset(sample_dag, database_table_fix Test column_check_operator with table dataset for all checks types and make sure the generated sql is working for all the database we support. """ - db, test_table = database_table_fixture + _, test_table = database_table_fixture with sample_dag: aql.SQLCheckOperator( dataset=test_table, From 9673af70e73c0a511c4b1274cd87d91ea81686fa Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Tue, 20 Dec 2022 18:21:18 +0530 Subject: [PATCH 31/34] Refactored code to remove duplication --- .../data_validations/ColumnCheckOperator.py | 22 +++++++------------ 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py index 6ba16351d..a89975771 100644 --- a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py +++ b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py @@ -181,33 +181,27 @@ def col_min(column_name: str, df: pandas.DataFrame) -> Optional[float]: def _get_failed_checks(checks, col=None): - if col: - return [ - f"Column: {col}\nCheck: {check},\nCheck Values: {check_values}\n" - for check, check_values in checks.items() - if not check_values["success"] - ] return [ - f"\tCheck: {check},\n\tCheck Values: {check_values}\n" + f"{get_checks_string(checks, col)} {check_values}\n" for check, check_values in checks.items() if not check_values["success"] ] def _get_success_checks(checks, col=None): - if col: - return [ - f"Column: {col}\nCheck: {check},\nCheck Values: {check_values}\n" - for check, check_values in checks.items() - if check_values["success"] - ] return [ - f"\tCheck: {check},\n\tCheck Values: {check_values}\n" + f"{get_checks_string(checks, col)} {check_values}\n" for check, check_values in checks.items() if check_values["success"] ] +def get_checks_string(check, col): + if col: + return f"Column: {col}\nCheck: {check},\nCheck Values:" + return f"\tCheck: {check},\n\tCheck Values:" + + def column_check( dataset: Union[BaseTable, pandas.DataFrame], column_mapping: Dict[str, Dict[str, Any]], From 5864a073d54bba9b96c54ff17068e7791d134321 Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Tue, 20 Dec 2022 18:32:24 +0530 Subject: [PATCH 32/34] Refactored code --- .../data_validations/ColumnCheckOperator.py | 48 +++---------------- 1 file changed, 7 insertions(+), 41 deletions(-) diff --git a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py index a89975771..05d39a928 100644 --- a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py +++ b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py @@ -105,13 +105,14 @@ def get_check_result(self, check_name: str, column_name: str, df: pandas.DataFra """ if df is not None and column_name in df.columns: column_checks = { - "null_check": self.col_null_check, - "distinct_check": self.col_distinct_check, - "unique_check": self.col_unique_check, - "min": self.col_min, - "max": self.col_max, + "null_check": lambda column_name, dataframe: int(df[column_name].isna().sum()), + "distinct_check": lambda column_name, dataframe: len(df[column_name].unique()), + "unique_check": lambda column_name, dataframe: len(df[column_name]) + - len(df[column_name].unique()), + "min": lambda column_name, dataframe: float(df[column_name].max()), + "max": lambda column_name, dataframe: float(df[column_name].min()), } - return column_checks[check_name](column_name=column_name, df=df) + return column_checks[check_name](column_name=column_name, dataframe=df) if df is None: raise ValueError("Dataframe is None") if column_name not in df.columns: @@ -144,41 +145,6 @@ def process_checks(self): if len(passed_tests) > 0: print(f"The following tests have passed:" f"\n{''.join(passed_tests)}") - @staticmethod - def col_null_check(column_name: str, df: pandas.DataFrame) -> int: - """ - Count the total null values in a dataframe column - """ - return int(df[column_name].isna().sum()) - - @staticmethod - def col_distinct_check(column_name: str, df: pandas.DataFrame) -> Optional[int]: - """ - Count the distinct value in a dataframe column - """ - return len(df[column_name].unique()) - - @staticmethod - def col_unique_check(column_name: str, df: pandas.DataFrame) -> Optional[int]: - """ - Count the unique value in a dataframe column - """ - return len(df[column_name]) - len(df[column_name].unique()) - - @staticmethod - def col_max(column_name: str, df: pandas.DataFrame) -> Optional[float]: - """ - Get the max value in dataframe column - """ - return float(df[column_name].max()) - - @staticmethod - def col_min(column_name: str, df: pandas.DataFrame) -> Optional[float]: - """ - Get the min value in dataframe column - """ - return float(df[column_name].min()) - def _get_failed_checks(checks, col=None): return [ From 493d65115b0bda1918b3d0549454fd0f6e885606 Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Tue, 20 Dec 2022 18:50:50 +0530 Subject: [PATCH 33/34] Code refactor --- .../data_validations/ColumnCheckOperator.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py index 05d39a928..1eb46ac29 100644 --- a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py +++ b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py @@ -92,7 +92,7 @@ def get_db_hook(self) -> DbApiHook: def execute(self, context: "Context"): if isinstance(self.dataset, BaseTable): return super().execute(context=context) - elif type(self.dataset) is pandas.DataFrame: + elif isinstance(self.dataset, pandas.DataFrame): self.df = self.dataset else: raise ValueError("dataset can only be of type pandas.dataframe | Table object") @@ -105,14 +105,13 @@ def get_check_result(self, check_name: str, column_name: str, df: pandas.DataFra """ if df is not None and column_name in df.columns: column_checks = { - "null_check": lambda column_name, dataframe: int(df[column_name].isna().sum()), - "distinct_check": lambda column_name, dataframe: len(df[column_name].unique()), - "unique_check": lambda column_name, dataframe: len(df[column_name]) - - len(df[column_name].unique()), - "min": lambda column_name, dataframe: float(df[column_name].max()), - "max": lambda column_name, dataframe: float(df[column_name].min()), + "null_check": lambda column_name: int(df[column_name].isna().sum()), + "distinct_check": lambda column_name: len(df[column_name].unique()), + "unique_check": lambda column_name: len(df[column_name]) - len(df[column_name].unique()), + "min": lambda column_name: float(df[column_name].min()), + "max": lambda column_name: float(df[column_name].max()), } - return column_checks[check_name](column_name=column_name, dataframe=df) + return column_checks[check_name](column_name=column_name) if df is None: raise ValueError("Dataframe is None") if column_name not in df.columns: From 8c71f394a35cf0229bca783f50ae98e86a5a4d34 Mon Sep 17 00:00:00 2001 From: utkarsh sharma Date: Tue, 20 Dec 2022 19:23:59 +0530 Subject: [PATCH 34/34] Refactored ColumnCheckOperator operator --- .../data_validations/ColumnCheckOperator.py | 36 +++++++++---------- 1 file changed, 17 insertions(+), 19 deletions(-) diff --git a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py index 1eb46ac29..e258d526f 100644 --- a/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py +++ b/python-sdk/src/astro/sql/operators/data_validations/ColumnCheckOperator.py @@ -99,22 +99,22 @@ def execute(self, context: "Context"): self.process_checks() - def get_check_result(self, check_name: str, column_name: str, df: pandas.DataFrame): + def get_check_result(self, check_name: str, column_name: str): """ Get the check method results post validating the dataframe """ - if df is not None and column_name in df.columns: + if self.df is not None and column_name in self.df.columns: column_checks = { - "null_check": lambda column_name: int(df[column_name].isna().sum()), - "distinct_check": lambda column_name: len(df[column_name].unique()), - "unique_check": lambda column_name: len(df[column_name]) - len(df[column_name].unique()), - "min": lambda column_name: float(df[column_name].min()), - "max": lambda column_name: float(df[column_name].max()), + "null_check": lambda column: column.isna().sum(), + "distinct_check": lambda column: len(column.unique()), + "unique_check": lambda column: len(column) - len(column.unique()), + "min": lambda column: column.min(), + "max": lambda column: column.max(), } - return column_checks[check_name](column_name=column_name) - if df is None: + return column_checks[check_name](column=self.df[column_name]) + if self.df is None: raise ValueError("Dataframe is None") - if column_name not in df.columns: + if column_name not in self.df.columns: raise ValueError(f"Dataframe doesn't have column {column_name}") def process_checks(self): @@ -129,15 +129,13 @@ def process_checks(self): checks = self.column_mapping[column] # Iterating over checks - for check in checks: - tolerance = self.column_mapping[column][check].get("tolerance") - result = self.get_check_result(check, column_name=column, df=self.df) - self.column_mapping[column][check]["result"] = result - self.column_mapping[column][check]["success"] = self._get_match( - self.column_mapping[column][check], result, tolerance - ) - failed_tests.extend(_get_failed_checks(self.column_mapping[column], column)) - passed_tests.extend(_get_success_checks(self.column_mapping[column], column)) + for check_key, check_val in checks.items(): + tolerance = check_val.get("tolerance") + result = self.get_check_result(check_key, column_name=column) + check_val["result"] = result + check_val["success"] = self._get_match(check_val, result, tolerance) + failed_tests.extend(_get_failed_checks(checks, column)) + passed_tests.extend(_get_success_checks(checks, column)) if len(failed_tests) > 0: raise AirflowException(f"The following tests have failed:" f"\n{''.join(failed_tests)}")