Skip to content

Commit

Permalink
Merge pull request #40 from moldyn/issue_39
Browse files Browse the repository at this point in the history
Fix Issue 39
  • Loading branch information
braniii authored Nov 3, 2023
2 parents 9f751f8 + 861b7a8 commit ed2fafb
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 6 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added Features and Improvements 🙌:
- Supporting latest Python 3.11 🎉

### Bugfix 🐛:
- Fixed wrong connectivity in `mh.datasets.utils.propagate_MCMC` for different datatypes, #39


## [1.0.4] - 2023-05-22
### Other changes:
Expand Down
8 changes: 4 additions & 4 deletions src/msmhelper/msm/timescales.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,15 +321,15 @@ def _propagate_MCMC_step(cummat, idx_from):
"""Propagate a single step Markov chain Monte Carlo."""
rand = random.random() # noqa: S311
cummat_perm, state_perm = cummat
cummat_perm, state_perm = cummat_perm[idx_from], state_perm[idx_from]
cummat_from, state_from = cummat_perm[idx_from], state_perm[idx_from]

for idx, cummat_idx in enumerate(cummat_perm):
for idx, cummat_idx in enumerate(cummat_from):
# strict less to ensure that rand=0 does not jump along unconnected
# states with Tij=0.
if rand < cummat_idx:
return state_perm[idx]
return state_from[idx]
# this should never be reached, but needed for numba to ensure int return
return len(cummat_perm) - 1
return state_from[np.argmax(cummat_from)] # pragma: no cover


@numba.njit
Expand Down
7 changes: 5 additions & 2 deletions src/msmhelper/utils/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import numpy as np

from msmhelper.msm.msm import row_normalize_matrix
from msmhelper.msm.timescales import _propagate_MCMC
from .tests import is_transition_matrix
from ._utils import shift_data
Expand Down Expand Up @@ -47,8 +48,10 @@ def propagate_tmat(tmat, nsteps, start=None):
raise ValueError('tmat needs to be a row-normalized matrix.')

n_states = len(tmat)
cummat = np.cumsum(tmat, axis=1)
cummat[:, -1] = 1 # enforce exact normalization
cummat = np.cumsum( # enforce exact normalization
row_normalize_matrix(tmat),
axis=1,
)
cummat_perm = np.tile(np.arange(n_states), (n_states, 1))

if start is None:
Expand Down
50 changes: 50 additions & 0 deletions test/test_utils_datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-
"""Tests for the filtering submodule.
BSD 3-Clause License
Copyright (c) 2019-2023, Daniel Nagel
All rights reserved.
"""
import numpy as np
import pytest
from msmhelper import msm
from msmhelper.utils import datasets

NSTEPS = 1000000
DECIMAL = 2


def test_nagel20_4state():
"""Test nagel20_4state mcmc."""
traj_mcmc = datasets.nagel20_4state(NSTEPS)

np.testing.assert_array_almost_equal(
datasets.nagel20_4state.tmat,
msm.estimate_markov_model(trajs=traj_mcmc, lagtime=1)[0],
decimal=DECIMAL,
)


def test_nagel20_6state():
"""Test nagel20_6state mcmc."""
traj_mcmc = datasets.nagel20_6state(NSTEPS)

np.testing.assert_array_almost_equal(
datasets.nagel20_6state.tmat,
msm.estimate_markov_model(trajs=traj_mcmc, lagtime=1)[0],
decimal=DECIMAL,
)


@pytest.mark.parametrize('nstates', [2, 2, 3, 4, 5])
def test_propagate_tmat(nstates):
tmat = np.random.uniform(size=(nstates, nstates))
tmat = msm.row_normalize_matrix(tmat)

traj_mcmc = datasets.propagate_tmat(tmat, NSTEPS)
np.testing.assert_array_almost_equal(
tmat,
msm.estimate_markov_model(trajs=traj_mcmc, lagtime=1)[0],
decimal=DECIMAL,
)

0 comments on commit ed2fafb

Please sign in to comment.