diff --git a/databuilder/loader/file_system_neo4j_csv_loader.py b/databuilder/loader/file_system_neo4j_csv_loader.py index 3c0fa1aaa..6d2f7c467 100644 --- a/databuilder/loader/file_system_neo4j_csv_loader.py +++ b/databuilder/loader/file_system_neo4j_csv_loader.py @@ -12,10 +12,9 @@ from databuilder.job.base_job import Job from databuilder.loader.base_loader import Loader -from databuilder.models.neo4j_csv_serde import NODE_LABEL, \ - RELATION_START_LABEL, RELATION_END_LABEL, RELATION_TYPE -from databuilder.models.neo4j_csv_serde import Neo4jCsvSerializable +from databuilder.models.graph_serializable import GraphSerializable from databuilder.utils.closer import Closer +from databuilder.serializers import neo4_serializer LOGGER = logging.getLogger(__name__) @@ -90,7 +89,7 @@ def _delete_dir() -> None: # Directory should be deleted after publish is finished Job.closer.register(_delete_dir) - def load(self, csv_serializable: Neo4jCsvSerializable) -> None: + def load(self, csv_serializable: GraphSerializable) -> None: """ Writes Neo4jCsvSerializable into CSV files. There are multiple CSV files that this method writes. @@ -107,9 +106,10 @@ def load(self, csv_serializable: Neo4jCsvSerializable) -> None: :return: """ - node_dict = csv_serializable.next_node() - while node_dict: - key = (node_dict[NODE_LABEL], len(node_dict)) + node = csv_serializable.next_node() + while node: + node_dict = neo4_serializer.serialize_node(node) + key = (node.label, len(node_dict)) file_suffix = '{}_{}'.format(*key) node_writer = self._get_writer(node_dict, self._node_file_mapping, @@ -117,13 +117,14 @@ def load(self, csv_serializable: Neo4jCsvSerializable) -> None: self._node_dir, file_suffix) node_writer.writerow(node_dict) - node_dict = csv_serializable.next_node() - - relation_dict = csv_serializable.next_relation() - while relation_dict: - key2 = (relation_dict[RELATION_START_LABEL], - relation_dict[RELATION_END_LABEL], - relation_dict[RELATION_TYPE], + node = csv_serializable.next_node() + + relation = csv_serializable.next_relation() + while relation: + relation_dict = neo4_serializer.serialize_relationship(relation) + key2 = (relation.start_label, + relation.end_label, + relation.type, len(relation_dict)) file_suffix = '{}_{}_{}'.format(key2[0], key2[1], key2[2]) @@ -133,7 +134,7 @@ def load(self, csv_serializable: Neo4jCsvSerializable) -> None: self._relation_dir, file_suffix) relation_writer.writerow(relation_dict) - relation_dict = csv_serializable.next_relation() + relation = csv_serializable.next_relation() def _get_writer(self, csv_record_dict: Dict[str, Any], diff --git a/databuilder/models/application.py b/databuilder/models/application.py index ba2c126dc..d86b0e2ce 100644 --- a/databuilder/models/application.py +++ b/databuilder/models/application.py @@ -1,16 +1,16 @@ # Copyright Contributors to the Amundsen project. # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Union +from typing import List, Union -from databuilder.models.neo4j_csv_serde import Neo4jCsvSerializable, NODE_KEY, \ - NODE_LABEL, RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ - RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE +from databuilder.models.graph_serializable import GraphSerializable from databuilder.models.table_metadata import TableMetadata +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship -class Application(Neo4jCsvSerializable): +class Application(GraphSerializable): """ Application-table matching model (Airflow task and table) """ @@ -48,14 +48,14 @@ def __init__(self, self._node_iter = iter(self.create_nodes()) self._relation_iter = iter(self.create_relation()) - def create_next_node(self) -> Union[Dict[str, Any], None]: + def create_next_node(self) -> Union[GraphNode, None]: # creates new node try: return next(self._node_iter) except StopIteration: return None - def create_next_relation(self) -> Union[Dict[str, Any], None]: + def create_next_relation(self) -> Union[GraphRelationship, None]: try: return next(self._relation_iter) except StopIteration: @@ -74,40 +74,47 @@ def get_application_model_key(self) -> str: dag=self.dag, task=self.task) - def create_nodes(self) -> List[Dict[str, Any]]: + def create_nodes(self) -> List[GraphNode]: """ Create a list of Neo4j node records :return: """ results = [] - - results.append({ - NODE_KEY: self.get_application_model_key(), - NODE_LABEL: Application.APPLICATION_LABEL, - Application.APPLICATION_URL_NAME: self.application_url, - Application.APPLICATION_NAME: Application.APPLICATION_TYPE, - Application.APPLICATION_DESCRIPTION: - '{app_type} with id {id}'.format(app_type=Application.APPLICATION_TYPE, - id=Application.APPLICATION_ID_FORMAT.format(dag_id=self.dag, - task_id=self.task)), - Application.APPLICATION_ID: Application.APPLICATION_ID_FORMAT.format(dag_id=self.dag, - task_id=self.task) - }) + application_description = '{app_type} with id {id}'.format( + app_type=Application.APPLICATION_TYPE, + id=Application.APPLICATION_ID_FORMAT.format(dag_id=self.dag, task_id=self.task) + ) + application_id = Application.APPLICATION_ID_FORMAT.format( + dag_id=self.dag, + task_id=self.task + ) + application_node = GraphNode( + key=self.get_application_model_key(), + label=Application.APPLICATION_LABEL, + attributes={ + Application.APPLICATION_URL_NAME: self.application_url, + Application.APPLICATION_NAME: Application.APPLICATION_TYPE, + Application.APPLICATION_DESCRIPTION: application_description, + Application.APPLICATION_ID: application_id + } + ) + results.append(application_node) return results - def create_relation(self) -> List[Dict[str, Any]]: + def create_relation(self) -> List[GraphRelationship]: """ Create a list of relations between application and table nodes :return: """ - results = [{ - RELATION_START_KEY: self.get_table_model_key(), - RELATION_START_LABEL: TableMetadata.TABLE_NODE_LABEL, - RELATION_END_KEY: self.get_application_model_key(), - RELATION_END_LABEL: Application.APPLICATION_LABEL, - RELATION_TYPE: Application.TABLE_APPLICATION_RELATION_TYPE, - RELATION_REVERSE_TYPE: Application.APPLICATION_TABLE_RELATION_TYPE - }] - + graph_relationship = GraphRelationship( + start_key=self.get_table_model_key(), + start_label=TableMetadata.TABLE_NODE_LABEL, + end_key=self.get_application_model_key(), + end_label=Application.APPLICATION_LABEL, + type=Application.TABLE_APPLICATION_RELATION_TYPE, + reverse_type=Application.APPLICATION_TABLE_RELATION_TYPE, + attributes={} + ) + results = [graph_relationship] return results diff --git a/databuilder/models/badge.py b/databuilder/models/badge.py index 9322b82ba..f1ef2fc12 100644 --- a/databuilder/models/badge.py +++ b/databuilder/models/badge.py @@ -1,12 +1,12 @@ # Copyright Contributors to the Amundsen project. # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional +from typing import List, Optional import re -from databuilder.models.neo4j_csv_serde import Neo4jCsvSerializable, NODE_KEY, \ - NODE_LABEL, RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ - RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship class Badge: @@ -19,7 +19,7 @@ def __repr__(self) -> str: self.category) -class BadgeMetadata(Neo4jCsvSerializable): +class BadgeMetadata(GraphSerializable): """ Badge model. """ @@ -62,14 +62,14 @@ def __repr__(self) -> str: return 'BadgeMetadata({!r}, {!r})'.format(self.start_label, self.start_key) - def create_next_node(self) -> Optional[Dict[str, Any]]: + def create_next_node(self) -> Optional[GraphNode]: # return the string representation of the data try: return next(self._node_iter) except StopIteration: return None - def create_next_relation(self) -> Optional[Dict[str, Any]]: + def create_next_relation(self) -> Optional[GraphRelationship]: try: return next(self._relation_iter) except StopIteration: @@ -84,7 +84,7 @@ def get_badge_key(name: str) -> str: def get_metadata_model_key(self) -> str: return self.start_key - def create_nodes(self) -> List[Dict[str, Any]]: + def create_nodes(self) -> List[GraphNode]: """ Create a list of Neo4j node records :return: @@ -92,22 +92,27 @@ def create_nodes(self) -> List[Dict[str, Any]]: results = [] for badge in self.badges: if badge: - results.append({ - NODE_KEY: self.get_badge_key(badge.name), - NODE_LABEL: self.BADGE_NODE_LABEL, - self.BADGE_CATEGORY: badge.category - }) + node = GraphNode( + key=self.get_badge_key(badge.name), + label=self.BADGE_NODE_LABEL, + attributes={ + self.BADGE_CATEGORY: badge.category + } + ) + results.append(node) return results - def create_relation(self) -> List[Dict[str, Any]]: + def create_relation(self) -> List[GraphRelationship]: results = [] for badge in self.badges: - results.append({ - RELATION_START_LABEL: self.start_label, - RELATION_END_LABEL: self.BADGE_NODE_LABEL, - RELATION_START_KEY: self.start_key, - RELATION_END_KEY: self.get_badge_key(badge.name), - RELATION_TYPE: self.BADGE_RELATION_TYPE, - RELATION_REVERSE_TYPE: self.INVERSE_BADGE_RELATION_TYPE, - }) + relation = GraphRelationship( + start_label=self.start_label, + end_label=self.BADGE_NODE_LABEL, + start_key=self.start_key, + end_key=self.get_badge_key(badge.name), + type=self.BADGE_RELATION_TYPE, + reverse_type=self.INVERSE_BADGE_RELATION_TYPE, + attributes={} + ) + results.append(relation) return results diff --git a/databuilder/models/column_usage_model.py b/databuilder/models/column_usage_model.py index 6d1a1e039..d37babdd9 100644 --- a/databuilder/models/column_usage_model.py +++ b/databuilder/models/column_usage_model.py @@ -1,20 +1,19 @@ # Copyright Contributors to the Amundsen project. # SPDX-License-Identifier: Apache-2.0 -from typing import Union, Dict, Any, Iterable, List +from typing import Union, Iterable, List -from databuilder.models.neo4j_csv_serde import ( - Neo4jCsvSerializable, RELATION_START_KEY, RELATION_END_KEY, - RELATION_START_LABEL, RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE -) +from databuilder.models.graph_serializable import GraphSerializable from databuilder.models.usage.usage_constants import ( READ_RELATION_TYPE, READ_REVERSE_RELATION_TYPE, READ_RELATION_COUNT_PROPERTY ) from databuilder.models.table_metadata import TableMetadata from databuilder.models.user import User +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship -class ColumnUsageModel(Neo4jCsvSerializable): +class ColumnUsageModel(GraphSerializable): """ A model represents user <--> column graph model @@ -49,14 +48,14 @@ def __init__(self, self._node_iter = iter(self.create_nodes()) self._relation_iter = iter(self.create_relation()) - def create_next_node(self) -> Union[Dict[str, Any], None]: + def create_next_node(self) -> Union[GraphNode, None]: try: return next(self._node_iter) except StopIteration: return None - def create_nodes(self) -> List[Dict[str, Any]]: + def create_nodes(self) -> List[GraphNode]: """ Create a list of Neo4j node records :return: @@ -64,23 +63,25 @@ def create_nodes(self) -> List[Dict[str, Any]]: return User(email=self.user_email).create_nodes() - def create_next_relation(self) -> Union[Dict[str, Any], None]: - + def create_next_relation(self) -> Union[GraphRelationship, None]: try: return next(self._relation_iter) except StopIteration: return None - def create_relation(self) -> Iterable[Any]: - return [{ - RELATION_START_LABEL: TableMetadata.TABLE_NODE_LABEL, - RELATION_END_LABEL: User.USER_NODE_LABEL, - RELATION_START_KEY: self._get_table_key(), - RELATION_END_KEY: self._get_user_key(self.user_email), - RELATION_TYPE: ColumnUsageModel.TABLE_USER_RELATION_TYPE, - RELATION_REVERSE_TYPE: ColumnUsageModel.USER_TABLE_RELATION_TYPE, - ColumnUsageModel.READ_RELATION_COUNT: self.read_count - }] + def create_relation(self) -> Iterable[GraphRelationship]: + relationship = GraphRelationship( + start_key=self._get_table_key(), + start_label=TableMetadata.TABLE_NODE_LABEL, + end_key=self._get_user_key(self.user_email), + end_label=User.USER_NODE_LABEL, + type=ColumnUsageModel.TABLE_USER_RELATION_TYPE, + reverse_type=ColumnUsageModel.USER_TABLE_RELATION_TYPE, + attributes={ + ColumnUsageModel.READ_RELATION_COUNT: self.read_count + } + ) + return [relationship] def _get_table_key(self) -> str: return TableMetadata.TABLE_KEY_FORMAT.format(db=self.database, diff --git a/databuilder/models/dashboard/dashboard_chart.py b/databuilder/models/dashboard/dashboard_chart.py index d11ff36e6..2185ce551 100644 --- a/databuilder/models/dashboard/dashboard_chart.py +++ b/databuilder/models/dashboard/dashboard_chart.py @@ -2,18 +2,20 @@ # SPDX-License-Identifier: Apache-2.0 import logging +from typing import Optional, Any, Union, Iterator -from typing import Optional, Dict, Any, Union, Iterator from databuilder.models.dashboard.dashboard_query import DashboardQuery -from databuilder.models.neo4j_csv_serde import ( - Neo4jCsvSerializable, NODE_LABEL, NODE_KEY, RELATION_START_KEY, RELATION_END_KEY, RELATION_START_LABEL, - RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE) +from databuilder.models.graph_serializable import ( + GraphSerializable) + +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship LOGGER = logging.getLogger(__name__) -class DashboardChart(Neo4jCsvSerializable): +class DashboardChart(GraphSerializable): """ A model that encapsulate Dashboard's charts """ @@ -47,51 +49,56 @@ def __init__(self, self._node_iterator = self._create_node_iterator() self._relation_iterator = self._create_relation_iterator() - def create_next_node(self) -> Union[Dict[str, Any], None]: + def create_next_node(self) -> Union[GraphNode, None]: try: return next(self._node_iterator) except StopIteration: return None - def _create_node_iterator(self) -> Iterator[Dict[str, Any]]: # noqa: C901 - node = { - NODE_LABEL: DashboardChart.DASHBOARD_CHART_LABEL, - NODE_KEY: self._get_chart_node_key(), + def _create_node_iterator(self) -> Iterator[GraphNode]: + node_attributes = { 'id': self._chart_id } if self._chart_name: - node['name'] = self._chart_name + node_attributes['name'] = self._chart_name if self._chart_type: - node['type'] = self._chart_type + node_attributes['type'] = self._chart_type if self._chart_url: - node['url'] = self._chart_url + node_attributes['url'] = self._chart_url + node = GraphNode( + key=self._get_chart_node_key(), + label=DashboardChart.DASHBOARD_CHART_LABEL, + attributes=node_attributes + ) yield node - def create_next_relation(self) -> Union[Dict[str, Any], None]: + def create_next_relation(self) -> Union[GraphRelationship, None]: try: return next(self._relation_iterator) except StopIteration: return None - def _create_relation_iterator(self) -> Iterator[Dict[str, Any]]: - yield { - RELATION_START_LABEL: DashboardQuery.DASHBOARD_QUERY_LABEL, - RELATION_END_LABEL: DashboardChart.DASHBOARD_CHART_LABEL, - RELATION_START_KEY: DashboardQuery.DASHBOARD_QUERY_KEY_FORMAT.format( + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + relationship = GraphRelationship( + start_label=DashboardQuery.DASHBOARD_QUERY_LABEL, + start_key=DashboardQuery.DASHBOARD_QUERY_KEY_FORMAT.format( product=self._product, cluster=self._cluster, dashboard_group_id=self._dashboard_group_id, dashboard_id=self._dashboard_id, query_id=self._query_id ), - RELATION_END_KEY: self._get_chart_node_key(), - RELATION_TYPE: DashboardChart.CHART_RELATION_TYPE, - RELATION_REVERSE_TYPE: DashboardChart.CHART_REVERSE_RELATION_TYPE - } + end_label=DashboardChart.DASHBOARD_CHART_LABEL, + end_key=self._get_chart_node_key(), + type=DashboardChart.CHART_RELATION_TYPE, + reverse_type=DashboardChart.CHART_REVERSE_RELATION_TYPE, + attributes={} + ) + yield relationship def _get_chart_node_key(self) -> str: return DashboardChart.DASHBOARD_CHART_KEY_FORMAT.format( diff --git a/databuilder/models/dashboard/dashboard_execution.py b/databuilder/models/dashboard/dashboard_execution.py index 4d74328a5..6aa5a04df 100644 --- a/databuilder/models/dashboard/dashboard_execution.py +++ b/databuilder/models/dashboard/dashboard_execution.py @@ -3,17 +3,18 @@ import logging -from typing import Optional, Dict, Any, Union, Iterator +from typing import Optional, Any, Union, Iterator from databuilder.models.dashboard.dashboard_metadata import DashboardMetadata -from databuilder.models.neo4j_csv_serde import ( - Neo4jCsvSerializable, NODE_LABEL, NODE_KEY, RELATION_START_KEY, RELATION_END_KEY, RELATION_START_LABEL, - RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE) +from databuilder.models.graph_serializable import (GraphSerializable) + +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship LOGGER = logging.getLogger(__name__) -class DashboardExecution(Neo4jCsvSerializable): +class DashboardExecution(GraphSerializable): """ A model that encapsulate Dashboard's execution timestamp in epoch and execution state """ @@ -46,40 +47,45 @@ def __init__(self, self._node_iterator = self._create_node_iterator() self._relation_iterator = self._create_relation_iterator() - def create_next_node(self) -> Union[Dict[str, Any], None]: + def create_next_node(self) -> Union[GraphNode, None]: try: return next(self._node_iterator) except StopIteration: return None - def _create_node_iterator(self) -> Iterator[Dict[str, Any]]: # noqa: C901 - yield { - NODE_LABEL: DashboardExecution.DASHBOARD_EXECUTION_LABEL, - NODE_KEY: self._get_last_execution_node_key(), - 'timestamp': self._execution_timestamp, - 'state': self._execution_state - } + def _create_node_iterator(self) -> Iterator[GraphNode]: + node = GraphNode( + key=self._get_last_execution_node_key(), + label=DashboardExecution.DASHBOARD_EXECUTION_LABEL, + attributes={ + 'timestamp': self._execution_timestamp, + 'state': self._execution_state + } + ) + yield node - def create_next_relation(self) -> Union[Dict[str, Any], None]: + def create_next_relation(self) -> Union[GraphRelationship, None]: try: return next(self._relation_iterator) except StopIteration: return None - def _create_relation_iterator(self) -> Iterator[Dict[str, Any]]: - yield { - RELATION_START_LABEL: DashboardMetadata.DASHBOARD_NODE_LABEL, - RELATION_END_LABEL: DashboardExecution.DASHBOARD_EXECUTION_LABEL, - RELATION_START_KEY: DashboardMetadata.DASHBOARD_KEY_FORMAT.format( + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + relationship = GraphRelationship( + start_label=DashboardMetadata.DASHBOARD_NODE_LABEL, + start_key=DashboardMetadata.DASHBOARD_KEY_FORMAT.format( product=self._product, cluster=self._cluster, dashboard_group=self._dashboard_group_id, dashboard_name=self._dashboard_id ), - RELATION_END_KEY: self._get_last_execution_node_key(), - RELATION_TYPE: DashboardExecution.DASHBOARD_EXECUTION_RELATION_TYPE, - RELATION_REVERSE_TYPE: DashboardExecution.EXECUTION_DASHBOARD_RELATION_TYPE - } + end_label=DashboardExecution.DASHBOARD_EXECUTION_LABEL, + end_key=self._get_last_execution_node_key(), + type=DashboardExecution.DASHBOARD_EXECUTION_RELATION_TYPE, + reverse_type=DashboardExecution.EXECUTION_DASHBOARD_RELATION_TYPE, + attributes={} + ) + yield relationship def _get_last_execution_node_key(self) -> str: return DashboardExecution.DASHBOARD_EXECUTION_KEY_FORMAT.format( diff --git a/databuilder/models/dashboard/dashboard_last_modified.py b/databuilder/models/dashboard/dashboard_last_modified.py index 4b7d9d2b5..916248d60 100644 --- a/databuilder/models/dashboard/dashboard_last_modified.py +++ b/databuilder/models/dashboard/dashboard_last_modified.py @@ -3,18 +3,20 @@ import logging -from typing import Optional, Dict, Any, Union, Iterator +from typing import Optional, Any, Union, Iterator from databuilder.models.dashboard.dashboard_metadata import DashboardMetadata -from databuilder.models.neo4j_csv_serde import ( - Neo4jCsvSerializable, NODE_LABEL, NODE_KEY, RELATION_START_KEY, RELATION_END_KEY, RELATION_START_LABEL, - RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE) +from databuilder.models.graph_serializable import ( + GraphSerializable) from databuilder.models.timestamp import timestamp_constants +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship + LOGGER = logging.getLogger(__name__) -class DashboardLastModifiedTimestamp(Neo4jCsvSerializable): +class DashboardLastModifiedTimestamp(GraphSerializable): """ A model that encapsulate Dashboard's last modified timestamp in epoch """ @@ -38,40 +40,46 @@ def __init__(self, self._node_iterator = self._create_node_iterator() self._relation_iterator = self._create_relation_iterator() - def create_next_node(self) -> Union[Dict[str, Any], None]: + def create_next_node(self) -> Union[GraphNode, None]: try: return next(self._node_iterator) except StopIteration: return None - def _create_node_iterator(self) -> Iterator[Dict[str, Any]]: # noqa: C901 - yield { - NODE_LABEL: timestamp_constants.NODE_LABEL, - NODE_KEY: self._get_last_modified_node_key(), + def _create_node_iterator(self) -> Iterator[GraphNode]: + node_attributes = { timestamp_constants.TIMESTAMP_PROPERTY: self._last_modified_timestamp, - timestamp_constants.TIMESTAMP_NAME_PROPERTY: timestamp_constants.TimestampName.last_updated_timestamp.name, + timestamp_constants.TIMESTAMP_NAME_PROPERTY: timestamp_constants.TimestampName.last_updated_timestamp.name } + node = GraphNode( + key=self._get_last_modified_node_key(), + label=timestamp_constants.NODE_LABEL, + attributes=node_attributes + ) + yield node - def create_next_relation(self) -> Union[Dict[str, Any], None]: + def create_next_relation(self) -> Union[GraphRelationship, None]: try: return next(self._relation_iterator) except StopIteration: return None - def _create_relation_iterator(self) -> Iterator[Dict[str, Any]]: - yield { - RELATION_START_LABEL: DashboardMetadata.DASHBOARD_NODE_LABEL, - RELATION_END_LABEL: timestamp_constants.NODE_LABEL, - RELATION_START_KEY: DashboardMetadata.DASHBOARD_KEY_FORMAT.format( + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + relationship = GraphRelationship( + start_key=DashboardMetadata.DASHBOARD_KEY_FORMAT.format( product=self._product, cluster=self._cluster, dashboard_group=self._dashboard_group_id, dashboard_name=self._dashboard_id ), - RELATION_END_KEY: self._get_last_modified_node_key(), - RELATION_TYPE: timestamp_constants.LASTUPDATED_RELATION_TYPE, - RELATION_REVERSE_TYPE: timestamp_constants.LASTUPDATED_REVERSE_RELATION_TYPE - } + start_label=DashboardMetadata.DASHBOARD_NODE_LABEL, + end_key=self._get_last_modified_node_key(), + end_label=timestamp_constants.NODE_LABEL, + type=timestamp_constants.LASTUPDATED_RELATION_TYPE, + reverse_type=timestamp_constants.LASTUPDATED_REVERSE_RELATION_TYPE, + attributes={} + ) + yield relationship def _get_last_modified_node_key(self) -> str: return DashboardLastModifiedTimestamp.DASHBOARD_LAST_MODIFIED_KEY_FORMAT.format( diff --git a/databuilder/models/dashboard/dashboard_metadata.py b/databuilder/models/dashboard/dashboard_metadata.py index 41244e4ae..740b1e568 100644 --- a/databuilder/models/dashboard/dashboard_metadata.py +++ b/databuilder/models/dashboard/dashboard_metadata.py @@ -1,22 +1,20 @@ # Copyright Contributors to the Amundsen project. # SPDX-License-Identifier: Apache-2.0 -from collections import namedtuple - -from typing import Any, Dict, Iterator, List, Optional, Set, Union +from typing import Any, Iterator, List, Optional, Set, Union, Dict from databuilder.models.cluster import cluster_constants -from databuilder.models.neo4j_csv_serde import ( - Neo4jCsvSerializable, NODE_LABEL, NODE_KEY, RELATION_START_KEY, RELATION_END_KEY, RELATION_START_LABEL, - RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE) +from databuilder.models.graph_serializable import ( + GraphSerializable +) # TODO: We could separate TagMetadata from table_metadata to own module from databuilder.models.table_metadata import TagMetadata -NodeTuple = namedtuple('KeyName', ['key', 'name', 'label']) -RelTuple = namedtuple('RelKeys', ['start_label', 'end_label', 'start_key', 'end_key', 'type', 'reverse_type']) +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship -class DashboardMetadata(Neo4jCsvSerializable): +class DashboardMetadata(GraphSerializable): """ Dashboard metadata including dashboard group name, dashboardgroup description, dashboard description, and tags. @@ -132,127 +130,162 @@ def _get_dashboard_group_key(self) -> str: cluster=self.cluster, product=self.product) - def create_next_node(self) -> Union[Dict[str, Any], None]: + def create_next_node(self) -> Union[GraphNode, None]: try: return next(self._node_iterator) except StopIteration: return None - def _create_next_node(self) -> Iterator[Any]: + def _create_next_node(self) -> Iterator[GraphNode]: # Cluster node if not self._get_cluster_key() in self._processed_cluster: self._processed_cluster.add(self._get_cluster_key()) - yield { - NODE_LABEL: cluster_constants.CLUSTER_NODE_LABEL, - NODE_KEY: self._get_cluster_key(), - cluster_constants.CLUSTER_NAME_PROP_KEY: self.cluster - } + cluster_node = GraphNode( + key=self._get_cluster_key(), + label=cluster_constants.CLUSTER_NODE_LABEL, + attributes={ + cluster_constants.CLUSTER_NAME_PROP_KEY: self.cluster + } + ) + yield cluster_node - # Dashboard node - dashboard_node: Dict[str, Any] = { - NODE_LABEL: DashboardMetadata.DASHBOARD_NODE_LABEL, - NODE_KEY: self._get_dashboard_key(), + # Dashboard node attributes + dashboard_node_attributes: Dict[str, Any] = { DashboardMetadata.DASHBOARD_NAME: self.dashboard_name, } if self.created_timestamp: - dashboard_node[DashboardMetadata.DASHBOARD_CREATED_TIME_STAMP] = self.created_timestamp + dashboard_node_attributes[DashboardMetadata.DASHBOARD_CREATED_TIME_STAMP] = self.created_timestamp if self.dashboard_url: - dashboard_node[DashboardMetadata.DASHBOARD_URL] = self.dashboard_url + dashboard_node_attributes[DashboardMetadata.DASHBOARD_URL] = self.dashboard_url + + dashboard_node = GraphNode( + key=self._get_dashboard_key(), + label=DashboardMetadata.DASHBOARD_NODE_LABEL, + attributes=dashboard_node_attributes + ) yield dashboard_node # Dashboard group if self.dashboard_group and not self._get_dashboard_group_key() in self._processed_dashboard_group: self._processed_dashboard_group.add(self._get_dashboard_group_key()) - dashboard_group_node = { - NODE_LABEL: DashboardMetadata.DASHBOARD_GROUP_NODE_LABEL, - NODE_KEY: self._get_dashboard_group_key(), + dashboard_group_node_attributes = { DashboardMetadata.DASHBOARD_NAME: self.dashboard_group, } if self.dashboard_group_url: - dashboard_group_node[DashboardMetadata.DASHBOARD_GROUP_URL] = self.dashboard_group_url + dashboard_group_node_attributes[DashboardMetadata.DASHBOARD_GROUP_URL] = self.dashboard_group_url + + dashboard_group_node = GraphNode( + key=self._get_dashboard_group_key(), + label=DashboardMetadata.DASHBOARD_GROUP_NODE_LABEL, + attributes=dashboard_group_node_attributes + ) yield dashboard_group_node # Dashboard group description if self.dashboard_group_description: - yield {NODE_LABEL: DashboardMetadata.DASHBOARD_DESCRIPTION_NODE_LABEL, - NODE_KEY: self._get_dashboard_group_description_key(), - DashboardMetadata.DASHBOARD_DESCRIPTION: self.dashboard_group_description} + dashboard_group_description_node = GraphNode( + key=self._get_dashboard_group_description_key(), + label=DashboardMetadata.DASHBOARD_DESCRIPTION_NODE_LABEL, + attributes={ + DashboardMetadata.DASHBOARD_DESCRIPTION: self.dashboard_group_description + } + ) + yield dashboard_group_description_node # Dashboard description node if self.description: - yield {NODE_LABEL: DashboardMetadata.DASHBOARD_DESCRIPTION_NODE_LABEL, - NODE_KEY: self._get_dashboard_description_key(), - DashboardMetadata.DASHBOARD_DESCRIPTION: self.description} + dashboard_description_node = GraphNode( + key=self._get_dashboard_description_key(), + label=DashboardMetadata.DASHBOARD_DESCRIPTION_NODE_LABEL, + attributes={ + DashboardMetadata.DASHBOARD_DESCRIPTION: self.description + } + ) + yield dashboard_description_node # Dashboard tag node if self.tags: for tag in self.tags: - yield {NODE_LABEL: TagMetadata.TAG_NODE_LABEL, - NODE_KEY: TagMetadata.get_tag_key(tag), - TagMetadata.TAG_TYPE: 'dashboard'} - - def create_next_relation(self) -> Union[Dict[str, Any], None]: + dashboard_tag_node = GraphNode( + key=TagMetadata.get_tag_key(tag), + label=TagMetadata.TAG_NODE_LABEL, + attributes={ + TagMetadata.TAG_TYPE: 'dashboard' + } + ) + yield dashboard_tag_node + + def create_next_relation(self) -> Union[GraphRelationship, None]: try: return next(self._relation_iterator) except StopIteration: return None - def _create_next_relation(self) -> Iterator[Any]: - + def _create_next_relation(self) -> Iterator[GraphRelationship]: # Cluster <-> Dashboard group - yield { - RELATION_START_LABEL: cluster_constants.CLUSTER_NODE_LABEL, - RELATION_END_LABEL: DashboardMetadata.DASHBOARD_GROUP_NODE_LABEL, - RELATION_START_KEY: self._get_cluster_key(), - RELATION_END_KEY: self._get_dashboard_group_key(), - RELATION_TYPE: DashboardMetadata.CLUSTER_DASHBOARD_GROUP_RELATION_TYPE, - RELATION_REVERSE_TYPE: DashboardMetadata.DASHBOARD_GROUP_CLUSTER_RELATION_TYPE - } + cluster_dashboard_group_relationship = GraphRelationship( + start_label=cluster_constants.CLUSTER_NODE_LABEL, + start_key=self._get_cluster_key(), + end_label=DashboardMetadata.DASHBOARD_GROUP_NODE_LABEL, + end_key=self._get_dashboard_group_key(), + type=DashboardMetadata.CLUSTER_DASHBOARD_GROUP_RELATION_TYPE, + reverse_type=DashboardMetadata.DASHBOARD_GROUP_CLUSTER_RELATION_TYPE, + attributes={} + ) + yield cluster_dashboard_group_relationship # Dashboard group > Dashboard group description relation if self.dashboard_group_description: - yield { - RELATION_START_LABEL: DashboardMetadata.DASHBOARD_GROUP_NODE_LABEL, - RELATION_END_LABEL: DashboardMetadata.DASHBOARD_DESCRIPTION_NODE_LABEL, - RELATION_START_KEY: self._get_dashboard_group_key(), - RELATION_END_KEY: self._get_dashboard_group_description_key(), - RELATION_TYPE: DashboardMetadata.DASHBOARD_DESCRIPTION_RELATION_TYPE, - RELATION_REVERSE_TYPE: DashboardMetadata.DESCRIPTION_DASHBOARD_RELATION_TYPE - } + dashboard_group_description_relationship = GraphRelationship( + start_label=DashboardMetadata.DASHBOARD_GROUP_NODE_LABEL, + start_key=self._get_dashboard_group_key(), + end_label=DashboardMetadata.DASHBOARD_DESCRIPTION_NODE_LABEL, + end_key=self._get_dashboard_group_description_key(), + type=DashboardMetadata.DASHBOARD_DESCRIPTION_RELATION_TYPE, + reverse_type=DashboardMetadata.DESCRIPTION_DASHBOARD_RELATION_TYPE, + attributes={} + ) + yield dashboard_group_description_relationship # Dashboard group > Dashboard relation - yield { - RELATION_START_LABEL: DashboardMetadata.DASHBOARD_NODE_LABEL, - RELATION_END_LABEL: DashboardMetadata.DASHBOARD_GROUP_NODE_LABEL, - RELATION_START_KEY: self._get_dashboard_key(), - RELATION_END_KEY: self._get_dashboard_group_key(), - RELATION_TYPE: DashboardMetadata.DASHBOARD_DASHBOARD_GROUP_RELATION_TYPE, - RELATION_REVERSE_TYPE: DashboardMetadata.DASHBOARD_GROUP_DASHBOARD_RELATION_TYPE - } + dashboard_group_dashboard_relationship = GraphRelationship( + start_label=DashboardMetadata.DASHBOARD_NODE_LABEL, + end_label=DashboardMetadata.DASHBOARD_GROUP_NODE_LABEL, + start_key=self._get_dashboard_key(), + end_key=self._get_dashboard_group_key(), + type=DashboardMetadata.DASHBOARD_DASHBOARD_GROUP_RELATION_TYPE, + reverse_type=DashboardMetadata.DASHBOARD_GROUP_DASHBOARD_RELATION_TYPE, + attributes={} + ) + yield dashboard_group_dashboard_relationship # Dashboard > Dashboard description relation if self.description: - yield { - RELATION_START_LABEL: DashboardMetadata.DASHBOARD_NODE_LABEL, - RELATION_END_LABEL: DashboardMetadata.DASHBOARD_DESCRIPTION_NODE_LABEL, - RELATION_START_KEY: self._get_dashboard_key(), - RELATION_END_KEY: self._get_dashboard_description_key(), - RELATION_TYPE: DashboardMetadata.DASHBOARD_DESCRIPTION_RELATION_TYPE, - RELATION_REVERSE_TYPE: DashboardMetadata.DESCRIPTION_DASHBOARD_RELATION_TYPE - } + dashboard_description_relationship = GraphRelationship( + start_label=DashboardMetadata.DASHBOARD_NODE_LABEL, + end_label=DashboardMetadata.DASHBOARD_DESCRIPTION_NODE_LABEL, + start_key=self._get_dashboard_key(), + end_key=self._get_dashboard_description_key(), + type=DashboardMetadata.DASHBOARD_DESCRIPTION_RELATION_TYPE, + reverse_type=DashboardMetadata.DESCRIPTION_DASHBOARD_RELATION_TYPE, + attributes={} + ) + yield dashboard_description_relationship # Dashboard > Dashboard tag relation if self.tags: for tag in self.tags: - yield { - RELATION_START_LABEL: DashboardMetadata.DASHBOARD_NODE_LABEL, - RELATION_END_LABEL: TagMetadata.TAG_NODE_LABEL, - RELATION_START_KEY: self._get_dashboard_key(), - RELATION_END_KEY: TagMetadata.get_tag_key(tag), - RELATION_TYPE: DashboardMetadata.DASHBOARD_TAG_RELATION_TYPE, - RELATION_REVERSE_TYPE: DashboardMetadata.TAG_DASHBOARD_RELATION_TYPE - } + dashboard_tag_relationship = GraphRelationship( + start_label=DashboardMetadata.DASHBOARD_NODE_LABEL, + end_label=TagMetadata.TAG_NODE_LABEL, + start_key=self._get_dashboard_key(), + end_key=TagMetadata.get_tag_key(tag), + type=DashboardMetadata.DASHBOARD_TAG_RELATION_TYPE, + reverse_type=DashboardMetadata.TAG_DASHBOARD_RELATION_TYPE, + attributes={} + ) + yield dashboard_tag_relationship diff --git a/databuilder/models/dashboard/dashboard_owner.py b/databuilder/models/dashboard/dashboard_owner.py index bade5f348..eda0ee0ac 100644 --- a/databuilder/models/dashboard/dashboard_owner.py +++ b/databuilder/models/dashboard/dashboard_owner.py @@ -3,20 +3,21 @@ import logging -from typing import Optional, Dict, Any, Union, Iterator +from typing import Optional, Any, Union, Iterator from databuilder.models.dashboard.dashboard_metadata import DashboardMetadata -from databuilder.models.neo4j_csv_serde import ( - Neo4jCsvSerializable, RELATION_START_KEY, RELATION_END_KEY, RELATION_START_LABEL, - RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE) +from databuilder.models.graph_serializable import ( + GraphSerializable) from databuilder.models.owner_constants import OWNER_OF_OBJECT_RELATION_TYPE, OWNER_RELATION_TYPE from databuilder.models.user import User +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship LOGGER = logging.getLogger(__name__) -class DashboardOwner(Neo4jCsvSerializable): +class DashboardOwner(GraphSerializable): """ A model that encapsulate Dashboard's owner. Note that it does not create new user as it has insufficient information about user but it builds relation @@ -42,29 +43,31 @@ def __init__(self, self._relation_iterator = self._create_relation_iterator() - def create_next_node(self) -> Union[Dict[str, Any], None]: + def create_next_node(self) -> Union[GraphNode, None]: return None - def create_next_relation(self) -> Union[Dict[str, Any], None]: + def create_next_relation(self) -> Union[GraphRelationship, None]: try: return next(self._relation_iterator) except StopIteration: return None - def _create_relation_iterator(self) -> Iterator[Dict[str, Any]]: - yield { - RELATION_START_LABEL: DashboardMetadata.DASHBOARD_NODE_LABEL, - RELATION_END_LABEL: User.USER_NODE_LABEL, - RELATION_START_KEY: DashboardMetadata.DASHBOARD_KEY_FORMAT.format( + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + relationship = GraphRelationship( + start_label=DashboardMetadata.DASHBOARD_NODE_LABEL, + end_label=User.USER_NODE_LABEL, + start_key=DashboardMetadata.DASHBOARD_KEY_FORMAT.format( product=self._product, cluster=self._cluster, dashboard_group=self._dashboard_group_id, dashboard_name=self._dashboard_id ), - RELATION_END_KEY: User.get_user_model_key(email=self._email), - RELATION_TYPE: OWNER_RELATION_TYPE, - RELATION_REVERSE_TYPE: OWNER_OF_OBJECT_RELATION_TYPE - } + end_key=User.get_user_model_key(email=self._email), + type=OWNER_RELATION_TYPE, + reverse_type=OWNER_OF_OBJECT_RELATION_TYPE, + attributes={} + ) + yield relationship def __repr__(self) -> str: return 'DashboardOwner({!r}, {!r}, {!r}, {!r}, {!r})'.format( diff --git a/databuilder/models/dashboard/dashboard_query.py b/databuilder/models/dashboard/dashboard_query.py index 1f57deb2f..61cd39620 100644 --- a/databuilder/models/dashboard/dashboard_query.py +++ b/databuilder/models/dashboard/dashboard_query.py @@ -3,17 +3,19 @@ import logging -from typing import Optional, Dict, Any, Union, Iterator +from typing import Optional, Any, Union, Iterator from databuilder.models.dashboard.dashboard_metadata import DashboardMetadata -from databuilder.models.neo4j_csv_serde import ( - Neo4jCsvSerializable, NODE_LABEL, NODE_KEY, RELATION_START_KEY, RELATION_END_KEY, RELATION_START_LABEL, - RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE) +from databuilder.models.graph_serializable import ( + GraphSerializable) + +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship LOGGER = logging.getLogger(__name__) -class DashboardQuery(Neo4jCsvSerializable): +class DashboardQuery(GraphSerializable): """ A model that encapsulate Dashboard's query name """ @@ -45,48 +47,54 @@ def __init__(self, self._node_iterator = self._create_node_iterator() self._relation_iterator = self._create_relation_iterator() - def create_next_node(self) -> Union[Dict[str, Any], None]: + def create_next_node(self) -> Union[GraphNode, None]: try: return next(self._node_iterator) except StopIteration: return None - def _create_node_iterator(self) -> Iterator[Dict[str, Any]]: # noqa: C901 - node = { - NODE_LABEL: DashboardQuery.DASHBOARD_QUERY_LABEL, - NODE_KEY: self._get_query_node_key(), + def _create_node_iterator(self) -> Iterator[GraphNode]: + node_attributes = { 'id': self._query_id, 'name': self._query_name, } if self._url: - node['url'] = self._url + node_attributes['url'] = self._url if self._query_text: - node['query_text'] = self._query_text + node_attributes['query_text'] = self._query_text + + node = GraphNode( + key=self._get_query_node_key(), + label=DashboardQuery.DASHBOARD_QUERY_LABEL, + attributes=node_attributes + ) yield node - def create_next_relation(self) -> Union[Dict[str, Any], None]: + def create_next_relation(self) -> Union[GraphRelationship, None]: try: return next(self._relation_iterator) except StopIteration: return None - def _create_relation_iterator(self) -> Iterator[Dict[str, Any]]: - yield { - RELATION_START_LABEL: DashboardMetadata.DASHBOARD_NODE_LABEL, - RELATION_END_LABEL: DashboardQuery.DASHBOARD_QUERY_LABEL, - RELATION_START_KEY: DashboardMetadata.DASHBOARD_KEY_FORMAT.format( + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + relationship = GraphRelationship( + start_label=DashboardMetadata.DASHBOARD_NODE_LABEL, + end_label=DashboardQuery.DASHBOARD_QUERY_LABEL, + start_key=DashboardMetadata.DASHBOARD_KEY_FORMAT.format( product=self._product, cluster=self._cluster, dashboard_group=self._dashboard_group_id, dashboard_name=self._dashboard_id ), - RELATION_END_KEY: self._get_query_node_key(), - RELATION_TYPE: DashboardQuery.DASHBOARD_QUERY_RELATION_TYPE, - RELATION_REVERSE_TYPE: DashboardQuery.QUERY_DASHBOARD_RELATION_TYPE - } + end_key=self._get_query_node_key(), + type=DashboardQuery.DASHBOARD_QUERY_RELATION_TYPE, + reverse_type=DashboardQuery.QUERY_DASHBOARD_RELATION_TYPE, + attributes={} + ) + yield relationship def _get_query_node_key(self) -> str: return DashboardQuery.DASHBOARD_QUERY_KEY_FORMAT.format( diff --git a/databuilder/models/dashboard/dashboard_table.py b/databuilder/models/dashboard/dashboard_table.py index 2bd210202..cd8acb142 100644 --- a/databuilder/models/dashboard/dashboard_table.py +++ b/databuilder/models/dashboard/dashboard_table.py @@ -4,19 +4,19 @@ import logging import re -from typing import Optional, Dict, Any, List, Union, Iterator +from typing import Optional, Any, List, Union, Iterator from databuilder.models.dashboard.dashboard_metadata import DashboardMetadata -from databuilder.models.neo4j_csv_serde import ( - Neo4jCsvSerializable, RELATION_START_KEY, RELATION_END_KEY, RELATION_START_LABEL, - RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE) +from databuilder.models.graph_serializable import ( + GraphSerializable) from databuilder.models.table_metadata import TableMetadata - +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship LOGGER = logging.getLogger(__name__) -class DashboardTable(Neo4jCsvSerializable): +class DashboardTable(GraphSerializable): """ A model that link Dashboard with the tables used in various charts of the dashboard. Note that it does not create new dashboard, table as it has insufficient information but it builds relation @@ -43,40 +43,41 @@ def __init__(self, self._relation_iterator = self._create_relation_iterator() - def create_next_node(self) -> Union[Dict[str, Any], None]: + def create_next_node(self) -> Union[GraphNode, None]: return None - def create_next_relation(self) -> Union[Dict[str, Any], None]: + def create_next_relation(self) -> Union[GraphRelationship, None]: if self._relation_iterator is None: return None - try: return next(self._relation_iterator) except StopIteration: return None - def _create_relation_iterator(self) -> Optional[Iterator[Dict[str, Any]]]: + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: for table_id in self._table_ids: m = re.match('([^./]+)://([^./]+)\.([^./]+)\/([^./]+)', table_id) if m: - yield { - RELATION_START_LABEL: DashboardMetadata.DASHBOARD_NODE_LABEL, - RELATION_END_LABEL: TableMetadata.TABLE_NODE_LABEL, - RELATION_START_KEY: DashboardMetadata.DASHBOARD_KEY_FORMAT.format( + relationship = GraphRelationship( + start_label=DashboardMetadata.DASHBOARD_NODE_LABEL, + end_label=TableMetadata.TABLE_NODE_LABEL, + start_key=DashboardMetadata.DASHBOARD_KEY_FORMAT.format( product=self._product, cluster=self._cluster, dashboard_group=self._dashboard_group_id, dashboard_name=self._dashboard_id ), - RELATION_END_KEY: TableMetadata.TABLE_KEY_FORMAT.format( + end_key=TableMetadata.TABLE_KEY_FORMAT.format( db=m.group(1), cluster=m.group(2), schema=m.group(3), tbl=m.group(4) ), - RELATION_TYPE: DashboardTable.DASHBOARD_TABLE_RELATION_TYPE, - RELATION_REVERSE_TYPE: DashboardTable.TABLE_DASHBOARD_RELATION_TYPE - } + type=DashboardTable.DASHBOARD_TABLE_RELATION_TYPE, + reverse_type=DashboardTable.TABLE_DASHBOARD_RELATION_TYPE, + attributes={} + ) + yield relationship def __repr__(self) -> str: return 'DashboardTable({!r}, {!r}, {!r}, {!r}, ({!r}))'.format( diff --git a/databuilder/models/dashboard/dashboard_usage.py b/databuilder/models/dashboard/dashboard_usage.py index 31dd1bfb0..5a30b1fa8 100644 --- a/databuilder/models/dashboard/dashboard_usage.py +++ b/databuilder/models/dashboard/dashboard_usage.py @@ -3,21 +3,23 @@ import logging -from typing import Optional, Dict, Any, Union, Iterator +from typing import Optional, Any, Union, Iterator from databuilder.models.dashboard.dashboard_metadata import DashboardMetadata -from databuilder.models.neo4j_csv_serde import ( - Neo4jCsvSerializable, RELATION_START_KEY, RELATION_END_KEY, RELATION_START_LABEL, - RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE) +from databuilder.models.graph_serializable import ( + GraphSerializable +) from databuilder.models.usage.usage_constants import ( READ_RELATION_TYPE, READ_REVERSE_RELATION_TYPE, READ_RELATION_COUNT_PROPERTY ) from databuilder.models.user import User +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship LOGGER = logging.getLogger(__name__) -class DashboardUsage(Neo4jCsvSerializable): +class DashboardUsage(GraphSerializable): """ A model that encapsulate Dashboard usage between Dashboard and User """ @@ -56,33 +58,36 @@ def __init__(self, self._should_create_user_node = bool(should_create_user_node) self._relation_iterator = self._create_relation_iterator() - def create_next_node(self) -> Union[Dict[str, Any], None]: + def create_next_node(self) -> Union[GraphNode, None]: if self._should_create_user_node: return self._user_model.create_next_node() return None - def create_next_relation(self) -> Union[Dict[str, Any], None]: + def create_next_relation(self) -> Union[GraphRelationship, None]: try: return next(self._relation_iterator) except StopIteration: return None - def _create_relation_iterator(self) -> Iterator[Dict[str, Any]]: - yield { - RELATION_START_LABEL: DashboardMetadata.DASHBOARD_NODE_LABEL, - RELATION_END_LABEL: User.USER_NODE_LABEL, - RELATION_START_KEY: DashboardMetadata.DASHBOARD_KEY_FORMAT.format( + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: + relationship = GraphRelationship( + start_label=DashboardMetadata.DASHBOARD_NODE_LABEL, + end_label=User.USER_NODE_LABEL, + start_key=DashboardMetadata.DASHBOARD_KEY_FORMAT.format( product=self._product, cluster=self._cluster, dashboard_group=self._dashboard_group_id, dashboard_name=self._dashboard_id ), - RELATION_END_KEY: User.get_user_model_key(email=self._email), - RELATION_TYPE: READ_REVERSE_RELATION_TYPE, - RELATION_REVERSE_TYPE: READ_RELATION_TYPE, - READ_RELATION_COUNT_PROPERTY: self._view_count - } + end_key=User.get_user_model_key(email=self._email), + type=READ_REVERSE_RELATION_TYPE, + reverse_type=READ_RELATION_TYPE, + attributes={ + READ_RELATION_COUNT_PROPERTY: self._view_count + } + ) + yield relationship def __repr__(self) -> str: return 'DashboardUsage({!r}, {!r}, {!r}, {!r}, {!r}, {!r}, {!r})'.format( diff --git a/databuilder/models/graph_node.py b/databuilder/models/graph_node.py new file mode 100644 index 000000000..52c9ed6f6 --- /dev/null +++ b/databuilder/models/graph_node.py @@ -0,0 +1,13 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from collections import namedtuple + +GraphNode = namedtuple( + 'GraphNode', + [ + 'key', + 'label', + 'attributes' + ] +) diff --git a/databuilder/models/graph_relationship.py b/databuilder/models/graph_relationship.py new file mode 100644 index 000000000..868c963a2 --- /dev/null +++ b/databuilder/models/graph_relationship.py @@ -0,0 +1,17 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from collections import namedtuple + +GraphRelationship = namedtuple( + 'GraphRelationship', + [ + 'start_label', + 'end_label', + 'start_key', + 'end_key', + 'type', + 'reverse_type', + 'attributes' + ] +) diff --git a/databuilder/models/graph_serializable.py b/databuilder/models/graph_serializable.py new file mode 100644 index 000000000..7c50c57d1 --- /dev/null +++ b/databuilder/models/graph_serializable.py @@ -0,0 +1,90 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +import abc + +from typing import Union # noqa: F401 +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship + +NODE_KEY = 'KEY' +NODE_LABEL = 'LABEL' + +RELATION_START_KEY = 'START_KEY' +RELATION_START_LABEL = 'START_LABEL' +RELATION_END_KEY = 'END_KEY' +RELATION_END_LABEL = 'END_LABEL' +RELATION_TYPE = 'TYPE' +RELATION_REVERSE_TYPE = 'REVERSE_TYPE' + + +class GraphSerializable(object, metaclass=abc.ABCMeta): + """ + A Serializable abstract class asks subclass to implement next node or + next relation in dict form so that it can be serialized to CSV file. + + Any model class that needs to be pushed to a graph database should inherit this class. + """ + def __init__(self) -> None: + pass + + @abc.abstractmethod + def create_next_node(self) -> Union[GraphNode, None]: + """ + Creates GraphNode the process that consumes this class takes the output + serializes to the desired graph database. + + :return: a GraphNode or None if no more records to serialize + """ + raise NotImplementedError + + @abc.abstractmethod + def create_next_relation(self) -> Union[GraphRelationship, None]: + """ + Creates GraphRelationship the process that consumes this class takes the output + serializes to the desired graph database. + + :return: a GraphRelationship or None if no more record to serialize + """ + raise NotImplementedError + + def next_node(self) -> Union[GraphNode, None]: + node_dict = self.create_next_node() + if not node_dict: + return None + + self._validate_node(node_dict) + return node_dict + + def next_relation(self) -> Union[GraphRelationship, None]: + relation_dict = self.create_next_relation() + if not relation_dict: + return None + + self._validate_relation(relation_dict) + return relation_dict + + def _validate_node(self, node: GraphNode) -> None: + node_id, node_label, _ = node + + if node_id is None: + RuntimeError('Required header missing. Required attributes id and label , Missing: id') + + if node_label is None: + RuntimeError('Required header missing. Required attributes id and label , Missing: label') + + self._validate_label_value(node_label) + + def _validate_relation(self, relation: GraphRelationship) -> None: + self._validate_label_value(relation.start_label) + self._validate_label_value(relation.end_label) + self._validate_relation_type_value(relation.type) + self._validate_relation_type_value(relation.reverse_type) + + def _validate_relation_type_value(self, value: str) -> None: + if not value == value.upper(): + raise RuntimeError('TYPE needs to be upper case: {}'.format(value)) + + def _validate_label_value(self, value: str) -> None: + if not value.istitle(): + raise RuntimeError('LABEL should only have upper case character on its first one: {}'.format(value)) diff --git a/databuilder/models/metric_metadata.py b/databuilder/models/metric_metadata.py index 572692273..fa25a1875 100644 --- a/databuilder/models/metric_metadata.py +++ b/databuilder/models/metric_metadata.py @@ -1,22 +1,19 @@ # Copyright Contributors to the Amundsen project. # SPDX-License-Identifier: Apache-2.0 -from collections import namedtuple - -from typing import Any, Iterator, Dict, List, Set, Union +from typing import Any, Iterator, List, Union, Set # TODO: We could separate TagMetadata from table_metadata to own module from databuilder.models.table_metadata import TagMetadata -from databuilder.models.neo4j_csv_serde import ( - Neo4jCsvSerializable, NODE_LABEL, NODE_KEY, RELATION_START_KEY, RELATION_END_KEY, RELATION_START_LABEL, - RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE) - +from databuilder.models.graph_serializable import ( + GraphSerializable +) -NodeTuple = namedtuple('KeyName', ['key', 'name', 'label']) -RelTuple = namedtuple('RelKeys', ['start_label', 'end_label', 'start_key', 'end_key', 'type', 'reverse_type']) +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship -class MetricMetadata(Neo4jCsvSerializable): +class MetricMetadata(GraphSerializable): """ Table metadata that contains columns. It implements Neo4jCsvSerializable so that it can be serialized to produce Table, Column and relation of those along with relationship with table and schema. Additionally, it will create @@ -73,7 +70,6 @@ def __init__(self, type: str, tags: List, ) -> None: - self.dashboard_group = dashboard_group self.dashboard_name = dashboard_name self.name = name @@ -108,39 +104,57 @@ def _get_dashboard_key(self) -> str: def _get_metric_description_key(self) -> str: return MetricMetadata.METRIC_DESCRIPTION_FORMAT.format(name=self.name) - def create_next_node(self) -> Union[Dict[str, Any], None]: + def create_next_node(self) -> Union[GraphNode, None]: try: return next(self._node_iterator) except StopIteration: return None - def _create_next_node(self) -> Iterator[Any]: - + def _create_next_node(self) -> Iterator[GraphNode]: # Metric node - yield {NODE_LABEL: MetricMetadata.METRIC_NODE_LABEL, - NODE_KEY: self._get_metric_key(), - MetricMetadata.METRIC_NAME: self.name, - MetricMetadata.METRIC_EXPRESSION_VALUE: self.expression - } + metric_node = GraphNode( + key=self._get_metric_key(), + label=MetricMetadata.METRIC_NODE_LABEL, + attributes={ + MetricMetadata.METRIC_NAME: self.name, + MetricMetadata.METRIC_EXPRESSION_VALUE: self.expression + } + ) + yield metric_node # Description node if self.description: - yield {NODE_LABEL: MetricMetadata.DESCRIPTION_NODE_LABEL, - NODE_KEY: self._get_metric_description_key(), - MetricMetadata.METRIC_DESCRIPTION: self.description} + description_node = GraphNode( + key=self._get_metric_description_key(), + label=MetricMetadata.DESCRIPTION_NODE_LABEL, + attributes={ + MetricMetadata.METRIC_DESCRIPTION: self.description + } + ) + yield description_node # Metric tag node if self.tags: for tag in self.tags: - yield {NODE_LABEL: TagMetadata.TAG_NODE_LABEL, - NODE_KEY: TagMetadata.get_tag_key(tag), - TagMetadata.TAG_TYPE: 'metric'} + tag_node = GraphNode( + key=TagMetadata.get_tag_key(tag), + label=TagMetadata.TAG_NODE_LABEL, + attributes={ + TagMetadata.TAG_TYPE: 'metric' + } + ) + yield tag_node # Metric type node if self.type: - yield {NODE_LABEL: MetricMetadata.METRIC_TYPE_NODE_LABEL, - NODE_KEY: self._get_metric_type_key(), - 'name': self.type} + type_node = GraphNode( + key=self._get_metric_type_key(), + label=MetricMetadata.METRIC_TYPE_NODE_LABEL, + attributes={ + 'name': self.type + } + ) + yield type_node # FIXME: this logic is wrong and does nothing presently others: List[Any] = [] @@ -148,63 +162,67 @@ def _create_next_node(self) -> Iterator[Any]: for node_tuple in others: if node_tuple not in MetricMetadata.serialized_nodes: MetricMetadata.serialized_nodes.add(node_tuple) - yield { - NODE_LABEL: node_tuple.label, - NODE_KEY: node_tuple.key, - 'name': node_tuple.name - } + yield node_tuple - def create_next_relation(self) -> Union[Dict[str, Any], None]: + def create_next_relation(self) -> Union[GraphRelationship, None]: try: return next(self._relation_iterator) except StopIteration: return None - def _create_next_relation(self) -> Iterator[Any]: + def _create_next_relation(self) -> Iterator[GraphRelationship]: # Dashboard > Metric relation - yield { - RELATION_START_LABEL: MetricMetadata.METRIC_NODE_LABEL, - RELATION_END_LABEL: MetricMetadata.DASHBOARD_NODE_LABEL, - RELATION_START_KEY: self._get_metric_key(), - RELATION_END_KEY: self._get_dashboard_key(), - RELATION_TYPE: MetricMetadata.METRIC_DASHBOARD_RELATION_TYPE, - RELATION_REVERSE_TYPE: MetricMetadata.DASHBOARD_METRIC_RELATION_TYPE - } + dashboard_metric_relation = GraphRelationship( + start_label=MetricMetadata.METRIC_NODE_LABEL, + start_key=self._get_metric_key(), + end_label=MetricMetadata.DASHBOARD_NODE_LABEL, + end_key=self._get_dashboard_key(), + type=MetricMetadata.METRIC_DASHBOARD_RELATION_TYPE, + reverse_type=MetricMetadata.DASHBOARD_METRIC_RELATION_TYPE, + attributes={} + ) + yield dashboard_metric_relation # Metric > Metric description relation if self.description: - yield { - RELATION_START_LABEL: MetricMetadata.METRIC_NODE_LABEL, - RELATION_END_LABEL: MetricMetadata.DESCRIPTION_NODE_LABEL, - RELATION_START_KEY: self._get_metric_key(), - RELATION_END_KEY: self._get_metric_description_key(), - RELATION_TYPE: MetricMetadata.METRIC_DESCRIPTION_RELATION_TYPE, - RELATION_REVERSE_TYPE: MetricMetadata.DESCRIPTION_METRIC_RELATION_TYPE - } + metric_description_relation = GraphRelationship( + start_label=MetricMetadata.METRIC_NODE_LABEL, + start_key=self._get_metric_key(), + end_label=MetricMetadata.DESCRIPTION_NODE_LABEL, + end_key=self._get_metric_description_key(), + type=MetricMetadata.METRIC_DESCRIPTION_RELATION_TYPE, + reverse_type=MetricMetadata.DESCRIPTION_METRIC_RELATION_TYPE, + attributes={} + ) + yield metric_description_relation # Metric > Metric tag relation if self.tags: for tag in self.tags: - yield { - RELATION_START_LABEL: MetricMetadata.METRIC_NODE_LABEL, - RELATION_END_LABEL: TagMetadata.TAG_NODE_LABEL, - RELATION_START_KEY: self._get_metric_key(), - RELATION_END_KEY: TagMetadata.get_tag_key(tag), - RELATION_TYPE: MetricMetadata.METRIC_TAG_RELATION_TYPE, - RELATION_REVERSE_TYPE: MetricMetadata.TAG_METRIC_RELATION_TYPE - } + tag_relation = GraphRelationship( + start_label=MetricMetadata.METRIC_NODE_LABEL, + start_key=self._get_metric_key(), + end_label=TagMetadata.TAG_NODE_LABEL, + end_key=TagMetadata.get_tag_key(tag), + type=MetricMetadata.METRIC_TAG_RELATION_TYPE, + reverse_type=MetricMetadata.TAG_METRIC_RELATION_TYPE, + attributes={} + ) + yield tag_relation # Metric > Metric type relation if self.type: - yield { - RELATION_START_LABEL: MetricMetadata.METRIC_NODE_LABEL, - RELATION_END_LABEL: MetricMetadata.METRIC_TYPE_NODE_LABEL, - RELATION_START_KEY: self._get_metric_key(), - RELATION_END_KEY: self._get_metric_type_key(), - RELATION_TYPE: MetricMetadata.METRIC_METRIC_TYPE_RELATION_TYPE, - RELATION_REVERSE_TYPE: MetricMetadata.METRIC_TYPE_METRIC_RELATION_TYPE - } + type_relation = GraphRelationship( + start_label=MetricMetadata.METRIC_NODE_LABEL, + start_key=self._get_metric_key(), + end_label=MetricMetadata.METRIC_TYPE_NODE_LABEL, + end_key=self._get_metric_type_key(), + type=MetricMetadata.METRIC_METRIC_TYPE_RELATION_TYPE, + reverse_type=MetricMetadata.METRIC_TYPE_METRIC_RELATION_TYPE, + attributes={} + ) + yield type_relation # FIXME: this logic is wrong and does nothing presently others: List[Any] = [] @@ -212,11 +230,4 @@ def _create_next_relation(self) -> Iterator[Any]: for rel_tuple in others: if rel_tuple not in MetricMetadata.serialized_rels: MetricMetadata.serialized_rels.add(rel_tuple) - yield { - RELATION_START_LABEL: rel_tuple.start_label, - RELATION_END_LABEL: rel_tuple.end_label, - RELATION_START_KEY: rel_tuple.start_key, - RELATION_END_KEY: rel_tuple.end_key, - RELATION_TYPE: rel_tuple.type, - RELATION_REVERSE_TYPE: rel_tuple.reverse_type - } + yield rel_tuple diff --git a/databuilder/models/neo4j_csv_serde.py b/databuilder/models/neo4j_csv_serde.py deleted file mode 100644 index 6c7f7e320..000000000 --- a/databuilder/models/neo4j_csv_serde.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright Contributors to the Amundsen project. -# SPDX-License-Identifier: Apache-2.0 - -import abc - -from typing import Dict, Set, Any, Union - -NODE_KEY = 'KEY' -NODE_LABEL = 'LABEL' -NODE_REQUIRED_HEADERS = {NODE_LABEL, NODE_KEY} - -RELATION_START_KEY = 'START_KEY' -RELATION_START_LABEL = 'START_LABEL' -RELATION_END_KEY = 'END_KEY' -RELATION_END_LABEL = 'END_LABEL' -RELATION_TYPE = 'TYPE' -RELATION_REVERSE_TYPE = 'REVERSE_TYPE' -RELATION_REQUIRED_HEADERS = {RELATION_START_KEY, RELATION_START_LABEL, - RELATION_END_KEY, RELATION_END_LABEL, - RELATION_TYPE, RELATION_REVERSE_TYPE} - -LABELS = {NODE_LABEL, RELATION_START_LABEL, RELATION_END_LABEL} -TYPES = {RELATION_TYPE, RELATION_REVERSE_TYPE} - - -class Neo4jCsvSerializable(object, metaclass=abc.ABCMeta): - """ - A Serializable abstract class asks subclass to implement next node or - next relation in dict form so that it can be serialized to CSV file. - - Any model class that needs to be pushed to Neo4j should inherit this class. - """ - - def __init__(self) -> None: - pass - - @abc.abstractmethod - def create_next_node(self) -> Union[Dict[str, Any], None]: - """ - Creates dict where keys represent header in CSV and value represents - row in CSV file. Should the class could have different types of - nodes that it needs to serialize, it just needs to provide dict with - different header -- the one who consumes this class figures it out and - serialize to different file. - - Node is Neo4j's term of Vertex in Graph. More information on - https://neo4j.com/docs/developer-manual/current/introduction/ - graphdb-concepts/ - :return: a dict or None if no more record to serialize - """ - raise NotImplementedError - - @abc.abstractmethod - def create_next_relation(self) -> Union[Dict[str, Any], None]: - """ - Creates dict where keys represent header in CSV and value represents - row in CSV file. Should the class could have different types of - relations that it needs to serialize, it just needs to provide dict - with different header -- the one who consumes this class figures it - out and serialize to different file. - - Relationship is Neo4j's term of Edge in Graph. More information on - https://neo4j.com/docs/developer-manual/current/introduction/ - graphdb-concepts/ - :return: a dict or None if no more record to serialize - """ - raise NotImplementedError - - def next_node(self) -> Union[Dict[str, Any], None]: - """ - Provides node(vertex) in dict form. - Note that subsequent call can create different header (dict.keys()) - which implicitly mean that it needs to be serialized in different - CSV file (as CSV is in fixed header) - :return: Non-nested dict where key is CSV header and each value - is a column - """ - node_dict = self.create_next_node() - if not node_dict: - return None - - self._validate(NODE_REQUIRED_HEADERS, node_dict) - return node_dict - - def next_relation(self) -> Union[Dict[str, Any], None]: - """ - Provides relation(edge) in dict form. - Note that subsequent call can create different header (dict.keys()) - which implicitly mean that it needs to be serialized in different - CSV file (as CSV is in fixed header) - :return: Non-nested dict where key is CSV header and each value - is a column - """ - relation_dict = self.create_next_relation() - if not relation_dict: - return None - - self._validate(RELATION_REQUIRED_HEADERS, relation_dict) - return relation_dict - - def _validate(self, - required_set: Set[str], - val_dict: Dict[str, Any]) -> None: - """ - Validates dict that represents CSV header and a row. - - Checks if it has required headers for either Node or Relation - - Checks value of LABEL if only first character is upper case - - Checks value of TYPE if it's all upper case characters - - :param required_set: - :param val_dict: - :return: - """ - required_count = 0 - for header_col, val_col in \ - ((header_col, val_col) for header_col, val_col - in val_dict.items() if header_col in required_set): - required_count += 1 - - if header_col in LABELS: - if not val_col.istitle(): - raise RuntimeError( - 'LABEL should only have upper case character on its ' - 'first one: {}'.format(val_col)) - elif header_col in TYPES: - if not val_col == val_col.upper(): - raise RuntimeError( - 'TYPE needs to be upper case: {}'.format(val_col)) - - if required_count != len(required_set): - raise RuntimeError( - 'Required header missing. Required: {} , Header: {}'.format( - required_set, val_dict.keys())) diff --git a/databuilder/models/neo4j_es_last_updated.py b/databuilder/models/neo4j_es_last_updated.py index c687a62f9..76039face 100644 --- a/databuilder/models/neo4j_es_last_updated.py +++ b/databuilder/models/neo4j_es_last_updated.py @@ -1,12 +1,14 @@ # Copyright Contributors to the Amundsen project. # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Union +from typing import List, Union -from databuilder.models.neo4j_csv_serde import Neo4jCsvSerializable, NODE_KEY, NODE_LABEL +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_node import GraphNode -class Neo4jESLastUpdated(Neo4jCsvSerializable): +class Neo4jESLastUpdated(GraphSerializable): """ Data model to keep track the last updated timestamp for neo4j and es. @@ -26,32 +28,33 @@ def __init__(self, self._node_iter = iter(self.create_nodes()) self._rel_iter = iter(self.create_relation()) - def create_next_node(self) -> Union[Dict[str, Any], None]: + def create_next_node(self) -> Union[GraphNode, None]: """ Will create an orphan node for last updated timestamp. - :return: """ try: return next(self._node_iter) except StopIteration: return None - def create_nodes(self) -> List[Dict[str, Any]]: + def create_nodes(self) -> List[GraphNode]: """ Create a list of Neo4j node records. - :return: """ - return [{ - NODE_KEY: Neo4jESLastUpdated.KEY, - NODE_LABEL: Neo4jESLastUpdated.LABEL, - Neo4jESLastUpdated.LATEST_TIMESTAMP: self.timestamp - }] - - def create_next_relation(self) -> Union[Dict[str, Any], None]: + node = GraphNode( + key=Neo4jESLastUpdated.KEY, + label=Neo4jESLastUpdated.LABEL, + attributes={ + Neo4jESLastUpdated.LATEST_TIMESTAMP: self.timestamp + } + ) + return [node] + + def create_next_relation(self) -> Union[GraphRelationship, None]: try: return next(self._rel_iter) except StopIteration: return None - def create_relation(self) -> List[Dict[str, Any]]: + def create_relation(self) -> List[GraphRelationship]: return [] diff --git a/databuilder/models/schema/schema.py b/databuilder/models/schema/schema.py index b7532a9b1..483abce36 100644 --- a/databuilder/models/schema/schema.py +++ b/databuilder/models/schema/schema.py @@ -1,15 +1,16 @@ # Copyright Contributors to the Amundsen project. # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, Any, Union, Iterator +from typing import Any, Union, Iterator -from databuilder.models.neo4j_csv_serde import ( - Neo4jCsvSerializable, NODE_LABEL, NODE_KEY) +from databuilder.models.graph_serializable import (GraphSerializable) from databuilder.models.schema.schema_constant import SCHEMA_NODE_LABEL, SCHEMA_NAME_ATTR from databuilder.models.table_metadata import DescriptionMetadata +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship -class SchemaModel(Neo4jCsvSerializable): +class SchemaModel(GraphSerializable): def __init__(self, schema_key: str, @@ -26,23 +27,26 @@ def __init__(self, self._node_iterator = self._create_node_iterator() self._relation_iterator = self._create_relation_iterator() - def create_next_node(self) -> Union[Dict[str, Any], None]: + def create_next_node(self) -> Union[GraphNode, None]: try: return next(self._node_iterator) except StopIteration: return None - def _create_node_iterator(self) -> Iterator[Dict[str, Any]]: - yield { - NODE_LABEL: SCHEMA_NODE_LABEL, - NODE_KEY: self._schema_key, - SCHEMA_NAME_ATTR: self._schema, - } + def _create_node_iterator(self) -> Iterator[GraphNode]: + node = GraphNode( + key=self._schema_key, + label=SCHEMA_NODE_LABEL, + attributes={ + SCHEMA_NAME_ATTR: self._schema, + } + ) + yield node if self._description: - yield self._description.get_node_dict(self._get_description_node_key()) + yield self._description.get_node(self._get_description_node_key()) - def create_next_relation(self) -> Union[Dict[str, Any], None]: + def create_next_relation(self) -> Union[GraphRelationship, None]: try: return next(self._relation_iterator) except StopIteration: @@ -52,7 +56,7 @@ def _get_description_node_key(self) -> str: desc = self._description.get_description_id() if self._description is not None else '' return '{}/{}'.format(self._schema_key, desc) - def _create_relation_iterator(self) -> Iterator[Dict[str, Any]]: + def _create_relation_iterator(self) -> Iterator[GraphRelationship]: if self._description: yield self._description.get_relation(start_node=SCHEMA_NODE_LABEL, start_key=self._schema_key, diff --git a/databuilder/models/table_column_usage.py b/databuilder/models/table_column_usage.py index 4aef0a6a0..cc62ecd9b 100644 --- a/databuilder/models/table_column_usage.py +++ b/databuilder/models/table_column_usage.py @@ -1,15 +1,15 @@ # Copyright Contributors to the Amundsen project. # SPDX-License-Identifier: Apache-2.0 -from typing import Iterable, Union, Dict, Any, Iterator +from typing import Iterable, Union, Iterator -from databuilder.models.neo4j_csv_serde import ( - Neo4jCsvSerializable, RELATION_START_KEY, RELATION_END_KEY, - RELATION_START_LABEL, RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE +from databuilder.models.graph_serializable import ( + GraphSerializable ) from databuilder.models.table_metadata import TableMetadata from databuilder.models.user import User -from databuilder.publisher.neo4j_csv_publisher import UNQUOTED_SUFFIX +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship class ColumnReader(object): @@ -40,7 +40,7 @@ def __repr__(self) -> str: .format(self.database, self.cluster, self.schema, self.table, self.column, self.user_email, self.read_count) -class TableColumnUsage(Neo4jCsvSerializable): +class TableColumnUsage(GraphSerializable): """ A model represents user <--> column graph model Currently it only support to serialize to table level @@ -52,7 +52,7 @@ class TableColumnUsage(Neo4jCsvSerializable): TABLE_USER_RELATION_TYPE = 'READ_BY' # Property key for relationship read, readby relationship - READ_RELATION_COUNT = 'read_count{}'.format(UNQUOTED_SUFFIX) + READ_RELATION_COUNT = 'read_count' def __init__(self, col_readers: Iterable[ColumnReader], @@ -65,37 +65,38 @@ def __init__(self, self._node_iterator = self._create_node_iterator() self._rel_iter = self._create_rel_iterator() - def create_next_node(self) -> Union[Dict[str, Any], None]: - + def create_next_node(self) -> Union[GraphNode, None]: try: return next(self._node_iterator) except StopIteration: return None - def _create_node_iterator(self) -> Iterator[Any]: + def _create_node_iterator(self) -> Iterator[GraphNode]: for col_reader in self.col_readers: if col_reader.column == '*': # using yield for better memory efficiency yield User(email=col_reader.user_email).create_nodes()[0] - def create_next_relation(self) -> Union[Dict[str, Any], None]: - + def create_next_relation(self) -> Union[GraphRelationship, None]: try: return next(self._rel_iter) except StopIteration: return None - def _create_rel_iterator(self) -> Iterator[Any]: + def _create_rel_iterator(self) -> Iterator[GraphRelationship]: for col_reader in self.col_readers: - yield { - RELATION_START_LABEL: TableMetadata.TABLE_NODE_LABEL, - RELATION_END_LABEL: User.USER_NODE_LABEL, - RELATION_START_KEY: self._get_table_key(col_reader), - RELATION_END_KEY: self._get_user_key(col_reader.user_email), - RELATION_TYPE: TableColumnUsage.TABLE_USER_RELATION_TYPE, - RELATION_REVERSE_TYPE: TableColumnUsage.USER_TABLE_RELATION_TYPE, - TableColumnUsage.READ_RELATION_COUNT: col_reader.read_count - } + relationship = GraphRelationship( + start_label=TableMetadata.TABLE_NODE_LABEL, + start_key=self._get_table_key(col_reader), + end_label=User.USER_NODE_LABEL, + end_key=self._get_user_key(col_reader.user_email), + type=TableColumnUsage.TABLE_USER_RELATION_TYPE, + reverse_type=TableColumnUsage.USER_TABLE_RELATION_TYPE, + attributes={ + TableColumnUsage.READ_RELATION_COUNT: col_reader.read_count + } + ) + yield relationship def _get_table_key(self, col_reader: ColumnReader) -> str: return TableMetadata.TABLE_KEY_FORMAT.format(db=col_reader.database, diff --git a/databuilder/models/table_last_updated.py b/databuilder/models/table_last_updated.py index f4082d597..b70385c8f 100644 --- a/databuilder/models/table_last_updated.py +++ b/databuilder/models/table_last_updated.py @@ -1,17 +1,17 @@ # Copyright Contributors to the Amundsen project. # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Union +from typing import List, Union -from databuilder.models.neo4j_csv_serde import Neo4jCsvSerializable, NODE_KEY, \ - NODE_LABEL, RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ - RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE +from databuilder.models.graph_serializable import GraphSerializable from databuilder.models.table_metadata import TableMetadata from databuilder.models.timestamp import timestamp_constants +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship -class TableLastUpdated(Neo4jCsvSerializable): +class TableLastUpdated(GraphSerializable): # constants LAST_UPDATED_NODE_LABEL = timestamp_constants.NODE_LABEL LAST_UPDATED_KEY_FORMAT = '{db}://{cluster}.{schema}/{tbl}/timestamp' @@ -42,14 +42,14 @@ def __repr__(self) -> str: """TableLastUpdated(table_name={!r}, last_updated_time={!r}, schema={!r}, db={!r}, cluster={!r})"""\ .format(self.table_name, self.last_updated_time, self.schema, self.db, self.cluster) - def create_next_node(self) -> Union[Dict[str, Any], None]: + def create_next_node(self) -> Union[GraphNode, None]: # creates new node try: return next(self._node_iter) except StopIteration: return None - def create_next_relation(self) -> Union[Dict[str, Any], None]: + def create_next_relation(self) -> Union[GraphRelationship, None]: try: return next(self._relation_iter) except StopIteration: @@ -69,35 +69,41 @@ def get_last_updated_model_key(self) -> str: schema=self.schema, tbl=self.table_name) - def create_nodes(self) -> List[Dict[str, Any]]: + def create_nodes(self) -> List[GraphNode]: """ Create a list of Neo4j node records :return: """ results = [] - results.append({ - NODE_KEY: self.get_last_updated_model_key(), - NODE_LABEL: TableLastUpdated.LAST_UPDATED_NODE_LABEL, - TableLastUpdated.TIMESTAMP_PROPERTY: self.last_updated_time, - timestamp_constants.TIMESTAMP_PROPERTY: self.last_updated_time, - TableLastUpdated.TIMESTAMP_NAME_PROPERTY: timestamp_constants.TimestampName.last_updated_timestamp.name - }) + node = GraphNode( + key=self.get_last_updated_model_key(), + label=TableLastUpdated.LAST_UPDATED_NODE_LABEL, + attributes={ + TableLastUpdated.TIMESTAMP_PROPERTY: self.last_updated_time, + timestamp_constants.TIMESTAMP_PROPERTY: self.last_updated_time, + TableLastUpdated.TIMESTAMP_NAME_PROPERTY: timestamp_constants.TimestampName.last_updated_timestamp.name + } + ) + + results.append(node) return results - def create_relation(self) -> List[Dict[str, Any]]: + def create_relation(self) -> List[GraphRelationship]: """ Create a list of relations mapping last updated node with table node :return: """ - results = [{ - RELATION_START_KEY: self.get_table_model_key(), - RELATION_START_LABEL: TableMetadata.TABLE_NODE_LABEL, - RELATION_END_KEY: self.get_last_updated_model_key(), - RELATION_END_LABEL: TableLastUpdated.LAST_UPDATED_NODE_LABEL, - RELATION_TYPE: TableLastUpdated.TABLE_LASTUPDATED_RELATION_TYPE, - RELATION_REVERSE_TYPE: TableLastUpdated.LASTUPDATED_TABLE_RELATION_TYPE - }] + relationship = GraphRelationship( + start_label=TableMetadata.TABLE_NODE_LABEL, + start_key=self.get_table_model_key(), + end_label=TableLastUpdated.LAST_UPDATED_NODE_LABEL, + end_key=self.get_last_updated_model_key(), + type=TableLastUpdated.TABLE_LASTUPDATED_RELATION_TYPE, + reverse_type=TableLastUpdated.LASTUPDATED_TABLE_RELATION_TYPE, + attributes={} + ) + results = [relationship] return results diff --git a/databuilder/models/table_lineage.py b/databuilder/models/table_lineage.py index ff35d455e..fda8e1155 100644 --- a/databuilder/models/table_lineage.py +++ b/databuilder/models/table_lineage.py @@ -2,16 +2,16 @@ # SPDX-License-Identifier: Apache-2.0 import re -from typing import Any, Dict, List, Union +from typing import List, Union -from databuilder.models.neo4j_csv_serde import Neo4jCsvSerializable, \ - RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ - RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE +from databuilder.models.graph_serializable import GraphSerializable from databuilder.models.table_metadata import TableMetadata +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship -class TableLineage(Neo4jCsvSerializable): +class TableLineage(GraphSerializable): """ Table Lineage Model. It won't create nodes but create upstream/downstream rels. """ @@ -25,7 +25,7 @@ def __init__(self, schema: str, table_name: str, cluster: str, - downstream_deps: List=None, + downstream_deps: List = None, ) -> None: self.db = db_name self.schema = schema @@ -37,14 +37,14 @@ def __init__(self, self._node_iter = iter(self.create_nodes()) self._relation_iter = iter(self.create_relation()) - def create_next_node(self) -> Union[Dict[str, Any], None]: + def create_next_node(self) -> Union[GraphNode, None]: # return the string representation of the data try: return next(self._node_iter) except StopIteration: return None - def create_next_relation(self) -> Union[Dict[str, Any], None]: + def create_next_relation(self) -> Union[GraphRelationship, None]: try: return next(self._relation_iter) except StopIteration: @@ -61,14 +61,14 @@ def get_table_model_key(self, schema=schema, table=table) - def create_nodes(self) -> List[Union[Dict[str, Any], None]]: + def create_nodes(self) -> List[Union[GraphNode, None]]: """ It won't create any node for this model :return: """ return [] - def create_relation(self) -> List[Dict[str, Any]]: + def create_relation(self) -> List[GraphRelationship]: """ Create a list of relation between source table and all the downstream tables :return: @@ -80,20 +80,26 @@ def create_relation(self) -> List[Dict[str, Any]]: m = re.match('(\w+)://(\w+)\.(\w+)\/(\w+)', downstream_tab) if m: # if not match, skip those records - results.append({ - RELATION_START_KEY: self.get_table_model_key(db=self.db, - cluster=self.cluster, - schema=self.schema, - table=self.table), - RELATION_START_LABEL: TableMetadata.TABLE_NODE_LABEL, - RELATION_END_KEY: self.get_table_model_key(db=m.group(1), - cluster=m.group(2), - schema=m.group(3), - table=m.group(4)), - RELATION_END_LABEL: TableMetadata.TABLE_NODE_LABEL, - RELATION_TYPE: TableLineage.ORIGIN_DEPENDENCY_RELATION_TYPE, - RELATION_REVERSE_TYPE: TableLineage.DEPENDENCY_ORIGIN_RELATION_TYPE - }) + relationship = GraphRelationship( + start_key=self.get_table_model_key( + db=self.db, + cluster=self.cluster, + schema=self.schema, + table=self.table + ), + start_label=TableMetadata.TABLE_NODE_LABEL, + end_label=TableMetadata.TABLE_NODE_LABEL, + end_key=self.get_table_model_key( + db=m.group(1), + cluster=m.group(2), + schema=m.group(3), + table=m.group(4) + ), + type=TableLineage.ORIGIN_DEPENDENCY_RELATION_TYPE, + reverse_type=TableLineage.DEPENDENCY_ORIGIN_RELATION_TYPE, + attributes={} + ) + results.append(relationship) return results def __repr__(self) -> str: diff --git a/databuilder/models/table_metadata.py b/databuilder/models/table_metadata.py index 94b285622..6441fe102 100644 --- a/databuilder/models/table_metadata.py +++ b/databuilder/models/table_metadata.py @@ -2,23 +2,22 @@ # SPDX-License-Identifier: Apache-2.0 import copy -from collections import namedtuple from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Union from databuilder.models.cluster import cluster_constants -from databuilder.models.neo4j_csv_serde import ( - Neo4jCsvSerializable, NODE_LABEL, NODE_KEY, RELATION_START_KEY, RELATION_END_KEY, RELATION_START_LABEL, - RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE) -from databuilder.publisher.neo4j_csv_publisher import UNQUOTED_SUFFIX +from databuilder.models.graph_serializable import GraphSerializable from databuilder.models.schema import schema_constant from databuilder.models.badge import BadgeMetadata, Badge +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship + DESCRIPTION_NODE_LABEL_VAL = 'Description' DESCRIPTION_NODE_LABEL = DESCRIPTION_NODE_LABEL_VAL -class TagMetadata(Neo4jCsvSerializable): +class TagMetadata(GraphSerializable): TAG_NODE_LABEL = 'Tag' TAG_KEY_FORMAT = '{tag}' TAG_TYPE = 'tag_type' @@ -34,7 +33,7 @@ def __init__(self, self._name = name self._tag_type = tag_type self._nodes = iter([self.create_tag_node(self._name, self._tag_type)]) - self._relations: Iterator[Dict[str, Any]] = iter([]) + self._relations: Iterator[GraphRelationship] = iter([]) @staticmethod def get_tag_key(name: str) -> str: @@ -43,21 +42,24 @@ def get_tag_key(name: str) -> str: return TagMetadata.TAG_KEY_FORMAT.format(tag=name) @staticmethod - def create_tag_node(name: str, - tag_type: str = DEFAULT_TYPE - ) -> Dict[str, str]: - return {NODE_LABEL: TagMetadata.TAG_NODE_LABEL, - NODE_KEY: TagMetadata.get_tag_key(name), - TagMetadata.TAG_TYPE: tag_type} - - def create_next_node(self) -> Optional[Dict[str, Any]]: + def create_tag_node(name: str, tag_type: str = DEFAULT_TYPE) -> GraphNode: + node = GraphNode( + key=TagMetadata.get_tag_key(name), + label=TagMetadata.TAG_NODE_LABEL, + attributes={ + TagMetadata.TAG_TYPE: tag_type + } + ) + return node + + def create_next_node(self) -> Optional[GraphNode]: # return the string representation of the data try: return next(self._nodes) except StopIteration: return None - def create_next_relation(self) -> Optional[Dict[str, Any]]: + def create_next_relation(self) -> Optional[GraphRelationship]: # We don't emit any relations for Tag ingestion try: return next(self._relations) @@ -117,29 +119,28 @@ def get_description_id(self) -> str: def __repr__(self) -> str: return 'DescriptionMetadata({!r}, {!r})'.format(self._source, self._text) - def get_node_dict(self, - node_key: str - ) -> Dict[str, str]: - return { - NODE_LABEL: self._label, - NODE_KEY: node_key, - DescriptionMetadata.DESCRIPTION_SOURCE: self._source, - DescriptionMetadata.DESCRIPTION_TEXT: self._text or '', - } - - def get_relation(self, - start_node: str, - start_key: str, - end_key: str - ) -> Dict[str, str]: - return { - RELATION_START_LABEL: start_node, - RELATION_END_LABEL: self._label, - RELATION_START_KEY: start_key, - RELATION_END_KEY: end_key, - RELATION_TYPE: DescriptionMetadata.DESCRIPTION_RELATION_TYPE, - RELATION_REVERSE_TYPE: DescriptionMetadata.INVERSE_DESCRIPTION_RELATION_TYPE - } + def get_node(self, node_key: str) -> GraphNode: + node = GraphNode( + key=node_key, + label=self._label, + attributes={ + DescriptionMetadata.DESCRIPTION_SOURCE: self._source, + DescriptionMetadata.DESCRIPTION_TEXT: self._text + } + ) + return node + + def get_relation(self, start_node: str, start_key: Any, end_key: Any) -> GraphRelationship: + relationship = GraphRelationship( + start_label=start_node, + start_key=start_key, + end_label=self._label, + end_key=end_key, + type=DescriptionMetadata.DESCRIPTION_RELATION_TYPE, + reverse_type=DescriptionMetadata.INVERSE_DESCRIPTION_RELATION_TYPE, + attributes={} + ) + return relationship class ColumnMetadata: @@ -147,7 +148,7 @@ class ColumnMetadata: COLUMN_KEY_FORMAT = '{db}://{cluster}.{schema}/{tbl}/{col}' COLUMN_NAME = 'name' COLUMN_TYPE = 'type' - COLUMN_ORDER = 'sort_order{}'.format(UNQUOTED_SUFFIX) # int value needs to be unquoted when publish to neo4j + COLUMN_ORDER = 'sort_order' COLUMN_DESCRIPTION = 'description' COLUMN_DESCRIPTION_FORMAT = '{db}://{cluster}.{schema}/{tbl}/{col}/{description_id}' @@ -183,12 +184,7 @@ def __repr__(self) -> str: self.badges) -# Tuples for de-dupe purpose on Database, Cluster, Schema. See TableMetadata docstring for more information -NodeTuple = namedtuple('KeyName', ['key', 'name', 'label']) -RelTuple = namedtuple('RelKeys', ['start_label', 'end_label', 'start_key', 'end_key', 'type', 'reverse_type']) - - -class TableMetadata(Neo4jCsvSerializable): +class TableMetadata(GraphSerializable): """ Table metadata that contains columns. It implements Neo4jCsvSerializable so that it can be serialized to produce Table, Column and relation of those along with relationship with table and schema. Additionally, it will create @@ -202,7 +198,7 @@ class TableMetadata(Neo4jCsvSerializable): TABLE_NODE_LABEL = 'Table' TABLE_KEY_FORMAT = '{db}://{cluster}.{schema}/{tbl}' TABLE_NAME = 'name' - IS_VIEW = 'is_view{}'.format(UNQUOTED_SUFFIX) # bool value needs to be unquoted when publish to neo4j + IS_VIEW = 'is_view' TABLE_DESCRIPTION_FORMAT = '{db}://{cluster}.{schema}/{tbl}/{description_id}' @@ -228,8 +224,8 @@ class TableMetadata(Neo4jCsvSerializable): TAG_TABLE_RELATION_TYPE = 'TAG' # Only for deduping database, cluster, and schema (table and column will be always processed) - serialized_nodes: Set[Any] = set() - serialized_rels: Set[Any] = set() + serialized_nodes_keys: Set[Any] = set() + serialized_rels_keys: Set[Any] = set() def __init__(self, database: str, @@ -289,8 +285,7 @@ def _get_table_key(self) -> str: schema=self.schema, tbl=self.name) - def _get_table_description_key(self, - description: DescriptionMetadata) -> str: + def _get_table_description_key(self, description: DescriptionMetadata) -> str: return TableMetadata.TABLE_DESCRIPTION_FORMAT.format(db=self.database, cluster=self.cluster, schema=self.schema, @@ -338,27 +333,18 @@ def format_tags(tags: Union[List, str, None]) -> List: return tags - def create_next_node(self) -> Union[Dict[str, Any], None]: + def create_next_node(self) -> Union[GraphNode, None]: try: return next(self._node_iterator) except StopIteration: return None - def _create_next_node(self) -> Iterator[Any]: # noqa: C901 - - table_node = {NODE_LABEL: TableMetadata.TABLE_NODE_LABEL, - NODE_KEY: self._get_table_key(), - TableMetadata.TABLE_NAME: self.name, - TableMetadata.IS_VIEW: self.is_view} - if self.attrs: - for k, v in self.attrs.items(): - if k not in table_node: - table_node[k] = v - yield table_node + def _create_next_node(self) -> Iterator[GraphNode]: + yield self._create_table_node() if self.description: node_key = self._get_table_description_key(self.description) - yield self.description.get_node_dict(node_key) + yield self.description.get_node(node_key) # Create the table tag node if self.tags: @@ -366,16 +352,20 @@ def _create_next_node(self) -> Iterator[Any]: # noqa: C901 yield TagMetadata.create_tag_node(tag) for col in self.columns: - yield { - NODE_LABEL: ColumnMetadata.COLUMN_NODE_LABEL, - NODE_KEY: self._get_col_key(col), - ColumnMetadata.COLUMN_NAME: col.name, - ColumnMetadata.COLUMN_TYPE: col.type, - ColumnMetadata.COLUMN_ORDER: col.sort_order} + column_node = GraphNode( + key=self._get_col_key(col), + label=ColumnMetadata.COLUMN_NODE_LABEL, + attributes={ + ColumnMetadata.COLUMN_NAME: col.name, + ColumnMetadata.COLUMN_TYPE: col.type, + ColumnMetadata.COLUMN_ORDER: col.sort_order + } + ) + yield column_node if col.description: node_key = self._get_col_description_key(col, col.description) - yield col.description.get_node_dict(node_key) + yield col.description.get_node(node_key) if col.badges: badge_metadata = BadgeMetadata(start_label=ColumnMetadata.COLUMN_NODE_LABEL, @@ -386,42 +376,68 @@ def _create_next_node(self) -> Iterator[Any]: # noqa: C901 yield node # Database, cluster, schema - others = [NodeTuple(key=self._get_database_key(), - name=self.database, - label=TableMetadata.DATABASE_NODE_LABEL), - NodeTuple(key=self._get_cluster_key(), - name=self.cluster, - label=TableMetadata.CLUSTER_NODE_LABEL), - NodeTuple(key=self._get_schema_key(), - name=self.schema, - label=TableMetadata.SCHEMA_NODE_LABEL) - ] + others = [ + GraphNode( + key=self._get_database_key(), + label=TableMetadata.DATABASE_NODE_LABEL, + attributes={ + 'name': self.database + } + ), + GraphNode( + key=self._get_cluster_key(), + label=TableMetadata.CLUSTER_NODE_LABEL, + attributes={ + 'name': self.cluster + } + ), + GraphNode( + key=self._get_schema_key(), + label=TableMetadata.SCHEMA_NODE_LABEL, + attributes={ + 'name': self.schema + } + ) + ] for node_tuple in others: - if node_tuple not in TableMetadata.serialized_nodes: - TableMetadata.serialized_nodes.add(node_tuple) - yield { - NODE_LABEL: node_tuple.label, - NODE_KEY: node_tuple.key, - 'name': node_tuple.name - } + if node_tuple.key not in TableMetadata.serialized_nodes_keys: + TableMetadata.serialized_nodes_keys.add(node_tuple.key) + yield node_tuple + + def _create_table_node(self) -> GraphNode: + table_attributes = { + TableMetadata.TABLE_NAME: self.name, + TableMetadata.IS_VIEW: self.is_view + } + if self.attrs: + for k, v in self.attrs.items(): + if k not in table_attributes: + table_attributes[k] = v - def create_next_relation(self) -> Union[Dict[str, Any], None]: + return GraphNode( + key=self._get_table_key(), + label=TableMetadata.TABLE_NODE_LABEL, + attributes=table_attributes + ) + + def create_next_relation(self) -> Union[GraphRelationship, None]: try: return next(self._relation_iterator) except StopIteration: return None - def _create_next_relation(self) -> Iterator[Any]: - - yield { - RELATION_START_LABEL: TableMetadata.SCHEMA_NODE_LABEL, - RELATION_END_LABEL: TableMetadata.TABLE_NODE_LABEL, - RELATION_START_KEY: self._get_schema_key(), - RELATION_END_KEY: self._get_table_key(), - RELATION_TYPE: TableMetadata.SCHEMA_TABLE_RELATION_TYPE, - RELATION_REVERSE_TYPE: TableMetadata.TABLE_SCHEMA_RELATION_TYPE - } + def _create_next_relation(self) -> Iterator[GraphRelationship]: + schema_table_relationship = GraphRelationship( + start_key=self._get_schema_key(), + start_label=TableMetadata.SCHEMA_NODE_LABEL, + end_key=self._get_table_key(), + end_label=TableMetadata.TABLE_NODE_LABEL, + type=TableMetadata.SCHEMA_TABLE_RELATION_TYPE, + reverse_type=TableMetadata.TABLE_SCHEMA_RELATION_TYPE, + attributes={} + ) + yield schema_table_relationship if self.description: yield self.description.get_relation(TableMetadata.TABLE_NODE_LABEL, @@ -430,29 +446,36 @@ def _create_next_relation(self) -> Iterator[Any]: if self.tags: for tag in self.tags: - yield { - RELATION_START_LABEL: TableMetadata.TABLE_NODE_LABEL, - RELATION_END_LABEL: TagMetadata.TAG_NODE_LABEL, - RELATION_START_KEY: self._get_table_key(), - RELATION_END_KEY: TagMetadata.get_tag_key(tag), - RELATION_TYPE: TableMetadata.TABLE_TAG_RELATION_TYPE, - RELATION_REVERSE_TYPE: TableMetadata.TAG_TABLE_RELATION_TYPE, - } + tag_relationship = GraphRelationship( + start_label=TableMetadata.TABLE_NODE_LABEL, + start_key=self._get_table_key(), + end_label=TagMetadata.TAG_NODE_LABEL, + end_key=TagMetadata.get_tag_key(tag), + type=TableMetadata.TABLE_TAG_RELATION_TYPE, + reverse_type=TableMetadata.TAG_TABLE_RELATION_TYPE, + attributes={} + ) + yield tag_relationship for col in self.columns: - yield { - RELATION_START_LABEL: TableMetadata.TABLE_NODE_LABEL, - RELATION_END_LABEL: ColumnMetadata.COLUMN_NODE_LABEL, - RELATION_START_KEY: self._get_table_key(), - RELATION_END_KEY: self._get_col_key(col), - RELATION_TYPE: TableMetadata.TABLE_COL_RELATION_TYPE, - RELATION_REVERSE_TYPE: TableMetadata.COL_TABLE_RELATION_TYPE - } + column_relationship = GraphRelationship( + start_label=TableMetadata.TABLE_NODE_LABEL, + start_key=self._get_table_key(), + end_label=ColumnMetadata.COLUMN_NODE_LABEL, + end_key=self._get_col_key(col), + type=TableMetadata.TABLE_COL_RELATION_TYPE, + reverse_type=TableMetadata.COL_TABLE_RELATION_TYPE, + attributes={} + ) + yield column_relationship if col.description: - yield col.description.get_relation(ColumnMetadata.COLUMN_NODE_LABEL, - self._get_col_key(col), - self._get_col_description_key(col, col.description)) + yield col.description.get_relation( + ColumnMetadata.COLUMN_NODE_LABEL, + self._get_col_key(col), + self._get_col_description_key(col, col.description) + ) + if col.badges: badge_metadata = BadgeMetadata(start_label=ColumnMetadata.COLUMN_NODE_LABEL, start_key=self._get_col_key(col), @@ -462,28 +485,27 @@ def _create_next_relation(self) -> Iterator[Any]: yield relation others = [ - RelTuple(start_label=TableMetadata.DATABASE_NODE_LABEL, - end_label=TableMetadata.CLUSTER_NODE_LABEL, - start_key=self._get_database_key(), - end_key=self._get_cluster_key(), - type=TableMetadata.DATABASE_CLUSTER_RELATION_TYPE, - reverse_type=TableMetadata.CLUSTER_DATABASE_RELATION_TYPE), - RelTuple(start_label=TableMetadata.CLUSTER_NODE_LABEL, - end_label=TableMetadata.SCHEMA_NODE_LABEL, - start_key=self._get_cluster_key(), - end_key=self._get_schema_key(), - type=TableMetadata.CLUSTER_SCHEMA_RELATION_TYPE, - reverse_type=TableMetadata.SCHEMA_CLUSTER_RELATION_TYPE) + GraphRelationship( + start_label=TableMetadata.DATABASE_NODE_LABEL, + end_label=TableMetadata.CLUSTER_NODE_LABEL, + start_key=self._get_database_key(), + end_key=self._get_cluster_key(), + type=TableMetadata.DATABASE_CLUSTER_RELATION_TYPE, + reverse_type=TableMetadata.CLUSTER_DATABASE_RELATION_TYPE, + attributes={} + ), + GraphRelationship( + start_label=TableMetadata.CLUSTER_NODE_LABEL, + end_label=TableMetadata.SCHEMA_NODE_LABEL, + start_key=self._get_cluster_key(), + end_key=self._get_schema_key(), + type=TableMetadata.CLUSTER_SCHEMA_RELATION_TYPE, + reverse_type=TableMetadata.SCHEMA_CLUSTER_RELATION_TYPE, + attributes={} + ) ] for rel_tuple in others: - if rel_tuple not in TableMetadata.serialized_rels: - TableMetadata.serialized_rels.add(rel_tuple) - yield { - RELATION_START_LABEL: rel_tuple.start_label, - RELATION_END_LABEL: rel_tuple.end_label, - RELATION_START_KEY: rel_tuple.start_key, - RELATION_END_KEY: rel_tuple.end_key, - RELATION_TYPE: rel_tuple.type, - RELATION_REVERSE_TYPE: rel_tuple.reverse_type - } + if (rel_tuple.start_key, rel_tuple.end_key, rel_tuple.type) not in TableMetadata.serialized_rels_keys: + TableMetadata.serialized_rels_keys.add((rel_tuple.start_key, rel_tuple.end_key, rel_tuple.type)) + yield rel_tuple diff --git a/databuilder/models/table_owner.py b/databuilder/models/table_owner.py index 066f49ce7..c9b2503fb 100644 --- a/databuilder/models/table_owner.py +++ b/databuilder/models/table_owner.py @@ -1,16 +1,16 @@ # Copyright Contributors to the Amundsen project. # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional, Union +from typing import List, Optional, Union -from databuilder.models.neo4j_csv_serde import Neo4jCsvSerializable, NODE_KEY, \ - NODE_LABEL, RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ - RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE +from databuilder.models.graph_serializable import GraphSerializable from databuilder.models.owner_constants import OWNER_RELATION_TYPE, OWNER_OF_OBJECT_RELATION_TYPE from databuilder.models.user import User +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship -class TableOwner(Neo4jCsvSerializable): +class TableOwner(GraphSerializable): """ Hive table owner model. """ @@ -35,14 +35,14 @@ def __init__(self, self._node_iter = iter(self.create_nodes()) self._relation_iter = iter(self.create_relation()) - def create_next_node(self) -> Optional[Dict[str, Any]]: + def create_next_node(self) -> Optional[GraphNode]: # return the string representation of the data try: return next(self._node_iter) except StopIteration: return None - def create_next_relation(self) -> Optional[Dict[str, Any]]: + def create_next_relation(self) -> Optional[GraphRelationship]: try: return next(self._relation_iter) except StopIteration: @@ -57,7 +57,7 @@ def get_metadata_model_key(self) -> str: schema=self.schema, table=self.table) - def create_nodes(self) -> List[Dict[str, Any]]: + def create_nodes(self) -> List[GraphNode]: """ Create a list of Neo4j node records :return: @@ -65,28 +65,33 @@ def create_nodes(self) -> List[Dict[str, Any]]: results = [] for owner in self.owners: if owner: - results.append({ - NODE_KEY: self.get_owner_model_key(owner), - NODE_LABEL: User.USER_NODE_LABEL, - User.USER_NODE_EMAIL: owner - }) + node = GraphNode( + key=self.get_owner_model_key(owner), + label=User.USER_NODE_LABEL, + attributes={ + User.USER_NODE_EMAIL: owner + } + ) + results.append(node) return results - def create_relation(self) -> List[Dict[str, Any]]: + def create_relation(self) -> List[GraphRelationship]: """ Create a list of relation map between owner record with original hive table :return: """ results = [] for owner in self.owners: - results.append({ - RELATION_START_KEY: self.get_owner_model_key(owner), - RELATION_START_LABEL: User.USER_NODE_LABEL, - RELATION_END_KEY: self.get_metadata_model_key(), - RELATION_END_LABEL: 'Table', - RELATION_TYPE: TableOwner.OWNER_TABLE_RELATION_TYPE, - RELATION_REVERSE_TYPE: TableOwner.TABLE_OWNER_RELATION_TYPE - }) + relationship = GraphRelationship( + start_key=self.get_owner_model_key(owner), + start_label=User.USER_NODE_LABEL, + end_key=self.get_metadata_model_key(), + end_label='Table', + type=TableOwner.OWNER_TABLE_RELATION_TYPE, + reverse_type=TableOwner.TABLE_OWNER_RELATION_TYPE, + attributes={} + ) + results.append(relationship) return results diff --git a/databuilder/models/table_source.py b/databuilder/models/table_source.py index f24a1c359..0cfaa0711 100644 --- a/databuilder/models/table_source.py +++ b/databuilder/models/table_source.py @@ -1,16 +1,16 @@ # Copyright Contributors to the Amundsen project. # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional +from typing import List, Optional -from databuilder.models.neo4j_csv_serde import Neo4jCsvSerializable, NODE_KEY, \ - NODE_LABEL, RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ - RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE +from databuilder.models.graph_serializable import GraphSerializable from databuilder.models.table_metadata import TableMetadata +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship -class TableSource(Neo4jCsvSerializable): +class TableSource(GraphSerializable): """ Hive table source model. """ @@ -38,14 +38,14 @@ def __init__(self, self._node_iter = iter(self.create_nodes()) self._relation_iter = iter(self.create_relation()) - def create_next_node(self) -> Optional[Dict[str, Any]]: + def create_next_node(self) -> Optional[GraphNode]: # return the string representation of the data try: return next(self._node_iter) except StopIteration: return None - def create_next_relation(self) -> Optional[Dict[str, Any]]: + def create_next_relation(self) -> Optional[GraphRelationship]: try: return next(self._relation_iter) except StopIteration: @@ -63,32 +63,37 @@ def get_metadata_model_key(self) -> str: schema=self.schema, table=self.table) - def create_nodes(self) -> List[Dict[str, Any]]: + def create_nodes(self) -> List[GraphNode]: """ Create a list of Neo4j node records :return: """ - results = [{ - NODE_KEY: self.get_source_model_key(), - NODE_LABEL: TableSource.LABEL, - 'source': self.source, - 'source_type': self.source_type - }] + node = GraphNode( + key=self.get_source_model_key(), + label=TableSource.LABEL, + attributes={ + 'source': self.source, + 'source_type': self.source_type + } + ) + results = [node] return results - def create_relation(self) -> List[Dict[str, Any]]: + def create_relation(self) -> List[GraphRelationship]: """ Create a list of relation map between owner record with original hive table :return: """ - results = [{ - RELATION_START_KEY: self.get_source_model_key(), - RELATION_START_LABEL: TableSource.LABEL, - RELATION_END_KEY: self.get_metadata_model_key(), - RELATION_END_LABEL: TableMetadata.TABLE_NODE_LABEL, - RELATION_TYPE: TableSource.SOURCE_TABLE_RELATION_TYPE, - RELATION_REVERSE_TYPE: TableSource.TABLE_SOURCE_RELATION_TYPE - }] + relationship = GraphRelationship( + start_label=TableSource.LABEL, + start_key=self.get_source_model_key(), + end_label=TableMetadata.TABLE_NODE_LABEL, + end_key=self.get_metadata_model_key(), + type=TableSource.SOURCE_TABLE_RELATION_TYPE, + reverse_type=TableSource.TABLE_SOURCE_RELATION_TYPE, + attributes={} + ) + results = [relationship] return results def __repr__(self) -> str: diff --git a/databuilder/models/table_stats.py b/databuilder/models/table_stats.py index 6cacef671..f536f81ac 100644 --- a/databuilder/models/table_stats.py +++ b/databuilder/models/table_stats.py @@ -1,15 +1,15 @@ # Copyright Contributors to the Amundsen project. # SPDX-License-Identifier: Apache-2.0 +import ast +from typing import List, Optional -from typing import Any, Dict, List, Optional - -from databuilder.models.neo4j_csv_serde import Neo4jCsvSerializable, NODE_KEY, \ - NODE_LABEL, RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ - RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE +from databuilder.models.graph_serializable import GraphSerializable from databuilder.models.table_metadata import ColumnMetadata +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship -class TableColumnStats(Neo4jCsvSerializable): +class TableColumnStats(GraphSerializable): """ Hive table stats model. Each instance represents one row of hive watermark result. @@ -27,9 +27,9 @@ def __init__(self, stat_val: str, start_epoch: str, end_epoch: str, - db: str='hive', - cluster: str='gold', - schema: str=None + db: str = 'hive', + cluster: str = 'gold', + schema: str = None ) -> None: if schema is None: self.schema, self.table = table_name.split('.') @@ -42,18 +42,22 @@ def __init__(self, self.end_epoch = end_epoch self.cluster = cluster self.stat_name = stat_name + try: + stat_val = ast.literal_eval(stat_val) + except ValueError: + stat_val = stat_val self.stat_val = stat_val self._node_iter = iter(self.create_nodes()) self._relation_iter = iter(self.create_relation()) - def create_next_node(self) -> Optional[Dict[str, Any]]: + def create_next_node(self) -> Optional[GraphNode]: # return the string representation of the data try: return next(self._node_iter) except StopIteration: return None - def create_next_relation(self) -> Optional[Dict[str, Any]]: + def create_next_relation(self) -> Optional[GraphRelationship]: try: return next(self._relation_iter) except StopIteration: @@ -75,32 +79,37 @@ def get_col_key(self) -> str: tbl=self.table, col=self.col_name) - def create_nodes(self) -> List[Dict[str, Any]]: + def create_nodes(self) -> List[GraphNode]: """ Create a list of Neo4j node records :return: """ - results = [{ - NODE_KEY: self.get_table_stat_model_key(), - NODE_LABEL: TableColumnStats.LABEL, - 'stat_val:UNQUOTED': self.stat_val, - 'stat_name': self.stat_name, - 'start_epoch': self.start_epoch, - 'end_epoch': self.end_epoch, - }] + node = GraphNode( + key=self.get_table_stat_model_key(), + label=TableColumnStats.LABEL, + attributes={ + 'stat_val': self.stat_val, + 'stat_name': self.stat_name, + 'start_epoch': self.start_epoch, + 'end_epoch': self.end_epoch, + } + ) + results = [node] return results - def create_relation(self) -> List[Dict[str, Any]]: + def create_relation(self) -> List[GraphRelationship]: """ Create a list of relation map between table stat record with original hive table :return: """ - results = [{ - RELATION_START_KEY: self.get_table_stat_model_key(), - RELATION_START_LABEL: TableColumnStats.LABEL, - RELATION_END_KEY: self.get_col_key(), - RELATION_END_LABEL: ColumnMetadata.COLUMN_NODE_LABEL, - RELATION_TYPE: TableColumnStats.STAT_Column_RELATION_TYPE, - RELATION_REVERSE_TYPE: TableColumnStats.Column_STAT_RELATION_TYPE - }] + relationship = GraphRelationship( + start_key=self.get_table_stat_model_key(), + start_label=TableColumnStats.LABEL, + end_key=self.get_col_key(), + end_label=ColumnMetadata.COLUMN_NODE_LABEL, + type=TableColumnStats.STAT_Column_RELATION_TYPE, + reverse_type=TableColumnStats.Column_STAT_RELATION_TYPE, + attributes={} + ) + results = [relationship] return results diff --git a/databuilder/models/usage/usage_constants.py b/databuilder/models/usage/usage_constants.py index 7a45bc420..f9dd962ba 100644 --- a/databuilder/models/usage/usage_constants.py +++ b/databuilder/models/usage/usage_constants.py @@ -1,10 +1,7 @@ # Copyright Contributors to the Amundsen project. # SPDX-License-Identifier: Apache-2.0 -from databuilder.publisher.neo4j_csv_publisher import UNQUOTED_SUFFIX - - READ_RELATION_TYPE = 'READ' READ_REVERSE_RELATION_TYPE = 'READ_BY' -READ_RELATION_COUNT_PROPERTY = 'read_count{}'.format(UNQUOTED_SUFFIX) +READ_RELATION_COUNT_PROPERTY = 'read_count' diff --git a/databuilder/models/user.py b/databuilder/models/user.py index cd22857eb..f48385186 100644 --- a/databuilder/models/user.py +++ b/databuilder/models/user.py @@ -2,15 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 import copy -from typing import Any, List, Dict, Optional -from databuilder.models.neo4j_csv_serde import Neo4jCsvSerializable, NODE_KEY, \ - NODE_LABEL, RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ - RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE -from databuilder.publisher.neo4j_csv_publisher import UNQUOTED_SUFFIX +from typing import Any, List, Optional +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship -class User(Neo4jCsvSerializable): + +class User(GraphSerializable): """ User model. This model doesn't define any relationship. """ @@ -25,7 +25,7 @@ class User(Neo4jCsvSerializable): USER_NODE_EMPLOYEE_TYPE = 'employee_type' USER_NODE_MANAGER_EMAIL = 'manager_email' USER_NODE_SLACK_ID = 'slack_id' - USER_NODE_IS_ACTIVE = 'is_active{}'.format(UNQUOTED_SUFFIX) # bool value needs to be unquoted when publish to neo4j + USER_NODE_IS_ACTIVE = 'is_active' # bool value needs to be unquoted when publish to neo4j USER_NODE_UPDATED_AT = 'updated_at' USER_NODE_ROLE_NAME = 'role_name' @@ -34,18 +34,18 @@ class User(Neo4jCsvSerializable): def __init__(self, email: str, - first_name: str='', - last_name: str='', - name: str='', - github_username: str='', - team_name: str='', - employee_type: str='', - manager_email: str='', - slack_id: str='', - is_active: bool=True, - updated_at: int=0, - role_name: str='', - do_not_update_empty_attribute: bool=False, + first_name: str = '', + last_name: str = '', + name: str = '', + github_username: str = '', + team_name: str = '', + employee_type: str = '', + manager_email: str = '', + slack_id: str = '', + is_active: bool = True, + updated_at: int = 0, + role_name: str = '', + do_not_update_empty_attribute: bool = False, **kwargs: Any ) -> None: """ @@ -91,14 +91,14 @@ def __init__(self, self._node_iter = iter(self.create_nodes()) self._rel_iter = iter(self.create_relation()) - def create_next_node(self) -> Optional[Dict[str, Any]]: + def create_next_node(self) -> Optional[GraphNode]: # return the string representation of the data try: return next(self._node_iter) except StopIteration: return None - def create_next_relation(self) -> Optional[Dict[str, Any]]: + def create_next_relation(self) -> Optional[GraphRelationship]: """ :return: """ @@ -115,55 +115,62 @@ def get_user_model_key(cls, return '' return User.USER_NODE_KEY_FORMAT.format(email=email) - def create_nodes(self) -> List[Dict[str, Any]]: + def create_nodes(self) -> List[GraphNode]: """ Create a list of Neo4j node records :return: """ - result_node = { - NODE_KEY: User.get_user_model_key(email=self.email), - NODE_LABEL: User.USER_NODE_LABEL, + + node_attributes = { User.USER_NODE_EMAIL: self.email, User.USER_NODE_IS_ACTIVE: self.is_active, } - result_node[User.USER_NODE_FIRST_NAME] = self.first_name if self.first_name else '' - result_node[User.USER_NODE_LAST_NAME] = self.last_name if self.last_name else '' - result_node[User.USER_NODE_FULL_NAME] = self.name if self.name else '' - result_node[User.USER_NODE_GITHUB_NAME] = self.github_username if self.github_username else '' - result_node[User.USER_NODE_TEAM] = self.team_name if self.team_name else '' - result_node[User.USER_NODE_EMPLOYEE_TYPE] = self.employee_type if self.employee_type else '' - result_node[User.USER_NODE_SLACK_ID] = self.slack_id if self.slack_id else '' - result_node[User.USER_NODE_ROLE_NAME] = self.role_name if self.role_name else '' + node_attributes[User.USER_NODE_FIRST_NAME] = self.first_name if self.first_name else '' + node_attributes[User.USER_NODE_LAST_NAME] = self.last_name if self.last_name else '' + node_attributes[User.USER_NODE_FULL_NAME] = self.name if self.name else '' + node_attributes[User.USER_NODE_GITHUB_NAME] = self.github_username if self.github_username else '' + node_attributes[User.USER_NODE_TEAM] = self.team_name if self.team_name else '' + node_attributes[User.USER_NODE_EMPLOYEE_TYPE] = self.employee_type if self.employee_type else '' + node_attributes[User.USER_NODE_SLACK_ID] = self.slack_id if self.slack_id else '' + node_attributes[User.USER_NODE_ROLE_NAME] = self.role_name if self.role_name else '' if self.updated_at: - result_node[User.USER_NODE_UPDATED_AT] = self.updated_at + node_attributes[User.USER_NODE_UPDATED_AT] = self.updated_at elif not self.do_not_update_empty_attribute: - result_node[User.USER_NODE_UPDATED_AT] = 0 + node_attributes[User.USER_NODE_UPDATED_AT] = 0 if self.attrs: for k, v in self.attrs.items(): - if k not in result_node: - result_node[k] = v + if k not in node_attributes: + node_attributes[k] = v if self.do_not_update_empty_attribute: - for k, v in list(result_node.items()): + for k, v in list(node_attributes.items()): if not v: - del result_node[k] + del node_attributes[k] + + node = GraphNode( + key=User.get_user_model_key(email=self.email), + label=User.USER_NODE_LABEL, + attributes=node_attributes + ) - return [result_node] + return [node] - def create_relation(self) -> List[Dict[str, Any]]: + def create_relation(self) -> List[GraphRelationship]: if self.manager_email: # only create the relation if the manager exists - return [{ - RELATION_START_KEY: User.get_user_model_key(email=self.email), - RELATION_START_LABEL: User.USER_NODE_LABEL, - RELATION_END_KEY: self.get_user_model_key(email=self.manager_email), - RELATION_END_LABEL: User.USER_NODE_LABEL, - RELATION_TYPE: User.USER_MANAGER_RELATION_TYPE, - RELATION_REVERSE_TYPE: User.MANAGER_USER_RELATION_TYPE - }] + relationship = GraphRelationship( + start_key=User.get_user_model_key(email=self.email), + start_label=User.USER_NODE_LABEL, + end_label=User.USER_NODE_LABEL, + end_key=self.get_user_model_key(email=self.manager_email), + type=User.USER_MANAGER_RELATION_TYPE, + reverse_type=User.MANAGER_USER_RELATION_TYPE, + attributes={} + ) + return [relationship] return [] def __repr__(self) -> str: diff --git a/databuilder/models/watermark.py b/databuilder/models/watermark.py index 8b71d8e8e..53472bc2f 100644 --- a/databuilder/models/watermark.py +++ b/databuilder/models/watermark.py @@ -1,14 +1,14 @@ # Copyright Contributors to the Amundsen project. # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Dict, List, Optional, Tuple +from typing import List, Union, Tuple -from databuilder.models.neo4j_csv_serde import Neo4jCsvSerializable, NODE_KEY, \ - NODE_LABEL, RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ - RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship -class Watermark(Neo4jCsvSerializable): +class Watermark(GraphSerializable): """ Table watermark result model. Each instance represents one row of table watermark result. @@ -46,14 +46,14 @@ def __init__(self, self._node_iter = iter(self.create_nodes()) self._relation_iter = iter(self.create_relation()) - def create_next_node(self) -> Optional[Dict[str, Any]]: + def create_next_node(self) -> Union[GraphNode, None]: # return the string representation of the data try: return next(self._node_iter) except StopIteration: return None - def create_next_relation(self) -> Optional[Dict[str, Any]]: + def create_next_relation(self) -> Union[GraphRelationship, None]: try: return next(self._relation_iter) except StopIteration: @@ -72,33 +72,38 @@ def get_metadata_model_key(self) -> str: schema=self.schema, table=self.table) - def create_nodes(self) -> List[Dict[str, Any]]: + def create_nodes(self) -> List[GraphNode]: """ Create a list of Neo4j node records :return: """ results = [] for part in self.parts: - results.append({ - NODE_KEY: self.get_watermark_model_key(), - NODE_LABEL: Watermark.LABEL, - 'partition_key': part[0], - 'partition_value': part[1], - 'create_time': self.create_time - }) + part_node = GraphNode( + key=self.get_watermark_model_key(), + label=Watermark.LABEL, + attributes={ + 'partition_key': part[0], + 'partition_value': part[1], + 'create_time': self.create_time + } + ) + results.append(part_node) return results - def create_relation(self) -> List[Dict[str, Any]]: + def create_relation(self) -> List[GraphRelationship]: """ Create a list of relation map between watermark record with original table :return: """ - results = [{ - RELATION_START_KEY: self.get_watermark_model_key(), - RELATION_START_LABEL: Watermark.LABEL, - RELATION_END_KEY: self.get_metadata_model_key(), - RELATION_END_LABEL: 'Table', - RELATION_TYPE: Watermark.WATERMARK_TABLE_RELATION_TYPE, - RELATION_REVERSE_TYPE: Watermark.TABLE_WATERMARK_RELATION_TYPE - }] + relation = GraphRelationship( + start_key=self.get_watermark_model_key(), + start_label=Watermark.LABEL, + end_key=self.get_metadata_model_key(), + end_label='Table', + type=Watermark.WATERMARK_TABLE_RELATION_TYPE, + reverse_type=Watermark.TABLE_WATERMARK_RELATION_TYPE, + attributes={} + ) + results = [relation] return results diff --git a/databuilder/serializers/__init__.py b/databuilder/serializers/__init__.py new file mode 100644 index 000000000..f3145d75b --- /dev/null +++ b/databuilder/serializers/__init__.py @@ -0,0 +1,2 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 diff --git a/databuilder/serializers/neo4_serializer.py b/databuilder/serializers/neo4_serializer.py new file mode 100644 index 000000000..baab36c0d --- /dev/null +++ b/databuilder/serializers/neo4_serializer.py @@ -0,0 +1,69 @@ +# Copyright Contributors to the Amundsen project. +# SPDX-License-Identifier: Apache-2.0 + +from typing import Dict, Any, Optional + +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_serializable import ( + NODE_LABEL, + NODE_KEY, + RELATION_END_KEY, + RELATION_END_LABEL, + RELATION_REVERSE_TYPE, + RELATION_START_KEY, + RELATION_START_LABEL, + RELATION_TYPE +) +from databuilder.publisher.neo4j_csv_publisher import UNQUOTED_SUFFIX + + +def serialize_node(node: Optional[GraphNode]) -> Dict[str, Any]: + if node is None: + return {} + + node_dict = { + NODE_LABEL: node.label, + NODE_KEY: node.key + } + for key, value in node.attributes.items(): + key_suffix = _get_neo4j_suffix_value(value) + formatted_key = "{key}{suffix}".format( + key=key, + suffix=key_suffix + ) + node_dict[formatted_key] = value + return node_dict + + +def serialize_relationship(relationship: Optional[GraphRelationship]) -> Dict[str, Any]: + if relationship is None: + return {} + + relationship_dict = { + RELATION_START_KEY: relationship.start_key, + RELATION_START_LABEL: relationship.start_label, + RELATION_END_KEY: relationship.end_key, + RELATION_END_LABEL: relationship.end_label, + RELATION_TYPE: relationship.type, + RELATION_REVERSE_TYPE: relationship.reverse_type, + } + for key, value in relationship.attributes.items(): + key_suffix = _get_neo4j_suffix_value(value) + formatted_key = "{key}{suffix}".format( + key=key, + suffix=key_suffix + ) + relationship_dict[formatted_key] = value + + return relationship_dict + + +def _get_neo4j_suffix_value(value: Any) -> str: + if isinstance(value, int): + return UNQUOTED_SUFFIX + + if isinstance(value, bool): + return UNQUOTED_SUFFIX + + return '' diff --git a/tests/unit/loader/test_fs_neo4j_csv_loader.py b/tests/unit/loader/test_fs_neo4j_csv_loader.py index abe9dec7b..511ae3962 100644 --- a/tests/unit/loader/test_fs_neo4j_csv_loader.py +++ b/tests/unit/loader/test_fs_neo4j_csv_loader.py @@ -14,7 +14,7 @@ from databuilder.job.base_job import Job from databuilder.loader.file_system_neo4j_csv_loader import FsNeo4jCSVLoader -from tests.unit.models.test_neo4j_csv_serde import Movie, Actor, City +from tests.unit.models.test_graph_serializable import Movie, Actor, City from operator import itemgetter diff --git a/tests/unit/models/dashboard/test_dashboard_chart.py b/tests/unit/models/dashboard/test_dashboard_chart.py index 34a603b89..73b24d1c1 100644 --- a/tests/unit/models/dashboard/test_dashboard_chart.py +++ b/tests/unit/models/dashboard/test_dashboard_chart.py @@ -6,8 +6,9 @@ from typing import Any, Dict from databuilder.models.dashboard.dashboard_chart import DashboardChart -from databuilder.models.neo4j_csv_serde import RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ +from databuilder.models.graph_serializable import RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE +from databuilder.serializers import neo4_serializer class TestDashboardChart(unittest.TestCase): @@ -24,6 +25,7 @@ def test_create_nodes(self) -> None: ) actual = dashboard_chart.create_next_node() + actual_serialized = neo4_serializer.serialize_node(actual) expected: Dict[str, Any] = { 'name': 'c_name', 'type': 'bar', @@ -34,7 +36,7 @@ def test_create_nodes(self) -> None: } assert actual is not None - self.assertDictEqual(expected, actual) + self.assertDictEqual(expected, actual_serialized) self.assertIsNone(dashboard_chart.create_next_node()) dashboard_chart = DashboardChart(dashboard_group_id='dg_id', @@ -45,6 +47,7 @@ def test_create_nodes(self) -> None: ) actual2 = dashboard_chart.create_next_node() + actual2_serialized = neo4_serializer.serialize_node(actual2) expected2: Dict[str, Any] = { 'id': 'c_id', 'KEY': '_dashboard://gold.dg_id/d_id/query/q_id/chart/c_id', @@ -52,7 +55,7 @@ def test_create_nodes(self) -> None: 'url': 'http://gold.foo.bar/' } assert actual2 is not None - self.assertDictEqual(expected2, actual2) + self.assertDictEqual(expected2, actual2_serialized) def test_create_relation(self) -> None: dashboard_chart = DashboardChart(dashboard_group_id='dg_id', @@ -64,6 +67,7 @@ def test_create_relation(self) -> None: ) actual = dashboard_chart.create_next_relation() + actual_serialized = neo4_serializer.serialize_relationship(actual) expected: Dict[str, Any] = { RELATION_END_KEY: '_dashboard://gold.dg_id/d_id/query/q_id/chart/c_id', RELATION_START_LABEL: 'Query', RELATION_END_LABEL: 'Chart', @@ -72,5 +76,5 @@ def test_create_relation(self) -> None: } assert actual is not None - self.assertEqual(expected, actual) + self.assertEqual(expected, actual_serialized) self.assertIsNone(dashboard_chart.create_next_relation()) diff --git a/tests/unit/models/dashboard/test_dashboard_last_modified.py b/tests/unit/models/dashboard/test_dashboard_last_modified.py index e8596a5f6..58b7e96dd 100644 --- a/tests/unit/models/dashboard/test_dashboard_last_modified.py +++ b/tests/unit/models/dashboard/test_dashboard_last_modified.py @@ -6,8 +6,9 @@ from typing import Any, Dict from databuilder.models.dashboard.dashboard_last_modified import DashboardLastModifiedTimestamp -from databuilder.models.neo4j_csv_serde import RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ +from databuilder.models.graph_serializable import RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE +from databuilder.serializers import neo4_serializer class TestDashboardLastModifiedTimestamp(unittest.TestCase): @@ -20,15 +21,16 @@ def test_dashboard_timestamp_nodes(self) -> None: dashboard_group_id='dashboard_group_id') actual = dashboard_last_modified.create_next_node() + actual_serialized = neo4_serializer.serialize_node(actual) expected: Dict[str, Any] = { - 'timestamp': 123456789, + 'timestamp:UNQUOTED': 123456789, 'name': 'last_updated_timestamp', 'KEY': 'product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id/_last_modified_timestamp', 'LABEL': 'Timestamp' } assert actual is not None - self.assertDictEqual(actual, expected) + self.assertDictEqual(actual_serialized, expected) self.assertIsNone(dashboard_last_modified.create_next_node()) def test_dashboard_owner_relations(self) -> None: @@ -39,6 +41,7 @@ def test_dashboard_owner_relations(self) -> None: dashboard_group_id='dashboard_group_id') actual = dashboard_last_modified.create_next_relation() + actual_serialized = neo4_serializer.serialize_relationship(actual) expected: Dict[str, Any] = { RELATION_END_KEY: 'product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id' '/_last_modified_timestamp', @@ -50,5 +53,5 @@ def test_dashboard_owner_relations(self) -> None: } assert actual is not None - self.assertDictEqual(actual, expected) + self.assertDictEqual(actual_serialized, expected) self.assertIsNone(dashboard_last_modified.create_next_relation()) diff --git a/tests/unit/models/dashboard/test_dashboard_metadata.py b/tests/unit/models/dashboard/test_dashboard_metadata.py index 56439a6ee..ba2e816b6 100644 --- a/tests/unit/models/dashboard/test_dashboard_metadata.py +++ b/tests/unit/models/dashboard/test_dashboard_metadata.py @@ -5,6 +5,7 @@ import unittest from databuilder.models.dashboard.dashboard_metadata import DashboardMetadata +from databuilder.serializers import neo4_serializer class TestDashboardMetadata(unittest.TestCase): @@ -41,10 +42,17 @@ def setUp(self) -> None: ) self.expected_nodes_deduped = [ - {'KEY': '_dashboard://gold', 'LABEL': 'Cluster', 'name': 'gold'}, - {'created_timestamp': 123456789, 'name': 'Agent', 'KEY': '_dashboard://gold.Product - Jobs.cz/Agent', - 'LABEL': 'Dashboard', - 'dashboard_url': 'https://foo.bar/dashboard_group/foo/dashboard/bar'}, + { + 'KEY': '_dashboard://gold', + 'LABEL': 'Cluster', 'name': 'gold' + }, + { + 'created_timestamp:UNQUOTED': 123456789, + 'name': 'Agent', + 'KEY': '_dashboard://gold.Product - Jobs.cz/Agent', + 'LABEL': 'Dashboard', + 'dashboard_url': 'https://foo.bar/dashboard_group/foo/dashboard/bar' + }, {'name': 'Product - Jobs.cz', 'KEY': '_dashboard://gold.Product - Jobs.cz', 'LABEL': 'Dashboardgroup', 'dashboard_group_url': 'https://foo.bar/dashboard_group/foo'}, {'KEY': '_dashboard://gold.Product - Jobs.cz/_description', 'LABEL': 'Description', @@ -138,7 +146,8 @@ def test_serialize(self) -> None: node_row = self.dashboard_metadata.next_node() actual = [] while node_row: - actual.append(node_row) + node_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_serialized) node_row = self.dashboard_metadata.next_node() self.assertEqual(self.expected_nodes, actual) @@ -146,7 +155,8 @@ def test_serialize(self) -> None: relation_row = self.dashboard_metadata.next_relation() actual = [] while relation_row: - actual.append(relation_row) + relation_serialized = neo4_serializer.serialize_relationship(relation_row) + actual.append(relation_serialized) relation_row = self.dashboard_metadata.next_relation() self.assertEqual(self.expected_rels, actual) @@ -155,7 +165,8 @@ def test_serialize(self) -> None: node_row = self.dashboard_metadata2.next_node() actual = [] while node_row: - actual.append(node_row) + node_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_serialized) node_row = self.dashboard_metadata2.next_node() self.assertEqual(self.expected_nodes_deduped2, actual) @@ -163,7 +174,8 @@ def test_serialize(self) -> None: relation_row = self.dashboard_metadata2.next_relation() actual = [] while relation_row: - actual.append(relation_row) + relation_serialized = neo4_serializer.serialize_relationship(relation_row) + actual.append(relation_serialized) relation_row = self.dashboard_metadata2.next_relation() self.assertEqual(self.expected_rels_deduped2, actual) @@ -172,7 +184,8 @@ def test_serialize(self) -> None: node_row = self.dashboard_metadata3.next_node() actual = [] while node_row: - actual.append(node_row) + node_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_serialized) node_row = self.dashboard_metadata3.next_node() self.assertEqual(self.expected_nodes_deduped3, actual) @@ -180,7 +193,8 @@ def test_serialize(self) -> None: relation_row = self.dashboard_metadata3.next_relation() actual = [] while relation_row: - actual.append(relation_row) + relation_serialized = neo4_serializer.serialize_relationship(relation_row) + actual.append(relation_serialized) relation_row = self.dashboard_metadata3.next_relation() self.assertEqual(self.expected_rels_deduped3, actual) diff --git a/tests/unit/models/dashboard/test_dashboard_owner.py b/tests/unit/models/dashboard/test_dashboard_owner.py index 895d6be98..191ea00f2 100644 --- a/tests/unit/models/dashboard/test_dashboard_owner.py +++ b/tests/unit/models/dashboard/test_dashboard_owner.py @@ -4,8 +4,9 @@ import unittest from databuilder.models.dashboard.dashboard_owner import DashboardOwner -from databuilder.models.neo4j_csv_serde import RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ +from databuilder.models.graph_serializable import RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE +from databuilder.serializers import neo4_serializer class TestDashboardOwner(unittest.TestCase): @@ -22,9 +23,10 @@ def test_dashboard_owner_relations(self) -> None: dashboard_id='dashboard_id', dashboard_group_id='dashboard_group_id') actual = dashboard_owner.create_next_relation() + actual_serialized = neo4_serializer.serialize_relationship(actual) expected = {RELATION_END_KEY: 'foo@bar.com', RELATION_START_LABEL: 'Dashboard', RELATION_END_LABEL: 'User', RELATION_START_KEY: 'product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id', RELATION_TYPE: 'OWNER', RELATION_REVERSE_TYPE: 'OWNER_OF'} assert actual is not None - self.assertDictEqual(actual, expected) + self.assertDictEqual(actual_serialized, expected) diff --git a/tests/unit/models/dashboard/test_dashboard_query.py b/tests/unit/models/dashboard/test_dashboard_query.py index cc709badf..29bb0f26b 100644 --- a/tests/unit/models/dashboard/test_dashboard_query.py +++ b/tests/unit/models/dashboard/test_dashboard_query.py @@ -4,9 +4,10 @@ import unittest from databuilder.models.dashboard.dashboard_query import DashboardQuery -from databuilder.models.neo4j_csv_serde import NODE_KEY, \ +from databuilder.models.graph_serializable import NODE_KEY, \ NODE_LABEL, RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE +from databuilder.serializers import neo4_serializer class TestDashboardQuery(unittest.TestCase): @@ -21,12 +22,13 @@ def test_create_nodes(self) -> None: query_text='SELECT * FROM foo.bar') actual = dashboard_query.create_next_node() + actual_serialized = neo4_serializer.serialize_node(actual) expected = {'url': 'http://foo.bar/query/baz', 'name': 'q_name', 'id': 'q_id', 'query_text': 'SELECT * FROM foo.bar', NODE_KEY: '_dashboard://gold.dg_id/d_id/query/q_id', NODE_LABEL: DashboardQuery.DASHBOARD_QUERY_LABEL} - self.assertEqual(expected, actual) + self.assertEqual(expected, actual_serialized) def test_create_relation(self) -> None: dashboard_query = DashboardQuery(dashboard_group_id='dg_id', @@ -35,9 +37,10 @@ def test_create_relation(self) -> None: query_name='q_name') actual = dashboard_query.create_next_relation() + actual_serialized = neo4_serializer.serialize_relationship(actual) expected = {RELATION_END_KEY: '_dashboard://gold.dg_id/d_id/query/q_id', RELATION_START_LABEL: 'Dashboard', RELATION_END_LABEL: DashboardQuery.DASHBOARD_QUERY_LABEL, RELATION_START_KEY: '_dashboard://gold.dg_id/d_id', RELATION_TYPE: 'HAS_QUERY', RELATION_REVERSE_TYPE: 'QUERY_OF'} - self.assertEqual(expected, actual) + self.assertEqual(expected, actual_serialized) diff --git a/tests/unit/models/dashboard/test_dashboard_table.py b/tests/unit/models/dashboard/test_dashboard_table.py index fc836d4df..a65d8301a 100644 --- a/tests/unit/models/dashboard/test_dashboard_table.py +++ b/tests/unit/models/dashboard/test_dashboard_table.py @@ -4,8 +4,9 @@ import unittest from databuilder.models.dashboard.dashboard_table import DashboardTable -from databuilder.models.neo4j_csv_serde import RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ +from databuilder.models.graph_serializable import RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE +from databuilder.serializers import neo4_serializer class TestDashboardTable(unittest.TestCase): @@ -24,26 +25,28 @@ def test_dashboard_table_relations(self) -> None: dashboard_id='dashboard_id', dashboard_group_id='dashboard_group_id') actual = dashboard_table.create_next_relation() + actual_serialized = neo4_serializer.serialize_relationship(actual) expected = {RELATION_END_KEY: 'hive://gold.schema/table1', RELATION_START_LABEL: 'Dashboard', RELATION_END_LABEL: 'Table', RELATION_START_KEY: 'product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id', RELATION_TYPE: 'DASHBOARD_WITH_TABLE', RELATION_REVERSE_TYPE: 'TABLE_OF_DASHBOARD'} assert actual is not None - self.assertDictEqual(actual, expected) + self.assertDictEqual(actual_serialized, expected) def test_dashboard_table_without_dot_as_name(self) -> None: dashboard_table = DashboardTable(table_ids=['bq-name://project-id.schema-name/table-name'], cluster='cluster_id', product='product_id', dashboard_id='dashboard_id', dashboard_group_id='dashboard_group_id') actual = dashboard_table.create_next_relation() + actual_serialized = neo4_serializer.serialize_relationship(actual) expected = {RELATION_END_KEY: 'bq-name://project-id.schema-name/table-name', RELATION_START_LABEL: 'Dashboard', RELATION_END_LABEL: 'Table', RELATION_START_KEY: 'product_id_dashboard://cluster_id.dashboard_group_id/dashboard_id', RELATION_TYPE: 'DASHBOARD_WITH_TABLE', RELATION_REVERSE_TYPE: 'TABLE_OF_DASHBOARD'} assert actual is not None - self.assertDictEqual(actual, expected) + self.assertDictEqual(actual_serialized, expected) def test_dashboard_table_with_dot_as_name(self) -> None: dashboard_table = DashboardTable(table_ids=['bq-name://project.id.schema-name/table-name'], diff --git a/tests/unit/models/dashboard/test_dashboard_usage.py b/tests/unit/models/dashboard/test_dashboard_usage.py index 413e18d82..2c7c04bf5 100644 --- a/tests/unit/models/dashboard/test_dashboard_usage.py +++ b/tests/unit/models/dashboard/test_dashboard_usage.py @@ -6,8 +6,9 @@ from typing import Any, Dict from databuilder.models.dashboard.dashboard_usage import DashboardUsage -from databuilder.models.neo4j_csv_serde import RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ +from databuilder.models.graph_serializable import RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE +from databuilder.serializers import neo4_serializer class TestDashboardOwner(unittest.TestCase): @@ -18,13 +19,14 @@ def test_dashboard_usage_user_nodes(self) -> None: product='product_id', should_create_user_node=True) actual = dashboard_usage.create_next_node() + actual_serialized = neo4_serializer.serialize_node(actual) expected: Dict[str, Any] = { 'is_active:UNQUOTED': True, 'last_name': '', 'full_name': '', 'employee_type': '', 'first_name': '', - 'updated_at': 0, + 'updated_at:UNQUOTED': 0, 'LABEL': 'User', 'slack_id': '', 'KEY': 'foo@bar.com', @@ -35,7 +37,7 @@ def test_dashboard_usage_user_nodes(self) -> None: } assert actual is not None - self.assertDictEqual(expected, actual) + self.assertDictEqual(expected, actual_serialized) self.assertIsNone(dashboard_usage.create_next_node()) def test_dashboard_usage_no_user_nodes(self) -> None: @@ -52,6 +54,7 @@ def test_dashboard_owner_relations(self) -> None: product='product_id') actual = dashboard_usage.create_next_relation() + actual_serialized = neo4_serializer.serialize_relationship(actual) expected: Dict[str, Any] = { 'read_count:UNQUOTED': 123, RELATION_END_KEY: 'foo@bar.com', @@ -63,5 +66,5 @@ def test_dashboard_owner_relations(self) -> None: } assert actual is not None - self.assertDictEqual(expected, actual) + self.assertDictEqual(expected, actual_serialized) self.assertIsNone(dashboard_usage.create_next_relation()) diff --git a/tests/unit/models/schema/test_schema.py b/tests/unit/models/schema/test_schema.py index ff2aa3d9b..b482288bb 100644 --- a/tests/unit/models/schema/test_schema.py +++ b/tests/unit/models/schema/test_schema.py @@ -4,6 +4,7 @@ import unittest from databuilder.models.schema.schema import SchemaModel +from databuilder.serializers import neo4_serializer class TestSchemaDescription(unittest.TestCase): @@ -13,12 +14,18 @@ def test_create_nodes(self) -> None: schema = SchemaModel(schema_key='db://cluster.schema', schema='schema_name', description='foo') - - self.assertDictEqual(schema.create_next_node() or {}, - {'name': 'schema_name', 'KEY': 'db://cluster.schema', 'LABEL': 'Schema'}) - self.assertDictEqual(schema.create_next_node() or {}, + schema_node = schema.create_next_node() + serialized_schema_node = neo4_serializer.serialize_node(schema_node) + schema_desc_node = schema.create_next_node() + serialized_schema_desc_node = neo4_serializer.serialize_node(schema_desc_node) + self.assertDictEqual( + serialized_schema_node, + {'name': 'schema_name', 'KEY': 'db://cluster.schema', 'LABEL': 'Schema'} + ) + self.assertDictEqual(serialized_schema_desc_node, {'description_source': 'description', 'description': 'foo', - 'KEY': 'db://cluster.schema/_description', 'LABEL': 'Description'}) + 'KEY': 'db://cluster.schema/_description', 'LABEL': 'Description'} + ) self.assertIsNone(schema.create_next_node()) def test_create_nodes_no_description(self) -> None: @@ -26,7 +33,10 @@ def test_create_nodes_no_description(self) -> None: schema = SchemaModel(schema_key='db://cluster.schema', schema='schema_name') - self.assertDictEqual(schema.create_next_node() or {}, + schema_node = schema.create_next_node() + serialized_schema_node = neo4_serializer.serialize_node(schema_node) + + self.assertDictEqual(serialized_schema_node, {'name': 'schema_name', 'KEY': 'db://cluster.schema', 'LABEL': 'Schema'}) self.assertIsNone(schema.create_next_node()) @@ -37,9 +47,14 @@ def test_create_nodes_programmatic_description(self) -> None: description='foo', description_source='bar') - self.assertDictEqual(schema.create_next_node() or {}, + schema_node = schema.create_next_node() + serialized_schema_node = neo4_serializer.serialize_node(schema_node) + schema_desc_node = schema.create_next_node() + serialized_schema_prod_desc_node = neo4_serializer.serialize_node(schema_desc_node) + + self.assertDictEqual(serialized_schema_node, {'name': 'schema_name', 'KEY': 'db://cluster.schema', 'LABEL': 'Schema'}) - self.assertDictEqual(schema.create_next_node() or {}, + self.assertDictEqual(serialized_schema_prod_desc_node, {'description_source': 'bar', 'description': 'foo', 'KEY': 'db://cluster.schema/_bar_description', 'LABEL': 'Programmatic_Description'}) self.assertIsNone(schema.create_next_node()) @@ -50,10 +65,11 @@ def test_create_relation(self) -> None: description='foo') actual = schema.create_next_relation() + serialized_actual = neo4_serializer.serialize_relationship(actual) expected = {'END_KEY': 'db://cluster.schema/_description', 'START_LABEL': 'Schema', 'END_LABEL': 'Description', 'START_KEY': 'db://cluster.schema', 'TYPE': 'DESCRIPTION', 'REVERSE_TYPE': 'DESCRIPTION_OF'} - self.assertEqual(expected, actual) + self.assertEqual(expected, serialized_actual) self.assertIsNone(schema.create_next_relation()) def test_create_relation_no_description(self) -> None: @@ -69,11 +85,12 @@ def test_create_relation_programmatic_description(self) -> None: description_source='bar') actual = schema.create_next_relation() + serialized_actual = neo4_serializer.serialize_relationship(actual) expected = { 'END_KEY': 'db://cluster.schema/_bar_description', 'START_LABEL': 'Schema', 'END_LABEL': 'Programmatic_Description', 'START_KEY': 'db://cluster.schema', 'TYPE': 'DESCRIPTION', 'REVERSE_TYPE': 'DESCRIPTION_OF' } - self.assertEqual(expected, actual) + self.assertEqual(expected, serialized_actual) self.assertIsNone(schema.create_next_relation()) diff --git a/tests/unit/models/test_application.py b/tests/unit/models/test_application.py index fc7eef641..c3f7b1a50 100644 --- a/tests/unit/models/test_application.py +++ b/tests/unit/models/test_application.py @@ -4,11 +4,12 @@ import unittest from databuilder.models.application import Application -from databuilder.models.neo4j_csv_serde import NODE_KEY, \ +from databuilder.models.graph_serializable import NODE_KEY, \ NODE_LABEL, RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE from databuilder.models.table_metadata import TableMetadata +from databuilder.serializers import neo4_serializer class TestApplication(unittest.TestCase): @@ -42,11 +43,13 @@ def setUp(self) -> None: def test_create_next_node(self) -> None: next_node = self.application.create_next_node() - self.assertEqual(next_node, self.expected_node_result) + serialized_next_node = neo4_serializer.serialize_node(next_node) + self.assertEquals(serialized_next_node, self.expected_node_result) def test_create_next_relation(self) -> None: next_relation = self.application.create_next_relation() - self.assertEqual(next_relation, self.expected_relation_result) + serialized_next_relation = neo4_serializer.serialize_relationship(next_relation) + self.assertEquals(serialized_next_relation, self.expected_relation_result) def test_get_table_model_key(self) -> None: table = self.application.get_table_model_key() @@ -58,10 +61,11 @@ def test_get_application_model_key(self) -> None: def test_create_nodes(self) -> None: nodes = self.application.create_nodes() - self.assertEqual(len(nodes), 1) - self.assertEqual(nodes[0], self.expected_node_result) + self.assertEquals(len(nodes), 1) + serialized_next_node = neo4_serializer.serialize_node(nodes[0]) + self.assertEquals(serialized_next_node, self.expected_node_result) def test_create_relation(self) -> None: relation = self.application.create_relation() - self.assertEqual(len(relation), 1) - self.assertEqual(relation[0], self.expected_relation_result) + self.assertEquals(len(relation), 1) + self.assertEquals(neo4_serializer.serialize_relationship(relation[0]), self.expected_relation_result) diff --git a/tests/unit/models/test_badge.py b/tests/unit/models/test_badge.py index 5b43e1ebc..c2b3afd22 100644 --- a/tests/unit/models/test_badge.py +++ b/tests/unit/models/test_badge.py @@ -3,10 +3,9 @@ import unittest from databuilder.models.badge import Badge, BadgeMetadata - -from databuilder.models.neo4j_csv_serde import NODE_KEY, NODE_LABEL, \ - RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ - RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE +from databuilder.serializers import neo4_serializer +from databuilder.models.graph_serializable import RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ + RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE, NODE_KEY, NODE_LABEL db = 'hive' SCHEMA = 'BASE' @@ -41,9 +40,13 @@ def test_create_nodes(self) -> None: NODE_LABEL: BadgeMetadata.BADGE_NODE_LABEL, BadgeMetadata.BADGE_CATEGORY: badge2.category } + serialized_nodes = [ + neo4_serializer.serialize_node(node) + for node in nodes + ] - self.assertTrue(node1 in nodes) - self.assertTrue(node2 in nodes) + self.assertTrue(node1 in serialized_nodes) + self.assertTrue(node2 in serialized_nodes) def test_bad_key_entity_match(self) -> None: column_label = 'Column' @@ -66,6 +69,10 @@ def test_bad_entity_label(self) -> None: def test_create_relation(self) -> None: relations = self.badge_metada.create_relation() + serialized_relations = [ + neo4_serializer.serialize_relationship(relation) + for relation in relations + ] self.assertEqual(len(relations), 2) relation1 = { @@ -85,5 +92,5 @@ def test_create_relation(self) -> None: RELATION_REVERSE_TYPE: BadgeMetadata.INVERSE_BADGE_RELATION_TYPE, } - self.assertTrue(relation1 in relations) - self.assertTrue(relation2 in relations) + self.assertTrue(relation1 in serialized_relations) + self.assertTrue(relation2 in serialized_relations) diff --git a/tests/unit/models/test_neo4j_csv_serde.py b/tests/unit/models/test_graph_serializable.py similarity index 58% rename from tests/unit/models/test_neo4j_csv_serde.py rename to tests/unit/models/test_graph_serializable.py index 7e621af5d..f46e5c169 100644 --- a/tests/unit/models/test_neo4j_csv_serde.py +++ b/tests/unit/models/test_graph_serializable.py @@ -3,13 +3,12 @@ import unittest -from typing import Union, Dict, Any, Iterable +from typing import Union, Iterable -from databuilder.models.neo4j_csv_serde import ( - NODE_KEY, NODE_LABEL, RELATION_START_KEY, RELATION_START_LABEL, - RELATION_END_KEY, RELATION_END_LABEL, RELATION_TYPE, - RELATION_REVERSE_TYPE) -from databuilder.models.neo4j_csv_serde import Neo4jCsvSerializable +from databuilder.models.graph_serializable import GraphSerializable +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.models.graph_node import GraphNode +from databuilder.serializers import neo4_serializer class TestSerialize(unittest.TestCase): @@ -22,7 +21,7 @@ def test_serialize(self) -> None: actual = [] node_row = movie.next_node() while node_row: - actual.append(node_row) + actual.append(neo4_serializer.serialize_node(node_row)) node_row = movie.next_node() expected = [ @@ -37,7 +36,7 @@ def test_serialize(self) -> None: actual = [] relation_row = movie.next_relation() while relation_row: - actual.append(relation_row) + actual.append(neo4_serializer.serialize_relationship(relation_row)) relation_row = movie.next_relation() expected = [ @@ -73,7 +72,7 @@ def __init__(self, name: str) -> None: self.name = name -class Movie(Neo4jCsvSerializable): +class Movie(GraphSerializable): LABEL = 'Movie' KEY_FORMAT = 'movie://{}' MOVIE_ACTOR_RELATION_TYPE = 'ACTOR' @@ -91,59 +90,73 @@ def __init__(self, self._node_iter = iter(self.create_nodes()) self._relation_iter = iter(self.create_relation()) - def create_next_node(self) -> Union[Dict[str, Any], None]: + def create_next_node(self) -> Union[GraphNode, None]: try: return next(self._node_iter) except StopIteration: return None - def create_next_relation(self) -> Union[Dict[str, Any], None]: + def create_next_relation(self) -> Union[GraphRelationship, None]: try: return next(self._relation_iter) except StopIteration: return None - def create_nodes(self) -> Iterable[Dict[str, Any]]: - result = [{NODE_KEY: Movie.KEY_FORMAT.format(self._name), - NODE_LABEL: Movie.LABEL, - 'name': self._name}] + def create_nodes(self) -> Iterable[GraphNode]: + result = [GraphNode( + key=Movie.KEY_FORMAT.format(self._name), + label=Movie.LABEL, + attributes={ + 'name': self._name + } + )] for actor in self._actors: - result.append({NODE_KEY: Actor.KEY_FORMAT.format(actor.name), - NODE_LABEL: Actor.LABEL, - 'name': self._name}) + actor_node = GraphNode( + key=Actor.KEY_FORMAT.format(actor.name), + label=Actor.LABEL, + attributes={ + 'name': self._name + } + ) + result.append(actor_node) for city in self._cities: - result.append({NODE_KEY: City.KEY_FORMAT.format(city.name), - NODE_LABEL: City.LABEL, - 'name': self._name}) + city_node = GraphNode( + key=City.KEY_FORMAT.format(city.name), + label=City.LABEL, + attributes={ + 'name': self._name + } + ) + result.append(city_node) return result - def create_relation(self) -> Iterable[Dict[str, Any]]: + def create_relation(self) -> Iterable[GraphRelationship]: result = [] for actor in self._actors: - result.append({RELATION_START_KEY: - Movie.KEY_FORMAT.format(self._name), - RELATION_START_LABEL: Movie.LABEL, - RELATION_END_KEY: - Actor.KEY_FORMAT.format(actor.name), - RELATION_END_LABEL: Actor.LABEL, - RELATION_TYPE: Movie.MOVIE_ACTOR_RELATION_TYPE, - RELATION_REVERSE_TYPE: - Movie.ACTOR_MOVIE_RELATION_TYPE - }) + movie_actor_relation = GraphRelationship( + start_key=Movie.KEY_FORMAT.format(self._name), + end_key=Actor.KEY_FORMAT.format(actor.name), + start_label=Movie.LABEL, + end_label=Actor.LABEL, + type=Movie.MOVIE_ACTOR_RELATION_TYPE, + reverse_type=Movie.ACTOR_MOVIE_RELATION_TYPE, + attributes={} + ) + result.append(movie_actor_relation) for city in self._cities: - result.append({RELATION_START_KEY: - City.KEY_FORMAT.format(self._name), - RELATION_START_LABEL: Movie.LABEL, - RELATION_END_KEY: - City.KEY_FORMAT.format(city.name), - RELATION_END_LABEL: City.LABEL, - RELATION_TYPE: Movie.MOVIE_CITY_RELATION_TYPE, - RELATION_REVERSE_TYPE: - Movie.CITY_MOVIE_RELATION_TYPE - }) + city_movie_relation = GraphRelationship( + start_key=City.KEY_FORMAT.format(self._name), + end_key=City.KEY_FORMAT.format(city.name), + start_label=Movie.LABEL, + end_label=City.LABEL, + type=Movie.MOVIE_CITY_RELATION_TYPE, + reverse_type=Movie.CITY_MOVIE_RELATION_TYPE, + attributes={} + ) + result.append(city_movie_relation) return result diff --git a/tests/unit/models/test_metric_metadata.py b/tests/unit/models/test_metric_metadata.py index 2f2bd5ee5..b1d0040b8 100644 --- a/tests/unit/models/test_metric_metadata.py +++ b/tests/unit/models/test_metric_metadata.py @@ -5,6 +5,7 @@ import unittest from databuilder.models.metric_metadata import MetricMetadata +from databuilder.serializers import neo4_serializer class TestMetricMetadata(unittest.TestCase): @@ -108,7 +109,8 @@ def test_serialize(self) -> None: node_row = self.metric_metadata.next_node() actual = [] while node_row: - actual.append(node_row) + serialized_node = neo4_serializer.serialize_node(node_row) + actual.append(serialized_node) node_row = self.metric_metadata.next_node() self.assertEqual(self.expected_nodes, actual) @@ -116,7 +118,8 @@ def test_serialize(self) -> None: relation_row = self.metric_metadata.next_relation() actual = [] while relation_row: - actual.append(relation_row) + serialized_relation = neo4_serializer.serialize_relationship(relation_row) + actual.append(serialized_relation) relation_row = self.metric_metadata.next_relation() self.assertEqual(self.expected_rels, actual) @@ -125,7 +128,8 @@ def test_serialize(self) -> None: node_row = self.metric_metadata2.next_node() actual = [] while node_row: - actual.append(node_row) + serialized_node = neo4_serializer.serialize_node(node_row) + actual.append(serialized_node) node_row = self.metric_metadata2.next_node() self.assertEqual(self.expected_nodes_deduped2, actual) @@ -133,7 +137,8 @@ def test_serialize(self) -> None: relation_row = self.metric_metadata2.next_relation() actual = [] while relation_row: - actual.append(relation_row) + serialized_relation = neo4_serializer.serialize_relationship(relation_row) + actual.append(serialized_relation) relation_row = self.metric_metadata2.next_relation() self.assertEqual(self.expected_rels_deduped2, actual) @@ -142,7 +147,8 @@ def test_serialize(self) -> None: node_row = self.metric_metadata3.next_node() actual = [] while node_row: - actual.append(node_row) + serialized_node = neo4_serializer.serialize_node(node_row) + actual.append(serialized_node) node_row = self.metric_metadata3.next_node() self.assertEqual(self.expected_nodes_deduped3, actual) @@ -150,7 +156,8 @@ def test_serialize(self) -> None: relation_row = self.metric_metadata3.next_relation() actual = [] while relation_row: - actual.append(relation_row) + serialized_relation = neo4_serializer.serialize_relationship(relation_row) + actual.append(serialized_relation) relation_row = self.metric_metadata3.next_relation() self.assertEqual(self.expected_rels_deduped3, actual) diff --git a/tests/unit/models/test_neo4j_es_last_updated.py b/tests/unit/models/test_neo4j_es_last_updated.py index 1b9b2cce0..f40b9a1b1 100644 --- a/tests/unit/models/test_neo4j_es_last_updated.py +++ b/tests/unit/models/test_neo4j_es_last_updated.py @@ -4,8 +4,9 @@ import unittest from databuilder.models.neo4j_es_last_updated import Neo4jESLastUpdated -from databuilder.models.neo4j_csv_serde import NODE_KEY, \ +from databuilder.models.graph_serializable import NODE_KEY, \ NODE_LABEL +from databuilder.serializers import neo4_serializer class TestNeo4jESLastUpdated(unittest.TestCase): @@ -17,17 +18,18 @@ def setUp(self) -> None: self.expected_node_result = { NODE_KEY: 'amundsen_updated_timestamp', NODE_LABEL: 'Updatedtimestamp', - 'latest_timestmap': 100, + 'latest_timestmap:UNQUOTED': 100, } def test_create_nodes(self) -> None: nodes = self.neo4j_es_last_updated.create_nodes() - self.assertEqual(len(nodes), 1) - self.assertEqual(nodes[0], self.expected_node_result) + self.assertEquals(len(nodes), 1) + serialized_node = neo4_serializer.serialize_node(nodes[0]) + self.assertEquals(serialized_node, self.expected_node_result) def test_create_next_node(self) -> None: next_node = self.neo4j_es_last_updated.create_next_node() - self.assertEqual(next_node, self.expected_node_result) + self.assertEquals(neo4_serializer.serialize_node(next_node), self.expected_node_result) def test_create_next_relation(self) -> None: self.assertIs(self.neo4j_es_last_updated.create_next_relation(), None) diff --git a/tests/unit/models/test_table_column_usage.py b/tests/unit/models/test_table_column_usage.py index 925041af4..78da735f4 100644 --- a/tests/unit/models/test_table_column_usage.py +++ b/tests/unit/models/test_table_column_usage.py @@ -5,6 +5,7 @@ from databuilder.models.table_column_usage import ColumnReader, TableColumnUsage from typing import no_type_check +from databuilder.serializers import neo4_serializer class TestTableColumnUsage(unittest.TestCase): @@ -19,7 +20,8 @@ def test_serialize(self) -> None: node_row = table_col_usage.next_node() actual = [] while node_row: - actual.append(node_row) + + actual.append(neo4_serializer.serialize_node(node_row)) node_row = table_col_usage.next_node() expected = [{'first_name': '', @@ -27,7 +29,7 @@ def test_serialize(self) -> None: 'full_name': '', 'employee_type': '', 'is_active:UNQUOTED': True, - 'updated_at': 0, + 'updated_at:UNQUOTED': 0, 'LABEL': 'User', 'slack_id': '', 'KEY': 'john@example.com', @@ -40,7 +42,7 @@ def test_serialize(self) -> None: rel_row = table_col_usage.next_relation() actual = [] while rel_row: - actual.append(rel_row) + actual.append(neo4_serializer.serialize_relationship(rel_row)) rel_row = table_col_usage.next_relation() expected = [{'read_count:UNQUOTED': 1, 'END_KEY': 'john@example.com', 'START_LABEL': 'Table', diff --git a/tests/unit/models/test_table_last_updated.py b/tests/unit/models/test_table_last_updated.py index 88a8ad571..cc732339b 100644 --- a/tests/unit/models/test_table_last_updated.py +++ b/tests/unit/models/test_table_last_updated.py @@ -3,11 +3,12 @@ import unittest -from databuilder.models.neo4j_csv_serde import NODE_KEY, \ +from databuilder.models.graph_serializable import NODE_KEY, \ NODE_LABEL, RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE from databuilder.models.table_last_updated import TableLastUpdated from databuilder.models.timestamp import timestamp_constants +from databuilder.serializers import neo4_serializer class TestTableLastUpdated(unittest.TestCase): @@ -22,8 +23,8 @@ def setUp(self) -> None: self.expected_node_result = { NODE_KEY: 'hive://gold.default/test_table/timestamp', NODE_LABEL: 'Timestamp', - 'last_updated_timestamp': 25195665, - timestamp_constants.TIMESTAMP_PROPERTY: 25195665, + 'last_updated_timestamp:UNQUOTED': 25195665, + timestamp_constants.TIMESTAMP_PROPERTY + ":UNQUOTED": 25195665, 'name': 'last_updated_timestamp' } @@ -38,11 +39,13 @@ def setUp(self) -> None: def test_create_next_node(self) -> None: next_node = self.tableLastUpdated.create_next_node() - self.assertEqual(next_node, self.expected_node_result) + next_node_serialized = neo4_serializer.serialize_node(next_node) + self.assertEqual(next_node_serialized, self.expected_node_result) def test_create_next_relation(self) -> None: next_relation = self.tableLastUpdated.create_next_relation() - self.assertEqual(next_relation, self.expected_relation_result) + next_relation_serialized = neo4_serializer.serialize_relationship(next_relation) + self.assertEqual(next_relation_serialized, self.expected_relation_result) def test_get_table_model_key(self) -> None: table = self.tableLastUpdated.get_table_model_key() @@ -54,10 +57,12 @@ def test_get_last_updated_model_key(self) -> None: def test_create_nodes(self) -> None: nodes = self.tableLastUpdated.create_nodes() - self.assertEqual(len(nodes), 1) - self.assertEqual(nodes[0], self.expected_node_result) + self.assertEquals(len(nodes), 1) + serialize_node = neo4_serializer.serialize_node(nodes[0]) + self.assertEquals(serialize_node, self.expected_node_result) def test_create_relation(self) -> None: relation = self.tableLastUpdated.create_relation() - self.assertEqual(len(relation), 1) - self.assertEqual(relation[0], self.expected_relation_result) + self.assertEquals(len(relation), 1) + serialized_relation = neo4_serializer.serialize_relationship(relation[0]) + self.assertEquals(serialized_relation, self.expected_relation_result) diff --git a/tests/unit/models/test_table_lineage.py b/tests/unit/models/test_table_lineage.py index 80f814cb5..2119d518a 100644 --- a/tests/unit/models/test_table_lineage.py +++ b/tests/unit/models/test_table_lineage.py @@ -4,8 +4,9 @@ import unittest from databuilder.models.table_lineage import TableLineage -from databuilder.models.neo4j_csv_serde import RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ +from databuilder.models.graph_serializable import RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE +from databuilder.serializers import neo4_serializer DB = 'hive' @@ -57,5 +58,9 @@ def test_create_relation(self) -> None: RELATION_TYPE: TableLineage.ORIGIN_DEPENDENCY_RELATION_TYPE, RELATION_REVERSE_TYPE: TableLineage.DEPENDENCY_ORIGIN_RELATION_TYPE } + actual_relations = [ + neo4_serializer.serialize_relationship(relation) + for relation in relations + ] self.assertTrue(len(relations), 2) - self.assertTrue(relation in relations) + self.assertTrue(relation in actual_relations) diff --git a/tests/unit/models/test_table_metadata.py b/tests/unit/models/test_table_metadata.py index 1c9bdfbea..12998aca6 100644 --- a/tests/unit/models/test_table_metadata.py +++ b/tests/unit/models/test_table_metadata.py @@ -5,13 +5,14 @@ import unittest from databuilder.models.table_metadata import ColumnMetadata, TableMetadata +from databuilder.serializers import neo4_serializer class TestTableMetadata(unittest.TestCase): def setUp(self) -> None: super(TestTableMetadata, self).setUp() - TableMetadata.serialized_nodes = set() - TableMetadata.serialized_rels = set() + TableMetadata.serialized_nodes_keys = set() + TableMetadata.serialized_rels_keys = set() def test_serialize(self) -> None: self.table_metadata = TableMetadata('hive', 'gold', 'test_schema1', 'test_table1', 'test_table1', [ @@ -110,7 +111,8 @@ def test_serialize(self) -> None: node_row = self.table_metadata.next_node() actual = [] while node_row: - actual.append(node_row) + node_row_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_row_serialized) node_row = self.table_metadata.next_node() for i in range(0, len(self.expected_nodes)): self.assertEqual(actual[i], self.expected_nodes[i]) @@ -118,7 +120,8 @@ def test_serialize(self) -> None: relation_row = self.table_metadata.next_relation() actual = [] while relation_row: - actual.append(relation_row) + relation_row_serialized = neo4_serializer.serialize_relationship(relation_row) + actual.append(relation_row_serialized) relation_row = self.table_metadata.next_relation() for i in range(0, len(self.expected_rels)): self.assertEqual(actual[i], self.expected_rels[i]) @@ -127,7 +130,8 @@ def test_serialize(self) -> None: node_row = self.table_metadata2.next_node() actual = [] while node_row: - actual.append(node_row) + node_row_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_row_serialized) node_row = self.table_metadata2.next_node() self.assertEqual(self.expected_nodes_deduped, actual) @@ -135,7 +139,8 @@ def test_serialize(self) -> None: relation_row = self.table_metadata2.next_relation() actual = [] while relation_row: - actual.append(relation_row) + relation_row_serialized = neo4_serializer.serialize_relationship(relation_row) + actual.append(relation_row_serialized) relation_row = self.table_metadata2.next_relation() self.assertEqual(self.expected_rels_deduped, actual) @@ -152,7 +157,8 @@ def test_table_attributes(self) -> None: node_row = self.table_metadata3.next_node() actual = [] while node_row: - actual.append(node_row) + node_row_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_row_serialized) node_row = self.table_metadata3.next_node() self.assertEqual(actual[0].get('attr1'), 'uri') @@ -171,7 +177,8 @@ def test_z_custom_sources(self) -> None: node_row = self.custom_source.next_node() actual = [] while node_row: - actual.append(node_row) + node_row_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_row_serialized) node_row = self.custom_source.next_node() expected = {'LABEL': 'Programmatic_Description', 'KEY': 'hive://gold.test_schema3/test_table4/_custom_description', @@ -186,7 +193,8 @@ def test_tags_field(self) -> None: node_row = self.table_metadata4.next_node() actual = [] while node_row: - actual.append(node_row) + node_row_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_row_serialized) node_row = self.table_metadata4.next_node() self.assertEqual(actual[0].get('attr1'), 'uri') @@ -199,7 +207,8 @@ def test_tags_field(self) -> None: relation_row = self.table_metadata4.next_relation() actual = [] while relation_row: - actual.append(relation_row) + relation_row_serialized = neo4_serializer.serialize_relationship(relation_row) + actual.append(relation_row_serialized) relation_row = self.table_metadata4.next_relation() # Table tag relationship @@ -221,7 +230,8 @@ def test_col_badge_field(self) -> None: node_row = self.table_metadata4.next_node() actual = [] while node_row: - actual.append(node_row) + serialized_node_row = neo4_serializer.serialize_node(node_row) + actual.append(serialized_node_row) node_row = self.table_metadata4.next_node() self.assertEqual(actual[4].get('KEY'), 'col-badge1') @@ -230,7 +240,8 @@ def test_col_badge_field(self) -> None: relation_row = self.table_metadata4.next_relation() actual = [] while relation_row: - actual.append(relation_row) + serialized_relation_row = neo4_serializer.serialize_relationship(relation_row) + actual.append(serialized_relation_row) relation_row = self.table_metadata4.next_relation() expected_col_badge_rel1 = {'END_KEY': 'col-badge1', 'START_LABEL': 'Column', @@ -253,7 +264,8 @@ def test_tags_populated_from_str(self) -> None: node_row = self.table_metadata5.next_node() actual = [] while node_row: - actual.append(node_row) + node_row_serialized = neo4_serializer.serialize_node(node_row) + actual.append(node_row_serialized) node_row = self.table_metadata5.next_node() self.assertEqual(actual[2].get('LABEL'), 'Tag') @@ -263,7 +275,8 @@ def test_tags_populated_from_str(self) -> None: relation_row = self.table_metadata5.next_relation() actual = [] while relation_row: - actual.append(relation_row) + relation_row_serialized = neo4_serializer.serialize_relationship(relation_row) + actual.append(relation_row_serialized) relation_row = self.table_metadata5.next_relation() # Table tag relationship @@ -286,13 +299,15 @@ def test_tags_arent_populated_from_empty_list_and_str(self) -> None: # Test table tag fields are not populated from empty List node_row = self.table_metadata6.next_node() while node_row: - self.assertNotEqual(node_row.get('LABEL'), 'Tag') + node_row_serialized = neo4_serializer.serialize_node(node_row) + self.assertNotEqual(node_row_serialized.get('LABEL'), 'Tag') node_row = self.table_metadata6.next_node() # Test table tag fields are not populated from empty str node_row = self.table_metadata7.next_node() while node_row: - self.assertNotEqual(node_row.get('LABEL'), 'Tag') + node_row_serialized = neo4_serializer.serialize_node(node_row) + self.assertNotEqual(node_row_serialized.get('LABEL'), 'Tag') node_row = self.table_metadata7.next_node() diff --git a/tests/unit/models/test_table_owner.py b/tests/unit/models/test_table_owner.py index 98d58c519..841a6da29 100644 --- a/tests/unit/models/test_table_owner.py +++ b/tests/unit/models/test_table_owner.py @@ -6,9 +6,10 @@ from databuilder.models.table_owner import TableOwner -from databuilder.models.neo4j_csv_serde import NODE_KEY, NODE_LABEL, \ +from databuilder.models.graph_serializable import NODE_KEY, NODE_LABEL, \ RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE +from databuilder.serializers import neo4_serializer db = 'hive' @@ -41,25 +42,29 @@ def test_create_nodes(self) -> None: nodes = self.table_owner.create_nodes() self.assertEqual(len(nodes), 2) - node1 = { + expected_node1 = { NODE_KEY: User.USER_NODE_KEY_FORMAT.format(email=owner1), NODE_LABEL: User.USER_NODE_LABEL, User.USER_NODE_EMAIL: owner1 } - node2 = { + expected_node2 = { NODE_KEY: User.USER_NODE_KEY_FORMAT.format(email=owner2), NODE_LABEL: User.USER_NODE_LABEL, User.USER_NODE_EMAIL: owner2 } + actual_nodes = [ + neo4_serializer.serialize_node(node) + for node in nodes + ] - self.assertTrue(node1 in nodes) - self.assertTrue(node2 in nodes) + self.assertTrue(expected_node1 in actual_nodes) + self.assertTrue(expected_node2 in actual_nodes) def test_create_relation(self) -> None: relations = self.table_owner.create_relation() self.assertEqual(len(relations), 2) - relation1 = { + expected_relation1 = { RELATION_START_KEY: owner1, RELATION_START_LABEL: User.USER_NODE_LABEL, RELATION_END_KEY: self.table_owner.get_metadata_model_key(), @@ -67,7 +72,7 @@ def test_create_relation(self) -> None: RELATION_TYPE: TableOwner.OWNER_TABLE_RELATION_TYPE, RELATION_REVERSE_TYPE: TableOwner.TABLE_OWNER_RELATION_TYPE } - relation2 = { + expected_relation2 = { RELATION_START_KEY: owner2, RELATION_START_LABEL: User.USER_NODE_LABEL, RELATION_END_KEY: self.table_owner.get_metadata_model_key(), @@ -76,8 +81,13 @@ def test_create_relation(self) -> None: RELATION_REVERSE_TYPE: TableOwner.TABLE_OWNER_RELATION_TYPE } - self.assertTrue(relation1 in relations) - self.assertTrue(relation2 in relations) + actual_relations = [ + neo4_serializer.serialize_relationship(relation) + for relation in relations + ] + + self.assertTrue(expected_relation1 in actual_relations) + self.assertTrue(expected_relation2 in actual_relations) def test_create_nodes_with_owners_list(self) -> None: self.table_owner_list = TableOwner(db_name='hive', @@ -88,16 +98,20 @@ def test_create_nodes_with_owners_list(self) -> None: nodes = self.table_owner_list.create_nodes() self.assertEqual(len(nodes), 2) - node1 = { + expected_node1 = { NODE_KEY: User.USER_NODE_KEY_FORMAT.format(email=owner1), NODE_LABEL: User.USER_NODE_LABEL, User.USER_NODE_EMAIL: owner1 } - node2 = { + expected_node2 = { NODE_KEY: User.USER_NODE_KEY_FORMAT.format(email=owner2), NODE_LABEL: User.USER_NODE_LABEL, User.USER_NODE_EMAIL: owner2 } + actual_nodes = [ + neo4_serializer.serialize_node(node) + for node in nodes + ] - self.assertTrue(node1 in nodes) - self.assertTrue(node2 in nodes) + self.assertTrue(expected_node1 in actual_nodes) + self.assertTrue(expected_node2 in actual_nodes) diff --git a/tests/unit/models/test_table_source.py b/tests/unit/models/test_table_source.py index 3b7f6ca72..754856adc 100644 --- a/tests/unit/models/test_table_source.py +++ b/tests/unit/models/test_table_source.py @@ -4,8 +4,9 @@ import unittest from databuilder.models.table_source import TableSource -from databuilder.models.neo4j_csv_serde import RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ +from databuilder.models.graph_serializable import RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE +from databuilder.serializers import neo4_serializer DB = 'hive' @@ -43,7 +44,8 @@ def test_create_nodes(self) -> None: def test_create_relation(self) -> None: relations = self.table_source.create_relation() - self.assertEqual(len(relations), 1) + self.assertEquals(len(relations), 1) + serialized_relation = neo4_serializer.serialize_relationship(relations[0]) start_key = '{db}://{cluster}.{schema}/{tbl}/_source'.format(db=DB, schema=SCHEMA, @@ -54,7 +56,7 @@ def test_create_relation(self) -> None: tbl=TABLE, cluster=CLUSTER) - relation = { + expected_relation = { RELATION_START_KEY: start_key, RELATION_START_LABEL: TableSource.LABEL, RELATION_END_KEY: end_key, @@ -63,4 +65,4 @@ def test_create_relation(self) -> None: RELATION_REVERSE_TYPE: TableSource.TABLE_SOURCE_RELATION_TYPE } - self.assertTrue(relation in relations) + self.assertDictEqual(expected_relation, serialized_relation) diff --git a/tests/unit/models/test_table_stats.py b/tests/unit/models/test_table_stats.py index 4d5ba9652..8a67c8e9a 100644 --- a/tests/unit/models/test_table_stats.py +++ b/tests/unit/models/test_table_stats.py @@ -4,9 +4,10 @@ import unittest from databuilder.models.table_stats import TableColumnStats -from databuilder.models.neo4j_csv_serde import NODE_KEY, \ +from databuilder.models.graph_serializable import NODE_KEY, \ NODE_LABEL, RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE +from databuilder.serializers import neo4_serializer class TestTableStats(unittest.TestCase): @@ -23,7 +24,7 @@ def setUp(self) -> None: self.expected_node_result = { NODE_KEY: 'hive://gold.base/test/col/avg/', NODE_LABEL: 'Stat', - 'stat_val:UNQUOTED': '1', + 'stat_val:UNQUOTED': 1, 'stat_name': 'avg', 'start_epoch': '1', 'end_epoch': '2', @@ -48,18 +49,23 @@ def test_get_col_key(self) -> None: def test_create_nodes(self) -> None: nodes = self.table_stats.create_nodes() - self.assertEqual(len(nodes), 1) - self.assertEqual(nodes[0], self.expected_node_result) + self.assertEquals(len(nodes), 1) + serialized_node = neo4_serializer.serialize_node(nodes[0]) + self.assertEquals(serialized_node, self.expected_node_result) def test_create_relation(self) -> None: relation = self.table_stats.create_relation() - self.assertEqual(len(relation), 1) - self.assertEqual(relation[0], self.expected_relation_result) + + self.assertEquals(len(relation), 1) + serialized_relation = neo4_serializer.serialize_relationship(relation[0]) + self.assertEquals(serialized_relation, self.expected_relation_result) def test_create_next_node(self) -> None: next_node = self.table_stats.create_next_node() - self.assertEqual(next_node, self.expected_node_result) + serialized_node = neo4_serializer.serialize_node(next_node) + self.assertEquals(serialized_node, self.expected_node_result) def test_create_next_relation(self) -> None: next_relation = self.table_stats.create_next_relation() - self.assertEqual(next_relation, self.expected_relation_result) + serialized_relation = neo4_serializer.serialize_relationship(next_relation) + self.assertEquals(serialized_relation, self.expected_relation_result) diff --git a/tests/unit/models/test_user.py b/tests/unit/models/test_user.py index 8dda34d87..f974cc1d6 100644 --- a/tests/unit/models/test_user.py +++ b/tests/unit/models/test_user.py @@ -3,9 +3,10 @@ import unittest -from databuilder.models.neo4j_csv_serde import RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ +from databuilder.models.graph_serializable import RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE from databuilder.models.user import User +from databuilder.serializers import neo4_serializer class TestUser(unittest.TestCase): @@ -48,9 +49,10 @@ def test_create_node_additional_attr(self) -> None: role_name='swe', enable_notify=True) nodes = test_user.create_nodes() - self.assertEqual(nodes[0]['email'], 'test@email.com') - self.assertEqual(nodes[0]['role_name'], 'swe') - self.assertTrue(nodes[0]['enable_notify']) + serialized_node = neo4_serializer.serialize_node(nodes[0]) + self.assertEqual(serialized_node['email'], 'test@email.com') + self.assertEqual(serialized_node['role_name'], 'swe') + self.assertTrue(serialized_node['enable_notify:UNQUOTED']) def test_create_relation(self) -> None: relations = self.user.create_relation() @@ -59,7 +61,7 @@ def test_create_relation(self) -> None: start_key = '{email}'.format(email='test@email.com') end_key = '{email}'.format(email='test_manager@email.com') - relation = { + expected_relation = { RELATION_START_KEY: start_key, RELATION_START_LABEL: User.USER_NODE_LABEL, RELATION_END_KEY: end_key, @@ -68,22 +70,22 @@ def test_create_relation(self) -> None: RELATION_REVERSE_TYPE: User.MANAGER_USER_RELATION_TYPE } - self.assertTrue(relation in relations) + self.assertTrue(expected_relation, neo4_serializer.serialize_relationship(relations[0])) def test_not_including_empty_attribute(self) -> None: test_user = User(email='test@email.com', foo='bar') - self.assertDictEqual(test_user.create_next_node() or {}, + self.assertDictEqual(neo4_serializer.serialize_node(test_user.create_next_node()), {'KEY': 'test@email.com', 'LABEL': 'User', 'email': 'test@email.com', 'is_active:UNQUOTED': True, 'first_name': '', 'last_name': '', 'full_name': '', 'github_username': '', 'team_name': '', 'employee_type': '', 'slack_id': '', - 'role_name': '', 'updated_at': 0, 'foo': 'bar'}) + 'role_name': '', 'updated_at:UNQUOTED': 0, 'foo': 'bar'}) test_user2 = User(email='test@email.com', foo='bar', is_active=False, do_not_update_empty_attribute=True) - self.assertDictEqual(test_user2.create_next_node() or {}, + self.assertDictEqual(neo4_serializer.serialize_node(test_user2.create_next_node()), {'KEY': 'test@email.com', 'LABEL': 'User', 'email': 'test@email.com', 'foo': 'bar'}) diff --git a/tests/unit/models/test_watermark.py b/tests/unit/models/test_watermark.py index 7fcfe7c1c..01e6ed248 100644 --- a/tests/unit/models/test_watermark.py +++ b/tests/unit/models/test_watermark.py @@ -4,10 +4,19 @@ import unittest from databuilder.models.watermark import Watermark -from databuilder.models.neo4j_csv_serde import NODE_KEY, \ - NODE_LABEL, RELATION_START_KEY, RELATION_START_LABEL, RELATION_END_KEY, \ - RELATION_END_LABEL, RELATION_TYPE, RELATION_REVERSE_TYPE - +from databuilder.models.graph_serializable import ( + NODE_KEY, + NODE_LABEL, + RELATION_START_KEY, + RELATION_START_LABEL, + RELATION_END_KEY, + RELATION_END_LABEL, + RELATION_TYPE, + RELATION_REVERSE_TYPE +) +from databuilder.models.graph_node import GraphNode +from databuilder.models.graph_relationship import GraphRelationship +from databuilder.serializers import neo4_serializer CREATE_TIME = '2017-09-18T00:00:00' DATABASE = 'DYNAMO' @@ -22,43 +31,60 @@ class TestWatermark(unittest.TestCase): def setUp(self) -> None: super(TestWatermark, self).setUp() - self.watermark = Watermark(create_time='2017-09-18T00:00:00', - database=DATABASE, - schema=SCHEMA, - table_name=TABLE, - cluster=CLUSTER, - part_type=PART_TYPE, - part_name=NESTED_PART) - - self.expected_node_result = { - NODE_KEY: '{database}://{cluster}.{schema}/{table}/{part_type}/' - .format( - database=DATABASE, - cluster=CLUSTER, - schema=SCHEMA, - table=TABLE, - part_type=PART_TYPE), + self.watermark = Watermark( + create_time='2017-09-18T00:00:00', + database=DATABASE, + schema=SCHEMA, + table_name=TABLE, + cluster=CLUSTER, + part_type=PART_TYPE, + part_name=NESTED_PART + ) + start_key = '{database}://{cluster}.{schema}/{table}/{part_type}/'.format( + database=DATABASE, + cluster=CLUSTER, + schema=SCHEMA, + table=TABLE, + part_type=PART_TYPE + ) + end_key = '{database}://{cluster}.{schema}/{table}'.format( + database=DATABASE, + cluster=CLUSTER, + schema=SCHEMA, + table=TABLE + ) + self.expected_node_result = GraphNode( + key=start_key, + label='Watermark', + attributes={ + 'partition_key': 'ds', + 'partition_value': '2017-09-18/feature_id=9', + 'create_time': '2017-09-18T00:00:00' + } + ) + + self.expected_serialized_node_result = { + NODE_KEY: start_key, NODE_LABEL: 'Watermark', 'partition_key': 'ds', 'partition_value': '2017-09-18/feature_id=9', 'create_time': '2017-09-18T00:00:00' } - self.expected_relation_result = { - RELATION_START_KEY: '{database}://{cluster}.{schema}/{table}/{part_type}/' - .format( - database=DATABASE, - cluster=CLUSTER, - schema=SCHEMA, - table=TABLE, - part_type=PART_TYPE), + self.expected_relation_result = GraphRelationship( + start_label='Watermark', + end_label='Table', + start_key=start_key, + end_key=end_key, + type='BELONG_TO_TABLE', + reverse_type='WATERMARK', + attributes={} + ) + + self.expected_serialized_relation_result = { + RELATION_START_KEY: start_key, RELATION_START_LABEL: 'Watermark', - RELATION_END_KEY: '{database}://{cluster}.{schema}/{table}' - .format( - database=DATABASE, - cluster=CLUSTER, - schema=SCHEMA, - table=TABLE), + RELATION_END_KEY: end_key, RELATION_END_LABEL: 'Table', RELATION_TYPE: 'BELONG_TO_TABLE', RELATION_REVERSE_TYPE: 'WATERMARK' @@ -84,18 +110,24 @@ def test_get_metadata_model_key(self) -> None: def test_create_nodes(self) -> None: nodes = self.watermark.create_nodes() - self.assertEqual(len(nodes), 1) - self.assertEqual(nodes[0], self.expected_node_result) + self.assertEquals(len(nodes), 1) + + self.assertEquals(nodes[0], self.expected_node_result) + self.assertEqual(neo4_serializer.serialize_node(nodes[0]), self.expected_serialized_node_result) def test_create_relation(self) -> None: relation = self.watermark.create_relation() - self.assertEqual(len(relation), 1) - self.assertEqual(relation[0], self.expected_relation_result) + self.assertEquals(len(relation), 1) + self.assertEquals(relation[0], self.expected_relation_result) + self.assertEqual(neo4_serializer.serialize_relationship(relation[0]), self.expected_serialized_relation_result) def test_create_next_node(self) -> None: next_node = self.watermark.create_next_node() - self.assertEqual(next_node, self.expected_node_result) + self.assertEquals(neo4_serializer.serialize_node(next_node), self.expected_serialized_node_result) def test_create_next_relation(self) -> None: next_relation = self.watermark.create_next_relation() - self.assertEqual(next_relation, self.expected_relation_result) + self.assertEquals( + neo4_serializer.serialize_relationship(next_relation), + self.expected_serialized_relation_result + )