Skip to content

Commit

Permalink
refine plot_1d with legend handler
Browse files Browse the repository at this point in the history
  • Loading branch information
xgarrido committed Oct 12, 2023
1 parent 9b482c2 commit 7a3e0a6
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions cobaya_utilities/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,23 +164,36 @@ def triangle_plot(*args, **kwargs):

def plots_1d(*args, **kwargs):
"""Overloaded plots_1d function with additional features"""
default_plotter_options = {"width_inch": 20}
default_plotter_options = {"width_inch": 4}
plotter_kwargs = {k: kwargs.get(k, v) for k, v in default_plotter_options.items()}

g = get_single_plotter(settings=get_default_settings(), **plotter_kwargs)
if legend_kwargs := kwargs.get("legend_kwargs"):
legend_labels = kwargs.get("legend_labels")
kwargs.update(dict(legend_labels=[]))

g = get_subplot_plotter(settings=get_default_settings(), **plotter_kwargs)
g.plots_1d(*args, **kwargs)

if kwargs.get("despine", True):
despine(g)

if legend_kwargs:
g.add_legend(legend_labels, **legend_kwargs)

return g


def plots_2d(*args, **kwargs):
"""Overloaded plots_2d function with additional features"""
default_plotter_options = {"width_inch": 20}
default_plotter_options = {"width_inch": 4}
plotter_kwargs = {k: kwargs.get(k, v) for k, v in default_plotter_options.items()}

g = get_subplot_plotter(settings=get_default_settings(), **plotter_kwargs)
g.plots_2d(*args, **kwargs)

if kwargs.get("despine", True):
despine(g)

return g


Expand Down

0 comments on commit 7a3e0a6

Please sign in to comment.