Skip to content

Commit

Permalink
Added framework_indices keyword
Browse files Browse the repository at this point in the history
By using the framework_indices keyword the framework can be defined explicitly.


Former-commit-id: d8ffa1d
  • Loading branch information
jd15489 committed Nov 20, 2023
1 parent d6fe153 commit f595da5
Showing 1 changed file with 42 additions and 19 deletions.
61 changes: 42 additions & 19 deletions kinisi/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,9 @@ class ASEParser(Parser):
:param memory_limit: Upper limit in the amount of computer memory that the displacements can occupy in
gigabytes (GB). Optional, defaults to :py:attr:`8.`.
:param progress: Print progress bars to screen. Optional, defaults to :py:attr:`True`.
:param specie_indices:
:param masses:
:param framework_indices:
"""

def __init__(self,
Expand Down Expand Up @@ -268,7 +271,7 @@ def __init__(self,
else:
indices = self.get_framework(structure, specie_indices)
else:
raise TypeError('Unrecognized type for Specie or Indices')
raise TypeError('Unrecognized type for specie or specie_indices')

self.coords_check = coords[0]

Expand Down Expand Up @@ -439,6 +442,9 @@ class PymatgenParser(Parser):
:param memory_limit: Upper limit in the amount of computer memory that the displacements can occupy in
gigabytes (GB). Optional, defaults to :py:attr:`8.`.
:param progress: Print progress bars to screen. Optional, defaults to :py:attr:`True`.
:param specie_indices:
:param masses:
:param framework_indices:
"""

def __init__(self,
Expand Down Expand Up @@ -470,7 +476,7 @@ def __init__(self,
else:
indices = self.get_framework(structure, specie_indices)
else:
raise TypeError('Unrecognized type for Specie or Indices')
raise TypeError('Unrecognized type for specie or specie_indices')

self.coords_check = coords[0]

Expand Down Expand Up @@ -663,6 +669,7 @@ class MDAnalysisParser(Parser):
must be set to None for this to function. Molecules can be specificed as a list of lists of indices.
This inner lists must all be on the same length.
:param masses: Optional, list of masses associated with the indices in specie_indices. Must be same shape as specie_indices.
:param framework_indices: Optional, list of framework indices to be used to correct framework drift.
"""

def __init__(self,
Expand Down Expand Up @@ -690,16 +697,16 @@ def __init__(self,
structure, coords, latt, volume = self.get_structure_coords_latt(universe, sub_sample_atoms, sub_sample_traj,
progress)
if specie != None:
indices = self.get_indices(structure, specie)
indices = self.get_indices(structure, specie, framework_indices)
elif isinstance(specie_indices, (list, tuple)):
if isinstance(specie_indices[0], (list, tuple)):
coords, indices = self.get_molecules(
structure, coords, specie_indices,
masses) #Warning: This function changes the structure without changing the MDAnalysis object
structure, coords, specie_indices, masses, framework_indices
) #Warning: This function changes the structure without changing the MDAnalysis object
else:
indices = self.get_framework(structure, specie_indices)
indices = self.get_framework(structure, specie_indices, framework_indices)
else:
raise TypeError('Unrecognized type for Specie or Indices')
raise TypeError('Unrecognized type for specie or specie_indices')

self.coords_check = coords[0]

Expand Down Expand Up @@ -734,7 +741,7 @@ def get_structure_coords_latt(
:param progress: Print progress bars to screen. Optional, defaults to :py:attr:`True`.
:return: Tuple containing: initial structure, fractional coordinates for all atoms,
lattice descriptions, and the cell volume
lattice descriptions, and the cell volume.
"""
coords, latt = [], []
first = True
Expand All @@ -757,12 +764,14 @@ def get_structure_coords_latt(
return structure, coords, latt, volume

@staticmethod
def get_indices(structure: "MDAnalysis.universe.Universe", specie: str) -> Tuple[np.ndarray, np.ndarray]:
def get_indices(structure: "MDAnalysis.universe.Universe", specie: str,
framework_indices: List[int]) -> Tuple[np.ndarray, np.ndarray]:
"""
Determine framework and non-framework indices for an :py:mod:`MDAnalysis` compatible file.
:param structure: Initial structure.
:param specie: Specie to calculate diffusivity for as a String, e.g. :py:attr:`'Li'`.
:param framework_indices: Indices of framework to be used in drift correction. If set to None will return all indices that are not specie.
:return: Tuple containing: indices for the atoms in the trajectory used in the calculation of the
diffusion and indices of framework atoms.
Expand All @@ -780,11 +789,16 @@ def get_indices(structure: "MDAnalysis.universe.Universe", specie: str) -> Tuple
indices.append(i)
else:
drift_indices.append(i)

if isinstance(framework_indices, (list, tuple)):
drift_indices = framework_indices

return indices, drift_indices

@staticmethod
def get_molecules(structure: "MDAnalysis.universe.Universe", coords: List[np.ndarray], indices: List[int],
masses: List[float]) -> Tuple[np.ndarray, np.ndarray, Tuple[np.ndarray, np.ndarray]]:
masses: List[float],
framework_indices: List[int]) -> Tuple[np.ndarray, np.ndarray, Tuple[np.ndarray, np.ndarray]]:
"""
Determine framework and non-framework indices for an :py:mod:`MDAnalysis` compatible file when specie_indices are provided and contain multiple molecules. Warning: This function changes the structure without changing the MDAnalysis object
Expand All @@ -793,7 +807,7 @@ def get_molecules(structure: "MDAnalysis.universe.Universe", coords: List[np.nda
:param indices: indices for the atoms in the molecules in the trajectory used in the calculation
of the diffusion.
:param masses: Masses associated with the molecule in indices.
:param framework_indices: Indices of framework to be used in drift correction. If set to None will return all indices that are not indices.
:return: Tuple containing: Tuple containing: fractional coordinates for centers and framework atoms
and Tuple containing: indices for centers used in the calculation
Expand All @@ -807,9 +821,12 @@ def get_molecules(structure: "MDAnalysis.universe.Universe", coords: List[np.nda

n_molecules = indices.shape[0]

for i, site in enumerate(structure):
if i not in indices:
drift_indices.append(i)
if isinstance(framework_indices, (list, tuple)):
drift_indices = framework_indices
else:
for i, site in enumerate(structure):
if i not in indices:
drift_indices.append(i)

if masses == None:
weights = None
Expand Down Expand Up @@ -839,22 +856,28 @@ def get_molecules(structure: "MDAnalysis.universe.Universe", coords: List[np.nda
return new_coords, (new_indices, new_drift_indices)

@staticmethod
def get_framework(structure: "MDAnalysis.universe.Universe", indices: List[int]) -> Tuple[np.ndarray, np.ndarray]:
def get_framework(structure: "MDAnalysis.universe.Universe", indices: List[int],
framework_indices: List[int]) -> Tuple[np.ndarray, np.ndarray]:
"""
Determine the framework indices from an :py:mod:`MDAnalysis` compatible file when indices are provided
:param structure: Initial structure.
:param indices: Indices for the atoms in the trajectory used in the calculation of the
diffusion.
:param framework_indices: Indices of framework to be used in drift correction. If set to None will return all indices that are not in indices.
:return: Tuple containing: indices for the atoms in the trajectory used in the calculation of the
diffusion and indices of framework atoms.
"""
drift_indices = []
if isinstance(framework_indices, (list, tuple)):
drift_indices = framework_indices
else:
drift_indices = []

for i, site in enumerate(structure):
if i not in indices:
drift_indices.append(i)

for i, site in enumerate(structure):
if i not in indices:
drift_indices.append(i)
return indices, drift_indices


Expand Down

0 comments on commit f595da5

Please sign in to comment.