diff --git a/mp_api/client/core/client.py b/mp_api/client/core/client.py index 3951acf7..83516c42 100644 --- a/mp_api/client/core/client.py +++ b/mp_api/client/core/client.py @@ -685,11 +685,26 @@ def _submit_request_and_process( # other sub-urls may use different document models # the client does not handle this in a particularly smart way currently if self.document_model and use_document_model: + raw_doc_list = [self.document_model.parse_obj(d) for d in data["data"]] # type: ignore - # Temporarily removed until user-testing completed - data["data"] = self._generate_returned_model(raw_doc_list) - # data["data"] = raw_doc_list + if len(raw_doc_list) > 0: + data_model, set_fields, _ = self._generate_returned_model( + raw_doc_list[0] + ) + + data["data"] = [ + data_model( + **{ + field: value + for field, value in raw_doc.dict().items() + if field in set_fields + } + ) + for raw_doc in raw_doc_list + ] + + # data["data"] = raw_doc_list meta_total_doc_num = data.get("meta", {}).get("total_doc", 1) @@ -715,70 +730,63 @@ def _submit_request_and_process( f"on URL {response.url} with message:\n{message}" ) - def _generate_returned_model(self, data): + def _generate_returned_model(self, doc): - new_data = [] + set_fields = [ + field for field, _ in doc if field in doc.dict(exclude_unset=True) + ] + unset_fields = [field for field in doc.__fields__ if field not in set_fields] - for doc in data: - set_data = { - field: value - for field, value in doc - if field in doc.dict(exclude_unset=True) - } - unset_fields = [field for field in doc.__fields__ if field not in set_data] + data_model = create_model( + "MPDataDoc", + fields_not_requested=unset_fields, + __base__=self.document_model, + ) - data_model = create_model( - "MPDataDoc", - fields_not_requested=unset_fields, - __base__=self.document_model, + data_model.__fields__ = { + **{ + name: description + for name, description in data_model.__fields__.items() + if name in set_fields + }, + "fields_not_requested": data_model.__fields__["fields_not_requested"], + } + + def new_repr(self) -> str: + extra = ",\n".join( + f"\033[1m{n}\033[0;0m={getattr(self, n)!r}" + for n in data_model.__fields__ ) - data_model.__fields__ = { - **{ - name: description - for name, description in data_model.__fields__.items() - if name in set_data - }, - "fields_not_requested": data_model.__fields__["fields_not_requested"], - } - - def new_repr(self) -> str: - extra = ",\n".join( - f"\033[1m{n}\033[0;0m={getattr(self, n)!r}" - for n in data_model.__fields__ - ) - - s = f"\033[4m\033[1m{self.__class__.__name__}<{self.__class__.__base__.__name__}>\033[0;0m\033[0;0m(\n{extra}\n)" # noqa: E501 - return s + s = f"\033[4m\033[1m{self.__class__.__name__}<{self.__class__.__base__.__name__}>\033[0;0m\033[0;0m(\n{extra}\n)" # noqa: E501 + return s - def new_str(self) -> str: - extra = ",\n".join( - f"\033[1m{n}\033[0;0m={getattr(self, n)!r}" - for n in data_model.__fields__ - if n != "fields_not_requested" - ) + def new_str(self) -> str: + extra = ",\n".join( + f"\033[1m{n}\033[0;0m={getattr(self, n)!r}" + for n in data_model.__fields__ + if n != "fields_not_requested" + ) - s = f"\033[4m\033[1m{self.__class__.__name__}<{self.__class__.__base__.__name__}>\033[0;0m\033[0;0m\n{extra}\n\n\033[1mFields not requested:\033[0;0m\n{unset_fields}" # noqa: E501 - return s + s = f"\033[4m\033[1m{self.__class__.__name__}<{self.__class__.__base__.__name__}>\033[0;0m\033[0;0m\n{extra}\n\n\033[1mFields not requested:\033[0;0m\n{unset_fields}" # noqa: E501 + return s - def new_getattr(self, attr) -> str: - if attr in unset_fields: - raise AttributeError( - f"'{attr}' data is available but has not been requested in 'fields'." - " A full list of unrequested fields can be found in `fields_not_requested`." - ) - else: - raise AttributeError( - f"{self.__class__.__name__!r} object has no attribute {attr!r}" - ) - - data_model.__repr__ = new_repr - data_model.__str__ = new_str - data_model.__getattr__ = new_getattr + def new_getattr(self, attr) -> str: + if attr in unset_fields: + raise AttributeError( + f"'{attr}' data is available but has not been requested in 'fields'." + " A full list of unrequested fields can be found in `fields_not_requested`." + ) + else: + raise AttributeError( + f"{self.__class__.__name__!r} object has no attribute {attr!r}" + ) - new_data.append(data_model(**set_data)) + data_model.__repr__ = new_repr + data_model.__str__ = new_str + data_model.__getattr__ = new_getattr - return new_data + return data_model, set_fields, unset_fields def _query_resource_data( self,