From 0112221508b749d4ab813155a98244a0e846edff Mon Sep 17 00:00:00 2001 From: Usiel Riedl Date: Mon, 10 Jun 2024 16:08:56 +0800 Subject: [PATCH] Ensures DAG params order regardless of backend Fixes https://github.com/apache/airflow/issues/40154 This change adds an extra attribute to the serialized DAG param objects which helps us decide the order of the deserialized params dictionary later even if the backend messes with us. I decided not to limit this just to MySQL since the operation is inexpensive and may turn out to be helpful. I made sure the new test fails with the old implementation + MySQL. I assume this test will be executed with MySQL somewhere in the build actions? --- airflow/serialization/serialized_objects.py | 13 ++++++++++--- tests/models/test_serialized_dag.py | 17 +++++++++++++++++ 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py index 6e7f50a87c73d0..3f130b1aad4939 100644 --- a/airflow/serialization/serialized_objects.py +++ b/airflow/serialization/serialized_objects.py @@ -831,14 +831,17 @@ def is_serialized(val): def _serialize_params_dict(cls, params: ParamsDict | dict): """Serialize Params dict for a DAG or task.""" serialized_params = {} - for k, v in params.items(): + for idx, item in enumerate(params.items()): + k, v = item # TODO: As of now, we would allow serialization of params which are of type Param only. try: class_identity = f"{v.__module__}.{v.__class__.__name__}" except AttributeError: class_identity = "" if class_identity == "airflow.models.param.Param": - serialized_params[k] = cls._serialize_param(v) + serialized_param = cls._serialize_param(v) + serialized_param["__position"] = idx + serialized_params[k] = serialized_param else: raise ValueError( f"Params to a DAG or a Task can be only of type airflow.models.param.Param, " @@ -850,7 +853,11 @@ def _serialize_params_dict(cls, params: ParamsDict | dict): def _deserialize_params_dict(cls, encoded_params: dict) -> ParamsDict: """Deserialize a DAG's Params dict.""" op_params = {} - for k, v in encoded_params.items(): + sorted_params = sorted( + encoded_params.items(), + key=lambda item: item[1].get("__position", 0) if isinstance(item[1], dict) else 0, + ) + for k, v in sorted_params: if isinstance(v, dict) and "__class" in v: op_params[k] = cls._deserialize_param(v) else: diff --git a/tests/models/test_serialized_dag.py b/tests/models/test_serialized_dag.py index c46d4e18a07342..236cb884e6d8b4 100644 --- a/tests/models/test_serialized_dag.py +++ b/tests/models/test_serialized_dag.py @@ -206,6 +206,23 @@ def test_get_dag_dependencies_default_to_empty(self, dag_dependencies_fields): expected_dependencies = {dag_id: [] for dag_id in example_dags} assert SDM.get_dag_dependencies() == expected_dependencies + def test_order_of_dag_params_is_stable(self): + """ + https://github.com/apache/airflow/issues/40154 + This asserts that we have logic in place which guarantees the order + of the params is maintained - even if the backend (e.g. MySQL) mutates + the serialized DAG JSON. + """ + example_dags = make_example_dags(example_dags_module) + example_params_trigger_ui = example_dags.get("example_params_trigger_ui") + before = list(example_params_trigger_ui.params.keys()) + + SDM.write_dag(example_params_trigger_ui) + retrieved_dag = SDM.get_dag("example_params_trigger_ui") + after = list(retrieved_dag.params.keys()) + + assert before == after + def test_order_of_deps_is_consistent(self): """ Previously the 'dag_dependencies' node in serialized dag was converted to list from set.