Skip to content

Commit

Permalink
fix: Support expressions and improve handling of column alias
Browse files Browse the repository at this point in the history
Support expressions in target list
Support column alias in WITH clause and column aliases in general

Add source to visit_dml_query and parser sdk.

Check for semantic errors like mismatched source and target columns

Scope binding to a specific source
Fix bugs and improve error handling in get_source and get_schema
  • Loading branch information
vrajat committed Jun 26, 2021
1 parent d3ebab7 commit 3684a08
Show file tree
Hide file tree
Showing 14 changed files with 480 additions and 154 deletions.
119 changes: 112 additions & 7 deletions data_lineage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@
from data_lineage.graph import LineageGraph


class SourceNotFound(Exception):
"""Source not found in catalog"""


class SchemaNotFound(Exception):
"""Schema not found in catalog"""


class TableNotFound(Exception):
"""Table not found in catalog"""

Expand All @@ -25,6 +33,18 @@ class ParseError(Exception):
"""Parser Error"""


class SemanticError(Exception):
"""Error due to mismatch in catalog data"""


class NoResultFound(Exception):
"""Raised when function returns no results"""


class MultipleResultsFound(Exception):
"""Raised when multiple results are found but expected only one or zero results"""


class Graph:
def __init__(self, url: str):
self._base_url = furl(url) / "api/main"
Expand Down Expand Up @@ -145,6 +165,24 @@ def _get(self, path: str, obj_id: int) -> Dict[Any, Any]:
raise RuntimeError(json_response["error"])
return json_response["data"]

@staticmethod
def _one(response):
json_response = response.json()
logging.debug(json_response)
num_results = json_response["meta"]["total"]
if num_results == 0:
raise NoResultFound
elif num_results > 1:
raise MultipleResultsFound

return json_response["data"][0]

def _search_one(self, path: str, filters):
params = {"filter[objects]": json.dumps(filters)}
response = self._session.get(self._build_url(path), params=params)
response.raise_for_status()
return Catalog._one(response)

def _search(self, path: str, search_string: str, clazz: Type[BaseModel]):
filters = [dict(name="name", op="like", val="%{}%".format(search_string))]
params = {"filter[objects]": json.dumps(filters)}
Expand Down Expand Up @@ -252,11 +290,40 @@ def get_column_lineage(self, job_ids: List[int]) -> List[ColumnLineage]:
for item in response.json()["data"]
]

def get_source_by_name(self, name):
return self._search("sources", name, Source)
def get_source(self, name) -> Source:
filters = [dict(name="name", op="eq", val="{}".format(name))]
try:
payload = self._search_one("sources", filters)
except NoResultFound:
raise SourceNotFound("Source not found: source_name={}".format(name))
return Source(
session=self._session,
attributes=payload["attributes"],
obj_id=payload["id"],
relationships=payload["relationships"],
)

def get_schema_by_name(self, name):
return self._search("schemata", name, Schema)
def get_schema(self, source_name: str, schema_name: str) -> Schema:
name_filter = dict(name="name", op="eq", val="{}".format(schema_name))
source_filter = dict(
name="source", op="has", val=dict(name="name", op="eq", val=source_name)
)
filters = {"and": [name_filter, source_filter]}
logging.debug(filters)
try:
payload = self._search_one("schemata", [filters])
except NoResultFound:
raise SchemaNotFound(
"Schema not found, (source_name={}, schema_name={})".format(
source_name, schema_name
)
)
return Schema(
session=self._session,
attributes=payload["attributes"],
obj_id=payload["id"],
relationships=payload["relationships"],
)

def get_table_by_name(self, name):
return self._search("tables", name, Table)
Expand Down Expand Up @@ -351,6 +418,43 @@ def scan_source(self, source: Source) -> bool:
)
return response.status_code == 200

def add_schema(self, name: str, source: Source) -> Schema:
data = {"name": name, "source_id": source.id}
payload = self._post(path="schemata", data=data, type="schemata")
return Schema(
session=self._session,
attributes=payload["attributes"],
obj_id=payload["id"],
relationships=payload["relationships"],
)

def add_table(self, name: str, schema: Schema) -> Table:
data = {"name": name, "schema_id": schema.id}
payload = self._post(path="tables", data=data, type="tables")
return Table(
session=self._session,
attributes=payload["attributes"],
obj_id=payload["id"],
relationships=payload["relationships"],
)

def add_column(
self, name: str, data_type: str, sort_order: int, table: Table
) -> Column:
data = {
"name": name,
"table_id": table.id,
"data_type": data_type,
"sort_order": sort_order,
}
payload = self._post(path="columns", data=data, type="columns")
return Column(
session=self._session,
attributes=payload["attributes"],
obj_id=payload["id"],
relationships=payload["relationships"],
)

def add_job(self, name: str, context: Dict[Any, Any]) -> Job:
data = {"name": name, "context": context}
payload = self._post(path="jobs", data=data, type="jobs")
Expand Down Expand Up @@ -410,9 +514,10 @@ def __init__(self, url: str):
self._base_url = furl(url) / "api/v1/parser"
self._session = requests.Session()

def parse(self, query: str, name: str = None) -> JobExecution:
def parse(self, query: str, source: Source, name: str = None) -> JobExecution:
response = self._session.post(
self._base_url, params={"query": query, "name": name}
self._base_url,
params={"query": query, "name": name, "source_id": source.id},
)
if response.status_code == 441:
raise TableNotFound(response.json()["message"])
Expand All @@ -421,8 +526,8 @@ def parse(self, query: str, name: str = None) -> JobExecution:
elif response.status_code == 422:
raise ParseError(response.json()["message"])

response.raise_for_status()
logging.debug(response.json())
response.raise_for_status()
payload = response.json()["data"]
return JobExecution(
session=self._session,
Expand Down
2 changes: 2 additions & 0 deletions data_lineage/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ def main(
is_production,
):
logging.basicConfig(level=getattr(logging, log_level.upper()))
sqlalchemy_logger = logging.getLogger("sqlalchemy.engine")
sqlalchemy_logger.setLevel(getattr(logging, log_level.upper()))
catalog = {
"user": catalog_user,
"password": catalog_password,
Expand Down
34 changes: 22 additions & 12 deletions data_lineage/parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import List, Optional

from dbcat.catalog import Catalog
from dbcat.catalog.models import JobExecution, JobExecutionStatus
from dbcat.catalog.models import CatSource, JobExecution, JobExecutionStatus
from pglast.parser import ParseError

from data_lineage.parser.dml_visitor import (
Expand All @@ -13,6 +13,7 @@
SelectSourceVisitor,
)
from data_lineage.parser.node import Parsed, parse
from data_lineage.parser.visitor import ExprVisitor, RedshiftExprVisitor


def parse_queries(queries: List[str]) -> List[Parsed]:
Expand All @@ -27,16 +28,24 @@ def parse_queries(queries: List[str]) -> List[Parsed]:
return parsed


def visit_dml_query(catalog: Catalog, parsed: Parsed) -> Optional[DmlVisitor]:
select_source_visitor: DmlVisitor = SelectSourceVisitor(parsed.name)
select_into_visitor: DmlVisitor = SelectIntoVisitor(parsed.name)
copy_from_visitor: DmlVisitor = CopyFromVisitor(parsed.name)
def visit_dml_query(
catalog: Catalog, parsed: Parsed, source: CatSource,
) -> Optional[DmlVisitor]:
expr_visitor_clazz = ExprVisitor
if source.source_type == "redshift":
expr_visitor_clazz = RedshiftExprVisitor

for visitor in [select_source_visitor, select_into_visitor, copy_from_visitor]:
parsed.node.accept(visitor)
if len(visitor.select_tables) > 0 and visitor.insert_table is not None:
visitor.bind(catalog)
return visitor
select_source_visitor: DmlVisitor = SelectSourceVisitor(
parsed.name, expr_visitor_clazz
)
select_into_visitor: DmlVisitor = SelectIntoVisitor(parsed.name, expr_visitor_clazz)
copy_from_visitor: DmlVisitor = CopyFromVisitor(parsed.name, expr_visitor_clazz)

for v in [select_source_visitor, select_into_visitor, copy_from_visitor]:
parsed.node.accept(v)
if len(v.select_tables) > 0 and v.insert_table is not None:
v.bind(catalog, source)
return v
return None


Expand All @@ -50,7 +59,8 @@ def extract_lineage(
for source, target in zip(
visited_query.source_columns, visited_query.target_columns
):
edge = catalog.add_column_lineage(source, target, job_execution.id, {})
logging.debug("Added {}".format(edge))
for column in source.columns:
edge = catalog.add_column_lineage(column, target, job_execution.id, {})
logging.debug("Added {}".format(edge))

return job_execution
Loading

0 comments on commit 3684a08

Please sign in to comment.