From ba6706cec7f2140c14aad08d60d4460e60ba1936 Mon Sep 17 00:00:00 2001 From: Florian Valeye Date: Fri, 31 Dec 2021 15:31:48 +0100 Subject: [PATCH] Improve the documentation and add MySQL external source (#11) * Improve the README and add examples * Add the MySQL external source in the Python binding --- Cargo.lock | 2 +- README.adoc | 9 +- python/Cargo.toml | 2 +- python/docs/source/installation.rst | 4 +- python/docs/source/usage.rst | 42 +++++-- python/examples/README.md | 3 + ...scan_external_sources_custom_data_rules.py | 87 +++++++++++++ .../scan_external_sources_database.py | 107 ++++++++++++++++ .../scan_external_sources_database_async.py | 93 ++++++++++++++ .../examples/scan_external_sources_table.py | 92 ++++++++++++++ python/examples/scan_local_sources.py | 59 +++++++++ python/metadata_guardian/conf.py | 1 + python/metadata_guardian/data_rules.py | 6 +- python/metadata_guardian/report.py | 4 +- python/metadata_guardian/scanner.py | 10 ++ python/metadata_guardian/source/__init__.py | 1 + .../source/external/aws_source.py | 26 ++-- .../source/external/deltatable_source.py | 20 ++- .../external/external_metadata_source.py | 33 ++++- .../source/external/gcp_source.py | 23 ++-- .../external/kafka_schema_registry_source.py | 26 ++-- .../source/external/mysql_source.py | 119 ++++++++++++++++++ .../source/external/snowflake_source.py | 19 +-- .../source/local/avro_schema_source.py | 7 +- .../source/local/avro_source.py | 7 +- .../source/local/local_metadata_source.py | 8 +- .../source/local/orc_source.py | 6 +- .../source/local/parquet_source.py | 3 +- .../source/metadata_source.py | 1 + python/pyproject.toml | 3 +- .../tests/external/test_deltatable_source.py | 5 +- python/tests/external/test_gcp_source.py | 20 +-- python/tests/external/test_mysql_source.py | 80 ++++++++++++ .../tests/external/test_snowflake_source.py | 5 +- 34 files changed, 847 insertions(+), 86 deletions(-) create mode 100644 python/examples/README.md create mode 100644 python/examples/scan_external_sources_custom_data_rules.py create mode 100644 python/examples/scan_external_sources_database.py create mode 100644 python/examples/scan_external_sources_database_async.py create mode 100644 python/examples/scan_external_sources_table.py create mode 100644 python/examples/scan_local_sources.py create mode 100644 python/metadata_guardian/source/external/mysql_source.py create mode 100644 python/tests/external/test_mysql_source.py diff --git a/Cargo.lock b/Cargo.lock index 9cd8f41..74c4e97 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -636,7 +636,7 @@ dependencies = [ [[package]] name = "metadata_guardian-python" -version = "0.1.0" +version = "0.1.1" dependencies = [ "env_logger", "metadata_guardian", diff --git a/README.adoc b/README.adoc index da32d23..23350b2 100644 --- a/README.adoc +++ b/README.adoc @@ -23,9 +23,10 @@ Using Rust, it makes blazing fast multi-regex matching. - Deltalake - GCP: BigQuery - Snowflake +- Kafka Schema Registry == Data Rules -The available data rules are: *https://github.com/fvaleye/metadata-guardian/blob/main/python/metadata_guardian/rules/pii_rules.yaml[PII]* and *https://github.com/fvaleye/metadata-guardian/blob/main/python/metadata_guardian/rules/inclusion_rules.yaml[INCLUSION]*. But it aims to be extended with custom data rules that could serve multiple purposes (for example: detect data that may contain IA biais, detect credentials...). +The available data rules are here: *https://github.com/fvaleye/metadata-guardian/blob/main/python/metadata_guardian/rules/pii_rules.yaml[PII]* and *https://github.com/fvaleye/metadata-guardian/blob/main/python/metadata_guardian/rules/inclusion_rules.yaml[INCLUSION]*. But it aims to be extended with custom data rules that could serve multiple purposes (for example: detect data that may contain IA biais, detect credentials...). == Where to get it @@ -35,12 +36,12 @@ pip install 'metadata_guardian[all]' ``` ```sh -# Install with one data source -pip install 'metadata_guardian[snowflake,avro,aws,gcp,deltalake]' +# Install with one metadata source in the list +pip install 'metadata_guardian[snowflake,avro,aws,gcp,deltalake,kafka_schema_registry]' ``` == Licence https://raw.githubusercontent.com/fvaleye/metadata-guardian/main/LICENSE.txt[Apache License 2.0] == Documentation -The documentation is hosted here: https://fvaleye.github.io/metadata-guardian/python/ \ No newline at end of file +The documentation is hosted here: https://fvaleye.github.io/metadata-guardian/python/ diff --git a/python/Cargo.toml b/python/Cargo.toml index 2d4f19c..7390b89 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "metadata_guardian-python" -version = "0.1.0" +version = "0.1.1" authors = ["Florian Valeye "] homepage = "https://fvaleye.github.io/metadata-guardian/python" license = "Apache-2.0" diff --git a/python/docs/source/installation.rst b/python/docs/source/installation.rst index 0133bb6..df04deb 100644 --- a/python/docs/source/installation.rst +++ b/python/docs/source/installation.rst @@ -8,5 +8,5 @@ Using Pip # Install all the metadata sources pip install 'metadata_guardian[all]' - # Install one metadata source in the list - pip install 'metadata_guardian[snowflake,avro,aws,gcp,deltalake,devel]' \ No newline at end of file + # Install with one metadata source in the list + pip install 'metadata_guardian[snowflake,avro,aws,gcp,deltalake,kafka_schema_registry]' \ No newline at end of file diff --git a/python/docs/source/usage.rst b/python/docs/source/usage.rst index 13a6d8f..a86326a 100644 --- a/python/docs/source/usage.rst +++ b/python/docs/source/usage.rst @@ -4,16 +4,15 @@ Usage Metadata Guardian ----------------- -Scan the column names of a local source: +**Workflow:** ->>> from metadata_guardian import DataRules, ColumnScanner, AvailableCategory ->>> from metadata_guardian.source import ParquetSource ->>> ->>> data_rules = DataRules.from_available_category(category=AvailableCategory.PII) ->>> source = ParquetSource("file.parquet") ->>> column_scanner = ColumnScanner(data_rules=data_rules) ->>> report = column_scanner.scan_local(source) ->>> report.to_console() +1. Create the Data Rules +2. Create the Metadata Source +3. Scan the Metadata Source +4. Analyze the reports + +Scan an external Metadata Source +-------------------------------- Scan the column names of a external source on a table: @@ -23,7 +22,8 @@ Scan the column names of a external source on a table: >>> data_rules = DataRules.from_available_category(category=AvailableCategory.PII) >>> source = SnowflakeSource(sf_account="account", sf_user="sf_user", sf_password="sf_password", warehouse="warehouse", schema_name="schema_name") >>> column_scanner = ColumnScanner(data_rules=data_rules) ->>> report = column_scanner.scan_external(source, database_name="database_name", table_name="table_name", include_comment=True) +>>> with source: +>>> report = column_scanner.scan_external(source, database_name="database_name", table_name="table_name", include_comment=True) >>> report.to_console() Scan the column names of a external source on database: @@ -34,7 +34,8 @@ Scan the column names of a external source on database: >>> data_rules = DataRules.from_available_category(category=AvailableCategory.PII) >>> source = SnowflakeSource(sf_account="account", sf_user="sf_user", sf_password="sf_password", warehouse="warehouse", schema_name="schema_name") >>> column_scanner = ColumnScanner(data_rules=data_rules) ->>> report = column_scanner.scan_external(source, database_name="database_name", include_comment=True) +>>> with source: +>>> report = column_scanner.scan_external(source, database_name="database_name", include_comment=True) >>> report.to_console() Scan the column names of an external source for a database asynchronously with asyncio: @@ -46,7 +47,23 @@ Scan the column names of an external source for a database asynchronously with a >>> data_rules = DataRules.from_available_category(category=AvailableCategory.PII) >>> source = SnowflakeSource(sf_account="account", sf_user="sf_user", sf_password="sf_password", warehouse="warehouse", schema_name="schema_name") >>> column_scanner = ColumnScanner(data_rules=data_rules) ->>> report = asyncio.run(column_scanner.scan_external_async(source, database_name="database_name", include_comment=True)) +>>> with source: +>>> report = asyncio.run(column_scanner.scan_external_async(source, database_name="database_name", include_comment=True)) +>>> report.to_console() + + +Scan an internal Metadata Source +-------------------------------- + +Scan the column names of a local source: + +>>> from metadata_guardian import DataRules, ColumnScanner, AvailableCategory +>>> from metadata_guardian.source import ParquetSource +>>> +>>> data_rules = DataRules.from_available_category(category=AvailableCategory.PII) +>>> column_scanner = ColumnScanner(data_rules=data_rules) +>>> with ParquetSource("file.parquet") as source: +>>> report = column_scanner.scan_local(source) >>> report.to_console() Scan the column names of a local source: @@ -57,6 +74,7 @@ Scan the column names of a local source: >>> data_rules = DataRules.from_available_category(category=AvailableCategory.PII) >>> column_scanner = ColumnScanner(data_rules=data_rules) >>> report = MetadataGuardianReport() +>>> paths = ["first_path", "second_path"] >>> for path in paths: >>> source = ParquetSource(path) >>> report.append(column_scanner.scan_local(source)) diff --git a/python/examples/README.md b/python/examples/README.md new file mode 100644 index 0000000..a50b9ea --- /dev/null +++ b/python/examples/README.md @@ -0,0 +1,3 @@ +Examples +This directory contains various examples of the Metadata Guardian features. +Make sure Metadata Guardian is installed and run the examples using the command line with python. diff --git a/python/examples/scan_external_sources_custom_data_rules.py b/python/examples/scan_external_sources_custom_data_rules.py new file mode 100644 index 0000000..4f4b562 --- /dev/null +++ b/python/examples/scan_external_sources_custom_data_rules.py @@ -0,0 +1,87 @@ +import argparse +import os + +from metadata_guardian import ( + AvailableCategory, + ColumnScanner, + DataRules, + ExternalMetadataSource, +) +from metadata_guardian.source import ( + AthenaSource, + BigQuerySource, + DeltaTableSource, + GlueSource, + KafkaSchemaRegistrySource, + SnowflakeSource, +) + + +def get_snowflake() -> ExternalMetadataSource: + return SnowflakeSource( + sf_account=os.environ["SNOWFLAKE_ACCOUNT"], + sf_user=os.environ["SNOWFLAKE_USER"], + sf_password=os.environ["SNOWFLAKE_PASSWORD"], + warehouse=os.environ["SNOWFLAKE_WAREHOUSE"], + schema_name=os.environ["SNOWFLAKE_SCHEMA_NAME"], + ) + + +def get_gcp_bigquery() -> ExternalMetadataSource: + return BigQuerySource( + service_account_json_path=os.environ["BIGQUERY_SERVICE_ACCOUNT"], + project=os.environ["BIGQUERY_PROJECT"], + location=os.environ["BIGQUERY_LOCATION"], + ) + + +def get_kafka_schema_registry() -> ExternalMetadataSource: + return KafkaSchemaRegistrySource(url=os.environ["KAFKA_SCHEMA_REGISTRY_URL"]) + + +def get_delta_table() -> ExternalMetadataSource: + return DeltaTableSource(uri=os.environ["DELTA_TABLE_URI"]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--data-rules-path", + required=True, + help="The Data Rules specification yaml file path to use for creating the Data Rules", + ) + parser.add_argument( + "--external-source", + choices=["Snowflake", "GCP BigQuery", "Kafka Schema Registry", "Delta Table"], + required=True, + help="The External Metadata Source to use", + ) + parser.add_argument( + "--scanner", choices=["ColumnScanner"], help="The scanner to use" + ) + parser.add_argument( + "--database_name", required=True, help="The database name to scan" + ) + parser.add_argument( + "--include_comments", default=True, help="Include the comments in the scan" + ) + args = parser.parse_args() + data_rules = DataRules(path=args.data_rules_path) + column_scanner = ColumnScanner(data_rules=data_rules) + + if args.external_source == "Snowflake": + source = get_snowflake() + elif args.external_source == "GCP BigQuery": + source = get_gcp_bigquery() + elif args.external_source == "Kafka Schema Registry": + source = get_kafka_schema_registry() + elif args.external_source == "Delta Table": + source = get_delta_table() + + with source: + report = column_scanner.scan_external( + source, + database_name=args.database_name, + include_comment=args.include_comments, + ) + report.to_console() diff --git a/python/examples/scan_external_sources_database.py b/python/examples/scan_external_sources_database.py new file mode 100644 index 0000000..22e4019 --- /dev/null +++ b/python/examples/scan_external_sources_database.py @@ -0,0 +1,107 @@ +import argparse +import os + +from metadata_guardian import ( + AvailableCategory, + ColumnScanner, + DataRules, + ExternalMetadataSource, +) +from metadata_guardian.source import ( + AthenaSource, + BigQuerySource, + DeltaTableSource, + GlueSource, + KafkaSchemaRegistrySource, + MySQLSource, + SnowflakeSource, +) + + +def get_snowflake() -> ExternalMetadataSource: + return SnowflakeSource( + sf_account=os.environ["SNOWFLAKE_ACCOUNT"], + sf_user=os.environ["SNOWFLAKE_USER"], + sf_password=os.environ["SNOWFLAKE_PASSWORD"], + warehouse=os.environ["SNOWFLAKE_WAREHOUSE"], + schema_name=os.environ["SNOWFLAKE_SCHEMA_NAME"], + ) + + +def get_gcp_bigquery() -> ExternalMetadataSource: + return BigQuerySource( + service_account_json_path=os.environ["BIGQUERY_SERVICE_ACCOUNT"], + project=os.environ["BIGQUERY_PROJECT"], + location=os.environ["BIGQUERY_LOCATION"], + ) + + +def get_kafka_schema_registry() -> ExternalMetadataSource: + return KafkaSchemaRegistrySource(url=os.environ["KAFKA_SCHEMA_REGISTRY_URL"]) + + +def get_delta_table() -> ExternalMetadataSource: + return DeltaTableSource(uri=os.environ["DELTA_TABLE_URI"]) + + +def get_mysql() -> ExternalMetadataSource: + return MySQLSource( + user=os.environ["MYSQL_USER"], + password=os.environ["MYSQL_PASSWORD"], + host=os.environ["MYSQL_HOST"], + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--data-rules", + choices=["PII", "INCLUSION"], + default="PII", + help="The Data Rules to use", + ) + parser.add_argument( + "--external-source", + choices=[ + "Snowflake", + "GCP BigQuery", + "Kafka Schema Registry", + "Delta Table", + "MySQL", + ], + required=True, + help="The External Metadata Source to use", + ) + parser.add_argument( + "--scanner", choices=["ColumnScanner"], help="The scanner to use" + ) + parser.add_argument( + "--database_name", required=True, help="The database name to scan" + ) + parser.add_argument( + "--include_comments", default=True, help="Include the comments in the scan" + ) + args = parser.parse_args() + data_rules = DataRules.from_available_category( + category=AvailableCategory[args.data_rules] + ) + column_scanner = ColumnScanner(data_rules=data_rules) + + if args.external_source == "Snowflake": + source = get_snowflake() + elif args.external_source == "GCP BigQuery": + source = get_gcp_bigquery() + elif args.external_source == "Kafka Schema Registry": + source = get_kafka_schema_registry() + elif args.external_source == "Delta Table": + source = get_delta_table() + elif args.external_source == "MySQL": + source = get_mysql() + + with source: + report = column_scanner.scan_external( + source, + database_name=args.database_name, + include_comment=args.include_comments, + ) + report.to_console() diff --git a/python/examples/scan_external_sources_database_async.py b/python/examples/scan_external_sources_database_async.py new file mode 100644 index 0000000..30c1beb --- /dev/null +++ b/python/examples/scan_external_sources_database_async.py @@ -0,0 +1,93 @@ +import argparse +import asyncio +import os + +from metadata_guardian import ( + AvailableCategory, + ColumnScanner, + DataRules, + ExternalMetadataSource, +) +from metadata_guardian.source import ( + AthenaSource, + BigQuerySource, + DeltaTableSource, + GlueSource, + KafkaSchemaRegistrySource, + SnowflakeSource, +) + + +def get_snowflake() -> ExternalMetadataSource: + return SnowflakeSource( + sf_account=os.environ["SNOWFLAKE_ACCOUNT"], + sf_user=os.environ["SNOWFLAKE_USER"], + sf_password=os.environ["SNOWFLAKE_PASSWORD"], + warehouse=os.environ["SNOWFLAKE_WAREHOUSE"], + schema_name=os.environ["SNOWFLAKE_SCHEMA_NAME"], + ) + + +def get_gcp_bigquery() -> ExternalMetadataSource: + return BigQuerySource( + service_account_json_path=os.environ["BIGQUERY_SERVICE_ACCOUNT"], + project=os.environ["BIGQUERY_PROJECT"], + location=os.environ["BIGQUERY_LOCATION"], + ) + + +def get_kafka_schema_registry() -> ExternalMetadataSource: + return KafkaSchemaRegistrySource(url=os.environ["KAFKA_SCHEMA_REGISTRY_URL"]) + + +def get_delta_table() -> ExternalMetadataSource: + return DeltaTableSource(uri=os.environ["DELTA_TABLE_URI"]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--data-rules", + choices=["PII", "INCLUSION"], + default="PII", + help="The Data Rules to use", + ) + parser.add_argument( + "--external-source", + choices=["Snowflake", "GCP BigQuery", "Kafka Schema Registry", "Delta Table"], + required=True, + help="The External Metadata Source to use", + ) + parser.add_argument( + "--scanner", choices=["ColumnScanner"], help="The scanner to use" + ) + parser.add_argument( + "--database_name", required=True, help="The database name to scan" + ) + parser.add_argument( + "--include_comments", default=True, help="Include the comments in the scan" + ) + args = parser.parse_args() + data_rules = DataRules.from_available_category( + category=AvailableCategory[args.data_rules] + ) + column_scanner = ColumnScanner(data_rules=data_rules) + + if args.external_source == "Snowflake": + source = get_snowflake() + elif args.external_source == "GCP BigQuery": + source = get_gcp_bigquery() + elif args.external_source == "Kafka Schema Registry": + source = get_kafka_schema_registry() + elif args.external_source == "Delta Table": + source = get_delta_table() + + with source: + report = asyncio.run( + column_scanner.scan_external_async( + source, + database_name=args.database_name, + include_comment=args.include_comments, + ) + ) + report.to_console() diff --git a/python/examples/scan_external_sources_table.py b/python/examples/scan_external_sources_table.py new file mode 100644 index 0000000..8074d18 --- /dev/null +++ b/python/examples/scan_external_sources_table.py @@ -0,0 +1,92 @@ +import argparse +import os + +from metadata_guardian import ( + AvailableCategory, + ColumnScanner, + DataRules, + ExternalMetadataSource, +) +from metadata_guardian.source import ( + AthenaSource, + BigQuerySource, + DeltaTableSource, + GlueSource, + KafkaSchemaRegistrySource, + SnowflakeSource, +) + + +def get_snowflake() -> ExternalMetadataSource: + return SnowflakeSource( + sf_account=os.environ["SNOWFLAKE_ACCOUNT"], + sf_user=os.environ["SNOWFLAKE_USER"], + sf_password=os.environ["SNOWFLAKE_PASSWORD"], + warehouse=os.environ["SNOWFLAKE_WAREHOUSE"], + schema_name=os.environ["SNOWFLAKE_SCHEMA_NAME"], + ) + + +def get_gcp_bigquery() -> ExternalMetadataSource: + return BigQuerySource( + service_account_json_path=os.environ["BIGQUERY_SERVICE_ACCOUNT"], + project=os.environ["BIGQUERY_PROJECT"], + location=os.environ["BIGQUERY_LOCATION"], + ) + + +def get_kafka_schema_registry() -> ExternalMetadataSource: + return KafkaSchemaRegistrySource(url=os.environ["KAFKA_SCHEMA_REGISTRY_URL"]) + + +def get_delta_table() -> ExternalMetadataSource: + return DeltaTableSource(uri=os.environ["DELTA_TABLE_URI"]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--data-rules", + choices=["PII", "INCLUSION"], + default="PII", + help="The Data Rules to use", + ) + parser.add_argument( + "--external-source", + choices=["Snowflake", "GCP BigQuery", "Kafka Schema Registry", "Delta Table"], + required=True, + help="The External Metadata Source to use", + ) + parser.add_argument( + "--scanner", choices=["ColumnScanner"], help="The scanner to use" + ) + parser.add_argument( + "--database_name", required=True, help="The database name to scan" + ) + parser.add_argument("--table_name", required=True, help="The table name to scan") + parser.add_argument( + "--include_comments", default=True, help="Include the comments in the scan" + ) + args = parser.parse_args() + data_rules = DataRules.from_available_category( + category=AvailableCategory[args.data_rules] + ) + column_scanner = ColumnScanner(data_rules=data_rules) + + if args.external_source == "Snowflake": + source = get_snowflake() + elif args.external_source == "GCP BigQuery": + source = get_gcp_bigquery() + elif args.external_source == "Kafka Schema Registry": + source = get_kafka_schema_registry() + elif args.external_source == "Delta Table": + source = get_delta_table() + + with source: + report = column_scanner.scan_external( + source, + database_name=args.database_name, + table_name=args.table_name, + include_comment=args.include_comments, + ) + report.to_console() diff --git a/python/examples/scan_local_sources.py b/python/examples/scan_local_sources.py new file mode 100644 index 0000000..f4845e2 --- /dev/null +++ b/python/examples/scan_local_sources.py @@ -0,0 +1,59 @@ +import argparse +import os + +from metadata_guardian import AvailableCategory, ColumnScanner, DataRules +from metadata_guardian.source import AvroSource, ORCSource, ParquetSource + + +def get_gcp_bigquery() -> ExternalMetadataSource: + return BigQuerySource( + service_account_json_path=os.environ["BIGQUERY_SERVICE_ACCOUNT"], + project=os.environ["BIGQUERY_PROJECT"], + location=os.environ["BIGQUERY_LOCATION"], + ) + + +def get_kafka_schema_registry() -> ExternalMetadataSource: + return KafkaSchemaRegistrySource(url=os.environ["KAFKA_SCHEMA_REGISTRY_URL"]) + + +def get_delta_table() -> ExternalMetadataSource: + return DeltaTableSource(uri=os.environ["DELTA_TABLE_URI"]) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--data-rules", + choices=["PII", "INCLUSION"], + default="PII", + help="The Data Rules to use", + ) + parser.add_argument( + "--local-source", + choices=["Avro", "Parquet", "Orc"], + required=True, + help="The Local Metadata Source to use", + ) + parser.add_argument( + "--scanner", choices=["ColumnScanner"], help="The scanner to use" + ) + parser.add_argument("--path", required=True, help="The path of the file to scan") + parser.add_argument( + "--include_comments", default=True, help="Include the comments in the scan" + ) + args = parser.parse_args() + data_rules = DataRules.from_available_category( + category=AvailableCategory[args.data_rules] + ) + column_scanner = ColumnScanner(data_rules=data_rules) + + if args.local_source == "Avro": + source = AvroSource(local_path=args.path) + elif args.local_source == "Parquet": + source = ParquetSource(local_path=args.path) + elif args.local_source == "Orc": + source = ORCSource(local_path=args.path) + + report = column_scanner.scan_local(source) + report.to_console() diff --git a/python/metadata_guardian/conf.py b/python/metadata_guardian/conf.py index 87b44a2..bc9ba96 100644 --- a/python/metadata_guardian/conf.py +++ b/python/metadata_guardian/conf.py @@ -6,6 +6,7 @@ def configure_logger() -> None: """ Configure the loguru configuration with Rich. + :return: """ if "LOGURU_LEVEL" not in os.environ: diff --git a/python/metadata_guardian/data_rules.py b/python/metadata_guardian/data_rules.py index 464818c..5850c01 100644 --- a/python/metadata_guardian/data_rules.py +++ b/python/metadata_guardian/data_rules.py @@ -44,6 +44,7 @@ def __init__(self, path: str) -> None: def from_available_category(cls, category: AvailableCategory) -> "DataRules": """ Get Data Rules from an available category. + :param category: the available category of the data rules :return: the Data Rules instance """ @@ -59,6 +60,7 @@ def from_available_category(cls, category: AvailableCategory) -> "DataRules": def validate_word(self, word: str) -> MetadataGuardianResults: """ Validate a word with the data rules defined. + :param word: the word to validate :return: the metadata guardian results """ @@ -79,7 +81,8 @@ def validate_word(self, word: str) -> MetadataGuardianResults: def validate_words(self, words: List[str]) -> List[MetadataGuardianResults]: """ - Validate a list of words with the data rules defined.:param word: the word to validate + Validate a list of words with the data rules defined.:param word: the word to validate. + :param words: the words to validate :return: the metadata guardian results """ @@ -104,6 +107,7 @@ def validate_words(self, words: List[str]) -> List[MetadataGuardianResults]: def validate_file(self, path: str) -> List[MetadataGuardianResults]: """ Validate a file content with the data rules defined. + :param path: the file path :return: the metadata guardian results """ diff --git a/python/metadata_guardian/report.py b/python/metadata_guardian/report.py index fca97f2..88a9b4a 100644 --- a/python/metadata_guardian/report.py +++ b/python/metadata_guardian/report.py @@ -94,6 +94,7 @@ class MetadataGuardianReport: def append(self, other_report: "MetadataGuardianReport") -> None: """ Concat the results before making the report. + :param other_report: other report to append :return: """ @@ -102,13 +103,14 @@ def append(self, other_report: "MetadataGuardianReport") -> None: def to_console(self) -> None: """ Display the metadata guardian results to the console. + :return: """ _console = Console() _table = Table( title=":magnifying_glass_tilted_right: Metadata Guardian report", show_header=True, - header_style="bold dim", + header_style="bold", show_lines=True, ) _table.add_column("Category", style="yellow", no_wrap=True) diff --git a/python/metadata_guardian/scanner.py b/python/metadata_guardian/scanner.py index ea7af86..27c0822 100644 --- a/python/metadata_guardian/scanner.py +++ b/python/metadata_guardian/scanner.py @@ -21,6 +21,7 @@ class Scanner(ABC): def scan_local(self, source: LocalMetadataSource) -> MetadataGuardianReport: """ Scan the column names from the local source. + :param source: the LocalMetadataSource to scan :return: a Metadata Guardian report """ @@ -36,6 +37,7 @@ def scan_external( ) -> MetadataGuardianReport: """ Scan the column names from the external source. + :param source: the ExternalMetadataSource to scan :param database_name: the name of the database :param table_name: the name of the table @@ -55,6 +57,7 @@ async def scan_external_async( ) -> MetadataGuardianReport: """ Scan the column names from the external source asynchronously. + :param source: the ExternalMetadataSource to scan :param database_name: the name of the database :param tasks_limit: the limit of the tasks to run in parallel @@ -75,6 +78,7 @@ class ColumnScanner(Scanner): def scan_local(self, source: LocalMetadataSource) -> MetadataGuardianReport: """ Scan the column names from the local source. + :param source: the MetadataSource to scan :return: a Metadata Guardian report """ @@ -82,6 +86,11 @@ def scan_local(self, source: LocalMetadataSource) -> MetadataGuardianReport: f"[blue]Launch the metadata scanning of the local provider {source.type}" ) with ProgressionBar(disable=self.progression_bar_disable) as progression_bar: + progression_bar.add_task_with_item( + item_name=source.local_path, + source_type=source.type, + total=1, + ) report = MetadataGuardianReport( report_results=[ ReportResults( @@ -104,6 +113,7 @@ def scan_external( ) -> MetadataGuardianReport: """ Scan the column names from the external source using a table name or a database name. + :param source: the ExternalMetadataSource to scan :param database_name: the name of the database :param table_name: the name of the table diff --git a/python/metadata_guardian/source/__init__.py b/python/metadata_guardian/source/__init__.py index a444342..669d044 100644 --- a/python/metadata_guardian/source/__init__.py +++ b/python/metadata_guardian/source/__init__.py @@ -3,6 +3,7 @@ from .external.external_metadata_source import * from .external.gcp_source import * from .external.kafka_schema_registry_source import * +from .external.mysql_source import * from .external.snowflake_source import * from .local.avro_schema_source import * from .local.avro_source import * diff --git a/python/metadata_guardian/source/external/aws_source.py b/python/metadata_guardian/source/external/aws_source.py index 0ae97f1..cf61ebc 100644 --- a/python/metadata_guardian/source/external/aws_source.py +++ b/python/metadata_guardian/source/external/aws_source.py @@ -32,9 +32,9 @@ class AthenaSource(ExternalMetadataSource): aws_access_key_id: Optional[str] = None aws_secret_access_key: Optional[str] = None - def get_connection(self) -> None: + def create_connection(self) -> None: """ - Get Athena connection. + Create Athena connection. :return: """ self.connection = boto3.client( @@ -44,6 +44,9 @@ def get_connection(self) -> None: aws_secret_access_key=self.aws_secret_access_key, ) + def close_connection(self) -> None: + pass + def get_column_names( self, database_name: str, table_name: str, include_comment: bool = False ) -> List[str]: @@ -56,7 +59,7 @@ def get_column_names( """ try: if not self.connection: - self.get_connection() + self.create_connection() response = self.connection.get_table_metadata( CatalogName=self.catalog_name, DatabaseName=database_name, @@ -83,7 +86,7 @@ def get_table_names_list(self, database_name: str) -> List[str]: """ try: if not self.connection: - self.get_connection() + self.create_connection() table_names_list = list() response = self.connection.list_table_metadata( CatalogName=self.catalog_name, @@ -122,9 +125,9 @@ class GlueSource(ExternalMetadataSource): aws_access_key_id: Optional[str] = None aws_secret_access_key: Optional[str] = None - def get_connection(self) -> None: + def create_connection(self) -> None: """ - Get the Glue connection + Create the Glue connection :return: """ self.connection = boto3.client( @@ -134,6 +137,9 @@ def get_connection(self) -> None: aws_secret_access_key=self.aws_secret_access_key, ) + def close_connection(self) -> None: + pass + def get_column_names( self, database_name: str, table_name: str, include_comment: bool = False ) -> List[str]: @@ -146,16 +152,16 @@ def get_column_names( """ try: if not self.connection: - self.get_connection() + self.create_connection() response = self.connection.get_table( DatabaseName=database_name, Name=table_name ) columns = list() for row in response["Table"]["StorageDescriptor"]["Columns"]: - columns.append(row["Name"].lower()) + columns.append(row["Name"]) if include_comment: if "Comment" in row: - columns.append(row["Comment"].lower()) + columns.append(row["Comment"]) return columns except botocore.exceptions.ClientError as exception: logger.exception( @@ -171,7 +177,7 @@ def get_table_names_list(self, database_name: str) -> List[str]: """ try: if not self.connection: - self.get_connection() + self.create_connection() table_names_list = list() response = self.connection.get_tables( DatabaseName=database_name, diff --git a/python/metadata_guardian/source/external/deltatable_source.py b/python/metadata_guardian/source/external/deltatable_source.py index bb4240a..d9ac25b 100644 --- a/python/metadata_guardian/source/external/deltatable_source.py +++ b/python/metadata_guardian/source/external/deltatable_source.py @@ -22,14 +22,18 @@ class DeltaTableSource(ExternalMetadataSource): uri: str data_catalog: DataCatalog = DataCatalog.AWS + external_data_catalog_disable: bool = True - def get_connection(self) -> None: + def create_connection(self) -> None: """ - Get the DeltaTable instance. + Create the DeltaTable instance. :return: """ self.connection = DeltaTable(self.uri) + def close_connection(self) -> None: + pass + def get_column_names( self, database_name: Optional[str] = None, @@ -44,20 +48,24 @@ def get_column_names( :return: the list of the column names """ try: - if database_name and table_name: + if ( + not self.external_data_catalog_disable + and database_name + and table_name + ): self.connection = DeltaTable.from_data_catalog( data_catalog=self.data_catalog, database_name=database_name, table_name=table_name, ) elif not self.connection: - self.get_connection() + self.create_connection() schema = self.connection.schema() columns = list() for field in schema.fields: - columns.append(field.name.lower()) + columns.append(field.name) if include_comment and field.metadata: - columns.append(str(field.metadata).lower()) + columns.append(str(field.metadata)) return columns except Exception as exception: logger.exception( diff --git a/python/metadata_guardian/source/external/external_metadata_source.py b/python/metadata_guardian/source/external/external_metadata_source.py index ac22245..c0ac42c 100644 --- a/python/metadata_guardian/source/external/external_metadata_source.py +++ b/python/metadata_guardian/source/external/external_metadata_source.py @@ -1,6 +1,8 @@ from abc import abstractmethod from typing import Any, List, Optional +from loguru import logger + from ...exceptions import MetadataGuardianException from ..metadata_source import MetadataSource @@ -10,6 +12,26 @@ class ExternalMetadataSource(MetadataSource): connection: Optional[Any] = None + def __enter__(self) -> "ExternalMetadataSource": + try: + self.create_connection() + except Exception as exception: + logger.exception( + "Error raised while opening the Metadata Source connection" + ) + raise exception + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore + try: + self.close_connection() + except Exception as exception: + logger.exception( + "Error raised while closing the Metadata Source connection" + ) + raise exception + return self + @abstractmethod def get_column_names( self, @@ -36,9 +58,16 @@ def get_table_names_list(self, database_name: str) -> List[str]: pass @abstractmethod - def get_connection(self) -> None: + def create_connection(self) -> None: + """ + Create the connection of the source. + :return: + """ + pass + + def close_connection(self) -> None: """ - Get the connection of the source. + Close the connection of the source. :return: """ pass diff --git a/python/metadata_guardian/source/external/gcp_source.py b/python/metadata_guardian/source/external/gcp_source.py index 3954eab..41081bb 100644 --- a/python/metadata_guardian/source/external/gcp_source.py +++ b/python/metadata_guardian/source/external/gcp_source.py @@ -26,7 +26,7 @@ class BigQuerySource(ExternalMetadataSource): project: Optional[str] = None location: Optional[str] = None - def get_connection(self) -> None: + def create_connection(self) -> None: """ Get the Big Query connection. :return: @@ -54,16 +54,17 @@ def get_column_names( try: if not self.connection: - self.get_connection() - query_job = self.connection.query( - f'SELECT column_name, description FROM `{database_name}.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS` WHERE table_name = "{table_name}"' - ) - results = query_job.result() + self.create_connection() + + table_reference = self.connection.dataset( + database_name, project=self.project + ).table(table_name) + table = self.connection.get_table(table_reference) columns = list() - for row in results: - columns.append(row.column_name.lower()) - if include_comment and row.description: - columns.append(row.description.lower()) + for column in table.schema: + columns.append(column.name.lower()) + if include_comment and column.description: + columns.append(column.description.lower()) return columns except Exception as exception: logger.exception( @@ -80,7 +81,7 @@ def get_table_names_list(self, database_name: str) -> List[str]: try: if not self.connection: - self.get_connection() + self.create_connection() query_job = self.connection.query( f"SELECT table_name FROM `{database_name}.INFORMATION_SCHEMA.TABLES`" ) diff --git a/python/metadata_guardian/source/external/kafka_schema_registry_source.py b/python/metadata_guardian/source/external/kafka_schema_registry_source.py index feb3d1d..99a4b1a 100644 --- a/python/metadata_guardian/source/external/kafka_schema_registry_source.py +++ b/python/metadata_guardian/source/external/kafka_schema_registry_source.py @@ -32,15 +32,16 @@ class KafkaSchemaRegistrySource(ExternalMetadataSource): url: str ssl_certificate_location: Optional[str] = None ssl_key_location: Optional[str] = None - connection: Optional[Any] = None + connection: Optional[SchemaRegistryClient] = None authenticator: Optional[ KafkaSchemaRegistryAuthentication ] = KafkaSchemaRegistryAuthentication.USER_PWD comment_field_name: str = "doc" - def get_connection(self) -> None: + def create_connection(self) -> None: """ - Get the connection of the Kafka Schema Registry. + Create the connection of the Kafka Schema Registry. + :return: """ if self.authenticator == KafkaSchemaRegistryAuthentication.USER_PWD: @@ -52,11 +53,20 @@ def get_connection(self) -> None: else: raise NotImplementedError() + def close_connection(self) -> None: + """ + Close the Kafka Schema Registry connection. + + :return: + """ + self.connection.__exit__() + def get_column_names( self, database_name: str, table_name: str, include_comment: bool = False ) -> List[str]: """ Get the column names from the subject. + :param database_name: not relevant :param table_name: the subject name :param include_comment: include the comment @@ -64,13 +74,13 @@ def get_column_names( """ try: if not self.connection: - self.get_connection() + self.create_connection() registered_schema = self.connection.get_latest_version(table_name) columns = list() for field in json.loads(registered_schema.schema.schema_str)["fields"]: - columns.append(field["name"].lower()) + columns.append(field["name"]) if include_comment and self.comment_field_name in field: - columns.append(field[self.comment_field_name].lower()) + columns.append(field[self.comment_field_name]) return columns except Exception as exception: logger.exception( @@ -81,12 +91,13 @@ def get_column_names( def get_table_names_list(self, database_name: str) -> List[str]: """ Get all the subjects from the Schema Registry. + :param database_name: not relevant in that case :return: the list of the table names of the database """ try: if not self.connection: - self.get_connection() + self.create_connection() all_subjects = self.connection.get_subjects() return all_subjects except Exception as exception: @@ -99,6 +110,7 @@ def get_table_names_list(self, database_name: str) -> List[str]: def type(self) -> str: """ The type of the source. + :return: the name of the source. """ return "Kafka Schema Registry" diff --git a/python/metadata_guardian/source/external/mysql_source.py b/python/metadata_guardian/source/external/mysql_source.py new file mode 100644 index 0000000..704f0a4 --- /dev/null +++ b/python/metadata_guardian/source/external/mysql_source.py @@ -0,0 +1,119 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Any, List, Optional + +from loguru import logger + +from .external_metadata_source import ( + ExternalMetadataSource, + ExternalMetadataSourceException, +) + +try: + import pymysql + + MYSQL_INSTALLED = True +except ImportError: + logger.debug("MySQL optional dependency is not installed.") + MYSQL_INSTALLED = False + +if MYSQL_INSTALLED: + + class MySQLAuthenticator(Enum): + """Authentication method for MySQL source.""" + + USER_PWD = 1 + + @dataclass + class MySQLSource(ExternalMetadataSource): + """Instance of a MySQL source.""" + + user: str + password: str + host: str + database: Optional[str] = None + authenticator: MySQLAuthenticator = MySQLAuthenticator.USER_PWD + + def create_connection(self) -> None: + """ + Create a MySQL connection based on the MySQLAuthenticator. + + :return: + """ + if self.authenticator == MySQLAuthenticator.USER_PWD: + self.connection = pymysql.connect( + host=self.host, + user=self.user, + password=self.password, + database=self.database, + cursorclass=pymysql.cursors.DictCursor, + ) + + def get_column_names( + self, database_name: str, table_name: str, include_comment: bool = False + ) -> List[str]: + """ + Get column names from the table. + + :param database_name: the database name + :param table_name: the table name + :param include_comment: include the comment + :return: the list of the column names + """ + try: + if not self.connection or not self.connection.open: + self.create_connection() + cursor = self.connection.cursor() + cursor.execute(f"SHOW FULL COLUMNS FROM {database_name}.{table_name}") + rows = cursor.fetchall() + columns = list() + for row in rows: + column_name = row["Field"] + columns.append(column_name) + if include_comment: + column_comment = row["Comment"] + if column_comment: + columns.append(column_comment) + return columns + except Exception as exception: + logger.exception( + f"Error in getting columns name from MySQL {database_name}.{table_name}" + ) + raise exception + finally: + cursor.close() + + def get_table_names_list(self, database_name: str) -> List[str]: + """ + Get the table names list from the MySQL database. + + :param database_name: the database name + :return: the list of the table names of the database + """ + try: + if not self.connection or not self.connection.open: + self.create_connection() + cursor = self.connection.cursor() + cursor.execute(f"SHOW TABLES IN {database_name}") + rows = cursor.fetchall() + table_list = list() + for row in rows: + table_name = list(row.values())[0] + table_list.append(table_name) + return table_list + except Exception as exception: + logger.exception( + f"Error in getting table names from the database {database_name} in MySQL" + ) + raise ExternalMetadataSourceException(exception) + finally: + cursor.close() + + @property + def type(self) -> str: + """ + The type of the source. + + :return: the name of the source. + """ + return "MySQL" diff --git a/python/metadata_guardian/source/external/snowflake_source.py b/python/metadata_guardian/source/external/snowflake_source.py index cb86bb0..4c59956 100644 --- a/python/metadata_guardian/source/external/snowflake_source.py +++ b/python/metadata_guardian/source/external/snowflake_source.py @@ -2,7 +2,6 @@ from enum import Enum from typing import Any, List, Optional -import snowflake.connector from loguru import logger from .external_metadata_source import ( @@ -43,9 +42,10 @@ class SnowflakeSource(ExternalMetadataSource): oauth_host: Optional[str] = None authenticator: SnowflakeAuthenticator = SnowflakeAuthenticator.USER_PWD - def get_connection(self) -> None: + def create_connection(self) -> None: """ - Get a Snowflake connection based on the SnowflakeAuthenticator. + Create a Snowflake connection based on the SnowflakeAuthenticator. + :return: """ if self.authenticator == SnowflakeAuthenticator.USER_PWD: @@ -81,6 +81,7 @@ def get_column_names( ) -> List[str]: """ Get column names from the table. + :param database_name: the database name :param table_name: the table name :param include_comment: include the comment @@ -88,7 +89,7 @@ def get_column_names( """ try: if not self.connection or self.connection.is_closed(): - self.get_connection() + self.create_connection() cursor = self.connection.cursor() cursor.execute( f'SHOW COLUMNS IN "{database_name}"."{self.schema_name}"."{table_name}"' @@ -97,14 +98,14 @@ def get_column_names( columns = list() for row in rows: column_name = row[2] - columns.append(column_name.lower()) + columns.append(column_name) if include_comment: column_comment = row[8] - columns.append(column_comment.lower()) + columns.append(column_comment) return columns except Exception as exception: logger.exception( - f"Error in getting columns name from Snowflake {self.schema_name}.{database_name}.{table_name}" + f"Error in getting columns name from Snowflake {database_name}.{self.schema_name}.{table_name}" ) raise exception finally: @@ -113,12 +114,13 @@ def get_column_names( def get_table_names_list(self, database_name: str) -> List[str]: """ Get the table names list from the Snowflake database. + :param database_name: the database name :return: the list of the table names of the database """ try: if not self.connection or self.connection.is_closed(): - self.get_connection() + self.create_connection() cursor = self.connection.cursor() cursor.execute(f'SHOW TABLES IN DATABASE "{database_name}"') rows = cursor.fetchall() @@ -139,6 +141,7 @@ def get_table_names_list(self, database_name: str) -> List[str]: def type(self) -> str: """ The type of the source. + :return: the name of the source. """ return "Snowflake" diff --git a/python/metadata_guardian/source/local/avro_schema_source.py b/python/metadata_guardian/source/local/avro_schema_source.py index bb0b5fe..c0ce6a1 100644 --- a/python/metadata_guardian/source/local/avro_schema_source.py +++ b/python/metadata_guardian/source/local/avro_schema_source.py @@ -17,6 +17,7 @@ def read(self) -> Union[Text, bytes]: def schema(self) -> Dict[str, Any]: """ Get the AVRO schema. + :return: the schema """ content = self.read() @@ -25,6 +26,7 @@ def schema(self) -> Dict[str, Any]: def get_field_attribute(self, attribute_name: str) -> Optional[List[str]]: """ Get the specific attribute from the AVRO Schema file. + :param attribute_name: the attribute name to get :return: the list of attributes in the fields """ @@ -36,6 +38,7 @@ def get_field_attribute(self, attribute_name: str) -> Optional[List[str]]: def get_column_names(self) -> List[str]: """ Get column names from the AVRO Schema file. + :return: the list of the column names """ return [field["name"] for field in self.schema()["fields"]] @@ -44,6 +47,7 @@ def get_column_names(self) -> List[str]: def namespace(self) -> str: """ Namespace of the AVRO schema. + :return: the namespace """ return self.schema()["namespace"] @@ -52,6 +56,7 @@ def namespace(self) -> str: def type(self) -> str: """ The type of the source. + :return: the name o of the source. """ - return "LocalAvroSchema" + return "AvroSchema" diff --git a/python/metadata_guardian/source/local/avro_source.py b/python/metadata_guardian/source/local/avro_source.py index df855df..93848c4 100644 --- a/python/metadata_guardian/source/local/avro_source.py +++ b/python/metadata_guardian/source/local/avro_source.py @@ -30,6 +30,7 @@ def read(self) -> DataFileReader: def schema(self) -> Dict[str, Any]: """ Get the AVRO schema. + :return: the schema """ reader = self.read() @@ -38,6 +39,7 @@ def schema(self) -> Dict[str, Any]: def get_field_attribute(self, attribute_name: str) -> Optional[List[str]]: """ Get the specific attribute from the AVRO Schema. + :param attribute_name: the attribute name to get :return: the list of attributes in the fields """ @@ -49,6 +51,7 @@ def get_field_attribute(self, attribute_name: str) -> Optional[List[str]]: def get_column_names(self) -> List[str]: """ Get column names from the AVRO file. + :return: the list of the column names """ return [field["name"] for field in self.schema()["fields"]] @@ -57,6 +60,7 @@ def get_column_names(self) -> List[str]: def namespace(self) -> str: """ Namespace of the AVRO schema. + :return: the namespace """ return self.schema()["namespace"] @@ -65,6 +69,7 @@ def namespace(self) -> str: def type(self) -> str: """ The type of the source. + :return: the name o of the source. """ - return "LocalAvroSource" + return "Avro" diff --git a/python/metadata_guardian/source/local/local_metadata_source.py b/python/metadata_guardian/source/local/local_metadata_source.py index 0f4d276..f23993a 100644 --- a/python/metadata_guardian/source/local/local_metadata_source.py +++ b/python/metadata_guardian/source/local/local_metadata_source.py @@ -19,13 +19,15 @@ class LocalMetadataSource(MetadataSource): def read(self) -> Dataset: """ Read the source local file. + :return: the file content """ return pyarrow.dataset.dataset(self.local_path, filesystem=self.fs) def schema(self) -> Schema: """ - Get the source schema + Get the source schema. + :return: the file schema """ return self.read().schema @@ -33,9 +35,7 @@ def schema(self) -> Schema: def get_column_names(self) -> List[str]: """ Get the column names from the schema. + :return: the list of the column names """ return [column for column in self.schema().names] - - def type(self) -> str: - return "LocalFile" diff --git a/python/metadata_guardian/source/local/orc_source.py b/python/metadata_guardian/source/local/orc_source.py index 3c43344..5b3494b 100644 --- a/python/metadata_guardian/source/local/orc_source.py +++ b/python/metadata_guardian/source/local/orc_source.py @@ -14,6 +14,7 @@ class ORCSource(LocalMetadataSource): def read(self) -> ORCFile: """ Read the ORC file. + :return: """ return ORCFile(self.local_path) @@ -21,6 +22,7 @@ def read(self) -> ORCFile: def schema(self) -> pyarrow.Schema: """ Get the ORC File. + :return: the orc schema """ return self.read().schema @@ -28,6 +30,7 @@ def schema(self) -> pyarrow.Schema: def get_column_names(self) -> List[str]: """ Get the column names from the schema. + :return: the list of the column names """ return self.schema().names @@ -36,6 +39,7 @@ def get_column_names(self) -> List[str]: def type(self) -> str: """ The type of the source. + :return: the name of the source. """ - return "LocalORC" + return "ORC" diff --git a/python/metadata_guardian/source/local/parquet_source.py b/python/metadata_guardian/source/local/parquet_source.py index fae1a79..adf38c5 100644 --- a/python/metadata_guardian/source/local/parquet_source.py +++ b/python/metadata_guardian/source/local/parquet_source.py @@ -11,6 +11,7 @@ class ParquetSource(LocalMetadataSource): def type(self) -> str: """ The type of the source. + :return: the name of the source. """ - return "LocalParquet" + return "Parquet" diff --git a/python/metadata_guardian/source/metadata_source.py b/python/metadata_guardian/source/metadata_source.py index 9ec0530..8ed4de1 100644 --- a/python/metadata_guardian/source/metadata_source.py +++ b/python/metadata_guardian/source/metadata_source.py @@ -9,6 +9,7 @@ class MetadataSource(ABC): def type(self) -> str: """ The type of the source. + :return: the name o of the source. """ pass diff --git a/python/pyproject.toml b/python/pyproject.toml index d3d33de..005d188 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -23,13 +23,14 @@ dependencies = [ ] [project.optional-dependencies] -all = ["avro", "snowflake-connector-python", "boto3", "boto3-stubs[athena,glue]", "deltalake", "google-cloud-bigquery", "confluent-kafka"] +all = ["avro", "snowflake-connector-python", "boto3", "boto3-stubs[athena,glue]", "deltalake", "google-cloud-bigquery", "confluent-kafka", "PyMySQL"] snowflake = [ "snowflake-connector-python" ] avro = [ "avro" ] aws = [ "boto3", "boto3-stubs[athena,glue]" ] gcp = [ "google-cloud-bigquery"] deltalake = [ "deltalake" ] kafka_schema_registry = [ "confluent-kafka" ] +mysql = ["PyMySQL"] devel = [ "mypy", "black", diff --git a/python/tests/external/test_deltatable_source.py b/python/tests/external/test_deltatable_source.py index d97da5d..77bc6ad 100644 --- a/python/tests/external/test_deltatable_source.py +++ b/python/tests/external/test_deltatable_source.py @@ -45,6 +45,7 @@ def test_deltatable_source_get_column_names_from_database_and_table(mock_connect uri = "s3://test_table" database_name = "database_name" table_name = "table_name" + external_data_catalog_disable = False schema = Schema( fields=[ Field( @@ -71,7 +72,9 @@ def test_deltatable_source_get_column_names_from_database_and_table(mock_connect "{'comment': 'comment2'}", ] - column_names = DeltaTableSource(uri=uri).get_column_names( + column_names = DeltaTableSource( + uri=uri, external_data_catalog_disable=external_data_catalog_disable + ).get_column_names( database_name=database_name, table_name=table_name, include_comment=True ) diff --git a/python/tests/external/test_gcp_source.py b/python/tests/external/test_gcp_source.py index c676424..907bcc4 100644 --- a/python/tests/external/test_gcp_source.py +++ b/python/tests/external/test_gcp_source.py @@ -1,6 +1,8 @@ from types import SimpleNamespace from unittest.mock import Mock, patch +from google.cloud import bigquery + from metadata_guardian.source import BigQuerySource @@ -9,14 +11,18 @@ def test_big_query_source_get_column_names(mock_connection): service_account_json_path = "" dataset_name = "test_dataset" table_name = "test_table" - results = [ - SimpleNamespace(column_name="timestamp", description="description1"), - SimpleNamespace(column_name="address_id", description="description2"), - ] + results = SimpleNamespace( + schema=[ + bigquery.SchemaField( + "timestamp", "STRING", mode="REQUIRED", description="description1" + ), + bigquery.SchemaField( + "address_id", "STRING", mode="REQUIRED", description="description2" + ), + ] + ) mock_connection.return_value = mock_connection - response = Mock() - response.result.return_value = results - mock_connection.query.return_value = response + mock_connection.get_table.return_value = results expected = ["timestamp", "description1", "address_id", "description2"] column_names = BigQuerySource( diff --git a/python/tests/external/test_mysql_source.py b/python/tests/external/test_mysql_source.py new file mode 100644 index 0000000..7761236 --- /dev/null +++ b/python/tests/external/test_mysql_source.py @@ -0,0 +1,80 @@ +from unittest.mock import patch + +from metadata_guardian.source import MySQLAuthenticator, MySQLSource + + +@patch("pymysql.connect") +def test_mysql_source_get_column_names(mock_connection): + database_name = "test" + table_name = "test_table" + user = "user" + host = "localhost" + password = "password" + mock_connection.cursor.return_value = mock_connection + mock_connection.fetchall.return_value = [ + { + "Field": "words", + "Type": "varchar(45)", + "Collation": "utf8_general_ci", + "Null": "YES", + "Key": "", + "Default": None, + "Extra": "", + "Privileges": "select,insert,update,references", + "Comment": "Use column to contain words", + }, + { + "Field": "name", + "Type": "varchar(45)", + "Collation": "utf8_general_ci", + "Null": "YES", + "Key": "", + "Default": None, + "Extra": "", + "Privileges": "select,insert,update,references", + "Comment": "", + }, + ] + mock_connection.execute.call_args == f"SHOW FULL COLUMNS FROM {database_name}.{table_name}" + expected = ["words", "Use column to contain words", "name"] + + source = MySQLSource( + host=host, + user=user, + password=password, + ) + source.connection = mock_connection + + column_names = source.get_column_names( + database_name=database_name, table_name=table_name, include_comment=True + ) + + assert column_names == expected + assert source.authenticator == MySQLAuthenticator.USER_PWD + + +@patch("pymysql.connect") +def test_mysql_source_get_table_names_list(mock_connection): + database_name = "test" + user = "user" + host = "localhost" + password = "password" + mock_connection.cursor.return_value = mock_connection + mock_connection.fetchall.return_value = [ + {"Tables_in_test": "t1"}, + {"Tables_in_test": "t2"}, + ] + mock_connection.execute.call_args == f"SHOW TABLES IN {database_name}" + expected = ["t1", "t2"] + + source = MySQLSource( + host=host, + user=user, + password=password, + ) + source.connection = mock_connection + + table_names = source.get_table_names_list(database_name=database_name) + + assert table_names == expected + assert source.authenticator == MySQLAuthenticator.USER_PWD diff --git a/python/tests/external/test_snowflake_source.py b/python/tests/external/test_snowflake_source.py index acdb0d4..e89fedf 100644 --- a/python/tests/external/test_snowflake_source.py +++ b/python/tests/external/test_snowflake_source.py @@ -46,7 +46,6 @@ def test_snowflake_source_get_table_names_list(mock_connection): sf_password = "sf_password" warehouse = "warehouse" mocked_cursor_one = mock_connection.connect().cursor.return_value - mocked_cursor_one.description = [["name"], ["phone"]] mocked_cursor_one.fetchall.return_value = [ (database_name, "TEST_TABLE"), (database_name, "TEST_TABLE2"), @@ -62,7 +61,7 @@ def test_snowflake_source_get_table_names_list(mock_connection): schema_name=schema_name, ) - column_names = source.get_table_names_list(database_name=database_name) + talbe_names = source.get_table_names_list(database_name=database_name) - assert column_names == expected + assert talbe_names == expected assert source.authenticator == SnowflakeAuthenticator.USER_PWD