Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Read a trajectory recording as an MDAnalysis Universe #23

Merged
merged 4 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .frame_data import FrameData
from .frame_data import FrameData, MissingDataError
from .frame_client import FrameClient
from .frame_publisher import FramePublisher, FRAME_SERVICE_NAME
from .frame_server import FrameServer
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
Module providing conversion and utility methods for working with Narupa and MDAnalysis.
"""
from .converter import mdanalysis_to_frame_data, frame_data_to_mdanalysis
from .universe import NarupaParser, NarupaReader
137 changes: 137 additions & 0 deletions python-libraries/narupa-mdanalysis/src/narupa/mdanalysis/recordings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
from typing import Generator
from narupa.protocol.trajectory import GetFrameResponse
from narupa.trajectory import FrameData, MissingDataError

MAGIC_NUMBER = 6661355757386708963


class Unpacker:
"""
Unpack data representations from a buffer of bytes.

The unpacking methods return the requested value from the next bytes of the
buffer and move the cursor forward. They raise an `IndexError` if the
buffer is too short to fullfil the request.
"""

_buffer: bytes
_cursor: int

def __init__(self, data: bytes):
self._buffer = data
self._cursor = 0

def unpack_bytes(self, n_bytes: int) -> bytes:
"""
Get the next `n_bytes` bytes from the buffer.

The method raises a `ValueError` if the requested number of bytes is
negative.
"""
if n_bytes < 0:
raise ValueError("Cannot unpack a negative number of bytes.")
end = self._cursor + n_bytes
if end >= len(self._buffer):
raise IndexError("Not enough bytes left in the buffer.")
bytes_to_return = self._buffer[self._cursor : end]
self._cursor = end
return bytes_to_return

def unpack_u64(self) -> int:
"""
Get an unsigned 64 bits integer from the next bytes of the buffer.
"""
buffer = self.unpack_bytes(8)
return int.from_bytes(buffer, "little", signed=False)

def unpack_u128(self) -> int:
"""
Get an unsigned 128 bits integer from the next bytes of the buffer.
"""
buffer = self.unpack_bytes(16)
return int.from_bytes(buffer, "little", signed=False)


class InvalidMagicNumber(Exception):
"""
The magic number read from a file is not the expected one.

The file may not be in the correct format, or it may be corrupted.
"""


class UnsuportedFormatVersion(Exception):
"""
The version of the file format is not supported by hthe parser.
"""

def __init__(self, format_version: int, supported_format_versions: tuple[int]):
self._format_version = format_version
self._supported_format_versions = supported_format_versions

def __str__(self) -> str:
return (
f"Version {self._format_version} of the format is not supported "
f"by this parser. The supported versions are "
f"{self._supported_format_versions}."
)


def iter_trajectory_recording(unpacker: Unpacker) -> Generator:
supported_format_versions = (2,)
magic_number = unpacker.unpack_u64()
if magic_number != MAGIC_NUMBER:
raise InvalidMagicNumber
format_version = unpacker.unpack_u64()
if format_version not in supported_format_versions:
raise UnsuportedFormatVersion(format_version, supported_format_versions)
while True:
try:
elapsed = unpacker.unpack_u128()
record_size = unpacker.unpack_u64()
buffer = unpacker.unpack_bytes(record_size)
except IndexError:
break
get_frame_response = GetFrameResponse()
get_frame_response.ParseFromString(buffer)
frame_index = get_frame_response.frame_index
frame = FrameData(get_frame_response.frame)
yield (elapsed, frame_index, frame)


def iter_trajectory_file(path) -> Generator:
with open(path, "rb") as infile:
data = infile.read()
unpacker = Unpacker(data)
yield from iter_trajectory_recording(unpacker)


def advance_to_first_particle_frame(frames):
for elapsed, frame_index, frame in frames:
try:
particle_count = frame.particle_count
except MissingDataError:
pass
else:
if particle_count > 0:
break
else:
return

yield (elapsed, frame_index, frame)
yield from frames


def advance_to_first_coordinate_frame(frames):
for elapsed, frame_index, frame in frames:
try:
frame.particle_positions
except MissingDataError:
pass
else:
break
else:
return

yield (elapsed, frame_index, frame)
yield from frames
229 changes: 229 additions & 0 deletions python-libraries/narupa-mdanalysis/src/narupa/mdanalysis/universe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
"""
Facilities to read a Narupa trajectory recording into an MDAnalysis Universe.

.. code:: python
import MDAnalysis as mda
from narupa.mdanalysis import NarupaReader, NarupaParser

u = mda.Universe(
'input.traj',
'input.traj',
format=NarupaReader,
topology_format=NarupaParser,
)

.. note::
A Narupa trajectory recording can have its topology change over time. It
can even contain trajectories for unrelated simulations. The topology in an
MDAnalysis Universe is constant. Only the frames corresponding to the first
topology are read in a Universe.

"""
import warnings
from itertools import takewhile, chain
from typing import NamedTuple, Type, Callable

from MDAnalysis.coordinates.base import ProtoReader
from MDAnalysis.coordinates.timestep import Timestep
from MDAnalysis.lib.util import openany
from MDAnalysis.core.topologyattrs import (
Atomnames,
Atomtypes,
Elements,
Resids,
Resnames,
Segids,
ChainIDs,
TopologyAttr,
)
from MDAnalysis.core.topology import Topology
from MDAnalysis.topology.base import TopologyReaderBase

from narupa.trajectory import FrameData
from narupa.trajectory.frame_data import (
PARTICLE_COUNT,
RESIDUE_COUNT,
CHAIN_COUNT,
PARTICLE_ELEMENTS,
PARTICLE_NAMES,
PARTICLE_RESIDUES,
RESIDUE_NAMES,
RESIDUE_CHAINS,
RESIDUE_IDS,
CHAIN_NAMES,
MissingDataError,
)

from .recordings import (
Unpacker,
iter_trajectory_recording,
advance_to_first_particle_frame,
advance_to_first_coordinate_frame,
)
from .converter import _to_chemical_symbol


class KeyConversion(NamedTuple):
attribute: Type[TopologyAttr]
conversion: Callable


def _as_is(value):
return value


def _trimmed(value):
return value.strip()


KEY_TO_ATTRIBUTE = {
PARTICLE_NAMES: KeyConversion(Atomnames, _trimmed),
RESIDUE_NAMES: KeyConversion(Resnames, _as_is),
RESIDUE_IDS: KeyConversion(Resids, _as_is),
CHAIN_NAMES: KeyConversion(Segids, _as_is),
}


class NarupaParser(TopologyReaderBase):
def parse(self, **kwargs):
with openany(self.filename, mode="rb") as infile:
data = infile.read()
unpacker = Unpacker(data)
# We assume the full topology is in the first frame with a particle
# count greater than 0. This will be true only most of the time.
# TODO: implement a more reliable way to get the full topology
try:
_, _, frame = next(
advance_to_first_particle_frame(iter_trajectory_recording(unpacker))
)
except StopIteration:
raise IOError("The file does not contain any frame.")

attrs = []
for frame_key, (attribute, converter) in KEY_TO_ATTRIBUTE.items():
try:
values = frame.arrays[frame_key]
except MissingDataError:
pass
else:
attrs.append(attribute([converter(value) for value in values]))

try:
elements = frame.arrays[PARTICLE_ELEMENTS]
except MissingDataError:
pass
else:
converted_elements = _to_chemical_symbol(elements)
attrs.append(Atomtypes(converted_elements))
attrs.append(Elements(converted_elements))

# TODO: generate these values if they are not part of the FrameData
residx = frame.arrays[PARTICLE_RESIDUES]
segidx = frame.arrays[RESIDUE_CHAINS]
n_atoms = int(frame.values[PARTICLE_COUNT])
n_residues = int(frame.values[RESIDUE_COUNT])
n_chains = int(frame.values[CHAIN_COUNT])

try:
chain_ids_per_chain = frame.arrays[CHAIN_NAMES]
except MissingDataError:
pass
else:
chain_ids_per_particle = [
chain_ids_per_chain[segidx[residx[atom]]] for atom in range(n_atoms)
]
attrs.append(ChainIDs(chain_ids_per_particle))

return Topology(
n_atoms,
n_residues,
n_chains,
attrs=attrs,
atom_resindex=residx,
residue_segindex=segidx,
)


class NarupaReader(ProtoReader):
units = {"time": "ps", "length": "nm", "velocity": "nm/ps"}

def __init__(self, filename, convert_units=True, **kwargs):
super().__init__()
self.filename = filename
self.convert_units = convert_units
with openany(filename, mode="rb") as infile:
data = infile.read()
unpacker = Unpacker(data)
recording = advance_to_first_coordinate_frame(
iter_trajectory_recording(unpacker)
)
try:
_, _, first_frame = next(recording)
except StopIteration:
raise IOError("Empty trajectory.")
self.n_atoms = first_frame.particle_count

non_topology_frames = takewhile(
lambda frame: not has_topology(frame),
map(lambda record: record[2], recording),
)
try:
next(recording)
except StopIteration:
pass
else:
warnings.warn(
"The simulation contains changes to the topology after the "
"first frame. Only the frames with the initial topology are "
"accessible in this Universe."
)
self._frames = list(chain([first_frame], non_topology_frames))
self.n_frames = len(self._frames)
self._read_frame(0)

def _read_frame(self, frame):
self._current_frame_index = frame
try:
frame_at_index = self._frames[frame]
except IndexError as err:
raise EOFError(err) from None

ts = Timestep(self.n_atoms)
ts.frame = frame
ts.positions = frame_at_index.particle_positions
try:
ts.time = frame_at_index.simulation_time
except MissingDataError:
pass
try:
ts.triclinic_dimensions = frame_at_index.box_vectors
except MissingDataError:
pass

ts.data.update(frame_at_index.values)

if self.convert_units:
self.convert_pos_from_native(ts._pos) # in-place !
if ts.dimensions is not None:
self.convert_pos_from_native(ts.dimensions[:3]) # in-place!
if ts.has_velocities:
# converts nm/ps to A/ps units
self.convert_velocities_from_native(ts._velocities)

self.ts = ts
return ts

def _read_next_timestep(self):
if self._current_frame_index is None:
frame = 0
else:
frame = self._current_frame_index + 1
return self._read_frame(frame)

def _reopen(self):
self._current_frame_index = None


def has_topology(frame: FrameData) -> bool:
topology_keys = set(list(KEY_TO_ATTRIBUTE.keys()) + [PARTICLE_ELEMENTS])
return bool(topology_keys.intersection(frame.array_keys))
Loading
Loading