From f7ac50e4513795e246b8d66cdd9bec1085b53d65 Mon Sep 17 00:00:00 2001 From: George Dang <53052793+gtdang@users.noreply.github.com> Date: Wed, 21 Aug 2024 16:10:59 -0400 Subject: [PATCH] fix: added a function to check for existing averaged dipoles -The GUI can average and append the average to the dipole list at an earlier stage than plotting with "data to compare" specified. This solution checks for an existing average, if not it will return an average. -This fixes an issue where the average_dipoles function was throwing an error because it was passed a list with an averaged dipole already. --- hnn_core/gui/_viz_manager.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/hnn_core/gui/_viz_manager.py b/hnn_core/gui/_viz_manager.py index 95a7cd756..42848b7e6 100644 --- a/hnn_core/gui/_viz_manager.py +++ b/hnn_core/gui/_viz_manager.py @@ -343,6 +343,17 @@ def _dynamic_rerender(fig): fig.tight_layout() +def _avg_dipole_check(dpls): + """Check for averaged dipole, else average the trials""" + # Check if there is an averaged dipole already + avg_dpls = [d for d in dpls if d.nave > 1] + if avg_dpls: + dpl = avg_dpls[0] + else: + dpl = average_dipoles(dpls) + return dpl + + def _plot_on_axes(b, simulations_widget, widgets_plot_type, data_widget, spectrogram_colormap_selection, max_spectral_frequency, @@ -427,7 +438,7 @@ def _plot_on_axes(b, simulations_widget, widgets_plot_type, t0 = 0.0 tstop = dpls_processed[-1].times[-1] if len(dpls_processed) > 1: - dpl = average_dipoles(dpls_processed) + dpl = _avg_dipole_check(dpls_processed) else: dpl = dpls_processed rmse = _rmse(dpl, target_dpl_processed, t0, tstop)