diff --git a/docs/tutorials/8-fireworks.md b/docs/tutorials/8-fireworks.md index 981f8c5d..f1ae2618 100644 --- a/docs/tutorials/8-fireworks.md +++ b/docs/tutorials/8-fireworks.md @@ -94,6 +94,16 @@ flow.update_config({"manager_config": {"_fworker": "fworker1"}}, name_filter="jo flow.update_config({"manager_config": {"_fworker": "fworker2"}}, name_filter="job2") ``` +NB: There are two ways to iterate over a `Flow`. The `iterflow` method iterates through a flow such that root nodes of the graph are always returned first. This has the benefit that the `job.output` references can always be resolved. +`Flow` also has an `__iter__` method, meaning you can write + +```py +for job_or_subflow in flow: + ... +``` + +to simply iterate through the `Flow.jobs` array. Note that `jobs` can also contain other flows. + ### Launching the Jobs As described above, convert the flow to a workflow via {obj}`flow_to_workflow` and add it to your launch pad. diff --git a/pyproject.toml b/pyproject.toml index a4946947..b1458520 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,6 +111,7 @@ show_missing = true exclude_lines = [ '^\s*@overload( |$)', '^\s*assert False(,|$)', + 'if TYPE_CHECKING:', 'if typing.TYPE_CHECKING:', ] diff --git a/src/jobflow/core/flow.py b/src/jobflow/core/flow.py index e1a620df..2daf69bf 100644 --- a/src/jobflow/core/flow.py +++ b/src/jobflow/core/flow.py @@ -2,21 +2,23 @@ from __future__ import annotations +import copy import logging -import typing import warnings +from typing import TYPE_CHECKING, Sequence from monty.json import MSONable +import jobflow from jobflow.core.reference import find_and_get_references from jobflow.utils import ValueEnum, contains_flow_or_job, suuid -if typing.TYPE_CHECKING: - from typing import Any, Callable +if TYPE_CHECKING: + from typing import Any, Callable, Iterator from networkx import DiGraph - import jobflow + from jobflow import Job __all__ = ["JobOrder", "Flow", "get_flow"] @@ -144,8 +146,94 @@ def __init__( self.add_jobs(jobs) self.output = output + def __len__(self) -> int: + """Get the number of jobs or subflows in the flow.""" + return len(self.jobs) + + def __getitem__(self, idx: int | slice) -> Flow | Job | tuple[Flow | Job, ...]: + """Get the job(s) or subflow(s) at the given index/slice.""" + return self.jobs[idx] + + def __setitem__( + self, idx: int | slice, value: Flow | Job | Sequence[Flow | Job] + ) -> None: + """Set the job(s) or subflow(s) at the given index/slice.""" + if ( + not isinstance(value, (Flow, jobflow.Job, tuple, list)) + or isinstance(value, (tuple, list)) + and not all(isinstance(v, (Flow, jobflow.Job)) for v in value) + ): + raise TypeError( + f"Flow can only contain Job or Flow objects, not {type(value).__name__}" + ) + jobs = list(self.jobs) + jobs[idx] = value # type: ignore[index, assignment] + self.jobs = tuple(jobs) + + def __iter__(self) -> Iterator[Flow | Job]: + """Iterate through the jobs in the flow.""" + return iter(self.jobs) + + def __contains__(self, item: Flow | Job) -> bool: + """Check if the flow contains a job or subflow.""" + return item in self.jobs + + def __add__(self, other: Job | Flow | Sequence[Flow | Job]) -> Flow: + """Add a job or subflow to the flow.""" + if not isinstance(other, (Flow, jobflow.Job, tuple, list)): + return NotImplemented + new_flow = self.__deepcopy__() + new_flow.add_jobs(other) + return new_flow + + def __sub__(self, other: Flow | Job) -> Flow: + """Remove a job or subflow from the flow.""" + if other not in self.jobs: + raise ValueError(f"{other!r} not found in flow") + new_flow = self.__deepcopy__() + new_flow.jobs = tuple([job for job in new_flow.jobs if job != other]) + return new_flow + + def __repr__(self, level=0, index=None) -> str: + """Get a string representation of the flow.""" + indent = " " * level + name, uuid = self.name, self.uuid + flow_index = f"{index}." if index is not None else "" + job_reprs = "\n".join( + f"{indent}{flow_index}{i}. " + f"{j.__repr__(level + 1, f'{flow_index}{i}') if isinstance(j, Flow) else j}" + for i, j in enumerate(self.jobs, 1) + ) + return f"Flow({name=}, {uuid=})\n{job_reprs}" + + def __eq__(self, other: object) -> bool: + """Check if the flow is equal to another flow.""" + if not isinstance(other, Flow): + return NotImplemented + return self.uuid == other.uuid + + def __hash__(self) -> int: + """Get the hash of the flow.""" + return hash(self.uuid) + + def __deepcopy__(self, memo: dict[int, Any] = None) -> Flow: + """Get a deep copy of the flow. + + Shallow copy doesn't make sense; jobs aren't allowed to belong to multiple flows + """ + kwds = self.as_dict() + for key in ("jobs", "@class", "@module", "@version"): + kwds.pop(key) + jobs = copy.deepcopy(self.jobs, memo) + new_flow = Flow(jobs=[], **kwds) + # reassign host + for job in jobs: + job.hosts = [new_flow.uuid] + new_flow.jobs = jobs + return new_flow + @property - def jobs(self) -> tuple[Flow | jobflow.Job, ...]: + def jobs(self) -> tuple[Flow | Job, ...]: """ Get the Jobs in the Flow. @@ -156,6 +244,20 @@ def jobs(self) -> tuple[Flow | jobflow.Job, ...]: """ return self._jobs + @jobs.setter + def jobs(self, jobs: Sequence[Flow | Job] | Job | Flow): + """ + Set the Jobs in the Flow. + + Parameters + ---------- + jobs + The list of Jobs/Flows of the Flow. + """ + if isinstance(jobs, (Flow, jobflow.Job)): + jobs = [jobs] + self._jobs = tuple(jobs) + @property def output(self) -> Any: """ @@ -666,7 +768,7 @@ def add_hosts_uuids( for j in self.jobs: j.add_hosts_uuids(hosts_uuids, prepend=prepend) - def add_jobs(self, jobs: list[Flow | jobflow.Job] | jobflow.Job | Flow): + def add_jobs(self, jobs: Job | Flow | Sequence[Flow | Job]) -> None: """ Add Jobs or Flows to the Flow. @@ -679,14 +781,14 @@ def add_jobs(self, jobs: list[Flow | jobflow.Job] | jobflow.Job | Flow): A list of Jobs and Flows. """ if not isinstance(jobs, (tuple, list)): - jobs = [jobs] + jobs = [jobs] # type: ignore[list-item] job_ids = set(self.all_uuids) hosts = [self.uuid, *self.hosts] for job in jobs: if job.host is not None and job.host != self.uuid: raise ValueError( - f"{job.__class__.__name__} {job.name} ({job.uuid}) already belongs " + f"{type(job).__name__} {job.name} ({job.uuid}) already belongs " f"to another flow." ) if job.uuid in job_ids: @@ -743,7 +845,7 @@ def remove_jobs(self, indices: int | list[int]): def get_flow( - flow: Flow | jobflow.Job | list[jobflow.Job], + flow: Flow | Job | list[jobflow.Job], ) -> Flow: """ Check dependencies and return flow object. diff --git a/src/jobflow/core/job.py b/src/jobflow/core/job.py index 8b9ee18c..cd9f19f6 100644 --- a/src/jobflow/core/job.py +++ b/src/jobflow/core/job.py @@ -13,7 +13,7 @@ from jobflow.utils.uuid import suuid if typing.TYPE_CHECKING: - from typing import Any, Callable, Hashable + from typing import Any, Callable, Hashable, Sequence from networkx import DiGraph from pydantic import BaseModel @@ -194,10 +194,7 @@ def get_job(*args, **kwargs) -> Job: args = args[1:] return Job( - function=f, - function_args=args, - function_kwargs=kwargs, - **job_kwargs, + function=f, function_args=args, function_kwargs=kwargs, **job_kwargs ) get_job.original = func @@ -366,6 +363,49 @@ def __init__( f"inputs to your Job." ) + def __repr__(self): + """Get a string representation of the job.""" + name, uuid = self.name, self.uuid + return f"Job({name=}, {uuid=})" + + def __contains__(self, item: Hashable) -> bool: + """ + Check if the job contains a reference to a given UUID. + + Parameters + ---------- + item + A UUID. + + Returns + ------- + bool + Whether the job contains a reference to the UUID. + """ + return item in self.input_uuids + + def __eq__(self, other: object) -> bool: + """ + Check if two jobs are equal. + + Parameters + ---------- + other + Another job. + + Returns + ------- + bool + Whether the jobs are equal. + """ + if not isinstance(other, Job): + return NotImplemented + return self.__dict__ == other.__dict__ + + def __hash__(self) -> int: + """Get the hash of the job.""" + return hash(self.uuid) + @property def input_references(self) -> tuple[jobflow.OutputReference, ...]: """ @@ -474,7 +514,7 @@ def host(self): """ return self.hosts[0] if self.hosts else None - def set_uuid(self, uuid: str): + def set_uuid(self, uuid: str) -> None: """ Set the UUID of the job. @@ -1079,7 +1119,7 @@ def __setattr__(self, key, value): else: super().__setattr__(key, value) - def add_hosts_uuids(self, hosts_uuids: str | list[str], prepend: bool = False): + def add_hosts_uuids(self, hosts_uuids: str | Sequence[str], prepend: bool = False): """ Add a list of UUIDs to the internal list of hosts. @@ -1095,7 +1135,7 @@ def add_hosts_uuids(self, hosts_uuids: str | list[str], prepend: bool = False): Insert the UUIDs at the beginning of the list rather than extending it. """ if not isinstance(hosts_uuids, (list, tuple)): - hosts_uuids = [hosts_uuids] + hosts_uuids = [hosts_uuids] # type: ignore if prepend: self.hosts[0:0] = hosts_uuids else: diff --git a/src/jobflow/core/reference.py b/src/jobflow/core/reference.py index 1f54fcb2..bb71d145 100644 --- a/src/jobflow/core/reference.py +++ b/src/jobflow/core/reference.py @@ -293,7 +293,7 @@ def as_dict(self): schema_dict = MontyEncoder().default(schema) if schema is not None else None data = { "@module": self.__class__.__module__, - "@class": self.__class__.__name__, + "@class": type(self).__name__, "@version": None, "uuid": self.uuid, "attributes": self.attributes, diff --git a/tests/core/test_flow.py b/tests/core/test_flow.py index f454d662..2216e753 100644 --- a/tests/core/test_flow.py +++ b/tests/core/test_flow.py @@ -9,7 +9,7 @@ def div(a, b=2): return a / b -def get_test_job(): +def get_test_job(*args): from jobflow import Job return Job(add, function_args=(1, 2)) @@ -456,7 +456,11 @@ def test_dag_validation(): job2 = Job(add, function_args=(job1.output, 2)) job1.function_args = (job2.output, 2) flow = Flow(jobs=[job1, job2]) - with pytest.raises(ValueError): + with pytest.raises( + ValueError, + match="Job connectivity contains cycles therefore job execution order " + "cannot be determined", + ): next(flow.iterflow()) @@ -504,7 +508,6 @@ def test_update_kwargs(): def test_update_maker_kwargs(): - # test no filter flow = get_maker_flow() flow.update_maker_kwargs({"b": 10}) @@ -687,6 +690,11 @@ def test_add_jobs(): with pytest.raises(ValueError): flow1.add_jobs(flow3) + # test passing single job to @jobs setter + flow1.jobs = add_job1 + assert len(flow1.jobs) == 1 + assert flow1.jobs[0] is add_job1 + def test_remove_jobs(): from jobflow.core.flow import Flow @@ -831,3 +839,162 @@ def test_update_config(): assert flow.jobs[0].config.resolve_references assert flow.jobs[1].config.manager_config == {"a": "b"} assert flow.jobs[1].config.resolve_references + + +def test_flow_magic_methods(): + from jobflow import Flow + + # prepare test jobs and flows + job1, job2, job3, job4, job5, job6 = map(get_test_job, range(6)) + + flow1 = Flow([job1]) + flow2 = Flow([job2, job3]) + + # test __len__ + assert len(flow1) == 1 + assert len(flow2) == 2 + + # test __getitem__ + assert flow2[0] == job2 + assert flow2[1] == job3 + + # test __setitem__ + flow2[0] = job4 + assert flow2[0] == job4 + + # test __iter__ + for job in flow2: + assert job in [job4, job3] + + # test __contains__ + assert job1 in flow1 + assert job4 in flow2 + assert job3 in flow2 + + # test __add__ + flow3 = flow1 + job5 + assert len(flow3) == 2 + assert job5 in flow3 + + # test __sub__ + flow4 = flow3 - job5 + assert len(flow4) == 1 == len(flow1) + assert job5 not in flow4 + + # test __eq__ and __hash__ + assert flow1 == flow1 + assert flow1 != flow2 + assert hash(flow1) != hash(flow2) + + # test __deepcopy__ + flow_copy = flow1.__deepcopy__() + assert flow_copy == flow1 + assert id(flow_copy) != id(flow1) + + # test __getitem__ with out of range index + with pytest.raises(IndexError): + _ = flow1[10] + + # test __setitem__ with out of range index + with pytest.raises(IndexError): + flow1[10] = job4 + + # test __contains__ with job not in flow + assert job5 not in flow1 + assert flow2 not in flow1 + + # test __add__ with non-job item + with pytest.raises(TypeError): + _ = job6 + "not a job" + + # test __sub__ with non-job item + with pytest.raises(TypeError): + _ = job6 - "not a job" + + # test __sub__ with job not in flow + with pytest.raises( + ValueError, match=r"Job\(name='add', uuid='.+'\) not found in flow" + ): + _ = flow1 - job5 + + # test __eq__ with non-flow item + assert flow1 != "not a flow" + + +def test_flow_magic_methods_edge_cases(): + from jobflow import Flow + + # prepare test jobs and flows + job1, job2, job3, job4, job5, job6 = map(get_test_job, range(6)) + Flow([job6]) + empty_flow = Flow([]) + flow1 = Flow([job1, job2, job3, job4]) + + # test negative indexing with __getitem__ and __setitem__ + assert flow1[-1] == job4 + flow1[-1] = job5 + assert flow1[-1] == job5 + + # test slicing with __getitem__ and __setitem__ + assert flow1[1:3] == (job2, job3) + flow1[1] = job4 # test single item + assert flow1[1] == job4 + flow1[1:3] = (job4, job5) # test multiple items with slicing + assert flow1[1:3] == (job4, job5) + + # test __add__ with bad type + assert flow1.__add__("string") == NotImplemented + + for val in (None, 1.0, 1, "1", [1], (1,), {1: 1}): + type_name = type(val).__name__ + with pytest.raises( + TypeError, + match=f"Flow can only contain Job or Flow objects, not {type_name}", + ): + flow1[1:3] = val + + # adding an empty flow still increases len by 1 + assert len(flow1 + empty_flow) == len(flow1) + 1 + + # test __add__ and __sub__ with job already in the flow + with pytest.raises( + ValueError, match="jobs array contains multiple jobs/flows with the same uuid" + ): + _ = flow1 + job1 + + with pytest.raises(ValueError, match="Job .+ already belongs to another flow"): + _ = flow1 + job6 + + +def test_flow_repr(): + from jobflow import Flow + + # prepare jobs and flows + job1, job2, job3, job4, job5, job6, job7 = map(get_test_job, range(7)) + + flow1 = Flow([job1]) + flow2 = Flow([job2, job3]) + sub_flow1 = Flow([job6, job7]) + flow3 = Flow([job4, job5, sub_flow1]) + flow4 = Flow([flow1, flow2, flow3]) + + flow_repr = repr(flow4).splitlines() + + lines = ( + "Flow(name='Flow', uuid='", + "1. Flow(name='Flow', uuid='", + " 1.1. Job(name='add', uuid='", + "2. Flow(name='Flow', uuid='", + " 2.1. Job(name='add', uuid='", + " 2.2. Job(name='add', uuid='", + "3. Flow(name='Flow', uuid='", + " 3.1. Job(name='add', uuid='", + " 3.2. Job(name='add', uuid='", + " 3.3. Flow(name='Flow', uuid='", + " 3.3.1. Job(name='add', uuid='", + " 3.3.2. Job(name='add', uuid='", + ) + + assert len(lines) == len(flow_repr) + for expected, line in zip(lines, flow_repr): + assert line.startswith(expected), f"{line=} doesn't start with {expected=}" diff --git a/tests/core/test_job.py b/tests/core/test_job.py index 3273dc75..8999ccc9 100644 --- a/tests/core/test_job.py +++ b/tests/core/test_job.py @@ -545,7 +545,7 @@ class MySchema(BaseModel): response = Response.from_job_returns( {"number": "5", "name": "Ian"}, output_schema=MySchema ) - assert response.output.__class__.__name__ == "MySchema" + assert type(response.output).__name__ == "MySchema" assert response.output.number == 5 assert response.output.name == "Ian" @@ -886,22 +886,22 @@ def add_schema_replace(a, b): test_job = add_schema(5, 6) response = test_job.run(memory_jobstore) - assert response.output.__class__.__name__ == "AddSchema" + assert type(response.output).__name__ == "AddSchema" assert response.output.result == 11 test_job = add_schema_dict(5, 6) response = test_job.run(memory_jobstore) - assert response.output.__class__.__name__ == "AddSchema" + assert type(response.output).__name__ == "AddSchema" assert response.output.result == 11 test_job = add_schema_response(5, 6) response = test_job.run(memory_jobstore) - assert response.output.__class__.__name__ == "AddSchema" + assert type(response.output).__name__ == "AddSchema" assert response.output.result == 11 test_job = add_schema_response_dict(5, 6) response = test_job.run(memory_jobstore) - assert response.output.__class__.__name__ == "AddSchema" + assert type(response.output).__name__ == "AddSchema" assert response.output.result == 11 test_job = add_schema_replace(5, 6) @@ -922,7 +922,7 @@ def add_schema_replace(a, b): def test_store_inputs(memory_jobstore): - from jobflow.core.job import OutputReference, store_inputs + from jobflow.core.job import Job, OutputReference, store_inputs test_job = store_inputs(1) test_job.run(memory_jobstore) @@ -935,6 +935,12 @@ def test_store_inputs(memory_jobstore): output = memory_jobstore.query_one({"uuid": test_job.uuid}, ["output"])["output"] assert OutputReference.from_dict(output) == ref + # test error msg for multiple stores + with pytest.raises( + ValueError, match="Cannot select True for multiple additional stores" + ): + _ = Job(function=sum, function_args=([1, 2],), store1=True, store2=True) + def test_pass_manager_config(): from jobflow import Flow, Job @@ -1261,3 +1267,31 @@ def use_maker(maker): response = test_job.run(memory_jobstore) assert response.replace.jobs[0].config == new_config assert response.replace.jobs[0].config_updates[0]["config"] == new_config + + +def test_job_magic_methods(): + from jobflow import Job + + # prepare test jobs + job1 = Job(function=sum, function_args=([1, 2],)) + job2 = Job(function=dict, function_args=((("a", 1), ("b", 2)),)) + job3 = Job(function=sum, function_args=([1, 2],)) + + # test __repr__ + assert repr(job1) == f"Job(name='sum', uuid='{job1.uuid}')" + assert repr(job2) == f"Job(name='dict', uuid='{job2.uuid}')" + assert repr(job3) == f"Job(name='sum', uuid='{job3.uuid}')" + assert repr(job1) != repr(job3) + + # test __contains__ (using some fake UUID) + # initial job.input_references is empty so can't test positive case + assert "fake-uuid" not in job1 + + # test __eq__ + assert job1 == job1 + assert job2 == job2 + assert job1 != job2 + assert job1 != job3 # Different UUIDs + + # test __hash__ + assert hash(job1) != hash(job2) != hash(job3) diff --git a/tests/managers/test_local.py b/tests/managers/test_local.py index 5658dbb5..77ba9a93 100644 --- a/tests/managers/test_local.py +++ b/tests/managers/test_local.py @@ -218,7 +218,7 @@ def test_replace_flow_nested(memory_jobstore, clean_dir, replace_flow_nested): assert len(responses[uuid1]) == 2 assert responses[uuid1][1].output == 11 assert responses[uuid1][1].replace is not None - assert responses[uuid1][2].output["first"].__class__.__name__ == "OutputReference" + assert type(responses[uuid1][2].output["first"]).__name__ == "OutputReference" assert responses[uuid2][1].output == "12345_end" # check store has the activity output