Skip to content

Commit

Permalink
removed duplicate get_framework functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jd15489 committed Jan 7, 2024
1 parent d47e764 commit 9594740
Showing 1 changed file with 29 additions and 76 deletions.
105 changes: 29 additions & 76 deletions kinisi/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def __init__(self,
if isinstance(specie_indices[0], (list, tuple)):
coords, indices = _get_molecules(structure, coords, specie_indices, masses, framework_indices)
else:
indices = self.get_framework(structure, specie_indices, framework_indices)
indices = _get_framework(structure, specie_indices, framework_indices)
else:
raise TypeError('Unrecognized type for specie or specie_indices')

Expand Down Expand Up @@ -345,30 +345,6 @@ def get_indices(structure: "ase.atoms.Atoms", specie: "str",

return indices, drift_indices

@staticmethod
def get_framework(structure: "MDAnalysis.universe.Universe", indices: List[int],
framework_indices: List[int]) -> Tuple[np.ndarray, np.ndarray]:
"""
Determine the framework indices from an :py:mod:`ase` compatible file when indices are provided
:param structure: Initial structure.
:param indices: Indices for the atoms in the trajectory used in the calculation of the
diffusion.
:param framework_indices: Indices of framework to be used in drift correction. If set to None will return all indices that are not in indices.
:return: Tuple containing: indices for the atoms in the trajectory used in the calculation of the
diffusion and indices of framework atoms.
"""
if isinstance(framework_indices, (list, tuple)):
drift_indices = framework_indices
else:
drift_indices = []

for i, site in enumerate(structure):
if i not in indices:
drift_indices.append(i)
return indices, drift_indices


class PymatgenParser(Parser):
"""
Expand Down Expand Up @@ -430,7 +406,7 @@ def __init__(self,
if isinstance(specie_indices[0], (list, tuple)):
coords, indices = _get_molecules(structure, coords, specie_indices, masses, framework_indices)
else:
indices = self.get_framework(structure, specie_indices, framework_indices)
indices = _get_framework(structure, specie_indices, framework_indices)
else:
raise TypeError('Unrecognized type for specie or specie_indices')

Expand Down Expand Up @@ -520,30 +496,6 @@ def get_indices(structure: "pymatgen.core.structure.Structure",

return indices, drift_indices

@staticmethod
def get_framework(structure: "pymatgen.core.structure.Structure", indices: List[int],
framework_indices: List[int]) -> Tuple[np.ndarray, np.ndarray]:
"""
Determine the framework indices from an :py:mod:`pymatgen` compatible file when indices are provided
:param structure: Initial structure.
:param indices: Indices for the atoms in the trajectory used in the calculation of the
diffusion.
:param framework_indices: Indices of framework to be used in drift correction. If set to None will return all indices that are not in indices.
:return: Tuple containing: indices for the atoms in the trajectory used in the calculation of the
diffusion and indices of framework atoms.
"""
if isinstance(framework_indices, (list, tuple)):
drift_indices = framework_indices
else:
drift_indices = []

for i, site in enumerate(structure):
if i not in indices:
drift_indices.append(i)
return indices, drift_indices


class MDAnalysisParser(Parser):
"""
Expand Down Expand Up @@ -615,7 +567,7 @@ def __init__(self,
structure, coords, specie_indices, masses, framework_indices
) #Warning: This function changes the structure without changing the MDAnalysis object
else:
indices = self.get_framework(structure, specie_indices, framework_indices)
indices = _get_framework(structure, specie_indices, framework_indices)
else:
raise TypeError('Unrecognized type for specie or specie_indices')

Expand Down Expand Up @@ -707,31 +659,6 @@ def get_indices(structure: "MDAnalysis.universe.Universe", specie: str,

return indices, drift_indices

@staticmethod
def get_framework(structure: "MDAnalysis.universe.Universe", indices: List[int],
framework_indices: List[int]) -> Tuple[np.ndarray, np.ndarray]:
"""
Determine the framework indices from an :py:mod:`MDAnalysis` compatible file when indices are provided
:param structure: Initial structure.
:param indices: Indices for the atoms in the trajectory used in the calculation of the
diffusion.
:param framework_indices: Indices of framework to be used in drift correction. If set to None will return all indices that are not in indices.
:return: Tuple containing: indices for the atoms in the trajectory used in the calculation of the
diffusion and indices of framework atoms.
"""
if isinstance(framework_indices, (list, tuple)):
drift_indices = framework_indices
else:
drift_indices = []

for i, site in enumerate(structure):
if i not in indices:
drift_indices.append(i)

return indices, drift_indices


def _get_matrix(dimensions: np.ndarray) -> np.ndarray:
"""
Expand Down Expand Up @@ -819,3 +746,29 @@ def _get_molecules(structure: "ase.atoms.Atoms" or "pymatgen.core.structure.Stru
new_coords = np.expand_dims(new_coords, axis=2)

return new_coords, (new_indices, new_drift_indices)


def _get_framework(structure: "ase.atoms.Atoms" or "pymatgen.core.structure.Structure"
or "MDAnalysis.universe.Universe", indices: List[int],
framework_indices: List[int]) -> Tuple[np.ndarray, np.ndarray]:
"""
Determine the framework indices from an :py:mod:`ase` or :py:mod:`pymatgen` or :py:mod:`MDAnalysis` compatible file when indices are provided
:param structure: Initial structure.
:param indices: Indices for the atoms in the trajectory used in the calculation of the
diffusion.
:param framework_indices: Indices of framework to be used in drift correction. If set to None will return all indices that are not in indices.
:return: Tuple containing: indices for the atoms in the trajectory used in the calculation of the
diffusion and indices of framework atoms.
"""
if isinstance(framework_indices, (list, tuple)):
drift_indices = framework_indices
else:
drift_indices = []

for i, site in enumerate(structure):
if i not in indices:
drift_indices.append(i)

return indices, drift_indices

0 comments on commit 9594740

Please sign in to comment.