Skip to content

Commit

Permalink
ENH: Added gui features to export simulation to a csv file
Browse files Browse the repository at this point in the history
ENH: New download button using HTML. Changed file extension to csv
  • Loading branch information
kmilo9999 committed May 8, 2024
1 parent 33de268 commit 4ded929
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 5 deletions.
21 changes: 21 additions & 0 deletions hnn_core/gui/_viz_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from hnn_core.gui._logging import logger
from hnn_core.viz import plot_dipole


_fig_placeholder = 'Run simulation to add figures here.'

_plot_types = [
Expand Down Expand Up @@ -894,6 +895,26 @@ def _layout_template_change(self, template_type):
# hide sim-data dropdown
self.datasets_dropdown.layout.visibility = "hidden"

def save_simulation_csv(self, simulation_name):
# pre-allocate np.matrix to save
signals_matrix = (np.zeros((self.data["simulations"][simulation_name]
['dpls'][0].data['agg'].size, 4)))
signals_matrix[:, 0] = (self.data["simulations"][simulation_name]
['dpls'][0].times)
signals_matrix[:, 1] = (self.data["simulations"][simulation_name]
['dpls'][0].data['agg'])
signals_matrix[:, 2] = (self.data["simulations"][simulation_name]
['dpls'][0].data['L2'])
signals_matrix[:, 3] = (self.data["simulations"][simulation_name]
['dpls'][0].data['L5'])
output = io.StringIO()
np.savetxt(output, signals_matrix, delimiter=',',
header='times,agg,L2,L5', fmt='%f, %f, %f, %f')

# Get the string from StringIO
csv_string = output.getvalue()
return csv_string

@unlink_relink(attribute='figs_config_tab_link')
def add_figure(self, b=None):
"""Add a figure and corresponding config tabs to the dashboard.
Expand Down
70 changes: 65 additions & 5 deletions hnn_core/gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from hnn_core.network import pick_connection
from hnn_core.params import (_extract_drive_specs_from_hnn_params, _read_json,
_read_legacy_params)
import base64


class _OutputWidgetHandler(logging.Handler):
Expand Down Expand Up @@ -147,12 +148,14 @@ def __init__(self, theme_color="#8A2BE2",
"header_height": f"{header_height}px",
"theme_color": theme_color,
"btn": Layout(height=f"{button_height}px", width='auto'),
"run_btn": Layout(height=f"{button_height}px", width='10%'),
"btn_full_w": Layout(height=f"{button_height}px", width='100%'),
"del_fig_btn": Layout(height=f"{button_height}px", width='auto'),
"log_out": Layout(border='1px solid gray',
height=f"{log_window_height-10}px",
overflow='auto'),
"viz_config": Layout(width='99%'),
"simulations_list": Layout(width=f'{left_sidebar_width-50}px'),
"visualization_window": Layout(
width=f"{viz_win_width-10}px",
height=f"{main_content_height-10}px",
Expand Down Expand Up @@ -234,6 +237,35 @@ def __init__(self, theme_color="#8A2BE2",
description='Load data',
button_style='success')

b64 = base64.b64encode("".encode())
payload = b64.decode()
# Initialliting HTML code for download button
self.html_download_button = '''<html>
<head>
<meta name="viewport" content="width=device-width, initial-scale=1">
</head>
<body>
<a download="{filename}" href="data:text/csv;base64,{payload}"
download>
<button style="background:{color_theme}"
class="p-Widget jupyter-widgets jupyter-button
widget-button mod-warning" {is_disabled} >Save Simulation</button>
</a>
</body>
</html>
'''
# Create widget wrapper
self.save_simuation_button = (
HTML(self.html_download_button.
format(payload=payload,
filename={""},
is_disabled="disabled",
color_theme=self.layout['theme_color'])))

self.simulation_list_widget = Dropdown(options=[],
value=None,
description='',
layout={'width': 'max-content'})
# Drive selection
self.widget_drive_type_selection = RadioButtons(
options=['Evoked', 'Poisson', 'Rhythmic'],
Expand All @@ -251,7 +283,7 @@ def __init__(self, theme_color="#8A2BE2",

# Dashboard level buttons
self.run_button = create_expanded_button(
'Run', 'success', layout=self.layout['btn'],
'Run', 'success', layout=self.layout['run_btn'],
button_color=self.layout['theme_color'])

self.load_connectivity_button = FileUpload(
Expand Down Expand Up @@ -316,7 +348,9 @@ def _init_ui_components(self):

self._log_window = HBox([self._log_out], layout=self.layout['log_out'])
self._operation_buttons = HBox(
[self.run_button, self.load_data_button],
[self.run_button, self.load_data_button,
self.save_simuation_button,
self.simulation_list_widget],
layout=self.layout['operation_box'])
# title
self._header = HTML(value=f"""<div
Expand Down Expand Up @@ -400,7 +434,20 @@ def _run_button_clicked(b):
self.widget_ntrials, self.widget_backend_selection,
self.widget_mpi_cmd, self.widget_n_jobs, self.params,
self._simulation_status_bar, self._simulation_status_contents,
self.connectivity_widgets, self.viz_manager)
self.connectivity_widgets, self.viz_manager,
self.simulation_list_widget)

def _simulation_list_change(value):
csv_string = _simulation_in_csv_format(self._log_out,
self.viz_manager,
self.simulation_list_widget)
result_file = f"{value.new}.csv"
b64 = base64.b64encode(csv_string.encode())
payload = b64.decode()
self.save_simuation_button.value = (
self.html_download_button.format(
payload=payload, filename=result_file,
is_disabled="", color_theme=self.layout['theme_color']))

self.widget_backend_selection.observe(_handle_backend_change, 'value')
self.add_drive_button.on_click(_add_drive_button_clicked)
Expand All @@ -409,8 +456,8 @@ def _run_button_clicked(b):
names='value')
self.load_drives_button.observe(_on_upload_drives, names='value')
self.run_button.on_click(_run_button_clicked)

self.load_data_button.observe(_on_upload_data, names='value')
self.simulation_list_widget.observe(_simulation_list_change, 'value')

def compose(self, return_layout=True):
"""Compose widgets.
Expand Down Expand Up @@ -1330,7 +1377,7 @@ def run_button_clicked(widget_simulation_name, log_out, drive_widgets,
all_data, dt, tstop, ntrials, backend_selection,
mpi_cmd, n_jobs, params, simulation_status_bar,
simulation_status_contents, connectivity_textfields,
viz_manager):
viz_manager, simulations_list_widget):
"""Run the simulation and plot outputs."""
log_out.clear_output()
simulation_data = all_data["simulation_data"]
Expand Down Expand Up @@ -1369,6 +1416,11 @@ def run_button_clicked(widget_simulation_name, log_out, drive_widgets,
simulation_status_bar.value = simulation_status_contents[
'finished']

sim_names = [sim_name for sim_name
in simulation_data]
simulations_list_widget.options = sim_names
simulations_list_widget.value = sim_names[0]

viz_manager.reset_fig_config_tabs()
viz_manager.add_figure()
fig_name = _idx2figname(viz_manager.data['fig_idx']['idx'] - 1)
Expand All @@ -1378,6 +1430,14 @@ def run_button_clicked(widget_simulation_name, log_out, drive_widgets,
plot_type, {}, "plot")


def _simulation_in_csv_format(log_out, viz_manager, simulation_list_widget):
# Only download if there is there is at least one simulation
sin_name = simulation_list_widget.value
with log_out:
logger.info(f"Saving {sin_name}.txt")
return viz_manager.save_simulation_csv(sin_name)


def handle_backend_change(backend_type, backend_config, mpi_cmd, n_jobs):
"""Switch backends between MPI and Joblib."""
backend_config.clear_output()
Expand Down

0 comments on commit 4ded929

Please sign in to comment.