From 1d2b226accf586323ddbba320310f60492e9c416 Mon Sep 17 00:00:00 2001 From: Antony Lewis Date: Fri, 12 Jul 2019 14:34:43 +0100 Subject: [PATCH] changes for Cobaya 2.0 compatibility --- getdist/chains.py | 23 +++-- ...ml_format_tools.py => cobaya_interface.py} | 81 ++--------------- getdist/gui/mainwindow.py | 8 +- getdist/mcsamples.py | 57 ++++++++---- getdist/paramnames.py | 7 +- getdist/parampriors.py | 4 +- getdist/plots.py | 11 ++- getdist/yaml_tools.py | 86 +++++++++++++++++++ 8 files changed, 159 insertions(+), 118 deletions(-) rename getdist/{yaml_format_tools.py => cobaya_interface.py} (60%) create mode 100644 getdist/yaml_tools.py diff --git a/getdist/chains.py b/getdist/chains.py index 01b0b09..0afbe43 100644 --- a/getdist/chains.py +++ b/getdist/chains.py @@ -40,7 +40,8 @@ def slice_or_none(x, start=None, end=None): return getattr(x, "__getitem__", lambda _: None)(slice(start, end)) -def chainFiles(root, chain_indices=None, ext='.txt', first_chain=0, last_chain=-1, chain_exclude=None): +def chainFiles(root, chain_indices=None, ext='.txt', separator="_", + first_chain=0, last_chain=-1, chain_exclude=None): """ Creates a list of file names for samples given a root name and optional filters @@ -59,8 +60,8 @@ def chainFiles(root, chain_indices=None, ext='.txt', first_chain=0, last_chain=- fname = root if index > 0: # deal with just-folder prefix - if not root.endswith("/"): - fname += '_' + if not root.endswith((os.sep, "/")): + fname += separator fname += str(index) if not fname.endswith(ext): fname += ext if index > first_chain and not os.path.exists(fname) or 0 < last_chain < index: break @@ -849,17 +850,22 @@ def __init__(self, root=None, jobItem=None, paramNamesFile=None, names=None, lab :param kwargs: extra options for :class:`~.chains.WeightedSamples`'s constructor """ + from getdist.cobaya_interface import get_sampler_type, _separator_files + self.chains = None WeightedSamples.__init__(self, **kwargs) self.jobItem = jobItem self.ignore_lines = float(kwargs.get('ignore_rows', 0)) self.root = root if not paramNamesFile and root: - mid = ('' if root.endswith("/") else "__") - if os.path.exists(root + '.paramnames'): - paramNamesFile = root + '.paramnames' - elif os.path.exists(root + mid + 'full.yaml'): - paramNamesFile = root + mid + 'full.yaml' + mid = not root.endswith((os.sep, "/")) + endings = ['.paramnames', ('__' if mid else '') + 'full.yaml', + (_separator_files if mid else '') + 'updated.yaml'] + try: + paramNamesFile = next( + root + ending for ending in endings if os.path.exists(root + ending)) + except StopIteration: + paramNamesFile = None self.setParamNames(paramNamesFile or names) if labels is not None: self.paramNames.setLabels(labels) @@ -871,7 +877,6 @@ def __init__(self, root=None, jobItem=None, paramNamesFile=None, names=None, lab raise ValueError("Unknown sampler type %s" % sampler) self.sampler = sampler.lower() elif isinstance(paramNamesFile, six.string_types) and paramNamesFile.endswith("yaml"): - from getdist.yaml_format_tools import get_sampler_type self.sampler = get_sampler_type(paramNamesFile) else: self.sampler = "mcmc" diff --git a/getdist/yaml_format_tools.py b/getdist/cobaya_interface.py similarity index 60% rename from getdist/yaml_format_tools.py rename to getdist/cobaya_interface.py index d7addf0..84c879b 100644 --- a/getdist/yaml_format_tools.py +++ b/getdist/cobaya_interface.py @@ -1,13 +1,12 @@ -# JT 2017-18 +# JT 2017-19 from __future__ import division from importlib import import_module from six import string_types from copy import deepcopy -import re from collections import OrderedDict as odict import numpy as np -import yaml + # Conventions _prior = "prior" @@ -21,6 +20,7 @@ _p_derived = "derived" _p_renames = "renames" _separator = "__" +_separator_files = "." _minuslogprior = "minuslogprior" _prior_1d_name = "0" _chi2 = "chi2" @@ -29,78 +29,6 @@ _post = "post" -# Exceptions -class InputSyntaxError(Exception): - """Syntax error in YAML input.""" - - -# Better loader for YAML -# 1. Matches 1e2 as 100 (no need for dot, or sign after e), -# from http://stackoverflow.com/a/30462009 -# 2. Wrapper to load mappings as OrderedDict (for likelihoods and params), -# from http://stackoverflow.com/a/21912744 -def yaml_load(text_stream, Loader=yaml.Loader, object_pairs_hook=odict, file_name=None): - class OrderedLoader(Loader): - pass - - def construct_mapping(loader, node): - loader.flatten_mapping(node) - return object_pairs_hook(loader.construct_pairs(node)) - - OrderedLoader.add_constructor( - yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, construct_mapping) - OrderedLoader.add_implicit_resolver( - u'tag:yaml.org,2002:float', - re.compile(u'''^(?: - [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? - |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) - |\\.[0-9_]+(?:[eE][-+][0-9]+)? - |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* - |[-+]?\\.(?:inf|Inf|INF) - |\\.(?:nan|NaN|NAN))$''', re.X), - list(u'-+0123456789.')) - - # Ignore python objects - def dummy_object_loader(loader, suffix, node): - return None - - OrderedLoader.add_multi_constructor( - u'tag:yaml.org,2002:python/name:', dummy_object_loader) - try: - return yaml.load(text_stream, OrderedLoader) - # Redefining the general exception to give more user-friendly information - except yaml.YAMLError as exception: - errstr = "Error in your input file " + ("'" + file_name + "'" if file_name else "") - if hasattr(exception, "problem_mark"): - line = 1 + exception.problem_mark.line - column = 1 + exception.problem_mark.column - signal = " --> " - signal_right = " <---- " - sep = "|" - context = 4 - lines = text_stream.split("\n") - pre = ((("\n" + " " * len(signal) + sep).join( - [""] + lines[max(line - 1 - context, 0):line - 1]))) + "\n" - errorline = (signal + sep + lines[line - 1] + - signal_right + "column %s" % column) - post = ((("\n" + " " * len(signal) + sep).join( - [""] + lines[line + 1 - 1:min(line + 1 + context - 1, len(lines))]))) + "\n" - raise InputSyntaxError( - errstr + " at line %d, column %d." % (line, column) + - pre + errorline + post + - "Maybe inconsistent indentation, '=' instead of ':', " - "no space after ':', or a missing ':' on an empty group?") - else: - raise InputSyntaxError(errstr) - - -def yaml_load_file(input_file): - """Wrapper to load a yaml file.""" - with open(input_file, "r") as f: - lines = "".join(f.readlines()) - return yaml_load(lines, file_name=input_file) - - def get_info_params(info): """ Extracts parameter info from the new yaml format. @@ -207,7 +135,8 @@ def expand_info_param(info_param): def get_sampler_type(filename_or_info): if isinstance(filename_or_info, string_types): + from getdist.yaml_tools import yaml_load_file filename_or_info = yaml_load_file(filename_or_info) default_sampler_for_chain_type = "mcmc" sampler = list(filename_or_info.get(_sampler, [default_sampler_for_chain_type]))[0] - return {"mcmc": "mcmc", "polychord": "nested"}[sampler] + return {"mcmc": "mcmc", "polychord": "nested", "minimize": "minimize"}[sampler] diff --git a/getdist/gui/mainwindow.py b/getdist/gui/mainwindow.py index cd4d87b..d846c61 100644 --- a/getdist/gui/mainwindow.py +++ b/getdist/gui/mainwindow.py @@ -1006,8 +1006,8 @@ def _updateComboBoxRootname(self, listOfRoots): self.comboBoxRootname.clear() self.listRoots.show() self.pushButtonRemove.show() - baseRoots = [(os.path.basename(root) if not root.endswith("/") - else os.path.basename(root[:-1]) + "/") + baseRoots = [(os.path.basename(root) if not root.endswith((os.sep, "/")) + else os.path.basename(root[:-1]) + os.sep) for root in listOfRoots] self.comboBoxRootname.addItems(baseRoots) if len(baseRoots) > 1: @@ -1039,8 +1039,8 @@ def newRootItem(self, root): else: path = self.rootdirname # new style, if the prefix is just a folder - if root[-1] == "/": - path = "/".join(path.split("/")[:-1]) + if root[-1] in (os.sep, "/"): + path = os.sep.join(path.replace("/", os.sep).split(os.sep)[:-1]) info = plots.RootInfo(root, path, self.batch) plotter.sampleAnalyser.addRoot(info) diff --git a/getdist/mcsamples.py b/getdist/mcsamples.py index a3faf30..09304cc 100644 --- a/getdist/mcsamples.py +++ b/getdist/mcsamples.py @@ -66,6 +66,8 @@ def loadMCSamples(file_root, ini=None, jobItem=None, no_cache=False, settings={} if settings and dist_settings: raise ValueError('Use settings or dist_settings') if dist_settings: settings = dist_settings files = chainFiles(file_root) + if not files: # try new Cobaya format + files = chainFiles(file_root, separator='.') path, name = os.path.split(file_root) path = getdist.cache_dir or path if not os.path.exists(path): os.mkdir(path) @@ -73,9 +75,13 @@ def loadMCSamples(file_root, ini=None, jobItem=None, no_cache=False, settings={} samples = MCSamples(file_root, jobItem=jobItem, ini=ini, settings=settings) if os.path.isfile(file_root + '.paramnames'): allfiles = files + [file_root + '.ranges', file_root + '.paramnames', file_root + '.properties.ini'] - else: # new format (txt+yaml) - mid = "" if file_root.endswith("/") else "__" - allfiles = files + [file_root + mid + ending for ending in ['input.yaml', 'full.yaml']] + else: # Cobaya + folder = os.path.dirname(file_root) + prefix = os.path.basename(file_root) + allfiles = files + [ + os.path.join(folder, f) for f in os.listdir(folder) if ( + f.startswith(prefix) and + any([f.lower().endswith(end) for end in ['updated.yaml', 'full.yaml']]))] if not no_cache and os.path.exists(cachefile) and lastModified(allfiles) < os.path.getmtime(cachefile): try: with open(cachefile, 'rb') as inp: @@ -95,12 +101,18 @@ def loadMCSamples(file_root, ini=None, jobItem=None, no_cache=False, settings={} return samples -def loadCobayaSamples(info, collections, name_tag=None, - ignore_rows=0, ini=None, settings={}): +def loadCobayaSamples(*args, **kwargs): + logging.warning("'loadCobayaSamples' will be deprecated in the future. " + "Use 'MCSamplesFromCobaya' instead.") + return MCSamplesFromCobaya(*args, **kwargs) + + +def MCSamplesFromCobaya(info, collections, name_tag=None, + ignore_rows=0, ini=None, settings={}): """ - Loads a set of samples from Cobaya's output. + Creates a set of samples from Cobaya's output. Parameter names, ranges and labels are taken from the "info" dictionary - (always use the "full", updated one generated by `cobaya.run`). + (always use the "updated" one generated by `cobaya.run`). For a description of the various analysis settings and default values see `analysis_defaults.ini `_. @@ -114,6 +126,10 @@ def loadCobayaSamples(info, collections, name_tag=None, :param settings: dictionary of analysis settings to override defaults :return: The :class:`MCSamples` instance """ + from getdist.cobaya_interface import _p_label, _p_renames, _weight, _minuslogpost + from getdist.cobaya_interface import get_info_params, get_range, is_derived_param + from getdist.cobaya_interface import get_sampler_type, _post + if not hasattr(info, "keys"): raise TypeError("Cannot regonise arguments. Are you sure you are calling " "with (info, collections, ...) in that order?") @@ -127,9 +143,6 @@ def loadCobayaSamples(info, collections, name_tag=None, "The second argument does not appear to be a (list of) samples `Collection`.") if not all([list(c.data) == columns for c in collections[1:]]): raise ValueError("The given collections don't have the same columns.") - from getdist.yaml_format_tools import _p_label, _p_renames, _weight, _minuslogpost - from getdist.yaml_format_tools import get_info_params, get_range, is_derived_param - from getdist.yaml_format_tools import get_sampler_type, _post # Check consistency with info info_params = get_info_params(info) # #################################################################################### @@ -139,8 +152,8 @@ def loadCobayaSamples(info, collections, name_tag=None, thin = info.get(_post, {}).get("thin", 1) # Maybe warn if trying to ignore rows twice? if ignore_rows != 0 and skip != 0: - logging.warn("You are asking for rows to be ignored (%r), but some (%r) were " - "already ignored in the original chain.", ignore_rows, skip) + logging.warning("You are asking for rows to be ignored (%r), but some (%r) were " + "already ignored in the original chain.", ignore_rows, skip) # Should we warn about thin too? # Most importantly: do we want to save somewhere the fact that we have *already* # thinned/skipped? @@ -182,7 +195,8 @@ class MCSamples(Chains): """ The main high-level class for a collection of parameter samples. - Derives from :class:`.chains.Chains`, adding high-level functions including Kernel Density estimates, parameter ranges and custom settings. + Derives from :class:`.chains.Chains`, adding high-level functions including + Kernel Density estimates, parameter ranges and custom settings. """ def __init__(self, root=None, jobItem=None, ini=None, settings=None, ranges=None, @@ -2083,10 +2097,15 @@ def _setLikeStats(self): def _readRanges(self): if self.root: + from getdist.cobaya_interface import _separator_files ranges_file_classic = self.root + '.ranges' - ranges_file_new = ( - self.root + ('' if self.root.endswith('/') else '__') + 'full.yaml') - for ranges_file in [ranges_file_classic, ranges_file_new]: + ranges_file_cobaya_old = ( + self.root + ('' if self.root.endswith((os.sep, "/")) else '__') + 'full.yaml') + ranges_file_cobaya = ( + self.root + ( + '' if self.root.endswith((os.sep, "/")) else _separator_files) + 'updated.yaml') + for ranges_file in [ + ranges_file_classic, ranges_file_cobaya_old, ranges_file_cobaya]: if os.path.isfile(ranges_file): self.ranges = ParamBounds(ranges_file) return @@ -2553,9 +2572,9 @@ def GetChainRootFiles(rootdir): """ pattern = os.path.join(rootdir, '*.paramnames') files = [os.path.splitext(f)[0] for f in glob.glob(pattern)] - ending = 'full.yaml' - pattern = os.path.join(rootdir, "*" + ending) - files += [f[:-len(ending)].rstrip("_") for f in glob.glob(pattern)] + for ending in ['full.yaml', 'updated.yaml']: + pattern = os.path.join(rootdir, "*" + ending) + files += [f[:-len(ending)].rstrip("_.") for f in glob.glob(pattern)] files.sort() return files diff --git a/getdist/paramnames.py b/getdist/paramnames.py index df974a8..585e48d 100644 --- a/getdist/paramnames.py +++ b/getdist/paramnames.py @@ -378,10 +378,9 @@ def loadFromFile(self, fileName): with open(fileName) as f: self.names = [ParamInfo(line) for line in [s.strip() for s in f] if line != ''] elif extension.lower() in ('.yaml', '.yml'): - from getdist.yaml_format_tools import yaml_load_file, get_info_params - from getdist.yaml_format_tools import is_sampled_param, is_derived_param - from getdist.yaml_format_tools import _p_label, _p_renames - + from getdist.yaml_tools import yaml_load_file + from getdist.cobaya_interface import get_info_params, is_sampled_param + from getdist.cobaya_interface import is_derived_param, _p_label, _p_renames info_params = get_info_params(yaml_load_file(fileName)) # first sampled, then derived self.names = [ParamInfo(name=param, label=(info or {}).get(_p_label, param), diff --git a/getdist/parampriors.py b/getdist/parampriors.py index f3cc9a0..c6e078d 100644 --- a/getdist/parampriors.py +++ b/getdist/parampriors.py @@ -31,8 +31,8 @@ def loadFromFile(self, fileName): if len(strings) == 3: self.setRange(strings[0], strings[1:]) elif extension in ('.yaml', '.yml'): - from getdist.yaml_format_tools import yaml_load_file, get_info_params - from getdist.yaml_format_tools import get_range, is_fixed_param + from getdist.cobaya_interface import get_range, is_fixed_param, get_info_params + from getdist.yaml_tools import yaml_load_file info_params = get_info_params(yaml_load_file(fileName)) for p, info in info_params.items(): if not is_fixed_param(info): diff --git a/getdist/plots.py b/getdist/plots.py index 440e5b9..6747760 100644 --- a/getdist/plots.py +++ b/getdist/plots.py @@ -14,6 +14,7 @@ from paramgrid import gridconfig, batchjob import getdist from getdist import MCSamples, loadMCSamples, ParamNames, ParamInfo, IniFile +from getdist.chains import chainFiles from getdist.paramnames import escapeLatex, makeList, mergeRenames from getdist.parampriors import ParamBounds from getdist.densities import Density1D, Density2D @@ -446,8 +447,8 @@ def samplesForRoot(self, root, file_root=None, cache=True, settings=None): if isinstance(root, MCSamples): return root if os.path.isabs(root): # deal with just-folder prefix - if root.endswith("/"): - root = os.path.basename(root[:-1]) + "/" + if root.endswith((os.sep, "/")): + root = os.path.basename(root[:-1]) + os.sep else: root = os.path.basename(root) if root in self.mcsamples and cache: return self.mcsamples[root] @@ -457,6 +458,7 @@ def samplesForRoot(self, root, file_root=None, cache=True, settings=None): else: dist_settings = {} if not file_root: + from getdist.cobaya_interface import _separator_files for chain_dir in self.chain_dirs: if hasattr(chain_dir, "resolveRoot"): jobItem = chain_dir.resolveRoot(root) @@ -468,7 +470,8 @@ def samplesForRoot(self, root, file_root=None, cache=True, settings=None): break else: name = os.path.join(chain_dir, root) - if os.path.exists(name + '_1.txt') or os.path.exists(name + '.txt'): + if any([chainFiles(name, separator=sep) + for sep in ['_', _separator_files]]): file_root = name break if not file_root: @@ -1944,7 +1947,7 @@ def triangle_plot(self, roots, params=None, legend_labels=None, plot_3d_with_par :param title_limit:if not None, a maginalized limit (1,2..) to print as the title of the first root on the diagonal 1D plots :param upper_kwargs: dict for same-named arguments for use when making upper-triangle 2D plots (contour_colors, etc). Set show_1d=False to not add to the diagonal. :param diag1d_kwargs: list of dict for arguments when making 1D plots on grid diagonal - :param markers: optional dict giving marker values indexed by parameter, or a list of marker values for each parameter plotted + :param markers: optional dict giving marker values indexed by parameter, or a list of marker values for each parameter plotted :param param_limits: a dictionary holding a mapping from parameter names to axis limits for that parameter :param kwargs: optional keyword arguments for :func:`~GetDistPlotter.plot_2d` or :func:`~GetDistPlotter.plot_3d` (lower triangle only) diff --git a/getdist/yaml_tools.py b/getdist/yaml_tools.py new file mode 100644 index 0000000..2907d5a --- /dev/null +++ b/getdist/yaml_tools.py @@ -0,0 +1,86 @@ +# JT 2017-19 + +from __future__ import division +import re +from collections import OrderedDict as odict +import six + +if six.PY2: + ModuleNotFoundError = ImportError +try: + import yaml +except ModuleNotFoundError: + raise ModuleNotFoundError( + "You need to install 'PyYAML' in order to load Cobaya samples.") + + +# Exceptions +class InputSyntaxError(Exception): + """Syntax error in YAML input.""" + + +# Better loader for YAML +# 1. Matches 1e2 as 100 (no need for dot, or sign after e), +# from http://stackoverflow.com/a/30462009 +# 2. Wrapper to load mappings as OrderedDict (for likelihoods and params), +# from http://stackoverflow.com/a/21912744 +def yaml_load(text_stream, Loader=yaml.Loader, object_pairs_hook=odict, file_name=None): + class OrderedLoader(Loader): + pass + + def construct_mapping(loader, node): + loader.flatten_mapping(node) + return object_pairs_hook(loader.construct_pairs(node)) + + OrderedLoader.add_constructor( + yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG, construct_mapping) + OrderedLoader.add_implicit_resolver( + u'tag:yaml.org,2002:float', + re.compile(u'''^(?: + [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? + |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) + |\\.[0-9_]+(?:[eE][-+][0-9]+)? + |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* + |[-+]?\\.(?:inf|Inf|INF) + |\\.(?:nan|NaN|NAN))$''', re.X), + list(u'-+0123456789.')) + + # Ignore python objects + def dummy_object_loader(loader, suffix, node): + return None + + OrderedLoader.add_multi_constructor( + u'tag:yaml.org,2002:python/name:', dummy_object_loader) + try: + return yaml.load(text_stream, OrderedLoader) + # Redefining the general exception to give more user-friendly information + except yaml.YAMLError as exception: + errstr = "Error in your input file " + ("'" + file_name + "'" if file_name else "") + if hasattr(exception, "problem_mark"): + line = 1 + exception.problem_mark.line + column = 1 + exception.problem_mark.column + signal = " --> " + signal_right = " <---- " + sep = "|" + context = 4 + lines = text_stream.split("\n") + pre = ((("\n" + " " * len(signal) + sep).join( + [""] + lines[max(line - 1 - context, 0):line - 1]))) + "\n" + errorline = (signal + sep + lines[line - 1] + + signal_right + "column %s" % column) + post = ((("\n" + " " * len(signal) + sep).join( + [""] + lines[line + 1 - 1:min(line + 1 + context - 1, len(lines))]))) + "\n" + raise InputSyntaxError( + errstr + " at line %d, column %d." % (line, column) + + pre + errorline + post + + "Maybe inconsistent indentation, '=' instead of ':', " + "no space after ':', or a missing ':' on an empty group?") + else: + raise InputSyntaxError(errstr) + + +def yaml_load_file(input_file): + """Wrapper to load a yaml file.""" + with open(input_file, "r") as f: + lines = "".join(f.readlines()) + return yaml_load(lines, file_name=input_file)