diff --git a/airflow/models/base.py b/airflow/models/base.py index 439308da70e..6b24e57a7fe 100644 --- a/airflow/models/base.py +++ b/airflow/models/base.py @@ -17,12 +17,19 @@ # under the License. import functools -from typing import Any, Type +import logging +from typing import TYPE_CHECKING, Any, Collection, Optional, Set, Type +import jinja2 from sqlalchemy import MetaData, String from sqlalchemy.ext.declarative import declarative_base from airflow.configuration import conf +from airflow.templates import SandboxedEnvironment +from airflow.utils.weight_rule import WeightRule + +if TYPE_CHECKING: + from airflow.models.dag import DAG SQL_ALCHEMY_SCHEMA = conf.get("core", "SQL_ALCHEMY_SCHEMA") @@ -34,9 +41,123 @@ ID_LEN = 250 -# used for typing class Operator: - """Class just used for Typing""" + """Common interface for operators, including unmapped and mapped.""" + + log: logging.Logger + + upstream_task_ids: Set[str] + downstream_task_ids: Set[str] + weight_rule: str + priority_weight: int + + # For derived classes to define which fields will get jinjaified. + template_fields: Collection[str] + # Defines which files extensions to look for in the templated fields. + template_ext: Collection[str] + + def get_dag(self) -> "Optional[DAG]": + raise NotImplementedError() + + @property + def dag_id(self) -> str: + """Returns dag id if it has one or an adhoc + owner""" + dag = self.get_dag() + if dag: + return self.dag.dag_id + return f"adhoc_{self.owner}" + + def get_template_env(self) -> jinja2.Environment: + """Fetch a Jinja template environment from the DAG or instantiate empty environment if no DAG.""" + dag = self.get_dag() + if dag: + return dag.get_template_env() + return SandboxedEnvironment(cache_size=0) + + def prepare_template(self) -> None: + """Hook triggered after the templated fields get replaced by their content. + + If you need your operator to alter the content of the file before the + template is rendered, it should override this method to do so. + """ + + def resolve_template_files(self) -> None: + """Getting the content of files for template_field / template_ext.""" + if self.template_ext: + for field in self.template_fields: + content = getattr(self, field, None) + if content is None: + continue + elif isinstance(content, str) and any(content.endswith(ext) for ext in self.template_ext): + env = self.get_template_env() + try: + setattr(self, field, env.loader.get_source(env, content)[0]) # type: ignore + except Exception: + self.log.exception("Failed to resolve template field %r", field) + elif isinstance(content, list): + env = self.get_template_env() + for i, item in enumerate(content): + if isinstance(item, str) and any(item.endswith(ext) for ext in self.template_ext): + try: + content[i] = env.loader.get_source(env, item)[0] # type: ignore + except Exception as e: + self.log.exception(e) + self.prepare_template() + + def get_direct_relative_ids(self, upstream: bool = False) -> Set[str]: + """Get direct relative IDs to the current task, upstream or downstream.""" + if upstream: + return self.upstream_task_ids + return self.downstream_task_ids + + def get_flat_relative_ids( + self, + upstream: bool = False, + found_descendants: Optional[Set[str]] = None, + ) -> Set[str]: + """Get a flat set of relative IDs, upstream or downstream.""" + dag = self.get_dag() + if not dag: + return set() + + if not found_descendants: + found_descendants = set() + relative_ids = self.get_direct_relative_ids(upstream) + + for relative_id in relative_ids: + if relative_id not in found_descendants: + found_descendants.add(relative_id) + relative_task = dag.task_dict[relative_id] + relative_task.get_flat_relative_ids(upstream, found_descendants) + + return found_descendants + + @property + def priority_weight_total(self) -> int: + """ + Total priority weight for the task. It might include all upstream or downstream tasks. + + Depending on the weight rule: + + - WeightRule.ABSOLUTE - only own weight + - WeightRule.DOWNSTREAM - adds priority weight of all downstream tasks + - WeightRule.UPSTREAM - adds priority weight of all upstream tasks + """ + if self.weight_rule == WeightRule.ABSOLUTE: + return self.priority_weight + elif self.weight_rule == WeightRule.DOWNSTREAM: + upstream = False + elif self.weight_rule == WeightRule.UPSTREAM: + upstream = True + else: + upstream = False + dag = self.get_dag() + if dag is None: + return self.priority_weight + return self.priority_weight + sum( + dag.task_dict[task_id].priority_weight + for task_id in self.get_flat_relative_ids(upstream=upstream) + ) def get_id_collation_args(): diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 96f79735c42..dbe63b4e9c9 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -31,6 +31,7 @@ Any, Callable, ClassVar, + Collection, Dict, FrozenSet, Iterable, @@ -52,7 +53,6 @@ from sqlalchemy.orm import Session from sqlalchemy.orm.exc import NoResultFound -import airflow.templates from airflow.compat.functools import cached_property from airflow.configuration import conf from airflow.exceptions import AirflowException, TaskDeferred @@ -249,6 +249,12 @@ def partial(cls, *, task_id: str, dag: Optional["DAG"] = None, **kwargs): ) +DEFAULT_QUEUE = conf.get("operators", "default_queue") +DEFAULT_RETRIES = conf.getint("core", "default_task_retries", fallback=0) +DEFAULT_WEIGHT_RULE = conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM) +DEFAULT_TRIGGER_RULE = TriggerRule.ALL_SUCCESS + + @functools.total_ordering class BaseOperator(Operator, LoggingMixin, DAGNode, metaclass=BaseOperatorMeta): """ @@ -428,11 +434,10 @@ class derived from this one results in the creation of a task object, that is visible in Task Instance details View in the Webserver """ - # For derived classes to define which fields will get jinjaified - template_fields: Sequence[str] = () - # Defines which files extensions to look for in the templated fields - template_ext: Sequence[str] = () - # Template field renderers indicating type of the field, for example sql, json, bash + # Implementing Operator. + template_fields: Collection[str] = () + template_ext: Collection[str] = () + template_fields_renderers: Dict[str, str] = {} # Defines the color in the UI @@ -496,6 +501,7 @@ class derived from this one results in the creation of a task object, _lock_for_execution = False _dag: Optional["DAG"] = None + task_group: Optional["TaskGroup"] = None # subdag parameter is only set for SubDagOperator. # Setting it to None by default as other Operators do not have that field @@ -524,7 +530,7 @@ def __init__( email: Optional[Union[str, Iterable[str]]] = None, email_on_retry: bool = conf.getboolean('email', 'default_email_on_retry', fallback=True), email_on_failure: bool = conf.getboolean('email', 'default_email_on_failure', fallback=True), - retries: Optional[int] = conf.getint('core', 'default_task_retries', fallback=0), + retries: Optional[int] = DEFAULT_RETRIES, retry_delay: Union[timedelta, float] = timedelta(seconds=300), retry_exponential_backoff: bool = False, max_retry_delay: Optional[Union[timedelta, float]] = None, @@ -536,8 +542,8 @@ def __init__( params: Optional[Dict] = None, default_args: Optional[Dict] = None, priority_weight: int = 1, - weight_rule: str = conf.get('core', 'default_task_weight_rule', fallback=WeightRule.DOWNSTREAM), - queue: str = conf.get('operators', 'default_queue'), + weight_rule: str = DEFAULT_WEIGHT_RULE, + queue: str = DEFAULT_QUEUE, pool: Optional[str] = None, pool_slots: int = 1, sla: Optional[timedelta] = None, @@ -548,7 +554,7 @@ def __init__( on_retry_callback: Optional[TaskStateChangeCallback] = None, pre_execute: Optional[TaskPreExecuteHook] = None, post_execute: Optional[TaskPostExecuteHook] = None, - trigger_rule: str = TriggerRule.ALL_SUCCESS, + trigger_rule: str = DEFAULT_TRIGGER_RULE, resources: Optional[Dict] = None, run_as_user: Optional[str] = None, task_concurrency: Optional[int] = None, @@ -846,6 +852,9 @@ def get_outlet_defs(self): def node_id(self): return self.task_id + def get_dag(self) -> "Optional[DAG]": + return self._dag + @property # type: ignore[override] def dag(self) -> 'DAG': # type: ignore[override] """Returns the Operator's DAG if set, otherwise raises an error""" @@ -880,14 +889,6 @@ def has_dag(self): """Returns True if the Operator has been assigned to a DAG.""" return self._dag is not None - @property - def dag_id(self) -> str: - """Returns dag id if it has one or an adhoc + owner""" - if self.has_dag(): - return self.dag.dag_id - else: - return 'adhoc_' + self.owner - deps: Iterable[BaseTIDep] = frozenset( { NotInRetryPeriodDep(), @@ -939,38 +940,6 @@ def set_xcomargs_dependencies(self) -> None: arg = getattr(self, field) XComArg.apply_upstream_relationship(self, arg) - @property - def priority_weight_total(self) -> int: - """ - Total priority weight for the task. It might include all upstream or downstream tasks. - depending on the weight rule. - - - WeightRule.ABSOLUTE - only own weight - - WeightRule.DOWNSTREAM - adds priority weight of all downstream tasks - - WeightRule.UPSTREAM - adds priority weight of all upstream tasks - - """ - if self.weight_rule == WeightRule.ABSOLUTE: - return self.priority_weight - elif self.weight_rule == WeightRule.DOWNSTREAM: - upstream = False - elif self.weight_rule == WeightRule.UPSTREAM: - upstream = True - else: - upstream = False - - if not self._dag: - return self.priority_weight - from airflow.models.dag import DAG - - dag: DAG = self._dag - return self.priority_weight + sum( - map( - lambda task_id: dag.task_dict[task_id].priority_weight, - self.get_flat_relative_ids(upstream=upstream), - ) - ) - @cached_property def operator_extra_link_dict(self) -> Dict[str, Any]: """Returns dictionary of all extra links for the operator""" @@ -1164,45 +1133,6 @@ def _render_nested_template_fields( self._do_render_template_fields(content, nested_template_fields, context, jinja_env, seen_oids) - def get_template_env(self) -> jinja2.Environment: - """Fetch a Jinja template environment from the DAG or instantiate empty environment if no DAG.""" - return ( - self.dag.get_template_env() - if self.has_dag() - else airflow.templates.SandboxedEnvironment(cache_size=0) - ) - - def prepare_template(self) -> None: - """ - Hook that is triggered after the templated fields get replaced - by their content. If you need your operator to alter the - content of the file before the template is rendered, - it should override this method to do so. - """ - - def resolve_template_files(self) -> None: - """Getting the content of files for template_field / template_ext""" - if self.template_ext: - for field in self.template_fields: - content = getattr(self, field, None) - if content is None: - continue - elif isinstance(content, str) and any(content.endswith(ext) for ext in self.template_ext): - env = self.get_template_env() - try: - setattr(self, field, env.loader.get_source(env, content)[0]) # type: ignore - except Exception as e: - self.log.exception(e) - elif isinstance(content, list): - env = self.dag.get_template_env() - for i, item in enumerate(content): - if isinstance(item, str) and any(item.endswith(ext) for ext in self.template_ext): - try: - content[i] = env.loader.get_source(env, item)[0] # type: ignore - except Exception as e: - self.log.exception(e) - self.prepare_template() - @provide_session def clear( self, @@ -1260,27 +1190,6 @@ def get_task_instances( .all() ) - def get_flat_relative_ids( - self, - upstream: bool = False, - found_descendants: Optional[Set[str]] = None, - ) -> Set[str]: - """Get a flat set of relatives' ids, either upstream or downstream.""" - if not self._dag: - return set() - - if not found_descendants: - found_descendants = set() - relative_ids = self.get_direct_relative_ids(upstream) - - for relative_id in relative_ids: - if relative_id not in found_descendants: - found_descendants.add(relative_id) - relative_task = self._dag.task_dict[relative_id] - relative_task.get_flat_relative_ids(upstream, found_descendants) - - return found_descendants - def get_flat_relatives(self, upstream: bool = False): """Get a flat list of relatives, either upstream or downstream.""" if not self._dag: @@ -1356,16 +1265,6 @@ def dry_run(self) -> None: self.log.info('Rendering template for %s', field) self.log.info(content) - def get_direct_relative_ids(self, upstream: bool = False) -> Set[str]: - """ - Get set of the direct relative ids to the current task, upstream or - downstream. - """ - if upstream: - return self.upstream_task_ids - else: - return self.downstream_task_ids - def get_direct_relatives(self, upstream: bool = False) -> Iterable["DAGNode"]: """ Get list of the direct relatives to the current task, upstream or @@ -1635,7 +1534,7 @@ def _validate_kwarg_names_for_mapping( @attr.define(kw_only=True) -class MappedOperator(DAGNode): +class MappedOperator(Operator, LoggingMixin, DAGNode): """Object representing a mapped operator in a DAG""" def __repr__(self) -> str: @@ -1668,21 +1567,59 @@ def __repr__(self) -> str: deps: Iterable[BaseTIDep] = attr.ib() operator_extra_links: Iterable['BaseOperatorLink'] = () - params: Union[ParamsDict, dict] = attr.ib(factory=ParamsDict) - template_fields: Iterable[str] = attr.ib() + template_fields: Collection[str] = attr.ib() + template_ext: Collection[str] = attr.ib() + + weight_rule: str = attr.ib() + priority_weight: int = attr.ib() + trigger_rule: str = attr.ib() subdag: None = attr.ib(init=False) @_is_dummy.default - def _is_dummy_default(self): + def _is_dummy_from_operator_class(self): from airflow.operators.dummy import DummyOperator return issubclass(self.operator_class, DummyOperator) @deps.default - def _deps_from_class(self): + def _deps_from_operator_class(self): return self.operator_class.deps + @template_fields.default + def _template_fields_from_operator_class(self): + return self.operator_class.template_fields + + @template_ext.default + def _template_ext_from_operator_class(self): + return self.operator_class.template_ext + + @task_type.default + def _task_type_from_operator_class(self): + # Can be a string if we are de-serialized + val = self.operator_class + if isinstance(val, str): + return val.rsplit('.', 1)[-1] + return val.__name__ + + @task_group.default + def _task_group_default(self): + from airflow.utils.task_group import TaskGroupContext + + return TaskGroupContext.get_current_task_group(self.dag) + + @weight_rule.default + def _weight_rule_from_kwargs(self) -> str: + return self.partial_kwargs.get("weight_rule", DEFAULT_WEIGHT_RULE) + + @priority_weight.default + def _priority_weight_from_kwargs(self) -> int: + return self.partial_kwargs.get("priority_weight", 1) + + @trigger_rule.default + def _trigger_rule_from_kwargs(self) -> int: + return self.partial_kwargs.get("trigger_rule", DEFAULT_TRIGGER_RULE) + @classmethod def from_operator(cls, operator: BaseOperator, mapped_kwargs: Dict[str, Any]) -> "MappedOperator": dag: Optional["DAG"] = getattr(operator, '_dag', None) @@ -1695,8 +1632,10 @@ def from_operator(cls, operator: BaseOperator, mapped_kwargs: Dict[str, Any]) -> return MappedOperator( operator_class=type(operator), task_id=operator.task_id, - task_group=getattr(operator, 'task_group', None), - dag=getattr(operator, '_dag', None), + task_group=operator.task_group, + dag=dag, + upstream_task_ids=operator.upstream_task_ids, + downstream_task_ids=operator.downstream_task_ids, start_date=operator.start_date, end_date=operator.end_date, partial_kwargs={k: v for k, v in operator_init_kwargs.items() if k != "task_id"}, @@ -1704,7 +1643,6 @@ def from_operator(cls, operator: BaseOperator, mapped_kwargs: Dict[str, Any]) -> owner=operator.owner, max_active_tis_per_dag=operator.max_active_tis_per_dag, deps=operator.deps, - params=operator.params, ) @classmethod @@ -1722,6 +1660,8 @@ def from_decorator( Different from ``from_operator``, this DOES NOT validate ``mapped_kwargs``. The task decorator calling this should be responsible for validation. """ + from airflow.models.xcom_arg import XComArg + operator = MappedOperator( operator_class=decorator.operator_class, partial_kwargs=decorator.kwargs, @@ -1731,6 +1671,8 @@ def from_decorator( task_group=task_group, ) operator.mapped_kwargs.update(mapped_kwargs) + for arg in mapped_kwargs.values(): + XComArg.apply_upstream_relationship(operator, arg) return operator def __attrs_post_init__(self): @@ -1745,26 +1687,11 @@ def __attrs_post_init__(self): for arg in self.mapped_kwargs.values(): XComArg.apply_upstream_relationship(self, arg) - @task_type.default - def _default_task_type(self): - # Can be a string if we are de-serialized - val = self.operator_class - if isinstance(val, str): - return val.rsplit('.', 1)[-1] - return val.__name__ - - @task_group.default - def _default_task_group(self): - from airflow.utils.task_group import TaskGroupContext - - return TaskGroupContext.get_current_task_group(self.dag) - - @template_fields.default - def _template_fields_default(self): - return self.operator_class.template_fields + def get_dag(self) -> "Optional[DAG]": + return self.dag @property - def node_id(self): + def node_id(self) -> str: return self.task_id def map(self, **kwargs) -> "MappedOperator": @@ -1824,11 +1751,55 @@ def get_serialized_fields(cls): 'operator_extra_links', 'upstream_task_ids', 'task_type', + # These are automatically populated from partial_kwargs. In + # a perfect world, they should be properties like other + # partial_kwargs-populated values e.g. 'queue' below, but we + # must match BaseOperator's implementation and declare them + # as writable attributes instead. + 'weight_rule', + 'priority_weight', + 'trigger_rule', } | {'template_fields'} ) return cls.__serialized_fields + @property + def params(self) -> Union[dict, ParamsDict]: + return self.partial_kwargs.get("params", ParamsDict()) + + @property + def queue(self) -> str: + return self.partial_kwargs.get("queue", DEFAULT_QUEUE) + + @property + def run_as_user(self) -> Optional[str]: + return self.partial_kwargs.get("run_as_user") + + @property + def pool(self) -> str: + return self.partial_kwargs.get("pool") or Pool.DEFAULT_POOL_NAME + + @property + def pool_slots(self) -> int: + return self.partial_kwargs.get("pool_slots", 1) + + @property + def retries(self) -> Optional[int]: + return self.partial_kwargs.get("retries", DEFAULT_RETRIES) + + @property + def executor_config(self) -> Optional[dict]: + return self.partial_kwargs.get("executor_config") + + @property + def wait_for_downstream(self) -> bool: + return bool(self.partial_kwargs.get("wait_for_downstream")) + + @property + def depends_on_past(self) -> bool: + return self.partial_kwargs.get("depends_on_past") or self.wait_for_downstream + # TODO: Deprecate for Airflow 3.0 Chainable = Union[DependencyMixin, Sequence[DependencyMixin]] diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index f2f58b49b91..81d26bc490d 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -830,9 +830,13 @@ def create_ti(task: "BaseOperator") -> TI: for task_type, count in created_counts.items(): Stats.incr(f"task_instance_created-{task_type}", count) session.flush() - except IntegrityError as err: - self.log.info(str(err)) - self.log.info('Hit IntegrityError while creating the TIs for %s- %s', dag.dag_id, self.run_id) + except IntegrityError: + self.log.info( + 'Hit IntegrityError while creating the TIs for %s- %s', + dag.dag_id, + self.run_id, + exc_info=True, + ) self.log.info('Doing session rollback.') # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive. session.rollback() diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 31d7aad7bea..63820ffdfe5 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -644,6 +644,7 @@ def deserialize_operator(cls, encoded_op: Dict[str, Any]) -> Union[BaseOperator, deps=tuple(), is_dummy=False, template_fields=(), + template_ext=(), ) else: op = SerializedBaseOperator(task_id=encoded_op['task_id']) diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index 92d273e3f6c..ee79aa9a75b 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -296,7 +296,7 @@ def test_jinja_invalid_expression_is_just_propagated(self): with pytest.raises(jinja2.exceptions.TemplateSyntaxError): task.render_template("{{ invalid expression }}", {}) - @mock.patch("airflow.templates.SandboxedEnvironment", autospec=True) + @mock.patch("airflow.models.base.SandboxedEnvironment", autospec=True) def test_jinja_env_creation(self, mock_jinja_env): """Verify if a Jinja environment is created only once when templating.""" task = MockOperator(task_id="op1", arg1="{{ foo }}", arg2="{{ bar }}") diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index a3030597ae1..84575a0c22a 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -2253,13 +2253,6 @@ def setup_class(self): with create_session() as session: session.query(TaskMap).delete() - def _run_ti_with_faked_mapped_dependants(self, ti): - # TODO: We can't actually put a MappedOperator in a DAG yet due to it - # lacking some functions we expect from BaseOperator, so we mock this - # instead to test what effect it has to TaskMap recording. - with mock.patch.object(ti.task, "has_mapped_dependants", new=lambda: True): - ti.run() - @pytest.mark.parametrize("xcom_value", [[1, 2, 3], {"a": 1, "b": 2}, "abc"]) def test_not_recorded_for_unused(self, dag_maker, xcom_value): """A value not used for task-mapping should not be recorded.""" @@ -2271,12 +2264,12 @@ def push_something(): push_something() - ti = dag_maker.create_dagrun().task_instances[0] + ti = next(ti for ti in dag_maker.create_dagrun().task_instances if ti.task_id == "push_something") ti.run() assert dag_maker.session.query(TaskMap).count() == 0 - def test_error_if_unmappable(self, caplog, dag_maker): + def test_error_if_unmappable(self, dag_maker): """If an unmappable return value is used to map, fail the task that pushed the XCom.""" with dag_maker(dag_id="test_not_recorded_for_unused") as dag: @@ -2284,11 +2277,15 @@ def test_error_if_unmappable(self, caplog, dag_maker): def push_something(): return "abc" - push_something() + @dag.task() + def pull_something(value): + print(value) - ti = dag_maker.create_dagrun().task_instances[0] + pull_something.map(value=push_something()) + + ti = next(ti for ti in dag_maker.create_dagrun().task_instances if ti.task_id == "push_something") with pytest.raises(UnmappableXComPushed) as ctx: - self._run_ti_with_faked_mapped_dependants(ti) + ti.run() assert dag_maker.session.query(TaskMap).count() == 0 assert ti.state == TaskInstanceState.FAILED @@ -2309,11 +2306,15 @@ def test_written_task_map(self, dag_maker, xcom_value, expected_length, expected def push_something(): return xcom_value - push_something() + @dag.task() + def pull_something(value): + print(value) + + pull_something.map(value=push_something()) dag_run = dag_maker.create_dagrun() ti = next(ti for ti in dag_run.task_instances if ti.task_id == "push_something") - self._run_ti_with_faked_mapped_dependants(ti) + ti.run() task_map = dag_maker.session.query(TaskMap).one() assert task_map.dag_id == "test_written_task_map" diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 0bf8b1e1bfe..35d0d6806f7 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -1603,6 +1603,7 @@ def test_mapped_operator_serde(): }, 'task_id': 'a', 'template_fields': ['bash_command', 'env'], + 'template_ext': ['.sh', '.bash'], } op = SerializedBaseOperator.deserialize_operator(serialized) @@ -1632,6 +1633,7 @@ def test_mapped_operator_xcomarg_serde(): 'partial_kwargs': {}, 'task_id': 'task_2', 'template_fields': ['arg1', 'arg2'], + 'template_ext': [], } op = SerializedBaseOperator.deserialize_operator(serialized)