From bc5a9fb7309a7b5eb77bb1dc5956d2d6307df345 Mon Sep 17 00:00:00 2001 From: Jason Munro Date: Thu, 17 Nov 2022 15:23:23 -0800 Subject: [PATCH] Add task_id list and elements to task `search` method (#709) * Add task_id list and elements to search * Import id validation * Update task tests * Fix task test * Fix task tests --- mp_api/client/routes/tasks.py | 26 ++++++++++++++++++++----- tests/test_tasks.py | 36 ++++++++++++++--------------------- 2 files changed, 35 insertions(+), 27 deletions(-) diff --git a/mp_api/client/routes/tasks.py b/mp_api/client/routes/tasks.py index b6a6e628..086bd489 100644 --- a/mp_api/client/routes/tasks.py +++ b/mp_api/client/routes/tasks.py @@ -5,6 +5,8 @@ import warnings +from mp_api.client.core.utils import validate_ids + class TaskRester(BaseRester[TaskDoc]): @@ -22,9 +24,9 @@ def get_trajectory(self, task_id): :return: List of trajectory objects """ - traj_data = self._query_resource_data( - suburl=f"trajectory/{task_id}/", use_document_model=False - )[0].get("trajectories", None) + traj_data = self._query_resource_data(suburl=f"trajectory/{task_id}/", use_document_model=False)[0].get( + "trajectories", None + ) if traj_data is None: raise MPRestError(f"No trajectory data for {task_id} found") @@ -37,8 +39,7 @@ def search_task_docs(self, *args, **kwargs): # pragma: no cover """ warnings.warn( - "MPRester.tasks.search_task_docs is deprecated. " - "Please use MPRester.tasks.search instead.", + "MPRester.tasks.search_task_docs is deprecated. " "Please use MPRester.tasks.search instead.", DeprecationWarning, stacklevel=2, ) @@ -47,7 +48,10 @@ def search_task_docs(self, *args, **kwargs): # pragma: no cover def search( self, + task_ids: Optional[List[str]] = None, chemsys: Optional[Union[str, List[str]]] = None, + elements: Optional[List[str]] = None, + exclude_elements: Optional[List[str]] = None, formula: Optional[Union[str, List[str]]] = None, num_chunks: Optional[int] = None, chunk_size: int = 1000, @@ -58,8 +62,11 @@ def search( Query core task docs using a variety of search criteria. Arguments: + task_ids (List[str]): List of Materials Project IDs to return data for. chemsys (str, List[str]): A chemical system or list of chemical systems (e.g., Li-Fe-O, Si-*, [Si-O, Li-Fe-P]). + elements (List[str]): A list of elements. + exclude_elements (List[str]): A list of elements to exclude. formula (str, List[str]): A formula including anonymized formula or wild cards (e.g., Fe2O3, ABO3, Si*). A list of chemical formulas can also be passed (e.g., [Fe2O3, ABO3]). @@ -75,9 +82,18 @@ def search( query_params = {} # type: dict + if task_ids: + query_params.update({"task_ids": ",".join(validate_ids(task_ids))}) + if formula: query_params.update({"formula": formula}) + if elements: + query_params.update({"elements": ",".join(elements)}) + + if exclude_elements: + query_params.update({"exclude_elements": ",".join(exclude_elements)}) + if chemsys: if isinstance(chemsys, str): chemsys = [chemsys] diff --git a/tests/test_tasks.py b/tests/test_tasks.py index 3c11d95f..f2bab514 100644 --- a/tests/test_tasks.py +++ b/tests/test_tasks.py @@ -19,19 +19,22 @@ def rester(): "num_chunks", "all_fields", "fields", - "formula", # Timeout issue + "formula", + "elements", + "exclude_elements", ] sub_doc_fields = [] # type: list -alt_name_dict = {"formula": "task_id"} # type: dict +alt_name_dict = {"formula": "task_id", "task_ids": "task_id", "exclude_elements": "task_id"} # type: dict -custom_field_tests = {"chemsys": "Si-O"} # type: dict +custom_field_tests = { + "chemsys": "Si-O", + "task_ids": ["mp-149"], +} # type: dict -@pytest.mark.skipif( - os.environ.get("MP_API_KEY", None) is None, reason="No API key found." -) +@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") def test_client(rester): search_method = rester.search @@ -52,9 +55,7 @@ def test_client(rester): param: (-100, 100), "chunk_size": 1, "num_chunks": 1, - "fields": [ - project_field if project_field is not None else param - ], + "fields": [project_field if project_field is not None else param], } elif param_type == typing.Tuple[float, float]: project_field = alt_name_dict.get(param, None) @@ -62,9 +63,7 @@ def test_client(rester): param: (-100.12, 100.12), "chunk_size": 1, "num_chunks": 1, - "fields": [ - project_field if project_field is not None else param - ], + "fields": [project_field if project_field is not None else param], } elif param_type is bool: project_field = alt_name_dict.get(param, None) @@ -72,9 +71,7 @@ def test_client(rester): param: False, "chunk_size": 1, "num_chunks": 1, - "fields": [ - project_field if project_field is not None else param - ], + "fields": [project_field if project_field is not None else param], } elif param in custom_field_tests: project_field = alt_name_dict.get(param, None) @@ -82,9 +79,7 @@ def test_client(rester): param: custom_field_tests[param], "chunk_size": 1, "num_chunks": 1, - "fields": [ - project_field if project_field is not None else param - ], + "fields": [project_field if project_field is not None else param], } doc = search_method(**q)[0].dict() @@ -92,10 +87,7 @@ def test_client(rester): if sub_field in doc: doc = doc[sub_field] - assert ( - doc[project_field if project_field is not None else param] - is not None - ) + assert doc[project_field if project_field is not None else param] is not None @pytest.mark.xfail(reason="Temporary until redeploy")