Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add order statistics checks #72

Merged
merged 15 commits into from
Feb 12, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 122 additions & 13 deletions cpnest/NestedSampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import pickle
import time
import logging
import bisect
import numpy as np
from numpy import logaddexp
from numpy import inf
from scipy.stats import kstest
from math import isnan
from . import nest2pos
from .nest2pos import logsubexp
Expand All @@ -19,6 +21,70 @@
logger = logging.getLogger('cpnest.NestedSampling')


class KeyOrderedList(list):
"""
List object that is ordered according to a key

Parameters
----------
iterable : array_like
Initial input used to intialise the list
key : function, optional
Key to use to sort the list, by defaul it is sorted by its
values.
"""
def __init__(self, iterable, key=lambda x: x):
iterable = sorted(iterable, key=key)
super(KeyOrderedList, self).__init__(iterable)

self._key = key
self._keys = [self._key(v) for v in iterable]

def search(self, item):
"""
Find the location of a new entry
"""
return bisect.bisect(self._keys, self._key(item))

def add(self, item):
"""
Update the ordered list with a single item and return the index
"""
index = self.search(item)
self.insert(index, item)
self._keys.insert(index, self._key(item))
return index


class OrderedLivePoints(KeyOrderedList):
"""
Object tha contains live points ordered by increasing log-likelihood. Requires
the log-likelihood to be pre-computed.

Assumes the log-likelihood is accesible as an attribute of each live point.

Parameters
----------
live_points : array_like
Initial live points
"""
def __init__(self, live_points):
super(OrderedLivePoints, self).__init__(live_points, key=lambda x: x.logL)

def insert_live_point(self, live_point):
"""
Insert a live point and return the index of the new point
"""
return self.add(live_point)

def remove_n_worst_points(self, n):
"""
Remvoe the n worst live points
"""
del self[:n]
del self._keys[:n]


class _NSintegralState(object):
"""
Stores the state of the nested sampling integrator
Expand Down Expand Up @@ -188,6 +254,8 @@ def __init__(self,
self.logLmax = self.manager.logLmax
self.iteration = 0
self.nested_samples = []
self.insertion_indices = []
self.rolling_p = []
self.logZ = None
self.state = _NSintegralState(self.Nlive)
sys.stdout.flush()
Expand Down Expand Up @@ -260,19 +328,26 @@ def consume_sample(self):
logLtmp=[p.logL for p in self.params[:nreplace]]

# Make sure we are mixing the chains
for i in np.random.permutation(range(len(self.worst))): self.manager.consumer_pipes[self.worst[i]].send(self.params[self.worst[i]])
for i in np.random.permutation(range(nreplace)): self.manager.consumer_pipes[self.worst[i]].send(self.params[self.worst[i]])
self.condition = logaddexp(self.state.logZ,self.logLmax.value - self.iteration/(float(self.Nlive))) - self.state.logZ

# Replace the points we just consumed with the next acceptable ones
for k in self.worst:
# Reversed since the for the first point the current number of
# live points is N - n_worst -1 (minus 1 because of counting from zero)
for k in reversed(self.worst):
self.iteration += 1
loops = 0
while(True):
loops += 1
acceptance, sub_acceptance, self.jumps, proposed = self.manager.consumer_pipes[self.queue_counter].recv()
if proposed.logL > self.logLmin.value:
# replace worst point with new one
self.params[k] = proposed
# Insert the new live point into the ordered list and
# return the index at which is was inserted, this will
# include the n worst points, so this subtracted next
index = self.params.insert_live_point(proposed)
# the index is then coverted to a value between [0, 1]
# accounting for the variable number of live points
self.insertion_indices.append((index - nreplace) / (self.Nlive - k - 1))
self.queue_counter = (self.queue_counter + 1) % len(self.manager.consumer_pipes)
self.accepted += 1
break
Expand All @@ -284,19 +359,48 @@ def consume_sample(self):
if self.verbose:
self.logger.info("{0:d}: n:{1:4d} NS_acc:{2:.3f} S{3:d}_acc:{4:.3f} sub_acc:{5:.3f} H: {6:.2f} logL {7:.5f} --> {8:.5f} dZ: {9:.3f} logZ: {10:.3f} logLmax: {11:.2f}"\
.format(self.iteration, self.jumps*loops, self.acceptance, k, acceptance, sub_acceptance, self.state.info,\
logLtmp[k], self.params[k].logL, self.condition, self.state.logZ, self.logLmax.value))
#sys.stderr.flush()
logLtmp[k], proposed.logL, self.condition, self.state.logZ, self.logLmax.value))

# points not removed earlier because they are used to resend to
# samplers if rejected
self.params.remove_n_worst_points(nreplace)

def get_worst_n_live_points(self, n):
"""
selects the lowest likelihood N live points
for evolution
"""
self.params.sort(key=attrgetter('logL'))
self.worst = np.arange(n)
self.logLmin.value = np.float128(self.params[n-1].logL)
return np.float128(self.logLmin.value)

def check_insertion_indices(self, rolling=True, filename=None):
"""
Checking the distibution of the insertion indices either during
the nested sampling run (rolling=True) or for the whole run
(rolling=False).
"""
if not self.insertion_indices:
return
if rolling:
indices = self.insertion_indices[-self.Nlive:]
else:
indices = self.insertion_indices

D, p = kstest(indices, 'uniform', args=(0, 1))
if rolling:
self.logger.warning('Rolling KS test: D={0:.3}, p-value={1:.3}'.format(D, p))
self.rolling_p.append(p)
else:
self.logger.warning('Final KS test: D={0:.3}, p-value={1:.3}'.format(D, p))

if filename is not None:
np.savetxt(os.path.join(
self.output_folder, filename),
self.insertion_indices,
newline='\n',delimiter=' ')


def reset(self):
"""
Initialise the pool of `cpnest.parameter.LivePoint` by
Expand All @@ -305,20 +409,22 @@ def reset(self):
# send all live points to the samplers for start
i = 0
nthreads=self.manager.nthreads
params = [None] * self.Nlive
with tqdm(total=self.Nlive, disable= not self.verbose, desc='CPNEST: populate samplers', position=nthreads) as pbar:
while i < self.Nlive:
for j in range(nthreads): self.manager.consumer_pipes[j].send(self.model.new_point())
for j in range(nthreads):
while i < self.Nlive:
acceptance,sub_acceptance,self.jumps,self.params[i] = self.manager.consumer_pipes[self.queue_counter].recv()
acceptance,sub_acceptance,self.jumps,params[i] = self.manager.consumer_pipes[self.queue_counter].recv()
self.queue_counter = (self.queue_counter + 1) % len(self.manager.consumer_pipes)
if np.isnan(self.params[i].logL):
self.logger.warn("Likelihood function returned NaN for params "+str(self.params))
if np.isnan(params[i].logL):
self.logger.warn("Likelihood function returned NaN for params "+str(params))
self.logger.warn("You may want to check your likelihood function")
if self.params[i].logP!=-np.inf and self.params[i].logL!=-np.inf:
if params[i].logP!=-np.inf and params[i].logL!=-np.inf:
i+=1
pbar.update()
break
self.params = OrderedLivePoints(params)
if self.verbose:
sys.stderr.write("\n")
sys.stderr.flush()
Expand All @@ -331,8 +437,7 @@ def nested_sampling_loop(self):
if not self.initialised:
self.reset()
if self.prior_sampling:
for i in range(self.Nlive):
self.nested_samples.append(self.params[i])
self.nested_samples = self.params
self.write_chain_to_file()
self.write_evidence_to_file()
self.logLmin.value = np.inf
Expand All @@ -348,6 +453,10 @@ def nested_sampling_loop(self):
if time.time() - self.last_checkpoint_time > self.manager.periodic_checkpoint_interval:
self.checkpoint()
self.last_checkpoint_time = time.time()

if (self.iteration % self.Nlive) < self.manager.nthreads:
self.check_insertion_indices()

except CheckPoint:
self.checkpoint()
# Run each pipe to get it to checkpoint
Expand Down
6 changes: 6 additions & 0 deletions cpnest/cpnest.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,8 @@ def run(self):
sys.exit(130)

if self.verbose >= 2:
self.NS.check_insertion_indices(rolling=False,
filename='insertion_indices.dat')
self.logger.critical(
"Saving nested samples in {0}".format(self.output)
)
Expand All @@ -297,10 +299,13 @@ def run(self):
)
self.posterior_samples = self.get_posterior_samples()
else:
self.NS.check_insertion_indices(rolling=False,
filename=None)
self.nested_samples = self.get_nested_samples(filename=None)
self.posterior_samples = self.get_posterior_samples(
filename=None
)

if self.verbose>=3 or self.NS.prior_sampling:
self.prior_samples = self.get_prior_samples(filename=None)
if self.verbose>=3 and not self.NS.prior_sampling:
Expand Down Expand Up @@ -523,6 +528,7 @@ def plot(self, corner = True):
ms=plotting_mcmc,
labels=pos.dtype.names,
filename=os.path.join(self.output,'corner.pdf'))
plot.plot_indices(self.NS.insertion_indices, filename=os.path.join(self.output, 'insertion_indices.pdf'))

def worker_sampler(self, producer_pipe, logLmin):
cProfile.runctx('self.sampler.produce_sample(producer_pipe, logLmin)', globals(), locals(), 'prof_sampler.prof')
Expand Down
31 changes: 31 additions & 0 deletions cpnest/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,37 @@ def plot_hist(x, name=None, prior_samples=None, mcmc_samples=None, filename=None
plt.savefig(filename, bbox_inches='tight')
plt.close()


def plot_indices(indices, filename=None, max_bins=30):
"""
Histogram indices for insertion indices tests.

Parameters
----------
indices : list
List of insertion indices
filename : str, optional
Filename used to saved resulting figure. If not specified figure
is not saved.
max_bins : int, optional
Maximum number of bins in the histogram.
"""
fig = plt.figure()
ax = fig.add_subplot(111)

ax.hist(indices, density=True, color='tab:blue', linewidth=1.25,
histtype='step', bins=min(len(indices) // 100, max_bins))
# Theoretical distribution
ax.axhline(1, color='black', linewidth=1.25, linestyle=':', label='pdf')

ax.legend()
ax.set_xlabel('Insertion indices [0, 1]')

if filename is not None:
plt.savefig(filename, bbox_inches='tight')
plt.close()


def plot_corner(xs, ps=None, ms=None, filename=None, **kwargs):
"""
Produce a corner plot
Expand Down