diff --git a/cobaya_utilities/plots.py b/cobaya_utilities/plots.py index c044863..02e7bd7 100644 --- a/cobaya_utilities/plots.py +++ b/cobaya_utilities/plots.py @@ -164,23 +164,36 @@ def triangle_plot(*args, **kwargs): def plots_1d(*args, **kwargs): """Overloaded plots_1d function with additional features""" - default_plotter_options = {"width_inch": 20} + default_plotter_options = {"width_inch": 4} plotter_kwargs = {k: kwargs.get(k, v) for k, v in default_plotter_options.items()} - g = get_single_plotter(settings=get_default_settings(), **plotter_kwargs) + if legend_kwargs := kwargs.get("legend_kwargs"): + legend_labels = kwargs.get("legend_labels") + kwargs.update(dict(legend_labels=[])) + + g = get_subplot_plotter(settings=get_default_settings(), **plotter_kwargs) g.plots_1d(*args, **kwargs) + if kwargs.get("despine", True): + despine(g) + + if legend_kwargs: + g.add_legend(legend_labels, **legend_kwargs) + return g def plots_2d(*args, **kwargs): """Overloaded plots_2d function with additional features""" - default_plotter_options = {"width_inch": 20} + default_plotter_options = {"width_inch": 4} plotter_kwargs = {k: kwargs.get(k, v) for k, v in default_plotter_options.items()} g = get_subplot_plotter(settings=get_default_settings(), **plotter_kwargs) g.plots_2d(*args, **kwargs) + if kwargs.get("despine", True): + despine(g) + return g