Skip to content

Commit

Permalink
Add task_id list and elements to task search method (#709)
Browse files Browse the repository at this point in the history
* Add task_id list and elements to search

* Import id validation

* Update task tests

* Fix task test

* Fix task tests
  • Loading branch information
Jason Munro authored Nov 17, 2022
1 parent 80f4686 commit bc5a9fb
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 27 deletions.
26 changes: 21 additions & 5 deletions mp_api/client/routes/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import warnings

from mp_api.client.core.utils import validate_ids


class TaskRester(BaseRester[TaskDoc]):

Expand All @@ -22,9 +24,9 @@ def get_trajectory(self, task_id):
:return: List of trajectory objects
"""

traj_data = self._query_resource_data(
suburl=f"trajectory/{task_id}/", use_document_model=False
)[0].get("trajectories", None)
traj_data = self._query_resource_data(suburl=f"trajectory/{task_id}/", use_document_model=False)[0].get(
"trajectories", None
)

if traj_data is None:
raise MPRestError(f"No trajectory data for {task_id} found")
Expand All @@ -37,8 +39,7 @@ def search_task_docs(self, *args, **kwargs): # pragma: no cover
"""

warnings.warn(
"MPRester.tasks.search_task_docs is deprecated. "
"Please use MPRester.tasks.search instead.",
"MPRester.tasks.search_task_docs is deprecated. " "Please use MPRester.tasks.search instead.",
DeprecationWarning,
stacklevel=2,
)
Expand All @@ -47,7 +48,10 @@ def search_task_docs(self, *args, **kwargs): # pragma: no cover

def search(
self,
task_ids: Optional[List[str]] = None,
chemsys: Optional[Union[str, List[str]]] = None,
elements: Optional[List[str]] = None,
exclude_elements: Optional[List[str]] = None,
formula: Optional[Union[str, List[str]]] = None,
num_chunks: Optional[int] = None,
chunk_size: int = 1000,
Expand All @@ -58,8 +62,11 @@ def search(
Query core task docs using a variety of search criteria.
Arguments:
task_ids (List[str]): List of Materials Project IDs to return data for.
chemsys (str, List[str]): A chemical system or list of chemical systems
(e.g., Li-Fe-O, Si-*, [Si-O, Li-Fe-P]).
elements (List[str]): A list of elements.
exclude_elements (List[str]): A list of elements to exclude.
formula (str, List[str]): A formula including anonymized formula
or wild cards (e.g., Fe2O3, ABO3, Si*). A list of chemical formulas can also be passed
(e.g., [Fe2O3, ABO3]).
Expand All @@ -75,9 +82,18 @@ def search(

query_params = {} # type: dict

if task_ids:
query_params.update({"task_ids": ",".join(validate_ids(task_ids))})

if formula:
query_params.update({"formula": formula})

if elements:
query_params.update({"elements": ",".join(elements)})

if exclude_elements:
query_params.update({"exclude_elements": ",".join(exclude_elements)})

if chemsys:
if isinstance(chemsys, str):
chemsys = [chemsys]
Expand Down
36 changes: 14 additions & 22 deletions tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,22 @@ def rester():
"num_chunks",
"all_fields",
"fields",
"formula", # Timeout issue
"formula",
"elements",
"exclude_elements",
]

sub_doc_fields = [] # type: list

alt_name_dict = {"formula": "task_id"} # type: dict
alt_name_dict = {"formula": "task_id", "task_ids": "task_id", "exclude_elements": "task_id"} # type: dict

custom_field_tests = {"chemsys": "Si-O"} # type: dict
custom_field_tests = {
"chemsys": "Si-O",
"task_ids": ["mp-149"],
} # type: dict


@pytest.mark.skipif(
os.environ.get("MP_API_KEY", None) is None, reason="No API key found."
)
@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.")
def test_client(rester):
search_method = rester.search

Expand All @@ -52,50 +55,39 @@ def test_client(rester):
param: (-100, 100),
"chunk_size": 1,
"num_chunks": 1,
"fields": [
project_field if project_field is not None else param
],
"fields": [project_field if project_field is not None else param],
}
elif param_type == typing.Tuple[float, float]:
project_field = alt_name_dict.get(param, None)
q = {
param: (-100.12, 100.12),
"chunk_size": 1,
"num_chunks": 1,
"fields": [
project_field if project_field is not None else param
],
"fields": [project_field if project_field is not None else param],
}
elif param_type is bool:
project_field = alt_name_dict.get(param, None)
q = {
param: False,
"chunk_size": 1,
"num_chunks": 1,
"fields": [
project_field if project_field is not None else param
],
"fields": [project_field if project_field is not None else param],
}
elif param in custom_field_tests:
project_field = alt_name_dict.get(param, None)
q = {
param: custom_field_tests[param],
"chunk_size": 1,
"num_chunks": 1,
"fields": [
project_field if project_field is not None else param
],
"fields": [project_field if project_field is not None else param],
}

doc = search_method(**q)[0].dict()
for sub_field in sub_doc_fields:
if sub_field in doc:
doc = doc[sub_field]

assert (
doc[project_field if project_field is not None else param]
is not None
)
assert doc[project_field if project_field is not None else param] is not None


@pytest.mark.xfail(reason="Temporary until redeploy")
Expand Down

0 comments on commit bc5a9fb

Please sign in to comment.