Skip to content

Commit

Permalink
Add a function to filter exploding frames in MDAnalysis
Browse files Browse the repository at this point in the history
Fixes #26
  • Loading branch information
jbarnoud committed Oct 13, 2023
1 parent 69e9e46 commit 517b617
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
Module providing conversion and utility methods for working with Narupa and MDAnalysis.
"""
from .converter import mdanalysis_to_frame_data, frame_data_to_mdanalysis
from .universe import NarupaParser, NarupaReader
from .universe import NarupaParser, NarupaReader, explosion_mask
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,52 @@ def batched(iterable, n):
it = iter(iterable)
while batch := tuple(islice(it, n)):
yield batch


def explosion_mask(trajectory, max_displacement):
"""
A mask to select the frames that are NOT explosions.
:param trajectory: The trajectory to mask.
:param max_displacement: The maximum displacement in Angstrom along
a given axis before the frame is considered as exploding.
:return: A list with one boolean per frame in the trajectory.
The boolean is True is the frame is NOT exploding.
:raises KeyError: if the frames in the trajectory do not have
a reset counter. This can be the case for narupa trajectories
recorded from a server that does not keep track of resets in
the frames, or if the universe has not been built from a narupa
trajectory recording.
Here is an example of how to write a trajectory that excludes the
exploding frames:
.. code:: python
import MDAnalysis as mda
from narupa.mdanalysis import NarupaParser, NarupaReader, explosion_mask
u = mda.Universe(
'hello.traj',
format=NarupaReader,
topology_format=NarupaParser,
)
mask = explosion_mask(u.trajectory, 100)
u.atoms.write('hello.pdb')
u.atoms.write('hello.xtc', frames=mask)
"""
mask = []
first = trajectory[0]
previous = first.positions
prev_reset = first.data["system.reset.counter"]
for i, ts in enumerate(trajectory):
reset = ts.data["system.reset.counter"]
diff = np.abs(ts.positions - previous)
has_reset = reset != prev_reset
is_explosion = (diff.max() > 100 and not has_reset) or np.any(
~np.isfinite(ts.positions)
)
mask.append(not is_explosion)
previous = ts.positions
prev_reset = reset
return mask

0 comments on commit 517b617

Please sign in to comment.