diff --git a/airflow/example_dags/example_dynamic_task_mapping_with_no_taskflow_operators.py b/airflow/example_dags/example_dynamic_task_mapping_with_no_taskflow_operators.py new file mode 100644 index 0000000000000..90465d66e6e19 --- /dev/null +++ b/airflow/example_dags/example_dynamic_task_mapping_with_no_taskflow_operators.py @@ -0,0 +1,62 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Example DAG demonstrating the usage of dynamic task mapping with non-TaskFlow operators.""" +from __future__ import annotations + +from datetime import datetime + +from airflow import DAG +from airflow.models.baseoperator import BaseOperator + + +class AddOneOperator(BaseOperator): + """A custom operator that adds one to the input.""" + + def __init__(self, value, **kwargs): + super().__init__(**kwargs) + self.value = value + + def execute(self, context): + return self.value + 1 + + +class SumItOperator(BaseOperator): + """A custom operator that sums the input.""" + + template_fields = ("values",) + + def __init__(self, values, **kwargs): + super().__init__(**kwargs) + self.values = values + + def execute(self, context): + total = sum(self.values) + print(f"Total was {total}") + return total + + +with DAG( + dag_id="example_dynamic_task_mapping_with_no_taskflow_operators", + start_date=datetime(2022, 3, 4), + catchup=False, +): + # map the task to a list of values + add_one_task = AddOneOperator.partial(task_id="add_one").expand(value=[1, 2, 3]) + + # aggregate (reduce) the mapped tasks results + sum_it_task = SumItOperator(task_id="sum_it", values=add_one_task.output) diff --git a/docs/apache-airflow/authoring-and-scheduling/dynamic-task-mapping.rst b/docs/apache-airflow/authoring-and-scheduling/dynamic-task-mapping.rst index f6c95660ab825..ecfe8e24136a4 100644 --- a/docs/apache-airflow/authoring-and-scheduling/dynamic-task-mapping.rst +++ b/docs/apache-airflow/authoring-and-scheduling/dynamic-task-mapping.rst @@ -177,9 +177,8 @@ Mapping with non-TaskFlow operators It is possible to use ``partial`` and ``expand`` with classic style operators as well. Some arguments are not mappable and must be passed to ``partial()``, such as ``task_id``, ``queue``, ``pool``, and most other arguments to ``BaseOperator``. -.. code-block:: python - - BashOperator.partial(task_id="bash", do_xcom_push=False).expand(bash_command=["echo 1", "echo 2"]) +.. exampleinclude:: /../../airflow/example_dags/example_dynamic_task_mapping_with_no_taskflow_operators.py + :language: python .. note:: Only keyword arguments are allowed to be passed to ``partial()``. diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index b025a61ae574b..e656a4f9d5d95 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -39,6 +39,7 @@ import airflow from airflow.datasets import Dataset from airflow.decorators import teardown +from airflow.decorators.base import DecoratedOperator from airflow.exceptions import AirflowException, SerializationError from airflow.hooks.base import BaseHook from airflow.kubernetes.pod_generator import PodGenerator @@ -615,7 +616,8 @@ def validate_deserialized_task( # data; checking its entirety basically duplicates this validation # function, so we just do some satiny checks. serialized_task.operator_class["_task_type"] == type(task).__name__ - serialized_task.operator_class["_operator_name"] == task._operator_name + if isinstance(serialized_task.operator_class, DecoratedOperator): + serialized_task.operator_class["_operator_name"] == task._operator_name # Serialization cleans up default values in partial_kwargs, this # adds them back to both sides. diff --git a/tests/www/views/test_views_acl.py b/tests/www/views/test_views_acl.py index 8a820fe01b3a4..98cc2660b59cb 100644 --- a/tests/www/views/test_views_acl.py +++ b/tests/www/views/test_views_acl.py @@ -252,6 +252,8 @@ def test_dag_autocomplete_success(client_all_dags): ) expected = [ {"name": "airflow", "type": "owner"}, + {"name": "example_dynamic_task_mapping_with_no_taskflow_operators", "type": "dag"}, + {"name": "example_setup_teardown_taskflow", "type": "dag"}, {"name": "test_mapped_taskflow", "type": "dag"}, {"name": "tutorial_taskflow_api", "type": "dag"}, {"name": "tutorial_taskflow_api_virtualenv", "type": "dag"},