Skip to content

Commit

Permalink
Extract DAG cycle tester (apache#7897)
Browse files Browse the repository at this point in the history
  • Loading branch information
mik-laj authored Mar 27, 2020
1 parent 2a98a61 commit 3cc631e
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 179 deletions.
44 changes: 2 additions & 42 deletions airflow/models/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import re
import sys
import traceback
from collections import OrderedDict, defaultdict
from collections import OrderedDict
from datetime import datetime, timedelta
from typing import Callable, Collection, Dict, FrozenSet, Iterable, List, Optional, Set, Type, Union

Expand All @@ -39,9 +39,7 @@
from airflow import settings, utils
from airflow.configuration import conf
from airflow.dag.base_dag import BaseDag
from airflow.exceptions import (
AirflowDagCycleException, AirflowException, DagNotFound, DuplicateTaskIdFound, TaskNotFound,
)
from airflow.exceptions import AirflowException, DagNotFound, DuplicateTaskIdFound, TaskNotFound
from airflow.models.base import ID_LEN, Base
from airflow.models.baseoperator import BaseOperator
from airflow.models.dagbag import DagBag
Expand Down Expand Up @@ -1640,44 +1638,6 @@ def get_num_task_instances(dag_id, task_ids=None, states=None, session=None):
qry = qry.filter(TaskInstance.state.in_(states))
return qry.scalar()

def test_cycle(self):
"""
Check to see if there are any cycles in the DAG. Returns False if no cycle found,
otherwise raises exception.
"""
from airflow.models.dagbag import DagBag # Avoid circular imports

# default of int is 0 which corresponds to CYCLE_NEW
visit_map = defaultdict(int)
for task_id in self.task_dict.keys():
# print('starting %s' % task_id)
if visit_map[task_id] == DagBag.CYCLE_NEW:
self._test_cycle_helper(visit_map, task_id)
return False

def _test_cycle_helper(self, visit_map, task_id):
"""
Checks if a cycle exists from the input task using DFS traversal
"""
from airflow.models.dagbag import DagBag # Avoid circular imports

# print('Inspecting %s' % task_id)
if visit_map[task_id] == DagBag.CYCLE_DONE:
return False

visit_map[task_id] = DagBag.CYCLE_IN_PROGRESS

task = self.task_dict[task_id]
for descendant_id in task.get_direct_relative_ids():
if visit_map[descendant_id] == DagBag.CYCLE_IN_PROGRESS:
msg = "Cycle detected in DAG. Faulty task: {0} to {1}".format(
task_id, descendant_id)
raise AirflowDagCycleException(msg)
else:
self._test_cycle_helper(visit_map, descendant_id)

visit_map[task_id] = DagBag.CYCLE_DONE

@classmethod
def get_serialized_fields(cls):
"""Stringified DAGs and operators contain exactly these fields."""
Expand Down
7 changes: 2 additions & 5 deletions airflow/models/dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from airflow.plugins_manager import integrate_dag_plugins
from airflow.stats import Stats
from airflow.utils import timezone
from airflow.utils.dag_cycle_tester import test_cycle
from airflow.utils.file import correct_maybe_zipped
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.timeout import timeout
Expand Down Expand Up @@ -78,10 +79,6 @@ class DagBag(BaseDagBag, LoggingMixin):
:type store_serialized_dags: bool
"""

# static class variables to detetct dag cycle
CYCLE_NEW = 0
CYCLE_IN_PROGRESS = 1
CYCLE_DONE = 2
DAGBAG_IMPORT_TIMEOUT = conf.getint('core', 'DAGBAG_IMPORT_TIMEOUT')
SCHEDULER_ZOMBIE_TASK_THRESHOLD = conf.getint('scheduler', 'scheduler_zombie_task_threshold')

Expand Down Expand Up @@ -317,7 +314,7 @@ def bag_dag(self, dag, parent_dag, root_dag):
Throws AirflowDagCycleException if a cycle is detected in this dag or its subdags
"""

dag.test_cycle() # throws if a task cycle is found
test_cycle(dag) # throws if a task cycle is found

dag.resolve_template_files()
dag.last_loaded = timezone.utcnow()
Expand Down
59 changes: 59 additions & 0 deletions airflow/utils/dag_cycle_tester.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
DAG Cycle tester
"""
from collections import defaultdict
from typing import Dict

from airflow.exceptions import AirflowDagCycleException

CYCLE_NEW = 0
CYCLE_IN_PROGRESS = 1
CYCLE_DONE = 2


def test_cycle(dag):
"""
Check to see if there are any cycles in the DAG. Returns False if no cycle found,
otherwise raises exception.
"""
def _test_cycle_helper(visit_map: Dict[str, int], task_id: str) -> None:
"""
Checks if a cycle exists from the input task using DFS traversal
"""
if visit_map[task_id] == CYCLE_DONE:
return

visit_map[task_id] = CYCLE_IN_PROGRESS

task = dag.task_dict[task_id]
for descendant_id in task.get_direct_relative_ids():
if visit_map[descendant_id] == CYCLE_IN_PROGRESS:
msg = "Cycle detected in DAG. Faulty task: {0} to {1}".format(task_id, descendant_id)
raise AirflowDagCycleException(msg)
else:
_test_cycle_helper(visit_map, descendant_id)

visit_map[task_id] = CYCLE_DONE

# default of int is 0 which corresponds to CYCLE_NEW
dag_visit_map: Dict[str, int] = defaultdict(int)
for dag_task_id in dag.task_dict.keys():
if dag_visit_map[dag_task_id] == CYCLE_NEW:
_test_cycle_helper(dag_visit_map, dag_task_id)
return False
133 changes: 1 addition & 132 deletions tests/models/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

from airflow import models, settings
from airflow.configuration import conf
from airflow.exceptions import AirflowDagCycleException, AirflowException, DuplicateTaskIdFound
from airflow.exceptions import AirflowException, DuplicateTaskIdFound
from airflow.jobs.scheduler_job import DagFileProcessor
from airflow.models import DAG, DagModel, DagRun, DagTag, TaskFail, TaskInstance as TI
from airflow.models.baseoperator import BaseOperator
Expand Down Expand Up @@ -556,137 +556,6 @@ def test_resolve_template_files_list(self):

self.assertEqual(task.test_field, ['{{ ds }}', 'some_string'])

def test_cycle_empty(self):
# test empty
dag = DAG(
'dag',
start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'})

self.assertFalse(dag.test_cycle())

def test_cycle_single_task(self):
# test single task
dag = DAG(
'dag',
start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'})

with dag:
DummyOperator(task_id='A')

self.assertFalse(dag.test_cycle())

def test_cycle_no_cycle(self):
# test no cycle
dag = DAG(
'dag',
start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'})

# A -> B -> C
# B -> D
# E -> F
with dag:
op1 = DummyOperator(task_id='A')
op2 = DummyOperator(task_id='B')
op3 = DummyOperator(task_id='C')
op4 = DummyOperator(task_id='D')
op5 = DummyOperator(task_id='E')
op6 = DummyOperator(task_id='F')
op1.set_downstream(op2)
op2.set_downstream(op3)
op2.set_downstream(op4)
op5.set_downstream(op6)

self.assertFalse(dag.test_cycle())

def test_cycle_loop(self):
# test self loop
dag = DAG(
'dag',
start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'})

# A -> A
with dag:
op1 = DummyOperator(task_id='A')
op1.set_downstream(op1)

with self.assertRaises(AirflowDagCycleException):
dag.test_cycle()

def test_cycle_downstream_loop(self):
# test downstream self loop
dag = DAG(
'dag',
start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'})

# A -> B -> C -> D -> E -> E
with dag:
op1 = DummyOperator(task_id='A')
op2 = DummyOperator(task_id='B')
op3 = DummyOperator(task_id='C')
op4 = DummyOperator(task_id='D')
op5 = DummyOperator(task_id='E')
op1.set_downstream(op2)
op2.set_downstream(op3)
op3.set_downstream(op4)
op4.set_downstream(op5)
op5.set_downstream(op5)

with self.assertRaises(AirflowDagCycleException):
dag.test_cycle()

def test_cycle_large_loop(self):
# large loop
dag = DAG(
'dag',
start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'})

# A -> B -> C -> D -> E -> A
with dag:
op1 = DummyOperator(task_id='A')
op2 = DummyOperator(task_id='B')
op3 = DummyOperator(task_id='C')
op4 = DummyOperator(task_id='D')
op5 = DummyOperator(task_id='E')
op1.set_downstream(op2)
op2.set_downstream(op3)
op3.set_downstream(op4)
op4.set_downstream(op5)
op5.set_downstream(op1)

with self.assertRaises(AirflowDagCycleException):
dag.test_cycle()

def test_cycle_arbitrary_loop(self):
# test arbitrary loop
dag = DAG(
'dag',
start_date=DEFAULT_DATE,
default_args={'owner': 'owner1'})

# E-> A -> B -> F -> A
# -> C -> F
with dag:
op1 = DummyOperator(task_id='A')
op2 = DummyOperator(task_id='B')
op3 = DummyOperator(task_id='C')
op4 = DummyOperator(task_id='E')
op5 = DummyOperator(task_id='F')
op1.set_downstream(op2)
op1.set_downstream(op3)
op4.set_downstream(op1)
op3.set_downstream(op5)
op2.set_downstream(op5)
op5.set_downstream(op1)

with self.assertRaises(AirflowDagCycleException):
dag.test_cycle()

def test_following_previous_schedule(self):
"""
Make sure DST transitions are properly observed
Expand Down
Loading

0 comments on commit 3cc631e

Please sign in to comment.