From fd2034bd1d038bd3d62229a1ac80e7071936aece Mon Sep 17 00:00:00 2001 From: Sebastiaan Huber Date: Mon, 5 Mar 2018 19:18:56 +0100 Subject: [PATCH] Implement workaround for single output format limitation of MPDS API The MPDS API only supports a single output format for a query. Either the result of the query is returned as a 'json' object or the result is a concatenation of string formatted cif files. To import CifData nodes, however, we need both the json object to retrieve the required source information, but we also need the raw cif file, as we do not want to reconstruct the structure or cif ourselves from the basic structural data that is provided in the json. As a workaround, we fire the exact same query twice, once asking the result in the json format and the other in the cif format. We add the cif string to the json result entries by cross referencing the source id that is present in both the json entry and the raw cif string. A special MpdsCifEntry will then use that cif string to directly set the contents attribute. This will then prevent a separate HTTP request to the source uri to retrieve the cif content, which would also result in a 429 HTTP error due to too many requests being fired --- aiida/tools/dbimporters/plugins/mpds.py | 154 ++++++++++++++++++------ mpds.py | 44 ------- 2 files changed, 117 insertions(+), 81 deletions(-) delete mode 100755 mpds.py diff --git a/aiida/tools/dbimporters/plugins/mpds.py b/aiida/tools/dbimporters/plugins/mpds.py index 2d07fcb61f..45631d1890 100644 --- a/aiida/tools/dbimporters/plugins/mpds.py +++ b/aiida/tools/dbimporters/plugins/mpds.py @@ -8,6 +8,8 @@ # For further information please visit http://www.aiida.net # ########################################################################### +import copy +import enum import json import os import requests @@ -15,12 +17,22 @@ from aiida.tools.dbimporters.baseclasses import CifEntry, DbEntry, DbImporter, DbSearchResults + +class ApiFormat(enum.Enum): + JSON = 'json' + CIF = 'cif' + + +DEFAULT_API_FORMAT = ApiFormat.JSON +CIF_ENTRY_ID_TAG = '_pauling_file_entry' + + class MpdsDbImporter(DbImporter): """ Database importer for the Materials Platform for Data Science (MPDS) """ - _url = "https://api.mpds.io/v0/download/facet" + _url = 'https://api.mpds.io/v0/download/facet' _api_key = None _collection = 'structures' _pagesize = 1000 @@ -125,18 +137,40 @@ def query(self, query, collection=None): if collection is None: collection = self.collection - results = [] - if collection == 'structures': - for entry in self.structures.find(query): - results.append(entry) + + results = [] + results_cif = {} + results_json = [] + + for entry in self.structures.find(query, fmt=ApiFormat.JSON): + results_json.append(entry) + + for entry in self.structures.find(query, fmt=ApiFormat.CIF): + entry_id = self.get_id_from_cif(entry) + results_cif[entry_id] = entry + + for entry in results_json: + + entry_id = entry['entry'] + + try: + cif = results_cif[entry_id] + except KeyError: + # Corresponding cif file was not retrieved, skipping + continue + + result_entry = copy.deepcopy(entry) + result_entry['cif'] = cif + results.append(result_entry) + search_results = MpdsSearchResults(results, return_class=MpdsCifEntry) else: raise ValueError('Unsupported collection: {}'.format(collection)) return search_results - def find(self, query): + def find(self, query, fmt=DEFAULT_API_FORMAT): """ Query the database with a given dictionary of query parameters @@ -147,29 +181,49 @@ def find(self, query): pagesize = self.pagesize - response = self.get(q=json.dumps(query), pagesize=pagesize) - content = self.get_response_content(response) + response = self.get(q=json.dumps(query), fmt=ApiFormat.JSON, pagesize=pagesize) + content = self.get_response_content(response, fmt=ApiFormat.JSON) count = content['count'] npages = content['npages'] for page in range(0, npages): - response = self.get(q=json.dumps(query), pagesize=pagesize, page=page) - content = self.get_response_content(response) + response = self.get(q=json.dumps(query), fmt=fmt, pagesize=pagesize, page=page) + content = self.get_response_content(response, fmt=fmt) - if (page + 1) * pagesize > count: - last = count - (page * pagesize) - else: - last = pagesize + if fmt == ApiFormat.JSON: - for i in range(0, last): - result = content['out'][i] - result['license'] = content['disclaimer'] + if (page + 1) * pagesize > count: + last = count - (page * pagesize) + else: + last = pagesize - yield result + for i in range(0, last): + result = content['out'][i] + result['license'] = content['disclaimer'] - def get(self, fmt='json', **kwargs): + yield result + + elif fmt == ApiFormat.CIF: + + lines = content.splitlines() + cif = [] + for line in lines: + if cif: + if line.startswith('data_'): + text = '\n'.join(cif) + cif = [line] + yield text + else: + cif.append(line) + else: + if line.startswith('data_'): + cif.append(line) + if cif: + yield '\n'.join(cif) + + def get(self, fmt=DEFAULT_API_FORMAT, **kwargs): """ Perform a GET request to the REST API using the kwargs as request parameters The url and API key will be used that were set upon construction @@ -177,10 +231,10 @@ def get(self, fmt='json', **kwargs): :param fmt: the format of the response, 'cif' or json' (default) :param kwargs: parameters for the GET request """ - kwargs['fmt'] = fmt + kwargs['fmt'] = fmt.value return requests.get(url=self.url, params=kwargs, headers={'Key': self.api_key}) - def get_response_content(self, response): + def get_response_content(self, response, fmt=DEFAULT_API_FORMAT): """ Analyze the response of an HTTP GET request, verify that the response code is OK and return the json loaded response text @@ -189,16 +243,35 @@ def get_response_content(self, response): :raises RuntimeError: HTTP response is not 200 :raises ValueError: HTTP response 200 contained non zero error message """ - content = response.json() - error = content.get('error', None) - if not response.ok: - raise RuntimeError('HTTP[{}] request failed: {}'.format(response.status_code, error)) + raise RuntimeError('HTTP[{}] request failed: {}'.format(response.status_code, response.text)) + + if fmt == ApiFormat.JSON: + content = response.json() + error = content.get('error', None) + + if error is not None: + raise ValueError('Got error response: {}'.format(error)) + + return content + else: + return response.text + + def get_id_from_cif(self, cif): + """ + Extract the entry id from the string formatted cif response of the MPDS API - if error is not None: - raise ValueError('Got error response: {}'.format(error)) + :param cif: string representation of the cif file + :returns: entry id of the cif file or None if could not be found + """ + entry_id = None - return content + for line in cif.split('\n'): + if CIF_ENTRY_ID_TAG in line: + entry_id = line.split()[1] + break + + return entry_id class StructuresCollection(object): @@ -213,15 +286,15 @@ def engine(self): """ return self._engine - def find(self, query): + def find(self, query, fmt=DEFAULT_API_FORMAT): """ Query the structures collection with a given dictionary of query parameters :param query: a dictionary with the query parameters """ - for result in self.engine.find(query): + for result in self.engine.find(query, fmt=fmt): - if 'object_type' not in result or result['object_type'] != 'S': + if fmt != ApiFormat.CIF and ('object_type' not in result or result['object_type'] != 'S'): continue yield result @@ -252,18 +325,23 @@ class MpdsCifEntry(CifEntry, MpdsEntry): def __init__(self, url, **kwargs): """ - Overwrite the permanent 'reference' URI with a URI that points to the CIF contents + The DbSearchResults base class instantiates a new DbEntry by explicitly passing the url + of the entry as an argument. In this case it is the same as the 'uri' value that is + already contained in the source dictionary so we just copy it """ + cif = kwargs.pop('cif', None) kwargs['uri'] = url super(MpdsCifEntry, self).__init__(url, **kwargs) + if cif is not None: + self.cif = cif + class MpdsSearchResults(DbSearchResults): """ A collection of MpdsEntry query result entries """ - _base_url ='https://api.mpds.io/v0/download/s' _db_name = 'Materials Platform for Data Science' _db_uri = 'https://mpds.io/' _return_class = MpdsEntry @@ -275,7 +353,7 @@ def __init__(self, results, return_class=None): def _get_source_dict(self, result_dict): """ - Returns the source information dictionary of an MPDS query result entry + Return the source information dictionary of an MPDS query result entry :param result_dict: query result entry dictionary """ @@ -288,13 +366,15 @@ def _get_source_dict(self, result_dict): 'version': result_dict['version'], } + if 'cif' in result_dict: + source_dict['cif'] = result_dict['cif'] + return source_dict def _get_url(self, result_dict): """ - Return the URL that points to the raw CIF content of the entry + Return the permanent URI of the result entry :param result_dict: query result entry dictionary """ - url = '{}?q={}&fmt=cif&export=1'.format(self._base_url, result_dict['entry']) - return url + return result_dict['reference'] diff --git a/mpds.py b/mpds.py deleted file mode 100755 index d4759f2629..0000000000 --- a/mpds.py +++ /dev/null @@ -1,44 +0,0 @@ -#!/usr/bin/env runaiida - -import json -import sys - -def main(): - from aiida.tools.dbimporters import DbImporterFactory - - database = 'mpds' - importer_parameters = {} - query_parameters = { - 'query': { - 'elements': 'Ti', - 'classes': 'binary', - 'props': 'atomic structure', - }, - 'collection': 'structures' - } - - importer_class = DbImporterFactory(database) - importer = importer_class(**importer_parameters) - - try: - query_results = importer.query(**query_parameters) - except BaseException as exception: - print(exception) - sys.exit(1) - - count = 0 - limit = 10 - print len(query_results) - return - - for entry in query_results: - cif = entry.get_cif_node() - cif.store() - print cif.pk - count += 1 - if count > limit: - return - - -if __name__ == '__main__': - main() \ No newline at end of file