Skip to content

Commit

Permalink
ENH: download multiple trials on a zip file
Browse files Browse the repository at this point in the history
TST: added download simulation test

DOC: Added docstring to serialize_simulation function

MAINT: Moved serialize simulation logic to gui.py

STY: Fixed flake8 errors in viz_manager.py

STY:Fixed flake8 errors
  • Loading branch information
kmilo9999 committed May 8, 2024
1 parent 4ded929 commit e881362
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 41 deletions.
21 changes: 0 additions & 21 deletions hnn_core/gui/_viz_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from hnn_core.gui._logging import logger
from hnn_core.viz import plot_dipole


_fig_placeholder = 'Run simulation to add figures here.'

_plot_types = [
Expand Down Expand Up @@ -895,26 +894,6 @@ def _layout_template_change(self, template_type):
# hide sim-data dropdown
self.datasets_dropdown.layout.visibility = "hidden"

def save_simulation_csv(self, simulation_name):
# pre-allocate np.matrix to save
signals_matrix = (np.zeros((self.data["simulations"][simulation_name]
['dpls'][0].data['agg'].size, 4)))
signals_matrix[:, 0] = (self.data["simulations"][simulation_name]
['dpls'][0].times)
signals_matrix[:, 1] = (self.data["simulations"][simulation_name]
['dpls'][0].data['agg'])
signals_matrix[:, 2] = (self.data["simulations"][simulation_name]
['dpls'][0].data['L2'])
signals_matrix[:, 3] = (self.data["simulations"][simulation_name]
['dpls'][0].data['L5'])
output = io.StringIO()
np.savetxt(output, signals_matrix, delimiter=',',
header='times,agg,L2,L5', fmt='%f, %f, %f, %f')

# Get the string from StringIO
csv_string = output.getvalue()
return csv_string

@unlink_relink(attribute='figs_config_tab_link')
def add_figure(self, b=None):
"""Add a figure and corresponding config tabs to the dashboard.
Expand Down
103 changes: 83 additions & 20 deletions hnn_core/gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from hnn_core.params import (_extract_drive_specs_from_hnn_params, _read_json,
_read_legacy_params)
import base64
import zipfile
import numpy as np


class _OutputWidgetHandler(logging.Handler):
Expand Down Expand Up @@ -240,26 +242,21 @@ def __init__(self, theme_color="#8A2BE2",
b64 = base64.b64encode("".encode())
payload = b64.decode()
# Initialliting HTML code for download button
self.html_download_button = '''<html>
<head>
<meta name="viewport" content="width=device-width, initial-scale=1">
</head>
<body>
self.html_download_button = '''
<a download="{filename}" href="data:text/csv;base64,{payload}"
download>
<button style="background:{color_theme}"
class="p-Widget jupyter-widgets jupyter-button
widget-button mod-warning" {is_disabled} >Save Simulation</button>
<button style="background:{color_theme}; height:{btn_height}"
class=" jupyter-button
mod-warning" {is_disabled} >Save Simulation</button>
</a>
</body>
</html>
'''
# Create widget wrapper
self.save_simuation_button = (
HTML(self.html_download_button.
format(payload=payload,
filename={""},
is_disabled="disabled",
btn_height=self.layout['run_btn'].height,
color_theme=self.layout['theme_color'])))

self.simulation_list_widget = Dropdown(options=[],
Expand Down Expand Up @@ -438,16 +435,23 @@ def _run_button_clicked(b):
self.simulation_list_widget)

def _simulation_list_change(value):
csv_string = _simulation_in_csv_format(self._log_out,
self.viz_manager,
self.simulation_list_widget)
result_file = f"{value.new}.csv"
b64 = base64.b64encode(csv_string.encode())
_simulation_data, file_extension = (
_serialize_simulation(self._log_out,
self.data,
self.simulation_list_widget))

result_file = f"{value.new}{file_extension}"
if file_extension == ".csv":
b64 = base64.b64encode(_simulation_data.encode())
else:
b64 = base64.b64encode(_simulation_data)

payload = b64.decode()
self.save_simuation_button.value = (
self.html_download_button.format(
payload=payload, filename=result_file,
is_disabled="", color_theme=self.layout['theme_color']))
is_disabled="", btn_height=self.layout['run_btn'].height,
color_theme=self.layout['theme_color']))

self.widget_backend_selection.observe(_handle_backend_change, 'value')
self.add_drive_button.on_click(_add_drive_button_clicked)
Expand Down Expand Up @@ -1430,12 +1434,71 @@ def run_button_clicked(widget_simulation_name, log_out, drive_widgets,
plot_type, {}, "plot")


def _simulation_in_csv_format(log_out, viz_manager, simulation_list_widget):
def _serialize_simulation(log_out, sim_data, simulation_list_widget):
# Only download if there is there is at least one simulation
sin_name = simulation_list_widget.value
sim_name = simulation_list_widget.value

with log_out:
logger.info(f"Saving {sin_name}.txt")
return viz_manager.save_simulation_csv(sin_name)
logger.info(f"Saving {sim_name}.txt")
return serialize_simulation(sim_data, sim_name)


def serialize_simulation(simulations_data, simulation_name):
"""
Serializes the simulation data for a given simulation name
into either a single CSV file
or a ZIP file containing multiple CSVs, depending on the number
of trials in the simulation.
Parameters:
- simulation_name (str): The name of the simulation to serialize.
This name is used to access the corresponding
data in the 'simulations' dictionary
of the instance.
Returns:
- tuple: A tuple where the first element is either
the CSV content (str) of a single trial
or the binary data of a ZIP file (bytes)
containing multiple CSV files, and the
second element is the file extension (either ".csv" or ".zip").
"""
simulation_data = simulations_data["simulation_data"]
csv_trials_output = []
# CSV file hearders
headers = 'times,agg,L2,L5'
fmt = '%f, %f, %f, %f'

for dpl_trial in simulation_data[simulation_name]['dpls']:
# Combine all data columns at once
signals_matrix = np.column_stack((
dpl_trial.times,
dpl_trial.data['agg'],
dpl_trial.data['L2'],
dpl_trial.data['L5']
))

# Using StringIO to collect CSV data
with io.StringIO() as output:
np.savetxt(output, signals_matrix, delimiter=',',
header=headers, fmt=fmt)
csv_trials_output.append(output.getvalue())

if len(csv_trials_output) == 1:
# Return a single csv file
return csv_trials_output[0], ".csv"
else:
# Create zip file
return _create_zip(csv_trials_output, simulation_name), ".zip"


def _create_zip(csv_data_list, simulation_name):
# Zip all files and keep it in memory
with io.BytesIO() as zip_buffer:
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zf:
for index, csv_data in enumerate(csv_data_list):
zf.writestr(f'{simulation_name}_{index+1}.csv', csv_data)
zip_buffer.seek(0)
return zip_buffer.read()


def handle_backend_change(backend_type, backend_config, mpi_cmd, n_jobs):
Expand Down
40 changes: 40 additions & 0 deletions hnn_core/tests/test_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from hnn_core.gui._viz_manager import (_idx2figname, _no_overlay_plot_types,
unlink_relink)
from hnn_core.gui.gui import _init_network_from_widgets
from hnn_core.gui.gui import serialize_simulation
from hnn_core.network import pick_connection
from hnn_core.network_models import jones_2009_model
from hnn_core.parallel_backends import requires_mpi4py, requires_psutil
Expand Down Expand Up @@ -563,3 +564,42 @@ def add_child_decorated(self, to_add):
# Check if the widgets are relinked, the selected index should be synced
gui.tab_group_1.selected_index = 0
assert gui.tab_group_2.selected_index == 0


def test_gui_download_simulation():
"""Test the GUI download simulation pipeline."""
gui = HNNGUI()
_ = gui.compose()
gui.params['N_pyr_x'] = 3
gui.params['N_pyr_y'] = 3

# Run a simulation with 3 trials
gui.widget_dt.value = 0.85
gui.widget_ntrials.value = 2

# Initiate 1rs simulation
sim_name = "sim1"
gui.widget_simulation_name.value = sim_name

# Run simulation
gui.run_button.click()

_, file_extension = (
serialize_simulation(gui.data, sim_name))
# result is a zip file
assert file_extension == ".zip"

# Run a simulation with 1 trials
gui.widget_dt.value = 0.85
gui.widget_ntrials.value = 1

# Initiate 2nd simulation
sim_name2 = "sim2"
gui.widget_simulation_name.value = sim_name2

# Run simulation
gui.run_button.click()
_, file_extension = (
serialize_simulation(gui.data, sim_name2))
# result is a single csv file
assert file_extension == ".csv"

0 comments on commit e881362

Please sign in to comment.