From 685aa4c3c6267f4b62ca9aff732cd0aae4b5ca58 Mon Sep 17 00:00:00 2001 From: jd15489 Date: Thu, 23 Nov 2023 16:54:08 +0000 Subject: [PATCH] Extended framework_indices to all parsers --- kinisi/parser.py | 80 ++++++++++++++++++++++++++++++------------------ 1 file changed, 50 insertions(+), 30 deletions(-) diff --git a/kinisi/parser.py b/kinisi/parser.py index 5e53341f..a3ed6d58 100644 --- a/kinisi/parser.py +++ b/kinisi/parser.py @@ -263,14 +263,12 @@ def __init__(self, structure, coords, latt = self.get_structure_coords_latt(atoms, sub_sample_traj, progress) if specie != None: - indices = self.get_indices(structure, specie) + indices = self.get_indices(structure, specie, framework_indices) elif isinstance(specie_indices, (list, tuple)): if isinstance(specie_indices[0], (list, tuple)): - coords, indices = self.get_molecules( - structure, coords, specie_indices, - masses) #Warning: This function changes the structure without changing the MDAnalysis object + coords, indices = self.get_molecules(structure, coords, specie_indices, masses, framework_indices) else: - indices = self.get_framework(structure, specie_indices) + indices = self.get_framework(structure, specie_indices, framework_indices) else: raise TypeError('Unrecognized type for specie or specie_indices') @@ -316,7 +314,8 @@ def get_structure_coords_latt( return structure, coords, latt @staticmethod - def get_indices(structure: "ase.atoms.Atoms", specie: "str") -> Tuple[np.ndarray, np.ndarray]: + def get_indices(structure: "ase.atoms.Atoms", specie: "str", + framework_indices: List[int]) -> Tuple[np.ndarray, np.ndarray]: """ Determine framework and non-framework indices for a :py:mod:`pymatgen` compatible file. @@ -337,11 +336,15 @@ def get_indices(structure: "ase.atoms.Atoms", specie: "str") -> Tuple[np.ndarray drift_indices.append(i) if len(indices) == 0: raise ValueError("There are no species selected to calculate the mean-squared displacement of.") + + if isinstance(framework_indices, (list, tuple)): + drift_indices = framework_indices + return indices, drift_indices @staticmethod - def get_molecules(structure: "MDAnalysis.universe.Universe", coords: List[np.ndarray], indices: List[int], - masses: List[float]) -> Tuple[np.ndarray, np.ndarray, Tuple[np.ndarray, np.ndarray]]: + def get_molecules(structure: "ase.atoms.Atoms", coords: List[np.ndarray], indices: List[int], masses: List[float], + framework_indices) -> Tuple[np.ndarray, np.ndarray, Tuple[np.ndarray, np.ndarray]]: """ Determine framework and non-framework indices for an :py:mod:`MDAnalysis` compatible file when specie_indices are provided and contain multiple molecules. Warning: This function changes the structure without changing the MDAnalysis object @@ -364,9 +367,12 @@ def get_molecules(structure: "MDAnalysis.universe.Universe", coords: List[np.nda n_molecules = indices.shape[0] - for i, site in enumerate(structure): - if i not in indices: - drift_indices.append(i) + if isinstance(framework_indices, (list, tuple)): + drift_indices = framework_indices + else: + for i, site in enumerate(structure): + if i not in indices: + drift_indices.append(i) if masses == None: weights = None @@ -396,7 +402,8 @@ def get_molecules(structure: "MDAnalysis.universe.Universe", coords: List[np.nda return new_coords, (new_indices, new_drift_indices) @staticmethod - def get_framework(structure: "MDAnalysis.universe.Universe", indices: List[int]) -> Tuple[np.ndarray, np.ndarray]: + def get_framework(structure: "MDAnalysis.universe.Universe", indices: List[int], + framework_indices: List[int]) -> Tuple[np.ndarray, np.ndarray]: """ Determine the framework indices from an :py:mod:`MDAnalysis` compatible file when indices are provided @@ -407,7 +414,10 @@ def get_framework(structure: "MDAnalysis.universe.Universe", indices: List[int]) :return: Tuple containing: indices for the atoms in the trajectory used in the calculation of the diffusion and indices of framework atoms. """ - drift_indices = [] + if isinstance(framework_indices, (list, tuple)): + drift_indices = framework_indices + else: + drift_indices = [] for i, site in enumerate(structure): if i not in indices: @@ -468,14 +478,12 @@ def __init__(self, structure, coords, latt = self.get_structure_coords_latt(structures, sub_sample_traj, progress) if specie != None: - indices = self.get_indices(structure, specie) + indices = self.get_indices(structure, specie, framework_indices) elif isinstance(specie_indices, (list, tuple)): if isinstance(specie_indices[0], (list, tuple)): - coords, indices = self.get_molecules( - structure, coords, specie_indices, - masses) #Warning: This function changes the structure without changing the MDAnalysis object + coords, indices = self.get_molecules(structure, coords, specie_indices, masses, framework_indices) else: - indices = self.get_framework(structure, specie_indices) + indices = self.get_framework(structure, specie_indices, framework_indices) else: raise TypeError('Unrecognized type for specie or specie_indices') @@ -531,12 +539,12 @@ def get_structure_coords_latt( return structure, coords, latt @staticmethod - def get_indices( - structure: "pymatgen.core.structure.Structure", - specie: Union["pymatgen.core.periodic_table.Element", "pymatgen.core.periodic_table.Specie", - "pymatgen.core.periodic_table.Species", List["pymatgen.core.periodic_table.Element"], - List["pymatgen.core.periodic_table.Specie"], List["pymatgen.core.periodic_table.Species"]] - ) -> Tuple[np.ndarray, np.ndarray]: + def get_indices(structure: "pymatgen.core.structure.Structure", + specie: Union["pymatgen.core.periodic_table.Element", "pymatgen.core.periodic_table.Specie", + "pymatgen.core.periodic_table.Species", List["pymatgen.core.periodic_table.Element"], + List["pymatgen.core.periodic_table.Specie"], + List["pymatgen.core.periodic_table.Species"]], + framework_indices: List[int]) -> Tuple[np.ndarray, np.ndarray]: """ Determine framework and non-framework indices for a :py:mod:`pymatgen` compatible file. @@ -557,11 +565,16 @@ def get_indices( drift_indices.append(i) if len(indices) == 0: raise ValueError("There are no species selected to calculate the mean-squared displacement of.") + + if isinstance(framework_indices, (list, tuple)): + drift_indices = framework_indices + return indices, drift_indices @staticmethod def get_molecules(structure: "MDAnalysis.universe.Universe", coords: List[np.ndarray], indices: List[int], - masses: List[float]) -> Tuple[np.ndarray, np.ndarray, Tuple[np.ndarray, np.ndarray]]: + masses: List[float], + framework_indices: List[int]) -> Tuple[np.ndarray, np.ndarray, Tuple[np.ndarray, np.ndarray]]: """ Determine framework and non-framework indices for an :py:mod:`MDAnalysis` compatible file when specie_indices are provided and contain multiple molecules. Warning: This function changes the structure without changing the MDAnalysis object @@ -584,9 +597,12 @@ def get_molecules(structure: "MDAnalysis.universe.Universe", coords: List[np.nda n_molecules = indices.shape[0] - for i, site in enumerate(structure): - if i not in indices: - drift_indices.append(i) + if isinstance(framework_indices, (list, tuple)): + drift_indices = framework_indices + else: + for i, site in enumerate(structure): + if i not in indices: + drift_indices.append(i) if masses == None: weights = None @@ -616,7 +632,8 @@ def get_molecules(structure: "MDAnalysis.universe.Universe", coords: List[np.nda return new_coords, (new_indices, new_drift_indices) @staticmethod - def get_framework(structure: "MDAnalysis.universe.Universe", indices: List[int]) -> Tuple[np.ndarray, np.ndarray]: + def get_framework(structure: "MDAnalysis.universe.Universe", indices: List[int], + framework_indices: List[int]) -> Tuple[np.ndarray, np.ndarray]: """ Determine the framework indices from an :py:mod:`MDAnalysis` compatible file when indices are provided @@ -627,7 +644,10 @@ def get_framework(structure: "MDAnalysis.universe.Universe", indices: List[int]) :return: Tuple containing: indices for the atoms in the trajectory used in the calculation of the diffusion and indices of framework atoms. """ - drift_indices = [] + if isinstance(framework_indices, (list, tuple)): + drift_indices = framework_indices + else: + drift_indices = [] for i, site in enumerate(structure): if i not in indices: