diff --git a/package/CHANGELOG b/package/CHANGELOG index f098c9fe425..6d05547c9cb 100644 --- a/package/CHANGELOG +++ b/package/CHANGELOG @@ -49,6 +49,9 @@ Fixes residues with identical resids (Issue #1387, PR #2872) * H5MD files are now pickleable with H5PYPicklable (Issue #2890, PR #2894) * Fixed Janin analysis residue filtering (include CYSH) (Issue #2898) + * libmdaxdr and libdcd classes in their last frame can now be pickled + (Issue #2878, PR #2911) + Enhancements * Refactored analysis.helanal into analysis.helix_analysis diff --git a/package/MDAnalysis/lib/formats/libdcd.pyx b/package/MDAnalysis/lib/formats/libdcd.pyx index 2229b02f9bd..03864d90f72 100644 --- a/package/MDAnalysis/lib/formats/libdcd.pyx +++ b/package/MDAnalysis/lib/formats/libdcd.pyx @@ -252,7 +252,7 @@ cdef class DCDFile: self.__getstate__()) def __getstate__(self): - return self.is_open, self.current_frame + return self.is_open, self.current_frame, self.n_frames def __setstate__(self, state): is_open = state[0] @@ -261,8 +261,20 @@ cdef class DCDFile: return current_frame = state[1] - self.seek(current_frame - 1) - self.current_frame = current_frame + if current_frame < self.n_frames: + self.seek(current_frame) + elif current_frame == self.n_frames: + # cannot seek to self.n_frames (a.k.a. len(DCDFile)); + # instead, we seek to the previous frame and read next. + # which is the state of the file when we need to serialize + # at the end of the trajectory. + self.seek(current_frame - 1) + _ = self.read() + else: # pragma: no cover + raise RuntimeError("Invalid frame number {} > {} -- this should" + "not happen.".format(current_frame, + self.n_frames) + ) def tell(self): """ diff --git a/package/MDAnalysis/lib/formats/libmdaxdr.pyx b/package/MDAnalysis/lib/formats/libmdaxdr.pyx index 54d64166a7c..81aaddebcb6 100644 --- a/package/MDAnalysis/lib/formats/libmdaxdr.pyx +++ b/package/MDAnalysis/lib/formats/libmdaxdr.pyx @@ -306,8 +306,20 @@ cdef class _XDRFile: # where was I current_frame = state[1] - self.seek(current_frame - 1) - self.current_frame = current_frame + if current_frame < self.offsets.size: + self.seek(current_frame) + elif current_frame == self.offsets.size: + # cannot seek to self.offsets.size (a.k.a len(_XDRFile)); + # instead, we seek to the previous frame and read next. + # which is the state of the file when we need to serialize + # at the end of the trajectory. + self.seek(current_frame - 1) + _ = self.read() + else: # pragma: no cover + raise RuntimeError("Invalid frame number {} > {} -- this should" + "not happen.".format(current_frame, + self.offsets.size) + ) def seek(self, frame): """Seek to Frame. diff --git a/testsuite/MDAnalysisTests/formats/test_libdcd.py b/testsuite/MDAnalysisTests/formats/test_libdcd.py index 5328a625f6d..0fdc53a7321 100644 --- a/testsuite/MDAnalysisTests/formats/test_libdcd.py +++ b/testsuite/MDAnalysisTests/formats/test_libdcd.py @@ -85,25 +85,61 @@ def dcd(): yield dcd +def _assert_compare_readers(old_reader, new_reader): + # same as next(old_reader) + frame = old_reader.read() + # same as next(new_reader) + new_frame = new_reader.read() + + assert old_reader.fname == new_reader.fname + assert old_reader.tell() == new_reader.tell() + assert_almost_equal(frame.xyz, new_frame.xyz) + assert_almost_equal(frame.unitcell, new_frame.unitcell) + + def test_pickle(dcd): + mid = len(dcd) // 2 + dcd.seek(mid) + new_dcd = pickle.loads(pickle.dumps(dcd)) + _assert_compare_readers(dcd, new_dcd) + + +def test_pickle_last(dcd): + # This is the file state when DCDReader is in its last frame. + # (Issue #2878) + dcd.seek(len(dcd) - 1) - dump = pickle.dumps(dcd) - new_dcd = pickle.loads(dump) + _ = dcd.read() + new_dcd = pickle.loads(pickle.dumps(dcd)) assert dcd.fname == new_dcd.fname assert dcd.tell() == new_dcd.tell() + with pytest.raises(StopIteration): + new_dcd.read() def test_pickle_closed(dcd): dcd.seek(len(dcd) - 1) dcd.close() - dump = pickle.dumps(dcd) - new_dcd = pickle.loads(dump) + new_dcd = pickle.loads(pickle.dumps(dcd)) assert dcd.fname == new_dcd.fname assert dcd.tell() != new_dcd.tell() +def test_pickle_after_read(dcd): + _ = dcd.read() + new_dcd = pickle.loads(pickle.dumps(dcd)) + _assert_compare_readers(dcd, new_dcd) + + +def test_pickle_immediately(dcd): + new_dcd = pickle.loads(pickle.dumps(dcd)) + + assert dcd.fname == new_dcd.fname + assert dcd.tell() == new_dcd.tell() + + @pytest.mark.parametrize("new_frame", (10, 42, 21)) def test_seek_normal(new_frame, dcd): # frame seek within range is tested diff --git a/testsuite/MDAnalysisTests/formats/test_libmdaxdr.py b/testsuite/MDAnalysisTests/formats/test_libmdaxdr.py index aff0bddf1b5..78f31afc109 100644 --- a/testsuite/MDAnalysisTests/formats/test_libmdaxdr.py +++ b/testsuite/MDAnalysisTests/formats/test_libmdaxdr.py @@ -25,7 +25,7 @@ import numpy as np from numpy.testing import (assert_almost_equal, assert_array_almost_equal, - assert_array_equal) + assert_array_equal, assert_equal) from MDAnalysis.lib.formats.libmdaxdr import TRRFile, XTCFile @@ -127,24 +127,53 @@ def test_read_write_mode_file(self, xdr, tmpdir, fname): with pytest.raises(IOError): f.read() + @staticmethod + def _assert_compare_readers(old_reader, new_reader): + frame = old_reader.read() + new_frame = new_reader.read() + + assert old_reader.fname == new_reader.fname + assert old_reader.tell() == new_reader.tell() + + assert_equal(old_reader.offsets, new_reader.offsets) + assert_almost_equal(frame.x, new_frame.x) + assert_almost_equal(frame.box, new_frame.box) + assert frame.step == new_frame.step + assert_almost_equal(frame.time, new_frame.time) + def test_pickle(self, reader): + mid = len(reader) // 2 + reader.seek(mid) + new_reader = pickle.loads(pickle.dumps(reader)) + self._assert_compare_readers(reader, new_reader) + + def test_pickle_last_frame(self, reader): + # This is the file state when XDRReader is in its last frame. + # (Issue #2878) reader.seek(len(reader) - 1) - dump = pickle.dumps(reader) - new_reader = pickle.loads(dump) + _ = reader.read() + new_reader = pickle.loads(pickle.dumps(reader)) assert reader.fname == new_reader.fname assert reader.tell() == new_reader.tell() - assert_almost_equal(reader.offsets, new_reader.offsets) + with pytest.raises(StopIteration): + new_reader.read() def test_pickle_closed(self, reader): reader.seek(len(reader) - 1) reader.close() - dump = pickle.dumps(reader) - new_reader = pickle.loads(dump) + new_reader = pickle.loads(pickle.dumps(reader)) assert reader.fname == new_reader.fname assert reader.tell() != new_reader.tell() + def test_pickle_immediately(self, reader): + new_reader = pickle.loads(pickle.dumps(reader)) + + assert reader.fname == new_reader.fname + assert reader.tell() == new_reader.tell() + + @pytest.mark.parametrize("xdrfile, fname, offsets", ((XTCFile, XTC_multi_frame, XTC_OFFSETS),