diff --git a/aiida/tools/dbimporters/plugins/mpds.py b/aiida/tools/dbimporters/plugins/mpds.py index 2d07fcb61f..45631d1890 100644 --- a/aiida/tools/dbimporters/plugins/mpds.py +++ b/aiida/tools/dbimporters/plugins/mpds.py @@ -8,6 +8,8 @@ # For further information please visit http://www.aiida.net # ########################################################################### +import copy +import enum import json import os import requests @@ -15,12 +17,22 @@ from aiida.tools.dbimporters.baseclasses import CifEntry, DbEntry, DbImporter, DbSearchResults + +class ApiFormat(enum.Enum): + JSON = 'json' + CIF = 'cif' + + +DEFAULT_API_FORMAT = ApiFormat.JSON +CIF_ENTRY_ID_TAG = '_pauling_file_entry' + + class MpdsDbImporter(DbImporter): """ Database importer for the Materials Platform for Data Science (MPDS) """ - _url = "https://api.mpds.io/v0/download/facet" + _url = 'https://api.mpds.io/v0/download/facet' _api_key = None _collection = 'structures' _pagesize = 1000 @@ -125,18 +137,40 @@ def query(self, query, collection=None): if collection is None: collection = self.collection - results = [] - if collection == 'structures': - for entry in self.structures.find(query): - results.append(entry) + + results = [] + results_cif = {} + results_json = [] + + for entry in self.structures.find(query, fmt=ApiFormat.JSON): + results_json.append(entry) + + for entry in self.structures.find(query, fmt=ApiFormat.CIF): + entry_id = self.get_id_from_cif(entry) + results_cif[entry_id] = entry + + for entry in results_json: + + entry_id = entry['entry'] + + try: + cif = results_cif[entry_id] + except KeyError: + # Corresponding cif file was not retrieved, skipping + continue + + result_entry = copy.deepcopy(entry) + result_entry['cif'] = cif + results.append(result_entry) + search_results = MpdsSearchResults(results, return_class=MpdsCifEntry) else: raise ValueError('Unsupported collection: {}'.format(collection)) return search_results - def find(self, query): + def find(self, query, fmt=DEFAULT_API_FORMAT): """ Query the database with a given dictionary of query parameters @@ -147,29 +181,49 @@ def find(self, query): pagesize = self.pagesize - response = self.get(q=json.dumps(query), pagesize=pagesize) - content = self.get_response_content(response) + response = self.get(q=json.dumps(query), fmt=ApiFormat.JSON, pagesize=pagesize) + content = self.get_response_content(response, fmt=ApiFormat.JSON) count = content['count'] npages = content['npages'] for page in range(0, npages): - response = self.get(q=json.dumps(query), pagesize=pagesize, page=page) - content = self.get_response_content(response) + response = self.get(q=json.dumps(query), fmt=fmt, pagesize=pagesize, page=page) + content = self.get_response_content(response, fmt=fmt) - if (page + 1) * pagesize > count: - last = count - (page * pagesize) - else: - last = pagesize + if fmt == ApiFormat.JSON: - for i in range(0, last): - result = content['out'][i] - result['license'] = content['disclaimer'] + if (page + 1) * pagesize > count: + last = count - (page * pagesize) + else: + last = pagesize - yield result + for i in range(0, last): + result = content['out'][i] + result['license'] = content['disclaimer'] - def get(self, fmt='json', **kwargs): + yield result + + elif fmt == ApiFormat.CIF: + + lines = content.splitlines() + cif = [] + for line in lines: + if cif: + if line.startswith('data_'): + text = '\n'.join(cif) + cif = [line] + yield text + else: + cif.append(line) + else: + if line.startswith('data_'): + cif.append(line) + if cif: + yield '\n'.join(cif) + + def get(self, fmt=DEFAULT_API_FORMAT, **kwargs): """ Perform a GET request to the REST API using the kwargs as request parameters The url and API key will be used that were set upon construction @@ -177,10 +231,10 @@ def get(self, fmt='json', **kwargs): :param fmt: the format of the response, 'cif' or json' (default) :param kwargs: parameters for the GET request """ - kwargs['fmt'] = fmt + kwargs['fmt'] = fmt.value return requests.get(url=self.url, params=kwargs, headers={'Key': self.api_key}) - def get_response_content(self, response): + def get_response_content(self, response, fmt=DEFAULT_API_FORMAT): """ Analyze the response of an HTTP GET request, verify that the response code is OK and return the json loaded response text @@ -189,16 +243,35 @@ def get_response_content(self, response): :raises RuntimeError: HTTP response is not 200 :raises ValueError: HTTP response 200 contained non zero error message """ - content = response.json() - error = content.get('error', None) - if not response.ok: - raise RuntimeError('HTTP[{}] request failed: {}'.format(response.status_code, error)) + raise RuntimeError('HTTP[{}] request failed: {}'.format(response.status_code, response.text)) + + if fmt == ApiFormat.JSON: + content = response.json() + error = content.get('error', None) + + if error is not None: + raise ValueError('Got error response: {}'.format(error)) + + return content + else: + return response.text + + def get_id_from_cif(self, cif): + """ + Extract the entry id from the string formatted cif response of the MPDS API - if error is not None: - raise ValueError('Got error response: {}'.format(error)) + :param cif: string representation of the cif file + :returns: entry id of the cif file or None if could not be found + """ + entry_id = None - return content + for line in cif.split('\n'): + if CIF_ENTRY_ID_TAG in line: + entry_id = line.split()[1] + break + + return entry_id class StructuresCollection(object): @@ -213,15 +286,15 @@ def engine(self): """ return self._engine - def find(self, query): + def find(self, query, fmt=DEFAULT_API_FORMAT): """ Query the structures collection with a given dictionary of query parameters :param query: a dictionary with the query parameters """ - for result in self.engine.find(query): + for result in self.engine.find(query, fmt=fmt): - if 'object_type' not in result or result['object_type'] != 'S': + if fmt != ApiFormat.CIF and ('object_type' not in result or result['object_type'] != 'S'): continue yield result @@ -252,18 +325,23 @@ class MpdsCifEntry(CifEntry, MpdsEntry): def __init__(self, url, **kwargs): """ - Overwrite the permanent 'reference' URI with a URI that points to the CIF contents + The DbSearchResults base class instantiates a new DbEntry by explicitly passing the url + of the entry as an argument. In this case it is the same as the 'uri' value that is + already contained in the source dictionary so we just copy it """ + cif = kwargs.pop('cif', None) kwargs['uri'] = url super(MpdsCifEntry, self).__init__(url, **kwargs) + if cif is not None: + self.cif = cif + class MpdsSearchResults(DbSearchResults): """ A collection of MpdsEntry query result entries """ - _base_url ='https://api.mpds.io/v0/download/s' _db_name = 'Materials Platform for Data Science' _db_uri = 'https://mpds.io/' _return_class = MpdsEntry @@ -275,7 +353,7 @@ def __init__(self, results, return_class=None): def _get_source_dict(self, result_dict): """ - Returns the source information dictionary of an MPDS query result entry + Return the source information dictionary of an MPDS query result entry :param result_dict: query result entry dictionary """ @@ -288,13 +366,15 @@ def _get_source_dict(self, result_dict): 'version': result_dict['version'], } + if 'cif' in result_dict: + source_dict['cif'] = result_dict['cif'] + return source_dict def _get_url(self, result_dict): """ - Return the URL that points to the raw CIF content of the entry + Return the permanent URI of the result entry :param result_dict: query result entry dictionary """ - url = '{}?q={}&fmt=cif&export=1'.format(self._base_url, result_dict['entry']) - return url + return result_dict['reference'] diff --git a/mpds.py b/mpds.py deleted file mode 100755 index d4759f2629..0000000000 --- a/mpds.py +++ /dev/null @@ -1,44 +0,0 @@ -#!/usr/bin/env runaiida - -import json -import sys - -def main(): - from aiida.tools.dbimporters import DbImporterFactory - - database = 'mpds' - importer_parameters = {} - query_parameters = { - 'query': { - 'elements': 'Ti', - 'classes': 'binary', - 'props': 'atomic structure', - }, - 'collection': 'structures' - } - - importer_class = DbImporterFactory(database) - importer = importer_class(**importer_parameters) - - try: - query_results = importer.query(**query_parameters) - except BaseException as exception: - print(exception) - sys.exit(1) - - count = 0 - limit = 10 - print len(query_results) - return - - for entry in query_results: - cif = entry.get_cif_node() - cif.store() - print cif.pk - count += 1 - if count > limit: - return - - -if __name__ == '__main__': - main() \ No newline at end of file