diff --git a/app.py b/app.py index 62b150eb..749d08e1 100644 --- a/app.py +++ b/app.py @@ -3,6 +3,7 @@ from mp_api.core.api import MAPI from mp_api.core.settings import MAPISettings +from maggma.stores import MongoURIStore, S3Store resources = {} @@ -13,68 +14,7 @@ db_suffix = os.environ.get("MAPI_DB_NAME_SUFFIX", db_version) debug = default_settings.DEBUG -materials_store_json = os.environ.get("MATERIALS_STORE", "materials_store.json") -bonds_store_json = os.environ.get("BONDS_STORE", "bonds_store.json") -formula_autocomplete_store_json = os.environ.get( - "FORMULA_AUTOCOMPLETE_STORE", "formula_autocomplete_store.json" -) -task_store_json = os.environ.get("TASK_STORE", "task_store.json") -thermo_store_json = os.environ.get("THERMO_STORE", "thermo_store.json") -phase_diagram_store_json = os.environ.get( - "PHASE_DIAGRAM_STORE", "phase_diagram_store.json" -) -dielectric_store_json = os.environ.get("DIELECTRIC_STORE", "dielectric_store.json") -piezoelectric_store_json = os.environ.get( - "PIEZOELECTRIC_STORE", "piezoelectric_store.json" -) -magnetism_store_json = os.environ.get("MAGNETISM_STORE", "magnetism_store.json") -phonon_bs_store_json = os.environ.get("PHONON_BS_STORE", "phonon_bs_store.json") -eos_store_json = os.environ.get("EOS_STORE", "eos_store.json") -similarity_store_json = os.environ.get("SIMILARITY_STORE", "similarity_store.json") -xas_store_json = os.environ.get("XAS_STORE", "xas_store.json") -gb_store_json = os.environ.get("GB_STORE", "xas_store.json") -fermi_store_json = os.environ.get("FERMI_STORE", "fermi_store.json") -elasticity_store_json = os.environ.get("ELASTICITY_STORE", "elasticity_store.json") -doi_store_json = os.environ.get("DOI_STORE", "doi_store.json") -substrates_store_json = os.environ.get("SUBSTRATES_STORE", "substrates_store.json") -surface_props_store_json = os.environ.get( - "SURFACE_PROPS_STORE", "surface_props_store.json" -) -robocrys_store_json = os.environ.get("ROBOCRYS_STORE", "robocrys_store.json") -synth_store_json = os.environ.get("SYNTH_STORE", "synth_store.json") -insertion_electrodes_store_json = os.environ.get( - "INSERTION_ELECTRODES_STORE", "insertion_electrodes_store.json" -) -molecules_store_json = os.environ.get("MOLECULES_STORE", "molecules_store.json") -oxi_states_store_json = os.environ.get("OXI_STATES_STORE", "oxi_states_store.json") -provenance_store_json = os.environ.get("PROVENANCE_STORE", "provenance_store.json") -summary_store_json = os.environ.get("SUMMARY_STORE", "summary_store.json") - -es_store_json = os.environ.get("ES_STORE", "es_store.json") - -bs_store_json = os.environ.get("BS_STORE", "bs_store.json") -dos_store_json = os.environ.get("DOS_STORE", "dos_store.json") - -s3_bs_index_json = os.environ.get("S3_BS_INDEX_STORE", "s3_bs_index.json") -s3_dos_index_json = os.environ.get("S3_DOS_INDEX_STORE", "s3_dos_index.json") - -s3_bs_json = os.environ.get("S3_BS_STORE", "s3_bs.json") -s3_dos_json = os.environ.get("S3_DOS_STORE", "s3_dos.json") - -s3_chgcar_index_json = os.environ.get("CHGCAR_INDEX_STORE", "chgcar_index_store.json") -s3_chgcar_json = os.environ.get("S3_CHGCAR_STORE", "s3_chgcar.json") - -mpcomplete_store_json = os.environ.get("MPCOMPLETE_STORE", "mpcomplete_store.json") - -consumer_settings_store_json = os.environ.get( - "CONSUMER_SETTINGS_STORE", "consumer_settings_store.json" -) - -general_store_json = os.environ.get("GENERAL_STORE_STORE", "general_store_store.json") - - if db_uri: - from maggma.stores import MongoURIStore, S3Store materials_store = MongoURIStore( uri=f"mongodb+srv://{db_uri}", @@ -331,48 +271,8 @@ key="submission_id", collection_name="general_store", ) - else: - materials_store = loadfn(materials_store_json) - bonds_store = loadfn(bonds_store_json) - formula_autocomplete_store = loadfn(formula_autocomplete_store_json) - task_store = loadfn(task_store_json) - thermo_store = loadfn(thermo_store_json) - phase_diagram_store = loadfn(phase_diagram_store_json) - dielectric_store = loadfn(dielectric_store_json) - piezoelectric_store = loadfn(piezoelectric_store_json) - magnetism_store = loadfn(magnetism_store_json) - phonon_bs_store = loadfn(phonon_bs_store_json) - eos_store = loadfn(eos_store_json) - similarity_store = loadfn(similarity_store_json) - xas_store = loadfn(xas_store_json) - gb_store = loadfn(gb_store_json) - fermi_store = loadfn(fermi_store_json) - elasticity_store = loadfn(elasticity_store_json) - doi_store = loadfn(doi_store_json) - substrates_store = loadfn(substrates_store_json) - surface_props_store = loadfn(surface_props_store_json) - robo_store = loadfn(robocrys_store_json) - synth_store = loadfn(synth_store_json) - insertion_electrodes_store = loadfn(insertion_electrodes_store_json) - molecules_store = loadfn(molecules_store_json) - oxi_states_store = loadfn(oxi_states_store_json) - provenance_store = loadfn(provenance_store_json) - summary_store = loadfn(summary_store_json) - - es_store = loadfn(es_store_json) - - s3_bs_index = loadfn(s3_bs_index_json) - s3_dos_index = loadfn(s3_dos_index_json) - s3_bs = loadfn(s3_bs_json) - s3_dos = loadfn(s3_dos_json) - - s3_chgcar_index = loadfn(s3_chgcar_index_json) - s3_chgcar = loadfn(s3_chgcar_json) - - mpcomplete_store = loadfn(mpcomplete_store_json) - consumer_settings_store = loadfn(consumer_settings_store_json) - general_store = loadfn(general_store_json) + raise RuntimeError("Must specify MongoDB Atlas URI") # Materials from mp_api.routes.materials.resources import ( diff --git a/src/mp_api/client.py b/src/mp_api/client.py index c96401f2..34bc63bc 100644 --- a/src/mp_api/client.py +++ b/src/mp_api/client.py @@ -14,10 +14,10 @@ from pymatgen.analysis.magnetism import Ordering from pymatgen.analysis.phase_diagram import PhaseDiagram from pymatgen.analysis.pourbaix_diagram import IonEntry -from pymatgen.analysis.wulff import WulffShape from pymatgen.core import Element, Structure from pymatgen.core.ion import Ion from pymatgen.io.vasp import Chgcar +from pymatgen.entries.computed_entries import ComputedEntry from pymatgen.symmetry.analyzer import SpacegroupAnalyzer from mp_api.core.client import BaseRester, MPRestError @@ -122,7 +122,11 @@ def __init__( self.session = BaseRester._create_session( api_key=api_key, include_user_agent=include_user_agent ) - self.contribs = Client(api_key) + + try: + self.contribs = Client(api_key) + except Exception: + self.contribs = None self._all_resters = [] @@ -180,7 +184,7 @@ def get_task_ids_associated_with_material_id( return list(tasks.values()) def get_structure_by_material_id( - self, material_id, final=True, conventional_unit_cell=False + self, material_id: str, final: bool = True, conventional_unit_cell: bool = False ) -> Union[Structure, List[Structure]]: """ Get a Structure corresponding to a material_id. @@ -232,10 +236,10 @@ def get_database_version(self): "db_version" ] - def get_materials_id_from_task_id(self, task_id): + def get_materials_id_from_task_id(self, task_id: str) -> Union[str, None]: """ Returns the current material_id from a given task_id. The - materials_id should rarely change, and is usually chosen from + material_id should rarely change, and is usually chosen from among the smallest numerical id from the group of task_ids for that material. However, in some circumstances it might change, and this method is useful for finding the new material_id. @@ -244,11 +248,11 @@ def get_materials_id_from_task_id(self, task_id): task_id (str): A task id. Returns: - materials_id (MPID) + material_id (MPID) """ docs = self.materials.search(task_ids=[task_id], fields=["material_id"]) if len(docs) == 1: # pragma: no cover - return str(docs[0].material_id) + return str(docs[0].material_id) # type: ignore elif len(docs) > 1: # pragma: no cover raise ValueError( f"Multiple documents return for {task_id}, this should not happen, please report it!" @@ -259,7 +263,7 @@ def get_materials_id_from_task_id(self, task_id): ) return None - def get_materials_id_references(self, material_id): + def get_materials_id_references(self, material_id: str) -> List[str]: """ Returns all references for a materials id. @@ -267,61 +271,65 @@ def get_materials_id_references(self, material_id): material_id (str): A material id. Returns: - BibTeX (str) + List of BibTeX references ([str]) """ return self.provenance.get_data_by_id(material_id).references - def get_materials_ids(self, chemsys_formula): + def get_materials_ids(self, chemsys_formula: str,) -> List[MPID]: """ Get all materials ids for a formula or chemsys. Args: - chemsys_formula (str): A chemical system (e.g., Li-Fe-O), - or formula (e.g., Fe2O3). + chemsys_formula (str): A chemical system (e.g., Li-Fe-O, Si-*), + or formula (e.g., Fe2O3, Si*). Returns: - ([MPID]) List of all materials ids. + List of all materials ids ([MPID]) """ + + if "-" in chemsys_formula: + input_params = {"chemsys": chemsys_formula} + else: + input_params = {"formula": chemsys_formula} + return sorted( doc.material_id for doc in self.materials.search_material_docs( - chemsys_formula=chemsys_formula, - all_fields=False, - fields=["material_id"], + **input_params, all_fields=False, fields=["material_id"], # type: ignore ) ) - def get_structures(self, chemsys_formula, final=True): + def get_structures(self, chemsys_formula: str, final=True) -> List[Structure]: """ - Get a list of Structures corresponding to a chemical system, formula, - or materials_id. + Get a list of Structures corresponding to a chemical system or formula. Args: - chemsys_formula_id (str): A chemical system (e.g., Li-Fe-O), - or formula (e.g., Fe2O3). + chemsys_formula (str): A chemical system (e.g., Li-Fe-O, Si-*), + or formula (e.g., Fe2O3, Si*). final (bool): Whether to get the final structure, or the list of initial (pre-relaxation) structures. Defaults to True. Returns: - List of Structure objects. + List of Structure objects. ([Structure]) """ + if "-" in chemsys_formula: + input_params = {"chemsys": chemsys_formula} + else: + input_params = {"formula": chemsys_formula} + if final: return [ doc.structure for doc in self.materials.search_material_docs( - chemsys_formula=chemsys_formula, - all_fields=False, - fields=["structure"], + **input_params, all_fields=False, fields=["structure"], # type: ignore ) ] else: structures = [] for doc in self.materials.search_material_docs( - chemsys_formula=chemsys_formula, - all_fields=False, - fields=["initial_structures"], + **input_params, all_fields=False, fields=["initial_structures"], # type: ignore ): structures.extend(doc.initial_structures) @@ -329,11 +337,11 @@ def get_structures(self, chemsys_formula, final=True): def find_structure( self, - filename_or_structure, - ltol=_EMMET_SETTINGS.LTOL, - stol=_EMMET_SETTINGS.STOL, - angle_tol=_EMMET_SETTINGS.ANGLE_TOL, - allow_multiple_results=False, + filename_or_structure: Union[str, Structure], + ltol: float = _EMMET_SETTINGS.LTOL, + stol: float = _EMMET_SETTINGS.STOL, + angle_tol: float = _EMMET_SETTINGS.ANGLE_TOL, + allow_multiple_results: bool = False, ) -> Union[List[str], str]: """ Finds matching structures from the Materials Project database. @@ -366,15 +374,15 @@ def find_structure( ) def get_entries( - self, chemsys_formula, sort_by_e_above_hull=False, + self, chemsys_formula: str, sort_by_e_above_hull=False, ): """ Get a list of ComputedEntries or ComputedStructureEntries corresponding to a chemical system or formula. Args: - chemsys_formula (str): A chemical system - (e.g., Li-Fe-O), or formula (e.g., Fe2O3). + chemsys_formula (str): A chemical system (e.g., Li-Fe-O, Si-*), + or formula (e.g., Fe2O3, Si*). sort_by_e_above_hull (bool): Whether to sort the list of entries by e_above_hull in ascending order. @@ -382,12 +390,17 @@ def get_entries( List of ComputedEntry or ComputedStructureEntry objects. """ + if "-" in chemsys_formula: + input_params = {"chemsys": chemsys_formula} + else: + input_params = {"formula": chemsys_formula} + entries = [] if sort_by_e_above_hull: for doc in self.thermo.search_thermo_docs( - chemsys_formula=chemsys_formula, + **input_params, # type: ignore all_fields=False, fields=["entries"], sort_fields=["energy_above_hull"], @@ -398,7 +411,7 @@ def get_entries( else: for doc in self.thermo.search_thermo_docs( - chemsys_formula=chemsys_formula, all_fields=False, fields=["entries"], + **input_params, all_fields=False, fields=["entries"], # type: ignore ): entries.extend(list(doc.entries.values())) @@ -707,7 +720,7 @@ def get_ion_entries( return ion_entries - def get_entry_by_material_id(self, material_id): + def get_entry_by_material_id(self, material_id: str): """ Get all ComputedEntry objects corresponding to a material_id. @@ -724,7 +737,7 @@ def get_entry_by_material_id(self, material_id): ) def get_entries_in_chemsys( - self, elements, use_gibbs: Optional[int] = None, + self, elements: Union[str, List[str]], use_gibbs: Optional[int] = None, ): """ Helper method to get a list of ComputedEntries in a chemical system. @@ -752,10 +765,10 @@ def get_entries_in_chemsys( for els in itertools.combinations(elements, i + 1): all_chemsyses.append("-".join(sorted(els))) - entries = [] + entries = [] # type: List[ComputedEntry] for chemsys in all_chemsyses: - entries.extend(self.get_entries(chemsys_formula=chemsys)) + entries.extend(self.get_entries(chemsys)) if use_gibbs: # replace the entries with GibbsComputedStructureEntry @@ -800,7 +813,7 @@ def get_dos_by_material_id(self, material_id: str): material_id=material_id ) - def get_phonon_dos_by_material_id(self, material_id): + def get_phonon_dos_by_material_id(self, material_id: str): """ Get phonon density of states data corresponding to a material_id. @@ -813,7 +826,7 @@ def get_phonon_dos_by_material_id(self, material_id): """ return self.phonon.get_data_by_id(material_id, fields=["ph_dos"]).ph_dos - def get_phonon_bandstructure_by_material_id(self, material_id): + def get_phonon_bandstructure_by_material_id(self, material_id: str): """ Get phonon dispersion data corresponding to a material_id. @@ -828,7 +841,9 @@ def get_phonon_bandstructure_by_material_id(self, material_id): def query( self, material_ids: Optional[List[MPID]] = None, - chemsys_formula: Optional[str] = None, + formula: Optional[str] = None, + chemsys: Optional[str] = None, + elements: Optional[List[str]] = None, exclude_elements: Optional[List[str]] = None, possible_species: Optional[List[str]] = None, nsites: Optional[Tuple[int, int]] = None, @@ -887,9 +902,10 @@ def query( Arguments: material_ids (List[MPID]): List of Materials Project IDs to return data for. - chemsys_formula (str): A chemical system (e.g., Li-Fe-O), - or formula including anonomyzed formula + formula (str): A formula including anonomyzed formula or wild cards (e.g., Fe2O3, ABO3, Si*). + chemsys (str): A chemical system including wild cards (e.g., Li-Fe-O, Si-*, *-*). + elements (List[str]): A list of elements. exclude_elements (List(str)): List of elements to exclude. possible_species (List(str)): List of element symbols appended with oxidation states. (e.g. Cr2+,O2-) @@ -962,7 +978,9 @@ def query( return self.summary.search_summary_docs( # type: ignore material_ids=material_ids, - chemsys_formula=chemsys_formula, + formula=formula, + chemsys=chemsys, + elements=elements, exclude_elements=exclude_elements, possible_species=possible_species, nsites=nsites, @@ -1031,7 +1049,7 @@ def submit_structures(self, structures, public_name, public_email): # TODO: call new MPComplete endpoint raise NotImplementedError - def get_wulff_shape(self, material_id) -> WulffShape: + def get_wulff_shape(self, material_id: str): """ Constructs a Wulff shape for a material. diff --git a/src/mp_api/routes/elasticity/query_operators.py b/src/mp_api/routes/elasticity/query_operators.py index c51dc3a1..954b9790 100644 --- a/src/mp_api/routes/elasticity/query_operators.py +++ b/src/mp_api/routes/elasticity/query_operators.py @@ -150,7 +150,7 @@ def query( return {"criteria": crit} -class ChemsysQuery(QueryOperator): +class ElasticityChemsysQuery(QueryOperator): """ Method to generate a query on chemsys data """ diff --git a/src/mp_api/routes/elasticity/resources.py b/src/mp_api/routes/elasticity/resources.py index d7e7d5ad..5499582a 100644 --- a/src/mp_api/routes/elasticity/resources.py +++ b/src/mp_api/routes/elasticity/resources.py @@ -3,7 +3,7 @@ from maggma.api.query_operator import PaginationQuery, SortQuery, SparseFieldsQuery from mp_api.routes.elasticity.query_operators import ( - ChemsysQuery, + ElasticityChemsysQuery, BulkModulusQuery, ShearModulusQuery, PoissonQuery, @@ -15,7 +15,7 @@ def elasticity_resource(elasticity_store): elasticity_store, ElasticityDoc, query_operators=[ - ChemsysQuery(), + ElasticityChemsysQuery(), BulkModulusQuery(), ShearModulusQuery(), PoissonQuery(), diff --git a/src/mp_api/routes/electrodes/client.py b/src/mp_api/routes/electrodes/client.py index 1a6a44ca..21d27430 100644 --- a/src/mp_api/routes/electrodes/client.py +++ b/src/mp_api/routes/electrodes/client.py @@ -11,10 +11,12 @@ class ElectrodeRester(BaseRester[InsertionElectrodeDoc]): document_model = InsertionElectrodeDoc # type: ignore primary_key = "battery_id" - # TODO: This requires a model fix to function properly def search_electrode_docs( # pragma: ignore self, working_ion: Optional[Element] = None, + formula: Optional[str] = None, + elements: Optional[List[str]] = None, + exclude_elements: Optional[List[str]] = None, max_delta_volume: Optional[Tuple[float, float]] = None, average_voltage: Optional[Tuple[float, float]] = None, capacity_grav: Optional[Tuple[float, float]] = None, @@ -38,6 +40,9 @@ def search_electrode_docs( # pragma: ignore Arguments: working_ion (Element): Element of the working ion. + formula (str): Chemical formula of the framework material. + elements (List[str]): A list of elements for the framework material. + exclude_elements (List[str]): A list of elements to exclude for the framework material. max_delta_volume (Tuple[float,float]): Minimum and maximum value of the max volume change in percent for a particular voltage step. average_voltage (Tuple[float,float]): Minimum and maximum value of the average voltage for a particular @@ -71,6 +76,15 @@ def search_electrode_docs( # pragma: ignore if working_ion: query_params.update({"working_ion": str(working_ion)}) + if formula: + query_params.update({"formula": formula}) + + if elements: + query_params.update({"elements": ",".join(elements)}) + + if exclude_elements: + query_params.update({"exclude_elements": ",".join(exclude_elements)}) + if sort_fields: query_params.update( {"sort_fields": ",".join([s.strip() for s in sort_fields])} diff --git a/src/mp_api/routes/electronic_structure/client.py b/src/mp_api/routes/electronic_structure/client.py index ba7e1e56..17c225c1 100644 --- a/src/mp_api/routes/electronic_structure/client.py +++ b/src/mp_api/routes/electronic_structure/client.py @@ -16,6 +16,10 @@ class ElectronicStructureRester(BaseRester[ElectronicStructureDoc]): def search_electronic_structure_docs( self, + formula: Optional[str] = None, + chemsys: Optional[str] = None, + elements: Optional[List[str]] = None, + exclude_elements: Optional[List[str]] = None, band_gap: Optional[Tuple[float, float]] = None, efermi: Optional[Tuple[float, float]] = None, magnetic_ordering: Optional[Ordering] = None, @@ -31,6 +35,11 @@ def search_electronic_structure_docs( Query electronic structure docs using a variety of search criteria. Arguments: + formula (str): A formula including anonomyzed formula + or wild cards (e.g., Fe2O3, ABO3, Si*). + chemsys (str): A chemical system including wild cards (e.g., Li-Fe-O, Si-*, *-*). + elements (List[str]): A list of elements. + exclude_elements (List[str]): A list of elements to exclude. band_gap (Tuple[float,float]): Minimum and maximum band gap in eV to consider. efermi (Tuple[float,float]): Minimum and maximum fermi energy in eV to consider. magnetic_ordering (Ordering): Magnetic ordering of the material. @@ -49,6 +58,18 @@ def search_electronic_structure_docs( query_params = defaultdict(dict) # type: dict + if formula: + query_params.update({"formula": formula}) + + if chemsys: + query_params.update({"chemsys": chemsys}) + + if elements: + query_params.update({"elements": ",".join(elements)}) + + if exclude_elements: + query_params.update({"exclude_elements": ",".join(exclude_elements)}) + if band_gap: query_params.update( {"band_gap_min": band_gap[0], "band_gap_max": band_gap[1]} diff --git a/src/mp_api/routes/electronic_structure/resources.py b/src/mp_api/routes/electronic_structure/resources.py index 7b3916a6..f69310a2 100644 --- a/src/mp_api/routes/electronic_structure/resources.py +++ b/src/mp_api/routes/electronic_structure/resources.py @@ -6,6 +6,7 @@ from mp_api.routes.materials.query_operators import ( ElementsQuery, FormulaQuery, + ChemsysQuery, DeprecationQuery, ) @@ -25,6 +26,7 @@ def es_resource(es_store): query_operators=[ ESSummaryDataQuery(), FormulaQuery(), + ChemsysQuery(), ElementsQuery(), NumericQuery(model=ElectronicStructureDoc), DeprecationQuery(), diff --git a/src/mp_api/routes/materials/client.py b/src/mp_api/routes/materials/client.py index 897859d1..2c9bf3a3 100644 --- a/src/mp_api/routes/materials/client.py +++ b/src/mp_api/routes/materials/client.py @@ -42,7 +42,10 @@ def get_structure_by_material_id( def search_material_docs( self, - chemsys_formula: Optional[str] = None, + formula: Optional[str] = None, + chemsys: Optional[str] = None, + elements: Optional[List[str]] = None, + exclude_elements: Optional[List[str]] = None, task_ids: Optional[List[str]] = None, crystal_system: Optional[CrystalSystem] = None, spacegroup_number: Optional[int] = None, @@ -61,9 +64,11 @@ def search_material_docs( Query core material docs using a variety of search criteria. Arguments: - chemsys_formula (str): A chemical system (e.g., Li-Fe-O), - or formula including anonomyzed formula + formula (str): A formula including anonomyzed formula or wild cards (e.g., Fe2O3, ABO3, Si*). + chemsys (str): A chemical system including wild cards (e.g., Li-Fe-O, Si-*, *-*). + elements (List[str]): A list of elements. + exclude_elements (List[str]): A list of elements to exclude. task_ids (List[str]): List of Materials Project IDs to return data for. crystal_system (CrystalSystem): Crystal system of material. spacegroup_number (int): Space group number of material. @@ -85,8 +90,17 @@ def search_material_docs( query_params = {"deprecated": deprecated} # type: dict - if chemsys_formula: - query_params.update({"formula": chemsys_formula}) + if formula: + query_params.update({"formula": formula}) + + if chemsys: + query_params.update({"chemsys": chemsys}) + + if elements: + query_params.update({"elements": ",".join(elements)}) + + if exclude_elements: + query_params.update({"exclude_elements": ",".join(exclude_elements)}) if task_ids: query_params.update({"task_ids": ",".join(task_ids)}) diff --git a/src/mp_api/routes/materials/query_operators.py b/src/mp_api/routes/materials/query_operators.py index e98c515b..4a1a4db7 100644 --- a/src/mp_api/routes/materials/query_operators.py +++ b/src/mp_api/routes/materials/query_operators.py @@ -5,7 +5,7 @@ from fastapi import Body, HTTPException, Query from maggma.api.query_operator import QueryOperator from maggma.api.utils import STORE_PARAMS -from mp_api.routes.materials.utils import formula_to_criteria +from mp_api.routes.materials.utils import formula_to_criteria, chemsys_to_criteria from pymatgen.analysis.structure_matcher import ElementComparator, StructureMatcher from pymatgen.core.composition import Composition, CompositionError from pymatgen.core.periodic_table import Element @@ -34,7 +34,32 @@ def query( return {"criteria": crit} def ensure_indexes(self): # pragma: no cover - keys = ["chemsys", "formula_pretty", "formula_anonymous", "composition_reduced"] + keys = ["formula_pretty", "formula_anonymous", "composition_reduced"] + return [(key, False) for key in keys] + + +class ChemsysQuery(QueryOperator): + """ + Factory method to generate a dependency for querying by + chemical system with wild cards. + """ + + def query( + self, + chemsys: Optional[str] = Query( + None, description="Query by chemsys including wild cards", + ), + ) -> STORE_PARAMS: + + crit = {} + + if chemsys: + crit.update(chemsys_to_criteria(chemsys)) + + return {"criteria": crit} + + def ensure_indexes(self): # pragma: no cover + keys = ["chemsys", "elements", "nelements"] return [(key, False) for key in keys] diff --git a/src/mp_api/routes/materials/resources.py b/src/mp_api/routes/materials/resources.py index 28be934d..38a9789b 100644 --- a/src/mp_api/routes/materials/resources.py +++ b/src/mp_api/routes/materials/resources.py @@ -17,6 +17,7 @@ from mp_api.routes.materials.query_operators import ( ElementsQuery, FormulaQuery, + ChemsysQuery, DeprecationQuery, SymmetryQuery, MultiTaskIDQuery, @@ -57,6 +58,7 @@ def materials_resource(materials_store): MaterialsDoc, query_operators=[ FormulaQuery(), + ChemsysQuery(), ElementsQuery(), MultiTaskIDQuery(), SymmetryQuery(), diff --git a/src/mp_api/routes/materials/utils.py b/src/mp_api/routes/materials/utils.py index d918fb59..9f77d8dd 100644 --- a/src/mp_api/routes/materials/utils.py +++ b/src/mp_api/routes/materials/utils.py @@ -8,31 +8,14 @@ def formula_to_criteria(formula: str) -> Dict: Santizes formula into a dictionary to search with wild cards Arguments: - formula: a chemical formula with wildcards in it for unknown elements + formula: formula with wildcards in it for unknown elements Returns: Mongo style search criteria for this formula """ dummies = "ADEGJLMQRXZ" - if "-" in formula: - crit = {} # type: dict - eles = formula.split("-") - - if "*" in eles: - crit["nelements"] = len(eles) - crit["elements"] = {"$all": [ele for ele in eles if ele != "*"]} - - if crit["elements"]["$all"] == []: - del crit["elements"] - - return crit - else: - chemsys = "-".join(sorted(eles)) - crit["chemsys"] = chemsys - return crit - - elif "*" in formula: + if "*" in formula: # Wild card in formula nstars = formula.count("*") @@ -64,8 +47,36 @@ def formula_to_criteria(formula: str) -> Dict: comp = Composition(formula) # Paranoia below about floating-point "equality" crit = {} - crit["nelements"] = len(comp) + crit["nelements"] = len(comp) # type: ignore for el, n in comp.to_reduced_dict.items(): crit[f"composition_reduced.{el}"] = n return crit + + +def chemsys_to_criteria(chemsys: str) -> Dict: + """ + Santizes chemsys into a dictionary to search with wild cards + + Arguments: + formula: a chemiscal system with wildcards in it for unknown elements + + Returns: + Mongo style search criteria for this formula + """ + + crit = {} # type: dict + eles = chemsys.split("-") + + if "*" in eles: + crit["nelements"] = len(eles) + crit["elements"] = {"$all": [ele for ele in eles if ele != "*"]} + + if crit["elements"]["$all"] == []: + del crit["elements"] + + return crit + else: + chemsys = "-".join(sorted(eles)) + crit["chemsys"] = chemsys + return crit diff --git a/src/mp_api/routes/oxidation_states/resources.py b/src/mp_api/routes/oxidation_states/resources.py index dea27db8..5b3d97cb 100644 --- a/src/mp_api/routes/oxidation_states/resources.py +++ b/src/mp_api/routes/oxidation_states/resources.py @@ -7,7 +7,7 @@ SparseFieldsQuery, ) -from mp_api.routes.materials.query_operators import FormulaQuery +from mp_api.routes.materials.query_operators import FormulaQuery, ChemsysQuery from mp_api.routes.oxidation_states.query_operators import PossibleOxiStateQuery @@ -17,6 +17,7 @@ def oxi_states_resource(oxi_states_store): OxidationStateDoc, query_operators=[ FormulaQuery(), + ChemsysQuery(), PossibleOxiStateQuery(), SortQuery(), PaginationQuery(), diff --git a/src/mp_api/routes/summary/client.py b/src/mp_api/routes/summary/client.py index 1fad2d9b..7e6335a3 100644 --- a/src/mp_api/routes/summary/client.py +++ b/src/mp_api/routes/summary/client.py @@ -17,7 +17,9 @@ class SummaryRester(BaseRester[SummaryDoc]): def search_summary_docs( self, material_ids: Optional[List[MPID]] = None, - chemsys_formula: Optional[str] = None, + formula: Optional[str] = None, + chemsys: Optional[str] = None, + elements: Optional[List[str]] = None, exclude_elements: Optional[List[str]] = None, possible_species: Optional[List[str]] = None, nsites: Optional[Tuple[int, int]] = None, @@ -76,9 +78,10 @@ def search_summary_docs( Arguments: material_ids (List[MPID]): List of Materials Project IDs to return data for. - chemsys_formula (str): A chemical system (e.g., Li-Fe-O), - or formula including anonomyzed formula + formula (str): A formula including anonomyzed formula or wild cards (e.g., Fe2O3, ABO3, Si*). + chemsys (str): A chemical system including wild cards (e.g., Li-Fe-O, Si-*, *-*). + elements (List[str]): A list of elements. exclude_elements (List(str)): List of elements to exclude. possible_species (List(str)): List of element symbols appended with oxidation states. (e.g. Cr2+,O2-) @@ -201,8 +204,14 @@ def search_summary_docs( if deprecated is not None: query_params.update({"deprecated": deprecated}) - if chemsys_formula: - query_params.update({"formula": chemsys_formula}) + if formula: + query_params.update({"formula": formula}) + + if chemsys: + query_params.update({"chemsys": chemsys}) + + if elements: + query_params.update({"elements": ",".join(elements)}) if exclude_elements is not None: query_params.update({"exclude_elements": ",".join(exclude_elements)}) diff --git a/src/mp_api/routes/summary/resources.py b/src/mp_api/routes/summary/resources.py index 6c8d7491..7c4284e2 100644 --- a/src/mp_api/routes/summary/resources.py +++ b/src/mp_api/routes/summary/resources.py @@ -11,6 +11,7 @@ DeprecationQuery, ElementsQuery, FormulaQuery, + ChemsysQuery, SymmetryQuery, ) from mp_api.routes.oxidation_states.query_operators import PossibleOxiStateQuery @@ -35,6 +36,7 @@ def summary_resource(summary_store): query_operators=[ MaterialIDsSearchQuery(), FormulaQuery(), + ChemsysQuery(), ElementsQuery(), PossibleOxiStateQuery(), SymmetryQuery(), diff --git a/src/mp_api/routes/tasks/client.py b/src/mp_api/routes/tasks/client.py index 4ab4c46c..ec0030e1 100644 --- a/src/mp_api/routes/tasks/client.py +++ b/src/mp_api/routes/tasks/client.py @@ -24,7 +24,8 @@ def get_trajectory(self, task_id): def search_task_docs( self, - chemsys_formula: Optional[str] = None, + formula: Optional[str] = None, + chemsys: Optional[str] = None, num_chunks: Optional[int] = None, chunk_size: int = 1000, all_fields: bool = True, @@ -34,9 +35,9 @@ def search_task_docs( Query core task docs using a variety of search criteria. Arguments: - chemsys_formula (str): A chemical system (e.g., Li-Fe-O), - or formula including anonomyzed formula + formula (str): A formula including anonomyzed formula or wild cards (e.g., Fe2O3, ABO3, Si*). + chemsys (str): A chemical system including wild cards (e.g., Li-Fe-O, Si-*, *-*). num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. chunk_size (int): Number of data entries per chunk. Max size is 100. all_fields (bool): Whether to return all fields in the document. Defaults to True. @@ -49,8 +50,11 @@ def search_task_docs( query_params = {} # type: dict - if chemsys_formula: - query_params.update({"formula": chemsys_formula}) + if formula: + query_params.update({"formula": formula}) + + if chemsys: + query_params.update({"chemsys": chemsys}) return super().search( num_chunks=num_chunks, diff --git a/src/mp_api/routes/tasks/resources.py b/src/mp_api/routes/tasks/resources.py index 5c70a768..8db69c83 100644 --- a/src/mp_api/routes/tasks/resources.py +++ b/src/mp_api/routes/tasks/resources.py @@ -11,6 +11,7 @@ from mp_api.routes.materials.query_operators import ( ElementsQuery, FormulaQuery, + ChemsysQuery, ) @@ -20,6 +21,7 @@ def task_resource(task_store): TaskDoc, query_operators=[ FormulaQuery(), + ChemsysQuery(), ElementsQuery(), MultipleTaskIDsQuery(), SortQuery(), diff --git a/src/mp_api/routes/thermo/client.py b/src/mp_api/routes/thermo/client.py index b35a951d..4023a0c3 100644 --- a/src/mp_api/routes/thermo/client.py +++ b/src/mp_api/routes/thermo/client.py @@ -15,7 +15,8 @@ class ThermoRester(BaseRester[ThermoDoc]): def search_thermo_docs( self, material_ids: Optional[List[str]] = None, - chemsys_formula: Optional[str] = None, + formula: Optional[str] = None, + chemsys: Optional[str] = None, nelements: Optional[Tuple[int, int]] = None, is_stable: Optional[bool] = None, total_energy: Optional[Tuple[float, float]] = None, @@ -34,9 +35,9 @@ def search_thermo_docs( Arguments: material_ids (List[str]): List of Materials Project IDs to return data for. - chemsys_formula (str): A chemical system (e.g., Li-Fe-O), - or formula including anonomyzed formula + formula (str): A formula including anonomyzed formula or wild cards (e.g., Fe2O3, ABO3, Si*). + chemsys (str): A chemical system including wild cards (e.g., Li-Fe-O, Si-*, *-*). nelements (Tuple[int,int]): Minimum and maximum number of elements in the material to consider. is_stable (bool): Whether the material is stable. total_energy (Tuple[float,float]): Minimum and maximum corrected total energy in eV/atom to consider. @@ -59,8 +60,11 @@ def search_thermo_docs( query_params = defaultdict(dict) # type: dict - if chemsys_formula: - query_params.update({"formula": chemsys_formula}) + if formula: + query_params.update({"formula": formula}) + + if chemsys: + query_params.update({"chemsys": chemsys}) if material_ids: query_params.update({"material_ids": ",".join(material_ids)}) diff --git a/src/mp_api/routes/thermo/resources.py b/src/mp_api/routes/thermo/resources.py index 0062ad3f..37b64d43 100644 --- a/src/mp_api/routes/thermo/resources.py +++ b/src/mp_api/routes/thermo/resources.py @@ -9,7 +9,11 @@ SparseFieldsQuery, ) from mp_api.routes.thermo.query_operators import IsStableQuery -from mp_api.routes.materials.query_operators import MultiMaterialIDQuery, FormulaQuery +from mp_api.routes.materials.query_operators import ( + MultiMaterialIDQuery, + FormulaQuery, + ChemsysQuery, +) def phase_diagram_resource(phase_diagram_store): @@ -32,6 +36,7 @@ def thermo_resource(thermo_store): query_operators=[ MultiMaterialIDQuery(), FormulaQuery(), + ChemsysQuery(), IsStableQuery(), NumericQuery(model=ThermoDoc), SortQuery(), diff --git a/src/mp_api/routes/xas/client.py b/src/mp_api/routes/xas/client.py index 835f9d03..577de718 100644 --- a/src/mp_api/routes/xas/client.py +++ b/src/mp_api/routes/xas/client.py @@ -11,12 +11,12 @@ class XASRester(BaseRester[XASDoc]): primary_key = "spectrum_id" def search_xas_docs( - # TODO: add proper docstring self, edge: Optional[Edge] = None, absorbing_element: Optional[Element] = None, - required_elements: Optional[List[Element]] = None, formula: Optional[str] = None, + chemsys: Optional[str] = None, + elements: Optional[List[str]] = None, task_ids: Optional[List[str]] = None, sort_fields: Optional[List[str]] = None, num_chunks: Optional[int] = None, @@ -24,18 +24,46 @@ def search_xas_docs( all_fields: bool = True, fields: Optional[List[str]] = None, ): - query_params = { - "edge": str(edge.value) if edge else None, - "absorbing_element": str(absorbing_element) if absorbing_element else None, - "formula": formula, - } # type: dict + """ + Query core XAS docs using a variety of search criteria. + + Arguments: + edge (Edge): The absorption edge (e.g. K, L2, L3, L2,3). + formula (str): A formula including anonomyzed formula + or wild cards (e.g., Fe2O3, ABO3, Si*). + chemsys (str): A chemical system including wild cards (e.g., Li-Fe-O, Si-*, *-*). + elements (List[str]): A list of elements. + task_ids (List[str]): List of Materials Project IDs to return data for. + sort_fields (List[str]): Fields used to sort results. Prefix with '-' to sort in descending order. + num_chunks (int): Maximum number of chunks of data to yield. None will yield all possible. + chunk_size (int): Number of data entries per chunk. + all_fields (bool): Whether to return all fields in the document. Defaults to True. + fields (List[str]): List of fields in MaterialsCoreDoc to return data for. + Default is material_id, last_updated, and formula_pretty if all_fields is False. + + Returns: + ([MaterialsDoc]) List of material documents + """ + query_params = {} + + if edge: + query_params.update({"edge": str(edge.value)}) + + if absorbing_element: + query_params.update({"absorbing_element": str(absorbing_element.symbol)}) + + if formula: + query_params.update({"formula": formula}) + + if chemsys: + query_params.update({"chemsys": chemsys}) + + if elements: + query_params.update({"elements": ",".join(elements)}) if task_ids is not None: query_params["task_ids"] = ",".join(task_ids) - if required_elements: - query_params["elements"] = ",".join([str(el) for el in required_elements]) - if sort_fields: query_params.update( {"sort_fields": ",".join([s.strip() for s in sort_fields])} diff --git a/src/mp_api/routes/xas/resources.py b/src/mp_api/routes/xas/resources.py index 2df979d6..2d7011e4 100644 --- a/src/mp_api/routes/xas/resources.py +++ b/src/mp_api/routes/xas/resources.py @@ -2,7 +2,11 @@ from emmet.core.xas import XASDoc from maggma.api.query_operator import PaginationQuery, SortQuery, SparseFieldsQuery -from mp_api.routes.materials.query_operators import ElementsQuery, FormulaQuery +from mp_api.routes.materials.query_operators import ( + ElementsQuery, + FormulaQuery, + ChemsysQuery, +) from mp_api.routes.xas.query_operators import XASQuery, XASTaskIDQuery @@ -12,6 +16,7 @@ def xas_resource(xas_store): XASDoc, query_operators=[ FormulaQuery(), + ChemsysQuery(), ElementsQuery(), XASQuery(), XASTaskIDQuery(), diff --git a/tests/elasticity/test_query_operators.py b/tests/elasticity/test_query_operators.py index f7a2084c..ac4f5f2b 100644 --- a/tests/elasticity/test_query_operators.py +++ b/tests/elasticity/test_query_operators.py @@ -2,7 +2,7 @@ BulkModulusQuery, ShearModulusQuery, PoissonQuery, - ChemsysQuery, + ElasticityChemsysQuery, ) from monty.tempfile import ScratchDir @@ -96,7 +96,7 @@ def test_poisson_query(): def test_chemsys_query(): - op = ChemsysQuery() + op = ElasticityChemsysQuery() assert op.query(chemsys="Fe-Bi-O") == {"criteria": {"chemsys": "Bi-Fe-O"}} diff --git a/tests/electrodes/test_client.py b/tests/electrodes/test_client.py index 00013a20..2403259c 100644 --- a/tests/electrodes/test_client.py +++ b/tests/electrodes/test_client.py @@ -18,12 +18,23 @@ sub_doc_fields = [] # type: list -alt_name_dict = {} # type: dict +alt_name_dict = { + "formula": "battery_id", + "exclude_elements": "battery_id", +} # type: dict -custom_field_tests = {"working_ion": Element("Li")} # type: dict +custom_field_tests = { + "working_ion": Element("Li"), + "formula": "CoO2", + "chemsys": "Co-O", + "elements": ["Co", "O"], + "exclude_elements": ["Co"], +} # type: dict -@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") +@pytest.mark.skipif( + os.environ.get("MP_API_KEY", None) is None, reason="No API key found." +) @pytest.mark.parametrize("rester", resters) def test_client(rester): # Get specific search method @@ -76,4 +87,7 @@ def test_client(rester): if sub_field in doc: doc = doc[sub_field] - assert doc[project_field if project_field is not None else param] is not None + assert ( + doc[project_field if project_field is not None else param] + is not None + ) diff --git a/tests/electronic_structure/test_client.py b/tests/electronic_structure/test_client.py index 9f2a1e81..17ac8de0 100644 --- a/tests/electronic_structure/test_client.py +++ b/tests/electronic_structure/test_client.py @@ -32,14 +32,24 @@ def es_rester(): sub_doc_fields = [] # type: list -es_alt_name_dict = {} # type: dict +es_alt_name_dict = { + "exclude_elements": "material_id", + "formula": "material_id", +} # type: dict es_custom_field_tests = { "magnetic_ordering": Ordering.FM, + "formula": "CoO2", + "chemsys": "Co-O", + "elements": ["Co", "O"], + "exclude_elements": ["Co"], } # type: dict -@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") +@pytest.mark.xfail(reason="Until deployment of new API") +@pytest.mark.skipif( + os.environ.get("MP_API_KEY", None) is None, reason="No API key found." +) def test_es_client(es_rester): # Get specific search method search_method = None @@ -90,7 +100,10 @@ def test_es_client(es_rester): doc = search_method(**q)[0].dict() - assert doc[project_field if project_field is not None else param] is not None + assert ( + doc[project_field if project_field is not None else param] + is not None + ) bs_custom_field_tests = { @@ -113,7 +126,9 @@ def bs_rester(): rester.session.close() -@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") +@pytest.mark.skipif( + os.environ.get("MP_API_KEY", None) is None, reason="No API key found." +) def test_bs_client(bs_rester): # Get specific search method search_method = None @@ -165,7 +180,9 @@ def dos_rester(): rester.session.close() -@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") +@pytest.mark.skipif( + os.environ.get("MP_API_KEY", None) is None, reason="No API key found." +) def test_dos_client(dos_rester): # Get specific search method search_method = None @@ -190,4 +207,7 @@ def test_dos_client(dos_rester): if param != "projection_type" and param != "magnetic_ordering": doc = doc["total"]["1"] - assert doc[project_field if project_field is not None else param] is not None + assert ( + doc[project_field if project_field is not None else param] is not None + ) + diff --git a/tests/materials/test_client.py b/tests/materials/test_client.py index 8dd76df7..5ae48cca 100644 --- a/tests/materials/test_client.py +++ b/tests/materials/test_client.py @@ -19,14 +19,18 @@ sub_doc_fields = [] # type: list alt_name_dict = { - "chemsys_formula": "material_id", + "formula": "material_id", "crystal_system": "symmetry", "spacegroup_number": "symmetry", "spacegroup_symbol": "symmetry", + "exclude_elements": "material_id", } # type: dict custom_field_tests = { - "chemsys_formula": "Si", + "formula": "Si", + "chemsys": "Si-O", + "elements": ["Si", "O"], + "exclude_elements": ["Si"], "task_ids": ["mp-149"], "crystal_system": CrystalSystem.cubic, "spacegroup_number": 38, @@ -34,7 +38,10 @@ } # type: dict -@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") +@pytest.mark.xfail(reason="Until deployment of new API") +@pytest.mark.skipif( + os.environ.get("MP_API_KEY", None) is None, reason="No API key found." +) @pytest.mark.parametrize("rester", resters) def test_client(rester): # Get specific search method @@ -88,4 +95,7 @@ def test_client(rester): if sub_field in doc: doc = doc[sub_field] - assert doc[project_field if project_field is not None else param] is not None + assert ( + doc[project_field if project_field is not None else param] + is not None + ) diff --git a/tests/materials/test_query_operators.py b/tests/materials/test_query_operators.py index 346cef86..121f54c6 100644 --- a/tests/materials/test_query_operators.py +++ b/tests/materials/test_query_operators.py @@ -3,6 +3,7 @@ from mp_api.core.settings import MAPISettings from mp_api.routes.materials.query_operators import ( FormulaQuery, + ChemsysQuery, ElementsQuery, DeprecationQuery, SymmetryQuery, @@ -40,6 +41,20 @@ def test_formula_query(): } +def test_chemsys_query(): + op = ChemsysQuery() + assert op.query("Si-O") == {"criteria": {"chemsys": "O-Si"}} + + assert op.query("Si-*") == { + "criteria": {"nelements": 2, "elements": {"$all": ["Si"]}} + } + + with ScratchDir("."): + dumpfn(op, "temp.json") + new_op = loadfn("temp.json") + assert new_op.query("Si-O") == {"criteria": {"chemsys": "O-Si"}} + + def test_elements_query(): eles = ["Si", "O"] neles = ["N", "P"] diff --git a/tests/materials/test_utils.py b/tests/materials/test_utils.py index be7471d2..b2f6e578 100644 --- a/tests/materials/test_utils.py +++ b/tests/materials/test_utils.py @@ -1,4 +1,4 @@ -from mp_api.routes.materials.utils import formula_to_criteria +from mp_api.routes.materials.utils import formula_to_criteria, chemsys_to_criteria def test_formula_to_criteria(): @@ -16,7 +16,9 @@ def test_formula_to_criteria(): # Anonymous element assert formula_to_criteria("A2B3") == {"formula_anonymous": "A2B3"} + +def test_chemsys_to_criteria(): # Chemsys - assert formula_to_criteria("Si-O") == {"chemsys": "O-Si"} - assert formula_to_criteria("Si-*") == {"elements": {"$all": ["Si"]}, "nelements": 2} - assert formula_to_criteria("*-*-*") == {"nelements": 3} + assert chemsys_to_criteria("Si-O") == {"chemsys": "O-Si"} + assert chemsys_to_criteria("Si-*") == {"elements": {"$all": ["Si"]}, "nelements": 2} + assert chemsys_to_criteria("*-*-*") == {"nelements": 3} diff --git a/tests/tasks/test_client.py b/tests/tasks/test_client.py index d9a1cb3a..3bda5180 100644 --- a/tests/tasks/test_client.py +++ b/tests/tasks/test_client.py @@ -17,12 +17,15 @@ sub_doc_fields = [] # type: list -alt_name_dict = {"chemsys_formula": "task_id"} # type: dict +alt_name_dict = {"formula": "task_id"} # type: dict -custom_field_tests = {"chemsys_formula": "Si-O"} # type: dict +custom_field_tests = {"formula": "Si", "chemsys": "Si-O"} # type: dict -@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") +@pytest.mark.xfail(reason="Until deployment of new API") +@pytest.mark.skipif( + os.environ.get("MP_API_KEY", None) is None, reason="No API key found." +) @pytest.mark.parametrize("rester", resters) def test_client(rester): # Get specific search method @@ -48,7 +51,9 @@ def test_client(rester): param: (-100, 100), "chunk_size": 1, "num_chunks": 1, - "fields": [project_field if project_field is not None else param], + "fields": [ + project_field if project_field is not None else param + ], } elif param_type is typing.Tuple[float, float]: project_field = alt_name_dict.get(param, None) @@ -56,7 +61,9 @@ def test_client(rester): param: (-100.12, 100.12), "chunk_size": 1, "num_chunks": 1, - "fields": [project_field if project_field is not None else param], + "fields": [ + project_field if project_field is not None else param + ], } elif param_type is bool: project_field = alt_name_dict.get(param, None) @@ -64,7 +71,9 @@ def test_client(rester): param: False, "chunk_size": 1, "num_chunks": 1, - "fields": [project_field if project_field is not None else param], + "fields": [ + project_field if project_field is not None else param + ], } elif param in custom_field_tests: project_field = alt_name_dict.get(param, None) @@ -72,7 +81,9 @@ def test_client(rester): param: custom_field_tests[param], "chunk_size": 1, "num_chunks": 1, - "fields": [project_field if project_field is not None else param], + "fields": [ + project_field if project_field is not None else param + ], } doc = search_method(**q)[0].dict() @@ -80,4 +91,7 @@ def test_client(rester): if sub_field in doc: doc = doc[sub_field] - assert doc[project_field if project_field is not None else param] is not None + assert ( + doc[project_field if project_field is not None else param] + is not None + ) diff --git a/tests/test_mprester.py b/tests/test_mprester.py index e9e3226d..0a41e207 100644 --- a/tests/test_mprester.py +++ b/tests/test_mprester.py @@ -68,17 +68,24 @@ def test_get_materials_id_references(self, mpr): data = mpr.get_materials_id_references("mp-123") assert len(data) > 5 + @pytest.mark.xfail(reason="Until deployment of new API") def test_get_materials_ids_doc(self, mpr): mpids = mpr.get_materials_ids("Al2O3") random.shuffle(mpids) doc = mpr.materials.get_data_by_id(mpids.pop(0)) assert doc.formula_pretty == "Al2O3" + mpids = mpr.get_materials_ids(chemsys="Al-O") + random.shuffle(mpids) + doc = mpr.materials.get_data_by_id(mpids.pop(0)) + assert doc.chemsys == "Al-O" + + @pytest.mark.xfail(reason="Until deployment of new API") def test_get_structures(self, mpr): structs = mpr.get_structures("Mn3O4") assert len(structs) > 0 - structs = mpr.get_structures("Mn3O4", final=False) + structs = mpr.get_structures("Mn-O", final=False) assert len(structs) > 0 def test_find_structure(self, mpr): @@ -107,6 +114,7 @@ def test_get_entry_by_material_id(self, mpr): assert isinstance(e[0], ComputedEntry) assert e[0].composition.reduced_formula == "LiFePO4" + @pytest.mark.xfail(reason="Until deployment of new API") def test_get_entries(self, mpr): syms = ["Li", "Fe", "O"] chemsys = "Li-Fe-O" @@ -120,6 +128,13 @@ def test_get_entries(self, mpr): assert sorted_entries != entries + formula = "SiO2" + entries = mpr.get_entries(formula) + + for e in entries: + assert isinstance(e, ComputedEntry) + + @pytest.mark.xfail(reason="Until deployment of new API") def test_get_entries_in_chemsys(self, mpr): syms = ["Li", "Fe", "O"] syms2 = "Li-Fe-O" @@ -138,6 +153,7 @@ def test_get_entries_in_chemsys(self, mpr): for e in gibbs_entries: assert isinstance(e, GibbsComputedStructureEntry) + @pytest.mark.xfail(reason="Until deployment of new API") def test_get_pourbaix_entries(self, mpr): # test input chemsys as a list of elements pbx_entries = mpr.get_pourbaix_entries(["Fe", "Cr"]) @@ -173,6 +189,7 @@ def test_get_pourbaix_entries(self, mpr): # Ensure entries are pourbaix compatible PourbaixDiagram(pbx_entries) + @pytest.mark.xfail(reason="Until deployment of new API") def test_get_ion_entries(self, mpr): entries = mpr.get_entries_in_chemsys("Ti-O-H") pd = PhaseDiagram(entries) @@ -209,6 +226,7 @@ def test_get_wulff_shape(self, mpr): ws = mpr.get_wulff_shape("mp-126") assert isinstance(ws, WulffShape) + @pytest.mark.xfail(reason="Until deployment of new API") def test_query(self, mpr): excluded_params = [ @@ -222,7 +240,7 @@ def test_query(self, mpr): alt_name_dict = { "material_ids": "material_id", - "chemsys_formula": "formula_pretty", + "formula": "formula_pretty", "exclude_elements": "formula_pretty", "piezoelectric_modulus": "e_ij_max", "crystal_system": "symmetry", @@ -241,7 +259,9 @@ def test_query(self, mpr): custom_field_tests = { "material_ids": ["mp-149"], - "chemsys_formula": "SiO2", + "formula": "SiO2", + "chemsys": "Si-O", + "elements": ["Si", "O"], "exclude_elements": ["Si"], "possible_species": ["O2-"], "crystal_system": CrystalSystem.cubic, diff --git a/tests/thermo/test_client.py b/tests/thermo/test_client.py index bf6c5bde..7628a580 100644 --- a/tests/thermo/test_client.py +++ b/tests/thermo/test_client.py @@ -20,7 +20,7 @@ sub_doc_fields = [] # type: list alt_name_dict = { - "chemsys_formula": "formula_pretty", + "formula": "formula_pretty", "material_ids": "material_id", "total_energy": "energy_per_atom", "formation_energy": "formation_energy_per_atom", @@ -30,10 +30,12 @@ custom_field_tests = { "material_ids": ["mp-149"], - "chemsys_formula": "Si-O", + "formula": "SiO2", + "chemsys": "Si-O", } # type: dict +@pytest.mark.xfail(reason="Until deployment of new API") @pytest.mark.skipif( os.environ.get("MP_API_KEY", None) is None, reason="No API key found." ) diff --git a/tests/xas/test_client.py b/tests/xas/test_client.py index 2e92d2ac..fa4c1580 100644 --- a/tests/xas/test_client.py +++ b/tests/xas/test_client.py @@ -23,6 +23,7 @@ alt_name_dict = { "required_elements": "elements", "formula": "formula_pretty", + "exclude_elements": "material_id", } # type: dict custom_field_tests = { @@ -30,10 +31,15 @@ "absorbing_element": Element("Ce"), "required_elements": [Element("Ce")], "formula": "Ce(WO4)2", + "chemsys": "Ce-O-W", + "elements": ["Ce"], } # type: dict -@pytest.mark.skipif(os.environ.get("MP_API_KEY", None) is None, reason="No API key found.") +@pytest.mark.xfail(reason="Until deployment of new API") +@pytest.mark.skipif( + os.environ.get("MP_API_KEY", None) is None, reason="No API key found." +) @pytest.mark.parametrize("rester", resters) def test_client(rester): # Get specific search method @@ -59,7 +65,9 @@ def test_client(rester): param: (-100, 100), "chunk_size": 1, "num_chunks": 1, - "fields": [project_field if project_field is not None else param], + "fields": [ + project_field if project_field is not None else param + ], } elif param_type is typing.Tuple[float, float]: project_field = alt_name_dict.get(param, None) @@ -67,7 +75,9 @@ def test_client(rester): param: (-100.12, 100.12), "chunk_size": 1, "num_chunks": 1, - "fields": [project_field if project_field is not None else param], + "fields": [ + project_field if project_field is not None else param + ], } elif param_type is bool: project_field = alt_name_dict.get(param, None) @@ -75,7 +85,9 @@ def test_client(rester): param: False, "chunk_size": 1, "num_chunks": 1, - "fields": [project_field if project_field is not None else param], + "fields": [ + project_field if project_field is not None else param + ], } elif param in custom_field_tests: project_field = alt_name_dict.get(param, None) @@ -83,7 +95,9 @@ def test_client(rester): param: custom_field_tests[param], "chunk_size": 1, "num_chunks": 1, - "fields": [project_field if project_field is not None else param], + "fields": [ + project_field if project_field is not None else param + ], } doc = search_method(**q)[0].dict() @@ -91,4 +105,7 @@ def test_client(rester): if sub_field in doc: doc = doc[sub_field] - assert doc[project_field if project_field is not None else param] is not None + assert ( + doc[project_field if project_field is not None else param] + is not None + )