From 31fab1d4c828d3852bef5c632753f5a895d473c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Iv=C3=A1n=20Pulido?= Date: Thu, 7 Apr 2022 11:03:01 -0400 Subject: [PATCH] Avoid changing global context cache behavior on modules imports (#977) --- perses/dispersed/feptasks.py | 1 - perses/dispersed/smc.py | 6 ------ perses/dispersed/utils.py | 11 +++-------- perses/samplers/multistate.py | 1 - perses/samplers/samplers.py | 3 --- perses/tests/test_relative.py | 5 ----- 6 files changed, 3 insertions(+), 24 deletions(-) diff --git a/perses/dispersed/feptasks.py b/perses/dispersed/feptasks.py index dd614a4f1..b2475d21a 100644 --- a/perses/dispersed/feptasks.py +++ b/perses/dispersed/feptasks.py @@ -24,7 +24,6 @@ _logger = logging.getLogger("feptasks") _logger.setLevel(logging.INFO) -#cache.global_context_cache.platform = openmm.Platform.getPlatformByName('Reference') #this is just a local version EquilibriumFEPTask = namedtuple('EquilibriumInput', ['sampler_state', 'inputs', 'outputs']) NonequilibriumFEPTask = namedtuple('NonequilibriumFEPTask', ['particle', 'inputs']) diff --git a/perses/dispersed/smc.py b/perses/dispersed/smc.py index 0ccf3334c..15d399f52 100644 --- a/perses/dispersed/smc.py +++ b/perses/dispersed/smc.py @@ -1,6 +1,3 @@ -import openmmtools.cache as cache -import os -import copy from perses.dispersed.utils import * from openmmtools.states import ThermodynamicState, CompoundThermodynamicState, SamplerState @@ -12,17 +9,14 @@ from collections import namedtuple from perses.annihilation.lambda_protocol import LambdaProtocol from perses.annihilation.lambda_protocol import RelativeAlchemicalState -from perses.dispersed import * import random import pymbar from perses.dispersed.parallel import Parallelism -from openmmtools import utils # Instantiate logger logging.basicConfig(level = logging.NOTSET) _logger = logging.getLogger("sMC") _logger.setLevel(logging.INFO) -cache.global_context_cache.platform = configure_platform(utils.get_fastest_platform().getName()) EquilibriumFEPTask = namedtuple('EquilibriumInput', ['sampler_state', 'inputs', 'outputs']) DISTRIBUTED_ERROR_TOLERANCE = 1e-4 diff --git a/perses/dispersed/utils.py b/perses/dispersed/utils.py index 51734ebdf..fdceaaa7c 100644 --- a/perses/dispersed/utils.py +++ b/perses/dispersed/utils.py @@ -2,7 +2,6 @@ import os import copy -from openmmtools import cache import openmmtools.mcmc as mcmc import openmmtools.integrators as integrators import openmmtools.states as states @@ -20,7 +19,6 @@ import dask.distributed as distributed from scipy.special import logsumexp import openmmtools.cache as cache -from openmmtools import utils # Instantiate logger logging.basicConfig(level = logging.NOTSET) @@ -86,11 +84,8 @@ def configure_platform(platform_name='Reference', fallback_platform_name='CPU', print(f"conducting subsequent work with the following platform: {platform.getName()}") return platform -######### -cache.global_context_cache.platform = configure_platform(utils.get_fastest_platform().getName()) -######### -#smc functions +# smc functions def compute_survival_rate(sMC_particle_ancestries): """ compute the time-series survival rate as a function of resamples @@ -773,7 +768,7 @@ def anneal(self, if rethermalize: self.context.setVelocitiesToTemperature(self.thermodynamic_state.temperature) #rethermalize if noneq_trajectory_filename is not None: - self.save_configuration(idx, sampler_state, context) + self.save_configuration(idx, sampler_state) if return_timer: timer[idx] = time.time() - start_timer except Exception as e: @@ -889,7 +884,7 @@ def update_context(self, _lambda): self.thermodynamic_state.apply_to_context(self.context) - def save_configuration(self, iteration, sampler_state, context): + def save_configuration(self, iteration, sampler_state): """ pass a conditional save function diff --git a/perses/samplers/multistate.py b/perses/samplers/multistate.py index c5ef0800a..aac30abbf 100644 --- a/perses/samplers/multistate.py +++ b/perses/samplers/multistate.py @@ -7,7 +7,6 @@ from openmmtools.multistate import sams, replicaexchange from openmmtools import cache, utils from perses.dispersed.utils import configure_platform -cache.global_context_cache.platform = configure_platform(utils.get_fastest_platform().getName()) from openmmtools.states import CompoundThermodynamicState, SamplerState, ThermodynamicState from perses.dispersed.utils import create_endstates diff --git a/perses/samplers/samplers.py b/perses/samplers/samplers.py index 97e9af586..a0b4ec944 100644 --- a/perses/samplers/samplers.py +++ b/perses/samplers/samplers.py @@ -19,9 +19,6 @@ import numpy as np import time from openmmtools.states import SamplerState, ThermodynamicState -from openmmtools import cache, utils -from perses.dispersed.utils import configure_platform -cache.global_context_cache.platform = configure_platform(utils.get_fastest_platform().getName()) from perses.annihilation.ncmc_switching import NCMCEngine from perses.dispersed import feptasks diff --git a/perses/tests/test_relative.py b/perses/tests/test_relative.py index f677c2311..b4fb1acc1 100644 --- a/perses/tests/test_relative.py +++ b/perses/tests/test_relative.py @@ -25,11 +25,6 @@ running_on_github_actions = os.environ.get('GITHUB_ACTIONS', None) == 'true' -try: - cache.global_context_cache.platform = openmm.Platform.getPlatformByName("Reference") -except Exception: - cache.global_context_cache.platform = openmm.Platform.getPlatformByName("Reference") - ############################################# # CONSTANTS #############################################