Skip to content

Commit

Permalink
allow start file to be specified for mcmc samplers (gwastro#3394)
Browse files Browse the repository at this point in the history
  • Loading branch information
Collin Capano authored and OliverEdy committed Apr 3, 2023
1 parent 8c53563 commit 7cc1dbd
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
7 changes: 6 additions & 1 deletion pycbc/inference/io/base_hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,8 @@ def copy_samples(self, other, parameters=None, parameter_names=None,
# if list of desired parameters is different, rename
if set(parameters) != set(self.variable_params):
other.attrs['variable_params'] = parameters
if read_args is None:
read_args = {}
samples = self.read_samples(parameters, **read_args)
logging.info("Copying {} samples".format(samples.size))
# if different parameter names are desired, get them from the samples
Expand All @@ -671,7 +673,10 @@ def copy_samples(self, other, parameters=None, parameter_names=None,
samples = FieldArray.from_kwargs(**arrs)
other.attrs['variable_params'] = samples.fieldnames
logging.info("Writing samples")
other.write_samples(other, samples, **write_args)
if write_args is None:
write_args = {}
other.write_samples({p: samples[p] for p in samples.fieldnames},
**write_args)

def copy(self, other, ignore=None, parameters=None, parameter_names=None,
read_args=None, write_args=None):
Expand Down
10 changes: 8 additions & 2 deletions pycbc/inference/sampler/base_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,8 +454,14 @@ def set_state_from_file(self, filename):
def set_start_from_config(self, cp):
"""Sets the initial state of the sampler from config file
"""
init_prior = initial_dist_from_config(cp, self.variable_params)
self.set_p0(prior=init_prior)
if cp.has_option('sampler', 'start-file'):
start_file = cp.get('sampler', 'start-file')
logging.info("Using file %s for initial positions", start_file)
init_prior = None
else:
start_file = None
init_prior = initial_dist_from_config(cp, self.variable_params)
self.set_p0(samples_file=start_file, prior=init_prior)

def resume_from_checkpoint(self):
"""Resume the sampler from the checkpoint file
Expand Down

0 comments on commit 7cc1dbd

Please sign in to comment.