Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Auto split row concat/hash validations when many columns #1233

Merged
merged 24 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
f8d0df3
fix: Prototype auto splitting row concat/hash validations when many c…
nj1973 Aug 2, 2024
0c1c4a7
fix: Prototype auto splitting row concat/hash validations when many c…
nj1973 Aug 2, 2024
8e1a23a
fix: Auto split row concat/hash validations when many columns
nj1973 Aug 6, 2024
b7dd487
fix: Auto split row concat/hash validations when many columns
nj1973 Aug 6, 2024
440df6e
Merge remote-tracking branch 'origin/develop' into 1216-tables-with-m…
nj1973 Aug 6, 2024
e999fcd
test: Add tests for BigQuery, Oracle, PostgreSQL amd Teradata
nj1973 Aug 7, 2024
4ffefd7
test: Add tests for BigQuery, Oracle, PostgreSQL amd Teradata
nj1973 Aug 7, 2024
bfd70a6
test: Add tests for more engines
nj1973 Aug 9, 2024
fad862f
fix: Auto split custom-query row concat/hash validations when many co…
nj1973 Aug 13, 2024
b19e70f
fix: Auto split custom-query row concat/hash validations when many co…
nj1973 Aug 13, 2024
18f256c
test: Add tests for more engines
nj1973 Aug 19, 2024
d09cdd9
Merge remote-tracking branch 'origin/develop' into 1216-tables-with-m…
nj1973 Aug 19, 2024
891fe93
chore: Fix type hint
nj1973 Aug 19, 2024
20d0811
test: Fix test assertions
nj1973 Aug 19, 2024
6bc17a0
chore: Fix typos
nj1973 Aug 19, 2024
4549c8e
test: Add tests for more engines
nj1973 Aug 19, 2024
3ae2294
test: Add tests for more engines
nj1973 Aug 19, 2024
93c2145
test: Add tests for more engines
nj1973 Aug 19, 2024
6b8d214
test: Add tests for more engines
nj1973 Aug 19, 2024
671aec2
test: Change Hive tests to use pytest.skip for disabled tests
nj1973 Aug 19, 2024
2b259c1
test: Change Hive tests to use pytest.skip for disabled tests
nj1973 Aug 19, 2024
375d71c
Merge remote-tracking branch 'origin/develop' into 1216-tables-with-m…
nj1973 Aug 20, 2024
02d0ae8
Merge branch 'develop' into 1216-tables-with-many-columns
nj1973 Aug 20, 2024
7753979
chore: PR comments
nj1973 Aug 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 24 additions & 19 deletions data_validation/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,19 +160,27 @@ def get_aggregate_config(args, config_manager: ConfigManager):
return aggregate_configs


def get_calculated_config(args, config_manager) -> List[dict]:
def get_calculated_config(args, config_manager: ConfigManager) -> List[dict]:
"""Return list of formatted calculated objects.

Args:
config_manager(ConfigManager): Validation config manager instance.
"""
calculated_configs = []
fields = []
if args.hash:
col_list = None if args.hash == "*" else cli_tools.get_arg_list(args.hash)
if config_manager.hash:
col_list = (
None
if config_manager.hash == "*"
else cli_tools.get_arg_list(config_manager.hash)
)
fields = config_manager.build_dependent_aliases("hash", col_list)
elif args.concat:
col_list = None if args.concat == "*" else cli_tools.get_arg_list(args.concat)
elif config_manager.concat:
col_list = (
None
if config_manager.concat == "*"
else cli_tools.get_arg_list(config_manager.concat)
)
fields = config_manager.build_dependent_aliases("concat", col_list)

if len(fields) > 0:
Expand All @@ -190,13 +198,13 @@ def get_calculated_config(args, config_manager) -> List[dict]:
custom_params=field.get("calc_params"),
)
)
if args.hash:
if config_manager.hash:
config_manager.append_comparison_fields(
config_manager.build_config_comparison_fields(
["hash__all"], depth=max_depth
)
)
elif args.concat:
elif config_manager.concat:
config_manager.append_comparison_fields(
config_manager.build_config_comparison_fields(
["concat__all"], depth=max_depth
Expand All @@ -212,6 +220,7 @@ def build_config_from_args(args: Namespace, config_manager: ConfigManager):
args (Namespace): User specified Arguments
config_manager (ConfigManager): Validation config manager instance.
"""

# Append SCHEMA_VALIDATION configs
if config_manager.validation_type == consts.SCHEMA_VALIDATION:
if args.exclusion_columns is not None:
Expand All @@ -226,22 +235,18 @@ def build_config_from_args(args: Namespace, config_manager: ConfigManager):
config_manager.append_custom_query_type(args.custom_query_type)

# Get source sql query from source sql file or inline query
if args.source_query:
source_query_str = config_manager.get_query_from_inline(args.source_query)
else:
source_query_str = config_manager.get_query_from_file(
args.source_query_file
config_manager.append_source_query(
cli_tools.get_query_from_query_args(
args.source_query, args.source_query_file
)
config_manager.append_source_query(source_query_str)
)

# Get target sql query from target sql file or inline query
if args.target_query:
target_query_str = config_manager.get_query_from_inline(args.target_query)
else:
target_query_str = config_manager.get_query_from_file(
args.target_query_file
config_manager.append_target_query(
cli_tools.get_query_from_query_args(
args.target_query, args.target_query_file
)
config_manager.append_target_query(target_query_str)
)

# For custom-query column command
if args.custom_query_type == consts.COLUMN_VALIDATION.lower():
Expand Down
144 changes: 142 additions & 2 deletions data_validation/cli_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"""

import argparse
import copy
import csv
import json
import logging
Expand All @@ -52,10 +53,11 @@
import os
import math
from argparse import Namespace
from typing import Dict, List
from typing import Dict, List, Optional
from yaml import Dumper, Loader, dump, load

from data_validation import clients, consts, state_manager, gcs_helper
from data_validation.validation_builder import list_to_sublists

CONNECTION_SOURCE_FIELDS = {
"BigQuery": [
Expand Down Expand Up @@ -463,6 +465,16 @@ def _configure_row_parser(
"Performs a case insensitive match by adding an UPPER() before comparison."
),
)
optional_arguments.add_argument(
"--max-concat-columns",
"-mcc",
type=int,
help=(
"The maximum number of columns accepted by a --hash or --concat validation. When there are "
"more columns than this the validation will implicitly be split into multiple validations. "
"This option has engine specific defaults."
),
)
# Generate-table-partitions and custom-query does not support random row
if not (is_generate_partitions or is_custom_query):
optional_arguments.add_argument(
Expand Down Expand Up @@ -1192,9 +1204,102 @@ def split_table(table_ref, schema_required=True):
return schema.strip(), table.strip()


def get_query_from_file(filename):
"""Return query from input file"""
query = ""
try:
query = gcs_helper.read_file(filename, download_as_text=True)
query = query.rstrip(";\n")
except IOError:
logging.error("Cannot read query file: ", filename)

if not query or query.isspace():
raise ValueError(
"Expected file with sql query, got empty file or file with white spaces. "
f"input file: {filename}"
)
return query


def get_query_from_inline(inline_query):
"""Return query from inline query arg"""

query = inline_query.strip()
query = query.rstrip(";\n")

if not query or query.isspace():
raise ValueError(
"Expected arg with sql query, got empty arg or arg with white "
f"spaces. input query: '{inline_query}'"
)
return query


def get_query_from_query_args(query_str_arg, query_file_arg) -> str:
if query_str_arg:
return get_query_from_inline(query_str_arg)
else:
return get_query_from_file(query_file_arg)


def _max_concat_columns(
max_concat_columns_option: int, source_client, target_client
) -> Optional[int]:
"""Determine any upper limit on number of columns allowed into concat() operation."""
if max_concat_columns_option:
# User specified limit takes precedence.
return max_concat_columns_option
else:
source_max = consts.MAX_CONCAT_COLUMNS_DEFAULTS.get(source_client.name, None)
target_max = consts.MAX_CONCAT_COLUMNS_DEFAULTS.get(target_client.name, None)
if source_max and target_max:
return min(source_max, target_max)
else:
return source_max or target_max


def _concat_column_count_configs(
cols: list,
pre_build_configs: dict,
arg_to_override: str,
max_col_count: int,
) -> list:
"""
Ensure we don't have too many columns for the engines involved.
https://github.com/GoogleCloudPlatform/professional-services-data-validator/issues/1216
"""
return_list = []
if max_col_count and len(cols) > max_col_count:
for col_chunk in list_to_sublists(cols, max_col_count):
col_csv = ",".join(col_chunk)
pre_build_configs_copy = copy.copy(pre_build_configs)
pre_build_configs_copy[arg_to_override] = col_csv
return_list.append(pre_build_configs_copy)
else:
return_list.append(pre_build_configs)
return return_list


def get_pre_build_configs(args: Namespace, validate_cmd: str) -> List[Dict]:
"""Return a dict of configurations to build ConfigManager object"""

def cols_from_arg(concat_arg: str, client, table_obj: dict, query_str: str) -> list:
if concat_arg == "*":
# If validating with "*" then we need to expand to count the columns.
if table_obj:
return clients.get_ibis_table_schema(
client,
table_obj["schema_name"],
table_obj["table_name"],
).names
else:
return clients.get_ibis_query_schema(
client,
query_str,
).names
else:
return get_arg_list(concat_arg)

# Since `generate-table-partitions` defaults to `validate_cmd=row`,
# `validate_cmd` is passed along while calling this method
if validate_cmd is None:
Expand Down Expand Up @@ -1255,10 +1360,12 @@ def get_pre_build_configs(args: Namespace, validate_cmd: str) -> List[Dict]:

# Get table list. Not supported in case of custom query validation
is_filesystem = source_client._source_type == "FileSystem"
query_str = None
if config_type == consts.CUSTOM_QUERY:
tables_list = get_tables_list(
None, default_value=[{}], is_filesystem=is_filesystem
)
query_str = get_query_from_query_args(args.source_query, args.source_query_file)
else:
tables_list = get_tables_list(
args.tables_list, default_value=[{}], is_filesystem=is_filesystem
Expand Down Expand Up @@ -1293,8 +1400,41 @@ def get_pre_build_configs(args: Namespace, validate_cmd: str) -> List[Dict]:
"filter_status": filter_status,
"trim_string_pks": getattr(args, "trim_string_pks", False),
"case_insensitive_match": getattr(args, "case_insensitive_match", False),
consts.CONFIG_ROW_CONCAT: getattr(args, consts.CONFIG_ROW_CONCAT, None),
consts.CONFIG_ROW_HASH: getattr(args, consts.CONFIG_ROW_HASH, None),
"verbose": args.verbose,
}
pre_build_configs_list.append(pre_build_configs)
if (
pre_build_configs[consts.CONFIG_ROW_CONCAT]
or pre_build_configs[consts.CONFIG_ROW_HASH]
):
# Ensure we don't have too many columns for the engines involved.
cols = cols_from_arg(
pre_build_configs[consts.CONFIG_ROW_HASH]
or pre_build_configs[consts.CONFIG_ROW_CONCAT],
source_client,
table_obj,
query_str,
)
additional_pre_build_configs = _concat_column_count_configs(
cols,
pre_build_configs,
consts.CONFIG_ROW_HASH if args.hash else consts.CONFIG_ROW_CONCAT,
_max_concat_columns(
args.max_concat_columns, source_client, target_client
),
)
if len(additional_pre_build_configs) > 1:
helensilva14 marked this conversation as resolved.
Show resolved Hide resolved
message_type = (
f'{table_obj["schema_name"]}.{table_obj["table_name"]}'
if table_obj
else "custom query"
)
logging.info(
f"Splitting validation into {len(additional_pre_build_configs)} queries for {message_type}"
)
pre_build_configs_list.extend(additional_pre_build_configs)
else:
pre_build_configs_list.append(pre_build_configs)

return pre_build_configs_list
40 changes: 29 additions & 11 deletions data_validation/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import copy
import logging
from typing import TYPE_CHECKING
import warnings

import google.oauth2.service_account
Expand All @@ -29,6 +30,11 @@
from third_party.ibis.ibis_mssql.api import mssql_connect
from third_party.ibis.ibis_redshift.api import redshift_connect

if TYPE_CHECKING:
import ibis.expr.schema as sch
import ibis.expr.types as ir


ibis.options.sql.default_limit = None

# Filter Ibis MySQL error when loading client.table()
Expand All @@ -38,6 +44,17 @@
)


IBIS_ALCHEMY_BACKENDS = [
"mysql",
"oracle",
"postgres",
"db2",
"mssql",
"redshift",
"snowflake",
]


def _raise_missing_client_error(msg):
def get_client_call(*args, **kwargs):
raise Exception(msg)
Expand Down Expand Up @@ -146,7 +163,7 @@ def get_ibis_table(client, schema_name, table_name, database_name=None):
return client.table(table_name, database=schema_name)


def get_ibis_query(client, query):
def get_ibis_query(client, query) -> "ir.Table":
"""Return Ibis Table from query expression for Supplied Client."""
iq = client.sql(query)
# Normalise all columns in the query to lower case.
Expand All @@ -155,28 +172,29 @@ def get_ibis_query(client, query):
return iq


def get_ibis_table_schema(client, schema_name, table_name):
def get_ibis_table_schema(client, schema_name: str, table_name: str) -> "sch.Schema":
"""Return Ibis Table Schema for Supplied Client.

client (IbisClient): Client to use for table
schema_name (str): Schema name of table object, may not need this since Backend uses database
table_name (str): Table name of table object
database_name (str): Database name (generally default is used)
"""
if client.name in [
"mysql",
"oracle",
"postgres",
"db2",
"mssql",
"redshift",
"snowflake",
]:
if client.name in IBIS_ALCHEMY_BACKENDS:
return client.table(table_name, schema=schema_name).schema()
else:
return client.get_schema(table_name, schema_name)


def get_ibis_query_schema(client, query_str) -> "sch.Schema":
if client.name in IBIS_ALCHEMY_BACKENDS:
ibis_query = get_ibis_query(client, query_str)
return ibis_query.schema()
else:
# NJ: I'm not happy about calling a private method but don't see how I can avoid it.
return client._get_schema_using_query(query_str)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can discuss more about that!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ibis does not expose a public method like it does for get_schema(). I didn't feel it was right for me to add a public method to Ibis objects, I know we already patch the private method but adding new methods seemed a step too far. Happy to be influenced in another direction though.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, since you added a comment about this current implementation I'd be okay to move forward as it is so we can unblock the customer. It seems like something we can analyze further later on.



def list_schemas(client):
"""Return a list of schemas in the DB."""
if hasattr(client, "list_databases"):
Expand Down
Loading