Skip to content

Commit

Permalink
Loading Cobaya post-processed samples (importance reweighting, adding…
Browse files Browse the repository at this point in the history
… derived, etc.) (#33)

* cobaya_post: simple case: adding a likelihood

* cobaya_post: priors

* cobaya_post: add/remove derived params

* cobaya_post: changing 1d priors

* cobaya_post: robustness, and placeholders for skip/thin

* cobaya: more error checking

* cobaya: bugfixes
  • Loading branch information
JesusTorrado authored and cmbant committed Jun 17, 2019
1 parent 4c3c07b commit 6422bc0
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 10 deletions.
4 changes: 2 additions & 2 deletions getdist/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -1088,8 +1088,8 @@ def array_dimension(a):
elif dim == 3:
for i, samples_i in enumerate(files_or_samples):
self.chains.append(WeightedSamples(
samples=samples_i, loglikes=None if loglikes is None else np.atleast_2d(loglikes)[i],
weights=None if weights is None else np.atleast_2d(weights)[i], **WSkwargs))
samples=samples_i, loglikes=None if loglikes is None else loglikes[i],
weights=None if weights is None else weights[i], **WSkwargs))
if self.paramNames is None:
self.paramNames = ParamNames(default=self.chains[0].n)
nchains = len(self.chains)
Expand Down
21 changes: 19 additions & 2 deletions getdist/mcsamples.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,31 @@ def loadCobayaSamples(info, collections, name_tag=None,
if hasattr(collections, "data"):
collections = [collections]
# Check consistency between collections
columns = list(collections[0].data)
try:
columns = list(collections[0].data)
except AttributeError:
raise TypeError(
"The second argument does not appear to be a (list of) samples `Collection`.")
if not all([list(c.data) == columns for c in collections[1:]]):
raise ValueError("The given collections don't have the same columns.")
from getdist.yaml_format_tools import _p_label, _p_renames, _weight, _minuslogpost
from getdist.yaml_format_tools import get_info_params, get_range, is_derived_param
from getdist.yaml_format_tools import get_sampler_type
from getdist.yaml_format_tools import get_sampler_type, _post
# Check consistency with info
info_params = get_info_params(info)
# ####################################################################################
# TODO! What to do with slip/ignore_rows and thin?
# This are skip and thin that *has already been done*
skip = info.get(_post, {}).get("skip", 0)
thin = info.get(_post, {}).get("thin", 1)
# Maybe warn if trying to ignore rows twice?
if ignore_rows != 0 and skip != 0:
logging.warn("You are asking for rows to be ignored (%r), but some (%r) were "
"already ignored in the original chain.", ignore_rows, skip)
# Should we warn about thin too?
# Most importantly: do we want to save somewhere the fact that we have *already*
# thinned/skipped?
######################################################################################
assert set(columns[2:]) == set(info_params.keys()), (
"Info and collection(s) are not compatible, because their parameters differ: "
"the collection(s) have %r and the info has %r. " % (
Expand Down
31 changes: 25 additions & 6 deletions getdist/yaml_format_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
_chi2 = "chi2"
_weight = "weight"
_minuslogpost = "minuslogpost"

_post = "post"

# Exceptions
class InputSyntaxError(Exception):
Expand Down Expand Up @@ -113,14 +113,33 @@ def get_info_params(info):
continue
info_params_full[p] = info_params[p]
# Add prior and likelihoods
priors = [_prior_1d_name] + list(info.get(_prior, []))
likes = list(info.get(_likelihood))
# Account for post
remove = info.get(_post, {}).get("remove", {})
for param in remove.get(_params, []) or []:
info_params_full.pop(param, None)
for like in remove.get(_likelihood, []) or []:
likes.remove(like)
for prior in remove.get(_prior, []) or []:
priors.remove(prior)
add = info.get(_post, {}).get("add", {})
# Adding derived params and updating 1d priors
for param, pinfo in add.get(_params, {}).items():
pinfo_old = info_params_full.get(param, {})
pinfo_old.update(pinfo)
info_params_full[param] = pinfo_old
likes += list(add.get(_likelihood, []))
priors += list(add.get(_prior, []))
# Add the prior and the likelihood as derived parameters
info_params_full[_minuslogprior] = {_p_label: r"-\log\pi"}
for prior in [_prior_1d_name] + list(info.get(_prior, [])):
for prior in priors:
info_params_full[_minuslogprior + _separator + prior] = {
_p_label: r"-\log\pi_\mathrm{" + prior.replace("_", "\ ") + r"}"}
_p_label: r"-\log\pi_\mathrm{" + prior.replace("_", r"\ ") + r"}"}
info_params_full[_chi2] = {_p_label: r"\chi^2"}
for lik in info.get(_likelihood):
info_params_full[_chi2 + _separator + lik] = {
_p_label: r"\chi^2_\mathrm{" + lik.replace("_", "\ ") + r"}"}
for like in likes:
info_params_full[_chi2 + _separator + like] = {
_p_label: r"\chi^2_\mathrm{" + like.replace("_", r"\ ") + r"}"}
return info_params_full


Expand Down

0 comments on commit 6422bc0

Please sign in to comment.