Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scalar_level in MultiIndex #1426

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 29 additions & 25 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@
merge_data_and_coords)
from .utils import (Frozen, SortedKeysDict, maybe_wrap_array, hashable,
decode_numpy_dict_values, ensure_us_time_resolution)
from .variable import (Variable, as_variable, IndexVariable,
broadcast_variables)
from .variable import Variable, as_variable, IndexVariable, broadcast_variables
from .pycompat import (iteritems, basestring, OrderedDict,
integer_types, dask_array_type, range)
from .options import OPTIONS
Expand Down Expand Up @@ -576,21 +575,16 @@ def _replace_vars_and_dims(self, variables, coord_names=None, dims=None,
return obj

def _replace_indexes(self, indexes):
"""
Make some index_level to scalar_level.
indexes: mapping from dimension name to new index.
"""
if not len(indexes):
return self
variables = self._variables.copy()
for name, idx in indexes.items():
variables[name] = IndexVariable(name, idx)
obj = self._replace_vars_and_dims(variables)

# switch from dimension to level names, if necessary
dim_names = {}
for dim, idx in indexes.items():
if not isinstance(idx, pd.MultiIndex) and idx.name != dim:
dim_names[dim] = idx.name
if dim_names:
obj = obj.rename(dim_names)
return obj
variables[name] = variables[name].reset_levels(idx.names)
return self._replace_vars_and_dims(variables)

def copy(self, deep=False):
"""Returns a copy of this dataset.
Expand Down Expand Up @@ -627,7 +621,7 @@ def _level_coords(self):
for cname in self._coord_names:
var = self.variables[cname]
if var.ndim == 1:
level_names = var.to_index_variable().level_names
level_names = var.all_level_names
if level_names is not None:
dim, = var.dims
level_coords.update({lname: dim for lname in level_names})
Expand Down Expand Up @@ -1127,10 +1121,7 @@ def isel(self, drop=False, **indexers):
Dataset.isel_points
DataArray.isel
"""
invalid = [k for k in indexers if k not in self.dims]
if invalid:
raise ValueError("dimensions %r do not exist" % invalid)

indexers = indexing.get_dim_pos_indexers(self, indexers)
# all indexers should be int, slice or np.ndarrays
indexers = [(k, (np.asarray(v)
if not isinstance(v, integer_types + (slice,))
Expand Down Expand Up @@ -1607,6 +1598,9 @@ def expand_dims(self, dim, axis=None):
If dim is already a scalar coordinate, it will be promoted to a 1D
coordinate consisting of a single value.

If dim is a scalar-level of MultiIndex, this level is changed to
index-level.

Parameters
----------
dim : str or sequence of str.
Expand All @@ -1629,6 +1623,12 @@ def expand_dims(self, dim, axis=None):

if isinstance(dim, basestring):
dim = [dim]
else:
dim = list(dim)
# scalars to converted to index-level
scalars = [d for d in dim if d in self._level_coords]
dim = [d for d in dim if d not in scalars]

if axis is not None and not isinstance(axis, (list, tuple)):
axis = [axis]

Expand All @@ -1653,7 +1653,8 @@ def expand_dims(self, dim, axis=None):
variables = OrderedDict()
for k, v in iteritems(self._variables):
if k not in dim:
if k in self._coord_names: # Do not change coordinates
if k in self._coord_names:
# Do not change coordinates
variables[k] = v
else:
result_ndim = len(v.dims) + len(axis)
Expand Down Expand Up @@ -1682,6 +1683,13 @@ def expand_dims(self, dim, axis=None):
# it will be promoted to a 1D coordinate with a single value.
variables[k] = v.set_dims(k)

# Convert scalar-level of MultiIndex to index-level
for k, v in iteritems(self._variables):
if v.scalar_level_names is not None and len(scalars) > 0:
level_dims = [s for s in scalars if s in
v.scalar_level_names] + list(v.dims)
variables[k] = v.set_dims(level_dims)

return self._replace_vars_and_dims(variables, self._coord_names)

def set_index(self, append=False, inplace=False, **indexes):
Expand Down Expand Up @@ -1768,11 +1776,7 @@ def reorder_levels(self, inplace=False, **dim_order):
replace_variables = {}
for dim, order in dim_order.items():
coord = self._variables[dim]
index = coord.to_index()
if not isinstance(index, pd.MultiIndex):
raise ValueError("coordinate %r has no MultiIndex" % dim)
replace_variables[dim] = IndexVariable(coord.dims,
index.reorder_levels(order))
replace_variables[dim] = coord.reorder_levels(dim, order)
variables = self._variables.copy()
variables.update(replace_variables)
return self._replace_vars_and_dims(variables, inplace=inplace)
Expand All @@ -1790,7 +1794,7 @@ def _stack_once(self, dims, new_dim):
variables[name] = stacked_var
else:
variables[name] = var.copy(deep=False)

# TODO move to IndexVariable method
# consider dropping levels that are unused?
levels = [self.get_index(dim) for dim in dims]
if hasattr(pd, 'RangeIndex'):
Expand Down
23 changes: 15 additions & 8 deletions xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,18 +214,25 @@ def _summarize_var_or_coord(name, var, col_width, show_values=True,
return front_str + values_str


def _summarize_coord_multiindex(coord, col_width, marker):
first_col = pretty_print(u' %s %s ' % (marker, coord.name), col_width)
return u'%s(%s) MultiIndex' % (first_col, unicode_type(coord.dims[0]))
def _summarize_coord_multiindex(coord, col_width, marker, name=None):
name = name or coord.name
first_col = pretty_print(u' %s %s ' % (marker, name), col_width)
if len(coord.dims) == 0:
return u'%sMultiIndex' % (first_col)
else:
return u'%s(%s) MultiIndex' % (first_col, unicode_type(coord.dims[0]))


def _summarize_coord_levels(coord, col_width, marker=u'-'):
relevant_coord = coord[:30]
if len(coord.dims) == 0:
relevant_coord = coord # scalar MultiIndex
else:
relevant_coord = coord[:30]
return u'\n'.join(
[_summarize_var_or_coord(lname,
relevant_coord.get_level_variable(lname),
col_width, marker=marker)
for lname in coord.level_names])
for lname in coord.all_level_names])


def _not_remote(var):
Expand All @@ -247,11 +254,11 @@ def summarize_coord(name, var, col_width):
is_index = name in var.dims
show_values = is_index or _not_remote(var)
marker = u'*' if is_index else u' '
if is_index:
coord = var.variable.to_index_variable()
if name in var.coords:
coord = var.variable
if coord.level_names is not None:
return u'\n'.join(
[_summarize_coord_multiindex(coord, col_width, marker),
[_summarize_coord_multiindex(coord, col_width, marker, name),
_summarize_coord_levels(coord, col_width)])
return _summarize_var_or_coord(name, var, col_width, show_values, marker)

Expand Down
Loading