Skip to content

Commit

Permalink
Ensures DAG params order regardless of backend (apache#40156)
Browse files Browse the repository at this point in the history
* Ensures DAG params order regardless of backend

Fixes apache#40154

This change adds an extra attribute to the serialized DAG param objects which helps us decide
the order of the deserialized params dictionary later even if the backend messes with us.

I decided not to limit this just to MySQL since the operation is inexpensive and may turn
out to be helpful.

I made sure the new test fails with the old implementation + MySQL. I assume this test will be
executed with MySQL somewhere in the build actions?

* Removes GitHub reference

Co-authored-by: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com>

* Serialize DAG params as array of tuples to ensure ordering

Alternative to previous approach: We serialize the DAG params dict as a list of tuples which _should_ keep their ordering regardless of backend.

Backwards compatibility is ensured because if `encoded_params` is a `dict` (not the expected `list`) then `dict(encoded_params)` still works.

* Make backwards compatibility more explicit

Based on suggestions by @uranusjr with an additional fix to make mypy happy.

---------

Co-authored-by: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com>
  • Loading branch information
2 people authored and romsharon98 committed Jul 26, 2024
1 parent 8363408 commit 5e182ff
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 21 deletions.
14 changes: 9 additions & 5 deletions airflow/serialization/schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@
"dag": {
"type": "object",
"properties": {
"params": { "$ref": "#/definitions/params_dict" },
"params": { "$ref": "#/definitions/params" },
"_dag_id": { "type": "string" },
"tasks": { "$ref": "#/definitions/tasks" },
"timezone": { "$ref": "#/definitions/timezone" },
Expand Down Expand Up @@ -206,9 +206,13 @@
"type": "array",
"additionalProperties": { "$ref": "#/definitions/operator" }
},
"params_dict": {
"type": "object",
"additionalProperties": {"$ref": "#/definitions/param" }
"params": {
"type": "array",
"prefixItems": [
{ "type": "string" },
{ "$ref": "#/definitions/param" }
],
"unevaluatedItems": false
},
"param": {
"$comment": "A param for a dag / operator",
Expand Down Expand Up @@ -258,7 +262,7 @@
"retry_delay": { "$ref": "#/definitions/timedelta" },
"retry_exponential_backoff": { "type": "boolean" },
"max_retry_delay": { "$ref": "#/definitions/timedelta" },
"params": { "$ref": "#/definitions/params_dict" },
"params": { "$ref": "#/definitions/params" },
"priority_weight": { "type": "number" },
"weight_rule": { "type": "string" },
"executor": { "type": "string" },
Expand Down
18 changes: 12 additions & 6 deletions airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,17 +827,17 @@ def is_serialized(val):
return class_(**kwargs)

@classmethod
def _serialize_params_dict(cls, params: ParamsDict | dict):
"""Serialize Params dict for a DAG or task."""
serialized_params = {}
def _serialize_params_dict(cls, params: ParamsDict | dict) -> list[tuple[str, dict]]:
"""Serialize Params dict for a DAG or task as a list of tuples to ensure ordering."""
serialized_params = []
for k, v in params.items():
# TODO: As of now, we would allow serialization of params which are of type Param only.
try:
class_identity = f"{v.__module__}.{v.__class__.__name__}"
except AttributeError:
class_identity = ""
if class_identity == "airflow.models.param.Param":
serialized_params[k] = cls._serialize_param(v)
serialized_params.append((k, cls._serialize_param(v)))
else:
raise ValueError(
f"Params to a DAG or a Task can be only of type airflow.models.param.Param, "
Expand All @@ -846,10 +846,16 @@ def _serialize_params_dict(cls, params: ParamsDict | dict):
return serialized_params

@classmethod
def _deserialize_params_dict(cls, encoded_params: dict) -> ParamsDict:
def _deserialize_params_dict(cls, encoded_params: list[tuple[str, dict]]) -> ParamsDict:
"""Deserialize a DAG's Params dict."""
if isinstance(encoded_params, collections.abc.Mapping):
# in 2.9.2 or earlier params were serialized as JSON objects
encoded_param_pairs: Iterable[tuple[str, dict]] = encoded_params.items()
else:
encoded_param_pairs = encoded_params

op_params = {}
for k, v in encoded_params.items():
for k, v in encoded_param_pairs:
if isinstance(v, dict) and "__class" in v:
op_params[k] = cls._deserialize_param(v)
else:
Expand Down
16 changes: 16 additions & 0 deletions tests/models/test_serialized_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,22 @@ def test_get_dag_dependencies_default_to_empty(self, dag_dependencies_fields):
expected_dependencies = {dag_id: [] for dag_id in example_dags}
assert SDM.get_dag_dependencies() == expected_dependencies

def test_order_of_dag_params_is_stable(self):
"""
This asserts that we have logic in place which guarantees the order
of the params is maintained - even if the backend (e.g. MySQL) mutates
the serialized DAG JSON.
"""
example_dags = make_example_dags(example_dags_module)
example_params_trigger_ui = example_dags.get("example_params_trigger_ui")
before = list(example_params_trigger_ui.params.keys())

SDM.write_dag(example_params_trigger_ui)
retrieved_dag = SDM.get_dag("example_params_trigger_ui")
after = list(retrieved_dag.params.keys())

assert before == after

def test_order_of_deps_is_consistent(self):
"""
Previously the 'dag_dependencies' node in serialized dag was converted to list from set.
Expand Down
42 changes: 32 additions & 10 deletions tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def detect_task_dependencies(task: Operator) -> DagDependency | None: # type: i
},
"edge_info": {},
"dag_dependencies": [],
"params": {},
"params": [],
},
}

Expand Down Expand Up @@ -2082,6 +2082,25 @@ def test_params_upgrade(self):
assert isinstance(dag.params.get_param("none"), Param)
assert dag.params["str"] == "str"

def test_params_serialization_from_dict_upgrade(self):
"""In <=2.9.2 params were serialized as a JSON object instead of a list of key-value pairs.
This test asserts that the params are still deserialized properly."""
serialized = {
"__version": 1,
"dag": {
"_dag_id": "simple_dag",
"fileloc": "/path/to/file.py",
"tasks": [],
"timezone": "UTC",
"params": {"my_param": {"__class": "airflow.models.param.Param", "default": "str"}},
},
}
dag = SerializedDAG.from_dict(serialized)

param = dag.params.get_param("my_param")
assert isinstance(param, Param)
assert param.value == "str"

def test_params_serialize_default_2_2_0(self):
"""In 2.0.0, param ``default`` was assumed to be json-serializable objects and were not run though
the standard serializer function. In 2.2.2 we serialize param ``default``. We keep this
Expand All @@ -2093,7 +2112,7 @@ def test_params_serialize_default_2_2_0(self):
"fileloc": "/path/to/file.py",
"tasks": [],
"timezone": "UTC",
"params": {"str": {"__class": "airflow.models.param.Param", "default": "str"}},
"params": [["str", {"__class": "airflow.models.param.Param", "default": "str"}]],
},
}
SerializedDAG.validate_schema(serialized)
Expand All @@ -2110,14 +2129,17 @@ def test_params_serialize_default(self):
"fileloc": "/path/to/file.py",
"tasks": [],
"timezone": "UTC",
"params": {
"my_param": {
"default": "a string value",
"description": "hello",
"schema": {"__var": {"type": "string"}, "__type": "dict"},
"__class": "airflow.models.param.Param",
}
},
"params": [
[
"my_param",
{
"default": "a string value",
"description": "hello",
"schema": {"__var": {"type": "string"}, "__type": "dict"},
"__class": "airflow.models.param.Param",
},
]
],
},
}
SerializedDAG.validate_schema(serialized)
Expand Down

0 comments on commit 5e182ff

Please sign in to comment.