Skip to content

Commit

Permalink
Add KafkaSchemaRegistrySource in the external providers in the Python…
Browse files Browse the repository at this point in the history
… binding (#10)
  • Loading branch information
fvaleye authored Dec 30, 2021
1 parent db32826 commit 2d9edd1
Show file tree
Hide file tree
Showing 8 changed files with 227 additions and 13 deletions.
14 changes: 9 additions & 5 deletions python/metadata_guardian/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List, NamedTuple, Optional, Tuple

from rich.console import Console
from rich.markup import escape
from rich.progress import (
BarColumn,
Progress,
Expand All @@ -23,13 +24,14 @@ class ProgressionBar(Progress):

task_id: Optional[TaskID] = None

def __init__(self) -> None:
def __init__(self, disable: bool) -> None:
super().__init__(
SpinnerColumn(),
"[progress.description]{task.description}: [red]{task.fields[current_item]}",
BarColumn(),
"[progress.percentage]{task.percentage:>3.0f}% ({task.completed}/{task.total})-",
TimeRemainingColumn(),
disable=disable,
)

def __enter__(self) -> "ProgressionBar":
Expand All @@ -43,10 +45,10 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None: # type: ignore

def add_task_with_item(
self,
item_name: Optional[str],
item_name: str,
source_type: str,
total: int,
current_item: str = "",
current_item: str = "Starting",
) -> None:
"""
Add task in the Progression Bar.
Expand All @@ -56,10 +58,12 @@ def add_task_with_item(
:param total: total of the number of tables
:return: the created Task
"""
task_details = f"[{item_name}]" if item_name else ""
task_description = f"[bold cyan]Searching in the {escape(source_type)} metadata source{escape(task_details)}"
task_id = super().add_task(
f"[bold cyan]Searching in {item_name} for the {source_type} metadata source",
description=task_description,
total=total,
current_item=current_item,
current_item=escape(current_item),
)
self.task_id = task_id

Expand Down
16 changes: 10 additions & 6 deletions python/metadata_guardian/scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@


class Scanner(ABC):
"""Scanner interface."""
"""
Scanner Interface.
"""

@abstractmethod
def scan_local(self, source: LocalMetadataSource) -> MetadataGuardianReport:
Expand Down Expand Up @@ -68,6 +70,7 @@ class ColumnScanner(Scanner):
"""Column Scanner instance."""

data_rules: DataRules
progression_bar_disable: bool = False

def scan_local(self, source: LocalMetadataSource) -> MetadataGuardianReport:
"""
Expand All @@ -78,7 +81,7 @@ def scan_local(self, source: LocalMetadataSource) -> MetadataGuardianReport:
logger.debug(
f"[blue]Launch the metadata scanning of the local provider {source.type}"
)
with ProgressionBar() as progression_bar:
with ProgressionBar(disable=self.progression_bar_disable) as progression_bar:
report = MetadataGuardianReport(
report_results=[
ReportResults(
Expand Down Expand Up @@ -108,9 +111,9 @@ def scan_external(
:return: a Metadata Guardian report
"""
logger.debug(
f"[blue]Launch the metadata scanning of the external provider {source.type} for the database {database_name}"
f"[blue]Launch the metadata scanning of the external provider {source.type} for {database_name}"
)
with ProgressionBar() as progression_bar:
with ProgressionBar(disable=self.progression_bar_disable) as progression_bar:
if table_name:
progression_bar.add_task_with_item(
item_name=database_name,
Expand Down Expand Up @@ -206,7 +209,7 @@ async def async_validate_words(
results=self.data_rules.validate_words(words=words),
)

with ProgressionBar() as progression_bar:
with ProgressionBar(disable=self.progression_bar_disable) as progression_bar:
if table_name:
tasks = [
async_validate_words(
Expand Down Expand Up @@ -239,6 +242,7 @@ class ContentFilesScanner:
"""Content Files Scanner instance."""

data_rules: DataRules
progression_bar_disable: bool = False

def scan_local_file(self, path: str) -> MetadataGuardianReport:
"""
Expand All @@ -250,7 +254,7 @@ def scan_local_file(self, path: str) -> MetadataGuardianReport:
f"[blue]Launch the metadata scanning the content of the file {path}"
)
progression_bar: ProgressionBar
with ProgressionBar() as progression_bar:
with ProgressionBar(disable=self.progression_bar_disable) as progression_bar:
progression_bar.add_task_with_item(
item_name=path, source_type="files", total=1
)
Expand Down
1 change: 1 addition & 0 deletions python/metadata_guardian/source/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .external.deltatable_source import *
from .external.external_metadata_source import *
from .external.gcp_source import *
from .external.kafka_schema_registry_source import *
from .external.snowflake_source import *
from .local.avro_schema_source import *
from .local.avro_source import *
Expand Down
2 changes: 1 addition & 1 deletion python/metadata_guardian/source/external/aws_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def get_column_names(
self, database_name: str, table_name: str, include_comment: bool = False
) -> List[str]:
"""
Get column names from the table.
Get the column names from the table.
:param database_name: the database name
:param table_name: the table name
:param include_comment: include the comment
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import json
from dataclasses import dataclass
from enum import Enum
from typing import Any, List, Optional

from loguru import logger

from .external_metadata_source import (
ExternalMetadataSource,
ExternalMetadataSourceException,
)

try:
from confluent_kafka.schema_registry import SchemaRegistryClient

KAFKA_SCHEMA_REGISTRY_INSTALLED = True
except ImportError:
logger.debug("Kafka Schema Registry optional dependency is not installed.")
KAFKA_SCHEMA_REGISTRY_INSTALLED = False

if KAFKA_SCHEMA_REGISTRY_INSTALLED:

class KafkaSchemaRegistryAuthentication(Enum):
"""Authentication method for Kafka Schema Registry source."""

USER_PWD = 1

@dataclass
class KafkaSchemaRegistrySource(ExternalMetadataSource):
"""Instance of a Kafka Schema Registry source."""

url: str
ssl_certificate_location: Optional[str] = None
ssl_key_location: Optional[str] = None
connection: Optional[Any] = None
authenticator: Optional[
KafkaSchemaRegistryAuthentication
] = KafkaSchemaRegistryAuthentication.USER_PWD
comment_field_name: str = "doc"

def get_connection(self) -> None:
"""
Get the connection of the Kafka Schema Registry.
:return:
"""
if self.authenticator == KafkaSchemaRegistryAuthentication.USER_PWD:
self.connection = SchemaRegistryClient(
{
"url": self.url,
}
)
else:
raise NotImplementedError()

def get_column_names(
self, database_name: str, table_name: str, include_comment: bool = False
) -> List[str]:
"""
Get the column names from the subject.
:param database_name: not relevant
:param table_name: the subject name
:param include_comment: include the comment
:return: the list of the column names
"""
try:
if not self.connection:
self.get_connection()
registered_schema = self.connection.get_latest_version(table_name)
columns = list()
for field in json.loads(registered_schema.schema.schema_str)["fields"]:
columns.append(field["name"].lower())
if include_comment and self.comment_field_name in field:
columns.append(field[self.comment_field_name].lower())
return columns
except Exception as exception:
logger.exception(
f"Error in getting columns name from the Kafka Schema Registry {table_name}"
)
raise exception

def get_table_names_list(self, database_name: str) -> List[str]:
"""
Get all the subjects from the Schema Registry.
:param database_name: not relevant in that case
:return: the list of the table names of the database
"""
try:
if not self.connection:
self.get_connection()
all_subjects = self.connection.get_subjects()
return all_subjects
except Exception as exception:
logger.exception(
f"Error all the subjects from the subject in the Kafka Schema Registry"
)
raise ExternalMetadataSourceException(exception)

@property
def type(self) -> str:
"""
The type of the source.
:return: the name of the source.
"""
return "Kafka Schema Registry"
3 changes: 2 additions & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ dependencies = [
]

[project.optional-dependencies]
all = ["avro", "snowflake-connector-python", "boto3", "boto3-stubs[athena,glue]", "deltalake", "google-cloud-bigquery"]
all = ["avro", "snowflake-connector-python", "boto3", "boto3-stubs[athena,glue]", "deltalake", "google-cloud-bigquery", "confluent-kafka"]
snowflake = [ "snowflake-connector-python" ]
avro = [ "avro" ]
aws = [ "boto3", "boto3-stubs[athena,glue]" ]
gcp = [ "google-cloud-bigquery"]
deltalake = [ "deltalake" ]
kafka_schema_registry = [ "confluent-kafka" ]
devel = [
"mypy",
"black",
Expand Down
69 changes: 69 additions & 0 deletions python/tests/external/test_kafka_schema_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from unittest.mock import patch

from confluent_kafka.schema_registry import RegisteredSchema, Schema

from metadata_guardian.source import (
KafkaSchemaRegistryAuthentication,
KafkaSchemaRegistrySource,
)


@patch("confluent_kafka.schema_registry.SchemaRegistryClient")
def test_kafka_schema_registry_source_get_column_names(mock_connection):
url = "url"
subject_name = "subject_name"
expected = ["key", "value", "doc"]

source = KafkaSchemaRegistrySource(
url=url,
)
schema_id = "schema_id"
schema_str = """{
"fields": [
{
"name": "key",
"type": "string"
},
{
"name": "value",
"type": "string",
"doc": "doc"
}
],
"name": "test_one",
"namespace": "test.one",
"type": "record"
}"""
schema = RegisteredSchema(
schema_id=schema_id,
schema=Schema(schema_str, "AVRO", []),
subject=subject_name,
version=1,
)
mock_connection.get_latest_version.return_value = schema
source.connection = mock_connection

column_names = source.get_column_names(
database_name=None, table_name=subject_name, include_comment=True
)

assert column_names == expected
assert source.authenticator == KafkaSchemaRegistryAuthentication.USER_PWD


@patch("confluent_kafka.schema_registry.SchemaRegistryClient")
def test_kafka_schema_registry_source_get_table_names_list(mock_connection):
url = "url"
expected = ["subject1", "subject2"]

source = KafkaSchemaRegistrySource(
url=url,
)
subjects = ["subject1", "subject2"]
mock_connection.get_subjects.return_value = subjects
source.connection = mock_connection

subjects_list = source.get_table_names_list(database_name=None)

assert subjects_list == expected
assert source.authenticator == KafkaSchemaRegistryAuthentication.USER_PWD
31 changes: 31 additions & 0 deletions python/tests/external/test_snowflake_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,34 @@ def test_snowflake_source_get_column_names(mock_connection):

assert column_names == expected
assert source.authenticator == SnowflakeAuthenticator.USER_PWD


@patch("snowflake.connector")
def test_snowflake_source_get_table_names_list(mock_connection):
database_name = "test_database"
schema_name = "PUBLIC"
sf_account = "sf_account"
sf_user = "sf_user"
sf_password = "sf_password"
warehouse = "warehouse"
mocked_cursor_one = mock_connection.connect().cursor.return_value
mocked_cursor_one.description = [["name"], ["phone"]]
mocked_cursor_one.fetchall.return_value = [
(database_name, "TEST_TABLE"),
(database_name, "TEST_TABLE2"),
]
mocked_cursor_one.execute.call_args == f'SHOW TABLES IN DATABASE "{database_name}"'
expected = ["TEST_TABLE", "TEST_TABLE2"]

source = SnowflakeSource(
sf_account=sf_account,
sf_user=sf_user,
sf_password=sf_password,
warehouse=warehouse,
schema_name=schema_name,
)

column_names = source.get_table_names_list(database_name=database_name)

assert column_names == expected
assert source.authenticator == SnowflakeAuthenticator.USER_PWD

0 comments on commit 2d9edd1

Please sign in to comment.