Skip to content
This repository has been archived by the owner on Sep 11, 2023. It is now read-only.

Commit

Permalink
Remove bhmm as dependency (#1550)
Browse files Browse the repository at this point in the history
  • Loading branch information
clonker authored Mar 17, 2022
1 parent f6a7a7d commit 6d1f334
Show file tree
Hide file tree
Showing 19 changed files with 276 additions and 407 deletions.
72 changes: 0 additions & 72 deletions .circleci/config.yml

This file was deleted.

1 change: 0 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,5 @@ Or start a discussion on our mailing list: pyemma-users@lists.fu-berlin.de
External Libraries
------------------
* mdtraj (LGPLv3): https://mdtraj.org
* bhmm (LGPLv3): http://github.com/bhmm/bhmm
* msmtools (LGLPv3): http://github.com/markovmodel/msmtools
* thermotools (LGLPv3): http://github.com/markovmodel/thermotools
32 changes: 1 addition & 31 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,38 +38,8 @@ def session_fixture():
try:
os.mkdir(tempfile.tempdir)
except OSError as ose:
if 'exists' not in ose.strerror.lower():
if 'exists' not in ose.strerror.lower():
raise
yield
import shutil
shutil.rmtree(tempfile.tempdir, ignore_errors=True)


def pytest_collection_modifyitems(session, config, items):
circle_node_total, circle_node_index = read_circleci_env_variables()
deselected = []
for item in items:
i = hash(item.name)
if i % circle_node_total != circle_node_index:
deselected.append(item)
for item in deselected:
items.remove(item)

config.hook.pytest_deselected(items=deselected)


def read_circleci_env_variables():
"""Read and convert CIRCLE_* environment variables"""
circle_node_total = int(os.environ.get("CIRCLE_NODE_TOTAL", "1").strip() or "1")
circle_node_index = int(os.environ.get("CIRCLE_NODE_INDEX", "0").strip() or "0")

if circle_node_index >= circle_node_total:
raise RuntimeError("CIRCLE_NODE_INDEX={} >= CIRCLE_NODE_TOTAL={}, should be less".format(circle_node_index, circle_node_total))

return circle_node_total, circle_node_index


def pytest_report_header(config):
"""Add CircleCI information to report"""
circle_node_total, circle_node_index = read_circleci_env_variables()
return "CircleCI total nodes: {}, this node index: {}".format(circle_node_total, circle_node_index)
2 changes: 1 addition & 1 deletion devtools/azure-pipelines-win.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ jobs:
- job:
displayName: Win
pool:
vmImage: vs2017-win2016
vmImage: "windows-2019"
strategy:
matrix:
Python37:
Expand Down
2 changes: 0 additions & 2 deletions devtools/conda-recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ source:

build:
script_env:
- CIRCLE_TEST_REPORTS
- OMP_NUM_THREADS
- PYEMMA_NJOBS
script: "{{ PYTHON }} -m pip install . --no-deps --ignore-installed --no-cache-dir -vvv"
Expand Down Expand Up @@ -39,7 +38,6 @@ requirements:
- deeptime >=0.3.0

run:
- bhmm >=0.6.3
- decorator >=4.0.0
- h5py
- intel-openmp # [osx]
Expand Down
2 changes: 1 addition & 1 deletion devtools/conda-recipe/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
cov_xml = os.path.join(xml_results_dest, 'coverage.xml')

print('junit destination:', junit_xml)
njobs_args = '-p no:xdist' # if os.getenv('TRAVIS') or os.getenv('CIRCLECI') else '-n2'
njobs_args = '-p no:xdist' # if os.getenv('TRAVIS') else '-n2'

pytest_args = ("-v --pyargs {test_pkg} "
"--cov={cover_pkg} "
Expand Down
16 changes: 11 additions & 5 deletions pyemma/_base/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import sys
import os

from threadpoolctl import threadpool_limits

from pyemma._ext.sklearn.base import BaseEstimator as _BaseEstimator
from pyemma._ext.sklearn.parameter_search import ParameterGrid
from pyemma.util import types as _types
Expand Down Expand Up @@ -125,7 +127,7 @@ def _call_member(obj, name, failfast=True, *args, **kwargs):


def _estimate_param_scan_worker(estimator, params, X, evaluate, evaluate_args,
failfast, return_exceptions):
failfast, return_exceptions, limit_threads):
""" Method that runs estimation for several parameter settings.
Defined as a worker for parallelization
Expand All @@ -134,8 +136,9 @@ def _estimate_param_scan_worker(estimator, params, X, evaluate, evaluate_args,
# run estimation
model = None
try: # catch any exception
estimator.estimate(X, **params)
model = estimator.model
with threadpool_limits(limits=1 if limit_threads else None):
estimator.estimate(X, **params)
model = estimator.model
except KeyboardInterrupt:
# we want to be able to interactively interrupt the worker, no matter of failfast=False.
raise
Expand Down Expand Up @@ -326,11 +329,14 @@ def estimate_param_scan(estimator, X, param_sets, evaluate=None, evaluate_args=N
if logger_available:
logger.debug('estimating %s with n_jobs=%s', estimator, n_jobs)
# iterate over parameter settings
limit_threads = True
task_iter = ((estimator,
param_set, X,
evaluate,
evaluate_args,
failfast, return_exceptions)
failfast,
return_exceptions,
limit_threads)
for estimator, param_set in zip(estimators, param_sets))

from pathos.multiprocessing import Pool
Expand Down Expand Up @@ -358,7 +364,7 @@ def error_callback(*args, **kw):
with ctx:
for estimator, param_set in zip(estimators, param_sets):
res.append(_estimate_param_scan_worker(estimator, param_set, X,
evaluate, evaluate_args, failfast, return_exceptions))
evaluate, evaluate_args, failfast, return_exceptions, False))
if progress_reporter is not None:
progress_reporter._progress_update(1, stage='param-scan')

Expand Down
1 change: 1 addition & 0 deletions pyemma/_base/serialization/pickle_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ class HDF5PersistentUnpickler(Unpickler):
'numpy',
'scipy',
'bhmm',
'deeptime'
)

def __init__(self, group, file):
Expand Down
2 changes: 1 addition & 1 deletion pyemma/coordinates/clustering/tests/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def make_blobs(n_samples=100, n_features=2, centers=3, cluster_std=1.0,
centers = generator.uniform(center_box[0], center_box[1],
size=(centers, n_features))
else:
from bhmm._external.sklearn.utils import check_array
from sklearn.utils import check_array
centers = check_array(centers)
n_features = centers.shape[1]

Expand Down
4 changes: 4 additions & 0 deletions pyemma/coordinates/tests/test_traj_info_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@

from unittest import mock

import pytest

from pyemma.coordinates import api
from pyemma.coordinates.data.feature_reader import FeatureReader
from pyemma.coordinates.data.numpy_filereader import NumPyFileReader
Expand All @@ -49,6 +51,7 @@
pdbfile = get_bpti_test_data()['top']


@pytest.mark.skip(reason="Sometimes causes CI to go spinning beach ball of death")
class TestTrajectoryInfoCache(unittest.TestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -336,6 +339,7 @@ def load_module(self, fullname, path=None):
finally:
del sys.meta_path[0]

@pytest.mark.skip(reason="Sometimes causes CI to go spinning beach ball of death")
def test_in_memory_db(self):
""" new instance, not yet saved to disk, no lru cache avail """
old_cfg_dir = config.cfg_dir
Expand Down
51 changes: 34 additions & 17 deletions pyemma/msm/estimators/bayesian_hmsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,13 @@

import numpy as _np


from pyemma._base.progress import ProgressReporterMixin
from pyemma.msm.estimators.maximum_likelihood_hmsm import MaximumLikelihoodHMSM as _MaximumLikelihoodHMSM
from pyemma.msm.models.hmsm import HMSM as _HMSM
from pyemma.msm.models.hmsm_sampled import SampledHMSM as _SampledHMSM
from pyemma.util.annotators import fix_docs
from pyemma.util.types import ensure_dtraj_list
from pyemma.util.units import TimeUnit
from bhmm import lag_observations as _lag_observations
from msmtools.estimation import number_of_states as _number_of_states

__author__ = 'noe'
Expand Down Expand Up @@ -220,7 +218,9 @@ def _estimate(self, dtrajs):

# if stride is different to init_hmsm, check if microstates in lagged-strided trajs are compatible
if self.stride != self.init_hmsm.stride:
dtrajs_lagged_strided = _lag_observations(dtrajs, self.lag, stride=self.stride)
from deeptime.markov import compute_dtrajs_effective
dtrajs_lagged_strided = compute_dtrajs_effective(dtrajs, lagtime=self.lag, n_states=-1,
stride=self.stride)
_nstates_obs = _number_of_states(dtrajs_lagged_strided, only_used=True)
_nstates_obs_full = _number_of_states(dtrajs)

Expand Down Expand Up @@ -279,29 +279,46 @@ def _estimate(self, dtrajs):
if self.show_progress:
self._progress_register(self.nsamples, description='Sampling HMSMs', stage=0)

def call_back():
self._progress_update(1, stage=0)
from deeptime.util.callbacks import ProgressCallback
outer_self = self

class BHMMCallback(ProgressCallback):

def __call__(self, inc=1, *args, **kw):
super().__call__(inc, *args, **kw)
outer_self._progress_update(1, stage=0)

progress = BHMMCallback
else:
call_back = None
progress = None

from bhmm import discrete_hmm, bayesian_hmm
from deeptime.markov.hmm import BayesianHMM

if self.init_hmsm is not None:
hmm_mle = self.init_hmsm.hmm
estimator = BayesianHMM(hmm_mle, n_samples=self.nsamples, stride=self.stride,
initial_distribution_prior=self.p0_prior,
transition_matrix_prior=self.transition_matrix_prior,
store_hidden=self.store_hidden, reversible=self.reversible,
stationary=self.stationary)
else:
hmm_mle = discrete_hmm(self.initial_distribution, self.transition_matrix, B_init)

sampled_hmm = bayesian_hmm(self.discrete_trajectories_lagged, hmm_mle, nsample=self.nsamples,
reversible=self.reversible, stationary=self.stationary,
p0_prior=self.p0_prior, transition_matrix_prior=self.transition_matrix_prior,
store_hidden=self.store_hidden, call_back=call_back)

estimator = BayesianHMM.default(dtrajs, n_hidden_states=self.nstates, lagtime=self.lag,
n_samples=self.nsamples, stride=self.stride,
initial_distribution_prior=self.p0_prior,
transition_matrix_prior=self.transition_matrix_prior,
store_hidden=self.store_hidden, reversible=self.reversible,
stationary=self.stationary,
prior_submodel=True, separate=self.separate)

estimator.fit(dtrajs, n_burn_in=0, n_thin=1, progress=progress)
model = estimator.fetch_model()
if self.show_progress:
self._progress_force_finish(stage=0)

# Samples
sample_inp = [(m.transition_matrix, m.stationary_distribution, m.output_probabilities)
for m in sampled_hmm.sampled_hmms]
sample_inp = [(m.transition_model.transition_matrix, m.transition_model.stationary_distribution,
m.output_probabilities)
for m in model.samples]

samples = []
for P, pi, pobs in sample_inp: # restrict to observable set if necessary
Expand All @@ -310,7 +327,7 @@ def call_back():
samples.append(_HMSM(P, pobs, pi=pi, dt_model=self.dt_model))

# store results
self.sampled_trajs = [sampled_hmm.sampled_hmms[i].hidden_state_trajectories for i in range(self.nsamples)]
self.sampled_trajs = [model.samples[i].hidden_state_trajectories for i in range(self.nsamples)]
self.update_model_params(samples=samples)

# deal with connectivity
Expand Down
Loading

0 comments on commit 6d1f334

Please sign in to comment.