diff --git a/README.md b/README.md index a2968fecc..50f958464 100644 --- a/README.md +++ b/README.md @@ -156,17 +156,16 @@ In addition, please note that SHA256 is not a supported function on Teradata sys If you wish to perform this comparison on Teradata you will need to [deploy a UDF to perform the conversion](https://github.com/akuroda/teradata-udf-sha2/blob/master/src/sha256.c).) -Below is the command syntax for row validations. In order to run row level -validations you need to pass a `--primary-key` flag which defines what field(s) -the validation will be compared on, as well as either the `--comparison-fields` flag -or the `--hash` flag. See *Primary Keys* section +Below is the command syntax for row validations. In order to run row level validations we require +unique columns to join row sets, which are either inferred from the source/target table or provided +via the `--primary-keys` flag, and either the `--hash`, `--concat` or `--comparison-fields` flags. +See *Primary Keys* section. The `--comparison-fields` flag specifies the values (e.g. columns) whose raw values will be compared based on the primary key join. The `--hash` flag will run a checksum across specified columns in the table. This will include casting to string, sanitizing the data (ifnull, rtrim, upper), concatenating, and finally hashing the row. - Under the hood, row validation uses [Calculated Fields](https://github.com/GoogleCloudPlatform/professional-services-data-validator#calculated-fields) to apply functions such as IFNULL() or RTRIM(). These can be edited in the YAML or JSON config file to customize your row validation. @@ -188,13 +187,14 @@ data-validation Comma separated list of tables in the form schema.table=target_schema.target_table Target schema name and table name are optional. i.e 'bigquery-public-data.new_york_citibike.citibike_trips' - --primary-keys or -pk PRIMARY_KEYS - Comma separated list of columns to use as primary keys. See *Primary Keys* section --comparison-fields or -comp-fields FIELDS Comma separated list of columns to compare. Can either be a physical column or an alias See: *Calculated Fields* section for details --hash COLUMNS Comma separated list of columns to hash or * for all columns --concat COLUMNS Comma separated list of columns to concatenate or * for all columns (use if a common hash function is not available between databases) + [--primary-keys PRIMARY_KEYS, -pk PRIMARY_KEYS] + Comma separated list of primary key columns, when not specified the value will be inferred + from the source or target table if available. See *Primary Keys* section [--exclude-columns or -ec] Flag to indicate the list of columns provided should be excluded from hash or concat instead of included. [--bq-result-handler or -bqrh PROJECT_ID.DATASET.TABLE] @@ -262,8 +262,6 @@ data-validation Either --tables-list or --source-query (or file) and --target-query (or file) must be provided --target-query-file TARGET_QUERY_FILE, -tqf TARGET_QUERY_FILE File containing the target sql command. Supports GCS and local paths. - --primary-keys PRIMARY_KEYS, -pk PRIMARY_KEYS - Comma separated list of primary key columns 'col_a,col_b'. See *Primary Keys* section --comparison-fields or -comp-fields FIELDS Comma separated list of columns to compare. Can either be a physical column or an alias See: *Calculated Fields* section for details @@ -277,6 +275,9 @@ data-validation --partition-num INT, -pn INT Number of partitions into which the table should be split, e.g. 1000 or 10000 In case this value exceeds the row count of the source/target table, it will be decreased to max(source_row_count, target_row_count) + [--primary-keys PRIMARY_KEYS, -pk PRIMARY_KEYS] + Comma separated list of primary key columns, when not specified the value will be inferred + from the source or target table if available. See *Primary Keys* section [--bq-result-handler or -bqrh PROJECT_ID.DATASET.TABLE] BigQuery destination for validation results. Defaults to stdout. See: *Validation Reports* section @@ -448,8 +449,8 @@ data-validation --hash '*' '*' to hash all columns. --concat COLUMNS Comma separated list of columns to concatenate or * for all columns (use if a common hash function is not available between databases) - --primary-key or -pk JOIN_KEY - Common column between source and target tables for join + [--primary-keys PRIMARY_KEYS, -pk PRIMARY_KEYS] + Common column between source and target queries for join [--exclude-columns or -ec] Flag to indicate the list of columns provided should be excluded from hash or concat instead of included. [--bq-result-handler or -bqrh PROJECT_ID.DATASET.TABLE] @@ -679,6 +680,8 @@ In many cases, validations (e.g. count, min, max etc) produce one row per table. and target table is to compare the value for each column in the source with the value of the column in the target. `grouped-columns` validation and `validate row` produce multiple rows per table. Data Validation Tool needs one or more columns to uniquely identify each row so the source and target can be compared. Data Validation Tool refers to these columns as primary keys. These do not need to be primary keys in the table. The only requirement is that the keys uniquely identify the row in the results. +These columns are inferred, where possible, from the source/target table or can be provided via the `--primary-keys` flag. + ### Grouped Columns Grouped Columns contain the fields you want your aggregations to be broken out diff --git a/data_validation/__main__.py b/data_validation/__main__.py index 0cb8d48c8..4a9535d72 100644 --- a/data_validation/__main__.py +++ b/data_validation/__main__.py @@ -219,7 +219,9 @@ def _get_calculated_config(args, config_manager: ConfigManager) -> List[dict]: return calculated_configs -def _get_comparison_config(args, config_manager: ConfigManager) -> List[dict]: +def _get_comparison_config( + args, config_manager: ConfigManager, primary_keys: list +) -> List[dict]: col_list = ( None if args.comparison_fields == "*" @@ -230,11 +232,7 @@ def _get_comparison_config(args, config_manager: ConfigManager) -> List[dict]: args.exclude_columns, ) # We can't have the PK columns in the comparison SQL twice therefore filter them out here if included. - comparison_fields = [ - _ - for _ in comparison_fields - if _ not in cli_tools.get_arg_list(args.primary_keys.casefold()) - ] + comparison_fields = [_ for _ in comparison_fields if _ not in primary_keys] # As per #1190, add rstrip for Teradata string comparison fields if ( @@ -314,18 +312,25 @@ def build_config_from_args(args: Namespace, config_manager: ConfigManager): _get_calculated_config(args, config_manager) ) - # Append Comparison fields - if args.comparison_fields: - config_manager.append_comparison_fields( - _get_comparison_config(args, config_manager) - ) - # Append primary_keys primary_keys = cli_tools.get_arg_list(args.primary_keys) + if not primary_keys and config_manager.validation_type != consts.CUSTOM_QUERY: + primary_keys = config_manager.auto_list_primary_keys() + if not primary_keys: + raise ValueError( + "No primary keys were provided and neither the source or target tables have primary keys. Please include --primary-keys argument" + ) + primary_keys = [_.casefold() for _ in primary_keys] config_manager.append_primary_keys( config_manager.build_column_configs(primary_keys) ) + # Append Comparison fields + if args.comparison_fields: + config_manager.append_comparison_fields( + _get_comparison_config(args, config_manager, primary_keys) + ) + return config_manager diff --git a/data_validation/cli_tools.py b/data_validation/cli_tools.py index e44bb93fc..bd5b43f63 100644 --- a/data_validation/cli_tools.py +++ b/data_validation/cli_tools.py @@ -170,6 +170,12 @@ ], } +VALIDATE_HELP_TEXT = "Run a validation and optionally store to config" +VALIDATE_COLUMN_HELP_TEXT = "Run a column validation" +VALIDATE_ROW_HELP_TEXT = "Run a row validation" +VALIDATE_SCHEMA_HELP_TEXT = "Run a schema validation" +VALIDATE_CUSTOM_QUERY_HELP_TEXT = "Run a custom query validation" + def _check_custom_query_args(parser: argparse.ArgumentParser, parsed_args: Namespace): # This is where we make additional checks if the arguments provided are what we expect @@ -471,9 +477,7 @@ def _configure_database_specific_parsers(parser): def _configure_validate_parser(subparsers): """Configure arguments to run validations.""" - validate_parser = subparsers.add_parser( - "validate", help="Run a validation and optionally store to config" - ) + validate_parser = subparsers.add_parser("validate", help=VALIDATE_HELP_TEXT) validate_parser.add_argument( "--dry-run", @@ -485,22 +489,22 @@ def _configure_validate_parser(subparsers): validate_subparsers = validate_parser.add_subparsers(dest="validate_cmd") column_parser = validate_subparsers.add_parser( - "column", help="Run a column validation" + "column", help=VALIDATE_COLUMN_HELP_TEXT ) _configure_column_parser(column_parser) - row_parser = validate_subparsers.add_parser("row", help="Run a row validation") + row_parser = validate_subparsers.add_parser("row", help=VALIDATE_ROW_HELP_TEXT) optional_arguments = row_parser.add_argument_group("optional arguments") required_arguments = row_parser.add_argument_group("required arguments") _configure_row_parser(row_parser, optional_arguments, required_arguments) schema_parser = validate_subparsers.add_parser( - "schema", help="Run a schema validation" + "schema", help=VALIDATE_SCHEMA_HELP_TEXT ) _configure_schema_parser(schema_parser) custom_query_parser = validate_subparsers.add_parser( - "custom-query", help="Run a custom query validation" + "custom-query", help=VALIDATE_CUSTOM_QUERY_HELP_TEXT ) _configure_custom_query_parser(custom_query_parser) @@ -514,6 +518,15 @@ def _configure_row_parser( ): """Configure arguments to run row level validations.""" # Group optional arguments + optional_arguments.add_argument( + "--primary-keys", + "-pk", + help=( + "Comma separated list of primary key columns 'col_a,col_b', " + "when not specified the value will be inferred from the source or target table if available" + ), + ) + optional_arguments.add_argument( "--threshold", "-th", @@ -586,14 +599,6 @@ def _configure_row_parser( help="Comma separated tables list in the form 'schema.table=target_schema.target_table'", ) - # Group required arguments - required_arguments.add_argument( - "--primary-keys", - "-pk", - required=True, - help="Comma separated list of primary key columns 'col_a,col_b'", - ) - # Group for mutually exclusive required arguments. Either must be supplied mutually_exclusive_arguments = required_arguments.add_mutually_exclusive_group( required=True diff --git a/data_validation/config_manager.py b/data_validation/config_manager.py index ed831b033..2e300966d 100644 --- a/data_validation/config_manager.py +++ b/data_validation/config_manager.py @@ -1177,3 +1177,20 @@ def build_comp_fields(self, col_list: list, exclude_cols: bool = False) -> dict: ) return casefold_source_columns + + def auto_list_primary_keys(self) -> list: + """Returns a list of primary key columns based on the source/target table. + + If neither source nor target systems have a primary key defined then [] is returned. + """ + assert ( + self.validation_type != consts.CUSTOM_QUERY + ), "Custom query validations should not be able to reach this method" + primary_keys = self.source_client.list_primary_key_columns( + self.source_schema, self.source_table + ) + if not primary_keys: + primary_keys = self.target_client.list_primary_key_columns( + self.target_schema, self.target_table + ) + return primary_keys or [] diff --git a/tests/resources/snowflake_test_tables.sql b/tests/resources/snowflake_test_tables.sql index 8802bb8a4..a122a948f 100644 --- a/tests/resources/snowflake_test_tables.sql +++ b/tests/resources/snowflake_test_tables.sql @@ -13,7 +13,7 @@ -- limitations under the License. CREATE OR REPLACE TABLE PSO_DATA_VALIDATOR.PUBLIC.DVT_CORE_TYPES ( - ID INT NOT NULL, + ID INT NOT NULL PRIMARY KEY, COL_INT8 TINYINT, COL_INT16 SMALLINT, COL_INT32 INT, diff --git a/tests/system/data_sources/common_functions.py b/tests/system/data_sources/common_functions.py index 9342716bf..53f246424 100644 --- a/tests/system/data_sources/common_functions.py +++ b/tests/system/data_sources/common_functions.py @@ -292,7 +292,7 @@ def row_validation_test( f"-tc={tc}", f"-tbls={tables}", f"--filters={filters}", - f"--primary-keys={primary_keys}", + f"--primary-keys={primary_keys}" if primary_keys else None, "--filter-status=fail", f"--comparison-fields={comp_fields}" if comp_fields else f"--hash={hash}", "--use-random-row" if use_randow_row else None, diff --git a/tests/system/data_sources/test_bigquery.py b/tests/system/data_sources/test_bigquery.py index 3610b1d13..47cfe1a97 100644 --- a/tests/system/data_sources/test_bigquery.py +++ b/tests/system/data_sources/test_bigquery.py @@ -1245,6 +1245,24 @@ def test_row_validation_core_types(mock_conn): ) +@mock.patch( + "data_validation.state_manager.StateManager.get_connection_config", + return_value=BQ_CONN, +) +def test_row_validation_core_types_auto_pks(mock_conn): + """Test auto population of -pks from BigQuery - expect an exception. + + Expects: + ValueError: --primary-keys argument is required for this validation + """ + with pytest.raises(ValueError): + row_validation_test( + tc="mock-conn", + hash="col_int8,col_int16", + primary_keys=None, + ) + + @mock.patch( "data_validation.state_manager.StateManager.get_connection_config", return_value=BQ_CONN, diff --git a/tests/system/data_sources/test_db2.py b/tests/system/data_sources/test_db2.py index b2cf860c8..9a605b074 100644 --- a/tests/system/data_sources/test_db2.py +++ b/tests/system/data_sources/test_db2.py @@ -152,6 +152,20 @@ def test_row_validation_core_types(): ) +@mock.patch( + "data_validation.state_manager.StateManager.get_connection_config", + new=mock_get_connection_config, +) +def test_row_validation_core_types_auto_pks(): + """Test auto population of -pks from DB2 defined constraint.""" + row_validation_test( + tables="db2inst1.dvt_core_types", + tc="mock-conn", + hash="col_string", + primary_keys=None, + ) + + @mock.patch( "data_validation.state_manager.StateManager.get_connection_config", new=mock_get_connection_config, diff --git a/tests/system/data_sources/test_mysql.py b/tests/system/data_sources/test_mysql.py index cb73beb9d..e13f45aaf 100644 --- a/tests/system/data_sources/test_mysql.py +++ b/tests/system/data_sources/test_mysql.py @@ -258,6 +258,19 @@ def test_row_validation_core_types(): ) +@mock.patch( + "data_validation.state_manager.StateManager.get_connection_config", + new=mock_get_connection_config, +) +def test_row_validation_core_types_auto_pks(): + """Test auto population of -pks from MySQL defined constraint.""" + row_validation_test( + tc="mock-conn", + hash="col_int8,col_int16", + primary_keys=None, + ) + + @mock.patch( "data_validation.state_manager.StateManager.get_connection_config", new=mock_get_connection_config, diff --git a/tests/system/data_sources/test_oracle.py b/tests/system/data_sources/test_oracle.py index e75cdbeb9..049917f03 100644 --- a/tests/system/data_sources/test_oracle.py +++ b/tests/system/data_sources/test_oracle.py @@ -325,6 +325,19 @@ def test_row_validation_core_types(): ) +@mock.patch( + "data_validation.state_manager.StateManager.get_connection_config", + new=mock_get_connection_config, +) +def test_row_validation_core_types_auto_pks(): + """Test auto population of -pks from Oracle defined constraint.""" + row_validation_test( + tc="mock-conn", + hash="col_int8,col_int16", + primary_keys=None, + ) + + @mock.patch( "data_validation.state_manager.StateManager.get_connection_config", new=mock_get_connection_config, diff --git a/tests/system/data_sources/test_postgres.py b/tests/system/data_sources/test_postgres.py index e256f6842..4acf64b72 100644 --- a/tests/system/data_sources/test_postgres.py +++ b/tests/system/data_sources/test_postgres.py @@ -682,12 +682,26 @@ def test_row_validation_pg_types(): ) +@mock.patch( + "data_validation.state_manager.StateManager.get_connection_config", + new=mock_get_connection_config, +) +def test_row_validation_core_types_auto_pks(): + """Test auto population of -pks from PostgreSQL defined constraint.""" + row_validation_test( + tables="pso_data_validator.dvt_core_types", + tc="mock-conn", + hash="col_int8,col_int16", + primary_keys=None, + ) + + @mock.patch( "data_validation.state_manager.StateManager.get_connection_config", new=mock_get_connection_config, ) def test_row_validation_comp_fields_pg_types(): - """PostgreSQL to PostgreSQL dvt_core_types row validation with --comp-fields""" + """PostgreSQL to PostgreSQL dvt_pg_types row validation --comp-fields""" row_validation_test( tables="pso_data_validator.dvt_pg_types", tc="mock-conn", diff --git a/tests/system/data_sources/test_snowflake.py b/tests/system/data_sources/test_snowflake.py index a71bd97e1..89e22f5be 100644 --- a/tests/system/data_sources/test_snowflake.py +++ b/tests/system/data_sources/test_snowflake.py @@ -264,6 +264,20 @@ def test_row_validation_core_types(): ) +@mock.patch( + "data_validation.state_manager.StateManager.get_connection_config", + new=mock_get_connection_config, +) +def test_row_validation_core_types_auto_pks(): + """Test auto population of -pks from Snowflake defined constraint.""" + row_validation_test( + tables="PSO_DATA_VALIDATOR.PUBLIC.DVT_CORE_TYPES", + tc="mock-conn", + hash="col_int8,col_int16", + primary_keys=None, + ) + + @mock.patch( "data_validation.state_manager.StateManager.get_connection_config", new=mock_get_connection_config, diff --git a/tests/system/data_sources/test_sql_server.py b/tests/system/data_sources/test_sql_server.py index b23cf0d43..6e3a8d20a 100644 --- a/tests/system/data_sources/test_sql_server.py +++ b/tests/system/data_sources/test_sql_server.py @@ -351,6 +351,19 @@ def test_row_validation_core_types(): ) +@mock.patch( + "data_validation.state_manager.StateManager.get_connection_config", + new=mock_get_connection_config, +) +def test_row_validation_core_types_auto_pks(): + """Test auto population of -pks from SQL Server defined constraint.""" + row_validation_test( + tc="mock-conn", + hash="col_int8,col_int16", + primary_keys=None, + ) + + @mock.patch( "data_validation.state_manager.StateManager.get_connection_config", new=mock_get_connection_config, diff --git a/tests/system/data_sources/test_teradata.py b/tests/system/data_sources/test_teradata.py index 5cafaf063..677684d20 100644 --- a/tests/system/data_sources/test_teradata.py +++ b/tests/system/data_sources/test_teradata.py @@ -334,6 +334,22 @@ def test_row_validation_core_types(): ) +@mock.patch( + "data_validation.state_manager.StateManager.get_connection_config", + new=mock_get_connection_config, +) +def test_row_validation_core_types_auto_pks(): + """Test auto population of -pks from Teradata defined constraint. + + Tests this with comp-fields, some other engines test with hash validation.""" + row_validation_test( + tables="udf.dvt_core_types", + tc="mock-conn", + comp_fields="col_int8,col_int16", + primary_keys=None, + ) + + # Expected result from partitioning table on 3 keys EXPECTED_PARTITION_FILTER = [ [ diff --git a/tests/unit/test_cli_tools.py b/tests/unit/test_cli_tools.py index 5188528a8..e1a2544ef 100644 --- a/tests/unit/test_cli_tools.py +++ b/tests/unit/test_cli_tools.py @@ -608,3 +608,79 @@ def test_get_query_from_inline(test_input: str, expect_exception: bool): else: query = cli_tools.get_query_from_inline(test_input) assert query in test_input + + +def test_arg_parser_help(capsys): + """Test --help arg.""" + parser = cli_tools.configure_arg_parser() + with pytest.raises(SystemExit): + _ = parser.parse_args(["--help"]) + captured = capsys.readouterr() + assert cli_tools.VALIDATE_HELP_TEXT in captured.out + + +def test_arg_parser_validate_help(capsys): + """Test validate --help arg.""" + parser = cli_tools.configure_arg_parser() + with pytest.raises(SystemExit): + _ = parser.parse_args(["validate", "--help"]) + captured = capsys.readouterr() + assert cli_tools.VALIDATE_COLUMN_HELP_TEXT in captured.out + assert cli_tools.VALIDATE_ROW_HELP_TEXT in captured.out + assert cli_tools.VALIDATE_SCHEMA_HELP_TEXT in captured.out + assert cli_tools.VALIDATE_CUSTOM_QUERY_HELP_TEXT in captured.out + + +def test_arg_parser_validate_column_help(capsys): + """Test validate column --help arg.""" + parser = cli_tools.configure_arg_parser() + with pytest.raises(SystemExit): + _ = parser.parse_args(["validate", "column", "--help"]) + captured = capsys.readouterr() + assert "--sum" in captured.out + assert "--hash" not in captured.out + assert "--source-query" not in captured.out + assert "--primary-keys" not in captured.out + + +def test_arg_parser_validate_row_help(capsys): + """Test validate row --help arg.""" + parser = cli_tools.configure_arg_parser() + with pytest.raises(SystemExit): + _ = parser.parse_args(["validate", "row", "--help"]) + captured = capsys.readouterr() + assert "--hash" in captured.out + assert "--source-query" not in captured.out + assert "--primary-keys" in captured.out + + +def test_arg_parser_validate_schema_help(capsys): + """Test validate column --help arg.""" + parser = cli_tools.configure_arg_parser() + with pytest.raises(SystemExit): + _ = parser.parse_args(["validate", "column", "--help"]) + captured = capsys.readouterr() + assert "--sum" in captured.out + assert "--hash" not in captured.out + assert "--source-query" not in captured.out + assert "--primary-keys" not in captured.out + + +def test_arg_parser_validate_custom_query_row_help(capsys): + """Test validate custom-query row --help arg.""" + parser = cli_tools.configure_arg_parser() + with pytest.raises(SystemExit): + _ = parser.parse_args(["validate", "custom-query", "row", "--help"]) + captured = capsys.readouterr() + assert "--hash" in captured.out + assert "--source-query" in captured.out + assert "--primary-keys" in captured.out + + +def test_arg_parser_generate_table_partitions_help(capsys): + """Test generate-table-partitions --help arg.""" + parser = cli_tools.configure_arg_parser() + with pytest.raises(SystemExit): + _ = parser.parse_args(["generate-table-partitions", "--help"]) + captured = capsys.readouterr() + assert "--partition-num" in captured.out diff --git a/third_party/ibis/ibis_addon/operations.py b/third_party/ibis/ibis_addon/operations.py index 3189543e8..733fd1da2 100644 --- a/third_party/ibis/ibis_addon/operations.py +++ b/third_party/ibis/ibis_addon/operations.py @@ -69,6 +69,7 @@ from ibis.expr.types import BinaryValue, NumericValue, TemporalValue # Do not remove these lines, they trigger patching of Ibis code. +import third_party.ibis.ibis_biquery.api # noqa import third_party.ibis.ibis_mysql.compiler # noqa from third_party.ibis.ibis_mssql.registry import mssql_table_column import third_party.ibis.ibis_postgres.client # noqa diff --git a/third_party/ibis/ibis_biquery/__init__.py b/third_party/ibis/ibis_biquery/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/third_party/ibis/ibis_biquery/api.py b/third_party/ibis/ibis_biquery/api.py new file mode 100644 index 000000000..bf9b48c14 --- /dev/null +++ b/third_party/ibis/ibis_biquery/api.py @@ -0,0 +1,10 @@ +from ibis.backends.bigquery import Backend as BigQueryBackend + + +def _list_primary_key_columns(self, database: str, table: str) -> list: + """Return a list of primary key column names.""" + # TODO: Related to issue-1253, it's not clear if this is possible, we should revisit if it becomes a requirement. + return None + + +BigQueryBackend.list_primary_key_columns = _list_primary_key_columns diff --git a/third_party/ibis/ibis_cloud_spanner/__init__.py b/third_party/ibis/ibis_cloud_spanner/__init__.py index 7a9f7bdd8..08dd25f1f 100644 --- a/third_party/ibis/ibis_cloud_spanner/__init__.py +++ b/third_party/ibis/ibis_cloud_spanner/__init__.py @@ -214,6 +214,11 @@ def drop_view(): def fetch_from_cursor(): pass + def list_primary_key_columns(self, database: str, table: str) -> list: + """Return a list of primary key column names.""" + # TODO: Related to issue-1253, it's not clear if this is possible, we should revisit if it becomes a requirement. + return None + def parse_instance_and_dataset( instance: str, dataset: Optional[str] = None diff --git a/third_party/ibis/ibis_db2/__init__.py b/third_party/ibis/ibis_db2/__init__.py index f0a4650ce..107204b10 100644 --- a/third_party/ibis/ibis_db2/__init__.py +++ b/third_party/ibis/ibis_db2/__init__.py @@ -84,3 +84,20 @@ def _metadata(self, query) -> Iterable[Tuple[str, dt.DataType]]: (column[0].lower(), _get_type(column[1])) for column in cursor.description ) + + def list_primary_key_columns(self, database: str, table: str) -> list: + """Return a list of primary key column names.""" + list_pk_col_sql = """ + SELECT key.colname + FROM syscat.tables tab + INNER JOIN syscat.tabconst const ON const.tabschema = tab.tabschema AND const.tabname = tab.tabname and const.type = 'P' + INNER JOIN syscat.keycoluse key ON const.tabschema = key.tabschema AND const.tabname = key.tabname AND const.constname = key.constname + WHERE tab.type = 'T' + AND tab.tabschema = ? + AND tab.tabname = ? + ORDER BY key.colseq""" + with self.begin() as con: + result = con.exec_driver_sql( + list_pk_col_sql, parameters=(database.upper(), table.upper()) + ) + return [_[0] for _ in result.cursor.fetchall()] diff --git a/third_party/ibis/ibis_impala/api.py b/third_party/ibis/ibis_impala/api.py index 81bb6e0ee..e432a3968 100644 --- a/third_party/ibis/ibis_impala/api.py +++ b/third_party/ibis/ibis_impala/api.py @@ -230,11 +230,17 @@ def _get_schema_using_query(self, query): return sch.Schema(ibis_fields) +def _list_primary_key_columns(self, database: str, table: str) -> list: + """No primary keys in Hadoop.""" + return None + + udf.parse_type = parse_type ibis.backends.impala._chunks_to_pandas_array = _chunks_to_pandas_array ImpalaBackend.get_schema = get_schema ImpalaBackend._get_schema_using_query = _get_schema_using_query ImpalaBackend.do_connect = do_connect +ImpalaBackend.list_primary_key_columns = _list_primary_key_columns def impala_connect( diff --git a/third_party/ibis/ibis_mssql/__init__.py b/third_party/ibis/ibis_mssql/__init__.py index 7433ee80c..156817de8 100644 --- a/third_party/ibis/ibis_mssql/__init__.py +++ b/third_party/ibis/ibis_mssql/__init__.py @@ -93,3 +93,19 @@ def _metadata(self, query): with self.begin() as bind: for column in bind.execute(query).mappings(): yield column["name"], _type_from_result_set_info(column) + + def list_primary_key_columns(self, database: str, table: str) -> list: + """Return a list of primary key column names.""" + list_pk_col_sql = """ + SELECT COL_NAME(ic.object_id, ic.column_id) AS column_name + FROM sys.tables t + INNER JOIN sys.indexes i ON (t.object_id = i.object_id) + INNER JOIN sys.index_columns ic ON (i.object_id = ic.object_id AND i.index_id = ic.index_id) + INNER JOIN sys.schemas s ON (t.schema_id = s.schema_id) + WHERE s.name = ? + AND t.name = ? + AND i.is_primary_key = 1 + ORDER BY ic.column_id""" + with self.begin() as con: + result = con.exec_driver_sql(list_pk_col_sql, parameters=(database, table)) + return [_[0] for _ in result.cursor.fetchall()] diff --git a/third_party/ibis/ibis_mysql/__init__.py b/third_party/ibis/ibis_mysql/__init__.py index e69de29bb..4f2b68793 100644 --- a/third_party/ibis/ibis_mysql/__init__.py +++ b/third_party/ibis/ibis_mysql/__init__.py @@ -0,0 +1,34 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ibis.backends.mysql import Backend as MySQLBackend + + +def _list_primary_key_columns(self, database: str, table: str) -> list: + """Return a list of primary key column names.""" + # No true binds in MySQL: + # https://dev.mysql.com/doc/connector-python/en/connector-python-api-mysqlcursor-execute.html + list_pk_col_sql = """ + SELECT COLUMN_NAME + FROM INFORMATION_SCHEMA.COLUMNS + WHERE TABLE_SCHEMA = %s + AND TABLE_NAME = %s + AND COLUMN_KEY = 'PRI' + """ + with self.begin() as con: + result = con.exec_driver_sql(list_pk_col_sql, parameters=(database, table)) + return [_[0] for _ in result.cursor.fetchall()] + + +MySQLBackend.list_primary_key_columns = _list_primary_key_columns diff --git a/third_party/ibis/ibis_oracle/__init__.py b/third_party/ibis/ibis_oracle/__init__.py index d6dd3055b..b184b2355 100644 --- a/third_party/ibis/ibis_oracle/__init__.py +++ b/third_party/ibis/ibis_oracle/__init__.py @@ -146,3 +146,20 @@ def _metadata(self, query) -> Iterable[Tuple[str, dt.DataType]]: result = con.exec_driver_sql(f"SELECT * FROM {query} t0 WHERE ROWNUM <= 1") cursor = result.cursor yield from ((column[0], _get_type(column)) for column in cursor.description) + + def list_primary_key_columns(self, database: str, table: str) -> list: + """Return a list of primary key column names.""" + list_pk_col_sql = """ + SELECT cc.column_name + FROM all_cons_columns cc + INNER JOIN all_constraints c ON (cc.owner = c.owner AND cc.constraint_name = c.constraint_name AND cc.table_name = c.table_name) + WHERE c.owner = :1 + AND c.table_name = :2 + AND c.constraint_type = 'P' + ORDER BY cc.position + """ + with self.begin() as con: + result = con.exec_driver_sql( + list_pk_col_sql, parameters=(database.upper(), table.upper()) + ) + return [_[0] for _ in result.cursor.fetchall()] diff --git a/third_party/ibis/ibis_postgres/client.py b/third_party/ibis/ibis_postgres/client.py index c02e97cf3..31c21d500 100644 --- a/third_party/ibis/ibis_postgres/client.py +++ b/third_party/ibis/ibis_postgres/client.py @@ -126,6 +126,25 @@ def list_schemas(self, like=None): return self._filter_with_like(schemas, like) +def _list_primary_key_columns(self, database: str, table: str) -> list: + """Return a list of primary key column names.""" + # From https://wiki.postgresql.org/wiki/Retrieve_primary_key_columns + list_pk_col_sql = """ + SELECT a.attname + FROM pg_index i + JOIN pg_attribute a ON a.attrelid = i.indrelid + AND a.attnum = ANY(i.indkey) + WHERE i.indrelid = CAST(:raw_name AS regclass) + AND i.indisprimary + """ + with self.begin() as con: + result = con.execute( + sa.text(list_pk_col_sql).bindparams(raw_name=f"{database}.{table}") + ) + return [_[0] for _ in result.cursor.fetchall()] + + PostgresBackend._metadata = _metadata PostgresBackend.list_databases = list_schemas PostgresBackend.do_connect = do_connect +PostgresBackend.list_primary_key_columns = _list_primary_key_columns diff --git a/third_party/ibis/ibis_redshift/__init__.py b/third_party/ibis/ibis_redshift/__init__.py index 33b0416ca..af5949dff 100644 --- a/third_party/ibis/ibis_redshift/__init__.py +++ b/third_party/ibis/ibis_redshift/__init__.py @@ -105,6 +105,11 @@ def _get_temp_view_definition( ) -> str: yield f"CREATE OR REPLACE TEMPORARY VIEW {name} AS {definition}" + def list_primary_key_columns(self, database: str, table: str) -> list: + """Return a list of primary key column names.""" + # TODO: Related to issue-1253, it's not clear if this is possible, we should revisit if it becomes a requirement. + return None + def _get_type(typestr: str) -> dt.DataType: is_array = typestr.endswith(_BRACKETS) diff --git a/third_party/ibis/ibis_snowflake/datatypes.py b/third_party/ibis/ibis_snowflake/datatypes.py index 47e736b30..764c92bc6 100644 --- a/third_party/ibis/ibis_snowflake/datatypes.py +++ b/third_party/ibis/ibis_snowflake/datatypes.py @@ -50,4 +50,15 @@ def _metadata(self, query: str) -> Iterable[Tuple[str, dt.DataType]]: yield name, typ +def _list_primary_key_columns(self, database: str, table: str) -> list: + """Return a list of primary key column names.""" + # From https://docs.snowflake.com/en/sql-reference/sql/show-primary-keys + # Column name is 5th field in output. + list_pk_col_sql = f"SHOW PRIMARY KEYS IN {database}.{table};" + with self.begin() as con: + result = con.exec_driver_sql(list_pk_col_sql) + return [_[4] for _ in result.cursor.fetchall()] + + SnowflakeBackend._metadata = _metadata +SnowflakeBackend.list_primary_key_columns = _list_primary_key_columns diff --git a/third_party/ibis/ibis_teradata/__init__.py b/third_party/ibis/ibis_teradata/__init__.py index 42bbe7d8c..b37c6e3cf 100644 --- a/third_party/ibis/ibis_teradata/__init__.py +++ b/third_party/ibis/ibis_teradata/__init__.py @@ -175,14 +175,14 @@ def _adapt_types(self, descr): adapted_types.append(typename) return dict(zip(names, adapted_types)) - def _execute(self, sql, results=False): + def _execute(self, sql, results=False, params=None): if self.use_no_lock_tables and sql.strip().startswith("SELECT"): sql = self.NO_LOCK_SQL + sql with warnings.catch_warnings(): # Suppress pandas warning of SQLAlchemy connectable DB support warnings.simplefilter("ignore") - df = pandas.read_sql(sql, self.client) + df = pandas.read_sql(sql, self.client, params=params) if results: return df @@ -259,6 +259,21 @@ def execute( return df + def list_primary_key_columns(self, database: str, table: str): + """Return a list of primary key column names.""" + list_pk_col_sql = """ + SELECT ColumnName + FROM DBC.IndicesV + WHERE DatabaseName = ? + AND TableName = ? + AND IndexType = 'K' + ORDER BY ColumnPosition + """ + tables_df = self._execute( + list_pk_col_sql, results=True, params=(database, table) + ) + return list(tables_df.ColumnName.str.rstrip()) + # Methods we need to implement for BaseSQLBackend def create_table(self): pass