diff --git a/package/MDAnalysis/auxiliary/base.py b/package/MDAnalysis/auxiliary/base.py index b17c1d7d0d2..d93612b1694 100644 --- a/package/MDAnalysis/auxiliary/base.py +++ b/package/MDAnalysis/auxiliary/base.py @@ -308,6 +308,10 @@ def __init__(self, represent_ts_as='closest', auxname=None, cutoff=-1, self.auxstep._dt = self.time - self.initial_time self.rewind() + def __getstate__(self): + # probably works fine, but someone needs to write tests to confirm + return NotImplementedError + def copy(self): raise NotImplementedError("Copy not implemented for AuxReader") diff --git a/package/MDAnalysis/coordinates/DLPoly.py b/package/MDAnalysis/coordinates/DLPoly.py index 5e322e4ffc1..7901d05f59c 100644 --- a/package/MDAnalysis/coordinates/DLPoly.py +++ b/package/MDAnalysis/coordinates/DLPoly.py @@ -37,6 +37,7 @@ from . import base from . import core +from ..lib import util _DLPOLY_UNITS = {'length': 'Angstrom', 'velocity': 'Angstrom/ps', 'time': 'ps'} @@ -141,7 +142,7 @@ def _read_first_frame(self): ts.frame = 0 -class HistoryReader(base.ReaderBase): +class HistoryReader(base.ReaderBase, base._AsciiPickle): """Reads DLPoly format HISTORY files .. versionadded:: 0.11.0 @@ -154,9 +155,9 @@ def __init__(self, filename, **kwargs): super(HistoryReader, self).__init__(filename, **kwargs) # "private" file handle - self._file = open(self.filename, 'r') - self.title = self._file.readline().strip() - self._levcfg, self._imcon, self.n_atoms = np.int64(self._file.readline().split()[:3]) + self._f = util.anyopen(self.filename, 'r') + self.title = self._f.readline().strip() + self._levcfg, self._imcon, self.n_atoms = np.int64(self._f.readline().split()[:3]) self._has_vels = True if self._levcfg > 0 else False self._has_forces = True if self._levcfg == 2 else False @@ -170,20 +171,20 @@ def _read_next_timestep(self, ts=None): if ts is None: ts = self.ts - line = self._file.readline() # timestep line + line = self._f.readline() # timestep line if not line.startswith('timestep'): raise IOError if not self._imcon == 0: - ts._unitcell[0] = self._file.readline().split() - ts._unitcell[1] = self._file.readline().split() - ts._unitcell[2] = self._file.readline().split() + ts._unitcell[0] = self._f.readline().split() + ts._unitcell[1] = self._f.readline().split() + ts._unitcell[2] = self._f.readline().split() # If ids are given, put them in here # and later sort by them ids = [] for i in range(self.n_atoms): - line = self._file.readline().strip() # atom info line + line = self._f.readline().strip() # atom info line try: idx = int(line.split()[1]) except IndexError: @@ -192,11 +193,11 @@ def _read_next_timestep(self, ts=None): ids.append(idx) # Read in this order for now, then later reorder in place - ts._pos[i] = self._file.readline().split() + ts._pos[i] = self._f.readline().split() if self._has_vels: - ts._velocities[i] = self._file.readline().split() + ts._velocities[i] = self._f.readline().split() if self._has_forces: - ts._forces[i] = self._file.readline().split() + ts._forces[i] = self._f.readline().split() if ids: ids = np.array(ids) @@ -214,7 +215,7 @@ def _read_next_timestep(self, ts=None): def _read_frame(self, frame): """frame is 0 based, error checking is done in base.getitem""" - self._file.seek(self._offsets[frame]) + self._f.seek(self._offsets[frame]) self.ts.frame = frame - 1 # gets +1'd in read_next_frame return self._read_next_timestep() @@ -234,7 +235,7 @@ def _read_n_frames(self): """ offsets = self._offsets = [] - with open(self.filename, 'r') as f: + with util.anyopen(self.filename, 'r') as f: n_frames = 0 f.readline() @@ -262,10 +263,10 @@ def _read_n_frames(self): def _reopen(self): self.close() - self._file = open(self.filename, 'r') - self._file.readline() # header is 2 lines - self._file.readline() + self._f = util.anyopen(self.filename, 'r') + self._f.readline() # header is 2 lines + self._f.readline() self.ts.frame = -1 def close(self): - self._file.close() + self._f.close() diff --git a/package/MDAnalysis/coordinates/GMS.py b/package/MDAnalysis/coordinates/GMS.py index 46be3a34c09..36251bfa086 100644 --- a/package/MDAnalysis/coordinates/GMS.py +++ b/package/MDAnalysis/coordinates/GMS.py @@ -47,7 +47,7 @@ import MDAnalysis.lib.util as util -class GMSReader(base.ReaderBase): +class GMSReader(base.ReaderBase, base._AsciiPickle): """Reads from an GAMESS output file :Data: @@ -82,7 +82,7 @@ def __init__(self, outfilename, **kwargs): super(GMSReader, self).__init__(outfilename, **kwargs) # the filename has been parsed to be either b(g)zipped or not - self.outfile = util.anyopen(self.filename) + self._f = util.anyopen(self.filename) # note that, like for xtc and trr files, _n_atoms and _n_frames are used quasi-private variables # to prevent the properties being recalculated @@ -177,7 +177,7 @@ def _read_out_n_frames(self): return len(offsets) def _read_frame(self, frame): - self.outfile.seek(self._offsets[frame]) + self._f.seek(self._offsets[frame]) self.ts.frame = frame - 1 # gets +1'd in _read_next return self._read_next_timestep() @@ -186,7 +186,7 @@ def _read_next_timestep(self, ts=None): if ts is None: ts = self.ts # check that the outfile object exists; if not reopen the trajectory - if self.outfile is None: + if self._f is None: self.open_trajectory() x = [] y = [] @@ -195,7 +195,7 @@ def _read_next_timestep(self, ts=None): flag = 0 counter = 0 - for line in self.outfile: + for line in self._f: if self.runtyp == 'optimize': if (flag == 0) and (re.match(r'^.NSERCH=.*', line) is not None): flag = 1 @@ -246,22 +246,22 @@ def _reopen(self): self.open_trajectory() def open_trajectory(self): - if self.outfile is not None: + if self._f is not None: raise IOError(errno.EALREADY, 'GMS file already opened', self.filename) if not os.path.exists(self.filename): # must check; otherwise might segmentation fault raise IOError(errno.ENOENT, 'GMS file not found', self.filename) - self.outfile = util.anyopen(self.filename) + self._f = util.anyopen(self.filename) # reset ts ts = self.ts ts.frame = -1 - return self.outfile + return self._f def close(self): """Close out trajectory file if it was open.""" - if self.outfile is None: + if self._f is None: return - self.outfile.close() - self.outfile = None + self._f.close() + self._f = None diff --git a/package/MDAnalysis/coordinates/GSD.py b/package/MDAnalysis/coordinates/GSD.py index e5968e32d34..50cbd9334d0 100644 --- a/package/MDAnalysis/coordinates/GSD.py +++ b/package/MDAnalysis/coordinates/GSD.py @@ -54,7 +54,7 @@ from . import base -class GSDReader(base.ReaderBase): +class GSDReader(base.ReaderBase, base._ExAsciiPickle): """Reader for the GSD format. """ @@ -76,23 +76,27 @@ def __init__(self, filename, **kwargs): super(GSDReader, self).__init__(filename, **kwargs) self.filename = filename self.open_trajectory() - self.n_atoms = self._file[0].particles.N + self.n_atoms = self._f[0].particles.N self.ts = self._Timestep(self.n_atoms, **self._ts_kwargs) self._read_next_timestep() - def open_trajectory(self) : + def open_trajectory(self): """opens the trajectory file using gsd.hoomd module""" self._frame = -1 - self._file = gsd.hoomd.open(self.filename,mode='rb') + self._f = gsd.hoomd.open(self.filename,mode='rb') + + def open_trajectory_for_pickle(self): + """opens the trajectory file while not reset frame""" + self._f = gsd.hoomd.open(self.filename, mode='rb') def close(self): """close reader""" - self._file.file.close() + self._f.file.close() @property def n_frames(self): """number of frames in trajectory""" - return len(self._file) + return len(self._f) def _reopen(self): """reopen trajectory""" @@ -101,7 +105,7 @@ def _reopen(self): def _read_frame(self, frame): try : - myframe = self._file[frame] + myframe = self._f[frame] except IndexError: raise_from(IOError, None) @@ -131,3 +135,4 @@ def _read_frame(self, frame): def _read_next_timestep(self) : """read next frame in trajectory""" return self._read_frame(self._frame + 1) + diff --git a/package/MDAnalysis/coordinates/LAMMPS.py b/package/MDAnalysis/coordinates/LAMMPS.py index 7dd70a26a94..cbea8e6e44c 100644 --- a/package/MDAnalysis/coordinates/LAMMPS.py +++ b/package/MDAnalysis/coordinates/LAMMPS.py @@ -454,7 +454,7 @@ def write(self, selection, frame=None): self._write_velocities(atoms) -class DumpReader(base.ReaderBase): +class DumpReader(base.ReaderBase, base._BAsciiPickle): """Reads the default `LAMMPS dump format`_ Expects trajectories produced by the default 'atom' style dump. @@ -478,7 +478,7 @@ def __init__(self, filename, **kwargs): def _reopen(self): self.close() - self._file = util.anyopen(self.filename) + self._f = util.anyopen(self.filename, 'rb') self.ts = self._Timestep(self.n_atoms, **self._ts_kwargs) self.ts.frame = -1 @@ -510,17 +510,17 @@ def n_frames(self): return len(self._offsets) def close(self): - if hasattr(self, '_file'): - self._file.close() + if hasattr(self, '_f'): + self._f.close() def _read_frame(self, frame): - self._file.seek(self._offsets[frame]) + self._f.seek(self._offsets[frame]) self.ts.frame = frame - 1 # gets +1'd in next return self._read_next_timestep() def _read_next_timestep(self): - f = self._file + f = self._f ts = self.ts ts.frame += 1 if ts.frame >= len(self): diff --git a/package/MDAnalysis/coordinates/PDB.py b/package/MDAnalysis/coordinates/PDB.py index 649bde13808..a7c8dceda7f 100644 --- a/package/MDAnalysis/coordinates/PDB.py +++ b/package/MDAnalysis/coordinates/PDB.py @@ -165,8 +165,7 @@ # Pairs of residue name / atom name in use to deduce PDB formatted atom names Pair = collections.namedtuple('Atom', 'resname name') - -class PDBReader(base.ReaderBase): +class PDBReader(base.ReaderBase, base._BAsciiPickle): """PDBReader that reads a `PDB-formatted`_ file, no frills. The following *PDB records* are parsed (see `PDB coordinate section`_ for @@ -292,7 +291,7 @@ def __init__(self, filename, **kwargs): if isinstance(filename, util.NamedStream) and isinstance(filename.stream, StringIO): filename.stream = BytesIO(filename.stream.getvalue().encode()) - pdbfile = self._pdbfile = util.anyopen(filename, 'rb') + pdbfile = self._f = util.anyopen(filename, 'rb') line = "magical" while line: @@ -360,7 +359,7 @@ def _reopen(self): # Pretend the current TS is -1 (in 0 based) so "next" is the # 0th frame self.close() - self._pdbfile = util.anyopen(self.filename, 'rb') + self._f = util.anyopen(self.filename, 'rb') self.ts.frame = -1 def _read_next_timestep(self, ts=None): @@ -400,8 +399,8 @@ def _read_frame(self, frame): occupancy = np.ones(self.n_atoms) # Seek to start and read until start of next frame - self._pdbfile.seek(start) - chunk = self._pdbfile.read(stop - start).decode() + self._f.seek(start) + chunk = self._f.read(stop - start).decode() tmp_buf = [] for line in chunk.splitlines(): @@ -459,7 +458,7 @@ def _read_frame(self, frame): return self.ts def close(self): - self._pdbfile.close() + self._f.close() class PDBWriter(base.WriterBase): diff --git a/package/MDAnalysis/coordinates/TRJ.py b/package/MDAnalysis/coordinates/TRJ.py index d87a4ca1824..a9009ed2c92 100644 --- a/package/MDAnalysis/coordinates/TRJ.py +++ b/package/MDAnalysis/coordinates/TRJ.py @@ -186,7 +186,7 @@ class Timestep(base.Timestep): order = 'C' -class TRJReader(base.ReaderBase): +class TRJReader(base.ReaderBase, base._AsciiPickle): """AMBER trajectory reader. Reads the ASCII formatted `AMBER TRJ format`_. Periodic box information @@ -218,7 +218,7 @@ def __init__(self, filename, n_atoms=None, **kwargs): self._n_atoms = n_atoms self._n_frames = None - self.trjfile = None # have _read_next_timestep() open it properly! + self._f = None # have _read_next_timestep() open it properly! self.ts = self._Timestep(self.n_atoms, **self._ts_kwargs) # FORMAT(10F8.3) (X(i), Y(i), Z(i), i=1,NATOM) @@ -243,22 +243,22 @@ def __init__(self, filename, n_atoms=None, **kwargs): self._read_next_timestep() def _read_frame(self, frame): - if self.trjfile is None: + if self._f is None: self.open_trajectory() - self.trjfile.seek(self._offsets[frame]) + self._f.seek(self._offsets[frame]) self.ts.frame = frame - 1 # gets +1'd in _read_next return self._read_next_timestep() def _read_next_timestep(self): # FORMAT(10F8.3) (X(i), Y(i), Z(i), i=1,NATOM) ts = self.ts - if self.trjfile is None: + if self._f is None: self.open_trajectory() # Read coordinat frame: # coordinates = numpy.zeros(3*self.n_atoms, dtype=np.float32) _coords = [] - for number, line in enumerate(self.trjfile): + for number, line in enumerate(self._f): try: _coords.extend(self.default_line_parser.read(line)) except ValueError: @@ -273,7 +273,7 @@ def _read_next_timestep(self): # Read box information if self.periodic: - line = next(self.trjfile) + line = next(self._f) box = self.box_line_parser.read(line) ts._unitcell[:3] = np.array(box, dtype=np.float32) ts._unitcell[3:] = [90., 90., 90.] # assumed @@ -320,7 +320,7 @@ def _detect_amber_box(self): self._read_next_timestep() ts = self.ts # TODO: what do we do with 1-frame trajectories? Try..except EOFError? - line = next(self.trjfile) + line = next(self._f) nentries = self.default_line_parser.number_of_matches(line) if nentries == 3: self.periodic = True @@ -371,8 +371,8 @@ def _reopen(self): def open_trajectory(self): """Open the trajectory for reading and load first frame.""" - self.trjfile = util.anyopen(self.filename) - self.header = self.trjfile.readline() # ignore first line + self._f = util.anyopen(self.filename) + self.header = self._f.readline() # ignore first line if len(self.header.rstrip()) > 80: # Chimera uses this check raise OSError( @@ -382,14 +382,14 @@ def open_trajectory(self): ts = self.ts ts.frame = -1 - return self.trjfile + return self._f def close(self): """Close trj trajectory file if it was open.""" - if self.trjfile is None: + if self._f is None: return - self.trjfile.close() - self.trjfile = None + self._f.close() + self._f = None class NCDFReader(base.ReaderBase): @@ -473,12 +473,11 @@ def __init__(self, filename, n_atoms=None, mmap=None, **kwargs): super(NCDFReader, self).__init__(filename, **kwargs) - self.trjfile = scipy.io.netcdf.netcdf_file(self.filename, - mmap=self._mmap) + self._f = scipy.io.netcdf.netcdf_file(self.filename, + mmap=self._mmap) - # AMBER NetCDF files should always have a convention try: - conventions = self.trjfile.Conventions + conventions = self._f.Conventions if not ('AMBER' in conventions.decode('utf-8').split(',') or 'AMBER' in conventions.decode('utf-8').split()): errmsg = ("NCDF trajectory {0} does not conform to AMBER " @@ -496,7 +495,7 @@ def __init__(self, filename, n_atoms=None, mmap=None, **kwargs): # AMBER NetCDF files should also have a ConventionVersion try: - ConventionVersion = self.trjfile.ConventionVersion.decode('utf-8') + ConventionVersion = self._f.ConventionVersion.decode('utf-8') if not ConventionVersion == self.version: wmsg = ("NCDF trajectory format is {0!s} but the reader " "implements format {1!s}".format( @@ -509,7 +508,7 @@ def __init__(self, filename, n_atoms=None, mmap=None, **kwargs): raise_from(ValueError(errmsg), None) # The AMBER NetCDF standard enforces 64 bit offsets - if not self.trjfile.version_byte == 2: + if not self._f.version_byte == 2: errmsg = ("NCDF trajectory {0} does not conform to AMBER " "specifications, as detailed in " "https://ambermd.org/netcdf/nctraj.xhtml " @@ -520,7 +519,7 @@ def __init__(self, filename, n_atoms=None, mmap=None, **kwargs): # The AMBER NetCDF standard enforces 3D coordinates try: - if not self.trjfile.dimensions['spatial'] == 3: + if not self._f.dimensions['spatial'] == 3: errmsg = "Incorrect spatial value for NCDF trajectory file" raise TypeError(errmsg) except KeyError: @@ -529,8 +528,8 @@ def __init__(self, filename, n_atoms=None, mmap=None, **kwargs): # AMBER NetCDF specs require program and programVersion. Warn users # if those attributes do not exist - if not (hasattr(self.trjfile, 'program') and - hasattr(self.trjfile, 'programVersion')): + if not (hasattr(self._f, 'program') and + hasattr(self._f, 'programVersion')): wmsg = ("NCDF trajectory {0} may not fully adhere to AMBER " "standards as either the `program` or `programVersion` " "attributes are missing".format(self.filename)) @@ -538,7 +537,7 @@ def __init__(self, filename, n_atoms=None, mmap=None, **kwargs): logger.warning(wmsg) try: - self.n_atoms = self.trjfile.dimensions['atom'] + self.n_atoms = self._f.dimensions['atom'] if n_atoms is not None and n_atoms != self.n_atoms: errmsg = ("Supplied n_atoms ({0}) != natom from ncdf ({1}). " "Note: n_atoms can be None and then the ncdf value " @@ -546,18 +545,19 @@ def __init__(self, filename, n_atoms=None, mmap=None, **kwargs): raise ValueError(errmsg) except KeyError: errmsg = ("NCDF trajectory {0} does not contain atom " - "information".format(self.filename)) + "information".format( + self._f.ConventionVersion, self.version)) raise_from(ValueError(errmsg), None) try: - self.n_frames = self.trjfile.dimensions['frame'] + self.n_frames = self._f.dimensions['frame'] # example trajectory when read with scipy.io.netcdf has # dimensions['frame'] == None (indicating a record dimension that # can grow) whereas if read with netCDF4 I get # len(dimensions['frame']) == 10: in any case, we need to get # the number of frames from somewhere such as the time variable: if self.n_frames is None: - self.n_frames = self.trjfile.variables['time'].shape[0] + self.n_frames = self._f.variables['time'].shape[0] except KeyError: raise_from( ValueError( @@ -567,7 +567,7 @@ def __init__(self, filename, n_atoms=None, mmap=None, **kwargs): None) try: - self.remarks = self.trjfile.title + self.remarks = self._f.title except AttributeError: self.remarks = "" # other metadata (*= requd): @@ -576,8 +576,8 @@ def __init__(self, filename, n_atoms=None, mmap=None, **kwargs): # checks for not-implemented features (other units would need to be # hacked into MDAnalysis.units) - self._verify_units(self.trjfile.variables['time'].units, 'picosecond') - self._verify_units(self.trjfile.variables['coordinates'].units, + self._verify_units(self._f.variables['time'].units, 'picosecond') + self._verify_units(self._f.variables['coordinates'].units, 'angstrom') # Check for scale_factor attributes for all data variables and @@ -589,32 +589,32 @@ def __init__(self, filename, n_atoms=None, mmap=None, **kwargs): 'velocities': 1.0, 'forces': 1.0} - for variable in self.trjfile.variables: - if hasattr(self.trjfile.variables[variable], 'scale_factor'): + for variable in self._f.variables: + if hasattr(self._f.variables[variable], 'scale_factor'): if variable in self.scale_factors: - scale_factor = self.trjfile.variables[variable].scale_factor + scale_factor = self._f.variables[variable].scale_factor self.scale_factors[variable] = scale_factor else: errmsg = ("scale_factors for variable {0} are " "not implemented".format(variable)) raise NotImplementedError(errmsg) - self.has_velocities = 'velocities' in self.trjfile.variables + self.has_velocities = 'velocities' in self._f.variables if self.has_velocities: - self._verify_units(self.trjfile.variables['velocities'].units, + self._verify_units(self._f.variables['velocities'].units, 'angstrom/picosecond') - self.has_forces = 'forces' in self.trjfile.variables + self.has_forces = 'forces' in self._f.variables if self.has_forces: - self._verify_units(self.trjfile.variables['forces'].units, + self._verify_units(self._f.variables['forces'].units, 'kilocalorie/mole/angstrom') - self.periodic = 'cell_lengths' in self.trjfile.variables + self.periodic = 'cell_lengths' in self._f.variables if self.periodic: - self._verify_units(self.trjfile.variables['cell_lengths'].units, + self._verify_units(self._f.variables['cell_lengths'].units, 'angstrom') # As of v1.0.0 only `degree` is accepted as a unit - cell_angle_units = self.trjfile.variables['cell_angles'].units + cell_angle_units = self._f.variables['cell_angles'].units self._verify_units(cell_angle_units, 'degree') self._current_frame = 0 @@ -628,6 +628,30 @@ def __init__(self, filename, n_atoms=None, mmap=None, **kwargs): # load first data frame self._read_frame(0) + def __getstate__(self): + state = self.__dict__.copy() + del state['_f'] + + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self._f = scipy.io.netcdf.netcdf_file(self.filename, + mmap=self._mmap) + +# @property +# def n_frames(self): +# n_frames = self._f.dimensions['frame'] +# # example trajectory when read with scipy.io.netcdf has +# # dimensions['frame'] == None (indicating a record dimension that can +# # grow) whereas if read with netCDF4 I get len(dimensions['frame']) == +# # 10: in any case, we need to get the number of frames from somewhere +# # such as the time variable: +# if n_frames is None: +# n_frames = self._f.variables['time'].shape[0] +# +# return n_frames + @staticmethod def _verify_units(eval_unit, expected_units): if eval_unit.decode('utf-8') != expected_units: @@ -645,7 +669,7 @@ def parse_n_atoms(filename, **kwargs): def _read_frame(self, frame): ts = self.ts - if self.trjfile is None: + if self._f is None: raise IOError("Trajectory is closed") if np.dtype(type(frame)) != np.dtype(int): # convention... for netcdf could also be a slice @@ -653,21 +677,21 @@ def _read_frame(self, frame): if frame >= self.n_frames or frame < 0: raise IndexError("frame index must be 0 <= frame < {0}".format( self.n_frames)) - # note: self.trjfile.variables['coordinates'].shape == (frames, n_atoms, 3) - ts._pos[:] = (self.trjfile.variables['coordinates'][frame] * + # note: self._f.variables['coordinates'].shape == (frames, n_atoms, 3) + ts._pos[:] = (self._f.variables['coordinates'][frame] * self.scale_factors['coordinates']) - ts.time = (self.trjfile.variables['time'][frame] * + ts.time = (self._f.variables['time'][frame] * self.scale_factors['time']) if self.has_velocities: - ts._velocities[:] = (self.trjfile.variables['velocities'][frame] * + ts._velocities[:] = (self._f.variables['velocities'][frame] * self.scale_factors['velocities']) if self.has_forces: - ts._forces[:] = (self.trjfile.variables['forces'][frame] * + ts._forces[:] = (self._f.variables['forces'][frame] * self.scale_factors['forces']) if self.periodic: - ts._unitcell[:3] = (self.trjfile.variables['cell_lengths'][frame] * + ts._unitcell[:3] = (self._f.variables['cell_lengths'][frame] * self.scale_factors['cell_lengths']) - ts._unitcell[3:] = (self.trjfile.variables['cell_angles'][frame] * + ts._unitcell[3:] = (self._f.variables['cell_angles'][frame] * self.scale_factors['cell_angles']) if self.convert_units: self.convert_pos_from_native(ts._pos) # in-place ! @@ -697,8 +721,8 @@ def _read_next_timestep(self, ts=None): raise_from(IOError, None) def _get_dt(self): - t1 = self.trjfile.variables['time'][1] - t0 = self.trjfile.variables['time'][0] + t1 = self._f.variables['time'][1] + t0 = self._f.variables['time'][0] return t1 - t0 def close(self): @@ -710,9 +734,9 @@ def close(self): before the file can be closed. """ - if self.trjfile is not None: - self.trjfile.close() - self.trjfile = None + if self._f is not None: + self._f.close() + self._f = None def Writer(self, filename, **kwargs): """Returns a NCDFWriter for `filename` with the same parameters as this NCDF. @@ -870,7 +894,7 @@ def __init__(self, self.remarks = remarks or "AMBER NetCDF format (MDAnalysis.coordinates.trj.NCDFWriter)" self._first_frame = True # signals to open trajectory - self.trjfile = None # open on first write with _init_netcdf() + self._f = None # open on first write with _init_netcdf() self.periodic = None # detect on first write self.has_velocities = kwargs.get('velocities', False) self.has_forces = kwargs.get('forces', False) @@ -969,7 +993,7 @@ def _init_netcdf(self, periodic=True): ncfile.sync() self._first_frame = False - self.trjfile = ncfile + self._f = ncfile def is_periodic(self, ts): """Test if timestep ``ts`` contains a periodic box. @@ -1018,7 +1042,7 @@ def _write_next_frame(self, ag): raise IOError( "NCDFWriter: Timestep does not have the correct number of atoms") - if self.trjfile is None: + if self._f is None: # first time step: analyze data and open trajectory accordingly self._init_netcdf(periodic=self.is_periodic(ts)) @@ -1054,12 +1078,12 @@ def _write_next_timestep(self, ts): unitcell = self.convert_dimensions_to_unitcell(ts) # write step - self.trjfile.variables['coordinates'][self.curr_frame, :, :] = pos - self.trjfile.variables['time'][self.curr_frame] = time + self._f.variables['coordinates'][self.curr_frame, :, :] = pos + self._f.variables['time'][self.curr_frame] = time if self.periodic: - self.trjfile.variables['cell_lengths'][ + self._f.variables['cell_lengths'][ self.curr_frame, :] = unitcell[:3] - self.trjfile.variables['cell_angles'][ + self._f.variables['cell_angles'][ self.curr_frame, :] = unitcell[3:] if self.has_velocities: @@ -1067,19 +1091,19 @@ def _write_next_timestep(self, ts): if self.convert_units: velocities = self.convert_velocities_to_native( velocities, inplace=False) - self.trjfile.variables['velocities'][self.curr_frame, :, :] = velocities + self._f.variables['velocities'][self.curr_frame, :, :] = velocities if self.has_forces: forces = ts._forces if self.convert_units: forces = self.convert_forces_to_native( forces, inplace=False) - self.trjfile.variables['forces'][self.curr_frame, :, :] = forces + self._f.variables['forces'][self.curr_frame, :, :] = forces - self.trjfile.sync() + self._f.sync() self.curr_frame += 1 def close(self): - if self.trjfile is not None: - self.trjfile.close() - self.trjfile = None + if self._f is not None: + self._f.close() + self._f = None diff --git a/package/MDAnalysis/coordinates/TRZ.py b/package/MDAnalysis/coordinates/TRZ.py index 6598f077b1a..9b45bf5efb7 100644 --- a/package/MDAnalysis/coordinates/TRZ.py +++ b/package/MDAnalysis/coordinates/TRZ.py @@ -125,7 +125,7 @@ def dimensions(self, box): self._unitcell[:] = triclinic_vectors(box).reshape(9) -class TRZReader(base.ReaderBase): +class TRZReader(base.ReaderBase, base._BAsciiPickle): """Reads an IBIsCO or YASP trajectory file Attributes @@ -169,7 +169,7 @@ def __init__(self, trzfilename, n_atoms=None, **kwargs): if n_atoms is None: raise ValueError('TRZReader requires the n_atoms keyword') - self.trzfile = util.anyopen(self.filename, 'rb') + self._f = util.anyopen(self.filename, 'rb') self._cache = dict() self._n_atoms = n_atoms @@ -233,7 +233,7 @@ def _read_trz_header(self): ('p2', '<2i4'), ('force', '".format( n_atoms=len(self.atoms)) - def __getstate__(self): - raise NotImplementedError + @classmethod + def _unpickle_U(cls, top, traj, anchor): + """Special method used by __reduce__ to deserialise a Universe""" + # top is a Topology object at this point, but Universe can handle that + u = cls(top) + u.anchor_name = anchor + # maybe this is None, but that's still cool + u.trajectory = traj + + return u - def __setstate__(self, state): - raise NotImplementedError + def __reduce__(self): + # Can't quite use __setstate__/__getstate__ so go via __reduce__ + # Universe's two "legs" of topology and traj both serialise themselves + # the only other state held in Universe is anchor name? + return (self._unpickle_U, (self._topology, self._trajectory, self.anchor_name)) # Properties @property diff --git a/testsuite/MDAnalysisTests/coordinates/test_netcdf.py b/testsuite/MDAnalysisTests/coordinates/test_netcdf.py index 81409b82d00..3e4a3ecf37f 100644 --- a/testsuite/MDAnalysisTests/coordinates/test_netcdf.py +++ b/testsuite/MDAnalysisTests/coordinates/test_netcdf.py @@ -57,7 +57,7 @@ def test_slice_iteration(self, universe): err_msg="slicing did not produce the expected frames") def test_metadata(self, universe): - data = universe.trajectory.trjfile + data = universe.trajectory._f assert_equal(data.Conventions.decode('utf-8'), 'AMBER') assert_equal(data.ConventionVersion.decode('utf-8'), '1.0') @@ -687,8 +687,8 @@ def _check_new_traj(self, universe, outfile): self.prec, err_msg="unitcells are not identical") # check that the NCDF data structures are the same - nc_orig = universe.trajectory.trjfile - nc_copy = uw.trajectory.trjfile + nc_orig = universe.trajectory._f + nc_copy = uw.trajectory._f # note that here 'dimensions' is a specific netcdf data structure and # not the unit cell dimensions in MDAnalysis diff --git a/testsuite/MDAnalysisTests/core/test_universe.py b/testsuite/MDAnalysisTests/core/test_universe.py index 5f8d73c0264..f2eea1a08e8 100644 --- a/testsuite/MDAnalysisTests/core/test_universe.py +++ b/testsuite/MDAnalysisTests/core/test_universe.py @@ -27,7 +27,7 @@ import os import subprocess - +import sys try: from cStringIO import StringIO except: @@ -272,10 +272,13 @@ def test_load_multiple_args(self): assert_equal(len(u.atoms), 3341, "Loading universe failed somehow") assert_equal(u.trajectory.n_frames, 2 * ref.trajectory.n_frames) - def test_pickle_raises_NotImplementedError(self): + @pytest.mark.xfail(sys.version_info < (3, 0), reason="pickle function not \ + working in python 2") + def test_pickle(self): u = mda.Universe(PSF, DCD) - with pytest.raises(NotImplementedError): - cPickle.dumps(u, protocol = cPickle.HIGHEST_PROTOCOL) + s = cPickle.dumps(u, protocol = cPickle.HIGHEST_PROTOCOL) + new_u = cPickle.loads(s) + assert_equal(u.atoms.names, new_u.atoms.names) @pytest.mark.parametrize('dtype', (int, np.float32, np.float64)) def test_set_dimensions(self, dtype): diff --git a/testsuite/MDAnalysisTests/parallelism/__init__.py b/testsuite/MDAnalysisTests/parallelism/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/testsuite/MDAnalysisTests/parallelism/test_multiprocessing.py b/testsuite/MDAnalysisTests/parallelism/test_multiprocessing.py new file mode 100644 index 00000000000..6b7fb117843 --- /dev/null +++ b/testsuite/MDAnalysisTests/parallelism/test_multiprocessing.py @@ -0,0 +1,163 @@ +# -*- Mode: python; tab-width: 4; indent-tabs-mode:nil; coding:utf-8 -*- +# vim: tabstop=4 expandtab shiftwidth=4 softtabstop=4 fileencoding=utf-8 +# +# MDAnalysis --- https://www.mdanalysis.org +# Copyright (c) 2006-2017 The MDAnalysis Development Team and contributors +# (see the file AUTHORS for the full list of names) +# +# Released under the GNU Public Licence, v2 or any higher version +# +# Please cite your use of MDAnalysis in published work: +# +# R. J. Gowers, M. Linke, J. Barnoud, T. J. E. Reddy, M. N. Melo, S. L. Seyler, +# D. L. Dotson, J. Domanski, S. Buchoux, I. M. Kenney, and O. Beckstein. +# MDAnalysis: A Python package for the rapid analysis of molecular dynamics +# simulations. In S. Benthall and S. Rostrup editors, Proceedings of the 15th +# Python in Science Conference, pages 102-109, Austin, TX, 2016. SciPy. +# doi: 10.25080/majora-629e541a-00e +# +# N. Michaud-Agrawal, E. J. Denning, T. B. Woolf, and O. Beckstein. +# MDAnalysis: A Toolkit for the Analysis of Molecular Dynamics Simulations. +# J. Comput. Chem. 32 (2011), 2319--2327, doi:10.1002/jcc.21787 +# + +from __future__ import absolute_import +import sys +import multiprocessing +import numpy as np +import pytest +import pickle + +import MDAnalysis as mda +from MDAnalysis.coordinates.core import get_reader_for +from MDAnalysisTests.datafiles import ( + AUX_XVG, + CRD, + PSF, DCD, + DMS, + DLP_CONFIG, + DLP_HISTORY, + INPCRD, + GMS_ASYMOPT, + GRO, + GSD, + LAMMPSdata_mini, + LAMMPSDUMP, + mol2_molecules, + MMTF, + NCDF, + PDB_small, PDB_multiframe, + PDBQT_input, + PQR, + TRR, + TRJ, + TRZ, + TXYZ, + XTC, + XPDB_small, + XYZ_mini, XYZ, +) + +from numpy.testing import assert_equal + + +@pytest.fixture(params=[ + (PSF, DCD), + (GRO, XTC), + (PDB_multiframe,), + (XYZ,), +]) +def u(request): + if len(request.param) == 1: + f = request.param[0] + return mda.Universe(f) + else: + top, trj = request.param + return mda.Universe(top, trj) + +# Define target functions here +# inside test functions doesn't work +def cog(u, ag, frame_id): + u.trajectory[frame_id] + + return ag.center_of_geometry() + + +def getnames(u, ix): + # Check topology stuff works + return u.atoms[ix].name + + +@pytest.mark.xfail(sys.version_info < (3, 0), reason="pickle function not \ + working in python 2") +def test_multiprocess_COG(u): + ag = u.atoms[10:20] + + ref = np.array([cog(u, ag, i) + for i in range(4)]) + + p = multiprocessing.Pool(2) + res = np.array([p.apply(cog, args=(u, ag, i)) + for i in range(4)]) + p.close() + assert_equal(ref, res) + + +@pytest.mark.xfail(sys.version_info <= (3, 0), reason="pickle function not \ + working in python 2") +def test_multiprocess_names(u): + ref = [getnames(u, i) + for i in range(10)] + + p = multiprocessing.Pool(2) + res = [p.apply(getnames, args=(u, i)) + for i in range(10)] + p.close() + + assert_equal(ref, res) + +@pytest.fixture(params=[ + # formatname, filename + ('CRD', CRD, dict()), + ('DATA', LAMMPSdata_mini, dict(n_atoms=1)), + ('DCD', DCD, dict()), + ('DMS', DMS, dict()), + ('CONFIG', DLP_CONFIG, dict()), + ('HISTORY', DLP_HISTORY, dict()), + ('INPCRD', INPCRD, dict()), + ('LAMMPSDUMP', LAMMPSDUMP, dict()), + ('GMS', GMS_ASYMOPT, dict()), + ('GRO', GRO, dict()), + ('GSD', GSD, dict()), + ('MMTF', MMTF, dict()), + ('MOL2', mol2_molecules, dict()), + ('PDB', PDB_small, dict()), + ('PQR', PQR, dict()), + ('PDBQT', PDBQT_input, dict()), + ('TRR', TRR, dict()), + ('TRZ', TRZ, dict(n_atoms=8184)), + ('TRJ', TRJ, dict(n_atoms=252)), + ('XTC', XTC, dict()), + ('XPDB', XPDB_small, dict()), + ('XYZ', XYZ_mini, dict()), + ('NCDF', NCDF, dict()), + ('TXYZ', TXYZ, dict()), + ('memory', np.arange(60).reshape(2, 10, 3).astype(np.float64), dict()), + ('CHAIN', [GRO, GRO, GRO], dict()), +]) +def ref_reader(request): + fmt_name, filename, extras = request.param + + r = get_reader_for(filename, format=fmt_name)(filename, **extras) + try: + yield r + finally: + # make sure file handle is closed afterwards + r.close() + +def test_readers_pickle(ref_reader): + ps = pickle.dumps(ref_reader) + + reanimated = pickle.loads(ps) + + assert len(ref_reader) == len(reanimated)