Skip to content

Commit

Permalink
Classic TaskGroup setup/teardown (#29891)
Browse files Browse the repository at this point in the history
* Classic TaskGroup setup/teardown

This implements classic TaskGroup setup/teardown. Ensures that nested
TaskGroups are taken into account

* fixup! Classic TaskGroup setup/teardown
  • Loading branch information
ephraimbuddy authored Mar 14, 2023
1 parent e508e8b commit 848a396
Show file tree
Hide file tree
Showing 3 changed files with 343 additions and 6 deletions.
2 changes: 2 additions & 0 deletions airflow/decorators/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,8 @@ def task_group(
ui_color: str = "CornflowerBlue",
ui_fgcolor: str = "#000",
add_suffix_on_collision: bool = False,
setup: bool = False,
teardown: bool = False,
) -> Callable[[Callable[FParams, FReturn]], _TaskGroupFactory[FParams, FReturn]]:
...

Expand Down
33 changes: 31 additions & 2 deletions airflow/utils/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,18 @@ def __init__(
ui_color: str = "CornflowerBlue",
ui_fgcolor: str = "#000",
add_suffix_on_collision: bool = False,
setup: bool = False,
teardown: bool = False,
):
from airflow.models.dag import DagContext

if setup and teardown:
raise AirflowException("Cannot set both setup and teardown to True")

self.prefix_group_id = prefix_group_id
self.default_args = copy.deepcopy(default_args or {})
self.setup = setup
self.teardown = teardown

dag = dag or DagContext.get_current_dag()

Expand Down Expand Up @@ -231,15 +238,37 @@ def add(self, task: DAGNode) -> None:
if task.children:
raise AirflowException("Cannot add a non-empty TaskGroup")

if SetupTeardownContext.is_setup:
is_setup, is_teardown = self._check_is_setup_teardown(task)

if SetupTeardownContext.is_setup or is_setup:
if isinstance(task, AbstractOperator):
setattr(task, "_is_setup", True)
elif SetupTeardownContext.is_teardown:
elif SetupTeardownContext.is_teardown or is_teardown:
if isinstance(task, AbstractOperator):
setattr(task, "_is_teardown", True)

self.children[key] = task

def _check_is_setup_teardown(self, task_):
"""Check if setup or teardown is set for the task"""
from airflow.models.abstractoperator import AbstractOperator

def _find_setup_teardown(tg):
setup, teardown = tg.setup, tg.teardown
while tg and tg.parent_group:
if setup or teardown:
break
tg = tg.parent_group
setup, teardown = tg.setup, tg.teardown
return setup, teardown

if isinstance(task_, TaskGroup):
return _find_setup_teardown(task_)
if isinstance(task_, AbstractOperator):
tg = task_.task_group
return _find_setup_teardown(tg)
return False, False

def _remove(self, task: DAGNode) -> None:
key = task.node_id

Expand Down
Loading

0 comments on commit 848a396

Please sign in to comment.