diff --git a/pyglotaran_extras/plotting/plot_doas.py b/pyglotaran_extras/plotting/plot_doas.py index cf75dbdc..477779b5 100644 --- a/pyglotaran_extras/plotting/plot_doas.py +++ b/pyglotaran_extras/plotting/plot_doas.py @@ -37,6 +37,7 @@ def plot_doas( cycler: Cycler | None = PlotStyle().cycler, oscillation_type: Literal["cos", "sin"] = "cos", title: str | None = "Damped oscillations", + legend_format_string: str = "{label}: v={frequency:.0f}, Γ={rate:.1f}", ) -> tuple[Figure, Axes]: """Plot DOAS (Damped Oscillation) related data of the optimization result. @@ -74,6 +75,11 @@ def plot_doas( Type of the oscillation to show in the oscillation plot. Defaults to "cos" title: str | None Title of the figure. Defaults to "Damped oscillations" + legend_format_string: str + Format string for each entry in the legend of the oscillation plot. Possible values which + can be replaced are ``label`` (label of the oscillation in the model definition), + ``frequency`` and ``rate``. Use ``""`` to remove the legend. Defaults to + ``"{label}: v={frequency:.0f}, Γ={rate:.1f}"`` Returns ------- @@ -115,10 +121,22 @@ def plot_doas( norm_factor = scales.max() if normalize is True else 1 - ((oscillations - 1) / osc_max * scales * norm_factor).sel(**time_sel_kwargs).plot.line( - x="time", ax=axes[0] + oscillations_to_plot = ((oscillations - 1) / osc_max * scales * norm_factor).sel( + **time_sel_kwargs ) + for oscillation_label in oscillations_to_plot.damped_oscillation.values: + oscillation = oscillations_to_plot.sel(damped_oscillation=[oscillation_label]) + frequency = oscillation.damped_oscillation_frequency.item() + rate = oscillation.damped_oscillation_rate.item() + oscillation.plot.line( + x="time", + ax=axes[0], + label=legend_format_string.format( + label=oscillation_label, frequency=frequency, rate=rate + ), + ) + (oscillations_spectra / spectra_max * scales / norm_factor).plot.line(x="spectral", ax=axes[1]) damped_oscillation_phase.plot.line(x="spectral", ax=axes[2]) @@ -134,6 +152,11 @@ def plot_doas( ) axes[2].set_ylabel("Phase (π)") + if not legend_format_string: + axes[0].get_legend().remove() + else: + axes[0].legend() + axes[1].get_legend().remove() axes[2].get_legend().remove()