Skip to content

Commit

Permalink
Enhance output objects (#654)
Browse files Browse the repository at this point in the history
* Update object formatting and attribute errors

* Remove newline

* Update error message

* MPDataEntry to MPDataDoc

* Swap repr and str formatting in output objects

* Ignore E501 for str and repr for output objects

* Linting
  • Loading branch information
Jason Munro authored Aug 12, 2022
1 parent ef951a2 commit be24667
Showing 1 changed file with 122 additions and 36 deletions.
158 changes: 122 additions & 36 deletions src/mp_api/core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from requests.exceptions import RequestException
from tqdm.auto import tqdm
from urllib3.util.retry import Retry
from inspect import cleandoc

try:
from pymatgen.core import __version__ as pmg_version # type: ignore
Expand Down Expand Up @@ -133,7 +134,9 @@ def _create_session(api_key, include_user_agent):
sys.version_info.major, sys.version_info.minor, sys.version_info.micro
)
platform_info = "{}/{}".format(platform.system(), platform.release())
session.headers["user-agent"] = "{} ({} {})".format(pymatgen_info, python_info, platform_info)
session.headers["user-agent"] = "{} ({} {})".format(
pymatgen_info, python_info, platform_info
)

max_retry_num = MAPIClientSettings().MAX_RETRIES
retry = Retry(
Expand Down Expand Up @@ -222,7 +225,10 @@ def _post_resource(
message = data
else:
try:
message = ", ".join("{} - {}".format(entry["loc"][1], entry["msg"]) for entry in data)
message = ", ".join(
"{} - {}".format(entry["loc"][1], entry["msg"])
for entry in data
)
except (KeyError, IndexError):
message = str(data)

Expand Down Expand Up @@ -309,7 +315,14 @@ def _query_resource(
raise MPRestError(str(ex))

def _submit_requests(
self, url, criteria, use_document_model, parallel_param=None, num_chunks=None, chunk_size=None, timeout=None
self,
url,
criteria,
use_document_model,
parallel_param=None,
num_chunks=None,
chunk_size=None,
timeout=None,
) -> Dict:
"""
Handle submitting requests. Parallel requests supported if possible.
Expand Down Expand Up @@ -337,7 +350,9 @@ def _submit_requests(
# criteria dicts.
if parallel_param is not None:
param_length = len(criteria[parallel_param].split(","))
slice_size = int(param_length / MAPIClientSettings().NUM_PARALLEL_REQUESTS) or 1
slice_size = (
int(param_length / MAPIClientSettings().NUM_PARALLEL_REQUESTS) or 1
)

new_param_values = [
entry
Expand All @@ -362,7 +377,11 @@ def _submit_requests(
# Split list and generate multiple criteria
new_criteria = [
{
**{key: criteria[key] for key in criteria if key not in [parallel_param, "_limit"]},
**{
key: criteria[key]
for key in criteria
if key not in [parallel_param, "_limit"]
},
parallel_param: ",".join(list_chunk),
"_limit": new_limits[list_num],
}
Expand All @@ -385,9 +404,13 @@ def _submit_requests(
subtotals = []
remaining_docs_avail = {}

initial_params_list = [{"url": url, "verify": True, "params": copy(crit)} for crit in new_criteria]
initial_params_list = [
{"url": url, "verify": True, "params": copy(crit)} for crit in new_criteria
]

initial_data_tuples = self._multi_thread(use_document_model, initial_params_list)
initial_data_tuples = self._multi_thread(
use_document_model, initial_params_list
)

for data, subtotal, crit_ind in initial_data_tuples:

Expand All @@ -400,7 +423,9 @@ def _submit_requests(

# Rebalance if some parallel queries produced too few results
if len(remaining_docs_avail) > 1 and len(total_data["data"]) < chunk_size:
remaining_docs_avail = dict(sorted(remaining_docs_avail.items(), key=lambda item: item[1]))
remaining_docs_avail = dict(
sorted(remaining_docs_avail.items(), key=lambda item: item[1])
)

# Redistribute missing docs from initial chunk among queries
# which have head room with respect to remaining document number.
Expand All @@ -427,15 +452,19 @@ def _submit_requests(
new_limits[crit_ind] += fill_docs
fill_docs = 0

rebalance_params.append({"url": url, "verify": True, "params": copy(crit)})
rebalance_params.append(
{"url": url, "verify": True, "params": copy(crit)}
)

new_criteria[crit_ind]["_skip"] += crit["_limit"]
new_criteria[crit_ind]["_limit"] = chunk_size

# Obtain missing initial data after rebalancing
if len(rebalance_params) > 0:

rebalance_data_tuples = self._multi_thread(use_document_model, rebalance_params)
rebalance_data_tuples = self._multi_thread(
use_document_model, rebalance_params
)

for data, _, _ in rebalance_data_tuples:
total_data["data"].extend(data["data"])
Expand All @@ -449,7 +478,9 @@ def _submit_requests(
total_data["meta"] = last_data_entry["meta"]

# Get max number of reponse pages
max_pages = num_chunks if num_chunks is not None else ceil(total_num_docs / chunk_size)
max_pages = (
num_chunks if num_chunks is not None else ceil(total_num_docs / chunk_size)
)

# Get total number of docs needed
num_docs_needed = min((max_pages * chunk_size), total_num_docs)
Expand All @@ -461,10 +492,7 @@ def _submit_requests(
else "Retrieving documents"
)
pbar = (
tqdm(
desc=pbar_message,
total=num_docs_needed,
)
tqdm(desc=pbar_message, total=num_docs_needed,)
if not MAPIClientSettings().MUTE_PROGRESS_BARS
else None
)
Expand Down Expand Up @@ -543,7 +571,11 @@ def _submit_requests(
return total_data

def _multi_thread(
self, use_document_model: bool, params_list: List[dict], progress_bar: tqdm = None, timeout: int = None
self,
use_document_model: bool,
params_list: List[dict],
progress_bar: tqdm = None,
timeout: int = None,
):
"""
Handles setting up a threadpool and sending parallel requests
Expand All @@ -561,16 +593,22 @@ def _multi_thread(

return_data = []

params_gen = iter(params_list) # Iter necessary for islice to keep track of what has been accessed
params_gen = iter(
params_list
) # Iter necessary for islice to keep track of what has been accessed

params_ind = 0

with ThreadPoolExecutor(max_workers=MAPIClientSettings().NUM_PARALLEL_REQUESTS) as executor:
with ThreadPoolExecutor(
max_workers=MAPIClientSettings().NUM_PARALLEL_REQUESTS
) as executor:

# Get list of initial futures defined by max number of parallel requests
futures = set()

for params in itertools.islice(params_gen, MAPIClientSettings().NUM_PARALLEL_REQUESTS):
for params in itertools.islice(
params_gen, MAPIClientSettings().NUM_PARALLEL_REQUESTS
):

future = executor.submit(
self._submit_request_and_process,
Expand Down Expand Up @@ -632,9 +670,13 @@ def _submit_request_and_process(
Tuple with data and total number of docs in matching the query in the database.
"""
try:
response = self.session.get(url=url, verify=verify, params=params, timeout=timeout)
response = self.session.get(
url=url, verify=verify, params=params, timeout=timeout
)
except requests.exceptions.ConnectTimeout:
raise MPRestError(f"REST query timed out on URL {url}. Try again with a smaller request.")
raise MPRestError(
f"REST query timed out on URL {url}. Try again with a smaller request."
)

if response.status_code == 200:

Expand All @@ -649,9 +691,8 @@ def _submit_request_and_process(
raw_doc_list = [self.document_model.parse_obj(d) for d in data["data"]] # type: ignore

# Temporarily removed until user-testing completed
# data["data"] = self._generate_returned_model(raw_doc_list)

data["data"] = raw_doc_list
data["data"] = self._generate_returned_model(raw_doc_list)
# data["data"] = raw_doc_list

meta_total_doc_num = data.get("meta", {}).get("total_doc", 1)

Expand All @@ -666,7 +707,10 @@ def _submit_request_and_process(
message = data
else:
try:
message = ", ".join("{} - {}".format(entry["loc"][1], entry["msg"]) for entry in data)
message = ", ".join(
"{} - {}".format(entry["loc"][1], entry["msg"])
for entry in data
)
except (KeyError, IndexError):
message = str(data)

Expand All @@ -680,25 +724,62 @@ def _generate_returned_model(self, data):
new_data = []

for doc in data:
set_data = {field: value for field, value in doc if field in doc.dict(exclude_unset=True)}
set_data = {
field: value
for field, value in doc
if field in doc.dict(exclude_unset=True)
}
unset_fields = [field for field in doc.__fields__ if field not in set_data]

data_model = create_model(
"MPDataEntry",
"MPDataDoc",
fields_not_requested=unset_fields,
__base__=self.document_model,
)

data_model.__fields__ = {
**{name: description for name, description in data_model.__fields__.items() if name in set_data},
**{
name: description
for name, description in data_model.__fields__.items()
if name in set_data
},
"fields_not_requested": data_model.__fields__["fields_not_requested"],
}

def new_repr(self) -> str:
extra = ", ".join(f"{n}={getattr(self, n)!r}" for n in data_model.__fields__)
return f"{self.__class__.__name__}<{self.__class__.__base__.__name__}>({extra})"
extra = ",\n".join(
f"\033[1m{n}\033[0;0m={getattr(self, n)!r}"
for n in data_model.__fields__
)

s = f"\033[4m\033[1m{self.__class__.__name__}<{self.__class__.__base__.__name__}>\033[0;0m\033[0;0m(\n{extra}\n)" # noqa: E501
return s

def new_str(self) -> str:
extra = ",\n".join(
f"\033[1m{n}\033[0;0m={getattr(self, n)!r}"
for n in data_model.__fields__
if n != "fields_not_requested"
)

s = f"\033[4m\033[1m{self.__class__.__name__}<{self.__class__.__base__.__name__}>\033[0;0m\033[0;0m\n{extra}\n\n\033[1mFields not requested:\033[0;0m\n{unset_fields}" # noqa: E501
return s

def new_getattr(self, attr) -> str:
if attr in unset_fields:
raise MPRestError(
f"'{attr}' data is available but has not been requested in 'fields'."
" A full list of unrequested fields can be found in `fields_not_requested`."
)
else:
raise AttributeError(
"%r object has no attribute %r"
% (self.__class__.__name__, attr)
)

data_model.__repr__ = new_repr
data_model.__str__ = new_str
data_model.__getattr__ = new_getattr

new_data.append(data_model(**set_data))

Expand Down Expand Up @@ -737,9 +818,7 @@ def _query_resource_data(
).get("data")

def get_data_by_id(
self,
document_id: str,
fields: Optional[List[str]] = None,
self, document_id: str, fields: Optional[List[str]] = None,
) -> T:
"""
Query the endpoint for a single document.
Expand All @@ -753,7 +832,10 @@ def get_data_by_id(
"""

if document_id is None:
raise ValueError("Please supply a specific ID. You can use the query method to find " "ids of interest.")
raise ValueError(
"Please supply a specific ID. You can use the query method to find "
"ids of interest."
)

if self.primary_key in ["material_id", "task_id"]:
validate_ids([document_id])
Expand Down Expand Up @@ -794,7 +876,9 @@ def get_data_by_id(
if not results:
raise MPRestError(f"No result for record {document_id}.")
elif len(results) > 1: # pragma: no cover
raise ValueError(f"Multiple records for {document_id}, this shouldn't happen. Please report as a bug.")
raise ValueError(
f"Multiple records for {document_id}, this shouldn't happen. Please report as a bug."
)
else:
return results[0]

Expand Down Expand Up @@ -901,7 +985,9 @@ def count(self, criteria: Optional[Dict] = None) -> Union[int, str]:
False,
False,
) # do not waste cycles decoding
results = self._query_resource(criteria=criteria, num_chunks=1, chunk_size=1)
results = self._query_resource(
criteria=criteria, num_chunks=1, chunk_size=1
)
self.monty_decode, self.use_document_model = user_preferences
return results["meta"]["total_doc"]
except Exception: # pragma: no cover
Expand Down

0 comments on commit be24667

Please sign in to comment.