Skip to content

Commit

Permalink
Allow passing [] and {} as argument.
Browse files Browse the repository at this point in the history
Add `merge_dict` function to merge dict correctly
  • Loading branch information
superstar54 committed Sep 19, 2024
1 parent cbb998a commit 4260f2d
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 20 deletions.
2 changes: 1 addition & 1 deletion aiida_workgraph/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def prepare_for_python_task(task: dict, kwargs: dict, var_kwargs: dict) -> dict:
import os

# get the names kwargs for the PythonJob, which are the inputs before _wait
function_kwargs = {}
function_kwargs = kwargs.pop("function_kwargs", {})
# TODO better way to find the function_kwargs
input_names = [
name
Expand Down
11 changes: 9 additions & 2 deletions aiida_workgraph/engine/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,9 @@ def reset_task(
self.ctx._tasks[name]["execution_count"] = 0
for child_task in self.ctx._tasks[name]["children"]:
self.reset_task(child_task, reset_process=False, recursive=False)
elif self.ctx._tasks[name]["metadata"]["node_type"].upper() in ["IF", "ZONE"]:
for child_task in self.ctx._tasks[name]["children"]:
self.reset_task(child_task, reset_process=False, recursive=False)
if recursive:
# reset its child tasks
names = self.ctx._connectivity["child_node"][name]
Expand Down Expand Up @@ -816,8 +819,8 @@ def update_zone_task_state(self, name: str) -> None:
finished, _ = self.are_childen_finished(name)
if finished:
self.set_task_state_info(name, "state", "FINISHED")
self.update_parent_task_state(name)
self.report(f"Task: {name} finished.")
self.update_parent_task_state(name)

def should_run_while_task(self, name: str) -> tuple[bool, t.Any]:
"""Check if the while task should run."""
Expand Down Expand Up @@ -949,6 +952,7 @@ def check_while_conditions(self) -> bool:
task_name, socket_name = c.split(".")
if "task_name" != "context":
condition_tasks.append(task_name)
self.reset_task(task_name)
self.run_tasks(condition_tasks, continue_workgraph=False)
conditions = []
for c in self.ctx._workgraph["conditions"]:
Expand Down Expand Up @@ -1018,7 +1022,10 @@ def run_tasks(self, names: t.List[str], continue_workgraph: bool = True) -> None
)
continue
# skip if the task is already executed
if name in self.ctx._executed_tasks:
# or if the task is in a skippped state
if name in self.ctx._executed_tasks or self.get_task_state_info(
name, "state"
) in ["SKIPPED"]:
continue
self.ctx._executed_tasks.append(name)
print("-" * 60)
Expand Down
51 changes: 37 additions & 14 deletions aiida_workgraph/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,18 @@ def get_nested_dict(d: Dict, name: str, **kwargs) -> Any:
return current


def merge_dicts(existing: Any, new: Any) -> Any:
"""Recursively merges two dictionaries."""
if isinstance(existing, dict) and isinstance(new, dict):
for k, v in new.items():
if k in existing and isinstance(existing[k], dict) and isinstance(v, dict):
merge_dicts(existing[k], v)
else:
existing[k] = v
else:
return new


def update_nested_dict(d: Optional[Dict[str, Any]], key: str, value: Any) -> None:
"""
Update or create a nested dictionary structure based on a dotted key path.
Expand Down Expand Up @@ -178,11 +190,21 @@ def update_nested_dict(d: Optional[Dict[str, Any]], key: str, value: Any) -> Non
If the resulting dictionary is empty after the update, it will be set to `None`.
"""

keys = key.split(".")
current = d if d is not None else {}
for k in keys[:-1]:
current = current.setdefault(k, {})
current[keys[-1]] = value
# Handle merging instead of overwriting
last_key = keys[-1]
if (
last_key in current
and isinstance(current[last_key], dict)
and isinstance(value, dict)
):
merge_dicts(current[last_key], value)
else:
current[last_key] = value
# if current is empty, set it to None
if not current:
current = None
Expand All @@ -200,26 +222,27 @@ def is_empty(value: Any) -> bool:
return False


def update_nested_dict_with_special_keys(d: Dict[str, Any]) -> Dict[str, Any]:
def update_nested_dict_with_special_keys(data: Dict[str, Any]) -> Dict[str, Any]:
"""Remove None and empty value"""
d = {k: v for k, v in d.items() if v is not None and not is_empty(v)}
# data = {k: v for k, v in data.items() if v is not None and not is_empty(v)}
data = {k: v for k, v in data.items() if v is not None}
#
special_keys = [k for k in d.keys() if "." in k]
special_keys = [k for k in data.keys() if "." in k]
for key in special_keys:
value = d.pop(key)
update_nested_dict(d, key, value)
return d
value = data.pop(key)
update_nested_dict(data, key, value)
return data


def merge_properties(wgdata: Dict[str, Any]) -> None:
"""Merge sub properties to the root properties.
{
"base.pw.parameters": 2,
"base.pw.code": 1,
}
after merge:
{"base": {"pw": {"parameters": 2,
"code": 1}}
{
"base.pw.parameters": 2,
"base.pw.code": 1,
}
after merge:
{"base": {"pw": {"parameters": 2,
"code": 1}}
So that no "." in the key name.
"""
for _, task in wgdata["tasks"].items():
Expand Down
3 changes: 1 addition & 2 deletions tests/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def add(x, y):
wg = WorkGraph("test_PythonJob_retrieve_files")
wg.add_task("PythonJob", function=add, name="add")
# ------------------------- Submit the calculation -------------------
wg.submit(
wg.run(
inputs={
"add": {
"x": 2,
Expand All @@ -450,7 +450,6 @@ def add(x, y):
},
},
},
wait=True,
)
assert (
"result.txt" in wg.tasks["add"].outputs["retrieved"].value.list_object_names()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_while.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_while_workgraph(decorated_add, decorated_multiply, decorated_compare):
wg.workgraph_type = "WHILE"
wg.conditions = ["compare1.result"]
wg.context = {"n": 1}
wg.max_iteration = 10
wg.max_iteration = 5
wg.add_task(decorated_compare, name="compare1", x="{{n}}", y=20)
multiply1 = wg.add_task(
decorated_multiply, name="multiply1", x="{{ n }}", y=orm.Int(2)
Expand Down

0 comments on commit 4260f2d

Please sign in to comment.