Skip to content

Commit

Permalink
Avoid changing global context cache behavior on modules imports (#977)
Browse files Browse the repository at this point in the history
  • Loading branch information
ijpulidos authored Apr 7, 2022
1 parent 59c77fb commit 31fab1d
Show file tree
Hide file tree
Showing 6 changed files with 3 additions and 24 deletions.
1 change: 0 additions & 1 deletion perses/dispersed/feptasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
_logger = logging.getLogger("feptasks")
_logger.setLevel(logging.INFO)

#cache.global_context_cache.platform = openmm.Platform.getPlatformByName('Reference') #this is just a local version
EquilibriumFEPTask = namedtuple('EquilibriumInput', ['sampler_state', 'inputs', 'outputs'])
NonequilibriumFEPTask = namedtuple('NonequilibriumFEPTask', ['particle', 'inputs'])

Expand Down
6 changes: 0 additions & 6 deletions perses/dispersed/smc.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import openmmtools.cache as cache
import os
import copy
from perses.dispersed.utils import *

from openmmtools.states import ThermodynamicState, CompoundThermodynamicState, SamplerState
Expand All @@ -12,17 +9,14 @@
from collections import namedtuple
from perses.annihilation.lambda_protocol import LambdaProtocol
from perses.annihilation.lambda_protocol import RelativeAlchemicalState
from perses.dispersed import *
import random
import pymbar
from perses.dispersed.parallel import Parallelism
from openmmtools import utils
# Instantiate logger
logging.basicConfig(level = logging.NOTSET)
_logger = logging.getLogger("sMC")
_logger.setLevel(logging.INFO)

cache.global_context_cache.platform = configure_platform(utils.get_fastest_platform().getName())
EquilibriumFEPTask = namedtuple('EquilibriumInput', ['sampler_state', 'inputs', 'outputs'])
DISTRIBUTED_ERROR_TOLERANCE = 1e-4

Expand Down
11 changes: 3 additions & 8 deletions perses/dispersed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
import copy

from openmmtools import cache
import openmmtools.mcmc as mcmc
import openmmtools.integrators as integrators
import openmmtools.states as states
Expand All @@ -20,7 +19,6 @@
import dask.distributed as distributed
from scipy.special import logsumexp
import openmmtools.cache as cache
from openmmtools import utils

# Instantiate logger
logging.basicConfig(level = logging.NOTSET)
Expand Down Expand Up @@ -86,11 +84,8 @@ def configure_platform(platform_name='Reference', fallback_platform_name='CPU',
print(f"conducting subsequent work with the following platform: {platform.getName()}")
return platform

#########
cache.global_context_cache.platform = configure_platform(utils.get_fastest_platform().getName())
#########

#smc functions
# smc functions
def compute_survival_rate(sMC_particle_ancestries):
"""
compute the time-series survival rate as a function of resamples
Expand Down Expand Up @@ -773,7 +768,7 @@ def anneal(self,
if rethermalize:
self.context.setVelocitiesToTemperature(self.thermodynamic_state.temperature) #rethermalize
if noneq_trajectory_filename is not None:
self.save_configuration(idx, sampler_state, context)
self.save_configuration(idx, sampler_state)
if return_timer:
timer[idx] = time.time() - start_timer
except Exception as e:
Expand Down Expand Up @@ -889,7 +884,7 @@ def update_context(self, _lambda):
self.thermodynamic_state.apply_to_context(self.context)


def save_configuration(self, iteration, sampler_state, context):
def save_configuration(self, iteration, sampler_state):
"""
pass a conditional save function
Expand Down
1 change: 0 additions & 1 deletion perses/samplers/multistate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from openmmtools.multistate import sams, replicaexchange
from openmmtools import cache, utils
from perses.dispersed.utils import configure_platform
cache.global_context_cache.platform = configure_platform(utils.get_fastest_platform().getName())
from openmmtools.states import CompoundThermodynamicState, SamplerState, ThermodynamicState
from perses.dispersed.utils import create_endstates

Expand Down
3 changes: 0 additions & 3 deletions perses/samplers/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
import numpy as np
import time
from openmmtools.states import SamplerState, ThermodynamicState
from openmmtools import cache, utils
from perses.dispersed.utils import configure_platform
cache.global_context_cache.platform = configure_platform(utils.get_fastest_platform().getName())

from perses.annihilation.ncmc_switching import NCMCEngine
from perses.dispersed import feptasks
Expand Down
5 changes: 0 additions & 5 deletions perses/tests/test_relative.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,6 @@

running_on_github_actions = os.environ.get('GITHUB_ACTIONS', None) == 'true'

try:
cache.global_context_cache.platform = openmm.Platform.getPlatformByName("Reference")
except Exception:
cache.global_context_cache.platform = openmm.Platform.getPlatformByName("Reference")

#############################################
# CONSTANTS
#############################################
Expand Down

0 comments on commit 31fab1d

Please sign in to comment.