Skip to content

Commit

Permalink
optimize reset_index calls
Browse files Browse the repository at this point in the history
  • Loading branch information
maximlt committed Jul 19, 2024
1 parent adbddf7 commit 326ee64
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 20 deletions.
34 changes: 23 additions & 11 deletions hvplot/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
process_crs,
process_intake,
process_xarray,
support_index,
check_library,
is_geodataframe,
process_derived_datetime_xarray,
Expand Down Expand Up @@ -1219,6 +1220,9 @@ def _process_data(
and y is None
and not by
):
# Broken, see https://github.com/holoviz/hvplot/issues/1364.
# Dask reset_index doesn't accept a level, so this would need to
# be adapted for Dask.
self.data = data.stack().reset_index(1).rename(columns={'level_1': group_label})
by = group_label
x = 'index'
Expand All @@ -1245,11 +1249,8 @@ def _process_data(
self.variables = indexes + list(self.data.columns)

# Reset groupby dimensions
groupby_index = [g for g in groupby if g in indexes]
if groupby_index:
# Dask and Pandas reset_index don't accept the same arguments
reset_args = (groupby_index,) if isinstance(self.data, pd.DataFrame) else ()
self.data = self.data.reset_index(*reset_args)
if not support_index(self.data) and any(g for g in groupby if g in indexes):
self.data = self.data.reset_index()

if isinstance(by, (np.ndarray, pd.Series)):
by_cols = []
Expand Down Expand Up @@ -1541,7 +1542,11 @@ def __call__(self, kind, x, y):
if self.streaming:
raise NotImplementedError('Streaming and groupby not yet implemented')
data = self.data
if not self.gridded and any(g in self.indexes for g in groups):
if (
not support_index(data)
and not self.gridded
and any(g in self.indexes for g in groups)
):
data = data.reset_index()

if self.datatype in ('geopandas', 'spatialpandas'):
Expand Down Expand Up @@ -1992,7 +1997,7 @@ def single_chart(self, element, x, y, data=None):

if self.by:
if element is Bars and not self.subplots:
if any(y in self.indexes for y in ys):
if not support_index(data) and any(y in self.indexes for y in ys):
data = data.reset_index()
return (
element(data, ([x] if x else []) + self.by, ys)
Expand Down Expand Up @@ -2070,8 +2075,10 @@ def _process_chart_args(self, data, x, y, single_y=False, categories=None):
data = data.sort_values(x)

# set index to column if needed in hover_cols
if self.use_index and any(
c for c in self.hover_cols if c in self.indexes and c not in data.columns
if (
not support_index(data)
and self.use_index
and any(c for c in self.hover_cols if c in self.indexes and c not in data.columns)
):
data = data.reset_index()

Expand Down Expand Up @@ -2175,6 +2182,8 @@ def _category_plot(self, element, x, y, data):

id_vars = [x]
if any(v in self.indexes for v in id_vars):
# Calling reset_index() is required since id_vars from melt
# only accepts column names, not index names.
data = data.reset_index()
data = data[y + [x]]

Expand Down Expand Up @@ -2551,6 +2560,7 @@ def table(self, x=None, y=None, data=None):
self._error_if_unavailable('table')
data = self.data if data is None else data
if isinstance(data.index, (DatetimeIndex, MultiIndex)):
# To get the index displayed in the table as Bokeh doesn't show it.
data = data.reset_index()

cur_opts, compat_opts = self._get_compat_opts('Table')
Expand Down Expand Up @@ -2603,8 +2613,10 @@ def _process_gridded_args(self, data, x, y, z):
if isinstance(data, xr.DataArray):
data = data.to_dataset(name=data.name or 'value')
if is_tabular(data):
if self.use_index and any(
c for c in self.hover_cols if c in self.indexes and c not in data.columns
if (
not support_index(data)
and self.use_index
and any(c for c in self.hover_cols if c in self.indexes and c not in data.columns)
):
data = data.reset_index()
# calculate any derived time
Expand Down
26 changes: 18 additions & 8 deletions hvplot/tests/testcharts.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,12 @@ def test_2d_set_hover_cols_to_list(self, kind, element):
@parameterized.expand([('points', Points), ('paths', Path)])
def test_2d_set_hover_cols_including_index(self, kind, element):
plot = self.cat_df.hvplot(x='x', y='y', hover_cols=['index'], kind=kind)
data = plot.data[0] if kind == 'paths' else plot.data
assert 'index' in data.columns
self.assertEqual(plot, element(self.cat_df.reset_index(), ['x', 'y'], ['index']))
self.assertEqual(plot, element(self.cat_df, ['x', 'y'], ['index']))

@parameterized.expand([('points', Points), ('paths', Path)])
def test_2d_set_hover_cols_to_all(self, kind, element):
plot = self.cat_df.hvplot(x='x', y='y', hover_cols='all', kind=kind)
data = plot.data[0] if kind == 'paths' else plot.data
assert 'index' in data.columns
self.assertEqual(
plot, element(self.cat_df.reset_index(), ['x', 'y'], ['index', 'category'])
)
self.assertEqual(plot, element(self.cat_df, ['x', 'y'], ['index', 'category']))

@parameterized.expand([('points', Points), ('paths', Path)])
def test_2d_set_hover_cols_to_all_with_use_index_as_false(self, kind, element):
Expand Down Expand Up @@ -115,6 +109,22 @@ def setUp(self):
def test_heatmap_2d_index_columns(self):
self.df.hvplot.heatmap()

@parameterized.expand([('points', Points), ('paths', Path)])
def test_2d_set_hover_cols_including_index(self, kind, element):
plot = self.cat_df.hvplot(x='x', y='y', hover_cols=['index'], kind=kind)
data = plot.data[0] if kind == 'paths' else plot.data
assert 'index' in data.columns
self.assertEqual(plot, element(self.cat_df.reset_index(), ['x', 'y'], ['index']))

@parameterized.expand([('points', Points), ('paths', Path)])
def test_2d_set_hover_cols_to_all(self, kind, element):
plot = self.cat_df.hvplot(x='x', y='y', hover_cols='all', kind=kind)
data = plot.data[0] if kind == 'paths' else plot.data
assert 'index' in data.columns
self.assertEqual(
plot, element(self.cat_df.reset_index(), ['x', 'y'], ['index', 'category'])
)


class TestChart1D(ComparisonTestCase):
def setUp(self):
Expand Down
12 changes: 11 additions & 1 deletion hvplot/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,16 @@ def is_xarray_dataarray(data):
return isinstance(data, DataArray)


def support_index(data):
"""
HoloViews added in v1.19.0 support for retaining Pandas indexes (no longer
calling .reset_index()).
Update this utility when other data interfaces support that (geopandas, dask, etc.)
"""
return isinstance(data, pd.DataFrame)


def process_intake(data, use_dask):
if data.container not in ('dataframe', 'xarray'):
raise NotImplementedError(
Expand Down Expand Up @@ -530,7 +540,7 @@ def process_xarray(
data = data.persist() if persist else data
else:
data = dataset.to_dataframe()
if len(data.index.names) > 1:
if not support_index(data) and len(data.index.names) > 1:
data = data.reset_index()
if len(dims) == 0:
dims = ['index']
Expand Down

0 comments on commit 326ee64

Please sign in to comment.