diff --git a/hnn_core/gui/_viz_manager.py b/hnn_core/gui/_viz_manager.py index a1eb83aae..afd8a0c73 100644 --- a/hnn_core/gui/_viz_manager.py +++ b/hnn_core/gui/_viz_manager.py @@ -17,7 +17,6 @@ from hnn_core.gui._logging import logger from hnn_core.viz import plot_dipole - _fig_placeholder = 'Run simulation to add figures here.' _plot_types = [ @@ -895,26 +894,6 @@ def _layout_template_change(self, template_type): # hide sim-data dropdown self.datasets_dropdown.layout.visibility = "hidden" - def save_simulation_csv(self, simulation_name): - # pre-allocate np.matrix to save - signals_matrix = (np.zeros((self.data["simulations"][simulation_name] - ['dpls'][0].data['agg'].size, 4))) - signals_matrix[:, 0] = (self.data["simulations"][simulation_name] - ['dpls'][0].times) - signals_matrix[:, 1] = (self.data["simulations"][simulation_name] - ['dpls'][0].data['agg']) - signals_matrix[:, 2] = (self.data["simulations"][simulation_name] - ['dpls'][0].data['L2']) - signals_matrix[:, 3] = (self.data["simulations"][simulation_name] - ['dpls'][0].data['L5']) - output = io.StringIO() - np.savetxt(output, signals_matrix, delimiter=',', - header='times,agg,L2,L5', fmt='%f, %f, %f, %f') - - # Get the string from StringIO - csv_string = output.getvalue() - return csv_string - @unlink_relink(attribute='figs_config_tab_link') def add_figure(self, b=None): """Add a figure and corresponding config tabs to the dashboard. diff --git a/hnn_core/gui/gui.py b/hnn_core/gui/gui.py index 32fda5c67..4e39b06a7 100644 --- a/hnn_core/gui/gui.py +++ b/hnn_core/gui/gui.py @@ -27,6 +27,8 @@ from hnn_core.params import (_extract_drive_specs_from_hnn_params, _read_json, _read_legacy_params) import base64 +import zipfile +import numpy as np class _OutputWidgetHandler(logging.Handler): @@ -240,19 +242,13 @@ def __init__(self, theme_color="#8A2BE2", b64 = base64.b64encode("".encode()) payload = b64.decode() # Initialliting HTML code for download button - self.html_download_button = ''' - - - - + self.html_download_button = ''' - + - - ''' # Create widget wrapper self.save_simuation_button = ( @@ -260,6 +256,7 @@ def __init__(self, theme_color="#8A2BE2", format(payload=payload, filename={""}, is_disabled="disabled", + btn_height=self.layout['run_btn'].height, color_theme=self.layout['theme_color']))) self.simulation_list_widget = Dropdown(options=[], @@ -438,16 +435,23 @@ def _run_button_clicked(b): self.simulation_list_widget) def _simulation_list_change(value): - csv_string = _simulation_in_csv_format(self._log_out, - self.viz_manager, - self.simulation_list_widget) - result_file = f"{value.new}.csv" - b64 = base64.b64encode(csv_string.encode()) + _simulation_data, file_extension = ( + _serialize_simulation(self._log_out, + self.data, + self.simulation_list_widget)) + + result_file = f"{value.new}{file_extension}" + if file_extension == ".csv": + b64 = base64.b64encode(_simulation_data.encode()) + else: + b64 = base64.b64encode(_simulation_data) + payload = b64.decode() self.save_simuation_button.value = ( self.html_download_button.format( payload=payload, filename=result_file, - is_disabled="", color_theme=self.layout['theme_color'])) + is_disabled="", btn_height=self.layout['run_btn'].height, + color_theme=self.layout['theme_color'])) self.widget_backend_selection.observe(_handle_backend_change, 'value') self.add_drive_button.on_click(_add_drive_button_clicked) @@ -1430,12 +1434,71 @@ def run_button_clicked(widget_simulation_name, log_out, drive_widgets, plot_type, {}, "plot") -def _simulation_in_csv_format(log_out, viz_manager, simulation_list_widget): +def _serialize_simulation(log_out, sim_data, simulation_list_widget): # Only download if there is there is at least one simulation - sin_name = simulation_list_widget.value + sim_name = simulation_list_widget.value + with log_out: - logger.info(f"Saving {sin_name}.txt") - return viz_manager.save_simulation_csv(sin_name) + logger.info(f"Saving {sim_name}.txt") + return serialize_simulation(sim_data, sim_name) + + +def serialize_simulation(simulations_data, simulation_name): + """ + Serializes the simulation data for a given simulation name + into either a single CSV file + or a ZIP file containing multiple CSVs, depending on the number + of trials in the simulation. + Parameters: + - simulation_name (str): The name of the simulation to serialize. + This name is used to access the corresponding + data in the 'simulations' dictionary + of the instance. + + Returns: + - tuple: A tuple where the first element is either + the CSV content (str) of a single trial + or the binary data of a ZIP file (bytes) + containing multiple CSV files, and the + second element is the file extension (either ".csv" or ".zip"). + """ + simulation_data = simulations_data["simulation_data"] + csv_trials_output = [] + # CSV file hearders + headers = 'times,agg,L2,L5' + fmt = '%f, %f, %f, %f' + + for dpl_trial in simulation_data[simulation_name]['dpls']: + # Combine all data columns at once + signals_matrix = np.column_stack(( + dpl_trial.times, + dpl_trial.data['agg'], + dpl_trial.data['L2'], + dpl_trial.data['L5'] + )) + + # Using StringIO to collect CSV data + with io.StringIO() as output: + np.savetxt(output, signals_matrix, delimiter=',', + header=headers, fmt=fmt) + csv_trials_output.append(output.getvalue()) + + if len(csv_trials_output) == 1: + # Return a single csv file + return csv_trials_output[0], ".csv" + else: + # Create zip file + return _create_zip(csv_trials_output, simulation_name), ".zip" + + +def _create_zip(csv_data_list, simulation_name): + # Zip all files and keep it in memory + with io.BytesIO() as zip_buffer: + with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zf: + for index, csv_data in enumerate(csv_data_list): + zf.writestr(f'{simulation_name}_{index+1}.csv', csv_data) + zip_buffer.seek(0) + return zip_buffer.read() def handle_backend_change(backend_type, backend_config, mpi_cmd, n_jobs): diff --git a/hnn_core/tests/test_gui.py b/hnn_core/tests/test_gui.py index 3b3fc902d..6f1822f89 100644 --- a/hnn_core/tests/test_gui.py +++ b/hnn_core/tests/test_gui.py @@ -10,6 +10,7 @@ from hnn_core.gui._viz_manager import (_idx2figname, _no_overlay_plot_types, unlink_relink) from hnn_core.gui.gui import _init_network_from_widgets +from hnn_core.gui.gui import serialize_simulation from hnn_core.network import pick_connection from hnn_core.network_models import jones_2009_model from hnn_core.parallel_backends import requires_mpi4py, requires_psutil @@ -563,3 +564,42 @@ def add_child_decorated(self, to_add): # Check if the widgets are relinked, the selected index should be synced gui.tab_group_1.selected_index = 0 assert gui.tab_group_2.selected_index == 0 + + +def test_gui_download_simulation(): + """Test the GUI download simulation pipeline.""" + gui = HNNGUI() + _ = gui.compose() + gui.params['N_pyr_x'] = 3 + gui.params['N_pyr_y'] = 3 + + # Run a simulation with 3 trials + gui.widget_dt.value = 0.85 + gui.widget_ntrials.value = 2 + + # Initiate 1rs simulation + sim_name = "sim1" + gui.widget_simulation_name.value = sim_name + + # Run simulation + gui.run_button.click() + + _, file_extension = ( + serialize_simulation(gui.data, sim_name)) + # result is a zip file + assert file_extension == ".zip" + + # Run a simulation with 1 trials + gui.widget_dt.value = 0.85 + gui.widget_ntrials.value = 1 + + # Initiate 2nd simulation + sim_name2 = "sim2" + gui.widget_simulation_name.value = sim_name2 + + # Run simulation + gui.run_button.click() + _, file_extension = ( + serialize_simulation(gui.data, sim_name2)) + # result is a single csv file + assert file_extension == ".csv"