Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add alternate walk to dynesty #3436

Merged
merged 2 commits into from
Sep 2, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 164 additions & 1 deletion pycbc/inference/sampler/dynesty.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
import os
import time
import numpy
import dynesty
import dynesty, dynesty.dynesty, dynesty.nestedsamplers
from dynesty.utils import unitcheck, reflect
from pycbc.pool import choose_pool
from dynesty import utils as dyfunc
from pycbc.inference.io import (DynestyFile, validate_checkpoint_files,
Expand Down Expand Up @@ -126,6 +127,12 @@ def __init__(self, model, nlive, nprocesses=1,
if len(reflective) == 0:
reflective = None

if 'sample' in kwargs:
if 'rwalk2' in kwargs['sample']:
dynesty.dynesty._SAMPLING["rwalk"] = sample_rwalk_mod
dynesty.nestedsamplers._SAMPLING["rwalk"] = sample_rwalk_mod
kwargs['sample'] = 'rwalk'

if self.nlive < 0:
# Interpret a negative input value for the number of live points
# (which is clearly an invalid input in all senses)
Expand Down Expand Up @@ -395,3 +402,159 @@ def logz_err(self):
dynesty sampler
"""
return self._sampler.results.logzerr[-1:][0]


def sample_rwalk_mod(args):
""" Modified version of dynesty.sampling.sample_rwalk

Adapted from version used in bilby/dynesty
"""

# Unzipping.
(u, loglstar, axes, scale,
prior_transform, loglikelihood, kwargs) = args
rstate = numpy.random

# Bounds
nonbounded = kwargs.get('nonbounded', None)
periodic = kwargs.get('periodic', None)
reflective = kwargs.get('reflective', None)

# Setup.
n = len(u)
walks = kwargs.get('walks', 10 * n) # minimum number of steps
maxmcmc = kwargs.get('maxmcmc', 2000) # Maximum number of steps
nact = kwargs.get('nact', 5) # Number of ACT
old_act = kwargs.get('old_act', walks)

# Initialize internal variables
accept = 0
reject = 0
nfail = 0
act = numpy.inf
u_list = []
v_list = []
logl_list = []

ii = 0
while ii < nact * act:
ii += 1

# Propose a direction on the unit n-sphere.
drhat = rstate.randn(n)
drhat /= numpy.linalg.norm(drhat)

# Scale based on dimensionality.
dr = drhat * rstate.rand() ** (1.0 / n)

# Transform to proposal distribution.
du = numpy.dot(axes, dr)
u_prop = u + scale * du

# Wrap periodic parameters
if periodic is not None:
u_prop[periodic] = numpy.mod(u_prop[periodic], 1)
# Reflect
if reflective is not None:
u_prop[reflective] = reflect(u_prop[reflective])

# Check unit cube constraints.
if unitcheck(u_prop, nonbounded):
pass
else:
nfail += 1
# Only start appending to the chain once a single jump is made
if accept > 0:
u_list.append(u_list[-1])
v_list.append(v_list[-1])
logl_list.append(logl_list[-1])
continue

# Check proposed point.
v_prop = prior_transform(numpy.array(u_prop))
logl_prop = loglikelihood(numpy.array(v_prop))
if logl_prop >= loglstar:
u = u_prop
v = v_prop
logl = logl_prop
accept += 1
u_list.append(u)
v_list.append(v)
logl_list.append(logl)
else:
reject += 1
# Only start appending to the chain once a single jump is made
if accept > 0:
u_list.append(u_list[-1])
v_list.append(v_list[-1])
logl_list.append(logl_list[-1])

# If we've taken the minimum number of steps, calculate the ACT
if accept + reject > walks:
act = estimate_nmcmc(
accept_ratio=accept / (accept + reject + nfail),
old_act=old_act, maxmcmc=maxmcmc)

# If we've taken too many likelihood evaluations then break
if accept + reject > maxmcmc:
logging.warning(
"Hit maximum number of walks {} with accept={}, reject={}, "
"and nfail={} try increasing maxmcmc"
.format(maxmcmc, accept, reject, nfail))
break

# If the act is finite, pick randomly from within the chain
if numpy.isfinite(act) and int(.5 * nact * act) < len(u_list):
idx = numpy.random.randint(int(.5 * nact * act), len(u_list))
u = u_list[idx]
v = v_list[idx]
logl = logl_list[idx]
else:
logging.debug("Unable to find a new point using walk: "
"returning a random point")
u = numpy.random.uniform(size=n)
v = prior_transform(u)
logl = loglikelihood(v)

blob = {'accept': accept, 'reject': reject, 'fail': nfail, 'scale': scale}
kwargs["old_act"] = act

ncall = accept + reject
return u, v, logl, ncall, blob


def estimate_nmcmc(accept_ratio, old_act, maxmcmc, safety=5, tau=None):
""" Estimate autocorrelation length of chain using acceptance fraction

Using ACL = (2/acc) - 1 multiplied by a safety margin. Code adapated from
CPNest:
- https://github.com/johnveitch/cpnest/blob/master/cpnest/sampler.py
- http://github.com/farr/Ensemble.jl

Parameters
----------
accept_ratio: float [0, 1]
Ratio of the number of accepted points to the total number of points
old_act: int
The ACT of the last iteration
maxmcmc: int
The maximum length of the MCMC chain to use
safety: int
A safety factor applied in the calculation
tau: int (optional)
The ACT, if given, otherwise estimated.

"""
if tau is None:
tau = maxmcmc / safety

if accept_ratio == 0.0:
Nmcmc_exact = (1 + 1 / tau) * old_act
else:
Nmcmc_exact = (
(1. - 1. / tau) * old_act +
(safety / tau) * (2. / accept_ratio - 1.)
)
Nmcmc_exact = float(min(Nmcmc_exact, maxmcmc))
return max(safety, int(Nmcmc_exact))