Skip to content

Commit

Permalink
changes for Cobaya 2.0 compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
cmbant committed Jul 12, 2019
1 parent 0ebc06f commit 1d2b226
Show file tree
Hide file tree
Showing 8 changed files with 159 additions and 118 deletions.
23 changes: 14 additions & 9 deletions getdist/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ def slice_or_none(x, start=None, end=None):
return getattr(x, "__getitem__", lambda _: None)(slice(start, end))


def chainFiles(root, chain_indices=None, ext='.txt', first_chain=0, last_chain=-1, chain_exclude=None):
def chainFiles(root, chain_indices=None, ext='.txt', separator="_",
first_chain=0, last_chain=-1, chain_exclude=None):
"""
Creates a list of file names for samples given a root name and optional filters
Expand All @@ -59,8 +60,8 @@ def chainFiles(root, chain_indices=None, ext='.txt', first_chain=0, last_chain=-
fname = root
if index > 0:
# deal with just-folder prefix
if not root.endswith("/"):
fname += '_'
if not root.endswith((os.sep, "/")):
fname += separator
fname += str(index)
if not fname.endswith(ext): fname += ext
if index > first_chain and not os.path.exists(fname) or 0 < last_chain < index: break
Expand Down Expand Up @@ -849,17 +850,22 @@ def __init__(self, root=None, jobItem=None, paramNamesFile=None, names=None, lab
:param kwargs: extra options for :class:`~.chains.WeightedSamples`'s constructor
"""
from getdist.cobaya_interface import get_sampler_type, _separator_files

self.chains = None
WeightedSamples.__init__(self, **kwargs)
self.jobItem = jobItem
self.ignore_lines = float(kwargs.get('ignore_rows', 0))
self.root = root
if not paramNamesFile and root:
mid = ('' if root.endswith("/") else "__")
if os.path.exists(root + '.paramnames'):
paramNamesFile = root + '.paramnames'
elif os.path.exists(root + mid + 'full.yaml'):
paramNamesFile = root + mid + 'full.yaml'
mid = not root.endswith((os.sep, "/"))
endings = ['.paramnames', ('__' if mid else '') + 'full.yaml',
(_separator_files if mid else '') + 'updated.yaml']
try:
paramNamesFile = next(
root + ending for ending in endings if os.path.exists(root + ending))
except StopIteration:
paramNamesFile = None
self.setParamNames(paramNamesFile or names)
if labels is not None:
self.paramNames.setLabels(labels)
Expand All @@ -871,7 +877,6 @@ def __init__(self, root=None, jobItem=None, paramNamesFile=None, names=None, lab
raise ValueError("Unknown sampler type %s" % sampler)
self.sampler = sampler.lower()
elif isinstance(paramNamesFile, six.string_types) and paramNamesFile.endswith("yaml"):
from getdist.yaml_format_tools import get_sampler_type
self.sampler = get_sampler_type(paramNamesFile)
else:
self.sampler = "mcmc"
Expand Down
81 changes: 5 additions & 76 deletions getdist/yaml_format_tools.py → getdist/cobaya_interface.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# JT 2017-18
# JT 2017-19

from __future__ import division
from importlib import import_module
from six import string_types
from copy import deepcopy
import re
from collections import OrderedDict as odict
import numpy as np
import yaml


# Conventions
_prior = "prior"
Expand All @@ -21,6 +20,7 @@
_p_derived = "derived"
_p_renames = "renames"
_separator = "__"
_separator_files = "."
_minuslogprior = "minuslogprior"
_prior_1d_name = "0"
_chi2 = "chi2"
Expand All @@ -29,78 +29,6 @@
_post = "post"


# Exceptions
class InputSyntaxError(Exception):
"""Syntax error in YAML input."""


# Better loader for YAML
# 1. Matches 1e2 as 100 (no need for dot, or sign after e),
# from http://stackoverflow.com/a/30462009
# 2. Wrapper to load mappings as OrderedDict (for likelihoods and params),
# from http://stackoverflow.com/a/21912744
def yaml_load(text_stream, Loader=yaml.Loader, object_pairs_hook=odict, file_name=None):
class OrderedLoader(Loader):
pass

def construct_mapping(loader, node):
loader.flatten_mapping(node)
return object_pairs_hook(loader.construct_pairs(node))

OrderedLoader.add_constructor(
yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, construct_mapping)
OrderedLoader.add_implicit_resolver(
u'tag:yaml.org,2002:float',
re.compile(u'''^(?:
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|[-+]?\\.(?:inf|Inf|INF)
|\\.(?:nan|NaN|NAN))$''', re.X),
list(u'-+0123456789.'))

# Ignore python objects
def dummy_object_loader(loader, suffix, node):
return None

OrderedLoader.add_multi_constructor(
u'tag:yaml.org,2002:python/name:', dummy_object_loader)
try:
return yaml.load(text_stream, OrderedLoader)
# Redefining the general exception to give more user-friendly information
except yaml.YAMLError as exception:
errstr = "Error in your input file " + ("'" + file_name + "'" if file_name else "")
if hasattr(exception, "problem_mark"):
line = 1 + exception.problem_mark.line
column = 1 + exception.problem_mark.column
signal = " --> "
signal_right = " <---- "
sep = "|"
context = 4
lines = text_stream.split("\n")
pre = ((("\n" + " " * len(signal) + sep).join(
[""] + lines[max(line - 1 - context, 0):line - 1]))) + "\n"
errorline = (signal + sep + lines[line - 1] +
signal_right + "column %s" % column)
post = ((("\n" + " " * len(signal) + sep).join(
[""] + lines[line + 1 - 1:min(line + 1 + context - 1, len(lines))]))) + "\n"
raise InputSyntaxError(
errstr + " at line %d, column %d." % (line, column) +
pre + errorline + post +
"Maybe inconsistent indentation, '=' instead of ':', "
"no space after ':', or a missing ':' on an empty group?")
else:
raise InputSyntaxError(errstr)


def yaml_load_file(input_file):
"""Wrapper to load a yaml file."""
with open(input_file, "r") as f:
lines = "".join(f.readlines())
return yaml_load(lines, file_name=input_file)


def get_info_params(info):
"""
Extracts parameter info from the new yaml format.
Expand Down Expand Up @@ -207,7 +135,8 @@ def expand_info_param(info_param):

def get_sampler_type(filename_or_info):
if isinstance(filename_or_info, string_types):
from getdist.yaml_tools import yaml_load_file
filename_or_info = yaml_load_file(filename_or_info)
default_sampler_for_chain_type = "mcmc"
sampler = list(filename_or_info.get(_sampler, [default_sampler_for_chain_type]))[0]
return {"mcmc": "mcmc", "polychord": "nested"}[sampler]
return {"mcmc": "mcmc", "polychord": "nested", "minimize": "minimize"}[sampler]
8 changes: 4 additions & 4 deletions getdist/gui/mainwindow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,8 +1006,8 @@ def _updateComboBoxRootname(self, listOfRoots):
self.comboBoxRootname.clear()
self.listRoots.show()
self.pushButtonRemove.show()
baseRoots = [(os.path.basename(root) if not root.endswith("/")
else os.path.basename(root[:-1]) + "/")
baseRoots = [(os.path.basename(root) if not root.endswith((os.sep, "/"))
else os.path.basename(root[:-1]) + os.sep)
for root in listOfRoots]
self.comboBoxRootname.addItems(baseRoots)
if len(baseRoots) > 1:
Expand Down Expand Up @@ -1039,8 +1039,8 @@ def newRootItem(self, root):
else:
path = self.rootdirname
# new style, if the prefix is just a folder
if root[-1] == "/":
path = "/".join(path.split("/")[:-1])
if root[-1] in (os.sep, "/"):
path = os.sep.join(path.replace("/", os.sep).split(os.sep)[:-1])
info = plots.RootInfo(root, path, self.batch)
plotter.sampleAnalyser.addRoot(info)

Expand Down
57 changes: 38 additions & 19 deletions getdist/mcsamples.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,16 +66,22 @@ def loadMCSamples(file_root, ini=None, jobItem=None, no_cache=False, settings={}
if settings and dist_settings: raise ValueError('Use settings or dist_settings')
if dist_settings: settings = dist_settings
files = chainFiles(file_root)
if not files: # try new Cobaya format
files = chainFiles(file_root, separator='.')
path, name = os.path.split(file_root)
path = getdist.cache_dir or path
if not os.path.exists(path): os.mkdir(path)
cachefile = os.path.join(path, name) + '.py_mcsamples'
samples = MCSamples(file_root, jobItem=jobItem, ini=ini, settings=settings)
if os.path.isfile(file_root + '.paramnames'):
allfiles = files + [file_root + '.ranges', file_root + '.paramnames', file_root + '.properties.ini']
else: # new format (txt+yaml)
mid = "" if file_root.endswith("/") else "__"
allfiles = files + [file_root + mid + ending for ending in ['input.yaml', 'full.yaml']]
else: # Cobaya
folder = os.path.dirname(file_root)
prefix = os.path.basename(file_root)
allfiles = files + [
os.path.join(folder, f) for f in os.listdir(folder) if (
f.startswith(prefix) and
any([f.lower().endswith(end) for end in ['updated.yaml', 'full.yaml']]))]
if not no_cache and os.path.exists(cachefile) and lastModified(allfiles) < os.path.getmtime(cachefile):
try:
with open(cachefile, 'rb') as inp:
Expand All @@ -95,12 +101,18 @@ def loadMCSamples(file_root, ini=None, jobItem=None, no_cache=False, settings={}
return samples


def loadCobayaSamples(info, collections, name_tag=None,
ignore_rows=0, ini=None, settings={}):
def loadCobayaSamples(*args, **kwargs):
logging.warning("'loadCobayaSamples' will be deprecated in the future. "
"Use 'MCSamplesFromCobaya' instead.")
return MCSamplesFromCobaya(*args, **kwargs)


def MCSamplesFromCobaya(info, collections, name_tag=None,
ignore_rows=0, ini=None, settings={}):
"""
Loads a set of samples from Cobaya's output.
Creates a set of samples from Cobaya's output.
Parameter names, ranges and labels are taken from the "info" dictionary
(always use the "full", updated one generated by `cobaya.run`).
(always use the "updated" one generated by `cobaya.run`).
For a description of the various analysis settings and default values see
`analysis_defaults.ini <http://getdist.readthedocs.org/en/latest/analysis_settings.html>`_.
Expand All @@ -114,6 +126,10 @@ def loadCobayaSamples(info, collections, name_tag=None,
:param settings: dictionary of analysis settings to override defaults
:return: The :class:`MCSamples` instance
"""
from getdist.cobaya_interface import _p_label, _p_renames, _weight, _minuslogpost
from getdist.cobaya_interface import get_info_params, get_range, is_derived_param
from getdist.cobaya_interface import get_sampler_type, _post

if not hasattr(info, "keys"):
raise TypeError("Cannot regonise arguments. Are you sure you are calling "
"with (info, collections, ...) in that order?")
Expand All @@ -127,9 +143,6 @@ def loadCobayaSamples(info, collections, name_tag=None,
"The second argument does not appear to be a (list of) samples `Collection`.")
if not all([list(c.data) == columns for c in collections[1:]]):
raise ValueError("The given collections don't have the same columns.")
from getdist.yaml_format_tools import _p_label, _p_renames, _weight, _minuslogpost
from getdist.yaml_format_tools import get_info_params, get_range, is_derived_param
from getdist.yaml_format_tools import get_sampler_type, _post
# Check consistency with info
info_params = get_info_params(info)
# ####################################################################################
Expand All @@ -139,8 +152,8 @@ def loadCobayaSamples(info, collections, name_tag=None,
thin = info.get(_post, {}).get("thin", 1)
# Maybe warn if trying to ignore rows twice?
if ignore_rows != 0 and skip != 0:
logging.warn("You are asking for rows to be ignored (%r), but some (%r) were "
"already ignored in the original chain.", ignore_rows, skip)
logging.warning("You are asking for rows to be ignored (%r), but some (%r) were "
"already ignored in the original chain.", ignore_rows, skip)
# Should we warn about thin too?
# Most importantly: do we want to save somewhere the fact that we have *already*
# thinned/skipped?
Expand Down Expand Up @@ -182,7 +195,8 @@ class MCSamples(Chains):
"""
The main high-level class for a collection of parameter samples.
Derives from :class:`.chains.Chains`, adding high-level functions including Kernel Density estimates, parameter ranges and custom settings.
Derives from :class:`.chains.Chains`, adding high-level functions including
Kernel Density estimates, parameter ranges and custom settings.
"""

def __init__(self, root=None, jobItem=None, ini=None, settings=None, ranges=None,
Expand Down Expand Up @@ -2083,10 +2097,15 @@ def _setLikeStats(self):

def _readRanges(self):
if self.root:
from getdist.cobaya_interface import _separator_files
ranges_file_classic = self.root + '.ranges'
ranges_file_new = (
self.root + ('' if self.root.endswith('/') else '__') + 'full.yaml')
for ranges_file in [ranges_file_classic, ranges_file_new]:
ranges_file_cobaya_old = (
self.root + ('' if self.root.endswith((os.sep, "/")) else '__') + 'full.yaml')
ranges_file_cobaya = (
self.root + (
'' if self.root.endswith((os.sep, "/")) else _separator_files) + 'updated.yaml')
for ranges_file in [
ranges_file_classic, ranges_file_cobaya_old, ranges_file_cobaya]:
if os.path.isfile(ranges_file):
self.ranges = ParamBounds(ranges_file)
return
Expand Down Expand Up @@ -2553,9 +2572,9 @@ def GetChainRootFiles(rootdir):
"""
pattern = os.path.join(rootdir, '*.paramnames')
files = [os.path.splitext(f)[0] for f in glob.glob(pattern)]
ending = 'full.yaml'
pattern = os.path.join(rootdir, "*" + ending)
files += [f[:-len(ending)].rstrip("_") for f in glob.glob(pattern)]
for ending in ['full.yaml', 'updated.yaml']:
pattern = os.path.join(rootdir, "*" + ending)
files += [f[:-len(ending)].rstrip("_.") for f in glob.glob(pattern)]
files.sort()
return files

Expand Down
7 changes: 3 additions & 4 deletions getdist/paramnames.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,10 +378,9 @@ def loadFromFile(self, fileName):
with open(fileName) as f:
self.names = [ParamInfo(line) for line in [s.strip() for s in f] if line != '']
elif extension.lower() in ('.yaml', '.yml'):
from getdist.yaml_format_tools import yaml_load_file, get_info_params
from getdist.yaml_format_tools import is_sampled_param, is_derived_param
from getdist.yaml_format_tools import _p_label, _p_renames

from getdist.yaml_tools import yaml_load_file
from getdist.cobaya_interface import get_info_params, is_sampled_param
from getdist.cobaya_interface import is_derived_param, _p_label, _p_renames
info_params = get_info_params(yaml_load_file(fileName))
# first sampled, then derived
self.names = [ParamInfo(name=param, label=(info or {}).get(_p_label, param),
Expand Down
4 changes: 2 additions & 2 deletions getdist/parampriors.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def loadFromFile(self, fileName):
if len(strings) == 3:
self.setRange(strings[0], strings[1:])
elif extension in ('.yaml', '.yml'):
from getdist.yaml_format_tools import yaml_load_file, get_info_params
from getdist.yaml_format_tools import get_range, is_fixed_param
from getdist.cobaya_interface import get_range, is_fixed_param, get_info_params
from getdist.yaml_tools import yaml_load_file
info_params = get_info_params(yaml_load_file(fileName))
for p, info in info_params.items():
if not is_fixed_param(info):
Expand Down
11 changes: 7 additions & 4 deletions getdist/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from paramgrid import gridconfig, batchjob
import getdist
from getdist import MCSamples, loadMCSamples, ParamNames, ParamInfo, IniFile
from getdist.chains import chainFiles
from getdist.paramnames import escapeLatex, makeList, mergeRenames
from getdist.parampriors import ParamBounds
from getdist.densities import Density1D, Density2D
Expand Down Expand Up @@ -446,8 +447,8 @@ def samplesForRoot(self, root, file_root=None, cache=True, settings=None):
if isinstance(root, MCSamples): return root
if os.path.isabs(root):
# deal with just-folder prefix
if root.endswith("/"):
root = os.path.basename(root[:-1]) + "/"
if root.endswith((os.sep, "/")):
root = os.path.basename(root[:-1]) + os.sep
else:
root = os.path.basename(root)
if root in self.mcsamples and cache: return self.mcsamples[root]
Expand All @@ -457,6 +458,7 @@ def samplesForRoot(self, root, file_root=None, cache=True, settings=None):
else:
dist_settings = {}
if not file_root:
from getdist.cobaya_interface import _separator_files
for chain_dir in self.chain_dirs:
if hasattr(chain_dir, "resolveRoot"):
jobItem = chain_dir.resolveRoot(root)
Expand All @@ -468,7 +470,8 @@ def samplesForRoot(self, root, file_root=None, cache=True, settings=None):
break
else:
name = os.path.join(chain_dir, root)
if os.path.exists(name + '_1.txt') or os.path.exists(name + '.txt'):
if any([chainFiles(name, separator=sep)
for sep in ['_', _separator_files]]):
file_root = name
break
if not file_root:
Expand Down Expand Up @@ -1944,7 +1947,7 @@ def triangle_plot(self, roots, params=None, legend_labels=None, plot_3d_with_par
:param title_limit:if not None, a maginalized limit (1,2..) to print as the title of the first root on the diagonal 1D plots
:param upper_kwargs: dict for same-named arguments for use when making upper-triangle 2D plots (contour_colors, etc). Set show_1d=False to not add to the diagonal.
:param diag1d_kwargs: list of dict for arguments when making 1D plots on grid diagonal
:param markers: optional dict giving marker values indexed by parameter, or a list of marker values for each parameter plotted
:param markers: optional dict giving marker values indexed by parameter, or a list of marker values for each parameter plotted
:param param_limits: a dictionary holding a mapping from parameter names to axis limits for that parameter
:param kwargs: optional keyword arguments for :func:`~GetDistPlotter.plot_2d` or :func:`~GetDistPlotter.plot_3d` (lower triangle only)
Expand Down
Loading

0 comments on commit 1d2b226

Please sign in to comment.