diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c8fa10a8..34373523 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,63 +1,63 @@ default_language_version: python: python3 -exclude: '^src/atomate2/vasp/schemas/calc_types/' +exclude: "^src/atomate2/vasp/schemas/calc_types/" repos: -- repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.250 - hooks: - - id: ruff - args: [--fix] -- repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 - hooks: - - id: check-yaml - - id: fix-encoding-pragma - args: [--remove] - - id: end-of-file-fixer - - id: trailing-whitespace -- repo: https://github.com/psf/black - rev: 22.12.0 - hooks: - - id: black -- repo: https://github.com/asottile/blacken-docs - rev: v1.12.1 - hooks: - - id: blacken-docs - additional_dependencies: [black] - exclude: README.md -- repo: https://github.com/pycqa/flake8 - rev: 6.0.0 - hooks: - - id: flake8 - entry: pflake8 - files: ^src/ - additional_dependencies: - - pyproject-flake8==6.0.0 - - flake8-bugbear==22.12.6 - - flake8-typing-imports==1.14.0 - - flake8-docstrings==1.6.0 - - flake8-rst-docstrings==0.3.0 - - flake8-rst==0.8.0 -- repo: https://github.com/pre-commit/pygrep-hooks - rev: v1.10.0 - hooks: - - id: python-use-type-annotations - - id: rst-backticks - - id: rst-directive-colons - - id: rst-inline-touching-normal -- repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.991 - hooks: - - id: mypy - files: ^src/ - additional_dependencies: - - tokenize-rt==4.1.0 - - types-pkg_resources==0.1.2 - - types-paramiko -- repo: https://github.com/codespell-project/codespell - rev: v2.2.2 - hooks: - - id: codespell - stages: [commit, commit-msg] - args: [--ignore-words-list, 'titel,statics,ba,nd,te'] - types_or: [python, rst, markdown] + - repo: https://github.com/charliermarsh/ruff-pre-commit + rev: v0.0.282 + hooks: + - id: ruff + args: [--fix] + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: check-yaml + - id: fix-encoding-pragma + args: [--remove] + - id: end-of-file-fixer + - id: trailing-whitespace + - repo: https://github.com/psf/black + rev: 22.12.0 + hooks: + - id: black + - repo: https://github.com/asottile/blacken-docs + rev: v1.12.1 + hooks: + - id: blacken-docs + additional_dependencies: [black] + exclude: README.md + - repo: https://github.com/pycqa/flake8 + rev: 6.0.0 + hooks: + - id: flake8 + entry: pflake8 + files: ^src/ + additional_dependencies: + - pyproject-flake8==6.0.0 + - flake8-bugbear==22.12.6 + - flake8-typing-imports==1.14.0 + - flake8-docstrings==1.6.0 + - flake8-rst-docstrings==0.3.0 + - flake8-rst==0.8.0 + - repo: https://github.com/pre-commit/pygrep-hooks + rev: v1.10.0 + hooks: + - id: python-use-type-annotations + - id: rst-backticks + - id: rst-directive-colons + - id: rst-inline-touching-normal + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.991 + hooks: + - id: mypy + files: ^src/ + additional_dependencies: + - tokenize-rt==4.1.0 + - types-pkg_resources==0.1.2 + - types-paramiko + - repo: https://github.com/codespell-project/codespell + rev: v2.2.2 + hooks: + - id: codespell + stages: [commit, commit-msg] + args: [--ignore-words-list, "titel,statics,ba,nd,te"] + types_or: [python, rst, markdown] diff --git a/examples/schema.py b/examples/schema.py index a540527c..b0da14e3 100644 --- a/examples/schema.py +++ b/examples/schema.py @@ -22,5 +22,5 @@ def compute(a: float, b: float): print(compute_job.output.total) # OutputReference(8ff2a94e-7633-42e9-8aa0-8479801347d5, .total) -compute_job.output.not_in_schema +_ = compute_job.output.not_in_schema # AttributeError: ComputeSchema does not have property 'not_in_schema'. diff --git a/pyproject.toml b/pyproject.toml index 485498a4..a4946947 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,54 +11,54 @@ license = { text = "modified BSD" } authors = [{ name = "Alex Ganose", email = "alexganose@gmail.com" }] dynamic = ["version"] classifiers = [ - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", "Development Status :: 2 - Pre-Alpha", + "Intended Audience :: Information Technology", "Intended Audience :: Science/Research", "Intended Audience :: System Administrators", - "Intended Audience :: Information Technology", "Operating System :: OS Independent", - "Topic :: Other/Nonlisted Topic", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", "Topic :: Database :: Front-Ends", + "Topic :: Other/Nonlisted Topic", "Topic :: Scientific/Engineering", ] requires-python = ">=3.8" dependencies = [ + "PyYAML", + "maggma>=0.38.1", "monty>=2021.5.9", - "pydash", "networkx", - "maggma>=0.38.1", "pydantic", - "PyYAML", + "pydash", ] [project.optional-dependencies] docs = [ - "sphinx==7.1.1", + "autodoc_pydantic==1.9.0", "furo==2023.7.26", - "myst_parser==2.0.0", "ipython==8.14.0", + "myst_parser==2.0.0", "nbsphinx==0.9.2", - "autodoc_pydantic==1.9.0", "sphinx-copybutton==0.5.2", + "sphinx==7.1.1", ] dev = ["pre-commit>=2.12.1"] -tests = ["pytest==7.4.0", "pytest-cov==4.1.0"] +tests = ["pytest-cov==4.1.0", "pytest==7.4.0"] vis = ["matplotlib", "pydot"] fireworks = ["FireWorks"] strict = [ + "FireWorks==2.0.3", + "PyYAML==6.0.1", + "maggma==0.51.24", + "matplotlib==3.7.2", "monty==2023.5.8", + "moto==4.1.14", "networkx==3.1", - "pydash==7.0.6", - "maggma==0.51.24", "pydantic==1.10.9", - "PyYAML==6.0.1", - "FireWorks==2.0.3", - "matplotlib==3.7.2", + "pydash==7.0.6", "pydot==1.4.2", - "moto==4.1.14", "typing-extensions==4.7.1", ] @@ -91,9 +91,9 @@ no_strict_optional = true [tool.pytest.ini_options] filterwarnings = [ "ignore:.*POTCAR.*:UserWarning", - "ignore:.*magmom.*:UserWarning", - "ignore:.*is not gzipped.*:UserWarning", "ignore:.*input structure.*:UserWarning", + "ignore:.*is not gzipped.*:UserWarning", + "ignore:.*magmom.*:UserWarning", "ignore::DeprecationWarning", ] @@ -109,9 +109,9 @@ source = ["src/"] skip_covered = true show_missing = true exclude_lines = [ + '^\s*@overload( |$)', '^\s*assert False(,|$)', 'if typing.TYPE_CHECKING:', - '^\s*@overload( |$)', ] [tool.ruff] @@ -134,6 +134,7 @@ select = [ "W", # pycodestyle "YTT", # flake8-2020 ] +ignore = ["B028", "PLW0603", "RUF013"] pydocstyle.convention = "numpy" isort.known-first-party = ["jobflow"] diff --git a/src/jobflow/core/flow.py b/src/jobflow/core/flow.py index 2be987a3..e1a620df 100644 --- a/src/jobflow/core/flow.py +++ b/src/jobflow/core/flow.py @@ -121,11 +121,11 @@ class Flow(MSONable): def __init__( self, jobs: list[Flow | jobflow.Job] | jobflow.Job | Flow, - output: Any | None = None, + output: Any = None, name: str = "Flow", order: JobOrder = JobOrder.AUTO, uuid: str = None, - hosts: list[str] | None = None, + hosts: list[str] = None, ): from jobflow.core.job import Job @@ -336,8 +336,8 @@ def iterflow(self): def update_kwargs( self, update: dict[str, Any], - name_filter: str | None = None, - function_filter: Callable | None = None, + name_filter: str = None, + function_filter: Callable = None, dict_mod: bool = False, ): """ @@ -392,8 +392,8 @@ def update_kwargs( def update_maker_kwargs( self, update: dict[str, Any], - name_filter: str | None = None, - class_filter: type[jobflow.Maker] | None = None, + name_filter: str = None, + class_filter: type[jobflow.Maker] = None, nested: bool = True, dict_mod: bool = False, ): @@ -511,8 +511,8 @@ def append_name(self, append_str: str, prepend: bool = False): def update_metadata( self, update: dict[str, Any], - name_filter: str | None = None, - function_filter: Callable | None = None, + name_filter: str = None, + function_filter: Callable = None, dict_mod: bool = False, dynamic: bool = True, ): @@ -634,7 +634,7 @@ def update_config( ) def add_hosts_uuids( - self, hosts_uuids: str | list[str] | None = None, prepend: bool = False + self, hosts_uuids: str | list[str] = None, prepend: bool = False ): """ Add a list of UUIDs to the internal list of hosts. diff --git a/src/jobflow/core/job.py b/src/jobflow/core/job.py index b3ceda00..8b9ee18c 100644 --- a/src/jobflow/core/job.py +++ b/src/jobflow/core/job.py @@ -66,7 +66,7 @@ class JobConfig(MSONable): response_manager_config: dict = field(default_factory=dict) -def job(method: Callable | None = None, **job_kwargs): +def job(method: Callable = None, **job_kwargs): """ Wrap a function to produce a :obj:`Job`. @@ -179,7 +179,6 @@ def decorator(func): @wraps(func) def get_job(*args, **kwargs) -> Job: - f = func if len(args) > 0: # see if the first argument has a function with the same name as @@ -309,15 +308,15 @@ def __init__( function: Callable, function_args: tuple[Any, ...] = None, function_kwargs: dict[str, Any] = None, - output_schema: type[BaseModel] | None = None, + output_schema: type[BaseModel] = None, uuid: str = None, index: int = 1, - name: str | None = None, + name: str = None, metadata: dict[str, Any] = None, config: JobConfig = None, - hosts: list[str] | None = None, - metadata_updates: list[dict[str, Any]] | None = None, - config_updates: list[dict[str, Any]] | None = None, + hosts: list[str] = None, + metadata_updates: list[dict[str, Any]] = None, + config_updates: list[dict[str, Any]] = None, **kwargs, ): from copy import deepcopy @@ -574,7 +573,6 @@ def run(self, store: jobflow.JobStore) -> Response: passed_config = None if passed_config: - if response.addition is not None: pass_manager_config(response.addition, passed_config) @@ -664,8 +662,8 @@ def resolve_args( def update_kwargs( self, update: dict[str, Any], - name_filter: str | None = None, - function_filter: Callable | None = None, + name_filter: str = None, + function_filter: Callable = None, dict_mod: bool = False, ): """ @@ -720,8 +718,8 @@ def update_kwargs( def update_maker_kwargs( self, update: dict[str, Any], - name_filter: str | None = None, - class_filter: type[jobflow.Maker] | None = None, + name_filter: str = None, + class_filter: type[jobflow.Maker] = None, nested: bool = True, dict_mod: bool = False, ): @@ -853,8 +851,8 @@ def append_name(self, append_str: str, prepend: bool = False): def update_metadata( self, update: dict[str, Any], - name_filter: str | None = None, - function_filter: Callable | None = None, + name_filter: str = None, + function_filter: Callable = None, dict_mod: bool = False, dynamic: bool = True, ): @@ -1134,11 +1132,11 @@ class Response(typing.Generic[T]): Stop executing all remaining jobs. """ - output: T | None = None - detour: jobflow.Flow | Job | list[Job] | list[jobflow.Flow] | None = None - addition: jobflow.Flow | Job | list[Job] | list[jobflow.Flow] | None = None - replace: jobflow.Flow | Job | list[Job] | list[jobflow.Flow] | None = None - stored_data: dict[Hashable, Any] | None = None + output: T = None + detour: jobflow.Flow | Job | list[Job] | list[jobflow.Flow] = None + addition: jobflow.Flow | Job | list[Job] | list[jobflow.Flow] = None + replace: jobflow.Flow | Job | list[Job] | list[jobflow.Flow] = None + stored_data: dict[Hashable, Any] = None stop_children: bool = False stop_jobflow: bool = False @@ -1146,7 +1144,7 @@ class Response(typing.Generic[T]): def from_job_returns( cls, job_returns: Any | None, - output_schema: type[BaseModel] | None = None, + output_schema: type[BaseModel] = None, ) -> Response: """ Generate a :obj:`Response` from the outputs of a :obj:`Job`. @@ -1323,7 +1321,6 @@ def pass_manager_config( all_jobs: list[Job] = [] def get_jobs(arg): - if isinstance(arg, Job): all_jobs.append(arg) elif isinstance(arg, (list, tuple)): diff --git a/src/jobflow/core/maker.py b/src/jobflow/core/maker.py index 2b99689c..19ba9c39 100644 --- a/src/jobflow/core/maker.py +++ b/src/jobflow/core/maker.py @@ -132,8 +132,8 @@ def name(self): def update_kwargs( self, update: dict[str, Any], - name_filter: str | None = None, - class_filter: type[Maker] | None = None, + name_filter: str = None, + class_filter: type[Maker] = None, nested: bool = True, dict_mod: bool = False, ): @@ -234,8 +234,8 @@ def _update_kwargs_func(maker: Maker): def recursive_call( obj: Maker, func: Callable[[Maker], Maker], - name_filter: str | None = None, - class_filter: type[Maker] | None = None, + name_filter: str = None, + class_filter: type[Maker] = None, nested: bool = True, ): """Recursively call a function on all Maker objects in the object. diff --git a/src/jobflow/core/reference.py b/src/jobflow/core/reference.py index ea5d6e4a..1f54fcb2 100644 --- a/src/jobflow/core/reference.py +++ b/src/jobflow/core/reference.py @@ -13,7 +13,6 @@ from jobflow.utils.enum import ValueEnum if typing.TYPE_CHECKING: - import jobflow __all__ = [ @@ -95,7 +94,7 @@ def __init__( self, uuid: str, attributes: tuple[tuple[str, Any], ...] = (), - output_schema: type[BaseModel] | None = None, + output_schema: type[BaseModel] = None, ): super().__init__() self.uuid = uuid @@ -111,7 +110,7 @@ def __init__( def resolve( self, store: jobflow.JobStore | None, - cache: dict[str, Any] | None = None, + cache: dict[str, Any] = None, on_missing: OnMissing = OnMissing.ERROR, ) -> Any: """ @@ -264,7 +263,7 @@ def __repr__(self) -> str: else: attribute_str = "" - return f"OutputReference({str(self.uuid)}{attribute_str})" + return f"OutputReference({self.uuid!s}{attribute_str})" def __hash__(self) -> int: """Return a hash of the reference.""" @@ -277,10 +276,8 @@ def __eq__(self, other: Any) -> bool: self.uuid == other.uuid and len(self.attributes) == len(other.attributes) and all( - [ - a[0] == b[0] and a[1] == b[1] - for a, b in zip(self.attributes, other.attributes) - ] + a[0] == b[0] and a[1] == b[1] + for a, b in zip(self.attributes, other.attributes) ) ) return False @@ -288,9 +285,7 @@ def __eq__(self, other: Any) -> bool: @property def attributes_formatted(self): """Get a formatted description of the attributes.""" - return [ - f".{x[1]}" if x[0] == "a" else f"[{repr(x[1])}]" for x in self.attributes - ] + return [f".{x[1]}" if x[0] == "a" else f"[{x[1]!r}]" for x in self.attributes] def as_dict(self): """Serialize the reference as a dict.""" @@ -310,7 +305,7 @@ def as_dict(self): def resolve_references( references: Sequence[OutputReference], store: jobflow.JobStore, - cache: dict[str, Any] | None = None, + cache: dict[str, Any] = None, on_missing: OnMissing = OnMissing.ERROR, ) -> dict[OutputReference, Any]: """ @@ -403,7 +398,7 @@ def find_and_get_references(arg: Any) -> tuple[OutputReference, ...]: def find_and_resolve_references( arg: Any, store: jobflow.JobStore, - cache: dict[str, Any] | None = None, + cache: dict[str, Any] = None, on_missing: OnMissing = OnMissing.ERROR, ) -> Any: """ diff --git a/src/jobflow/core/state.py b/src/jobflow/core/state.py index 38e73673..f67dde40 100644 --- a/src/jobflow/core/state.py +++ b/src/jobflow/core/state.py @@ -26,8 +26,8 @@ class State: """State of the current job and store.""" - job: jobflow.Job | None = None - store: jobflow.JobStore | None = None + job: jobflow.Job = None + store: jobflow.JobStore = None def reset(self): """Reset the current state.""" diff --git a/src/jobflow/core/store.py b/src/jobflow/core/store.py index cb467876..b665e3cd 100644 --- a/src/jobflow/core/store.py +++ b/src/jobflow/core/store.py @@ -51,7 +51,7 @@ class JobStore(Store): def __init__( self, docs_store: Store, - additional_stores: dict[str, Store] | None = None, + additional_stores: dict[str, Store] = None, save: save_type = None, load: load_type = False, ): @@ -117,7 +117,7 @@ def close(self): for additional_store in self.additional_stores.values(): additional_store.close() - def count(self, criteria: dict | None = None) -> int: + def count(self, criteria: dict = None) -> int: """ Count the number of documents matching the query criteria. @@ -135,9 +135,9 @@ def count(self, criteria: dict | None = None) -> int: def query( self, - criteria: dict | None = None, - properties: dict | list | None = None, - sort: dict[str, Sort | int] | None = None, + criteria: dict = None, + properties: dict | list = None, + sort: dict[str, Sort | int] = None, skip: int = 0, limit: int = 0, load: load_type = None, @@ -220,9 +220,9 @@ def query( def query_one( self, - criteria: dict | None = None, - properties: dict | list | None = None, - sort: dict[str, Sort | int] | None = None, + criteria: dict = None, + properties: dict | list = None, + sort: dict[str, Sort | int] = None, load: load_type = None, ) -> dict | None: """ @@ -255,7 +255,7 @@ def query_one( def update( self, docs: list[dict] | dict, - key: list | str | None = None, + key: list | str = None, save: bool | save_type = None, ): """ @@ -295,7 +295,7 @@ def update( blob_data = defaultdict(list) dict_docs = [] for doc in docs: - doc = jsanitize(doc, strict=True, allow_bson=True) + doc = jsanitize(doc, strict=True, allow_bson=True) # noqa: PLW2901 dict_docs.append(doc) if save_keys: @@ -364,9 +364,9 @@ def ensure_index(self, key: str, unique: bool = False) -> bool: def groupby( self, keys: list[str] | str, - criteria: dict | None = None, - properties: dict | list | None = None, - sort: dict[str, Sort | int] | None = None, + criteria: dict = None, + properties: dict | list = None, + sort: dict[str, Sort | int] = None, skip: int = 0, limit: int = 0, load: load_type = None, @@ -449,7 +449,7 @@ def get_output( uuid: str, which: str | int = "last", load: load_type = False, - cache: dict[str, Any] | None = None, + cache: dict[str, Any] = None, on_missing: OnMissing = OnMissing.ERROR, ): """ @@ -515,7 +515,7 @@ def get_output( raise ValueError(f"UUID: {uuid}{istr} has no outputs.") refs = find_and_get_references(result["output"]) - if any([ref.uuid == uuid for ref in refs]): + if any(ref.uuid == uuid for ref in refs): raise RuntimeError("Reference cycle detected - aborting.") return find_and_resolve_references( @@ -537,7 +537,7 @@ def get_output( results = [r["output"] for r in results] refs = find_and_get_references(results) - if any([ref.uuid == uuid for ref in refs]): + if any(ref.uuid == uuid for ref in refs): raise RuntimeError("Reference cycle detected - aborting.") return find_and_resolve_references( @@ -704,7 +704,7 @@ def _prepare_load( new_load[store_name] = store_load else: if not isinstance(store_load, (tuple, list)): - store_load = [store_load] + store_load = [store_load] # noqa: PLW2901 new_store_load = [] for ltype in store_load: @@ -731,7 +731,7 @@ def _prepare_save( new_save = {} for store_name, store_save in save.items(): if not isinstance(store_save, (tuple, list, bool)): - store_save = [store_save] + store_save = [store_save] # noqa: PLW2901 new_save[store_name] = [ o.value if isinstance(o, Enum) else o for o in store_save diff --git a/src/jobflow/managers/fireworks.py b/src/jobflow/managers/fireworks.py index aed8cb67..69f8de75 100644 --- a/src/jobflow/managers/fireworks.py +++ b/src/jobflow/managers/fireworks.py @@ -16,7 +16,7 @@ def flow_to_workflow( flow: jobflow.Flow | jobflow.Job | list[jobflow.Job], - store: jobflow.JobStore | None = None, + store: jobflow.JobStore = None, **kwargs, ) -> Workflow: """ @@ -43,7 +43,7 @@ def flow_to_workflow( Workflow The job or flow as a workflow. """ - from fireworks.core.firework import Firework, Workflow + from fireworks.core.firework import Workflow from jobflow.core.flow import get_flow @@ -61,9 +61,9 @@ def flow_to_workflow( def job_to_firework( job: jobflow.Job, - store: jobflow.JobStore | None = None, - parents: Sequence[str] | None = None, - parent_mapping: dict[str, Firework] | None = None, + store: jobflow.JobStore = None, + parents: Sequence[str] = None, + parent_mapping: dict[str, Firework] = None, **kwargs, ) -> Firework: """ @@ -139,7 +139,7 @@ class JobFiretask(FiretaskBase): the computer that runs the workflow will be used. """ - required_params = ["job", "store"] + required_params = ("job", "store") def run_task(self, fw_spec): """Run the job and handle any dynamic firework submissions.""" diff --git a/src/jobflow/managers/local.py b/src/jobflow/managers/local.py index 77591ac7..a29bf735 100644 --- a/src/jobflow/managers/local.py +++ b/src/jobflow/managers/local.py @@ -18,7 +18,7 @@ def run_locally( flow: jobflow.Flow | jobflow.Job | list[jobflow.Job], log: bool = True, - store: jobflow.JobStore | None = None, + store: jobflow.JobStore = None, create_folders: bool = False, ensure_success: bool = False, ) -> dict[str, dict[int, jobflow.Response]]: diff --git a/src/jobflow/utils/find.py b/src/jobflow/utils/find.py index ab841101..2588a8af 100644 --- a/src/jobflow/utils/find.py +++ b/src/jobflow/utils/find.py @@ -244,7 +244,7 @@ def get_root_locations(locations): sorted_locs = sorted(locations, key=lambda x: len(x)) root_locations = [] for loc in sorted_locs: - if any([loc[: len(rloc)] == rloc for rloc in root_locations]): + if any(loc[: len(rloc)] == rloc for rloc in root_locations): continue root_locations.append(loc) return root_locations diff --git a/tests/core/test_maker.py b/tests/core/test_maker.py index 65964d54..d08cb352 100644 --- a/tests/core/test_maker.py +++ b/tests/core/test_maker.py @@ -1,3 +1,5 @@ +from dataclasses import field + import pytest @@ -15,7 +17,6 @@ class BadMaker(Maker): @dataclass class BadMaker(Maker): - a = 1 with pytest.raises(NotImplementedError): @@ -27,7 +28,6 @@ def test_required_arguments_works(): @dataclass class MyMaker: - a: int name = "123" @@ -91,8 +91,8 @@ def make(self, a, b): @dataclass class DoubleAddMaker(Maker): name: str = "add_add" - add1: AddMaker = AddMaker() - add2: AddMaker = AddMaker() + add1: AddMaker = field(default_factory=AddMaker) + add2: AddMaker = field(default_factory=AddMaker) def make(self, a, b): first = self.add1.make(a, b) @@ -131,7 +131,7 @@ def make(self, a, b): @dataclass class DetourMaker(Maker): name: str = "add" - add_maker: Maker = AddMaker() + add_maker: Maker = field(default_factory=AddMaker) def make(self, a, b): detour = self.add_maker.make(a, b) @@ -198,7 +198,7 @@ def make(self, a, b): @dataclass class FakeDetourMaker(Maker): name: str = "add" - add_maker: MSONable = NotAMaker() + add_maker: MSONable = field(default_factory=NotAMaker) def make(self, a, b): detour = self.add_maker.make(a, b) diff --git a/tests/core/test_store.py b/tests/core/test_store.py index ba6d8e59..5066143e 100644 --- a/tests/core/test_store.py +++ b/tests/core/test_store.py @@ -299,9 +299,9 @@ def test_groupby(memory_jobstore): ) data = list(memory_jobstore.groupby("d")) assert len(data) == 2 - grouped_by_9 = [g[1] for g in data if g[0]["d"] == 9][0] + grouped_by_9 = next(g[1] for g in data if g[0]["d"] == 9) assert len(grouped_by_9) == 3 - grouped_by_10 = [g[1] for g in data if g[0]["d"] == 10][0] + grouped_by_10 = next(g[1] for g in data if g[0]["d"] == 10) assert len(grouped_by_10) == 1 data = list(memory_jobstore.groupby(["e", "d"])) diff --git a/tests/managers/conftest.py b/tests/managers/conftest.py index f025f111..dd170df7 100644 --- a/tests/managers/conftest.py +++ b/tests/managers/conftest.py @@ -291,13 +291,13 @@ def fw_dir(): import tempfile old_cwd = os.getcwd() - newpath = tempfile.mkdtemp() - os.chdir(newpath) + new_path = tempfile.mkdtemp() + os.chdir(new_path) yield os.chdir(old_cwd) - shutil.rmtree(newpath) + shutil.rmtree(new_path) @pytest.fixture(scope="session") diff --git a/tests/managers/test_fireworks.py b/tests/managers/test_fireworks.py index 8feb5c8e..4821d239 100644 --- a/tests/managers/test_fireworks.py +++ b/tests/managers/test_fireworks.py @@ -111,10 +111,10 @@ def test_simple_flow(lpad, mongo_jobstore, fw_dir, simple_flow, capsys): rapidfire(lpad) # check workflow completed - fw_id = list(fw_ids.values())[0] + fw_id = next(iter(fw_ids.values())) wf = lpad.get_wf_by_fw_id(fw_id) - assert all([s == "COMPLETED" for s in wf.fw_states.values()]) + assert all(s == "COMPLETED" for s in wf.fw_states.values()) # check store has the activity output result = mongo_jobstore.query_one({"uuid": uuid}) @@ -142,10 +142,10 @@ def test_simple_flow_no_store(lpad, fw_dir, simple_flow, capsys): rapidfire(lpad) # check workflow completed - fw_id = list(fw_ids.values())[0] + fw_id = next(iter(fw_ids.values())) wf = lpad.get_wf_by_fw_id(fw_id) - assert all([s == "COMPLETED" for s in wf.fw_states.values()]) + assert all(s == "COMPLETED" for s in wf.fw_states.values()) # check store has the activity output result = SETTINGS.JOB_STORE.query_one({"uuid": uuid}) @@ -175,10 +175,10 @@ def test_simple_flow_metadata( rapidfire(lpad) # check workflow completed - fw_id = list(fw_ids.values())[0] + fw_id = next(iter(fw_ids.values())) wf = lpad.get_wf_by_fw_id(fw_id) - assert all([s == "COMPLETED" for s in wf.fw_states.values()]) + assert all(s == "COMPLETED" for s in wf.fw_states.values()) assert wf.fws[0].spec["tags"] == ["my_flow"] # check store has the activity output @@ -216,7 +216,7 @@ def test_simple_flow_metadata( rapidfire(lpad) result = mongo_jobstore.query_one({"uuid": uuid}) - fw_id = list(fw_ids.values())[0] + fw_id = next(iter(fw_ids.values())) assert result["metadata"] == {"fw_id": fw_id, "tags": ["my_flow"]} # Test flow with existing tags @@ -258,10 +258,10 @@ def test_connected_flow(lpad, mongo_jobstore, fw_dir, connected_flow, capsys): rapidfire(lpad) # check workflow completed - fw_id = list(fw_ids.values())[0] + fw_id = next(iter(fw_ids.values())) wf = lpad.get_wf_by_fw_id(fw_id) - assert all([s == "COMPLETED" for s in wf.fw_states.values()]) + assert all(s == "COMPLETED" for s in wf.fw_states.values()) # check store has the activity output result1 = mongo_jobstore.query_one({"uuid": uuid1}) @@ -289,10 +289,10 @@ def test_nested_flow(lpad, mongo_jobstore, fw_dir, nested_flow, capsys): rapidfire(lpad) # check workflow completed - fw_id = list(fw_ids.values())[0] + fw_id = next(iter(fw_ids.values())) wf = lpad.get_wf_by_fw_id(fw_id) - assert all([s == "COMPLETED" for s in wf.fw_states.values()]) + assert all(s == "COMPLETED" for s in wf.fw_states.values()) # check store has the activity output result1 = mongo_jobstore.query_one({"uuid": uuid1}) @@ -321,12 +321,12 @@ def test_addition_flow(lpad, mongo_jobstore, fw_dir, addition_flow, capsys): rapidfire(lpad) # check workflow completed - fw_id = list(fw_ids.values())[0] + fw_id = next(iter(fw_ids.values())) wf = lpad.get_wf_by_fw_id(fw_id) uuids = [fw.tasks[0]["job"].uuid for fw in wf.fws] - uuid2 = [u for u in uuids if u != uuid1][0] - assert all([s == "COMPLETED" for s in wf.fw_states.values()]) + uuid2 = next(u for u in uuids if u != uuid1) + assert all(s == "COMPLETED" for s in wf.fw_states.values()) # check store has the activity output result1 = mongo_jobstore.query_one({"uuid": uuid1}) @@ -352,12 +352,12 @@ def test_detour_flow(lpad, mongo_jobstore, fw_dir, detour_flow, capsys): rapidfire(lpad) # check workflow completed - fw_id = list(fw_ids.values())[0] + fw_id = next(iter(fw_ids.values())) wf = lpad.get_wf_by_fw_id(fw_id) uuids = [fw.tasks[0]["job"].uuid for fw in wf.fws] - uuid2 = [u for u in uuids if u != uuid1 and u != uuid3][0] - assert all([s == "COMPLETED" for s in wf.fw_states.values()]) + uuid2 = next(u for u in uuids if u != uuid1 and u != uuid3) + assert all(s == "COMPLETED" for s in wf.fw_states.values()) # check store has the activity output result1 = mongo_jobstore.query_one({"uuid": uuid1}) @@ -388,10 +388,10 @@ def test_replace_flow(lpad, mongo_jobstore, fw_dir, replace_flow, capsys): rapidfire(lpad) # check workflow completed - fw_id = list(fw_ids.values())[0] + fw_id = next(iter(fw_ids.values())) wf = lpad.get_wf_by_fw_id(fw_id) - assert all([s == "COMPLETED" for s in wf.fw_states.values()]) + assert all(s == "COMPLETED" for s in wf.fw_states.values()) # check store has the activity output result1 = mongo_jobstore.query_one({"uuid": uuid1, "index": 1}) @@ -422,7 +422,7 @@ def test_stop_jobflow_flow(lpad, mongo_jobstore, fw_dir, stop_jobflow_flow, caps rapidfire(lpad) # check workflow completed - fw_id = list(fw_ids.values())[0] + fw_id = next(iter(fw_ids.values())) wf = lpad.get_wf_by_fw_id(fw_id) assert set(wf.fw_states.values()) == {"COMPLETED", "DEFUSED"} @@ -450,7 +450,7 @@ def test_stop_jobflow_job(lpad, mongo_jobstore, fw_dir, stop_jobflow_job, capsys rapidfire(lpad) # check workflow completed - fw_id = list(fw_ids.values())[0] + fw_id = next(iter(fw_ids.values())) wf = lpad.get_wf_by_fw_id(fw_id) assert list(wf.fw_states.values()) == ["COMPLETED"] @@ -480,7 +480,7 @@ def test_stop_children_flow(lpad, mongo_jobstore, fw_dir, stop_children_flow, ca rapidfire(lpad) # check workflow completed - fw_id = list(fw_ids.values())[0] + fw_id = next(iter(fw_ids.values())) wf = lpad.get_wf_by_fw_id(fw_id) states = Counter(wf.fw_states.values()) @@ -511,7 +511,7 @@ def test_error_flow(lpad, mongo_jobstore, fw_dir, error_flow): # run the workflow rapidfire(lpad) # check workflow completed - fw_id = list(fw_ids.values())[0] + fw_id = next(iter(fw_ids.values())) wf = lpad.get_wf_by_fw_id(fw_id) assert set(wf.fw_states.values()) == {"WAITING", "FIZZLED"} @@ -527,7 +527,7 @@ def test_stored_data_flow(lpad, mongo_jobstore, fw_dir, stored_data_flow, capsys from jobflow.managers.fireworks import flow_to_workflow flow = stored_data_flow() - flow.jobs[0].uuid + _fw_id = flow.jobs[0].uuid wf = flow_to_workflow(flow, mongo_jobstore) fw_ids = lpad.add_wf(wf) @@ -536,7 +536,7 @@ def test_stored_data_flow(lpad, mongo_jobstore, fw_dir, stored_data_flow, capsys rapidfire(lpad) # check workflow completed - fw_id = list(fw_ids.values())[0] + fw_id = next(iter(fw_ids.values())) wf = lpad.get_wf_by_fw_id(fw_id) assert list(wf.fw_states.values()) == ["COMPLETED"] @@ -561,16 +561,14 @@ def test_detour_stop_flow(lpad, mongo_jobstore, fw_dir, detour_stop_flow, capsys rapidfire(lpad) # check workflow completed - fw_id = list(fw_ids.values())[0] + fw_id = next(iter(fw_ids.values())) wf = lpad.get_wf_by_fw_id(fw_id) uuids = [fw.tasks[0]["job"].uuid for fw in wf.fws] - uuid2 = [u for u in uuids if u != uuid1 and u != uuid3][0] + uuid2 = next(u for u in uuids if u != uuid1 and u != uuid3) # Sort by firework id explicitly instead of assuming they are sorted - states_dict = { - key: val for key, val in zip(list(wf.id_fw.keys()), list(wf.fw_states.values())) - } + states_dict = dict(zip(list(wf.id_fw.keys()), list(wf.fw_states.values()))) sorted_states_dict = dict(sorted(states_dict.items())) assert list(sorted_states_dict.values()) == ["DEFUSED", "COMPLETED", "COMPLETED"] @@ -602,13 +600,13 @@ def test_replace_and_detour_flow( rapidfire(lpad) # check workflow completed - fw_id = list(fw_ids.values())[0] + fw_id = next(iter(fw_ids.values())) wf = lpad.get_wf_by_fw_id(fw_id) uuids = [fw.tasks[0]["job"].uuid for fw in wf.fws] - uuid2 = [u for u in uuids if u != uuid1 and u != uuid3][0] + uuid2 = next(u for u in uuids if u != uuid1 and u != uuid3) - assert all([s == "COMPLETED" for s in wf.fw_states.values()]) + assert all(s == "COMPLETED" for s in wf.fw_states.values()) # check store has the activity output result1 = mongo_jobstore.query_one({"uuid": uuid1, "index": 1}) diff --git a/tests/managers/test_local.py b/tests/managers/test_local.py index ff78b99e..5658dbb5 100644 --- a/tests/managers/test_local.py +++ b/tests/managers/test_local.py @@ -125,7 +125,7 @@ def test_addition_flow(memory_jobstore, clean_dir, addition_flow): # run with log responses = run_locally(flow, store=memory_jobstore) - uuid2 = [u for u in responses if u != uuid1][0] + uuid2 = next(u for u in responses if u != uuid1) # check responses has been filled assert len(responses) == 2 @@ -150,7 +150,7 @@ def test_detour_flow(memory_jobstore, clean_dir, detour_flow): # run with log responses = run_locally(flow, store=memory_jobstore) - uuid2 = [u for u in responses if u != uuid1 and u != uuid3][0] + uuid2 = next(u for u in responses if u != uuid1 and u != uuid3) # check responses has been filled assert len(responses) == 3 @@ -345,7 +345,7 @@ def test_detour_stop_flow(memory_jobstore, clean_dir, detour_stop_flow): # run with log responses = run_locally(flow, store=memory_jobstore) - uuid2 = [u for u in responses if u != uuid1 and u != uuid3][0] + uuid2 = next(u for u in responses if u != uuid1 and u != uuid3) # check responses has been filled assert len(responses) == 2