From 826598fba762a943523f98aba00ab3fef4afdd20 Mon Sep 17 00:00:00 2001 From: lbluque Date: Wed, 17 Jan 2024 15:52:18 -0800 Subject: [PATCH 01/64] minor cleanup of lmbddatabase --- ocpmodels/datasets/lmdb_database.py | 71 ++++++++++++++++------------- 1 file changed, 39 insertions(+), 32 deletions(-) diff --git a/ocpmodels/datasets/lmdb_database.py b/ocpmodels/datasets/lmdb_database.py index 214315067b..fdc345774f 100644 --- a/ocpmodels/datasets/lmdb_database.py +++ b/ocpmodels/datasets/lmdb_database.py @@ -8,8 +8,10 @@ https://gitlab.com/ase/ase/-/blob/master/LICENSE """ +from __future__ import annotations import os +import typing import zlib from typing import Optional @@ -18,6 +20,11 @@ import orjson from ase.db.core import Database, now, ops from ase.db.row import AtomsRow +from typing_extensions import Self + +if typing.TYPE_CHECKING: + from ase import Atoms + # These are special keys in the ASE LMDB that hold # metadata and other info @@ -25,9 +32,6 @@ class LMDBDatabase(Database): - def __enter__(self) -> "LMDBDatabase": - return self - def __init__( self, filename: Optional[str] = None, @@ -79,7 +83,8 @@ def __init__( # Load all ids based on keys in the DB. self._load_ids() - return + def __enter__(self) -> Self: + return self def __exit__(self, exc_type, exc_value, tb) -> None: self.close() @@ -89,7 +94,13 @@ def close(self) -> None: self.txn.commit() self.env.close() - def _write(self, atoms, key_value_pairs, data, id): + def _write( + self, + atoms: Atoms | AtomsRow, + key_value_pairs: dict, + data: Optional[dict], + idx: Optional[int] = None, + ) -> None: Database._write(self, atoms, key_value_pairs, data) mtime = now() @@ -121,25 +132,24 @@ def _write(self, atoms, key_value_pairs, data, id): constraint.todict() for constraint in constraints ] - # json doesn't like Cell objects, so make it a cell + # json doesn't like Cell objects, so make it an array dct["cell"] = np.asarray(dct["cell"]) - if id is None: - nextid = self._get_nextid() - id = nextid - nextid += 1 + if idx is None: + idx = self._get_nextid() + nextid = idx + 1 else: data = self.txn.get("{id}".encode("ascii")) assert data is not None # Add the new entry, then add the id and write the nextid self.txn.put( - f"{id}".encode("ascii"), + f"{idx}".encode("ascii"), zlib.compress( orjson.dumps(dct, option=orjson.OPT_SERIALIZE_NUMPY) ), ) - self.ids.append(id) + self.ids.append(idx) self.txn.put( "nextid".encode("ascii"), zlib.compress( @@ -147,12 +157,12 @@ def _write(self, atoms, key_value_pairs, data, id): ), ) - return id + return idx - def delete(self, ids) -> None: - for id in ids: - self.txn.delete(f"{id}".encode("ascii")) - self.ids.remove(id) + def delete(self, ids: list[int]) -> None: + for idx in ids: + self.txn.delete(f"{idx}".encode("ascii")) + self.ids.remove(idx) self.deleted_ids += ids self.txn.put( @@ -164,21 +174,21 @@ def delete(self, ids) -> None: ), ) - def _get_row(self, id, include_data: bool = True): - if id is None: + def _get_row(self, idx: int, include_data: bool = True): + if idx is None: assert len(self.ids) == 1 - id = self.ids[0] - data = self.txn.get(f"{id}".encode("ascii")) + idx = self.ids[0] + data = self.txn.get(f"{idx}".encode("ascii")) if data is not None: dct = orjson.loads(zlib.decompress(data)) else: - raise KeyError(f"Id {id} missing from the database!") + raise KeyError(f"Id {idx} missing from the database!") if not include_data: dct.pop("data", None) - dct["id"] = id + dct["id"] = idx return AtomsRow(dct) def _get_row_by_index(self, index: int, include_data: bool = True): @@ -202,12 +212,12 @@ def _get_row_by_index(self, index: int, include_data: bool = True): def _select( self, keys, - cmps, + cmps: list[tuple[str, str, str]], explain: bool = False, verbosity: int = 0, - limit=None, + limit: Optional[int] = None, offset: int = 0, - sort=None, + sort: Optional[str] = None, include_data: bool = True, columns: str = "all", ): @@ -215,16 +225,13 @@ def _select( yield {"explain": (0, 0, 0, "scan table")} return - if sort: + if sort is not None: if sort[0] == "-": reverse = True sort = sort[1:] else: reverse = False - def f(row): - return row.get(sort, missing) - rows = [] missing = [] for row in self._select(keys, cmps): @@ -248,10 +255,10 @@ def f(row): cmps = [(key, ops[op], val) for key, op, val in cmps] n = 0 - for id in self.ids: + for idx in self.ids: if n - offset == limit: return - row = self._get_row(id, include_data=False) + row = self._get_row(idx, include_data=include_data) for key in keys: if key not in row: From 324a645f5913829268bf09011a87de784a785ad7 Mon Sep 17 00:00:00 2001 From: lbluque Date: Wed, 17 Jan 2024 18:01:16 -0800 Subject: [PATCH 02/64] ase dataset compat for unified trainer and cleanup --- ocpmodels/datasets/ase_datasets.py | 75 +++++++++++++----------------- 1 file changed, 32 insertions(+), 43 deletions(-) diff --git a/ocpmodels/datasets/ase_datasets.py b/ocpmodels/datasets/ase_datasets.py index 9e4d76b43e..cf52573d93 100644 --- a/ocpmodels/datasets/ase_datasets.py +++ b/ocpmodels/datasets/ase_datasets.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import bisect import copy import functools @@ -7,7 +9,7 @@ import warnings from abc import ABC, abstractmethod from pathlib import Path -from typing import List +from typing import Any, Callable, Optional import ase import numpy as np @@ -18,6 +20,7 @@ from ocpmodels.common.registry import registry from ocpmodels.datasets.lmdb_database import LMDBDatabase from ocpmodels.datasets.target_metadata_guesser import guess_property_metadata +from ocpmodels.modules.transforms import DataTransforms from ocpmodels.preprocessing import AtomsToGraphs @@ -65,19 +68,24 @@ class AseAtomsDataset(Dataset, ABC): """ def __init__( - self, config, transform=None, atoms_transform=apply_one_tags + self, + config: dict, + atoms_transform: Callable[ + [ase.Atoms, Any], ase.Atoms + ] = apply_one_tags, + transform=None, ) -> None: self.config = config a2g_args = config.get("a2g_args", {}) - if a2g_args is None: - a2g_args = {} # Make sure we always include PBC info in the resulting atoms objects a2g_args["r_pbc"] = True self.a2g = AtomsToGraphs(**a2g_args) - self.transform = transform + self.key_mapping = self.config.get("key_mapping", None) + self.transforms = DataTransforms(self.config.get("transforms", {})) + self.atoms_transform = atoms_transform if self.config.get("keep_in_memory", False): @@ -91,7 +99,7 @@ def __len__(self) -> int: def __getitem__(self, idx): # Handle slicing if isinstance(idx, slice): - return [self[i] for i in range(*idx.indices(len(self.ids)))] + return [self[i] for i in range(*idx.indices(len(self)))] # Get atoms object via derived class method atoms = self.get_atoms_object(self.ids[idx]) @@ -105,10 +113,10 @@ def __getitem__(self, idx): sid = atoms.info.get("sid", self.ids[idx]) try: sid = tensor([sid]) + except (RuntimeError, ValueError, TypeError): warnings.warn( "Supplied sid is not numeric (or missing). Using dataset indices instead." ) - except: sid = tensor([idx]) fid = atoms.info.get("fid", tensor([0])) @@ -118,11 +126,17 @@ def __getitem__(self, idx): data_object.fid = fid data_object.natoms = len(atoms) + if self.key_mapping is not None: + for _property in self.key_mapping: + # catch for test data not containing labels + if _property in data_object: + new_property = self.key_mapping[_property] + if new_property not in data_object: + data_object[new_property] = data_object[_property] + del data_object[_property] + # Transform data object - if self.transform is not None: - data_object = self.transform( - data_object, **self.config.get("transform_args", {}) - ) + data_object = self.transforms(data_object) if self.config.get("include_relaxed_energy", False): data_object.y_relaxed = self.get_relaxed_energy(self.ids[idx]) @@ -220,7 +234,7 @@ class AseReadDataset(AseAtomsDataset): """ - def load_dataset_get_ids(self, config) -> List[Path]: + def load_dataset_get_ids(self, config) -> list[Path]: self.ase_read_args = config.get("ase_read_args", {}) if ":" in self.ase_read_args.get("index", ""): @@ -374,32 +388,6 @@ def get_relaxed_energy(self, identifier): return relaxed_atoms.get_potential_energy(apply_constraint=False) -class dummy_list(list): - def __init__(self, max) -> None: - self.max = max - return - - def __len__(self): - return self.max - - def __getitem__(self, idx): - # Handle slicing - if isinstance(idx, slice): - return [self[i] for i in range(*idx.indices(self.max))] - - # Cast idx as int since it could be a tensor index - idx = int(idx) - - # Handle negative indices (referenced from end) - if idx < 0: - idx += self.max - - if 0 <= idx < self.max: - return idx - else: - raise IndexError - - @registry.register_dataset("ase_db") class AseDBDataset(AseAtomsDataset): """ @@ -444,15 +432,16 @@ class AseDBDataset(AseAtomsDataset): atoms_transform_args (dict): Additional keyword arguments for the atoms_transform callable - transform_args (dict): Additional keyword arguments for the transform callable + transforms (dict[str, dict]): Dictionary specifying data transforms as {transform_function: config} + where config is a dictionary specifying arguments to the transform_function atoms_transform (callable, optional): Additional preprocessing function applied to the Atoms object. Useful for applying tags, for example. - transform (callable, optional): Additional preprocessing function for the Data object + transform (callable, optional): deprecated? """ - def load_dataset_get_ids(self, config) -> dummy_list: + def load_dataset_get_ids(self, config) -> list[int]: if isinstance(config["src"], list): filepaths = config["src"] elif os.path.isfile(config["src"]): @@ -495,7 +484,7 @@ def load_dataset_get_ids(self, config) -> dummy_list: idlens = [len(ids) for ids in self.db_ids] self._idlen_cumulative = np.cumsum(idlens).tolist() - return dummy_list(sum(idlens)) + return list(range(sum(idlens))) def get_atoms_object(self, idx): # Figure out which db this should be indexed from. @@ -515,7 +504,7 @@ def get_atoms_object(self, idx): return atoms - def connect_db(self, address, connect_args={}): + def connect_db(self, address, connect_args: Optional[dict] = None): if connect_args is None: connect_args = {} db_type = connect_args.get("type", "extract_from_name") From 6bb3b81afe8bbc9580c6f81570b3b891f922bca8 Mon Sep 17 00:00:00 2001 From: lbluque Date: Thu, 18 Jan 2024 15:13:10 -0800 Subject: [PATCH 03/64] typo in docstring --- ocpmodels/preprocessing/atoms_to_graphs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ocpmodels/preprocessing/atoms_to_graphs.py b/ocpmodels/preprocessing/atoms_to_graphs.py index 34291f173a..ceaa03039a 100644 --- a/ocpmodels/preprocessing/atoms_to_graphs.py +++ b/ocpmodels/preprocessing/atoms_to_graphs.py @@ -129,7 +129,7 @@ def _reshape_features(self, c_index, n_index, n_distance, offsets): return edge_index, edge_distances, cell_offsets def convert(self, atoms: ase.Atoms, sid=None): - """Convert a single atomic stucture to a graph. + """Convert a single atomic structure to a graph. Args: atoms (ase.atoms.Atoms): An ASE atoms object. From b4614c413fb25c10f2c2603607720ca34762892f Mon Sep 17 00:00:00 2001 From: lbluque Date: Thu, 18 Jan 2024 16:09:43 -0800 Subject: [PATCH 04/64] key_mapping docstring --- ocpmodels/datasets/ase_datasets.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/ocpmodels/datasets/ase_datasets.py b/ocpmodels/datasets/ase_datasets.py index cf52573d93..31a147e4f7 100644 --- a/ocpmodels/datasets/ase_datasets.py +++ b/ocpmodels/datasets/ase_datasets.py @@ -71,9 +71,9 @@ def __init__( self, config: dict, atoms_transform: Callable[ - [ase.Atoms, Any], ase.Atoms + [ase.Atoms, Any, ...], ase.Atoms ] = apply_one_tags, - transform=None, + transform=None, # is this deprecated? ) -> None: self.config = config @@ -227,6 +227,10 @@ class AseReadDataset(AseAtomsDataset): transform_args (dict): Additional keyword arguments for the transform callable + key_mapping (dict[str, str]): Dictionary specifying a mapping between the name of a property used + in the model with the corresponding property as it was named in the dataset. Only need to use if + the name is different. + atoms_transform (callable, optional): Additional preprocessing function applied to the Atoms object. Useful for applying tags, for example. @@ -319,6 +323,10 @@ class AseReadMultiStructureDataset(AseAtomsDataset): transform_args (dict): Additional keyword arguments for the transform callable + key_mapping (dict[str, str]): Dictionary specifying a mapping between the name of a property used + in the model with the corresponding property as it was named in the dataset. Only need to use if + the name is different. + atoms_transform (callable, optional): Additional preprocessing function applied to the Atoms object. Useful for applying tags, for example. @@ -435,6 +443,10 @@ class AseDBDataset(AseAtomsDataset): transforms (dict[str, dict]): Dictionary specifying data transforms as {transform_function: config} where config is a dictionary specifying arguments to the transform_function + key_mapping (dict[str, str]): Dictionary specifying a mapping between the name of a property used + in the model with the corresponding property as it was named in the dataset. Only need to use if + the name is different. + atoms_transform (callable, optional): Additional preprocessing function applied to the Atoms object. Useful for applying tags, for example. @@ -477,6 +489,7 @@ def load_dataset_get_ids(self, config) -> list[int]: if hasattr(db, "ids") and self.select_args == {}: self.db_ids.append(db.ids) else: + # this is the slow alternative self.db_ids.append( [row.id for row in db.select(**self.select_args)] ) From d736b00905815f4e41c410c09d1519e17f1e9780 Mon Sep 17 00:00:00 2001 From: lbluque Date: Thu, 18 Jan 2024 16:49:25 -0800 Subject: [PATCH 05/64] add stress to atoms_to_graphs.py and test --- ocpmodels/preprocessing/atoms_to_graphs.py | 9 +++++++++ tests/preprocessing/atoms.json | 11 ++++++----- tests/preprocessing/test_atoms_to_graphs.py | 9 +++++++++ 3 files changed, 24 insertions(+), 5 deletions(-) diff --git a/ocpmodels/preprocessing/atoms_to_graphs.py b/ocpmodels/preprocessing/atoms_to_graphs.py index ceaa03039a..f5aff17455 100644 --- a/ocpmodels/preprocessing/atoms_to_graphs.py +++ b/ocpmodels/preprocessing/atoms_to_graphs.py @@ -45,6 +45,7 @@ class AtomsToGraphs: radius (int or float): Cutoff radius in Angstroms to search for neighbors. r_energy (bool): Return the energy with other properties. Default is False, so the energy will not be returned. r_forces (bool): Return the forces with other properties. Default is False, so the forces will not be returned. + r_forces (bool): Return the stress with other properties. Default is False, so the stress will not be returned. r_distances (bool): Return the distances with other properties. Default is False, so the distances will not be returned. r_edges (bool): Return interatomic edges with other properties. Default is True, so edges will be returned. @@ -58,6 +59,7 @@ class AtomsToGraphs: radius (int or float): Cutoff radius in Angstoms to search for neighbors. r_energy (bool): Return the energy with other properties. Default is False, so the energy will not be returned. r_forces (bool): Return the forces with other properties. Default is False, so the forces will not be returned. + r_stress (bool): Return the stress with other properties. Default is False, so the stress will not be returned. r_distances (bool): Return the distances with other properties. Default is False, so the distances will not be returned. r_edges (bool): Return interatomic edges with other properties. Default is True, so edges will be returned. @@ -78,11 +80,13 @@ def __init__( r_edges: bool = True, r_fixed: bool = True, r_pbc: bool = False, + r_stress: bool = False, ) -> None: self.max_neigh = max_neigh self.radius = radius self.r_energy = r_energy self.r_forces = r_forces + self.r_stress = r_stress self.r_distances = r_distances self.r_fixed = r_fixed self.r_edges = r_edges @@ -181,6 +185,11 @@ def convert(self, atoms: ase.Atoms, sid=None): if self.r_forces: forces = torch.Tensor(atoms.get_forces(apply_constraint=False)) data.force = forces + if self.r_stress: + stress = torch.Tensor( + atoms.get_stress(apply_constraint=False, voigt=False) + ) + data.stress = stress if self.r_distances and self.r_edges: data.distances = edge_distances if self.r_fixed: diff --git a/tests/preprocessing/atoms.json b/tests/preprocessing/atoms.json index 97c6c47304..098525ffe9 100644 --- a/tests/preprocessing/atoms.json +++ b/tests/preprocessing/atoms.json @@ -1,20 +1,21 @@ {"1": { "calculator": "unknown", "calculator_parameters": {}, - "cell": {"array": {"__ndarray__": [[3, 3], "float64", [0.0, -8.07194878, 0.0, 6.93127032, 0.0, 0.08307657, 0.0, 0.0, 39.37850739]]}, "pbc": {"__ndarray__": [[3], "bool", [true, true, true]]}, "__ase_objtype__": "cell"}, + "cell": {"array": {"__ndarray__": [[3, 3], "float64", [0.0, -8.07194878, 0.0, 6.93127032, 0.0, 0.08307657, 0.0, 0.0, 39.37850739]]}, "__ase_objtype__": "cell"}, "constraints": [{"name": "FixAtoms", "kwargs": {"indices": [2, 3, 5, 6, 7, 9, 10, 11, 12, 14, 15, 17, 18, 19, 20, 22, 23, 24, 26, 27, 28, 30, 31, 33]}}], - "ctime": 20.460198850701047, + "ctime": 24.049479396943177, "energy": -135.66393572, "forces": {"__ndarray__": [[34, 3], "float64", [0.05011766, -0.01973735, 0.23846654, -0.12013861, -0.05240431, -0.22395961, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.10578597, 0.01361956, -0.05699137, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03172177, 0.00066391, -0.01049754, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.00908246, -0.09729627, 0.00726873, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.02260358, -0.09508909, -0.01036104, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03928853, -0.04423657, 0.04053315, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.02912151, 0.05899768, -0.01100117, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.09680946, 0.06950572, 0.05602877, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03057741, 0.10594487, -0.04712197, 0.0, 0.0, 0.0]]}, "initial_charges": {"__ndarray__": [[34], "float64", [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]}, "initial_magmoms": {"__ndarray__": [[34], "float64", [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]}, "momenta": {"__ndarray__": [[34, 3], "float64", [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]}, - "mtime": 20.460198850701047, + "mtime": 24.049479396943177, "numbers": {"__ndarray__": [[34], "int64", [6, 8, 13, 13, 13, 13, 13, 13, 13, 13, 29, 29, 29, 29, 29, 29, 29, 29, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34]]}, "pbc": {"__ndarray__": [[3], "bool", [true, true, true]]}, "positions": {"__ndarray__": [[34, 3], "float64", [-0.3289066593614256, -3.0340615866893037, 27.073342845551938, -0.0750331499077992, -2.8712314914365584, 28.205836912191387, 6.2092629718957655, -4.771209055418616, 21.953210855443853, 3.8988395550000003, -0.735234665418617, 18.643976120697392, 1.636610785518665, -1.2302542698255066, 23.72823397486728, 1.5884161381042343, -4.771209055418616, 15.334741779736007, 2.7436278118957658, -6.789196250418616, 21.91167257044385, 0.433204395, -2.7532218604186167, 18.602437835697394, 5.33707967127947, -3.0430981333485136, 25.502246117362063, 5.054051298104235, -6.789196250418616, 15.376280064736006, 3.8988395550000003, -4.771209055418616, 18.643976120697392, 1.5884161381042343, -0.735234665418617, 15.334741779736007, 6.2092629718957655, -0.735234665418617, 21.953210855443853, 1.7024669335227842, -4.898430878701221, 24.462466125364735, 2.7436278118957658, -2.7532218604186167, 21.91167257044385, 0.433204395, -6.789196250418616, 18.602437835697394, 5.0596241087542175, -7.073912126493459, 24.329534869886448, 5.054051298104235, -2.7532218604186167, 15.376280064736006, 1.5841717747237825, -4.763794809025211, 17.789819163977032, 6.205018677828017, -0.7278204190252113, 14.563661393015645, 3.8945952609322516, -0.7278204190252113, 21.09905389955426, 6.2730609484910635, -5.008717107687484, 24.37936591790035, 5.049806934723782, -6.796610416092535, 17.831357448977034, 2.739383517828017, -2.7606360260925347, 14.522123108015645, 0.4289601009322512, -2.7606360260925347, 21.05751561455426, 2.7016609108638554, -7.122213699359126, 24.33216256212159, 5.058295592171984, -2.7458076140252117, 17.84351570962914, 2.747872175276218, -6.781782004025211, 14.534281368667754, 0.43744868906774886, -6.781782004025211, 21.069673874375603, 3.0271987649116516, -2.983072135599385, 24.66107410517354, 3.903083849067749, -4.778623221092535, 21.111212159375604, 1.5926604321719833, -0.7426488310925348, 17.801977424629143, 6.319541839318875, -0.99856463967624, 24.661108015400288, 6.213507335276218, -4.778623221092535, 14.575819653667754]]}, + "stress": {"__ndarray__": [[3, 3], "float64", [-0.02096864, 0.0, 0.0, 0.0, -0.02096864, 0.0, 0.0, 0.0, -0.02096864]]}, "tags": {"__ndarray__": [[34], "int64", [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]}, - "unique_id": "77df5102462860280bfa6b622c880125", - "user": "bwood"}, + "unique_id": "5b9364a49d520cbbb323adf19a4eac8f", + "user": "lbluque"}, "ids": [1], "nextid": 2} diff --git a/tests/preprocessing/test_atoms_to_graphs.py b/tests/preprocessing/test_atoms_to_graphs.py index b0a035a16e..b92cb8d5d7 100644 --- a/tests/preprocessing/test_atoms_to_graphs.py +++ b/tests/preprocessing/test_atoms_to_graphs.py @@ -27,6 +27,7 @@ def atoms_to_graphs_internals(request) -> None: radius=6, r_energy=True, r_forces=True, + r_stress=True, r_distances=True, ) request.cls.atg = test_object @@ -106,6 +107,10 @@ def test_convert(self) -> None: act_forces = self.atoms.get_forces(apply_constraint=False) forces = data.force.numpy() np.testing.assert_allclose(act_forces, forces) + # stress + act_stress = self.atoms.get_stress(apply_constraint=False, voigt=False) + stress = data.stress.numpy() + np.testing.assert_allclose(act_stress, stress) def test_convert_all(self) -> None: # run convert_all on a list with one atoms object @@ -129,3 +134,7 @@ def test_convert_all(self) -> None: act_forces = self.atoms.get_forces(apply_constraint=False) forces = data_list[0].force.numpy() np.testing.assert_allclose(act_forces, forces) + # stress + act_stress = self.atoms.get_stress(apply_constraint=False, voigt=False) + stress = data_list[0].stress.numpy() + np.testing.assert_allclose(act_stress, stress) From 0a17008c79e3cd5514577d8ba08beba369a4b352 Mon Sep 17 00:00:00 2001 From: lbluque Date: Fri, 19 Jan 2024 11:47:44 -0800 Subject: [PATCH 06/64] allow adding target properties in atoms.info --- ocpmodels/datasets/ase_datasets.py | 6 +++--- ocpmodels/preprocessing/atoms_to_graphs.py | 16 +++++++++++++-- tests/preprocessing/atoms.json | 6 +++--- tests/preprocessing/test_atoms_to_graphs.py | 22 +++++++++++++++++++++ 4 files changed, 42 insertions(+), 8 deletions(-) diff --git a/ocpmodels/datasets/ase_datasets.py b/ocpmodels/datasets/ase_datasets.py index 31a147e4f7..97a45b5c27 100644 --- a/ocpmodels/datasets/ase_datasets.py +++ b/ocpmodels/datasets/ase_datasets.py @@ -210,7 +210,7 @@ class AseReadDataset(AseAtomsDataset): default options will work for most users If you are using this for a training dataset, set - "r_energy":True and/or "r_forces":True as appropriate + "r_energy":True, "r_forces":True, and/or "r_stress":True as appropriate In that case, energy/forces must be in the files you read (ex. OUTCAR) ase_read_args (dict): Keyword arguments for ase.io.read() @@ -304,7 +304,7 @@ class AseReadMultiStructureDataset(AseAtomsDataset): default options will work for most users If you are using this for a training dataset, set - "r_energy":True and/or "r_forces":True as appropriate + "r_energy":True, "r_forces":True, and/or "r_stress":True as appropriate In that case, energy/forces must be in the files you read (ex. OUTCAR) ase_read_args (dict): Keyword arguments for ase.io.read() @@ -431,7 +431,7 @@ class AseDBDataset(AseAtomsDataset): default options will work for most users If you are using this for a training dataset, set - "r_energy":True and/or "r_forces":True as appropriate + "r_energy":True, "r_forces":True, and/or "r_stress":True as appropriate In that case, energy/forces must be in the database keep_in_memory (bool): Store data in memory. This helps avoid random reads if you need diff --git a/ocpmodels/preprocessing/atoms_to_graphs.py b/ocpmodels/preprocessing/atoms_to_graphs.py index f5aff17455..437ac6949f 100644 --- a/ocpmodels/preprocessing/atoms_to_graphs.py +++ b/ocpmodels/preprocessing/atoms_to_graphs.py @@ -5,7 +5,9 @@ LICENSE file in the root directory of this source tree. """ -from typing import Optional +from __future__ import annotations + +from typing import Optional, Sequence import ase.db.sqlite import ase.io.trajectory @@ -53,6 +55,8 @@ class AtomsToGraphs: Default is True, so the fixed indices will be returned. r_pbc (bool): Return the periodic boundary conditions with other properties. Default is False, so the periodic boundary conditions will not be returned. + r_data_keys (sequence of str, optional): Return values corresponding to given keys in atoms.data with other + properties. Default is None, so no data will be returned as properties. Attributes: max_neigh (int): Maximum number of neighbors to consider. @@ -67,7 +71,8 @@ class AtomsToGraphs: Default is True, so the fixed indices will be returned. r_pbc (bool): Return the periodic boundary conditions with other properties. Default is False, so the periodic boundary conditions will not be returned. - + r_data_keys (sequence of str, optional): Return values corresponding to given keys in atoms.data with other + properties. Default is None, so no data will be returned as properties. """ def __init__( @@ -81,6 +86,7 @@ def __init__( r_fixed: bool = True, r_pbc: bool = False, r_stress: bool = False, + r_data_keys: Optional[Sequence[str]] = None, ) -> None: self.max_neigh = max_neigh self.radius = radius @@ -91,6 +97,7 @@ def __init__( self.r_fixed = r_fixed self.r_edges = r_edges self.r_pbc = r_pbc + self.r_data_keys = r_data_keys def _get_neighbors_pymatgen(self, atoms: ase.Atoms): """Preforms nearest neighbor search and returns edge index, distances, @@ -203,6 +210,11 @@ def convert(self, atoms: ase.Atoms, sid=None): data.fixed = fixed_idx if self.r_pbc: data.pbc = torch.tensor(atoms.pbc) + if self.r_data_keys is not None: + for ( + data_key + ) in self.r_data_keys: # if key is not present let it raise error + data[data_key] = torch.Tensor(atoms.info[data_key]) return data diff --git a/tests/preprocessing/atoms.json b/tests/preprocessing/atoms.json index 098525ffe9..86d47cf6b6 100644 --- a/tests/preprocessing/atoms.json +++ b/tests/preprocessing/atoms.json @@ -3,19 +3,19 @@ "calculator_parameters": {}, "cell": {"array": {"__ndarray__": [[3, 3], "float64", [0.0, -8.07194878, 0.0, 6.93127032, 0.0, 0.08307657, 0.0, 0.0, 39.37850739]]}, "__ase_objtype__": "cell"}, "constraints": [{"name": "FixAtoms", "kwargs": {"indices": [2, 3, 5, 6, 7, 9, 10, 11, 12, 14, 15, 17, 18, 19, 20, 22, 23, 24, 26, 27, 28, 30, 31, 33]}}], - "ctime": 24.049479396943177, + "ctime": 24.049558230324397, "energy": -135.66393572, "forces": {"__ndarray__": [[34, 3], "float64", [0.05011766, -0.01973735, 0.23846654, -0.12013861, -0.05240431, -0.22395961, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.10578597, 0.01361956, -0.05699137, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03172177, 0.00066391, -0.01049754, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.00908246, -0.09729627, 0.00726873, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.02260358, -0.09508909, -0.01036104, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03928853, -0.04423657, 0.04053315, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.02912151, 0.05899768, -0.01100117, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, -0.09680946, 0.06950572, 0.05602877, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.03057741, 0.10594487, -0.04712197, 0.0, 0.0, 0.0]]}, "initial_charges": {"__ndarray__": [[34], "float64", [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]}, "initial_magmoms": {"__ndarray__": [[34], "float64", [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]}, "momenta": {"__ndarray__": [[34, 3], "float64", [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]}, - "mtime": 24.049479396943177, + "mtime": 24.049558230324397, "numbers": {"__ndarray__": [[34], "int64", [6, 8, 13, 13, 13, 13, 13, 13, 13, 13, 29, 29, 29, 29, 29, 29, 29, 29, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34, 34]]}, "pbc": {"__ndarray__": [[3], "bool", [true, true, true]]}, "positions": {"__ndarray__": [[34, 3], "float64", [-0.3289066593614256, -3.0340615866893037, 27.073342845551938, -0.0750331499077992, -2.8712314914365584, 28.205836912191387, 6.2092629718957655, -4.771209055418616, 21.953210855443853, 3.8988395550000003, -0.735234665418617, 18.643976120697392, 1.636610785518665, -1.2302542698255066, 23.72823397486728, 1.5884161381042343, -4.771209055418616, 15.334741779736007, 2.7436278118957658, -6.789196250418616, 21.91167257044385, 0.433204395, -2.7532218604186167, 18.602437835697394, 5.33707967127947, -3.0430981333485136, 25.502246117362063, 5.054051298104235, -6.789196250418616, 15.376280064736006, 3.8988395550000003, -4.771209055418616, 18.643976120697392, 1.5884161381042343, -0.735234665418617, 15.334741779736007, 6.2092629718957655, -0.735234665418617, 21.953210855443853, 1.7024669335227842, -4.898430878701221, 24.462466125364735, 2.7436278118957658, -2.7532218604186167, 21.91167257044385, 0.433204395, -6.789196250418616, 18.602437835697394, 5.0596241087542175, -7.073912126493459, 24.329534869886448, 5.054051298104235, -2.7532218604186167, 15.376280064736006, 1.5841717747237825, -4.763794809025211, 17.789819163977032, 6.205018677828017, -0.7278204190252113, 14.563661393015645, 3.8945952609322516, -0.7278204190252113, 21.09905389955426, 6.2730609484910635, -5.008717107687484, 24.37936591790035, 5.049806934723782, -6.796610416092535, 17.831357448977034, 2.739383517828017, -2.7606360260925347, 14.522123108015645, 0.4289601009322512, -2.7606360260925347, 21.05751561455426, 2.7016609108638554, -7.122213699359126, 24.33216256212159, 5.058295592171984, -2.7458076140252117, 17.84351570962914, 2.747872175276218, -6.781782004025211, 14.534281368667754, 0.43744868906774886, -6.781782004025211, 21.069673874375603, 3.0271987649116516, -2.983072135599385, 24.66107410517354, 3.903083849067749, -4.778623221092535, 21.111212159375604, 1.5926604321719833, -0.7426488310925348, 17.801977424629143, 6.319541839318875, -0.99856463967624, 24.661108015400288, 6.213507335276218, -4.778623221092535, 14.575819653667754]]}, "stress": {"__ndarray__": [[3, 3], "float64", [-0.02096864, 0.0, 0.0, 0.0, -0.02096864, 0.0, 0.0, 0.0, -0.02096864]]}, "tags": {"__ndarray__": [[34], "int64", [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]}, - "unique_id": "5b9364a49d520cbbb323adf19a4eac8f", + "unique_id": "e0e1e88b155869bc277be4b4733ed932", "user": "lbluque"}, "ids": [1], "nextid": 2} diff --git a/tests/preprocessing/test_atoms_to_graphs.py b/tests/preprocessing/test_atoms_to_graphs.py index b92cb8d5d7..395e6a6769 100644 --- a/tests/preprocessing/test_atoms_to_graphs.py +++ b/tests/preprocessing/test_atoms_to_graphs.py @@ -22,6 +22,17 @@ def atoms_to_graphs_internals(request) -> None: index=0, format="json", ) + atoms.info["stiffness_tensor"] = np.array( + [ + [293, 121, 121, 0, 0, 0], + [121, 293, 121, 0, 0, 0], + [121, 121, 293, 0, 0, 0], + [0, 0, 0, 146, 0, 0], + [0, 0, 0, 0, 146, 0], + [0, 0, 0, 0, 0, 146], + ], + dtype=float, + ) test_object = AtomsToGraphs( max_neigh=200, radius=6, @@ -29,6 +40,7 @@ def atoms_to_graphs_internals(request) -> None: r_forces=True, r_stress=True, r_distances=True, + r_data_keys=["stiffness_tensor"], ) request.cls.atg = test_object request.cls.atoms = atoms @@ -111,6 +123,11 @@ def test_convert(self) -> None: act_stress = self.atoms.get_stress(apply_constraint=False, voigt=False) stress = data.stress.numpy() np.testing.assert_allclose(act_stress, stress) + # additional data (ie stiffness_tensor) + stiffness_tensor = data.stiffness_tensor.numpy() + np.testing.assert_allclose( + self.atoms.info["stiffness_tensor"], stiffness_tensor + ) def test_convert_all(self) -> None: # run convert_all on a list with one atoms object @@ -138,3 +155,8 @@ def test_convert_all(self) -> None: act_stress = self.atoms.get_stress(apply_constraint=False, voigt=False) stress = data_list[0].stress.numpy() np.testing.assert_allclose(act_stress, stress) + # additional data (ie stiffness_tensor) + stiffness_tensor = data_list[0].stiffness_tensor.numpy() + np.testing.assert_allclose( + self.atoms.info["stiffness_tensor"], stiffness_tensor + ) From 3a7f8107f2782e9e53363998975d54acec512569 Mon Sep 17 00:00:00 2001 From: lbluque Date: Tue, 23 Jan 2024 10:04:28 -0800 Subject: [PATCH 07/64] test using generic tensor property in ase_datasets --- tests/datasets/test_ase_datasets.py | 69 ++++++++++++----------------- 1 file changed, 28 insertions(+), 41 deletions(-) diff --git a/tests/datasets/test_ase_datasets.py b/tests/datasets/test_ase_datasets.py index d1767c9782..a272b6ef58 100644 --- a/tests/datasets/test_ase_datasets.py +++ b/tests/datasets/test_ase_datasets.py @@ -18,9 +18,15 @@ build.fcc111("Pt", size=[2, 2, 3], vacuum=8, periodic=True), ] for atoms in structures: - calc = SinglePointCalculator(atoms, energy=1, forces=atoms.positions) + calc = SinglePointCalculator( + atoms, + energy=1, + forces=atoms.positions, + stress=np.random.random((3, 3)), + ) atoms.calc = calc - atoms.info["test_extensive_property"] = 3 * len(atoms) + atoms.info["extensive_property"] = 3 * len(atoms) + atoms.info["tensor_property"] = np.random.random((6, 6)) structures[2].set_pbc(True) @@ -55,39 +61,18 @@ def test_ase_read_dataset() -> None: ) -def test_ase_db_dataset() -> None: - try: - os.remove( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb.db" - ) - ) - except FileNotFoundError: - pass - - with db.connect( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb.db") - ) as database: +def test_ase_db_dataset(tmp_path) -> None: + with db.connect(tmp_path / "asedb.db") as database: for i, structure in enumerate(structures): database.write(structure) - dataset = AseDBDataset( - config={ - "src": os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb.db" - ), - } - ) + dataset = AseDBDataset(config={"src": str(tmp_path / "asedb.db")}) assert len(dataset) == len(structures) data = dataset[0] del data - os.remove( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb.db") - ) - def test_ase_db_dataset_folder() -> None: try: @@ -259,17 +244,23 @@ def test_lmdb_metadata_guesser() -> None: assert metadata["targets"]["forces"]["extensive"] is True assert metadata["targets"]["forces"]["type"] == "per-atom" - # Confirm forces metadata guessed properly - assert ( - metadata["targets"]["info.test_extensive_property"]["extensive"] - is True - ) - assert metadata["targets"]["info.test_extensive_property"]["shape"] == () + # Confirm stress metadata guessed properly + assert metadata["targets"]["stress"]["shape"] == (3, 3) + assert metadata["targets"]["stress"]["extensive"] is False + assert metadata["targets"]["stress"]["type"] == "per-image" + + # Confirm extensive_property metadata guessed properly + assert metadata["targets"]["info.extensive_property"]["extensive"] is True + assert metadata["targets"]["info.extensive_property"]["shape"] == () assert ( - metadata["targets"]["info.test_extensive_property"]["type"] - == "per-image" + metadata["targets"]["info.extensive_property"]["type"] == "per-image" ) + # Confirm tensor_property metadata guessed properly + assert metadata["targets"]["info.tensor_property"]["extensive"] is False + assert metadata["targets"]["info.tensor_property"]["shape"] == (6, 6) + assert metadata["targets"]["info.tensor_property"]["type"] == "per-image" + os.remove( os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb.lmdb") ) @@ -312,14 +303,10 @@ def test_ase_metadata_guesser() -> None: assert metadata["targets"]["forces"]["type"] == "per-atom" # Confirm forces metadata guessed properly + assert metadata["targets"]["info.extensive_property"]["extensive"] is True + assert metadata["targets"]["info.extensive_property"]["shape"] == () assert ( - metadata["targets"]["info.test_extensive_property"]["extensive"] - is True - ) - assert metadata["targets"]["info.test_extensive_property"]["shape"] == () - assert ( - metadata["targets"]["info.test_extensive_property"]["type"] - == "per-image" + metadata["targets"]["info.extensive_property"]["type"] == "per-image" ) dataset = AseDBDataset( From f47a0b8145a7fabc6fb3778d948ec9b88eb1db47 Mon Sep 17 00:00:00 2001 From: lbluque Date: Tue, 23 Jan 2024 10:04:53 -0800 Subject: [PATCH 08/64] minor docstring/comments --- ocpmodels/datasets/ase_datasets.py | 1 + ocpmodels/preprocessing/atoms_to_graphs.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/ocpmodels/datasets/ase_datasets.py b/ocpmodels/datasets/ase_datasets.py index 97a45b5c27..d738fe810e 100644 --- a/ocpmodels/datasets/ase_datasets.py +++ b/ocpmodels/datasets/ase_datasets.py @@ -512,6 +512,7 @@ def get_atoms_object(self, idx): atoms_row = self.dbs[db_idx]._get_row(self.db_ids[db_idx][el_idx]) atoms = atoms_row.toatoms() + # put data back into atoms info if isinstance(atoms_row.data, dict): atoms.info.update(atoms_row.data) diff --git a/ocpmodels/preprocessing/atoms_to_graphs.py b/ocpmodels/preprocessing/atoms_to_graphs.py index 437ac6949f..48897db50b 100644 --- a/ocpmodels/preprocessing/atoms_to_graphs.py +++ b/ocpmodels/preprocessing/atoms_to_graphs.py @@ -55,7 +55,7 @@ class AtomsToGraphs: Default is True, so the fixed indices will be returned. r_pbc (bool): Return the periodic boundary conditions with other properties. Default is False, so the periodic boundary conditions will not be returned. - r_data_keys (sequence of str, optional): Return values corresponding to given keys in atoms.data with other + r_data_keys (sequence of str, optional): Return values corresponding to given keys in atoms.info data with other properties. Default is None, so no data will be returned as properties. Attributes: @@ -71,7 +71,7 @@ class AtomsToGraphs: Default is True, so the fixed indices will be returned. r_pbc (bool): Return the periodic boundary conditions with other properties. Default is False, so the periodic boundary conditions will not be returned. - r_data_keys (sequence of str, optional): Return values corresponding to given keys in atoms.data with other + r_data_keys (sequence of str, optional): Return values corresponding to given keys in atoms.info data with other properties. Default is None, so no data will be returned as properties. """ From c2a789e249b4108224539b7f542b6f993b3897f9 Mon Sep 17 00:00:00 2001 From: lbluque Date: Tue, 23 Jan 2024 14:57:46 -0800 Subject: [PATCH 09/64] handle stress in voigt notation in metadata guesser --- ocpmodels/datasets/target_metadata_guesser.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/ocpmodels/datasets/target_metadata_guesser.py b/ocpmodels/datasets/target_metadata_guesser.py index 844bd8127f..0ee58e80a9 100644 --- a/ocpmodels/datasets/target_metadata_guesser.py +++ b/ocpmodels/datasets/target_metadata_guesser.py @@ -1,6 +1,7 @@ import logging import numpy as np +from ase.stress import voigt_6_to_full_3x3_stress def uniform_atoms_lengths(atoms_lens) -> bool: @@ -184,6 +185,16 @@ def guess_property_metadata(atoms_list): np.array(atoms.calc.results[key]) for atoms in atoms_list ] + # stress needs to be handled separately in case it was saved in voigt (6, ) notation + # atoms2graphs will always request voigt=False so turn it into full 3x3 + if key == "stress": + target_samples = [ + voigt_6_to_full_3x3_stress(sample) + if sample.shape != (3, 3) + else sample + for sample in target_samples + ] + # Guess the metadata targets[f"{key}"] = guess_target_metadata( atoms_len, target_samples From 47f4578042e12e142108f4e4b5cbce1be0040327 Mon Sep 17 00:00:00 2001 From: lbluque Date: Tue, 23 Jan 2024 17:47:34 -0800 Subject: [PATCH 10/64] handle scalar generic values in a2g --- ocpmodels/preprocessing/atoms_to_graphs.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ocpmodels/preprocessing/atoms_to_graphs.py b/ocpmodels/preprocessing/atoms_to_graphs.py index 48897db50b..29e7246213 100644 --- a/ocpmodels/preprocessing/atoms_to_graphs.py +++ b/ocpmodels/preprocessing/atoms_to_graphs.py @@ -214,7 +214,11 @@ def convert(self, atoms: ase.Atoms, sid=None): for ( data_key ) in self.r_data_keys: # if key is not present let it raise error - data[data_key] = torch.Tensor(atoms.info[data_key]) + data[data_key] = ( + atoms.info[data_key] + if isinstance(atoms.info[data_key], (int, float)) + else torch.Tensor(atoms.info[data_key]) + ) return data From 48dc7d030d45f637a00ea4c079f85e0170bd03a4 Mon Sep 17 00:00:00 2001 From: lbluque Date: Wed, 24 Jan 2024 11:23:16 -0800 Subject: [PATCH 11/64] clean up ase dataset unit tests --- tests/datasets/test_ase_datasets.py | 404 +++++++--------------------- 1 file changed, 91 insertions(+), 313 deletions(-) diff --git a/tests/datasets/test_ase_datasets.py b/tests/datasets/test_ase_datasets.py index a272b6ef58..41a5b0c20c 100644 --- a/tests/datasets/test_ase_datasets.py +++ b/tests/datasets/test_ase_datasets.py @@ -1,6 +1,5 @@ -import os - import numpy as np +import pytest from ase import build, db from ase.calculators.singlepoint import SinglePointCalculator from ase.io import Trajectory, write @@ -22,7 +21,9 @@ atoms, energy=1, forces=atoms.positions, - stress=np.random.random((3, 3)), + # there is an issue with ASE db when writing a db with 3x3 stress it is flattened to (9,) and then + # errors when trying to read it + stress=np.random.random((6,)), ) atoms.calc = calc atoms.info["extensive_property"] = 3 * len(atoms) @@ -31,206 +32,92 @@ structures[2].set_pbc(True) -def test_ase_read_dataset() -> None: - for i, structure in enumerate(structures): - write( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), f"{i}.cif" - ), - structure, - ) - - dataset = AseReadDataset( - config={ - "src": os.path.join(os.path.dirname(os.path.abspath(__file__))), - "pattern": "*.cif", - } - ) - - assert len(dataset) == len(structures) - data = dataset[0] - del data - - dataset.close_db() - - for i in range(len(structures)): - os.remove( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), f"{i}.cif" - ) - ) - - -def test_ase_db_dataset(tmp_path) -> None: - with db.connect(tmp_path / "asedb.db") as database: - for i, structure in enumerate(structures): - database.write(structure) - - dataset = AseDBDataset(config={"src": str(tmp_path / "asedb.db")}) - - assert len(dataset) == len(structures) - data = dataset[0] - - del data - - -def test_ase_db_dataset_folder() -> None: - try: - os.remove( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb1.db" - ) - ) - os.remove( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb2.db" - ) +@pytest.fixture( + scope="function", + params=[ + "db_dataset", + "db_dataset_folder", + "db_dataset_list", + "lmdb_dataset", + ], +) +def ase_dataset(request, tmp_path_factory): + tmp_path = tmp_path_factory.mktemp("dataset") + mult = 1 + a2g_args = { + "r_energy": True, + "r_forces": True, + "r_stress": True, + "r_data_keys": ["extensive_property", "tensor_property"], + } + if request.param == "db_dataset": + with db.connect(tmp_path / "asedb.db") as database: + for i, atoms in enumerate(structures): + database.write(atoms, data=atoms.info) + dataset = AseDBDataset( + config={"src": str(tmp_path / "asedb.db"), "a2g_args": a2g_args} ) - except FileNotFoundError: - pass - - with db.connect( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb1.db") - ) as database: - for i, structure in enumerate(structures): - database.write(structure) - - with db.connect( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb2.db") - ) as database: - for i, structure in enumerate(structures): - database.write(structure) - - dataset = AseDBDataset( - config={ - "src": os.path.join( - os.path.dirname(os.path.abspath(__file__)), "./" - ), - } - ) - - assert len(dataset) == len(structures) * 2 - data = dataset[0] - del data - - os.remove( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb1.db") - ) - os.remove( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb2.db") - ) - - -def test_ase_db_dataset_list() -> None: - try: - os.remove( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb1.db" - ) + elif ( + request.param == "db_dataset_folder" + or request.param == "db_dataset_list" + ): + with db.connect(tmp_path / "asedb1.db") as database: + for i, atoms in enumerate(structures): + database.write(atoms, data=atoms.info) + with db.connect(tmp_path / "asedb2.db") as database: + for i, atoms in enumerate(structures): + database.write(atoms, data=atoms.info) + mult = 2 + src = ( + str(tmp_path) + if request.param == "db_dataset_folder" + else [str(tmp_path / "asedb1.db"), str(tmp_path / "asedb2.db")] ) - os.remove( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb2.db" - ) + dataset = AseDBDataset(config={"src": src, "a2g_args": a2g_args}) + else: # "lmbd_dataset" + with LMDBDatabase(str(tmp_path / "asedb.lmdb")) as database: + for i, atoms in enumerate(structures): + database.write(atoms, data=atoms.info) + + dataset = AseDBDataset( + config={"src": str(tmp_path / "asedb.lmdb"), "a2g_args": a2g_args} ) - except FileNotFoundError: - pass - - with db.connect( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb1.db") - ) as database: - for i, structure in enumerate(structures): - database.write(structure) - - with db.connect( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb2.db") - ) as database: - for i, structure in enumerate(structures): - database.write(structure) - - dataset = AseDBDataset( - config={ - "src": [ - os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb1.db" - ), - os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb2.db" - ), - ] - } - ) - assert len(dataset) == len(structures) * 2 - data = dataset[0] - del data + return dataset, mult - os.remove( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb1.db") - ) - os.remove( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb2.db") - ) +def test_ase_dataset(ase_dataset): + dataset, mult = ase_dataset + assert len(dataset) == mult * len(structures) + for data in dataset: + assert hasattr(data, "y") + assert data.force.shape == (data.natoms, 3) + assert data.stress.shape == (3, 3) + assert data.tensor_property.shape == (6, 6) + assert isinstance(data.extensive_property, int) -def test_ase_lmdb_dataset() -> None: - try: - os.remove( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb.lmdb" - ) - ) - except FileNotFoundError: - pass - with LMDBDatabase( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb.lmdb") - ) as database: - for i, structure in enumerate(structures): - database.write(structure) +def test_ase_read_dataset(tmp_path) -> None: + # unfortunately there is currently no clean (already implemented) way to save atoms.info when saving + # individual structures - so test separately + for i, structure in enumerate(structures): + write(tmp_path / f"{i}.cif", structure) - dataset = AseDBDataset( + dataset = AseReadDataset( config={ - "src": os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb.lmdb" - ), + "src": str(tmp_path), + "pattern": "*.cif", } ) assert len(dataset) == len(structures) data = dataset[0] del data + dataset.close_db() - os.remove( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb.lmdb") - ) - - -def test_lmdb_metadata_guesser() -> None: - # Cleanup old lmdb in case it's left over from previous tests - try: - os.remove( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb.lmdb" - ) - ) - except FileNotFoundError: - pass - - # Write an LMDB - with LMDBDatabase( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb.lmdb") - ) as database: - for i, structure in enumerate(structures): - database.write(structure, data=structure.info) - dataset = AseDBDataset( - config={ - "src": os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb.lmdb" - ), - } - ) +def test_ase_metadata_guesser(ase_dataset) -> None: + dataset, _ = ase_dataset metadata = dataset.get_metadata() @@ -261,61 +148,15 @@ def test_lmdb_metadata_guesser() -> None: assert metadata["targets"]["info.tensor_property"]["shape"] == (6, 6) assert metadata["targets"]["info.tensor_property"]["type"] == "per-image" - os.remove( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb.lmdb") - ) - - -def test_ase_metadata_guesser() -> None: - try: - os.remove( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb.db" - ) - ) - except FileNotFoundError: - pass - - with db.connect( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb.db") - ) as database: - for i, structure in enumerate(structures): - database.write(structure, data=structure.info) - - dataset = AseDBDataset( - config={ - "src": os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb.db" - ), - } - ) - - metadata = dataset.get_metadata() - - # Confirm energy metadata guessed properly - assert metadata["targets"]["energy"]["extensive"] is False - assert metadata["targets"]["energy"]["shape"] == () - assert metadata["targets"]["energy"]["type"] == "per-image" - - # Confirm forces metadata guessed properly - assert metadata["targets"]["forces"]["shape"] == (3,) - assert metadata["targets"]["forces"]["extensive"] is True - assert metadata["targets"]["forces"]["type"] == "per-atom" - # Confirm forces metadata guessed properly - assert metadata["targets"]["info.extensive_property"]["extensive"] is True - assert metadata["targets"]["info.extensive_property"]["shape"] == () - assert ( - metadata["targets"]["info.extensive_property"]["type"] == "per-image" - ) +def test_db_add_delete(tmp_path) -> None: + database = db.connect(tmp_path / "asedb.db") + for i, atoms in enumerate(structures): + database.write(atoms, data=atoms.info) - dataset = AseDBDataset( - config={ - "src": os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb.db" - ), - } - ) + dataset = AseDBDataset(config={"src": str(tmp_path / "asedb.db")}) + assert len(dataset) == len(structures) + orig_len = len(dataset) database.delete([1]) @@ -324,55 +165,20 @@ def test_ase_metadata_guesser() -> None: build.bulk("Al"), ] - for i, structure in enumerate(new_structures): - database.write(structure) - - dataset = AseDBDataset( - config={ - "src": os.path.join( - os.path.dirname(os.path.abspath(__file__)), "asedb.db" - ), - } - ) - - assert len(dataset) == len(structures) + len(new_structures) - 1 - data = dataset[:] - assert data - - os.remove( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb.db") - ) + for i, atoms in enumerate(new_structures): + database.write(atoms, data=atoms.info) + dataset = AseDBDataset(config={"src": str(tmp_path / "asedb.db")}) + assert len(dataset) == orig_len + len(new_structures) - 1 dataset.close_db() -def test_ase_multiread_dataset() -> None: - try: - os.remove( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), "test.traj" - ) - ) - except FileNotFoundError: - pass - - try: - os.remove( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), "test_index_file" - ) - ) - except FileNotFoundError: - pass - +def test_ase_multiread_dataset(tmp_path) -> None: atoms_objects = [build.bulk("Cu", a=a) for a in np.linspace(3.5, 3.7, 10)] energies = np.linspace(1, 0, len(atoms_objects)) - traj = Trajectory( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "test.traj"), - mode="w", - ) + traj = Trajectory(tmp_path / "test.traj", mode="w") for atoms, energy in zip(atoms_objects, energies): calc = SinglePointCalculator( @@ -383,7 +189,7 @@ def test_ase_multiread_dataset() -> None: dataset = AseReadMultiStructureDataset( config={ - "src": os.path.join(os.path.dirname(os.path.abspath(__file__))), + "src": str(tmp_path), "pattern": "*.traj", "keep_in_memory": True, "atoms_transform_args": { @@ -393,35 +199,19 @@ def test_ase_multiread_dataset() -> None: ) assert len(dataset) == len(atoms_objects) - [dataset[:]] - f = open( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), "test_index_file" - ), - "w", - ) - f.write( - f"{os.path.join(os.path.dirname(os.path.abspath(__file__)), 'test.traj')} {len(atoms_objects)}" - ) - f.close() + with open(tmp_path / "test_index_file", "w") as f: + f.write(f"{tmp_path / 'test.traj'} {len(atoms_objects)}") dataset = AseReadMultiStructureDataset( - config={ - "index_file": os.path.join( - os.path.dirname(os.path.abspath(__file__)), "test_index_file" - ) - }, + config={"index_file": str(tmp_path / "test_index_file")}, ) assert len(dataset) == len(atoms_objects) - [dataset[:]] dataset = AseReadMultiStructureDataset( config={ - "index_file": os.path.join( - os.path.dirname(os.path.abspath(__file__)), "test_index_file" - ), + "index_file": str(tmp_path / "test_index_file"), "a2g_args": { "r_energy": True, "r_forces": True, @@ -431,7 +221,6 @@ def test_ase_multiread_dataset() -> None: ) assert len(dataset) == len(atoms_objects) - [dataset[:]] assert hasattr(dataset[0], "y_relaxed") assert dataset[0].y_relaxed != dataset[0].y @@ -439,7 +228,7 @@ def test_ase_multiread_dataset() -> None: dataset = AseReadDataset( config={ - "src": os.path.join(os.path.dirname(os.path.abspath(__file__))), + "src": str(tmp_path), "pattern": "*.traj", "ase_read_args": { "index": "0", @@ -452,16 +241,5 @@ def test_ase_multiread_dataset() -> None: } ) - [dataset[:]] - assert hasattr(dataset[0], "y_relaxed") assert dataset[0].y_relaxed != dataset[0].y - - os.remove( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "test.traj") - ) - os.remove( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), "test_index_file" - ) - ) From 8549411bafe925cb843172a9db7cce9e09c7cd85 Mon Sep 17 00:00:00 2001 From: lbluque Date: Thu, 25 Jan 2024 13:53:11 -0800 Subject: [PATCH 12/64] allow .aselmdb extensions --- ocpmodels/datasets/ase_datasets.py | 5 +++-- tests/datasets/test_ase_datasets.py | 11 ++++++++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/ocpmodels/datasets/ase_datasets.py b/ocpmodels/datasets/ase_datasets.py index d738fe810e..948bb447fc 100644 --- a/ocpmodels/datasets/ase_datasets.py +++ b/ocpmodels/datasets/ase_datasets.py @@ -522,8 +522,9 @@ def connect_db(self, address, connect_args: Optional[dict] = None): if connect_args is None: connect_args = {} db_type = connect_args.get("type", "extract_from_name") - if db_type == "lmdb" or ( - db_type == "extract_from_name" and address.split(".")[-1] == "lmdb" + if db_type in ("lmdb", "aselmdb") or ( + db_type == "extract_from_name" + and address.split(".")[-1] in ("lmdb", "aselmdb") ): return LMDBDatabase(address, readonly=True, **connect_args) else: diff --git a/tests/datasets/test_ase_datasets.py b/tests/datasets/test_ase_datasets.py index 41a5b0c20c..38b5218dac 100644 --- a/tests/datasets/test_ase_datasets.py +++ b/tests/datasets/test_ase_datasets.py @@ -39,6 +39,7 @@ "db_dataset_folder", "db_dataset_list", "lmdb_dataset", + "aselmdb_dataset", ], ) def ase_dataset(request, tmp_path_factory): @@ -74,7 +75,15 @@ def ase_dataset(request, tmp_path_factory): else [str(tmp_path / "asedb1.db"), str(tmp_path / "asedb2.db")] ) dataset = AseDBDataset(config={"src": src, "a2g_args": a2g_args}) - else: # "lmbd_dataset" + elif request.param == "lmbd_dataset": + with LMDBDatabase(str(tmp_path / "asedb.lmdb")) as database: + for i, atoms in enumerate(structures): + database.write(atoms, data=atoms.info) + + dataset = AseDBDataset( + config={"src": str(tmp_path / "asedb.lmdb"), "a2g_args": a2g_args} + ) + else: # "aselmbd_dataset" with .aselmdb file extension with LMDBDatabase(str(tmp_path / "asedb.lmdb")) as database: for i, atoms in enumerate(structures): database.write(atoms, data=atoms.info) From 3371cae583af49d52e30868e32a4a7e9d5f1cd1c Mon Sep 17 00:00:00 2001 From: lbluque Date: Thu, 25 Jan 2024 14:51:35 -0800 Subject: [PATCH 13/64] fix minor bugs in lmdb database and update tests --- ocpmodels/datasets/lmdb_database.py | 91 ++++++++++++++--------- tests/datasets/test_ase_lmdb.py | 110 +++++++++------------------- 2 files changed, 91 insertions(+), 110 deletions(-) diff --git a/ocpmodels/datasets/lmdb_database.py b/ocpmodels/datasets/lmdb_database.py index fdc345774f..8d09f99472 100644 --- a/ocpmodels/datasets/lmdb_database.py +++ b/ocpmodels/datasets/lmdb_database.py @@ -13,6 +13,7 @@ import os import typing import zlib +from pathlib import Path from typing import Optional import lmdb @@ -34,7 +35,7 @@ class LMDBDatabase(Database): def __init__( self, - filename: Optional[str] = None, + filename: Optional[str | Path] = None, create_indices: bool = True, use_lock_file: bool = False, serial: bool = False, @@ -47,7 +48,12 @@ def __init__( arguments, except that we add a readonly flag. """ super().__init__( - filename, create_indices, use_lock_file, serial, *args, **kwargs + Path(filename), + create_indices, + use_lock_file, + serial, + *args, + **kwargs, ) # Add a readonly mode for when we're only training @@ -57,7 +63,7 @@ def __init__( if self.readonly: # Open a new env self.env = lmdb.open( - self.filename, + str(self.filename), subdir=False, meminit=False, map_async=True, @@ -71,7 +77,7 @@ def __init__( else: # Open a new env with write access self.env = lmdb.open( - self.filename, + str(self.filename), map_size=1099511627776 * 2, subdir=False, meminit=False, @@ -81,6 +87,8 @@ def __init__( self.txn = self.env.begin(write=True) # Load all ids based on keys in the DB. + self.ids = [] + self.deleted_ids = [] self._load_ids() def __enter__(self) -> Self: @@ -136,35 +144,49 @@ def _write( dct["cell"] = np.asarray(dct["cell"]) if idx is None: - idx = self._get_nextid() + idx = self._nextid nextid = idx + 1 else: - data = self.txn.get("{id}".encode("ascii")) + data = self.txn.get(f"{idx}".encode("ascii")) assert data is not None - # Add the new entry, then add the id and write the nextid + # Add the new entry self.txn.put( f"{idx}".encode("ascii"), zlib.compress( orjson.dumps(dct, option=orjson.OPT_SERIALIZE_NUMPY) ), ) - self.ids.append(idx) - self.txn.put( - "nextid".encode("ascii"), - zlib.compress( - orjson.dumps(nextid, option=orjson.OPT_SERIALIZE_NUMPY) - ), - ) + # only append if idx is not in ids + if idx not in self.ids: + self.ids.append(idx) + self.txn.put( + "nextid".encode("ascii"), + zlib.compress( + orjson.dumps(nextid, option=orjson.OPT_SERIALIZE_NUMPY) + ), + ) + # check if id is in removed ids and remove accordingly + if idx in self.deleted_ids: + self.deleted_ids.remove(idx) + self._write_deleted_ids() return idx - def delete(self, ids: list[int]) -> None: - for idx in ids: - self.txn.delete(f"{idx}".encode("ascii")) - self.ids.remove(idx) + def _update( + self, + idx: int, + key_value_pairs: Optional[dict] = None, + data: Optional[dict] = None, + ): + # hack this to play nicely with ASE code + row = self._get_row(idx, include_data=True) + if data is not None or key_value_pairs is not None: + self._write( + atoms=row, idx=idx, key_value_pairs=key_value_pairs, data=data + ) - self.deleted_ids += ids + def _write_deleted_ids(self): self.txn.put( "deleted_ids".encode("ascii"), zlib.compress( @@ -174,6 +196,14 @@ def delete(self, ids: list[int]) -> None: ), ) + def delete(self, ids: list[int]) -> None: + for idx in ids: + self.txn.delete(f"{idx}".encode("ascii")) + self.ids.remove(idx) + + self.deleted_ids += ids + self._write_deleted_ids() + def _get_row(self, idx: int, include_data: bool = True): if idx is None: assert len(self.ids) == 1 @@ -195,8 +225,7 @@ def _get_row_by_index(self, index: int, include_data: bool = True): """Auxiliary function to get the ith entry, rather than a specific id """ - id = self.ids[index] - data = self.txn.get(f"{id}".encode("ascii")) + data = self.txn.get(f"{self.ids[index]}".encode("ascii")) if data is not None: dct = orjson.loads(zlib.decompress(data)) @@ -303,16 +332,14 @@ def metadata(self, dct): ), ) - def _get_nextid(self): + @property + def _nextid(self): """Get the id of the next row to be written""" # Get the nextid nextid_data = self.txn.get("nextid".encode("ascii")) - if nextid_data is not None: - nextid = orjson.loads(zlib.decompress(nextid_data)) - else: - # This db is empty; start at 1! - nextid = 1 - + nextid = ( + orjson.loads(zlib.decompress(nextid_data)) if nextid_data else 1 + ) return nextid def count(self, selection=None, **kwargs) -> int: @@ -341,14 +368,10 @@ def _load_ids(self) -> None: # Load the deleted ids deleted_ids_data = self.txn.get("deleted_ids".encode("ascii")) - if deleted_ids_data is None: - self.deleted_ids = [] - else: + if deleted_ids_data is not None: self.deleted_ids = orjson.loads(zlib.decompress(deleted_ids_data)) # Reconstruct the full id list self.ids = [ - i - for i in range(1, self._get_nextid()) - if i not in set(self.deleted_ids) + i for i in range(1, self._nextid) if i not in set(self.deleted_ids) ] diff --git a/tests/datasets/test_ase_lmdb.py b/tests/datasets/test_ase_lmdb.py index 29ad95b668..52f0fec64f 100644 --- a/tests/datasets/test_ase_lmdb.py +++ b/tests/datasets/test_ase_lmdb.py @@ -1,25 +1,16 @@ -from pathlib import Path - import numpy as np -import tqdm +import pytest from ase import build from ase.calculators.singlepoint import SinglePointCalculator from ase.constraints import FixAtoms +from ase.db.row import AtomsRow from ocpmodels.datasets.lmdb_database import LMDBDatabase -DB_NAME = "ase_lmdb.lmdb" N_WRITES = 100 N_READS = 200 -def cleanup_asedb() -> None: - if Path(DB_NAME).is_file(): - Path(DB_NAME).unlink() - if Path(f"{DB_NAME}-lock").is_file(): - Path(f"{DB_NAME}-lock").unlink() - - test_structures = [ build.molecule("H2O", vacuum=4), build.bulk("Cu"), @@ -61,110 +52,77 @@ def generate_random_structure(): return slab -def write_random_atoms() -> None: - slab = build.fcc111("Cu", size=(4, 4, 3), vacuum=10.0) - with LMDBDatabase(DB_NAME) as db: +@pytest.fixture(scope="function") +def ase_lmbd_path(tmp_path_factory): + tmp_path = tmp_path_factory.mktemp("dataset") + with LMDBDatabase(tmp_path / "ase_lmdb.lmdb") as db: for structure in test_structures: db.write(structure) - for i in tqdm.tqdm(range(N_WRITES)): + for _ in range(N_WRITES): slab = generate_random_structure() - # Save the slab info, and make sure the info gets put in as data db.write(slab, data=slab.info) + return tmp_path / "ase_lmdb.lmdb" -def test_aselmdb_write() -> None: - # Representative structure - write_random_atoms() - - with LMDBDatabase(DB_NAME, readonly=True) as db: +def test_aselmdb_write(ase_lmbd_path) -> None: + with LMDBDatabase(ase_lmbd_path, readonly=True) as db: for i, structure in enumerate(test_structures): assert str(structure) == str(db._get_row_by_index(i).toatoms()) - cleanup_asedb() - - -def test_aselmdb_count() -> None: - # Representative structure - write_random_atoms() - with LMDBDatabase(DB_NAME, readonly=True) as db: +def test_aselmdb_count(ase_lmbd_path) -> None: + with LMDBDatabase(ase_lmbd_path, readonly=True) as db: assert db.count() == N_WRITES + len(test_structures) - cleanup_asedb() - - -def test_aselmdb_delete() -> None: - cleanup_asedb() - # Representative structure - write_random_atoms() - - with LMDBDatabase(DB_NAME) as db: +def test_aselmdb_delete(ase_lmbd_path) -> None: + with LMDBDatabase(ase_lmbd_path) as db: for i in range(5): # Note the available ids list is updating # but the ids themselves are fixed. db.delete([db.ids[0]]) - assert db.count() == N_WRITES + len(test_structures) - 5 - cleanup_asedb() - -def test_aselmdb_randomreads() -> None: - write_random_atoms() - - with LMDBDatabase(DB_NAME, readonly=True) as db: - for i in tqdm.tqdm(range(N_READS)): +def test_aselmdb_randomreads(ase_lmbd_path) -> None: + with LMDBDatabase(ase_lmbd_path, readonly=True) as db: + for _ in range(N_READS): total_size = db.count() - row = db._get_row_by_index(np.random.choice(total_size)).toatoms() - del row - cleanup_asedb() - + assert isinstance( + db._get_row_by_index(np.random.choice(total_size)), AtomsRow + ) -def test_aselmdb_constraintread() -> None: - write_random_atoms() - with LMDBDatabase(DB_NAME, readonly=True) as db: +def test_aselmdb_constraintread(ase_lmbd_path) -> None: + with LMDBDatabase(ase_lmbd_path, readonly=True) as db: atoms = db._get_row_by_index(2).toatoms() - assert type(atoms.constraints[0]) == FixAtoms + assert isinstance(atoms.constraints[0], FixAtoms) - cleanup_asedb() - -def update_keyvalue_pair() -> None: - write_random_atoms() - with LMDBDatabase(DB_NAME) as db: +def test_update_keyvalue_pair(ase_lmbd_path) -> None: + with LMDBDatabase(ase_lmbd_path) as db: db.update(1, test=5) - with LMDBDatabase(DB_NAME) as db: - row = db.get_row_by_id(1) + with LMDBDatabase(ase_lmbd_path) as db: + row = db._get_row(1) assert row.test == 5 - cleanup_asedb() - -def update_atoms() -> None: - write_random_atoms() - with LMDBDatabase(DB_NAME) as db: +def test_update_atoms(ase_lmbd_path) -> None: + with LMDBDatabase(ase_lmbd_path) as db: db.update(40, atoms=test_structures[-1]) - with LMDBDatabase(DB_NAME) as db: - row = db.get_row_by_id(40) + with LMDBDatabase(ase_lmbd_path) as db: + row = db._get_row(40) assert str(row.toatoms()) == str(test_structures[-1]) - cleanup_asedb() - -def test_metadata() -> None: - write_random_atoms() - - with LMDBDatabase(DB_NAME) as db: +def test_metadata(ase_lmbd_path) -> None: + with LMDBDatabase(ase_lmbd_path) as db: db.metadata = {"test": True} - with LMDBDatabase(DB_NAME, readonly=True) as db: + with LMDBDatabase(ase_lmbd_path, readonly=True) as db: assert db.metadata["test"] is True - - cleanup_asedb() From a0a2b2e84d40486c0c2a71ac37a51d471d638aaf Mon Sep 17 00:00:00 2001 From: lbluque Date: Thu, 25 Jan 2024 15:38:57 -0800 Subject: [PATCH 14/64] make connect_db staticmethod --- ocpmodels/datasets/ase_datasets.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ocpmodels/datasets/ase_datasets.py b/ocpmodels/datasets/ase_datasets.py index 948bb447fc..84f582f192 100644 --- a/ocpmodels/datasets/ase_datasets.py +++ b/ocpmodels/datasets/ase_datasets.py @@ -509,7 +509,7 @@ def get_atoms_object(self, idx): el_idx = idx - self._idlen_cumulative[db_idx - 1] assert el_idx >= 0 - atoms_row = self.dbs[db_idx]._get_row(self.db_ids[db_idx][el_idx]) + atoms_row = self.dbs[db_idx].get(self.db_ids[db_idx][el_idx]) atoms = atoms_row.toatoms() # put data back into atoms info @@ -518,7 +518,8 @@ def get_atoms_object(self, idx): return atoms - def connect_db(self, address, connect_args: Optional[dict] = None): + @staticmethod + def connect_db(address, connect_args: Optional[dict] = None): if connect_args is None: connect_args = {} db_type = connect_args.get("type", "extract_from_name") From 237f000648c39448eec6aca34726df8a5d2e8b3c Mon Sep 17 00:00:00 2001 From: lbluque Date: Thu, 25 Jan 2024 15:58:07 -0800 Subject: [PATCH 15/64] remove redundant methods and make some private --- ocpmodels/datasets/ase_datasets.py | 34 +++++++++++++++++------------ ocpmodels/datasets/lmdb_database.py | 4 +--- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/ocpmodels/datasets/ase_datasets.py b/ocpmodels/datasets/ase_datasets.py index 84f582f192..ca7844e82c 100644 --- a/ocpmodels/datasets/ase_datasets.py +++ b/ocpmodels/datasets/ase_datasets.py @@ -91,7 +91,7 @@ def __init__( if self.config.get("keep_in_memory", False): self.__getitem__ = functools.cache(self.__getitem__) - self.ids = self.load_dataset_get_ids(config) + self.ids = self._load_dataset_get_ids(config) def __len__(self) -> int: return len(self.ids) @@ -151,7 +151,7 @@ def get_atoms_object(self, identifier): ) @abstractmethod - def load_dataset_get_ids(self, config): + def _load_dataset_get_ids(self, config): # This function should return a list of ids that can be used to index into the database raise NotImplementedError( "Every ASE dataset needs to declare a function to load the dataset and return a list of ids." @@ -161,7 +161,7 @@ def close_db(self) -> None: # This method is sometimes called by a trainer pass - def guess_target_metadata(self, num_samples: int = 100): + def get_metadata(self, num_samples: int = 100): metadata = {} if num_samples < len(self): @@ -183,9 +183,6 @@ def guess_target_metadata(self, num_samples: int = 100): return metadata - def get_metadata(self): - return self.guess_target_metadata() - @registry.register_dataset("ase_read") class AseReadDataset(AseAtomsDataset): @@ -238,7 +235,7 @@ class AseReadDataset(AseAtomsDataset): """ - def load_dataset_get_ids(self, config) -> list[Path]: + def _load_dataset_get_ids(self, config) -> list[Path]: self.ase_read_args = config.get("ase_read_args", {}) if ":" in self.ase_read_args.get("index", ""): @@ -333,7 +330,7 @@ class AseReadMultiStructureDataset(AseAtomsDataset): transform (callable, optional): Additional preprocessing function for the Data object """ - def load_dataset_get_ids(self, config): + def _load_dataset_get_ids(self, config): self.ase_read_args = config.get("ase_read_args", {}) if not hasattr(self.ase_read_args, "index"): self.ase_read_args["index"] = ":" @@ -453,7 +450,7 @@ class AseDBDataset(AseAtomsDataset): transform (callable, optional): deprecated? """ - def load_dataset_get_ids(self, config) -> list[int]: + def _load_dataset_get_ids(self, config: dict) -> list[int]: if isinstance(config["src"], list): filepaths = config["src"] elif os.path.isfile(config["src"]): @@ -499,7 +496,14 @@ def load_dataset_get_ids(self, config) -> list[int]: return list(range(sum(idlens))) - def get_atoms_object(self, idx): + def get_atoms_object(self, idx: int) -> ase.Atoms: + """Get atoms object corresponding to datapoint idx. Useful to read other properties not in data object. + Args: + idx (int): index in dataset + + Returns: + atoms: ASE atoms corresponding to datapoint idx + """ # Figure out which db this should be indexed from. db_idx = bisect.bisect(self._idlen_cumulative, idx) @@ -519,13 +523,15 @@ def get_atoms_object(self, idx): return atoms @staticmethod - def connect_db(address, connect_args: Optional[dict] = None): + def connect_db( + address: str | Path, connect_args: Optional[dict] = None + ) -> ase.db.core.Database: if connect_args is None: connect_args = {} db_type = connect_args.get("type", "extract_from_name") if db_type in ("lmdb", "aselmdb") or ( db_type == "extract_from_name" - and address.split(".")[-1] in ("lmdb", "aselmdb") + and str(address).split(".")[-1] in ("lmdb", "aselmdb") ): return LMDBDatabase(address, readonly=True, **connect_args) else: @@ -536,12 +542,12 @@ def close_db(self) -> None: if hasattr(db, "close"): db.close() - def get_metadata(self): + def get_metadata(self, num_samples: int = 100) -> dict: logging.warning( "You specific a folder of ASE dbs, so it's impossible to know which metadata to use. Using the first!" ) if self.dbs[0].metadata == {}: - return self.guess_target_metadata() + return super().get_metadata(num_samples) else: return copy.deepcopy(self.dbs[0].metadata) diff --git a/ocpmodels/datasets/lmdb_database.py b/ocpmodels/datasets/lmdb_database.py index 8d09f99472..2264d25195 100644 --- a/ocpmodels/datasets/lmdb_database.py +++ b/ocpmodels/datasets/lmdb_database.py @@ -222,9 +222,7 @@ def _get_row(self, idx: int, include_data: bool = True): return AtomsRow(dct) def _get_row_by_index(self, index: int, include_data: bool = True): - """Auxiliary function to get the ith entry, rather than - a specific id - """ + """Auxiliary function to get the ith entry, rather than a specific id""" data = self.txn.get(f"{self.ids[index]}".encode("ascii")) if data is not None: From cae07655156b07811b768d0783c316d170017054 Mon Sep 17 00:00:00 2001 From: lbluque Date: Thu, 25 Jan 2024 16:52:24 -0800 Subject: [PATCH 16/64] allow a list of paths in AseDBdataset --- ocpmodels/datasets/ase_datasets.py | 18 +++++++++++------ tests/datasets/test_ase_datasets.py | 31 ++++++++++++++++++++++------- 2 files changed, 36 insertions(+), 13 deletions(-) diff --git a/ocpmodels/datasets/ase_datasets.py b/ocpmodels/datasets/ase_datasets.py index ca7844e82c..0e34312a9e 100644 --- a/ocpmodels/datasets/ase_datasets.py +++ b/ocpmodels/datasets/ase_datasets.py @@ -2,12 +2,12 @@ import bisect import copy -import functools -import glob import logging import os import warnings from abc import ABC, abstractmethod +from functools import cache, reduce +from glob import glob from pathlib import Path from typing import Any, Callable, Optional @@ -89,7 +89,7 @@ def __init__( self.atoms_transform = atoms_transform if self.config.get("keep_in_memory", False): - self.__getitem__ = functools.cache(self.__getitem__) + self.__getitem__ = cache(self.__getitem__) self.ids = self._load_dataset_get_ids(config) @@ -452,13 +452,19 @@ class AseDBDataset(AseAtomsDataset): def _load_dataset_get_ids(self, config: dict) -> list[int]: if isinstance(config["src"], list): - filepaths = config["src"] + if os.path.isdir(config["src"][0]): + filepaths = reduce( + lambda x, y: x + y, + (glob(f"{path}/*db") for path in config["src"]), + ) + else: + filepaths = config["src"] elif os.path.isfile(config["src"]): filepaths = [config["src"]] elif os.path.isdir(config["src"]): - filepaths = glob.glob(f'{config["src"]}/*') + filepaths = glob(f'{config["src"]}/*db') else: - filepaths = glob.glob(config["src"]) + filepaths = glob(config["src"]) self.dbs = [] diff --git a/tests/datasets/test_ase_datasets.py b/tests/datasets/test_ase_datasets.py index 38b5218dac..7505f63787 100644 --- a/tests/datasets/test_ase_datasets.py +++ b/tests/datasets/test_ase_datasets.py @@ -1,3 +1,5 @@ +import os + import numpy as np import pytest from ase import build, db @@ -38,7 +40,7 @@ "db_dataset", "db_dataset_folder", "db_dataset_list", - "lmdb_dataset", + "db_dataset_path_list" "lmdb_dataset", "aselmdb_dataset", ], ) @@ -62,12 +64,10 @@ def ase_dataset(request, tmp_path_factory): request.param == "db_dataset_folder" or request.param == "db_dataset_list" ): - with db.connect(tmp_path / "asedb1.db") as database: - for i, atoms in enumerate(structures): - database.write(atoms, data=atoms.info) - with db.connect(tmp_path / "asedb2.db") as database: - for i, atoms in enumerate(structures): - database.write(atoms, data=atoms.info) + for db_name in ("asedb1.db", "asedb2.db"): + with db.connect(tmp_path / db_name) as database: + for i, atoms in enumerate(structures): + database.write(atoms, data=atoms.info) mult = 2 src = ( str(tmp_path) @@ -75,6 +75,23 @@ def ase_dataset(request, tmp_path_factory): else [str(tmp_path / "asedb1.db"), str(tmp_path / "asedb2.db")] ) dataset = AseDBDataset(config={"src": src, "a2g_args": a2g_args}) + elif request.param == "db_dataset_path_list": + os.mkdir(tmp_path / "dir1") + os.mkdir(tmp_path / "dir2") + + for dir_name in ("dir1", "dir2"): + for db_name in ("asedb1.db", "asedb2.db"): + with db.connect(tmp_path / dir_name / db_name) as database: + for i, atoms in enumerate(structures): + database.write(atoms, data=atoms.info) + mult = 4 + dataset = AseDBDataset( + config={ + "src": [str(tmp_path / "dir1"), str(tmp_path / "dir2")], + "a2g_args": a2g_args, + } + ) + print(len(dataset.dbs)) elif request.param == "lmbd_dataset": with LMDBDatabase(str(tmp_path / "asedb.lmdb")) as database: for i, atoms in enumerate(structures): From dd0b5fc8cfd91942bd615308feca2093378bd3ed Mon Sep 17 00:00:00 2001 From: lbluque Date: Thu, 25 Jan 2024 17:30:04 -0800 Subject: [PATCH 17/64] remove sprinkled print statement --- tests/datasets/test_ase_datasets.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/datasets/test_ase_datasets.py b/tests/datasets/test_ase_datasets.py index 7505f63787..a9b67041ae 100644 --- a/tests/datasets/test_ase_datasets.py +++ b/tests/datasets/test_ase_datasets.py @@ -91,7 +91,6 @@ def ase_dataset(request, tmp_path_factory): "a2g_args": a2g_args, } ) - print(len(dataset.dbs)) elif request.param == "lmbd_dataset": with LMDBDatabase(str(tmp_path / "asedb.lmdb")) as database: for i, atoms in enumerate(structures): From 303120aa42c20887f716df9d35edce62dcd977b3 Mon Sep 17 00:00:00 2001 From: lbluque Date: Mon, 29 Jan 2024 11:55:33 -0800 Subject: [PATCH 18/64] remove deprecated transform kwarg --- ocpmodels/datasets/ase_datasets.py | 4 ---- ocpmodels/datasets/lmdb_dataset.py | 4 +--- 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/ocpmodels/datasets/ase_datasets.py b/ocpmodels/datasets/ase_datasets.py index 0e34312a9e..7b61115df0 100644 --- a/ocpmodels/datasets/ase_datasets.py +++ b/ocpmodels/datasets/ase_datasets.py @@ -73,7 +73,6 @@ def __init__( atoms_transform: Callable[ [ase.Atoms, Any, ...], ase.Atoms ] = apply_one_tags, - transform=None, # is this deprecated? ) -> None: self.config = config @@ -230,9 +229,6 @@ class AseReadDataset(AseAtomsDataset): atoms_transform (callable, optional): Additional preprocessing function applied to the Atoms object. Useful for applying tags, for example. - - transform (callable, optional): Additional preprocessing function for the Data object - """ def _load_dataset_get_ids(self, config) -> list[Path]: diff --git a/ocpmodels/datasets/lmdb_dataset.py b/ocpmodels/datasets/lmdb_dataset.py index 93e13ed33c..6e649f52ea 100644 --- a/ocpmodels/datasets/lmdb_dataset.py +++ b/ocpmodels/datasets/lmdb_dataset.py @@ -44,11 +44,9 @@ class LmdbDataset(Dataset[T_co]): folder, but lmdb lengths are now calculated directly from the number of keys. Args: config (dict): Dataset configuration - transform (callable, optional): Data transform function. - (default: :obj:`None`) """ - def __init__(self, config, transform=None) -> None: + def __init__(self, config) -> None: super(LmdbDataset, self).__init__() self.config = config From 56df36d97503032d15226d39bcb3481bc42eed67 Mon Sep 17 00:00:00 2001 From: lbluque Date: Mon, 29 Jan 2024 11:56:47 -0800 Subject: [PATCH 19/64] fix doctring typo --- ocpmodels/preprocessing/atoms_to_graphs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ocpmodels/preprocessing/atoms_to_graphs.py b/ocpmodels/preprocessing/atoms_to_graphs.py index 29e7246213..d3d8d48b86 100644 --- a/ocpmodels/preprocessing/atoms_to_graphs.py +++ b/ocpmodels/preprocessing/atoms_to_graphs.py @@ -47,7 +47,7 @@ class AtomsToGraphs: radius (int or float): Cutoff radius in Angstroms to search for neighbors. r_energy (bool): Return the energy with other properties. Default is False, so the energy will not be returned. r_forces (bool): Return the forces with other properties. Default is False, so the forces will not be returned. - r_forces (bool): Return the stress with other properties. Default is False, so the stress will not be returned. + r_stress (bool): Return the stress with other properties. Default is False, so the stress will not be returned. r_distances (bool): Return the distances with other properties. Default is False, so the distances will not be returned. r_edges (bool): Return interatomic edges with other properties. Default is True, so edges will be returned. From 597e4216cdfeb82225a5f8e8b59c36196fefe0c1 Mon Sep 17 00:00:00 2001 From: lbluque Date: Mon, 29 Jan 2024 12:09:13 -0800 Subject: [PATCH 20/64] rename keys function --- ocpmodels/datasets/_utils.py | 33 +++++++++++++++++++++++++ ocpmodels/datasets/ase_datasets.py | 11 +++------ ocpmodels/datasets/lmdb_dataset.py | 11 +++------ ocpmodels/datasets/oc22_lmdb_dataset.py | 10 +++----- 4 files changed, 45 insertions(+), 20 deletions(-) create mode 100644 ocpmodels/datasets/_utils.py diff --git a/ocpmodels/datasets/_utils.py b/ocpmodels/datasets/_utils.py new file mode 100644 index 0000000000..c0c17db083 --- /dev/null +++ b/ocpmodels/datasets/_utils.py @@ -0,0 +1,33 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + +from __future__ import annotations + +import typing + +if typing.TYPE_CHECKING: + from torch_geometric.data import Data + + +def rename_data_object_keys( + data_object: Data, key_mapping: dict[str, str] +) -> Data: + """Rename data object keys + + Args: + data_object: data object + key_mapping: dictionary specifying keys to rename and new names {prev_key: new_key} + """ + for _property in key_mapping: + # catch for test data not containing labels + if _property in data_object: + new_property = key_mapping[_property] + if new_property not in data_object: + data_object[new_property] = data_object[_property] + del data_object[_property] + + return data_object diff --git a/ocpmodels/datasets/ase_datasets.py b/ocpmodels/datasets/ase_datasets.py index 7b61115df0..32accdd3b8 100644 --- a/ocpmodels/datasets/ase_datasets.py +++ b/ocpmodels/datasets/ase_datasets.py @@ -18,6 +18,7 @@ from tqdm import tqdm from ocpmodels.common.registry import registry +from ocpmodels.datasets._utils import rename_data_object_keys from ocpmodels.datasets.lmdb_database import LMDBDatabase from ocpmodels.datasets.target_metadata_guesser import guess_property_metadata from ocpmodels.modules.transforms import DataTransforms @@ -126,13 +127,9 @@ def __getitem__(self, idx): data_object.natoms = len(atoms) if self.key_mapping is not None: - for _property in self.key_mapping: - # catch for test data not containing labels - if _property in data_object: - new_property = self.key_mapping[_property] - if new_property not in data_object: - data_object[new_property] = data_object[_property] - del data_object[_property] + data_object = rename_data_object_keys( + data_object, self.key_mapping + ) # Transform data object data_object = self.transforms(data_object) diff --git a/ocpmodels/datasets/lmdb_dataset.py b/ocpmodels/datasets/lmdb_dataset.py index 6e649f52ea..1c7e313ac2 100644 --- a/ocpmodels/datasets/lmdb_dataset.py +++ b/ocpmodels/datasets/lmdb_dataset.py @@ -21,6 +21,7 @@ from ocpmodels.common.registry import registry from ocpmodels.common.typing import assert_is_instance from ocpmodels.common.utils import pyg2_data_transform +from ocpmodels.datasets._utils import rename_data_object_keys from ocpmodels.datasets.target_metadata_guesser import guess_property_metadata from ocpmodels.modules.transforms import DataTransforms @@ -149,13 +150,9 @@ def __getitem__(self, idx: int) -> T_co: data_object = pyg2_data_transform(pickle.loads(datapoint_pickled)) if self.key_mapping is not None: - for _property in self.key_mapping: - # catch for test data not containing labels - if _property in data_object: - new_property = self.key_mapping[_property] - if new_property not in data_object: - data_object[new_property] = data_object[_property] - del data_object[_property] + data_object = rename_data_object_keys( + data_object, self.key_mapping + ) data_object = self.transforms(data_object) diff --git a/ocpmodels/datasets/oc22_lmdb_dataset.py b/ocpmodels/datasets/oc22_lmdb_dataset.py index aee0a2f81e..347f3d25d0 100644 --- a/ocpmodels/datasets/oc22_lmdb_dataset.py +++ b/ocpmodels/datasets/oc22_lmdb_dataset.py @@ -17,6 +17,7 @@ from ocpmodels.common.registry import registry from ocpmodels.common.typing import assert_is_instance as aii from ocpmodels.common.utils import pyg2_data_transform +from ocpmodels.datasets._utils import rename_data_object_keys from ocpmodels.modules.transforms import DataTransforms @@ -198,12 +199,9 @@ def __getitem__(self, idx): data_object[attr] -= lin_energy if self.key_mapping is not None: - for _property in self.key_mapping: - if _property in data_object: - new_property = self.key_mapping[_property] - if new_property not in data_object: - data_object[new_property] = data_object[_property] - del data_object[_property] + data_object = rename_data_object_keys( + data_object, self.key_mapping + ) # to jointly train on oc22+oc20, need to delete these oc20-only attributes # ensure otf_graph=1 in your model configuration From 11bd455f1db8112aea884cab8536b0d7b3040de7 Mon Sep 17 00:00:00 2001 From: lbluque Date: Mon, 29 Jan 2024 12:10:56 -0800 Subject: [PATCH 21/64] fix missing comma in tests --- tests/datasets/test_ase_datasets.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/datasets/test_ase_datasets.py b/tests/datasets/test_ase_datasets.py index a9b67041ae..5f08c17960 100644 --- a/tests/datasets/test_ase_datasets.py +++ b/tests/datasets/test_ase_datasets.py @@ -40,7 +40,8 @@ "db_dataset", "db_dataset_folder", "db_dataset_list", - "db_dataset_path_list" "lmdb_dataset", + "db_dataset_path_list", + "lmdb_dataset", "aselmdb_dataset", ], ) From 07f21726c9cc18eace04b0f6c8a971a44caa83d9 Mon Sep 17 00:00:00 2001 From: lbluque Date: Mon, 29 Jan 2024 13:24:13 -0800 Subject: [PATCH 22/64] set default r_edges in a2g in AseDatasets to false --- ocpmodels/datasets/ase_datasets.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ocpmodels/datasets/ase_datasets.py b/ocpmodels/datasets/ase_datasets.py index 32accdd3b8..2150b0e01f 100644 --- a/ocpmodels/datasets/ase_datasets.py +++ b/ocpmodels/datasets/ase_datasets.py @@ -79,6 +79,10 @@ def __init__( a2g_args = config.get("a2g_args", {}) + # set default to False if not set by user, assuming otf_graph will be used + if "r_edges" not in a2g_args: + a2g_args["r_edges"] = False + # Make sure we always include PBC info in the resulting atoms objects a2g_args["r_pbc"] = True self.a2g = AtomsToGraphs(**a2g_args) From d99d383c572293b2de5f9d3da3ede4cf42fe8e5b Mon Sep 17 00:00:00 2001 From: lbluque Date: Mon, 29 Jan 2024 13:32:40 -0800 Subject: [PATCH 23/64] simple unit-test for good measure --- tests/datasets/test_utils.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 tests/datasets/test_utils.py diff --git a/tests/datasets/test_utils.py b/tests/datasets/test_utils.py new file mode 100644 index 0000000000..d1367c011a --- /dev/null +++ b/tests/datasets/test_utils.py @@ -0,0 +1,18 @@ +import pytest +import torch +from torch_geometric.data import Data + +from ocpmodels.datasets._utils import rename_data_object_keys + + +@pytest.fixture() +def pyg_data(): + return Data(rand_tensor=torch.rand((3, 3))) + + +def test_rename_data_object_keys(pyg_data): + assert "rand_tensor" in pyg_data.keys + key_mapping = {"rand_tensor": "random_tensor"} + pyg_data = rename_data_object_keys(pyg_data, key_mapping) + assert "rand_tensor" not in pyg_data.keys + assert "random_tensor" in pyg_data.keys From 18fd2f11a70df82104edcfb0dcc303ec7a35730e Mon Sep 17 00:00:00 2001 From: lbluque Date: Wed, 31 Jan 2024 14:19:19 -0800 Subject: [PATCH 24/64] call _get_row directly --- ocpmodels/datasets/ase_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ocpmodels/datasets/ase_datasets.py b/ocpmodels/datasets/ase_datasets.py index 2150b0e01f..afbd6b6ce5 100644 --- a/ocpmodels/datasets/ase_datasets.py +++ b/ocpmodels/datasets/ase_datasets.py @@ -516,7 +516,7 @@ def get_atoms_object(self, idx: int) -> ase.Atoms: el_idx = idx - self._idlen_cumulative[db_idx - 1] assert el_idx >= 0 - atoms_row = self.dbs[db_idx].get(self.db_ids[db_idx][el_idx]) + atoms_row = self.dbs[db_idx]._get_row(self.db_ids[db_idx][el_idx]) atoms = atoms_row.toatoms() # put data back into atoms info From fd30b43ac380b5dd954c5effcf4238e965f021a9 Mon Sep 17 00:00:00 2001 From: lbluque Date: Thu, 1 Feb 2024 14:46:03 -0800 Subject: [PATCH 25/64] [wip] allow string sids --- ocpmodels/datasets/ase_datasets.py | 8 -------- ocpmodels/trainers/ocp_trainer.py | 4 ++-- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/ocpmodels/datasets/ase_datasets.py b/ocpmodels/datasets/ase_datasets.py index afbd6b6ce5..aee25145f8 100644 --- a/ocpmodels/datasets/ase_datasets.py +++ b/ocpmodels/datasets/ase_datasets.py @@ -115,14 +115,6 @@ def __getitem__(self, idx): ) sid = atoms.info.get("sid", self.ids[idx]) - try: - sid = tensor([sid]) - except (RuntimeError, ValueError, TypeError): - warnings.warn( - "Supplied sid is not numeric (or missing). Using dataset indices instead." - ) - sid = tensor([idx]) - fid = atoms.info.get("fid", tensor([0])) # Convert to data object diff --git a/ocpmodels/trainers/ocp_trainer.py b/ocpmodels/trainers/ocp_trainer.py index 26c92bf0af..18548da51a 100644 --- a/ocpmodels/trainers/ocp_trainer.py +++ b/ocpmodels/trainers/ocp_trainer.py @@ -514,10 +514,10 @@ def predict( return predictions ### Get unique system identifiers - sids = batch.sid.tolist() + sids = list(batch.sid) ## Support naming structure for OC20 S2EF if "fid" in batch: - fids = batch.fid.tolist() + fids = list(batch.fid) systemids = [f"{sid}_{fid}" for sid, fid in zip(sids, fids)] else: systemids = [f"{sid}" for sid in sids] From 77a40dd103f0328f374dd69b0f851714bbca94e3 Mon Sep 17 00:00:00 2001 From: lbluque Date: Thu, 1 Feb 2024 16:42:59 -0800 Subject: [PATCH 26/64] raise a helpful error if AseAtomsAdaptor not available --- ocpmodels/preprocessing/atoms_to_graphs.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/ocpmodels/preprocessing/atoms_to_graphs.py b/ocpmodels/preprocessing/atoms_to_graphs.py index d3d8d48b86..76b2d8dc00 100644 --- a/ocpmodels/preprocessing/atoms_to_graphs.py +++ b/ocpmodels/preprocessing/atoms_to_graphs.py @@ -19,8 +19,8 @@ try: from pymatgen.io.ase import AseAtomsAdaptor -except Exception: - pass +except ImportError: + AseAtomsAdaptor = None try: @@ -102,6 +102,11 @@ def __init__( def _get_neighbors_pymatgen(self, atoms: ase.Atoms): """Preforms nearest neighbor search and returns edge index, distances, and cell offsets""" + if AseAtomsAdaptor is None: + raise RuntimeError( + "Unable to import pymatgen.io.ase.AseAtomsAdaptor. Make sure pymatgen is properly installed." + ) + struct = AseAtomsAdaptor.get_structure(atoms) _c_index, _n_index, _offsets, n_distance = struct.get_neighbor_list( r=self.radius, numerical_tol=0, exclude_self=True From c4417346e83e19a4d427e60c8a51eace57eda28c Mon Sep 17 00:00:00 2001 From: lbluque Date: Fri, 9 Feb 2024 10:42:32 -0800 Subject: [PATCH 27/64] remove db extension in filepaths --- ocpmodels/datasets/ase_datasets.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ocpmodels/datasets/ase_datasets.py b/ocpmodels/datasets/ase_datasets.py index aee25145f8..a1f82f412b 100644 --- a/ocpmodels/datasets/ase_datasets.py +++ b/ocpmodels/datasets/ase_datasets.py @@ -444,14 +444,15 @@ def _load_dataset_get_ids(self, config: dict) -> list[int]: if os.path.isdir(config["src"][0]): filepaths = reduce( lambda x, y: x + y, - (glob(f"{path}/*db") for path in config["src"]), + (glob(f"{path}/*") for path in config["src"]), ) else: filepaths = config["src"] elif os.path.isfile(config["src"]): filepaths = [config["src"]] elif os.path.isdir(config["src"]): - filepaths = glob(f'{config["src"]}/*db') + filepaths = glob(f'{config["src"]}/*') + print(filepaths) else: filepaths = glob(config["src"]) From 5b13296876803f5b838a00794654f63d68f6818a Mon Sep 17 00:00:00 2001 From: lbluque Date: Thu, 15 Feb 2024 16:34:44 -0800 Subject: [PATCH 28/64] set logger to info level when trying to read non db files, remove print --- ocpmodels/datasets/ase_datasets.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ocpmodels/datasets/ase_datasets.py b/ocpmodels/datasets/ase_datasets.py index a1f82f412b..96b699db8a 100644 --- a/ocpmodels/datasets/ase_datasets.py +++ b/ocpmodels/datasets/ase_datasets.py @@ -452,7 +452,6 @@ def _load_dataset_get_ids(self, config: dict) -> list[int]: filepaths = [config["src"]] elif os.path.isdir(config["src"]): filepaths = glob(f'{config["src"]}/*') - print(filepaths) else: filepaths = glob(config["src"]) @@ -464,7 +463,7 @@ def _load_dataset_get_ids(self, config: dict) -> list[int]: self.connect_db(path, config.get("connect_args", {})) ) except ValueError: - logging.warning( + logging.info( f"Tried to connect to {path} but it's not an ASE database!" ) From 242b54fe770f60be41592c181067c6e4a2e8ebef Mon Sep 17 00:00:00 2001 From: lbluque Date: Fri, 16 Feb 2024 18:21:53 -0800 Subject: [PATCH 29/64] set logging.debug to avoid saturating logs --- ocpmodels/datasets/ase_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ocpmodels/datasets/ase_datasets.py b/ocpmodels/datasets/ase_datasets.py index 96b699db8a..8e31c797e0 100644 --- a/ocpmodels/datasets/ase_datasets.py +++ b/ocpmodels/datasets/ase_datasets.py @@ -463,7 +463,7 @@ def _load_dataset_get_ids(self, config: dict) -> list[int]: self.connect_db(path, config.get("connect_args", {})) ) except ValueError: - logging.info( + logging.debug( f"Tried to connect to {path} but it's not an ASE database!" ) From 6c678f10c937c642cf6d7e8e1a442a87be533fa7 Mon Sep 17 00:00:00 2001 From: Ethan Sunshine <93541000+emsunshine@users.noreply.github.com> Date: Mon, 26 Feb 2024 09:11:44 -0500 Subject: [PATCH 30/64] Update documentation for dataset config changes This PR is intended to address https://github.com/Open-Catalyst-Project/ocp/issues/629 --- TRAIN.md | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/TRAIN.md b/TRAIN.md index d16d08dfbe..37b9934f67 100644 --- a/TRAIN.md +++ b/TRAIN.md @@ -382,10 +382,8 @@ If your data is already in an [ASE Database](https://databases.fysik.dtu.dk/ase/ To use this dataset, we will just have to change our config files to use the ASE DB Dataset rather than the LMDB Dataset: ```yaml -task: - dataset: ase_db - dataset: + format: ase_db train: src: # The path/address to your ASE DB connect_args: @@ -420,10 +418,8 @@ It is possible to train/predict directly on ASE-readable files. This is only rec This dataset assumes a single structure will be obtained from each file: ```yaml -task: - dataset: ase_read - dataset: + format: ase_read train: src: # The folder that contains ASE-readable files pattern: # Pattern matching each file you want to read (e.g. "*/POSCAR"). Search recursively with two wildcards: "**/*.cif". @@ -443,10 +439,8 @@ dataset: This dataset supports reading files that each contain multiple structure (for example, an ASE .traj file). Using an index file, which tells the dataset how many structures each file contains, is recommended. Otherwise, the dataset is forced to load every file at startup and count the number of structures! ```yaml -task: - dataset: ase_read_multi - dataset: + format: ase_read_multi train: index_file: Filepath to an index file which contains each filename and the number of structures in each file. e.g.: /path/to/relaxation1.traj 200 From fd4d3e80d4a70fe2347afb0f3a76af59a970cd49 Mon Sep 17 00:00:00 2001 From: Ethan Sunshine <93541000+emsunshine@users.noreply.github.com> Date: Mon, 26 Feb 2024 09:46:06 -0500 Subject: [PATCH 31/64] Update atoms_to_graphs.py --- ocpmodels/preprocessing/atoms_to_graphs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ocpmodels/preprocessing/atoms_to_graphs.py b/ocpmodels/preprocessing/atoms_to_graphs.py index 34291f173a..962274f316 100644 --- a/ocpmodels/preprocessing/atoms_to_graphs.py +++ b/ocpmodels/preprocessing/atoms_to_graphs.py @@ -177,10 +177,10 @@ def convert(self, atoms: ase.Atoms, sid=None): data.cell_offsets = cell_offsets if self.r_energy: energy = atoms.get_potential_energy(apply_constraint=False) - data.y = energy + data.energy = energy if self.r_forces: forces = torch.Tensor(atoms.get_forces(apply_constraint=False)) - data.force = forces + data.forces = forces if self.r_distances and self.r_edges: data.distances = edge_distances if self.r_fixed: From 61ffef350022c6182fc3ae260bbe8e11cfd5c5a7 Mon Sep 17 00:00:00 2001 From: Ethan Sunshine <93541000+emsunshine@users.noreply.github.com> Date: Mon, 26 Feb 2024 11:18:21 -0500 Subject: [PATCH 32/64] Update test_ase_datasets.py --- tests/datasets/test_ase_datasets.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/datasets/test_ase_datasets.py b/tests/datasets/test_ase_datasets.py index d1767c9782..6df280e191 100644 --- a/tests/datasets/test_ase_datasets.py +++ b/tests/datasets/test_ase_datasets.py @@ -1,4 +1,4 @@ -import os +gyimport os import numpy as np from ase import build, db @@ -447,8 +447,8 @@ def test_ase_multiread_dataset() -> None: [dataset[:]] assert hasattr(dataset[0], "y_relaxed") - assert dataset[0].y_relaxed != dataset[0].y - assert dataset[-1].y_relaxed == dataset[-1].y + assert dataset[0].y_relaxed != dataset[0].energy + assert dataset[-1].y_relaxed == dataset[-1].energy dataset = AseReadDataset( config={ @@ -468,7 +468,7 @@ def test_ase_multiread_dataset() -> None: [dataset[:]] assert hasattr(dataset[0], "y_relaxed") - assert dataset[0].y_relaxed != dataset[0].y + assert dataset[0].y_relaxed != dataset[0].energy os.remove( os.path.join(os.path.dirname(os.path.abspath(__file__)), "test.traj") From e3ea55915eedbf1b359ba13a175b41d546d2b7cc Mon Sep 17 00:00:00 2001 From: Ethan Sunshine <93541000+emsunshine@users.noreply.github.com> Date: Mon, 26 Feb 2024 11:19:51 -0500 Subject: [PATCH 33/64] Update test_ase_datasets.py --- tests/datasets/test_ase_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/datasets/test_ase_datasets.py b/tests/datasets/test_ase_datasets.py index 6df280e191..072e5fd505 100644 --- a/tests/datasets/test_ase_datasets.py +++ b/tests/datasets/test_ase_datasets.py @@ -1,4 +1,4 @@ -gyimport os +import os import numpy as np from ase import build, db From 21ccf6ab1565b7c09488c5cd332ac53da0da779c Mon Sep 17 00:00:00 2001 From: Ethan Sunshine <93541000+emsunshine@users.noreply.github.com> Date: Mon, 26 Feb 2024 12:04:36 -0500 Subject: [PATCH 34/64] Update test_atoms_to_graphs.py --- tests/preprocessing/test_atoms_to_graphs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/preprocessing/test_atoms_to_graphs.py b/tests/preprocessing/test_atoms_to_graphs.py index b0a035a16e..8468e375ba 100644 --- a/tests/preprocessing/test_atoms_to_graphs.py +++ b/tests/preprocessing/test_atoms_to_graphs.py @@ -100,7 +100,7 @@ def test_convert(self) -> None: np.testing.assert_allclose(act_positions, positions) # check energy value act_energy = self.atoms.get_potential_energy(apply_constraint=False) - test_energy = data.y + test_energy = data.energy np.testing.assert_equal(act_energy, test_energy) # forces act_forces = self.atoms.get_forces(apply_constraint=False) @@ -123,7 +123,7 @@ def test_convert_all(self) -> None: np.testing.assert_allclose(act_positions, positions) # check energy value act_energy = self.atoms.get_potential_energy(apply_constraint=False) - test_energy = data_list[0].y + test_energy = data_list[0].energy np.testing.assert_equal(act_energy, test_energy) # forces act_forces = self.atoms.get_forces(apply_constraint=False) From b8a4c2f95ec5af9b138e3884f9ead671b278e13f Mon Sep 17 00:00:00 2001 From: Ethan Sunshine <93541000+emsunshine@users.noreply.github.com> Date: Mon, 26 Feb 2024 12:11:34 -0500 Subject: [PATCH 35/64] Update test_atoms_to_graphs.py --- tests/preprocessing/test_atoms_to_graphs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/preprocessing/test_atoms_to_graphs.py b/tests/preprocessing/test_atoms_to_graphs.py index 8468e375ba..3749286504 100644 --- a/tests/preprocessing/test_atoms_to_graphs.py +++ b/tests/preprocessing/test_atoms_to_graphs.py @@ -104,7 +104,7 @@ def test_convert(self) -> None: np.testing.assert_equal(act_energy, test_energy) # forces act_forces = self.atoms.get_forces(apply_constraint=False) - forces = data.force.numpy() + forces = data.forces.numpy() np.testing.assert_allclose(act_forces, forces) def test_convert_all(self) -> None: @@ -127,5 +127,5 @@ def test_convert_all(self) -> None: np.testing.assert_equal(act_energy, test_energy) # forces act_forces = self.atoms.get_forces(apply_constraint=False) - forces = data_list[0].force.numpy() + forces = data_list[0].forces.numpy() np.testing.assert_allclose(act_forces, forces) From ec17ce8ba10a83bb99319fd0fb49b2dc39f66fd7 Mon Sep 17 00:00:00 2001 From: lbluque Date: Mon, 26 Feb 2024 16:01:53 -0800 Subject: [PATCH 36/64] case for explicit a2g_args None values --- ocpmodels/datasets/ase_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ocpmodels/datasets/ase_datasets.py b/ocpmodels/datasets/ase_datasets.py index 8e31c797e0..d1b71c9a3c 100644 --- a/ocpmodels/datasets/ase_datasets.py +++ b/ocpmodels/datasets/ase_datasets.py @@ -77,7 +77,7 @@ def __init__( ) -> None: self.config = config - a2g_args = config.get("a2g_args", {}) + a2g_args = config.get("a2g_args", {}) or {} # set default to False if not set by user, assuming otf_graph will be used if "r_edges" not in a2g_args: From 01863ddc55d3e8f6ecd30f558fbbe3239e92f67f Mon Sep 17 00:00:00 2001 From: Ethan Sunshine <93541000+emsunshine@users.noreply.github.com> Date: Tue, 27 Feb 2024 10:04:31 -0500 Subject: [PATCH 37/64] Update update_config() --- ocpmodels/common/utils.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/ocpmodels/common/utils.py b/ocpmodels/common/utils.py index 53b497e32d..d43ab87b96 100644 --- a/ocpmodels/common/utils.py +++ b/ocpmodels/common/utils.py @@ -1204,13 +1204,20 @@ def update_config(base_config): are now. Update old configs to fit the new expected structure. """ config = copy.deepcopy(base_config) - config["dataset"]["format"] = config["task"].get("dataset", "lmdb") + + # If config["dataset"]["format"] is missing, get it from the task (legacy location). + # If it is not there either, default to LMDB. + config["dataset"]["format"] = config["dataset"].get("format", config["task"].get("dataset", "lmdb")) + ### Read task based off config structure, similar to OCPCalculator. if config["task"]["dataset"] in [ "trajectory_lmdb", "lmdb", "trajectory_lmdb_v2", "oc22_lmdb", + "ase_read", + "ase_read_multi", + "ase_db", ]: task = "s2ef" elif config["task"]["dataset"] == "single_point_lmdb": From 1c5ca262e6607dd1d70905183daefa3a68521fc0 Mon Sep 17 00:00:00 2001 From: Ethan Sunshine <93541000+emsunshine@users.noreply.github.com> Date: Tue, 27 Feb 2024 10:12:12 -0500 Subject: [PATCH 38/64] Update utils.py --- ocpmodels/common/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ocpmodels/common/utils.py b/ocpmodels/common/utils.py index d43ab87b96..c75ecb5fb2 100644 --- a/ocpmodels/common/utils.py +++ b/ocpmodels/common/utils.py @@ -1207,7 +1207,8 @@ def update_config(base_config): # If config["dataset"]["format"] is missing, get it from the task (legacy location). # If it is not there either, default to LMDB. - config["dataset"]["format"] = config["dataset"].get("format", config["task"].get("dataset", "lmdb")) + config["dataset"]["format"] = config["dataset"].get( + "format", config["task"].get("dataset", "lmdb")) ### Read task based off config structure, similar to OCPCalculator. if config["task"]["dataset"] in [ From 90a6f6eff48db8861b56db3c67d3301b525d64a4 Mon Sep 17 00:00:00 2001 From: Ethan Sunshine <93541000+emsunshine@users.noreply.github.com> Date: Tue, 27 Feb 2024 10:19:44 -0500 Subject: [PATCH 39/64] Update utils.py --- ocpmodels/common/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/ocpmodels/common/utils.py b/ocpmodels/common/utils.py index c75ecb5fb2..13b588ed14 100644 --- a/ocpmodels/common/utils.py +++ b/ocpmodels/common/utils.py @@ -1208,8 +1208,9 @@ def update_config(base_config): # If config["dataset"]["format"] is missing, get it from the task (legacy location). # If it is not there either, default to LMDB. config["dataset"]["format"] = config["dataset"].get( - "format", config["task"].get("dataset", "lmdb")) - + "format", config["task"].get("dataset", "lmdb") + ) + ### Read task based off config structure, similar to OCPCalculator. if config["task"]["dataset"] in [ "trajectory_lmdb", From 885deba90f8f5cd5260986be6222de2618d65ce4 Mon Sep 17 00:00:00 2001 From: Ethan Sunshine <93541000+emsunshine@users.noreply.github.com> Date: Tue, 27 Feb 2024 12:36:53 -0500 Subject: [PATCH 40/64] Update ocp_trainer.py More helpful warning for debug mode --- ocpmodels/trainers/ocp_trainer.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/ocpmodels/trainers/ocp_trainer.py b/ocpmodels/trainers/ocp_trainer.py index 1ef82baf52..07a3c19397 100644 --- a/ocpmodels/trainers/ocp_trainer.py +++ b/ocpmodels/trainers/ocp_trainer.py @@ -527,7 +527,16 @@ def predict( for key in predictions: predictions[key] = np.array(predictions[key]) - self.save_results(predictions, results_file) + if self.is_debug: + try: + self.save_results(predictions, results_file) + except FileNotFoundError: + logging.warning( + "Predictions npz file not found. " + \ + "This file was not written since the trainer was running in debug mode." + ) + else: + self.save_results(predictions, results_file) if self.ema: self.ema.restore() From 0903f032f3655afaca752873fa86b0779c1a27e3 Mon Sep 17 00:00:00 2001 From: Ethan Sunshine <93541000+emsunshine@users.noreply.github.com> Date: Tue, 27 Feb 2024 12:38:36 -0500 Subject: [PATCH 41/64] Update ocp_trainer.py --- ocpmodels/trainers/ocp_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ocpmodels/trainers/ocp_trainer.py b/ocpmodels/trainers/ocp_trainer.py index 07a3c19397..c210f048bc 100644 --- a/ocpmodels/trainers/ocp_trainer.py +++ b/ocpmodels/trainers/ocp_trainer.py @@ -533,7 +533,7 @@ def predict( except FileNotFoundError: logging.warning( "Predictions npz file not found. " + \ - "This file was not written since the trainer was running in debug mode." + "This file was not written since the trainer is running in debug mode." ) else: self.save_results(predictions, results_file) From 17ca6a9d2c2d7429675963271dfc00128aa81b69 Mon Sep 17 00:00:00 2001 From: Ethan Sunshine <93541000+emsunshine@users.noreply.github.com> Date: Tue, 27 Feb 2024 12:41:33 -0500 Subject: [PATCH 42/64] Update ocp_trainer.py --- ocpmodels/trainers/ocp_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ocpmodels/trainers/ocp_trainer.py b/ocpmodels/trainers/ocp_trainer.py index c210f048bc..519fa7fc94 100644 --- a/ocpmodels/trainers/ocp_trainer.py +++ b/ocpmodels/trainers/ocp_trainer.py @@ -532,8 +532,8 @@ def predict( self.save_results(predictions, results_file) except FileNotFoundError: logging.warning( - "Predictions npz file not found. " + \ - "This file was not written since the trainer is running in debug mode." + "Predictions npz file not found. " + + "This file was not written since the trainer is running in debug mode." ) else: self.save_results(predictions, results_file) From c4ca1b05a8268721a6b37d6edd138116ba39be98 Mon Sep 17 00:00:00 2001 From: Ethan Sunshine <93541000+emsunshine@users.noreply.github.com> Date: Tue, 27 Feb 2024 12:55:24 -0500 Subject: [PATCH 43/64] Update TRAIN.md --- TRAIN.md | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/TRAIN.md b/TRAIN.md index 37b9934f67..03719e23db 100644 --- a/TRAIN.md +++ b/TRAIN.md @@ -204,11 +204,11 @@ To train and validate an OC20 IS2RE/S2EF model on total energies instead of adso ```yaml task: - dataset: oc22_lmdb prediction_dtype: float32 ... dataset: + format: oc22_lmdb train: src: data/oc20/s2ef/train normalize_labels: False @@ -308,8 +308,8 @@ For the IS2RE-Total task, the model takes the initial structure as input and pre ```yaml trainer: energy # Use the EnergyTrainer -task: - dataset: oc22_lmdb # Use the OC22LmdbDataset +dataset: + format: oc22_lmdb # Use the OC22LmdbDataset ... ``` You can find examples configuration files in [`configs/oc22/is2re`](https://github.com/Open-Catalyst-Project/ocp/tree/main/configs/oc22/is2re). @@ -321,8 +321,8 @@ The S2EF-Total task takes a structure and predicts the total DFT energy and per- ```yaml trainer: forces # Use the ForcesTrainer -task: - dataset: oc22_lmdb # Use the OC22LmdbDataset +dataset: + format: oc22_lmdb # Use the OC22LmdbDataset ... ``` You can find examples configuration files in [`configs/oc22/s2ef`](https://github.com/Open-Catalyst-Project/ocp/tree/main/configs/oc22/s2ef). @@ -332,8 +332,8 @@ You can find examples configuration files in [`configs/oc22/s2ef`](https://githu Training on OC20 total energies whether independently or jointly with OC22 requires a path to the `oc20_ref` (download link provided below) to be specified in the configuration file. These are necessary to convert OC20 adsorption energies into their corresponding total energies. The following changes in the configuration file capture these changes: ```yaml -task: - dataset: oc22_lmdb +dataset: + format: oc22_lmdb ... dataset: From ce52b2fe6855e656f9bde0c59830c1baad388e7e Mon Sep 17 00:00:00 2001 From: lbluque Date: Tue, 27 Feb 2024 11:50:09 -0800 Subject: [PATCH 44/64] fix concatenating predictions --- ocpmodels/trainers/ocp_trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ocpmodels/trainers/ocp_trainer.py b/ocpmodels/trainers/ocp_trainer.py index 3dbc67b398..0916a80d04 100644 --- a/ocpmodels/trainers/ocp_trainer.py +++ b/ocpmodels/trainers/ocp_trainer.py @@ -525,7 +525,9 @@ def predict( predictions["ids"].extend(systemids) for key in predictions: - predictions[key] = np.array(predictions[key]) + # allow for lists of 'zero dim' arrays + axis = 0 if isinstance(predictions[key][0], np.ndarray) else None + predictions[key] = np.concatenate(predictions[key], axis=axis) self.save_results(predictions, results_file) From 574190740b36d8ad0a83a9e527e4712905460501 Mon Sep 17 00:00:00 2001 From: lbluque Date: Tue, 27 Feb 2024 13:39:41 -0800 Subject: [PATCH 45/64] check if keys exist in atoms.info --- ocpmodels/preprocessing/atoms_to_graphs.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/ocpmodels/preprocessing/atoms_to_graphs.py b/ocpmodels/preprocessing/atoms_to_graphs.py index 76b2d8dc00..546fe6991b 100644 --- a/ocpmodels/preprocessing/atoms_to_graphs.py +++ b/ocpmodels/preprocessing/atoms_to_graphs.py @@ -216,14 +216,13 @@ def convert(self, atoms: ase.Atoms, sid=None): if self.r_pbc: data.pbc = torch.tensor(atoms.pbc) if self.r_data_keys is not None: - for ( - data_key - ) in self.r_data_keys: # if key is not present let it raise error - data[data_key] = ( - atoms.info[data_key] - if isinstance(atoms.info[data_key], (int, float)) - else torch.Tensor(atoms.info[data_key]) - ) + for data_key in self.r_data_keys: + if data_key in atoms.info: + data[data_key] = ( + atoms.info[data_key] + if isinstance(atoms.info[data_key], (int, float)) + else torch.Tensor(atoms.info[data_key]) + ) return data From 068b053e3b32b1e1a09faf5fcc449193287a544c Mon Sep 17 00:00:00 2001 From: Ethan Sunshine <93541000+emsunshine@users.noreply.github.com> Date: Wed, 28 Feb 2024 10:13:12 -0500 Subject: [PATCH 46/64] Update test_ase_datasets.py --- tests/datasets/test_ase_datasets.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/datasets/test_ase_datasets.py b/tests/datasets/test_ase_datasets.py index 26f7e75d7c..76e08a1693 100644 --- a/tests/datasets/test_ase_datasets.py +++ b/tests/datasets/test_ase_datasets.py @@ -117,7 +117,7 @@ def test_ase_dataset(ase_dataset): assert len(dataset) == mult * len(structures) for data in dataset: assert hasattr(data, "y") - assert data.force.shape == (data.natoms, 3) + assert data.forces.shape == (data.natoms, 3) assert data.stress.shape == (3, 3) assert data.tensor_property.shape == (6, 6) assert isinstance(data.extensive_property, int) @@ -268,4 +268,4 @@ def test_ase_multiread_dataset(tmp_path) -> None: ) assert hasattr(dataset[0], "y_relaxed") - assert dataset[0].y_relaxed != dataset[0].energy \ No newline at end of file + assert dataset[0].y_relaxed != dataset[0].energy From 987ba9fae641edbf4534615e50341f6ebc0f6110 Mon Sep 17 00:00:00 2001 From: lbluque Date: Mon, 4 Mar 2024 18:26:17 -0700 Subject: [PATCH 47/64] use list() to cast all batch.sid/fid --- ocpmodels/trainers/ocp_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ocpmodels/trainers/ocp_trainer.py b/ocpmodels/trainers/ocp_trainer.py index 67b714ef88..3bff4620d1 100644 --- a/ocpmodels/trainers/ocp_trainer.py +++ b/ocpmodels/trainers/ocp_trainer.py @@ -591,7 +591,7 @@ def run_relaxations(self, split="val"): if check_traj_files( batch, self.config["task"]["relax_opt"].get("traj_dir", None) ): - logging.info(f"Skipping batch: {batch[0].sid.tolist()}") + logging.info(f"Skipping batch: {list(batch[0].sid)}") continue relaxed_batch = ml_relax( @@ -606,7 +606,7 @@ def run_relaxations(self, split="val"): ) if self.config["task"].get("write_pos", False): - systemids = [str(i) for i in relaxed_batch.sid.tolist()] + systemids = [str(i) for i in list(relaxed_batch.sid)] natoms = relaxed_batch.natoms.tolist() positions = torch.split(relaxed_batch.pos, natoms) batch_relaxed_positions = [pos.tolist() for pos in positions] From 7995b5efd6bc92a7ccd9103e97336d20c109da7e Mon Sep 17 00:00:00 2001 From: lbluque Date: Tue, 5 Mar 2024 16:58:51 -0800 Subject: [PATCH 48/64] correctly stack predictions --- ocpmodels/trainers/ocp_trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/ocpmodels/trainers/ocp_trainer.py b/ocpmodels/trainers/ocp_trainer.py index 3bff4620d1..6da91b6923 100644 --- a/ocpmodels/trainers/ocp_trainer.py +++ b/ocpmodels/trainers/ocp_trainer.py @@ -526,8 +526,7 @@ def predict( for key in predictions: # allow for lists of 'zero dim' arrays - axis = 0 if isinstance(predictions[key][0], np.ndarray) else None - predictions[key] = np.concatenate(predictions[key], axis=axis) + predictions[key] = np.vstack(predictions[key]).squeeze() if self.is_debug: try: From f0982bbc41dba8a472225a1d1aec225ec0906f18 Mon Sep 17 00:00:00 2001 From: lbluque Date: Mon, 18 Mar 2024 17:09:53 -0700 Subject: [PATCH 49/64] raise error on empty datasets --- ocpmodels/datasets/ase_datasets.py | 11 +++++++++-- tests/datasets/test_ase_datasets.py | 9 +++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/ocpmodels/datasets/ase_datasets.py b/ocpmodels/datasets/ase_datasets.py index d1b71c9a3c..b20d925c54 100644 --- a/ocpmodels/datasets/ase_datasets.py +++ b/ocpmodels/datasets/ase_datasets.py @@ -97,6 +97,12 @@ def __init__( self.ids = self._load_dataset_get_ids(config) + if len(self.ids) == 0: + raise ValueError( + rf"No valid ase data found!" + f"Double check that the src path and/or glob search pattern gives ASE compatible data: {config['src']}" + ) + def __len__(self) -> int: return len(self.ids) @@ -240,7 +246,7 @@ def _load_dataset_get_ids(self, config) -> list[Path]: self.relaxed_ase_read_args = copy.deepcopy(self.ase_read_args) self.relaxed_ase_read_args["index"] = "-1" - return list(self.path.glob(f'{config["pattern"]}')) + return list(self.path.glob(f'{config.get("pattern", "*")}')) def get_atoms_object(self, identifier): try: @@ -339,7 +345,7 @@ def _load_dataset_get_ids(self, config): self.path = Path(config["src"]) if self.path.is_file(): raise Exception("The specified src is not a directory") - filenames = list(self.path.glob(f'{config["pattern"]}')) + filenames = list(self.path.glob(f'{config.get("pattern", "*")}')) ids = [] @@ -397,6 +403,7 @@ class AseDBDataset(AseAtomsDataset): - the path an ASE DB, - the connection address of an ASE DB, - a folder with multiple ASE DBs, + - a list of folders with ASE DBs - a glob string to use to find ASE DBs, or - a list of ASE db paths/addresses. If a folder, every file will be attempted as an ASE DB, and warnings diff --git a/tests/datasets/test_ase_datasets.py b/tests/datasets/test_ase_datasets.py index 76e08a1693..34dd474118 100644 --- a/tests/datasets/test_ase_datasets.py +++ b/tests/datasets/test_ase_datasets.py @@ -269,3 +269,12 @@ def test_ase_multiread_dataset(tmp_path) -> None: assert hasattr(dataset[0], "y_relaxed") assert dataset[0].y_relaxed != dataset[0].energy + + +def test_empty_dataset(tmp_path): + # raises error on empty dataset + with pytest.raises(ValueError): + AseReadMultiStructureDataset(config={"src": str(tmp_path)}) + + with pytest.raises(ValueError): + AseDBDataset(config={"src": str(tmp_path)}) From 56531d7e8d8d55b735ffc013db9d599a10bc2c02 Mon Sep 17 00:00:00 2001 From: lbluque Date: Mon, 18 Mar 2024 17:12:13 -0700 Subject: [PATCH 50/64] raise ValueError instead of exception --- ocpmodels/datasets/ase_datasets.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/ocpmodels/datasets/ase_datasets.py b/ocpmodels/datasets/ase_datasets.py index b20d925c54..ee8b47fa6a 100644 --- a/ocpmodels/datasets/ase_datasets.py +++ b/ocpmodels/datasets/ase_datasets.py @@ -240,7 +240,9 @@ def _load_dataset_get_ids(self, config) -> list[Path]: self.path = Path(config["src"]) if self.path.is_file(): - raise Exception("The specified src is not a directory") + ValueError( + f"The specified src is not a directory: {self.config['src']}" + ) if self.config.get("include_relaxed_energy", False): self.relaxed_ase_read_args = copy.deepcopy(self.ase_read_args) @@ -344,7 +346,10 @@ def _load_dataset_get_ids(self, config): self.path = Path(config["src"]) if self.path.is_file(): - raise Exception("The specified src is not a directory") + raise ValueError( + f"The specified src is not a directory: {self.config['src']}" + ) + filenames = list(self.path.glob(f'{config.get("pattern", "*")}')) ids = [] From b9e758d351a6c1393ae73cece30bc1c5138f3f1d Mon Sep 17 00:00:00 2001 From: lbluque Date: Mon, 18 Mar 2024 17:36:46 -0700 Subject: [PATCH 51/64] code cleanup --- ocpmodels/datasets/ase_datasets.py | 58 ++++++++++++++++++------------ 1 file changed, 36 insertions(+), 22 deletions(-) diff --git a/ocpmodels/datasets/ase_datasets.py b/ocpmodels/datasets/ase_datasets.py index ee8b47fa6a..aa0c049b6d 100644 --- a/ocpmodels/datasets/ase_datasets.py +++ b/ocpmodels/datasets/ase_datasets.py @@ -1,3 +1,10 @@ +""" +Copyright (c) Meta, Inc. and its affiliates. + +This source code is licensed under the MIT license found in the +LICENSE file in the root directory of this source tree. +""" + from __future__ import annotations import bisect @@ -142,7 +149,7 @@ def __getitem__(self, idx): return data_object @abstractmethod - def get_atoms_object(self, identifier): + def get_atoms_object(self, idx: str | int): # This function should return an ASE atoms object. raise NotImplementedError( "Returns an ASE atoms object. Derived classes should implement this function." @@ -155,6 +162,12 @@ def _load_dataset_get_ids(self, config): "Every ASE dataset needs to declare a function to load the dataset and return a list of ids." ) + @abstractmethod + def get_relaxed_energy(self, identifier): + raise NotImplementedError( + "IS2RE-Direct is not implemented with this dataset." + ) + def close_db(self) -> None: # This method is sometimes called by a trainer pass @@ -240,7 +253,7 @@ def _load_dataset_get_ids(self, config) -> list[Path]: self.path = Path(config["src"]) if self.path.is_file(): - ValueError( + raise ValueError( f"The specified src is not a directory: {self.config['src']}" ) @@ -250,11 +263,11 @@ def _load_dataset_get_ids(self, config) -> list[Path]: return list(self.path.glob(f'{config.get("pattern", "*")}')) - def get_atoms_object(self, identifier): + def get_atoms_object(self, idx): try: - atoms = ase.io.read(identifier, **self.ase_read_args) + atoms = ase.io.read(idx, **self.ase_read_args) except Exception as err: - warnings.warn(f"{err} occured for: {identifier}") + warnings.warn(f"{err} occured for: {idx}", stacklevel=2) raise err return atoms @@ -333,12 +346,12 @@ def _load_dataset_get_ids(self, config): self.ase_read_args["index"] = ":" if config.get("index_file", None) is not None: - f = open(config["index_file"], "r") - index = f.readlines() + with open(config["index_file"], "r") as f: + index = f.readlines() ids = [] for line in index: - filename = line.split(" ")[0] + filename = line.split(" ", maxsplit=1)[0] for i in range(int(line.split(" ")[1])): ids.append(f"{filename} {i}") @@ -360,30 +373,31 @@ def _load_dataset_get_ids(self, config): try: structures = ase.io.read(filename, **self.ase_read_args) except Exception as err: - warnings.warn(f"{err} occured for: {filename}") + warnings.warn(f"{err} occured for: {filename}", stacklevel=2) else: - for i, structure in enumerate(structures): + for i, _ in enumerate(structures): ids.append(f"{filename} {i}") return ids - def get_atoms_object(self, identifier): + def get_atoms_object(self, idx): try: + identifiers = idx.split(" ") atoms = ase.io.read( - "".join(identifier.split(" ")[:-1]), **self.ase_read_args - )[int(identifier.split(" ")[-1])] + "".join(identifiers[:-1]), **self.ase_read_args + )[int(identifiers[-1])] except Exception as err: - warnings.warn(f"{err} occured for: {identifier}") + warnings.warn(f"{err} occured for: {idx}", stacklevel=2) raise err if "sid" not in atoms.info: - atoms.info["sid"] = "".join(identifier.split(" ")[:-1]) + atoms.info["sid"] = "".join(identifiers[:-1]) if "fid" not in atoms.info: - atoms.info["fid"] = int(identifier.split(" ")[-1]) + atoms.info["fid"] = int(identifiers[-1]) return atoms - def get_metadata(self): + def get_metadata(self, num_samples: int = 100): return {} def get_relaxed_energy(self, identifier): @@ -538,11 +552,11 @@ def connect_db( db_type = connect_args.get("type", "extract_from_name") if db_type in ("lmdb", "aselmdb") or ( db_type == "extract_from_name" - and str(address).split(".")[-1] in ("lmdb", "aselmdb") + and str(address).rsplit(".", maxsplit=1)[-1] in ("lmdb", "aselmdb") ): return LMDBDatabase(address, readonly=True, **connect_args) - else: - return ase.db.connect(address, **connect_args) + + return ase.db.connect(address, **connect_args) def close_db(self) -> None: for db in self.dbs: @@ -555,8 +569,8 @@ def get_metadata(self, num_samples: int = 100) -> dict: ) if self.dbs[0].metadata == {}: return super().get_metadata(num_samples) - else: - return copy.deepcopy(self.dbs[0].metadata) + + return copy.deepcopy(self.dbs[0].metadata) def get_relaxed_energy(self, identifier): raise NotImplementedError( From f6bb5d57fc71eea183536493a3261bc86bdf5ec0 Mon Sep 17 00:00:00 2001 From: lbluque Date: Mon, 18 Mar 2024 17:42:52 -0700 Subject: [PATCH 52/64] rename get_atoms object -> get_atoms for brevity --- ocpmodels/datasets/ase_datasets.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/ocpmodels/datasets/ase_datasets.py b/ocpmodels/datasets/ase_datasets.py index aa0c049b6d..cdecd82dff 100644 --- a/ocpmodels/datasets/ase_datasets.py +++ b/ocpmodels/datasets/ase_datasets.py @@ -119,7 +119,7 @@ def __getitem__(self, idx): return [self[i] for i in range(*idx.indices(len(self)))] # Get atoms object via derived class method - atoms = self.get_atoms_object(self.ids[idx]) + atoms = self.get_atoms(self.ids[idx]) # Transform atoms object if self.atoms_transform is not None: @@ -149,7 +149,7 @@ def __getitem__(self, idx): return data_object @abstractmethod - def get_atoms_object(self, idx: str | int): + def get_atoms(self, idx: str | int) -> ase.Atoms: # This function should return an ASE atoms object. raise NotImplementedError( "Returns an ASE atoms object. Derived classes should implement this function." @@ -172,13 +172,13 @@ def close_db(self) -> None: # This method is sometimes called by a trainer pass - def get_metadata(self, num_samples: int = 100): + def get_metadata(self, num_samples: int = 100) -> dict: metadata = {} if num_samples < len(self): metadata["targets"] = guess_property_metadata( [ - self.get_atoms_object(self.ids[idx]) + self.get_atoms(self.ids[idx]) for idx in np.random.choice( len(self), size=(num_samples,), replace=False ) @@ -186,10 +186,7 @@ def get_metadata(self, num_samples: int = 100): ) else: metadata["targets"] = guess_property_metadata( - [ - self.get_atoms_object(self.ids[idx]) - for idx in range(len(self)) - ] + [self.get_atoms(self.ids[idx]) for idx in range(len(self))] ) return metadata @@ -263,7 +260,7 @@ def _load_dataset_get_ids(self, config) -> list[Path]: return list(self.path.glob(f'{config.get("pattern", "*")}')) - def get_atoms_object(self, idx): + def get_atoms(self, idx: str | int) -> ase.Atoms: try: atoms = ase.io.read(idx, **self.ase_read_args) except Exception as err: @@ -272,7 +269,7 @@ def get_atoms_object(self, idx): return atoms - def get_relaxed_energy(self, identifier): + def get_relaxed_energy(self, identifier) -> float: relaxed_atoms = ase.io.read(identifier, **self.relaxed_ase_read_args) return relaxed_atoms.get_potential_energy(apply_constraint=False) @@ -340,7 +337,7 @@ class AseReadMultiStructureDataset(AseAtomsDataset): transform (callable, optional): Additional preprocessing function for the Data object """ - def _load_dataset_get_ids(self, config): + def _load_dataset_get_ids(self, config) -> list[str]: self.ase_read_args = config.get("ase_read_args", {}) if not hasattr(self.ase_read_args, "index"): self.ase_read_args["index"] = ":" @@ -380,7 +377,7 @@ def _load_dataset_get_ids(self, config): return ids - def get_atoms_object(self, idx): + def get_atoms(self, idx: str) -> ase.Atoms: try: identifiers = idx.split(" ") atoms = ase.io.read( @@ -397,10 +394,10 @@ def get_atoms_object(self, idx): return atoms - def get_metadata(self, num_samples: int = 100): + def get_metadata(self, num_samples: int = 100) -> dict: return {} - def get_relaxed_energy(self, identifier): + def get_relaxed_energy(self, identifier) -> float: relaxed_atoms = ase.io.read( "".join(identifier.split(" ")[:-1]), **self.ase_read_args )[-1] @@ -517,7 +514,7 @@ def _load_dataset_get_ids(self, config: dict) -> list[int]: return list(range(sum(idlens))) - def get_atoms_object(self, idx: int) -> ase.Atoms: + def get_atoms(self, idx: int) -> ase.Atoms: """Get atoms object corresponding to datapoint idx. Useful to read other properties not in data object. Args: idx (int): index in dataset From 2f6ac22b2232bfc785751ba1ef464a966d38f7fa Mon Sep 17 00:00:00 2001 From: lbluque Date: Fri, 22 Mar 2024 15:04:44 -0700 Subject: [PATCH 53/64] revert to raise keyerror when data_keys are missing --- ocpmodels/preprocessing/atoms_to_graphs.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/ocpmodels/preprocessing/atoms_to_graphs.py b/ocpmodels/preprocessing/atoms_to_graphs.py index 0153b3e5dd..f88da439bd 100644 --- a/ocpmodels/preprocessing/atoms_to_graphs.py +++ b/ocpmodels/preprocessing/atoms_to_graphs.py @@ -205,7 +205,7 @@ def convert(self, atoms: ase.Atoms, sid=None): if self.r_distances and self.r_edges: data.distances = edge_distances if self.r_fixed: - fixed_idx = torch.zeros(natoms) + fixed_idx = torch.zeros(natoms, dtype=torch.int) if hasattr(atoms, "constraints"): from ase.constraints import FixAtoms @@ -217,12 +217,11 @@ def convert(self, atoms: ase.Atoms, sid=None): data.pbc = torch.tensor(atoms.pbc) if self.r_data_keys is not None: for data_key in self.r_data_keys: - if data_key in atoms.info: - data[data_key] = ( - atoms.info[data_key] - if isinstance(atoms.info[data_key], (int, float)) - else torch.Tensor(atoms.info[data_key]) - ) + data[data_key] = ( + atoms.info[data_key] + if isinstance(atoms.info[data_key], (int, float)) + else torch.Tensor(atoms.info[data_key]) + ) return data From b426842f59ce9328e84404c380b148636700d170 Mon Sep 17 00:00:00 2001 From: lbluque Date: Fri, 22 Mar 2024 16:11:55 -0700 Subject: [PATCH 54/64] cast tensors to list using tolist and vstack relaxation pos --- ocpmodels/common/utils.py | 7 ++++++- ocpmodels/trainers/ocp_trainer.py | 25 +++++++++++++++++++------ 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/ocpmodels/common/utils.py b/ocpmodels/common/utils.py index 13b588ed14..a07c49b7b6 100644 --- a/ocpmodels/common/utils.py +++ b/ocpmodels/common/utils.py @@ -969,7 +969,12 @@ def check_traj_files(batch, traj_dir) -> bool: if traj_dir is None: return False traj_dir = Path(traj_dir) - traj_files = [traj_dir / f"{id}.traj" for id in batch[0].sid.tolist()] + sid_list = ( + batch.sid.tolist() + if isinstance(batch.sid, torch.Tensor) + else list(batch.sid) + ) + traj_files = [traj_dir / f"{sid}.traj" for sid in sid_list] return all(fl.exists() for fl in traj_files) diff --git a/ocpmodels/trainers/ocp_trainer.py b/ocpmodels/trainers/ocp_trainer.py index c58489230b..8a810204c8 100644 --- a/ocpmodels/trainers/ocp_trainer.py +++ b/ocpmodels/trainers/ocp_trainer.py @@ -515,10 +515,18 @@ def predict( return predictions ### Get unique system identifiers - sids = list(batch.sid) + sids = ( + batch.sid.tolist() + if isinstance(batch.sid, torch.Tensor) + else list(batch.sid) + ) ## Support naming structure for OC20 S2EF if "fid" in batch: - fids = list(batch.fid) + fids = ( + batch.fid.tolist() + if isinstance(batch.fid, torch.Tensor) + else list(batch.fid) + ) systemids = [f"{sid}_{fid}" for sid, fid in zip(sids, fids)] else: systemids = [f"{sid}" for sid in sids] @@ -606,7 +614,12 @@ def run_relaxations(self, split="val"): ) if self.config["task"].get("write_pos", False): - systemids = [str(i) for i in list(relaxed_batch.sid)] + sid_list = ( + relaxed_batch.sid.tolist() + if isinstance(relaxed_batch.sid, torch.Tensor) + else list(relaxed_batch.sid) + ) + systemids = [str(sid) for sid in sid_list] natoms = relaxed_batch.natoms.tolist() positions = torch.split(relaxed_batch.pos, natoms) batch_relaxed_positions = [pos.tolist() for pos in positions] @@ -689,9 +702,8 @@ def run_relaxations(self, split="val"): # might be repeated to make no. of samples even across GPUs. _, idx = np.unique(gather_results["ids"], return_index=True) gather_results["ids"] = np.array(gather_results["ids"])[idx] - gather_results["pos"] = np.concatenate( - np.array(gather_results["pos"])[idx] - ) + + gather_results["pos"] = np.vstack(gather_results["pos"])[idx] gather_results["chunk_idx"] = np.cumsum( np.array(gather_results["chunk_idx"])[idx] )[ @@ -741,4 +753,5 @@ def run_relaxations(self, split="val"): if self.ema: self.ema.restore() + breakpoint() registry.unregister("set_deterministic_scatter") From 0709e46b6709ed4390a85661e88751dfac564cd8 Mon Sep 17 00:00:00 2001 From: lbluque Date: Fri, 22 Mar 2024 16:55:27 -0700 Subject: [PATCH 55/64] remove r_energy, r_forces, r_stress and r_data_keys from test_dataset w use_train_settings --- ocpmodels/trainers/base_trainer.py | 13 +++++++++++++ ocpmodels/trainers/ocp_trainer.py | 1 - 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/ocpmodels/trainers/base_trainer.py b/ocpmodels/trainers/base_trainer.py index 0f981c34c0..1d3e1b9d1d 100644 --- a/ocpmodels/trainers/base_trainer.py +++ b/ocpmodels/trainers/base_trainer.py @@ -317,6 +317,19 @@ def load_datasets(self) -> None: if self.config["test_dataset"].get("use_train_settings", True): test_config = self.config["dataset"].copy() test_config.update(self.config["test_dataset"]) + # if a2g_args are used remove keys for labels + if "a2g_args" in test_config["dataset"]: + test_config["dataset"]["a2g_args"] = { + k: v + for k, v in test_config["dataset"]["a2g_args"] + if k + not in ( + "r_energy", + "r_forces", + "r_stress", + "r_data_keys", + ) + } else: test_config = self.config["test_dataset"] diff --git a/ocpmodels/trainers/ocp_trainer.py b/ocpmodels/trainers/ocp_trainer.py index 8a810204c8..1737eda350 100644 --- a/ocpmodels/trainers/ocp_trainer.py +++ b/ocpmodels/trainers/ocp_trainer.py @@ -753,5 +753,4 @@ def run_relaxations(self, split="val"): if self.ema: self.ema.restore() - breakpoint() registry.unregister("set_deterministic_scatter") From 310468d3c4214087f444ab0f1bce9bd4035e1936 Mon Sep 17 00:00:00 2001 From: lbluque Date: Fri, 22 Mar 2024 17:02:21 -0700 Subject: [PATCH 56/64] fix test_dataset key --- ocpmodels/trainers/base_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ocpmodels/trainers/base_trainer.py b/ocpmodels/trainers/base_trainer.py index 1d3e1b9d1d..bfde165e95 100644 --- a/ocpmodels/trainers/base_trainer.py +++ b/ocpmodels/trainers/base_trainer.py @@ -318,10 +318,10 @@ def load_datasets(self) -> None: test_config = self.config["dataset"].copy() test_config.update(self.config["test_dataset"]) # if a2g_args are used remove keys for labels - if "a2g_args" in test_config["dataset"]: + if "a2g_args" in test_config["test_dataset"]: test_config["dataset"]["a2g_args"] = { k: v - for k, v in test_config["dataset"]["a2g_args"] + for k, v in test_config["test_dataset"]["a2g_args"] if k not in ( "r_energy", From 2422bb9b8fe52df2296e865f5a1f0f2a4e710df8 Mon Sep 17 00:00:00 2001 From: lbluque Date: Fri, 22 Mar 2024 17:09:48 -0700 Subject: [PATCH 57/64] fix test_dataset key! --- ocpmodels/trainers/base_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ocpmodels/trainers/base_trainer.py b/ocpmodels/trainers/base_trainer.py index bfde165e95..1ddcc3e540 100644 --- a/ocpmodels/trainers/base_trainer.py +++ b/ocpmodels/trainers/base_trainer.py @@ -318,7 +318,7 @@ def load_datasets(self) -> None: test_config = self.config["dataset"].copy() test_config.update(self.config["test_dataset"]) # if a2g_args are used remove keys for labels - if "a2g_args" in test_config["test_dataset"]: + if "a2g_args" in test_config: test_config["dataset"]["a2g_args"] = { k: v for k, v in test_config["test_dataset"]["a2g_args"] From 3f2f4bb81827f2a27f9546ec65ead66b1251ab58 Mon Sep 17 00:00:00 2001 From: lbluque Date: Tue, 26 Mar 2024 13:32:43 -0700 Subject: [PATCH 58/64] revert to not setting a2g_args dataset keys --- ocpmodels/trainers/base_trainer.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/ocpmodels/trainers/base_trainer.py b/ocpmodels/trainers/base_trainer.py index 1ddcc3e540..0f981c34c0 100644 --- a/ocpmodels/trainers/base_trainer.py +++ b/ocpmodels/trainers/base_trainer.py @@ -317,19 +317,6 @@ def load_datasets(self) -> None: if self.config["test_dataset"].get("use_train_settings", True): test_config = self.config["dataset"].copy() test_config.update(self.config["test_dataset"]) - # if a2g_args are used remove keys for labels - if "a2g_args" in test_config: - test_config["dataset"]["a2g_args"] = { - k: v - for k, v in test_config["test_dataset"]["a2g_args"] - if k - not in ( - "r_energy", - "r_forces", - "r_stress", - "r_data_keys", - ) - } else: test_config = self.config["test_dataset"] From ac3c1c31ad53be8943ee570409e534fc99384c24 Mon Sep 17 00:00:00 2001 From: Muhammed Shuaibi Date: Tue, 26 Mar 2024 21:00:08 +0000 Subject: [PATCH 59/64] fix debug predict logic --- ocpmodels/trainers/ocp_trainer.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/ocpmodels/trainers/ocp_trainer.py b/ocpmodels/trainers/ocp_trainer.py index 1737eda350..e4338f8f9f 100644 --- a/ocpmodels/trainers/ocp_trainer.py +++ b/ocpmodels/trainers/ocp_trainer.py @@ -411,6 +411,11 @@ def predict( results_file: Optional[str] = None, disable_tqdm: bool = False, ): + if self.is_debug and per_image: + raise FileNotFoundError( + "Predictions require debug mode to be turned off." + ) + ensure_fitted(self._unwrapped_model, warn=True) if distutils.is_master() and not disable_tqdm: @@ -537,16 +542,7 @@ def predict( # allow for lists of 'zero dim' arrays predictions[key] = np.vstack(predictions[key]).squeeze() - if self.is_debug: - try: - self.save_results(predictions, results_file) - except FileNotFoundError: - logging.warning( - "Predictions npz file not found. " - + "This file was not written since the trainer is running in debug mode." - ) - else: - self.save_results(predictions, results_file) + self.save_results(predictions, results_file) if self.ema: self.ema.restore() From a4087a71b8aebabbaee4ef5a5cdcda3f10015cd2 Mon Sep 17 00:00:00 2001 From: Muhammed Shuaibi Date: Thu, 28 Mar 2024 01:38:10 +0000 Subject: [PATCH 60/64] support numpy 1.26 --- ocpmodels/trainers/base_trainer.py | 2 +- ocpmodels/trainers/ocp_trainer.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/ocpmodels/trainers/base_trainer.py b/ocpmodels/trainers/base_trainer.py index 0f981c34c0..cf54422791 100644 --- a/ocpmodels/trainers/base_trainer.py +++ b/ocpmodels/trainers/base_trainer.py @@ -871,7 +871,7 @@ def save_results( else: if f"{k}_chunk_idx" in keys or k == "forces": gather_results[k] = np.concatenate( - np.array(gather_results[k])[idx] + np.array(gather_results[k], dtype=object)[idx] ) else: gather_results[k] = np.array(gather_results[k])[idx] diff --git a/ocpmodels/trainers/ocp_trainer.py b/ocpmodels/trainers/ocp_trainer.py index e4338f8f9f..5202876158 100644 --- a/ocpmodels/trainers/ocp_trainer.py +++ b/ocpmodels/trainers/ocp_trainer.py @@ -539,8 +539,7 @@ def predict( predictions["ids"].extend(systemids) for key in predictions: - # allow for lists of 'zero dim' arrays - predictions[key] = np.vstack(predictions[key]).squeeze() + predictions[key] = np.array(predictions[key], dtype=object) self.save_results(predictions, results_file) From 07ea92f34422d5d5c37ba16e488a463f54a47434 Mon Sep 17 00:00:00 2001 From: Muhammed Shuaibi Date: Thu, 28 Mar 2024 15:00:11 +0000 Subject: [PATCH 61/64] fix numpy version --- env.common.yml | 1 + ocpmodels/trainers/base_trainer.py | 2 +- ocpmodels/trainers/ocp_trainer.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/env.common.yml b/env.common.yml index 7a22ee2ce3..c12632b6df 100644 --- a/env.common.yml +++ b/env.common.yml @@ -7,6 +7,7 @@ dependencies: - ase=3.22.1 - black=22.3.0 - e3nn=0.4.4 +- numpy=1.23.5 - matplotlib - numba - orjson diff --git a/ocpmodels/trainers/base_trainer.py b/ocpmodels/trainers/base_trainer.py index cf54422791..0f981c34c0 100644 --- a/ocpmodels/trainers/base_trainer.py +++ b/ocpmodels/trainers/base_trainer.py @@ -871,7 +871,7 @@ def save_results( else: if f"{k}_chunk_idx" in keys or k == "forces": gather_results[k] = np.concatenate( - np.array(gather_results[k], dtype=object)[idx] + np.array(gather_results[k])[idx] ) else: gather_results[k] = np.array(gather_results[k])[idx] diff --git a/ocpmodels/trainers/ocp_trainer.py b/ocpmodels/trainers/ocp_trainer.py index 5202876158..e0c0302e75 100644 --- a/ocpmodels/trainers/ocp_trainer.py +++ b/ocpmodels/trainers/ocp_trainer.py @@ -539,7 +539,7 @@ def predict( predictions["ids"].extend(systemids) for key in predictions: - predictions[key] = np.array(predictions[key], dtype=object) + predictions[key] = np.array(predictions[key]) self.save_results(predictions, results_file) From 47f47e261e858cd6eecee34548138a1768a0b313 Mon Sep 17 00:00:00 2001 From: Muhammed Shuaibi Date: Thu, 28 Mar 2024 22:03:52 +0000 Subject: [PATCH 62/64] revert write_pos --- ocpmodels/trainers/ocp_trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ocpmodels/trainers/ocp_trainer.py b/ocpmodels/trainers/ocp_trainer.py index e0c0302e75..d4abc8ea80 100644 --- a/ocpmodels/trainers/ocp_trainer.py +++ b/ocpmodels/trainers/ocp_trainer.py @@ -698,7 +698,9 @@ def run_relaxations(self, split="val"): _, idx = np.unique(gather_results["ids"], return_index=True) gather_results["ids"] = np.array(gather_results["ids"])[idx] - gather_results["pos"] = np.vstack(gather_results["pos"])[idx] + gather_results["pos"] = np.concatenate( + np.array(gather_results["pos"])[idx] + ) gather_results["chunk_idx"] = np.cumsum( np.array(gather_results["chunk_idx"])[idx] )[ From ca9dbafe5ff2b2fe500119aa3ec38c77bf0a2fbb Mon Sep 17 00:00:00 2001 From: lbluque Date: Thu, 28 Mar 2024 15:53:54 -0700 Subject: [PATCH 63/64] no list casting on batch lists --- ocpmodels/common/utils.py | 2 +- ocpmodels/trainers/ocp_trainer.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ocpmodels/common/utils.py b/ocpmodels/common/utils.py index a07c49b7b6..bdc1544d17 100644 --- a/ocpmodels/common/utils.py +++ b/ocpmodels/common/utils.py @@ -972,7 +972,7 @@ def check_traj_files(batch, traj_dir) -> bool: sid_list = ( batch.sid.tolist() if isinstance(batch.sid, torch.Tensor) - else list(batch.sid) + else batch.sid ) traj_files = [traj_dir / f"{sid}.traj" for sid in sid_list] return all(fl.exists() for fl in traj_files) diff --git a/ocpmodels/trainers/ocp_trainer.py b/ocpmodels/trainers/ocp_trainer.py index d4abc8ea80..8876299e1e 100644 --- a/ocpmodels/trainers/ocp_trainer.py +++ b/ocpmodels/trainers/ocp_trainer.py @@ -523,14 +523,14 @@ def predict( sids = ( batch.sid.tolist() if isinstance(batch.sid, torch.Tensor) - else list(batch.sid) + else batch.sid ) ## Support naming structure for OC20 S2EF if "fid" in batch: fids = ( batch.fid.tolist() if isinstance(batch.fid, torch.Tensor) - else list(batch.fid) + else batch.fid ) systemids = [f"{sid}_{fid}" for sid, fid in zip(sids, fids)] else: @@ -612,7 +612,7 @@ def run_relaxations(self, split="val"): sid_list = ( relaxed_batch.sid.tolist() if isinstance(relaxed_batch.sid, torch.Tensor) - else list(relaxed_batch.sid) + else relaxed_batch.sid ) systemids = [str(sid) for sid in sid_list] natoms = relaxed_batch.natoms.tolist() From bdbba480eeb0aa70f1be1f2d52087c776750bb8a Mon Sep 17 00:00:00 2001 From: lbluque Date: Fri, 29 Mar 2024 09:21:14 -0700 Subject: [PATCH 64/64] pretty logging --- ocpmodels/trainers/ocp_trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ocpmodels/trainers/ocp_trainer.py b/ocpmodels/trainers/ocp_trainer.py index 8876299e1e..336628f003 100644 --- a/ocpmodels/trainers/ocp_trainer.py +++ b/ocpmodels/trainers/ocp_trainer.py @@ -594,7 +594,9 @@ def run_relaxations(self, split="val"): if check_traj_files( batch, self.config["task"]["relax_opt"].get("traj_dir", None) ): - logging.info(f"Skipping batch: {list(batch.sid)}") + logging.info( + f"Skipping batch: {batch.sid.tolist() if isinstance(batch.sid, torch.Tensor) else batch.sid}" + ) continue relaxed_batch = ml_relax(