Skip to content

Commit

Permalink
Test for immutability.
Browse files Browse the repository at this point in the history
  • Loading branch information
braniii committed Nov 9, 2023
1 parent 7b8722b commit c98cec9
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 10 deletions.
22 changes: 12 additions & 10 deletions src/msmhelper/statetraj.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def states(self):
Numpy array holding active set of states.
"""
return self._states
return self._states.copy()

@property
def nstates(self):
Expand Down Expand Up @@ -108,7 +108,7 @@ def ntrajs(self):

@property
def nframes(self):
"""Return cummulative length of all trajectories.
"""Return cumulative length of all trajectories.
Returns
-------
Expand All @@ -131,7 +131,7 @@ def trajs(self):
if np.array_equal(self.states, np.arange(1, self.nstates + 1)):
return [traj + 1 for traj in self._trajs]
if np.array_equal(self.states, np.arange(self.nstates)):
return self._trajs
return self.index_trajs
return mh.shift_data(
self._trajs,
np.arange(self.nstates),
Expand Down Expand Up @@ -160,7 +160,7 @@ def index_trajs(self):
List of ndarrays holding the input data.
"""
return self._trajs
return [traj.copy() for traj in self._trajs]

@property
def index_trajs_flatten(self):
Expand Down Expand Up @@ -338,7 +338,7 @@ def states(self):
Numpy array holding active set of states.
"""
return self._macrostates
return self._macrostates.copy()

@property
def nstates(self):
Expand All @@ -362,8 +362,10 @@ def microstate_trajs(self):
List of ndarrays holding the input data.
"""
if np.array_equal(self.microstates, np.arange(self.nmicrostates)):
return self._trajs
if np.array_equal(self.microstates, np.arange(1, self.nstates + 1)):
return [traj + 1 for traj in self._trajs]
elif np.array_equal(self.microstates, np.arange(self.nmicrostates)):
return self.microstate_index_trajs
return mh.shift_data(
self._trajs,
np.arange(self.nmicrostates),
Expand Down Expand Up @@ -392,7 +394,7 @@ def microstate_index_trajs(self):
List of ndarrays holding the microstate index trajectory.
"""
return self._trajs
return [traj.copy() for traj in self._trajs]

@property
def microstate_index_trajs_flatten(self):
Expand Down Expand Up @@ -448,7 +450,7 @@ def microstates(self):
Numpy array holding active set of states.
"""
return self._states
return self._states.copy()

@property
def nmicrostates(self):
Expand All @@ -472,7 +474,7 @@ def state_assignment(self):
Micro to macrostate assignment vector.
"""
return self._state_assignment
return self._state_assignment.copy()

@property
def _state_assignment_idx(self):
Expand Down
10 changes: 10 additions & 0 deletions test/test_statetraj.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,16 @@ def test_nstates(state_traj, statetraj, macro_traj, macrotraj):
state_traj.nstates = 5


def test_states(state_traj):
"""Test immutability of states property."""
states = state_traj.states
states += 2
assert np.testing.assert_array_almost_equal(states - 2, state_traj.states)

with pytest.raises(AttributeError):
state_traj.states = states


def test_nframes(state_traj):
"""Test nframes property."""
assert state_traj.nframes == len(state_traj[0])
Expand Down

0 comments on commit c98cec9

Please sign in to comment.