Skip to content

Commit

Permalink
use add_node and add_link (#75)
Browse files Browse the repository at this point in the history
- Add `add_node` and `add_link`
- Make them the default API for the users.

Note: the original `nodes.new` and `links.new` are still kept, and are used in the low-level.
  • Loading branch information
superstar54 authored Jul 16, 2024
1 parent 089d6ed commit 50d552a
Show file tree
Hide file tree
Showing 48 changed files with 417 additions and 400 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ def multiply(x, y):

# Create a workgraph to link the tasks.
wg = WorkGraph("test_add_multiply")
wg.tasks.new(add, name="add1")
wg.tasks.new(multiply, name="multiply1")
wg.links.new(wg.tasks["add1"].outputs["result"], wg.tasks["multiply1"].inputs["x"])
wg.add_task(add, name="add1")
wg.add_task(multiply, name="multiply1")
wg.add_link(wg.tasks["add1"].outputs["result"], wg.tasks["multiply1"].inputs["x"])

```

Expand Down
19 changes: 18 additions & 1 deletion aiida_workgraph/workgraph.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import aiida.orm
import node_graph
import aiida
import node_graph.link
from aiida_workgraph.socket import NodeSocket
from aiida_workgraph import USE_WIDGET
from aiida_workgraph.tasks import task_pool
from aiida_workgraph.task import Task
import time
from aiida_workgraph.collection import TaskCollection
from aiida_workgraph.utils.graph import (
Expand All @@ -11,10 +14,10 @@
link_creation_hook,
link_deletion_hook,
)
from typing import Any, Dict, List, Optional, Union

if USE_WIDGET:
from aiida_workgraph.widget import NodeGraphWidget
from typing import Any, Dict, List, Optional


class WorkGraph(node_graph.NodeGraph):
Expand Down Expand Up @@ -480,6 +483,20 @@ def _repr_mimebundle_(self, *args, **kwargs):
else:
return self._widget._ipython_display_(*args, **kwargs)

def add_task(
self, identifier: Union[str, callable], name: str = None, **kwargs
) -> Task:
"""Add a task to the workgraph."""
node = self.tasks.new(identifier, name, **kwargs)
return node

def add_link(
self, source: NodeSocket, target: NodeSocket
) -> node_graph.link.NodeLink:
"""Add a link between two nodes."""
link = self.links.new(source, target)
return link

def to_html(self, output: str = None, **kwargs):
"""Write a standalone html file to visualize the workgraph."""
self._widget.from_workgraph(self)
Expand Down
12 changes: 6 additions & 6 deletions docs/source/blog/aiida_python.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,10 @@
],
"source": [
"wg = WorkGraph(\"atomization_energy\")\n",
"pw_atom = wg.tasks.new(\"PythonJob\", function=emt, name=\"emt_atom\")\n",
"pw_mol = wg.tasks.new(\"PythonJob\", function=emt, name=\"emt_mol\")\n",
"pw_atom = wg.add_task(\"PythonJob\", function=emt, name=\"emt_atom\")\n",
"pw_mol = wg.add_task(\"PythonJob\", function=emt, name=\"emt_mol\")\n",
"# create the task to calculate the atomization energy\n",
"wg.tasks.new(\"PythonJob\", function=atomization_energy, name=\"atomization_energy\",\n",
"wg.add_task(\"PythonJob\", function=atomization_energy, name=\"atomization_energy\",\n",
" energy_atom=pw_atom.outputs[\"result\"],\n",
" energy_molecule=pw_mol.outputs[\"result\"])\n",
"wg.to_html()"
Expand Down Expand Up @@ -272,8 +272,8 @@
" return x*y + z\n",
"\n",
"wg = WorkGraph(\"PythonJob_parent_folder\")\n",
"wg.tasks.new(\"PythonJob\", function=add, name=\"add\")\n",
"wg.tasks.new(\"PythonJob\", function=multiply, name=\"multiply\",\n",
"wg.add_task(\"PythonJob\", function=add, name=\"add\")\n",
"wg.add_task(\"PythonJob\", function=multiply, name=\"multiply\",\n",
" parent_folder=wg.tasks[\"add\"].outputs[\"remote_folder\"],\n",
" )\n",
"\n",
Expand Down Expand Up @@ -324,7 +324,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.1.-1"
"version": "3.11.0"
},
"vscode": {
"interpreter": {
Expand Down
16 changes: 8 additions & 8 deletions docs/source/blog/workgraph_vs_workchain.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -480,10 +480,10 @@
"# create workgraph\n",
"wg = WorkGraph(\"add_multiply\")\n",
"# add tasks to workgraph\n",
"wg.tasks.new(add, name=\"add\")\n",
"wg.tasks.new(multiply, name=\"multiply\")\n",
"wg.add_task(add, name=\"add\")\n",
"wg.add_task(multiply, name=\"multiply\")\n",
"# link add and multiply tasks\n",
"wg.links.new(wg.tasks[\"add\"].outputs[0], wg.tasks[\"multiply\"].inputs[\"x\"])\n",
"wg.add_link(wg.tasks[\"add\"].outputs[0], wg.tasks[\"multiply\"].inputs[\"x\"])\n",
"\n",
"# Submit the workgraph\n",
"wg.submit(inputs = {\"add\": {\"x\": Int(1),\n",
Expand Down Expand Up @@ -1020,12 +1020,12 @@
"source": [
"wg2 = WorkGraph(name=\"add_multiply_add\")\n",
"# add tasks to workgraph\n",
"wg2.tasks.new(add, name=\"add\")\n",
"wg2.tasks.new(multiply, name=\"multiply\")\n",
"wg2.tasks.new(add, name=\"add2\")\n",
"wg2.add_task(add, name=\"add\")\n",
"wg2.add_task(multiply, name=\"multiply\")\n",
"wg2.add_task(add, name=\"add2\")\n",
"# link add and multiply tasks\n",
"wg2.links.new(wg2.tasks[\"add\"].outputs[0], wg2.tasks[\"multiply\"].inputs[\"x\"])\n",
"wg2.links.new(wg2.tasks[\"multiply\"].outputs[0], wg2.tasks[\"add2\"].inputs[\"x\"])\n",
"wg2.add_link(wg2.tasks[\"add\"].outputs[0], wg2.tasks[\"multiply\"].inputs[\"x\"])\n",
"wg2.add_link(wg2.tasks[\"multiply\"].outputs[0], wg2.tasks[\"add2\"].inputs[\"x\"])\n",
"\n",
"# Submit the workgraph\n",
"wg2.submit(inputs = {\"add\": {\"x\": Int(1),\n",
Expand Down
48 changes: 24 additions & 24 deletions docs/source/built-in/pythonjob.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@
" return x*y\n",
"\n",
"wg = WorkGraph(\"first_workflow\")\n",
"wg.tasks.new(add, name=\"add\")\n",
"wg.add_task(add, name=\"add\")\n",
"# we can also use a normal python function directly, but provide the \"PythonJob\" as the first argument\n",
"wg.tasks.new(\"PythonJob\", function=multiply, name=\"multiply\", x=wg.tasks[\"add\"].outputs[0])\n",
"wg.add_task(\"PythonJob\", function=multiply, name=\"multiply\", x=wg.tasks[\"add\"].outputs[0])\n",
"\n",
"# visualize the workgraph\n",
"wg.to_html()\n",
Expand Down Expand Up @@ -461,8 +461,8 @@
" return x*y + z\n",
"\n",
"wg = WorkGraph(\"PythonJob_parent_folder\")\n",
"wg.tasks.new(\"PythonJob\", function=add, name=\"add\")\n",
"wg.tasks.new(\"PythonJob\", function=multiply, name=\"multiply\",\n",
"wg.add_task(\"PythonJob\", function=add, name=\"add\")\n",
"wg.add_task(\"PythonJob\", function=multiply, name=\"multiply\",\n",
" parent_folder=wg.tasks[\"add\"].outputs[\"remote_folder\"],\n",
" )\n",
"\n",
Expand Down Expand Up @@ -553,7 +553,7 @@
"\n",
"\n",
"wg = WorkGraph(\"PythonJob_upload_files\")\n",
"wg.tasks.new(\"PythonJob\", function=add, name=\"add\")\n",
"wg.add_task(\"PythonJob\", function=add, name=\"add\")\n",
"\n",
"#------------------------- Submit the calculation -------------------\n",
"# we need use full path to the file\n",
Expand Down Expand Up @@ -654,9 +654,9 @@
],
"source": [
"wg = WorkGraph(\"atomization_energy\")\n",
"pw_atom = wg.tasks.new(\"PythonJob\", function=emt, name=\"emt_atom\")\n",
"pw_mol = wg.tasks.new(\"PythonJob\", function=emt, name=\"emt_mol\")\n",
"wg.tasks.new(\"PythonJob\", function=atomization_energy, name=\"atomization_energy\",\n",
"pw_atom = wg.add_task(\"PythonJob\", function=emt, name=\"emt_atom\")\n",
"pw_mol = wg.add_task(\"PythonJob\", function=emt, name=\"emt_mol\")\n",
"wg.add_task(\"PythonJob\", function=atomization_energy, name=\"atomization_energy\",\n",
" energy_atom=pw_atom.outputs[\"result\"],\n",
" energy_molecule=pw_mol.outputs[\"result\"])\n",
"wg.to_html()"
Expand Down Expand Up @@ -1065,8 +1065,8 @@
"\n",
"\n",
"wg = WorkGraph(\"PythonJob_shell_command\")\n",
"wg.tasks.new(\"PythonJob\", function=add, name=\"add\")\n",
"wg.tasks.new(\"PythonJob\", function=multiply, name=\"multiply\", x=wg.tasks[\"add\"].outputs[0])\n",
"wg.add_task(\"PythonJob\", function=add, name=\"add\")\n",
"wg.add_task(\"PythonJob\", function=multiply, name=\"multiply\", x=wg.tasks[\"add\"].outputs[0])\n",
"\n",
"# visualize the workgraph\n",
"wg.to_html()\n"
Expand Down Expand Up @@ -1351,14 +1351,14 @@
" return wg\n",
"\n",
"wg = WorkGraph()\n",
"wg.tasks.new(\"PythonJob\", function=add_multiply, name=\"add_multiply\")\n",
"wg.add_task(\"PythonJob\", function=add_multiply, name=\"add_multiply\")\n",
"\n",
"---------------------------------------------------------------------------\n",
"ValueError Traceback (most recent call last)\n",
"/tmp/ipykernel_3498848/1351840398.py in <cell line: 0>()\n",
" 8 \n",
" 9 wg = WorkGraph()\n",
"---> 10 wg.tasks.new(\"PythonJob\", function=add_multiply, name=\"add_multiply\")\n",
"---> 10 wg.add_task(\"PythonJob\", function=add_multiply, name=\"add_multiply\")\n",
"\n",
"~/repos/superstar54/aiida-workgraph/aiida_workgraph/collection.py in new(self, identifier, name, uuid, run_remotely, **kwargs)\n",
" 35 return super().new(identifier, name, uuid, **kwargs)\n",
Expand Down Expand Up @@ -1415,11 +1415,11 @@
"@task.graph_builder()\n",
"def add_multiply():\n",
" wg = WorkGraph()\n",
" wg.tasks.new(\"PythonJob\", function=add, name=\"add\")\n",
" wg.add_task(\"PythonJob\", function=add, name=\"add\")\n",
" return wg\n",
"\n",
"wg = WorkGraph()\n",
"wg.tasks.new(add_multiply, name=\"add_multiply\")"
"wg.add_task(add_multiply, name=\"add_multiply\")"
]
},
{
Expand All @@ -1435,7 +1435,7 @@
"In the context of an NSCF calculation, where data dependency exists on outputs from a SCF calculation, the workflow can be configured as follows:\n",
"\n",
"```python\n",
"nscf_task = wg.tasks.new(\"PythonJob\",\n",
"nscf_task = wg.add_task(\"PythonJob\",\n",
" function=pw_calculator,\n",
" name=\"nscf\",\n",
" parent_folder=scf_task.outputs[\"remote_folder\"],\n",
Expand All @@ -1454,15 +1454,15 @@
"For a Bader analysis requiring different charge density files:\n",
"\n",
"```python\n",
"bader_task = wg.tasks.new(\"PythonJob\",\n",
"bader_task = wg.add_task(\"PythonJob\",\n",
" function=bader_calculator,\n",
" name=\"bader\",\n",
" command=bader_command,\n",
" charge_density_folder=\"pp_valence_remote_folder\",\n",
" reference_charge_density_folder=\"pp_all_remote_folder\",\n",
")\n",
"wg.links.new(pp_valence.outputs[\"remote_folder\"], bader_task.inputs[\"copy_files\"])\n",
"wg.links.new(pp_all.outputs[\"remote_folder\"], bader_task.inputs[\"copy_files\"])\n",
"wg.add_link(pp_valence.outputs[\"remote_folder\"], bader_task.inputs[\"copy_files\"])\n",
"wg.add_link(pp_all.outputs[\"remote_folder\"], bader_task.inputs[\"copy_files\"])\n",
"```\n",
"\n",
"The `bader_calculator` function using specified charge density data:\n",
Expand Down Expand Up @@ -1564,7 +1564,7 @@
"then you can pass the value of `add_multiply.add` as an input to another task:\n",
"\n",
"```python\n",
"wg.tasks.new(\"PythonJob\",\n",
"wg.add_task(\"PythonJob\",\n",
" function=myfunc3,\n",
" name=\"myfunc3\",\n",
" x=wg.tasks[\"myfunc\"].outputs[\"add_multiply.add\"],\n",
Expand Down Expand Up @@ -1641,7 +1641,7 @@
" from aiida_workgraph import WorkGraph\n",
" wg = WorkGraph()\n",
" for key, atoms in scaled_atoms.items():\n",
" emt1 = wg.tasks.new(emt, name=f\"emt1_{key}\", atoms=atoms,\n",
" emt1 = wg.add_task(emt, name=f\"emt1_{key}\", atoms=atoms,\n",
" run_remotely=True)\n",
" emt1.set({\"computer\": \"localhost\"})\n",
" # save the output parameters to the context\n",
Expand Down Expand Up @@ -1672,18 +1672,18 @@
"atoms = bulk(\"Au\", cubic=True)\n",
"\n",
"wg = WorkGraph(\"pythonjob_eos_emt\")\n",
"scale_atoms_task = wg.tasks.new(\"PythonJob\",\n",
"scale_atoms_task = wg.add_task(\"PythonJob\",\n",
" function=generate_scaled_atoms,\n",
" name=\"scale_atoms\",\n",
" atoms=atoms,\n",
" )\n",
" # -------- calculate_enegies -----------\n",
"calculate_enegies_task = wg.tasks.new(calculate_enegies,\n",
"calculate_enegies_task = wg.add_task(calculate_enegies,\n",
" name=\"calculate_enegies\",\n",
" scaled_atoms=scale_atoms_task.outputs[\"scaled_atoms\"],\n",
" )\n",
" # -------- fit_eos -----------\n",
"wg.tasks.new(\"PythonJob\",\n",
"wg.add_task(\"PythonJob\",\n",
" function=fit_eos,\n",
" name=\"fit_eos\",\n",
" volumes=scale_atoms_task.outputs[\"volumes\"],\n",
Expand Down Expand Up @@ -2340,7 +2340,7 @@
" return x + y\n",
"\n",
"wg = WorkGraph(\"test_PythonJob_retrieve_files\")\n",
"wg.tasks.new(\"PythonJob\", function=add, name=\"add\")\n",
"wg.add_task(\"PythonJob\", function=add, name=\"add\")\n",
"# ------------------------- Submit the calculation -------------------\n",
"wg.submit(\n",
" inputs={\n",
Expand Down
14 changes: 7 additions & 7 deletions docs/source/built-in/shelljob.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
"from aiida_workgraph import WorkGraph\n",
"\n",
"wg = WorkGraph(name=\"test_shell_date\")\n",
"date_task = wg.tasks.new(\"ShellJob\", command=\"date\")\n",
"date_task = wg.add_task(\"ShellJob\", command=\"date\")\n",
"wg.submit(wait=True)\n",
"\n",
"# Print out the result:\n",
Expand Down Expand Up @@ -110,7 +110,7 @@
"source": [
"# Create a workgraph\n",
"wg = WorkGraph(name=\"test_shell_date_with_arguments\")\n",
"date_task = wg.tasks.new(\"ShellJob\", command=\"date\", arguments=['--iso-8601'])\n",
"date_task = wg.add_task(\"ShellJob\", command=\"date\", arguments=['--iso-8601'])\n",
"wg.submit(wait=True)\n",
"\n",
"# Print out the result:\n",
Expand Down Expand Up @@ -146,7 +146,7 @@
"from aiida.orm import SinglefileData\n",
"\n",
"wg = WorkGraph(name=\"test_shell_cat_with_file_arguments\")\n",
"cat_task = wg.tasks.new(\"ShellJob\", command=\"cat\",\n",
"cat_task = wg.add_task(\"ShellJob\", command=\"cat\",\n",
" arguments=[\"{file_a}\", \"{file_b}\"],\n",
" nodes={\n",
" 'file_a': SinglefileData.from_string('string a'),\n",
Expand Down Expand Up @@ -242,24 +242,24 @@
"# Create a workgraph\n",
"wg = WorkGraph(name=\"shell_add_mutiply_workflow\")\n",
"# echo x + y expression\n",
"echo_task_1 = wg.tasks.new(\"ShellJob\", name=\"echo_task_1\", command=\"echo\", arguments=[\"{x}\", \"+\", \"{y}\"],\n",
"echo_task_1 = wg.add_task(\"ShellJob\", name=\"echo_task_1\", command=\"echo\", arguments=[\"{x}\", \"+\", \"{y}\"],\n",
" nodes={'x': Int(2),\n",
" 'y': Int(3)},\n",
" )\n",
"# bc command to calculate the expression\n",
"bc_task_1 = wg.tasks.new(\"ShellJob\", name=\"bc_task_1\", command=\"bc\", arguments=[\"{expression}\"],\n",
"bc_task_1 = wg.add_task(\"ShellJob\", name=\"bc_task_1\", command=\"bc\", arguments=[\"{expression}\"],\n",
" parser=PickledData(parser),\n",
" nodes={'expression': echo_task_1.outputs[\"stdout\"]},\n",
" parser_outputs=[{\"name\": \"result\"}],\n",
" )\n",
"# echo result + z expression\n",
"echo_task_2 = wg.tasks.new(\"ShellJob\", name=\"echo_task_2\", command=\"echo\",\n",
"echo_task_2 = wg.add_task(\"ShellJob\", name=\"echo_task_2\", command=\"echo\",\n",
" arguments=[\"{result}\", \"*\", \"{z}\"],\n",
" nodes={'z': Int(4),\n",
" \"result\": bc_task_1.outputs[\"result\"]},\n",
" )\n",
"# bc command to calculate the expression\n",
"bc_task_2 = wg.tasks.new(\"ShellJob\", name=\"bc_task_2\", command=\"bc\", arguments=[\"{expression}\"],\n",
"bc_task_2 = wg.add_task(\"ShellJob\", name=\"bc_task_2\", command=\"bc\", arguments=[\"{expression}\"],\n",
" parser=PickledData(parser),\n",
" nodes={'expression': echo_task_2.outputs[\"stdout\"]},\n",
" parser_outputs=[{\"name\": \"result\"}],\n",
Expand Down
14 changes: 7 additions & 7 deletions docs/source/concept/task.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,9 @@
"source": [
"from aiida_workgraph import WorkGraph\n",
"wg = WorkGraph()\n",
"add_minus1 = wg.tasks.new(add_minus, name=\"add_minus1\")\n",
"multiply1 = wg.tasks.new(multiply, name=\"multiply1\")\n",
"wg.links.new(add_minus1.outputs[\"sum\"], multiply1.inputs[\"x\"])"
"add_minus1 = wg.add_task(add_minus, name=\"add_minus1\")\n",
"multiply1 = wg.add_task(multiply, name=\"multiply1\")\n",
"wg.add_link(add_minus1.outputs[\"sum\"], multiply1.inputs[\"x\"])"
]
},
{
Expand Down Expand Up @@ -197,7 +197,7 @@
"NormTask = build_task(norm)\n",
"\n",
"wg = WorkGraph()\n",
"norm_task = wg.tasks.new(NormTask, name=\"norm1\")\n",
"norm_task = wg.add_task(NormTask, name=\"norm1\")\n",
"norm_task.to_html()\n"
]
},
Expand Down Expand Up @@ -263,7 +263,7 @@
"from aiida.calculations.arithmetic.add import ArithmeticAddCalculation\n",
"\n",
"wg = WorkGraph()\n",
"add1 = wg.tasks.new(ArithmeticAddCalculation, name=\"add1\")"
"add1 = wg.add_task(ArithmeticAddCalculation, name=\"add1\")"
]
},
{
Expand Down Expand Up @@ -321,7 +321,7 @@
"source": [
"from aiida_workgraph import WorkGraph\n",
"wg = WorkGraph()\n",
"wg.tasks.new(MyAdd, name=\"add1\")\n"
"wg.add_task(MyAdd, name=\"add1\")\n"
]
},
{
Expand All @@ -330,7 +330,7 @@
"source": [
"One can also register the task in task pool, and then use its `identifer` directly.\n",
"```python\n",
"wg.tasks.new(\"MyAdd\", name=\"add1\")\n",
"wg.add_task(\"MyAdd\", name=\"add1\")\n",
"```"
]
}
Expand Down
Loading

0 comments on commit 50d552a

Please sign in to comment.