From 3cc631e0cb67bd5cdd9ce4f3dedbb1ab611a00e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kamil=20Bregu=C5=82a?= Date: Fri, 27 Mar 2020 11:07:23 +0100 Subject: [PATCH] Extract DAG cycle tester (#7897) --- airflow/models/dag.py | 44 +-------- airflow/models/dagbag.py | 7 +- airflow/utils/dag_cycle_tester.py | 59 +++++++++++ tests/models/test_dag.py | 133 +------------------------ tests/utils/dag_cycle_tester.py | 157 ++++++++++++++++++++++++++++++ 5 files changed, 221 insertions(+), 179 deletions(-) create mode 100644 airflow/utils/dag_cycle_tester.py create mode 100644 tests/utils/dag_cycle_tester.py diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 1f8098eda1161..98dc21976d40b 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -24,7 +24,7 @@ import re import sys import traceback -from collections import OrderedDict, defaultdict +from collections import OrderedDict from datetime import datetime, timedelta from typing import Callable, Collection, Dict, FrozenSet, Iterable, List, Optional, Set, Type, Union @@ -39,9 +39,7 @@ from airflow import settings, utils from airflow.configuration import conf from airflow.dag.base_dag import BaseDag -from airflow.exceptions import ( - AirflowDagCycleException, AirflowException, DagNotFound, DuplicateTaskIdFound, TaskNotFound, -) +from airflow.exceptions import AirflowException, DagNotFound, DuplicateTaskIdFound, TaskNotFound from airflow.models.base import ID_LEN, Base from airflow.models.baseoperator import BaseOperator from airflow.models.dagbag import DagBag @@ -1640,44 +1638,6 @@ def get_num_task_instances(dag_id, task_ids=None, states=None, session=None): qry = qry.filter(TaskInstance.state.in_(states)) return qry.scalar() - def test_cycle(self): - """ - Check to see if there are any cycles in the DAG. Returns False if no cycle found, - otherwise raises exception. - """ - from airflow.models.dagbag import DagBag # Avoid circular imports - - # default of int is 0 which corresponds to CYCLE_NEW - visit_map = defaultdict(int) - for task_id in self.task_dict.keys(): - # print('starting %s' % task_id) - if visit_map[task_id] == DagBag.CYCLE_NEW: - self._test_cycle_helper(visit_map, task_id) - return False - - def _test_cycle_helper(self, visit_map, task_id): - """ - Checks if a cycle exists from the input task using DFS traversal - """ - from airflow.models.dagbag import DagBag # Avoid circular imports - - # print('Inspecting %s' % task_id) - if visit_map[task_id] == DagBag.CYCLE_DONE: - return False - - visit_map[task_id] = DagBag.CYCLE_IN_PROGRESS - - task = self.task_dict[task_id] - for descendant_id in task.get_direct_relative_ids(): - if visit_map[descendant_id] == DagBag.CYCLE_IN_PROGRESS: - msg = "Cycle detected in DAG. Faulty task: {0} to {1}".format( - task_id, descendant_id) - raise AirflowDagCycleException(msg) - else: - self._test_cycle_helper(visit_map, descendant_id) - - visit_map[task_id] = DagBag.CYCLE_DONE - @classmethod def get_serialized_fields(cls): """Stringified DAGs and operators contain exactly these fields.""" diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index 9cf57b3251c18..f57f524a43997 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -36,6 +36,7 @@ from airflow.plugins_manager import integrate_dag_plugins from airflow.stats import Stats from airflow.utils import timezone +from airflow.utils.dag_cycle_tester import test_cycle from airflow.utils.file import correct_maybe_zipped from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.timeout import timeout @@ -78,10 +79,6 @@ class DagBag(BaseDagBag, LoggingMixin): :type store_serialized_dags: bool """ - # static class variables to detetct dag cycle - CYCLE_NEW = 0 - CYCLE_IN_PROGRESS = 1 - CYCLE_DONE = 2 DAGBAG_IMPORT_TIMEOUT = conf.getint('core', 'DAGBAG_IMPORT_TIMEOUT') SCHEDULER_ZOMBIE_TASK_THRESHOLD = conf.getint('scheduler', 'scheduler_zombie_task_threshold') @@ -317,7 +314,7 @@ def bag_dag(self, dag, parent_dag, root_dag): Throws AirflowDagCycleException if a cycle is detected in this dag or its subdags """ - dag.test_cycle() # throws if a task cycle is found + test_cycle(dag) # throws if a task cycle is found dag.resolve_template_files() dag.last_loaded = timezone.utcnow() diff --git a/airflow/utils/dag_cycle_tester.py b/airflow/utils/dag_cycle_tester.py new file mode 100644 index 0000000000000..5b28b96e6d96c --- /dev/null +++ b/airflow/utils/dag_cycle_tester.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +DAG Cycle tester +""" +from collections import defaultdict +from typing import Dict + +from airflow.exceptions import AirflowDagCycleException + +CYCLE_NEW = 0 +CYCLE_IN_PROGRESS = 1 +CYCLE_DONE = 2 + + +def test_cycle(dag): + """ + Check to see if there are any cycles in the DAG. Returns False if no cycle found, + otherwise raises exception. + """ + def _test_cycle_helper(visit_map: Dict[str, int], task_id: str) -> None: + """ + Checks if a cycle exists from the input task using DFS traversal + """ + if visit_map[task_id] == CYCLE_DONE: + return + + visit_map[task_id] = CYCLE_IN_PROGRESS + + task = dag.task_dict[task_id] + for descendant_id in task.get_direct_relative_ids(): + if visit_map[descendant_id] == CYCLE_IN_PROGRESS: + msg = "Cycle detected in DAG. Faulty task: {0} to {1}".format(task_id, descendant_id) + raise AirflowDagCycleException(msg) + else: + _test_cycle_helper(visit_map, descendant_id) + + visit_map[task_id] = CYCLE_DONE + + # default of int is 0 which corresponds to CYCLE_NEW + dag_visit_map: Dict[str, int] = defaultdict(int) + for dag_task_id in dag.task_dict.keys(): + if dag_visit_map[dag_task_id] == CYCLE_NEW: + _test_cycle_helper(dag_visit_map, dag_task_id) + return False diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index eae06fe91e9a2..2a556f152b32b 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -35,7 +35,7 @@ from airflow import models, settings from airflow.configuration import conf -from airflow.exceptions import AirflowDagCycleException, AirflowException, DuplicateTaskIdFound +from airflow.exceptions import AirflowException, DuplicateTaskIdFound from airflow.jobs.scheduler_job import DagFileProcessor from airflow.models import DAG, DagModel, DagRun, DagTag, TaskFail, TaskInstance as TI from airflow.models.baseoperator import BaseOperator @@ -556,137 +556,6 @@ def test_resolve_template_files_list(self): self.assertEqual(task.test_field, ['{{ ds }}', 'some_string']) - def test_cycle_empty(self): - # test empty - dag = DAG( - 'dag', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) - - self.assertFalse(dag.test_cycle()) - - def test_cycle_single_task(self): - # test single task - dag = DAG( - 'dag', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) - - with dag: - DummyOperator(task_id='A') - - self.assertFalse(dag.test_cycle()) - - def test_cycle_no_cycle(self): - # test no cycle - dag = DAG( - 'dag', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) - - # A -> B -> C - # B -> D - # E -> F - with dag: - op1 = DummyOperator(task_id='A') - op2 = DummyOperator(task_id='B') - op3 = DummyOperator(task_id='C') - op4 = DummyOperator(task_id='D') - op5 = DummyOperator(task_id='E') - op6 = DummyOperator(task_id='F') - op1.set_downstream(op2) - op2.set_downstream(op3) - op2.set_downstream(op4) - op5.set_downstream(op6) - - self.assertFalse(dag.test_cycle()) - - def test_cycle_loop(self): - # test self loop - dag = DAG( - 'dag', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) - - # A -> A - with dag: - op1 = DummyOperator(task_id='A') - op1.set_downstream(op1) - - with self.assertRaises(AirflowDagCycleException): - dag.test_cycle() - - def test_cycle_downstream_loop(self): - # test downstream self loop - dag = DAG( - 'dag', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) - - # A -> B -> C -> D -> E -> E - with dag: - op1 = DummyOperator(task_id='A') - op2 = DummyOperator(task_id='B') - op3 = DummyOperator(task_id='C') - op4 = DummyOperator(task_id='D') - op5 = DummyOperator(task_id='E') - op1.set_downstream(op2) - op2.set_downstream(op3) - op3.set_downstream(op4) - op4.set_downstream(op5) - op5.set_downstream(op5) - - with self.assertRaises(AirflowDagCycleException): - dag.test_cycle() - - def test_cycle_large_loop(self): - # large loop - dag = DAG( - 'dag', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) - - # A -> B -> C -> D -> E -> A - with dag: - op1 = DummyOperator(task_id='A') - op2 = DummyOperator(task_id='B') - op3 = DummyOperator(task_id='C') - op4 = DummyOperator(task_id='D') - op5 = DummyOperator(task_id='E') - op1.set_downstream(op2) - op2.set_downstream(op3) - op3.set_downstream(op4) - op4.set_downstream(op5) - op5.set_downstream(op1) - - with self.assertRaises(AirflowDagCycleException): - dag.test_cycle() - - def test_cycle_arbitrary_loop(self): - # test arbitrary loop - dag = DAG( - 'dag', - start_date=DEFAULT_DATE, - default_args={'owner': 'owner1'}) - - # E-> A -> B -> F -> A - # -> C -> F - with dag: - op1 = DummyOperator(task_id='A') - op2 = DummyOperator(task_id='B') - op3 = DummyOperator(task_id='C') - op4 = DummyOperator(task_id='E') - op5 = DummyOperator(task_id='F') - op1.set_downstream(op2) - op1.set_downstream(op3) - op4.set_downstream(op1) - op3.set_downstream(op5) - op2.set_downstream(op5) - op5.set_downstream(op1) - - with self.assertRaises(AirflowDagCycleException): - dag.test_cycle() - def test_following_previous_schedule(self): """ Make sure DST transitions are properly observed diff --git a/tests/utils/dag_cycle_tester.py b/tests/utils/dag_cycle_tester.py new file mode 100644 index 0000000000000..fbe1411ad9a76 --- /dev/null +++ b/tests/utils/dag_cycle_tester.py @@ -0,0 +1,157 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest + +from airflow import DAG +from airflow.exceptions import AirflowDagCycleException +from airflow.operators.dummy_operator import DummyOperator +from airflow.utils.dag_cycle_tester import test_cycle +from tests.models import DEFAULT_DATE + + +class TestCycleTester(unittest.TestCase): + def test_cycle_empty(self): + # test empty + dag = DAG( + 'dag', + start_date=DEFAULT_DATE, + default_args={'owner': 'owner1'}) + + self.assertFalse(test_cycle(dag)) + + def test_cycle_single_task(self): + # test single task + dag = DAG( + 'dag', + start_date=DEFAULT_DATE, + default_args={'owner': 'owner1'}) + + with dag: + DummyOperator(task_id='A') + + self.assertFalse(test_cycle(dag)) + + def test_cycle_no_cycle(self): + # test no cycle + dag = DAG( + 'dag', + start_date=DEFAULT_DATE, + default_args={'owner': 'owner1'}) + + # A -> B -> C + # B -> D + # E -> F + with dag: + op1 = DummyOperator(task_id='A') + op2 = DummyOperator(task_id='B') + op3 = DummyOperator(task_id='C') + op4 = DummyOperator(task_id='D') + op5 = DummyOperator(task_id='E') + op6 = DummyOperator(task_id='F') + op1.set_downstream(op2) + op2.set_downstream(op3) + op2.set_downstream(op4) + op5.set_downstream(op6) + + self.assertFalse(test_cycle(dag)) + + def test_cycle_loop(self): + # test self loop + dag = DAG( + 'dag', + start_date=DEFAULT_DATE, + default_args={'owner': 'owner1'}) + + # A -> A + with dag: + op1 = DummyOperator(task_id='A') + op1.set_downstream(op1) + + with self.assertRaises(AirflowDagCycleException): + self.assertFalse(test_cycle(dag)) + + def test_cycle_downstream_loop(self): + # test downstream self loop + dag = DAG( + 'dag', + start_date=DEFAULT_DATE, + default_args={'owner': 'owner1'}) + + # A -> B -> C -> D -> E -> E + with dag: + op1 = DummyOperator(task_id='A') + op2 = DummyOperator(task_id='B') + op3 = DummyOperator(task_id='C') + op4 = DummyOperator(task_id='D') + op5 = DummyOperator(task_id='E') + op1.set_downstream(op2) + op2.set_downstream(op3) + op3.set_downstream(op4) + op4.set_downstream(op5) + op5.set_downstream(op5) + + with self.assertRaises(AirflowDagCycleException): + self.assertFalse(test_cycle(dag)) + + def test_cycle_large_loop(self): + # large loop + dag = DAG( + 'dag', + start_date=DEFAULT_DATE, + default_args={'owner': 'owner1'}) + + # A -> B -> C -> D -> E -> A + with dag: + op1 = DummyOperator(task_id='A') + op2 = DummyOperator(task_id='B') + op3 = DummyOperator(task_id='C') + op4 = DummyOperator(task_id='D') + op5 = DummyOperator(task_id='E') + op1.set_downstream(op2) + op2.set_downstream(op3) + op3.set_downstream(op4) + op4.set_downstream(op5) + op5.set_downstream(op1) + + with self.assertRaises(AirflowDagCycleException): + self.assertFalse(test_cycle(dag)) + + def test_cycle_arbitrary_loop(self): + # test arbitrary loop + dag = DAG( + 'dag', + start_date=DEFAULT_DATE, + default_args={'owner': 'owner1'}) + + # E-> A -> B -> F -> A + # -> C -> F + with dag: + op1 = DummyOperator(task_id='A') + op2 = DummyOperator(task_id='B') + op3 = DummyOperator(task_id='C') + op4 = DummyOperator(task_id='E') + op5 = DummyOperator(task_id='F') + op1.set_downstream(op2) + op1.set_downstream(op3) + op4.set_downstream(op1) + op3.set_downstream(op5) + op2.set_downstream(op5) + op5.set_downstream(op1) + + with self.assertRaises(AirflowDagCycleException): + self.assertFalse(test_cycle(dag))