Skip to content

Commit

Permalink
Merge pull request #690 from jbarnoud/issue-260-fix-trz
Browse files Browse the repository at this point in the history
TRZ reader and writer are compatible with python 3
  • Loading branch information
richardjgowers committed Feb 1, 2016
2 parents 34ce8fa + 767a9be commit ac8a98c
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 14 deletions.
3 changes: 3 additions & 0 deletions package/CHANGELOG
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ Changes

* xdrlib has been ported to cython. (Issue #441)
* util.NamedStream no longer inherits from basestring (Issue #649)
* Short TRZ titles are striped from trailing spaces. A friendlier error
message is raised when the TRZ writer is asked to write a title longer
than 80 characters. (Issue #689)

Fixes

Expand Down
34 changes: 23 additions & 11 deletions package/MDAnalysis/coordinates/TRZ.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@
.. autoclass:: TRZWriter
:members:
"""

import six
from six.moves import range

from sys import maxint
import sys
import warnings
import numpy as np
import os
Expand Down Expand Up @@ -222,7 +222,7 @@ def _read_trz_header(self):
('force', 'i4'),
('p3', 'i4')])
data = np.fromfile(self.trzfile, dtype=self._headerdtype, count=1)
self.title = ''.join(data['title'][0])
self.title = ''.join(c.decode('utf-8') for c in data['title'][0]).strip()
if data['force'] == 10:
self.has_force = False
elif data['force'] == 20:
Expand Down Expand Up @@ -358,13 +358,21 @@ def _seek(self, nframes):
.. versionadded:: 0.9.0
"""
maxi_l = long(maxint)

framesize = long(self._dtype.itemsize)
seeksize = framesize * nframes
# On python 2, seek has issues with long int. This is solve in python 3
# where there is no longer a distinction between int and long int.
if six.PY2:
framesize = long(self._dtype.itemsize)
seeksize = framesize * nframes
maxi_l = long(sys.maxint)
else:
framesize = self._dtype.itemsize
seeksize = framesize * nframes
maxi_l = seeksize + 1

if seeksize > maxi_l:
# Workaround for seek not liking long ints
# On python 3 this branch will never be used as we defined maxi_l
# greater than seeksize.
framesize = long(framesize)
seeksize = framesize * nframes

Expand Down Expand Up @@ -435,7 +443,8 @@ def __init__(self, filename, n_atoms, title='TRZ', convert_units=None):
:Keywords:
*title*
title of the trajectory
title of the trajectory; the title must be 80 characters or shorter,
a longer title raises a ValueError exception.
*convert_units*
units are converted to the MDAnalysis base format; ``None`` selects
the value of :data:`MDAnalysis.core.flags` ['convert_lengths'].
Expand All @@ -448,6 +457,9 @@ def __init__(self, filename, n_atoms, title='TRZ', convert_units=None):
raise ValueError("TRZWriter: no atoms in output trajectory")
self.n_atoms = n_atoms

if len(title) > 80:
raise ValueError("TRZWriter: 'title' must be 80 characters of shorter")

if convert_units is None:
convert_units = flags['convert_lengths']
self.convert_units = convert_units
Expand Down Expand Up @@ -504,7 +516,7 @@ def _writeheader(self, title):
('pad3', 'i4'), ('nrec', 'i4'), ('pad4', 'i4')])
out = np.zeros((), dtype=hdt)
out['pad1'], out['pad2'] = 80, 80
out['title'] = title
out['title'] = title + ' ' * (80 - len(title))
out['pad3'], out['pad4'] = 4, 4
out['nrec'] = 10
out.tofile(self.trzfile)
Expand All @@ -517,8 +529,8 @@ def write_next_timestep(self, ts):
# Gather data, faking it when unavailable
data = {}
faked_attrs = []
for att in ['pressure', 'pressure_tensor', 'total_energy', 'potential_energy',
'kinetic_energy', 'temperature']:
for att in ['pressure', 'pressure_tensor', 'total_energy',
'potential_energy', 'kinetic_energy', 'temperature']:
try:
data[att] = ts.data[att]
except KeyError:
Expand Down
2 changes: 1 addition & 1 deletion package/MDAnalysis/coordinates/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,13 +694,13 @@
from . import PQR
from . import TRJ
from . import TRR
from . import TRZ
from . import XTC
from . import XYZ

try:
from . import DCD
from . import LAMMPS
from . import TRZ
except ImportError as e:
# The import is expected to fail under Python 3.
# It should not fail on Python 2, however.
Expand Down
2 changes: 2 additions & 0 deletions testsuite/MDAnalysisTests/coordinates/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,8 @@ class RefTRZ(object):
dtype=np.float32)
ref_delta = 0.001
ref_time = 0.01
ref_title = ('ABCDEFGHIJKLMNOPQRSTUVWXYZ12345678901234'
'ABCDEFGHIJKLMNOPQRSTUVWXYZ12345678901234')


class RefLAMMPSData(object):
Expand Down
18 changes: 16 additions & 2 deletions testsuite/MDAnalysisTests/coordinates/test_trz.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from six.moves import zip

from numpy.testing import (assert_equal, assert_array_almost_equal,
assert_almost_equal)
assert_almost_equal, assert_raises)
import tempdir
import numpy as np

Expand Down Expand Up @@ -96,6 +96,9 @@ def test_time(self):
assert_almost_equal(self.trz.time, self.ref_time, self.prec,
"wrong time value in trz")

def test_title(self):
assert_equal(self.ref_title, self.trz.title, "wrong title in trz")

def test_get_writer(self):
with tempdir.in_tempdir():
self.outfile = 'test-trz-writer.trz'
Expand Down Expand Up @@ -125,21 +128,24 @@ def setUp(self):
self.prec = 3
self.tmpdir = tempdir.TempDir()
self.outfile = self.tmpdir.name + '/test-trz-writer.trz'
self.outfile_long = self.tmpdir.name + '/test-trz-writer-long.trz'
self.Writer = mda.coordinates.TRZ.TRZWriter
self.title_to_write = 'Test title TRZ'

def tearDown(self):
del self.universe
del self.prec
try:
os.unlink(self.outfile)
os.unlink(self.outfile_long)
except OSError:
pass
del self.Writer
del self.tmpdir

def test_write_trajectory(self):
t = self.universe.trajectory
W = self.Writer(self.outfile, t.n_atoms)
W = self.Writer(self.outfile, t.n_atoms, title=self.title_to_write)
self._copy_traj(W)

def _copy_traj(self, writer):
Expand All @@ -149,6 +155,9 @@ def _copy_traj(self, writer):

uw = mda.Universe(TRZ_psf, self.outfile)

assert_equal(uw.trajectory.title, self.title_to_write,
"Title mismatch between original and written files.")

for orig_ts, written_ts in zip(self.universe.trajectory,
uw.trajectory):
assert_array_almost_equal(orig_ts._pos, written_ts._pos, self.prec,
Expand All @@ -169,6 +178,11 @@ def _copy_traj(self, writer):
written_ts.data[att], self.prec,
err_msg="TS equal failed for {0!s}".format(att))

def test_long_title(self):
title = '*' * 81
assert_raises(ValueError,
self.Writer, self.outfile, self.ref_n_atoms, title=title)


class TestTRZWriter2(object):
def setUp(self):
Expand Down

0 comments on commit ac8a98c

Please sign in to comment.