Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(airflow): Fix for failing serialisation when Param was specified + support for external task sensor #5368

Merged
merged 3 commits into from
Jul 12, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 33 additions & 25 deletions metadata-ingestion/src/datahub_provider/client/airflow_generator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Dict, List, Optional, Union
from typing import TYPE_CHECKING, Dict, List, Optional, Union, cast

from airflow.configuration import conf

Expand Down Expand Up @@ -87,6 +87,27 @@ def _get_dependencies(
if subdag_task_id in upstream_task._downstream_task_ids:
upstream_subdag_triggers.append(upstream_task_urn)

# If the operator is an ExternalTaskSensor then we set the remote task as upstream.
# It is possible to tie an external sensor to DAG if external_task_id is omitted but currently we can't tie
# jobflow to anothet jobflow.
external_task_upstreams = []
if task.task_type == "ExternalTaskSensor":
from airflow.sensors.external_task_sensor import ExternalTaskSensor

task = cast(ExternalTaskSensor, task)
if hasattr(task, "external_task_id") and task.external_task_id is not None:
external_task_upstreams = [
DataJobUrn.create_from_ids(
job_id=task.external_task_id,
data_flow_urn=str(
DataFlowUrn.create_from_ids(
orchestrator=flow_urn.get_orchestrator_name(),
flow_id=task.external_dag_id,
env=flow_urn.get_env(),
)
),
)
]
# exclude subdag operator tasks since these are not emitted, resulting in empty metadata
upstream_tasks = (
[
Expand All @@ -96,6 +117,7 @@ def _get_dependencies(
]
+ upstream_subdag_task_urns
+ upstream_subdag_triggers
+ external_task_upstreams
)
return upstream_tasks

Expand All @@ -114,22 +136,14 @@ def generate_dataflow(
:param capture_owner:
:return: DataFlow - Data generated dataflow
"""
from airflow.serialization.serialized_objects import SerializedDAG

id = dag.dag_id
orchestrator = "airflow"
description = f"{dag.description}\n\n{dag.doc_md or ''}"
data_flow = DataFlow(
cluster=cluster, id=id, orchestrator=orchestrator, description=description
)

flow_property_bag: Dict[str, str] = {
key: repr(value)
for (key, value) in SerializedDAG.serialize_dag(dag).items()
}
for key in dag.get_serialized_fields():
if key not in flow_property_bag:
flow_property_bag[key] = repr(getattr(dag, key))
flow_property_bag: Dict[str, str] = {}

allowed_flow_keys = [
"_access_control",
Expand All @@ -142,9 +156,10 @@ def generate_dataflow(
"tags",
"timezone",
]
flow_property_bag = {
k: v for (k, v) in flow_property_bag.items() if k in allowed_flow_keys
}

for key in allowed_flow_keys:
if hasattr(dag, key):
flow_property_bag[key] = repr(getattr(dag, key))

data_flow.properties = flow_property_bag
base_url = conf.get("webserver", "base_url")
Expand Down Expand Up @@ -191,21 +206,13 @@ def generate_datajob(
:param capture_tags: bool - whether to set tags automatically from airflow task
:return: DataJob - returns the generated DataJob object
"""
from airflow.serialization.serialized_objects import SerializedBaseOperator

dataflow_urn = DataFlowUrn.create_from_ids(
orchestrator="airflow", env=cluster, flow_id=dag.dag_id
)
datajob = DataJob(id=task.task_id, flow_urn=dataflow_urn)
datajob.description = AirflowGenerator._get_description(task)

job_property_bag: Dict[str, str] = {
key: repr(value)
for (key, value) in SerializedBaseOperator.serialize_operator(task).items()
}
for key in task.get_serialized_fields():
if key not in job_property_bag:
job_property_bag[key] = repr(getattr(task, key))
job_property_bag: Dict[str, str] = {}

allowed_task_keys = [
"_downstream_task_ids",
Expand All @@ -223,9 +230,10 @@ def generate_datajob(
"trigger_rule",
"wait_for_downstream",
]
job_property_bag = {
k: v for (k, v) in job_property_bag.items() if k in allowed_task_keys
}

for key in allowed_task_keys:
if hasattr(task, key):
job_property_bag[key] = repr(getattr(task, key))

datajob.properties = job_property_bag
base_url = conf.get("webserver", "base_url")
Expand Down