Skip to content

Commit

Permalink
Extended framework_indices to all parsers
Browse files Browse the repository at this point in the history
  • Loading branch information
jd15489 committed Nov 23, 2023
1 parent 4398a1a commit 685aa4c
Showing 1 changed file with 50 additions and 30 deletions.
80 changes: 50 additions & 30 deletions kinisi/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,14 +263,12 @@ def __init__(self,
structure, coords, latt = self.get_structure_coords_latt(atoms, sub_sample_traj, progress)

if specie != None:
indices = self.get_indices(structure, specie)
indices = self.get_indices(structure, specie, framework_indices)
elif isinstance(specie_indices, (list, tuple)):
if isinstance(specie_indices[0], (list, tuple)):
coords, indices = self.get_molecules(
structure, coords, specie_indices,
masses) #Warning: This function changes the structure without changing the MDAnalysis object
coords, indices = self.get_molecules(structure, coords, specie_indices, masses, framework_indices)
else:
indices = self.get_framework(structure, specie_indices)
indices = self.get_framework(structure, specie_indices, framework_indices)
else:
raise TypeError('Unrecognized type for specie or specie_indices')

Expand Down Expand Up @@ -316,7 +314,8 @@ def get_structure_coords_latt(
return structure, coords, latt

@staticmethod
def get_indices(structure: "ase.atoms.Atoms", specie: "str") -> Tuple[np.ndarray, np.ndarray]:
def get_indices(structure: "ase.atoms.Atoms", specie: "str",
framework_indices: List[int]) -> Tuple[np.ndarray, np.ndarray]:
"""
Determine framework and non-framework indices for a :py:mod:`pymatgen` compatible file.
Expand All @@ -337,11 +336,15 @@ def get_indices(structure: "ase.atoms.Atoms", specie: "str") -> Tuple[np.ndarray
drift_indices.append(i)
if len(indices) == 0:
raise ValueError("There are no species selected to calculate the mean-squared displacement of.")

if isinstance(framework_indices, (list, tuple)):
drift_indices = framework_indices

return indices, drift_indices

@staticmethod
def get_molecules(structure: "MDAnalysis.universe.Universe", coords: List[np.ndarray], indices: List[int],
masses: List[float]) -> Tuple[np.ndarray, np.ndarray, Tuple[np.ndarray, np.ndarray]]:
def get_molecules(structure: "ase.atoms.Atoms", coords: List[np.ndarray], indices: List[int], masses: List[float],
framework_indices) -> Tuple[np.ndarray, np.ndarray, Tuple[np.ndarray, np.ndarray]]:
"""
Determine framework and non-framework indices for an :py:mod:`MDAnalysis` compatible file when specie_indices are provided and contain multiple molecules. Warning: This function changes the structure without changing the MDAnalysis object
Expand All @@ -364,9 +367,12 @@ def get_molecules(structure: "MDAnalysis.universe.Universe", coords: List[np.nda

n_molecules = indices.shape[0]

for i, site in enumerate(structure):
if i not in indices:
drift_indices.append(i)
if isinstance(framework_indices, (list, tuple)):
drift_indices = framework_indices
else:
for i, site in enumerate(structure):
if i not in indices:
drift_indices.append(i)

if masses == None:
weights = None
Expand Down Expand Up @@ -396,7 +402,8 @@ def get_molecules(structure: "MDAnalysis.universe.Universe", coords: List[np.nda
return new_coords, (new_indices, new_drift_indices)

@staticmethod
def get_framework(structure: "MDAnalysis.universe.Universe", indices: List[int]) -> Tuple[np.ndarray, np.ndarray]:
def get_framework(structure: "MDAnalysis.universe.Universe", indices: List[int],
framework_indices: List[int]) -> Tuple[np.ndarray, np.ndarray]:
"""
Determine the framework indices from an :py:mod:`MDAnalysis` compatible file when indices are provided
Expand All @@ -407,7 +414,10 @@ def get_framework(structure: "MDAnalysis.universe.Universe", indices: List[int])
:return: Tuple containing: indices for the atoms in the trajectory used in the calculation of the
diffusion and indices of framework atoms.
"""
drift_indices = []
if isinstance(framework_indices, (list, tuple)):
drift_indices = framework_indices
else:
drift_indices = []

for i, site in enumerate(structure):
if i not in indices:
Expand Down Expand Up @@ -468,14 +478,12 @@ def __init__(self,
structure, coords, latt = self.get_structure_coords_latt(structures, sub_sample_traj, progress)

if specie != None:
indices = self.get_indices(structure, specie)
indices = self.get_indices(structure, specie, framework_indices)
elif isinstance(specie_indices, (list, tuple)):
if isinstance(specie_indices[0], (list, tuple)):
coords, indices = self.get_molecules(
structure, coords, specie_indices,
masses) #Warning: This function changes the structure without changing the MDAnalysis object
coords, indices = self.get_molecules(structure, coords, specie_indices, masses, framework_indices)
else:
indices = self.get_framework(structure, specie_indices)
indices = self.get_framework(structure, specie_indices, framework_indices)
else:
raise TypeError('Unrecognized type for specie or specie_indices')

Expand Down Expand Up @@ -531,12 +539,12 @@ def get_structure_coords_latt(
return structure, coords, latt

@staticmethod
def get_indices(
structure: "pymatgen.core.structure.Structure",
specie: Union["pymatgen.core.periodic_table.Element", "pymatgen.core.periodic_table.Specie",
"pymatgen.core.periodic_table.Species", List["pymatgen.core.periodic_table.Element"],
List["pymatgen.core.periodic_table.Specie"], List["pymatgen.core.periodic_table.Species"]]
) -> Tuple[np.ndarray, np.ndarray]:
def get_indices(structure: "pymatgen.core.structure.Structure",
specie: Union["pymatgen.core.periodic_table.Element", "pymatgen.core.periodic_table.Specie",
"pymatgen.core.periodic_table.Species", List["pymatgen.core.periodic_table.Element"],
List["pymatgen.core.periodic_table.Specie"],
List["pymatgen.core.periodic_table.Species"]],
framework_indices: List[int]) -> Tuple[np.ndarray, np.ndarray]:
"""
Determine framework and non-framework indices for a :py:mod:`pymatgen` compatible file.
Expand All @@ -557,11 +565,16 @@ def get_indices(
drift_indices.append(i)
if len(indices) == 0:
raise ValueError("There are no species selected to calculate the mean-squared displacement of.")

if isinstance(framework_indices, (list, tuple)):
drift_indices = framework_indices

return indices, drift_indices

@staticmethod
def get_molecules(structure: "MDAnalysis.universe.Universe", coords: List[np.ndarray], indices: List[int],
masses: List[float]) -> Tuple[np.ndarray, np.ndarray, Tuple[np.ndarray, np.ndarray]]:
masses: List[float],
framework_indices: List[int]) -> Tuple[np.ndarray, np.ndarray, Tuple[np.ndarray, np.ndarray]]:
"""
Determine framework and non-framework indices for an :py:mod:`MDAnalysis` compatible file when specie_indices are provided and contain multiple molecules. Warning: This function changes the structure without changing the MDAnalysis object
Expand All @@ -584,9 +597,12 @@ def get_molecules(structure: "MDAnalysis.universe.Universe", coords: List[np.nda

n_molecules = indices.shape[0]

for i, site in enumerate(structure):
if i not in indices:
drift_indices.append(i)
if isinstance(framework_indices, (list, tuple)):
drift_indices = framework_indices
else:
for i, site in enumerate(structure):
if i not in indices:
drift_indices.append(i)

if masses == None:
weights = None
Expand Down Expand Up @@ -616,7 +632,8 @@ def get_molecules(structure: "MDAnalysis.universe.Universe", coords: List[np.nda
return new_coords, (new_indices, new_drift_indices)

@staticmethod
def get_framework(structure: "MDAnalysis.universe.Universe", indices: List[int]) -> Tuple[np.ndarray, np.ndarray]:
def get_framework(structure: "MDAnalysis.universe.Universe", indices: List[int],
framework_indices: List[int]) -> Tuple[np.ndarray, np.ndarray]:
"""
Determine the framework indices from an :py:mod:`MDAnalysis` compatible file when indices are provided
Expand All @@ -627,7 +644,10 @@ def get_framework(structure: "MDAnalysis.universe.Universe", indices: List[int])
:return: Tuple containing: indices for the atoms in the trajectory used in the calculation of the
diffusion and indices of framework atoms.
"""
drift_indices = []
if isinstance(framework_indices, (list, tuple)):
drift_indices = framework_indices
else:
drift_indices = []

for i, site in enumerate(structure):
if i not in indices:
Expand Down

0 comments on commit 685aa4c

Please sign in to comment.