Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable custom bias classes for umbrella integration #177

Merged
merged 22 commits into from
Aug 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/pysages_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,5 @@ dipeptide
conformational
kT
Sobolev
pysages
metad
5 changes: 3 additions & 2 deletions examples/hoomd-blue/umbrella_integration/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import pysages
from pysages.colvars import Component
from pysages.methods import UmbrellaIntegration, SerialExecutor
from pysages.methods import HarmonicBias, UmbrellaIntegration, SerialExecutor


params = {"A": 0.5, "w": 0.2, "p": 2}
Expand Down Expand Up @@ -139,7 +139,8 @@ def main(argv):
cvs = [Component([0], 0)]

centers = list(np.linspace(args.start_path, args.end_path, args.replicas))
method = UmbrellaIntegration(cvs, args.k_spring, centers, args.log_period, args.log_delay)
biasers = [HarmonicBias(cvs, args.k_spring, c) for c in centers]
method = UmbrellaIntegration(biasers, args.log_period, args.log_delay)

context_args = {"mpi_enabled": args.mpi}

Expand Down
2 changes: 0 additions & 2 deletions examples/openmm/umbrella_integration/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,10 @@ def main(argv):
for pos in center_pos:
centers.append((pos, pos))
method = UmbrellaIntegration(cvs, args.k_spring, centers, args.log_period, args.log_delay)

raw_result = pysages.run(
method,
generate_simulation,
args.time_steps,
# post_run_action=post_run_action,
executor=get_executor(args),
)
result = pysages.analyze(raw_result)
Expand Down
1 change: 1 addition & 0 deletions pysages/methods/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,4 @@
)
from .restraints import CVRestraints
from .unbiased import Unbiased
from .bias import Bias
70 changes: 70 additions & 0 deletions pysages/methods/bias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2020-2021: PySAGES contributors
# See LICENSE.md and CONTRIBUTORS.md at https://github.com/SSAGESLabs/PySAGES

"""
Generic abstract bias method.
"""

from abc import abstractmethod
from jax import numpy as np
from pysages.methods.core import SamplingMethod


class Bias(SamplingMethod):
"""
Abstract biasing class.
In this context a biasing methond ensures that a system is biased around a fixed `center` in CV space.
How this biasing is achieved is up to the individual implementation.
A common biasing form is implemented via the `HarmonicBias` class.

Biasing is commonly used in other advanced sampling methods, such as UmbrellaIntegration
or the ImprovedString method.
This abstract class defines an interface to interact with the CV center, such that method can rely on it.
"""

__special_args__ = {"center"}
snapshot_flags = {"positions", "indices"}

def __init__(self, cvs, center, **kwargs):
"""
Arguments
---------
cvs: Union[List, Tuple]
A list or tuple of collective variables, length `N`.
center:
An array of length `N` representing the minimum of the harmonic biasing potential.
"""
super().__init__(cvs, **kwargs)
self.cv_dimension = len(cvs)
self.center = center

def __getstate__(self):
state, kwargs = super().__getstate__()
state["center"] = self._center
return state, kwargs

@property
def center(self):
"""
Retrieve current center of the collective variable.
"""
return self._center

@center.setter
def center(self, center):
"""
Set the center of the collective variable to a new position.
"""
center = np.asarray(center)
if center.shape == ():
center = center.reshape(1)
if len(center.shape) != 1 or center.shape[0] != self.cv_dimension:
raise RuntimeError(
f"Invalid center shape expected {self.cv_dimension} got {center.shape}."
)
self._center = center

@abstractmethod
def build(self, snapshot, helpers, *args, **kwargs):
pass
35 changes: 6 additions & 29 deletions pysages/methods/harmonic_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

from jax import numpy as np

from pysages.methods.core import SamplingMethod, default_getstate, generalize
from pysages.methods.core import generalize
from pysages.methods.bias import Bias
from pysages.utils import JaxArray


Expand All @@ -40,13 +41,12 @@ def __repr__(self):
return repr("PySAGES" + type(self).__name__)


class HarmonicBias(SamplingMethod):
class HarmonicBias(Bias):
"""
Harmonic bias method class.
"""

__special_args__ = {"kspring", "center"}
snapshot_flags = {"positions", "indices"}
__special_args__ = Bias.__special_args__.union({"kspring"})

def __init__(self, cvs, kspring, center, **kwargs):
"""
Expand All @@ -59,15 +59,13 @@ def __init__(self, cvs, kspring, center, **kwargs):
center:
An array of length `N` representing the minimum of the harmonic biasing potential.
"""
super().__init__(cvs, **kwargs)
super().__init__(cvs, center, **kwargs)
self.cv_dimension = len(cvs)
self.kspring = kspring
self.center = center

def __getstate__(self):
state, kwargs = default_getstate(self)
state, kwargs = super().__getstate__()
InnocentBug marked this conversation as resolved.
Show resolved Hide resolved
state["kspring"] = self._kspring
state["center"] = self._center
return state, kwargs

@property
Expand Down Expand Up @@ -109,27 +107,6 @@ def kspring(self, kspring):
self._kspring = np.identity(N) * kspring
return self._kspring

@property
def center(self):
"""
Retrieve current center of the collective variable.
"""
return self._center

@center.setter
def center(self, center):
"""
Set the center of the collective variable to a new position.
"""
center = np.asarray(center)
if center.shape == ():
center = center.reshape(1)
if len(center.shape) != 1 or center.shape[0] != self.cv_dimension:
raise RuntimeError(
f"Invalid center shape expected {self.cv_dimension} got {center.shape}."
)
self._center = center

def build(self, snapshot, helpers, *args, **kwargs):
return _harmonic_bias(self, snapshot, helpers)

Expand Down
38 changes: 37 additions & 1 deletion pysages/methods/umbrella_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from copy import deepcopy
from typing import Callable, Optional, Union

import plum
from pysages.methods.core import Result, SamplingMethod, _run
from pysages.methods.harmonic_bias import HarmonicBias
from pysages.methods.utils import HistogramLogger, listify, SerialExecutor
Expand All @@ -33,7 +34,16 @@ class UmbrellaIntegration(SamplingMethod):
Note that this is not very accurate and usually requires more sophisticated analysis on top.
"""

def __init__(self, cvs, ksprings, centers, hist_periods, hist_offsets=0, **kwargs):
@plum.dispatch
def __init__(
self,
cvs,
ksprings,
centers,
hist_periods: Union[list, int],
hist_offsets: Union[list, int] = 0,
**kwargs
):
"""
Initialization, sets up the HarmonicBias subsamplers.

Expand Down Expand Up @@ -62,6 +72,32 @@ def __init__(self, cvs, ksprings, centers, hist_periods, hist_offsets=0, **kwarg
self.submethods = [HarmonicBias(cvs, k, c) for (k, c) in zip(ksprings, centers)]
self.histograms = [HistogramLogger(p, o) for (p, o) in zip(periods, offsets)]

@plum.dispatch
def __init__( # noqa: F811 # pylint: disable=C0116,E0102
self,
biasers: list,
hist_periods: Union[list, int],
hist_offsets: Union[list, int] = 0,
**kwargs
):
cvs = None
for bias in biasers:
if cvs is None:
cvs = bias.cvs
else:
if bias.cvs != cvs:
raise RuntimeError(
"Attempted run of UmbrellaSampling with different CVs"
" for the individual biaser."
)
super().__init__(cvs, **kwargs)
replicas = len(biasers)
periods = listify(hist_periods, replicas, "hist_periods", int)
offsets = listify(hist_offsets, replicas, "hist_offsets", int)

self.submethods = biasers
self.histograms = [HistogramLogger(p, o) for (p, o) in zip(periods, offsets)]

# We delegate the sampling work to HarmonicBias
# (or possibly other methods in the future)
def build(self): # pylint: disable=arguments-differ
Expand Down
6 changes: 1 addition & 5 deletions pysages/methods/unbiased.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from jax import numpy as np

from pysages.methods.core import SamplingMethod, default_getstate, generalize
from pysages.methods.core import SamplingMethod, generalize
from pysages.utils import JaxArray


Expand Down Expand Up @@ -57,10 +57,6 @@ def __init__(self, cvs, **kwargs):
kwargs["cv_grad"] = None
super().__init__(cvs, **kwargs)

def __getstate__(self):
state, kwargs = default_getstate(self)
return state, kwargs

def build(self, snapshot, helpers, *args, **kwargs):
return _unbias(self, snapshot, helpers)

Expand Down
1 change: 1 addition & 0 deletions tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
"ReplicasConfiguration": {},
"SerialExecutor": {},
"CVRestraints": {"lower": (-pi, -pi), "upper": (pi, pi), "kl": (0.0, 1.0), "ku": (1.0, 0.0)},
"Bias": {"cvs": [pysages.colvars.Component([0], 0)], "center": 0.7},
}


Expand Down