Skip to content

Commit

Permalink
fix #14 : horizontal labels accept more than 1 dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
alixdamman committed Sep 27, 2017
1 parent 7e393af commit c2c5072
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 66 deletions.
86 changes: 49 additions & 37 deletions larray_editor/arrayadapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@


class LArrayDataAdapter(object):
def __init__(self, axes_model, hlabels_model, vlabels_model, data_model,
data=None, changes=None, current_filter=None, bg_gradient=None, bg_value=None):
def __init__(self, axes_model, hlabels_model, vlabels_model, data_model, data=None,
changes=None, current_filter=None, nb_dims_hlabels=1, bg_gradient=None, bg_value=None):
# set models
self.axes_model = axes_model
self.hlabels_model = hlabels_model
self.vlabels_model = vlabels_model
self.data_model = data_model
# set number of dims of hlabels
self.nb_dims_hlabels = nb_dims_hlabels
# set current filter
if current_filter is None:
current_filter = {}
Expand All @@ -31,38 +33,43 @@ def set_changes(self, changes=None):
assert isinstance(changes, dict)
self.changes = changes

def update_nb_dims_hlabels(self, nb_dims_hlabels):
self.nb_dims_hlabels = nb_dims_hlabels
self.update_axes_and_labels()

def get_axes_names(self):
return self.filtered_data.axes.display_names

def get_axes(self):
axes = self.filtered_data.axes
if len(axes) == 0:
return None
else:
axes_names = axes.display_names
if len(axes_names) >= 2:
axes_names = axes_names[:-2] + [axes_names[-2] + '\\' + axes_names[-1]]
return [[axis_name] for axis_name in axes_names]

def get_hlabels(self):
axes = self.filtered_data.axes
if len(axes) == 0:
axes_names = self.filtered_data.axes.display_names
if len(axes_names) == 0:
return None
elif len(axes.labels[-1]) == 0:
return [['']]
elif len(axes_names) == 1:
return [axes_names]
else:
return [[label] for label in axes.labels[-1]]
nb_dims_vlabels = len(axes_names) - self.nb_dims_hlabels
# axes corresponding to horizontal labels are set to the last column
res = [['' for c in range(nb_dims_vlabels-1)] + [axis_name] for axis_name in axes_names[nb_dims_vlabels:]]
# axes corresponding to vertical labels are set to the last row
res = res + [[axis_name for axis_name in axes_names[:nb_dims_vlabels]]]
return res

def get_vlabels(self):
def get_labels(self):
axes = self.filtered_data.axes
if len(axes) == 0:
return None
elif len(axes) == 1:
return [['']]
else:
labels = axes.labels[:-1]
prod = Product(labels)
return [_LazyDimLabels(prod, i) for i in range(len(labels))]
nb_dims_vlabels = len(axes) - self.nb_dims_hlabels
def get_labels_product(axes, extra_row=False):
if len(axes) == 0:
return None
else:
# XXX: appends a fake axis instead of using _LazyNone because
# _LazyNone mess up with LabelsArrayModel.get_values (in which slices are used)
if extra_row:
axes.append(la.Axis([' ']))
prod = Product(axes.labels)
return [_LazyDimLabels(prod, i) for i in range(len(axes.labels))]
vlabels = get_labels_product(axes[:nb_dims_vlabels])
hlabels = get_labels_product(axes[nb_dims_vlabels:], nb_dims_vlabels > 0)
return vlabels, hlabels

def get_2D_data(self):
"""Returns Numpy 2D ndarray"""
Expand Down Expand Up @@ -110,24 +117,29 @@ def set_data(self, data, bg_gradient=None, bg_value=None, current_filter=None):
self.bg_gradient = bg_gradient
self.update_filtered_data(current_filter)

def update_axes_and_labels(self):
axes = self.get_axes()
vlabels, hlabels = self.get_labels()
self.axes_model.set_data(axes)
self.hlabels_model.set_data(hlabels)
self.vlabels_model.set_data(vlabels)

def update_data_2D(self):
data_2D = self.get_2D_data()
changes_2D = self.get_changes_2D()
bg_value_2D = self.get_bg_value_2D(data_2D.shape)
self.data_model.set_data(data_2D, changes_2D)
self.data_model.set_background(self.bg_gradient, bg_value_2D)

def update_filtered_data(self, current_filter=None):
if current_filter is not None:
assert isinstance(current_filter, dict)
self.current_filter = current_filter
self.filtered_data = self.la_data[self.current_filter]
if np.isscalar(self.filtered_data):
self.filtered_data = la.aslarray(self.filtered_data)
axes = self.get_axes()
hlabels = self.get_hlabels()
vlabels = self.get_vlabels()
data_2D = self.get_2D_data()
changes_2D = self.get_changes_2D()
bg_value_2D = self.get_bg_value_2D(data_2D.shape)
self.axes_model.set_data(axes)
self.hlabels_model.set_data(hlabels)
self.vlabels_model.set_data(vlabels)
self.data_model.set_data(data_2D, changes_2D)
self.data_model.set_background(self.bg_gradient, bg_value_2D)
self.update_axes_and_labels()
self.update_data_2D()

def get_data(self):
return self.la_data
Expand Down
38 changes: 25 additions & 13 deletions larray_editor/arraymodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ class LabelsArrayModel(AbstractArrayModel):
font : QFont, optional
Font. Default is `Calibri` with size 11.
"""
def __init__(self, parent=None, data=None, readonly=False, font=None):
def __init__(self, parent=None, data=None, readonly=False, font=None, orientation=Qt.Horizontal):
self.orientation = orientation
AbstractArrayModel.__init__(self, parent, data, readonly, font)
self.font.setBold(True)

Expand All @@ -138,28 +139,39 @@ def _set_data(self, data, changes=None):
QMessageBox.critical(self.dialog, "Error", "Expected list or tuple.")
data = [[]]
self._data = data
self.total_rows = len(data[0])
self.total_cols = len(data) if self.total_rows > 0 else 0
if self.orientation == Qt.Horizontal:
self.total_rows = len(data) if self.total_cols > 0 else 0
self.total_cols = len(data[0])
else:
self.total_rows = len(data[0])
self.total_cols = len(data) if self.total_rows > 0 else 0
self._compute_rows_cols_loaded()

def flags(self, index):
"""Set editable flag"""
return Qt.ItemIsEnabled

def get_value(self, index):
i = index.row()
j = index.column()
# we need to inverse column and row because of the way vlabels are generated
return str(self._data[j][i])
if self.orientation == Qt.Horizontal:
i, j = index.row(), index.column()
else:
i, j = index.column(), index.row()
return str(self._data[i][j])

# XXX: I wonder if we shouldn't return a 2D Numpy array of strings?
def get_values(self, left=0, top=0, right=None, bottom=None):
if right is None:
right = self.total_rows
if bottom is None:
bottom = self.total_cols
values = [list(line[left:right]) for line in self._data[top:bottom]]
return values
if self.orientation == Qt.Horizontal:
if right is None:
right = self.total_cols
if bottom is None:
bottom = self.total_rows
return [list(line[left:right]) for line in self._data[top:bottom]]
else:
if right is None:
right = self.total_rows
if bottom is None:
bottom = self.total_cols
return [list(line[top:bottom]) for line in self._data[left:right]]

def data(self, index, role=Qt.DisplayRole):
# print('data', index.column(), index.row(), self.rowCount(), self.columnCount(), '\n', self._data)
Expand Down
47 changes: 31 additions & 16 deletions larray_editor/arraywidget.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ def __init__(self, parent, data, readonly=False, bg_value=None, bg_gradient=None
self.model_hlabels = LabelsArrayModel(parent=self, readonly=readonly)
self.view_hlabels = LabelsView(parent=self, model=self.model_hlabels, position=(TOP, RIGHT))

self.model_vlabels = LabelsArrayModel(parent=self, readonly=readonly)
self.model_vlabels = LabelsArrayModel(parent=self, readonly=readonly, orientation=Qt.Vertical)
self.view_vlabels = LabelsView(parent=self, model=self.model_vlabels, position=(BOTTOM, LEFT))

self.model_data = DataArrayModel(parent=self, readonly=readonly, minvalue=minvalue, maxvalue=maxvalue)
Expand Down Expand Up @@ -617,6 +617,13 @@ def __init__(self, parent, data, readonly=False, bg_value=None, bg_gradient=None
self.bgcolor_checkbox = bgcolor
btn_layout.addWidget(bgcolor)

label = QLabel("Horizontal Dimensions")
btn_layout.addWidget(label)
spin = QSpinBox(self)
spin.valueChanged.connect(self.nb_horizontal_dims_changed)
self.nb_horizontal_dims_spinbox = spin
btn_layout.addWidget(spin)

# Set widget layout
layout = QVBoxLayout()
layout.addLayout(self.filters_layout)
Expand Down Expand Up @@ -702,7 +709,8 @@ def dropEvent(self, event):

def set_data(self, data, bg_gradient=None, bg_value=None):
self.data_adapter.set_data(data, bg_gradient=bg_gradient, bg_value=bg_value)
self._update_digits_scientific(self.data_adapter.get_data())
self._update_digits_scientific_dims(self.data_adapter.get_data())
self.nb_horizontal_dims_spinbox.setValue(1)

# update filters
la_data = self.data_adapter.get_data()
Expand All @@ -725,7 +733,7 @@ def set_data(self, data, bg_gradient=None, bg_value=None):
self.view_vlabels.set_default_size()
self.view_data.set_default_size()

def _update_digits_scientific(self, data):
def _update_digits_scientific_dims(self, data):
"""
data : LArray
"""
Expand Down Expand Up @@ -755,6 +763,9 @@ def _update_digits_scientific(self, data):
self.bgcolor_checkbox.setChecked(self.model_data.bgcolor_enabled)
self.bgcolor_checkbox.setEnabled(self.model_data.bgcolor_enabled)

self.nb_horizontal_dims_spinbox.setMinimum(1)
self.nb_horizontal_dims_spinbox.setMaximum(max(1, self.data_adapter.ndim - 1))

def choose_scientific(self, data):
# max_digits = self.get_max_digits()
# default width can fit 8 chars
Expand Down Expand Up @@ -887,7 +898,7 @@ def dirty(self):
def accept_changes(self):
"""Accept changes"""
la_data = self.data_adapter.accept_changes()
self._update_digits_scientific(la_data)
self._update_digits_scientific_dims(la_data)

def reject_changes(self):
"""Reject changes"""
Expand All @@ -912,10 +923,13 @@ def digits_changed(self, value):
self.digits = value
self.model_data.set_format(self.cell_format)

def nb_horizontal_dims_changed(self, value):
self.data_adapter.update_nb_dims_hlabels(value)

def create_filter_combo(self, axis):
def filter_changed(checked_items):
filtered = self.data_adapter.change_filter(axis, checked_items)
self._update_digits_scientific(filtered)
self._update_digits_scientific_dims(filtered)
combo = FilterComboBox(self)
combo.addItems([str(l) for l in axis.labels])
combo.checkedItemsChanged.connect(filter_changed)
Expand Down Expand Up @@ -947,15 +961,15 @@ def _selection_data(self, headers=True, none_selects_all=True):
if not self.data_adapter.ndim:
return raw_data
# FIXME: this is extremely ad-hoc.
# TODO: in the future (pandas-based branch) we should use to_string(data[self._selection_filter()])
# TODO: in the future (multi_index supported) we should use to_string(data[self._selection_filter()])
dim_headers = self.model_axes.get_values()
hlabels = self.model_hlabels.get_values(top=col_min, bottom=col_max)
topheaders = [[dim_header[0] for dim_header in dim_headers] + [label[0] for label in hlabels]]
hlabels = self.model_hlabels.get_values(left=col_min, right=col_max)
topheaders = [dims + labels for dims, labels in zip(dim_headers, hlabels)]
if self.data_adapter.ndim == 1:
return chain(topheaders, [chain([''], row) for row in raw_data])
else:
assert self.data_adapter.ndim > 1
vlabels = self.model_vlabels.get_values(left=row_min, right=row_max)
vlabels = self.model_vlabels.get_values(top=row_min, bottom=row_max)
return chain(topheaders,
[chain([vlabels[j][r] for j in range(len(vlabels))], row)
for r, row in enumerate(raw_data)])
Expand Down Expand Up @@ -1039,12 +1053,13 @@ def plot(self):
row_min, row_max, col_min, col_max = self.view_data._selection_bounds()
dim_names = self.data_adapter.get_axes_names()
# labels
xlabels = [label[0] for label in self.model_hlabels.get_values(top=col_min, bottom=col_max)]
ylabels = self.model_vlabels.get_values(left=row_min, right=row_max)
# transpose ylabels
ylabels = [[str(ylabels[i][j]) for i in range(len(ylabels))] for j in range(len(ylabels[0]))]
# if there is only one dimension, ylabels is empty
if not ylabels:
xlabels = self.model_hlabels.get_values(left=col_min, right=col_max, bottom=self.data_adapter.nb_dims_hlabels)
xlabels = [[str(xlabels[i][j]) for i in range(len(xlabels))] for j in range(len(xlabels[0]))]
if self.data_adapter.ndim > 1:
ylabels = self.model_vlabels.get_values(top=row_min, bottom=row_max)
# transpose ylabels
ylabels = [[str(ylabels[i][j]) for i in range(len(ylabels))] for j in range(len(ylabels[0]))]
else:
ylabels = [[]]

assert data.ndim == 2
Expand All @@ -1064,7 +1079,7 @@ def plot(self):
else:
# plot each row as a line
xlabel = dim_names[-1]
xticklabels = [str(label) for label in xlabels]
xticklabels = ['\n'.join(row) for row in xlabels]
xdata = np.arange(col_max - col_min)
for row in range(len(data)):
ax.plot(xdata, data[row], label=' '.join(ylabels[row]))
Expand Down

0 comments on commit c2c5072

Please sign in to comment.