diff --git a/CHANGELOG.md b/CHANGELOG.md index 044d667..5890f01 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added Features and Improvements 🙌: - Supporting latest Python 3.11 🎉 +### Bugfix 🐛: +- Fixed wrong connectivity in `mh.datasets.utils.propagate_MCMC` for different datatypes, #39 + ## [1.0.4] - 2023-05-22 ### Other changes: diff --git a/src/msmhelper/msm/timescales.py b/src/msmhelper/msm/timescales.py index dc3ecef..c0151c6 100644 --- a/src/msmhelper/msm/timescales.py +++ b/src/msmhelper/msm/timescales.py @@ -321,15 +321,15 @@ def _propagate_MCMC_step(cummat, idx_from): """Propagate a single step Markov chain Monte Carlo.""" rand = random.random() # noqa: S311 cummat_perm, state_perm = cummat - cummat_perm, state_perm = cummat_perm[idx_from], state_perm[idx_from] + cummat_from, state_from = cummat_perm[idx_from], state_perm[idx_from] - for idx, cummat_idx in enumerate(cummat_perm): + for idx, cummat_idx in enumerate(cummat_from): # strict less to ensure that rand=0 does not jump along unconnected # states with Tij=0. if rand < cummat_idx: - return state_perm[idx] + return state_from[idx] # this should never be reached, but needed for numba to ensure int return - return len(cummat_perm) - 1 + return state_from[np.argmax(cummat_from)] # pragma: no cover @numba.njit diff --git a/src/msmhelper/utils/datasets.py b/src/msmhelper/utils/datasets.py index 1a8d5de..e92cc2e 100644 --- a/src/msmhelper/utils/datasets.py +++ b/src/msmhelper/utils/datasets.py @@ -7,6 +7,7 @@ import numpy as np +from msmhelper.msm.msm import row_normalize_matrix from msmhelper.msm.timescales import _propagate_MCMC from .tests import is_transition_matrix from ._utils import shift_data @@ -47,8 +48,10 @@ def propagate_tmat(tmat, nsteps, start=None): raise ValueError('tmat needs to be a row-normalized matrix.') n_states = len(tmat) - cummat = np.cumsum(tmat, axis=1) - cummat[:, -1] = 1 # enforce exact normalization + cummat = np.cumsum( # enforce exact normalization + row_normalize_matrix(tmat), + axis=1, + ) cummat_perm = np.tile(np.arange(n_states), (n_states, 1)) if start is None: diff --git a/test/test_utils_datasets.py b/test/test_utils_datasets.py new file mode 100644 index 0000000..fb47045 --- /dev/null +++ b/test/test_utils_datasets.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +"""Tests for the filtering submodule. + +BSD 3-Clause License +Copyright (c) 2019-2023, Daniel Nagel +All rights reserved. + +""" +import numpy as np +import pytest +from msmhelper import msm +from msmhelper.utils import datasets + +NSTEPS = 1000000 +DECIMAL = 2 + + +def test_nagel20_4state(): + """Test nagel20_4state mcmc.""" + traj_mcmc = datasets.nagel20_4state(NSTEPS) + + np.testing.assert_array_almost_equal( + datasets.nagel20_4state.tmat, + msm.estimate_markov_model(trajs=traj_mcmc, lagtime=1)[0], + decimal=DECIMAL, + ) + + +def test_nagel20_6state(): + """Test nagel20_6state mcmc.""" + traj_mcmc = datasets.nagel20_6state(NSTEPS) + + np.testing.assert_array_almost_equal( + datasets.nagel20_6state.tmat, + msm.estimate_markov_model(trajs=traj_mcmc, lagtime=1)[0], + decimal=DECIMAL, + ) + + +@pytest.mark.parametrize('nstates', [2, 2, 3, 4, 5]) +def test_propagate_tmat(nstates): + tmat = np.random.uniform(size=(nstates, nstates)) + tmat = msm.row_normalize_matrix(tmat) + + traj_mcmc = datasets.propagate_tmat(tmat, NSTEPS) + np.testing.assert_array_almost_equal( + tmat, + msm.estimate_markov_model(trajs=traj_mcmc, lagtime=1)[0], + decimal=DECIMAL, + )