Skip to content

Commit

Permalink
more PSA tests
Browse files Browse the repository at this point in the history
- additional tests
- pylint fix in PSA code
- tests for load() still xfail
  • Loading branch information
orbeckst committed Apr 10, 2019
1 parent e405413 commit ba55155
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 12 deletions.
35 changes: 23 additions & 12 deletions package/MDAnalysis/analysis/psa.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@
from __future__ import division, absolute_import, print_function

import six
from six.moves import range, cPickle
from six.moves import range, cPickle, zip
from six import string_types

import os
Expand Down Expand Up @@ -1353,10 +1353,10 @@ def __init__(self, universes, reference=None, ref_select='name CA',

# Set default directory names for storing topology/reference structures,
# fitted trajectories, paths, distance matrices, and plots
self.datadirs = {'fitted_trajs' : '/fitted_trajs',
'paths' : '/paths',
'distance_matrices' : '/distance_matrices',
'plots' : '/plots'}
self.datadirs = {'fitted_trajs' : 'fitted_trajs',
'paths' : 'paths',
'distance_matrices' : 'distance_matrices',
'plots' : 'plots'}
for dir_name, directory in six.iteritems(self.datadirs):
try:
full_dir_name = os.path.join(self.targetdir, dir_name)
Expand Down Expand Up @@ -1465,7 +1465,7 @@ def generate_paths(self, align=False, filename='fitted', infix='', weights=None,
for i, u in enumerate(self.universes):
p = Path(u, self.u_reference, ref_select=self.ref_select,
path_select=self.path_select, ref_frame=ref_frame)
trj_dir = self.targetdir + self.datadirs['fitted_trajs']
trj_dir = os.path.join(self.targetdir, self.datadirs['fitted_trajs'])
postfix = '{0}{1}{2:03n}'.format(infix, '_psa', i+1)
top_name, fit_trj_name = p.run(align=align, filename=filename,
postfix=postfix,
Expand Down Expand Up @@ -1635,7 +1635,7 @@ def save_result(self, filename=None):
"""
filename = filename or 'psa_distances'
head = self.targetdir + self.datadirs['distance_matrices']
head = os.path.join(self.targetdir, self.datadirs['distance_matrices'])
outfile = os.path.join(head, filename)
if self.D is None:
raise NoDataError("Distance matrix has not been calculated yet")
Expand All @@ -1661,9 +1661,13 @@ def save_paths(self, filename=None):
-------
filename : str
See Also
--------
load
"""
filename = filename or 'path_psa'
head = self.targetdir + self.datadirs['paths']
head = os.path.join(self.targetdir, self.datadirs['paths'])
outfile = os.path.join(head, filename)
if self.paths is None:
raise NoDataError("Paths have not been calculated yet")
Expand All @@ -1682,6 +1686,13 @@ def save_paths(self, filename=None):
def load(self):
"""Load fitted paths specified by 'psa_path-names.pkl' in
:attr:`PSAnalysis.targetdir`.
All filenames are determined by :class:`PSAnalysis`.
See Also
--------
save_paths
"""
if not os.path.exists(self._paths_pkl):
raise NoDataError("Fitted trajectories cannot be loaded; save file" +
Expand All @@ -1690,7 +1701,7 @@ def load(self):
self.paths = [np.load(pname) for pname in self.path_names]
if os.path.exists(self._labels_pkl):
self.labels = np.load(self._labels_pkl)
print("Loaded paths from " + self._paths_pkl)
logger.info("Loaded paths from %r", self._paths_pkl)


def plot(self, filename=None, linkage='ward', count_sort=False,
Expand Down Expand Up @@ -1796,7 +1807,7 @@ def plot(self, filename=None, linkage='ward', count_sort=False,
tic.tick1On = tic.tick2On = False

if filename is not None:
head = self.targetdir + self.datadirs['plots']
head = os.path.join(self.targetdir, self.datadirs['plots'])
outfile = os.path.join(head, filename)
savefig(outfile, dpi=300, bbox_inches='tight')

Expand Down Expand Up @@ -1907,7 +1918,7 @@ def plot_annotated_heatmap(self, filename=None, linkage='ward', \
tic.tick1On = tic.tick2On = False

if filename is not None:
head = self.targetdir + self.datadirs['plots']
head = os.path.join(self.targetdir, self.datadirs['plots'])
outfile = os.path.join(head, filename)
savefig(outfile, dpi=600, bbox_inches='tight')

Expand Down Expand Up @@ -2011,7 +2022,7 @@ def plot_nearest_neighbors(self, filename=None, idx=0, \
tight_layout()

if filename is not None:
head = self.targetdir + self.datadirs['plots']
head = os.path.join(self.targetdir, self.datadirs['plots'])
outfile = os.path.join(head, filename)
savefig(outfile, dpi=300, bbox_inches='tight')

Expand Down
48 changes: 48 additions & 0 deletions testsuite/MDAnalysisTests/analysis/test_psa.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,54 @@ def test_reversal_frechet(self, frech_matrix):
err_msg = "Frechet distances did not increase after path reversal"
assert frech_matrix[1,2] >= frech_matrix[0,1], err_msg

def test_get_num_paths(self, psa):
assert psa.get_num_paths() == 3

def test_get_paths(self, psa):
paths = psa.get_paths()
assert len(paths) == 3
assert isinstance(paths, list)

def test_psa_pairs_ValueError(self, psa):
with pytest.raises(ValueError):
psa.psa_pairs

def test_psa_pairs(self, psa):
psa.run_pairs_analysis()
assert len(psa.psa_pairs) == 3

def test_hausdorff_pairs_ValueError(self, psa):
with pytest.raises(ValueError):
psa.hausdorff_pairs

def test_hausdorff_pairs(self, psa):
psa.run_pairs_analysis(hausdorff_pairs=True)
assert len(psa.hausdorff_pairs) == 3

def test_nearest_neighbors_ValueError(self, psa):
with pytest.raises(ValueError):
psa.nearest_neighbors

def test_nearest_neighbors(self, psa):
psa.run_pairs_analysis(neighbors=True)
assert len(psa.nearest_neighbors) == 3

@pytest.mark.xfail
def test_load(self, psa):
"""Test that the automatically saved files can be loaded"""
expected_path_names = psa.path_names[:]
expected_paths = [p.copy() for p in psa.paths]
psa.save_paths()
psa.load()
assert psa.path_names == expected_path_names
# manually compare paths because
# assert_almost_equal(psa.paths, expected_paths, decimal=6)
# raises a ValueError in the assertion code itself
assert len(psa.paths) == len(expected_paths)
for ipath, (observed, expected) in enumerate(zip(psa.paths, expected_paths)):
assert_almost_equal(observed, expected, decimal=6,
err_msg="loaded path {} does not agree with input".format(ipath))

def test_dendrogram_produced(self, plot_data):
"""Test whether Dendrogram dictionary object was produced"""
err_msg = "Dendrogram dictionary object was not produced"
Expand Down

0 comments on commit ba55155

Please sign in to comment.