diff --git a/hvplot/converter.py b/hvplot/converter.py index 9bd19e84d..7bb159c3d 100644 --- a/hvplot/converter.py +++ b/hvplot/converter.py @@ -64,6 +64,7 @@ process_crs, process_intake, process_xarray, + support_index, check_library, is_geodataframe, process_derived_datetime_xarray, @@ -1219,6 +1220,9 @@ def _process_data( and y is None and not by ): + # Broken, see https://github.com/holoviz/hvplot/issues/1364. + # Dask reset_index doesn't accept a level, so this would need to + # be adapted for Dask. self.data = data.stack().reset_index(1).rename(columns={'level_1': group_label}) by = group_label x = 'index' @@ -1245,11 +1249,8 @@ def _process_data( self.variables = indexes + list(self.data.columns) # Reset groupby dimensions - groupby_index = [g for g in groupby if g in indexes] - if groupby_index: - # Dask and Pandas reset_index don't accept the same arguments - reset_args = (groupby_index,) if isinstance(self.data, pd.DataFrame) else () - self.data = self.data.reset_index(*reset_args) + if not support_index(self.data) and any(g for g in groupby if g in indexes): + self.data = self.data.reset_index() if isinstance(by, (np.ndarray, pd.Series)): by_cols = [] @@ -1541,7 +1542,11 @@ def __call__(self, kind, x, y): if self.streaming: raise NotImplementedError('Streaming and groupby not yet implemented') data = self.data - if not self.gridded and any(g in self.indexes for g in groups): + if ( + not support_index(data) + and not self.gridded + and any(g in self.indexes for g in groups) + ): data = data.reset_index() if self.datatype in ('geopandas', 'spatialpandas'): @@ -1992,7 +1997,7 @@ def single_chart(self, element, x, y, data=None): if self.by: if element is Bars and not self.subplots: - if any(y in self.indexes for y in ys): + if not support_index(data) and any(y in self.indexes for y in ys): data = data.reset_index() return ( element(data, ([x] if x else []) + self.by, ys) @@ -2070,8 +2075,10 @@ def _process_chart_args(self, data, x, y, single_y=False, categories=None): data = data.sort_values(x) # set index to column if needed in hover_cols - if self.use_index and any( - c for c in self.hover_cols if c in self.indexes and c not in data.columns + if ( + not support_index(data) + and self.use_index + and any(c for c in self.hover_cols if c in self.indexes and c not in data.columns) ): data = data.reset_index() @@ -2175,6 +2182,8 @@ def _category_plot(self, element, x, y, data): id_vars = [x] if any(v in self.indexes for v in id_vars): + # Calling reset_index() is required since id_vars from melt + # only accepts column names, not index names. data = data.reset_index() data = data[y + [x]] @@ -2551,6 +2560,7 @@ def table(self, x=None, y=None, data=None): self._error_if_unavailable('table') data = self.data if data is None else data if isinstance(data.index, (DatetimeIndex, MultiIndex)): + # To get the index displayed in the table as Bokeh doesn't show it. data = data.reset_index() cur_opts, compat_opts = self._get_compat_opts('Table') @@ -2603,8 +2613,10 @@ def _process_gridded_args(self, data, x, y, z): if isinstance(data, xr.DataArray): data = data.to_dataset(name=data.name or 'value') if is_tabular(data): - if self.use_index and any( - c for c in self.hover_cols if c in self.indexes and c not in data.columns + if ( + not support_index(data) + and self.use_index + and any(c for c in self.hover_cols if c in self.indexes and c not in data.columns) ): data = data.reset_index() # calculate any derived time diff --git a/hvplot/tests/testcharts.py b/hvplot/tests/testcharts.py index b8c2be173..c3107e929 100644 --- a/hvplot/tests/testcharts.py +++ b/hvplot/tests/testcharts.py @@ -54,18 +54,12 @@ def test_2d_set_hover_cols_to_list(self, kind, element): @parameterized.expand([('points', Points), ('paths', Path)]) def test_2d_set_hover_cols_including_index(self, kind, element): plot = self.cat_df.hvplot(x='x', y='y', hover_cols=['index'], kind=kind) - data = plot.data[0] if kind == 'paths' else plot.data - assert 'index' in data.columns - self.assertEqual(plot, element(self.cat_df.reset_index(), ['x', 'y'], ['index'])) + self.assertEqual(plot, element(self.cat_df, ['x', 'y'], ['index'])) @parameterized.expand([('points', Points), ('paths', Path)]) def test_2d_set_hover_cols_to_all(self, kind, element): plot = self.cat_df.hvplot(x='x', y='y', hover_cols='all', kind=kind) - data = plot.data[0] if kind == 'paths' else plot.data - assert 'index' in data.columns - self.assertEqual( - plot, element(self.cat_df.reset_index(), ['x', 'y'], ['index', 'category']) - ) + self.assertEqual(plot, element(self.cat_df, ['x', 'y'], ['index', 'category'])) @parameterized.expand([('points', Points), ('paths', Path)]) def test_2d_set_hover_cols_to_all_with_use_index_as_false(self, kind, element): @@ -115,6 +109,22 @@ def setUp(self): def test_heatmap_2d_index_columns(self): self.df.hvplot.heatmap() + @parameterized.expand([('points', Points), ('paths', Path)]) + def test_2d_set_hover_cols_including_index(self, kind, element): + plot = self.cat_df.hvplot(x='x', y='y', hover_cols=['index'], kind=kind) + data = plot.data[0] if kind == 'paths' else plot.data + assert 'index' in data.columns + self.assertEqual(plot, element(self.cat_df.reset_index(), ['x', 'y'], ['index'])) + + @parameterized.expand([('points', Points), ('paths', Path)]) + def test_2d_set_hover_cols_to_all(self, kind, element): + plot = self.cat_df.hvplot(x='x', y='y', hover_cols='all', kind=kind) + data = plot.data[0] if kind == 'paths' else plot.data + assert 'index' in data.columns + self.assertEqual( + plot, element(self.cat_df.reset_index(), ['x', 'y'], ['index', 'category']) + ) + class TestChart1D(ComparisonTestCase): def setUp(self): diff --git a/hvplot/util.py b/hvplot/util.py index 0c498529d..13f1b9249 100644 --- a/hvplot/util.py +++ b/hvplot/util.py @@ -446,6 +446,16 @@ def is_xarray_dataarray(data): return isinstance(data, DataArray) +def support_index(data): + """ + HoloViews added in v1.19.0 support for retaining Pandas indexes (no longer + calling .reset_index()). + + Update this utility when other data interfaces support that (geopandas, dask, etc.) + """ + return isinstance(data, pd.DataFrame) + + def process_intake(data, use_dask): if data.container not in ('dataframe', 'xarray'): raise NotImplementedError( @@ -530,7 +540,7 @@ def process_xarray( data = data.persist() if persist else data else: data = dataset.to_dataframe() - if len(data.index.names) > 1: + if not support_index(data) and len(data.index.names) > 1: data = data.reset_index() if len(dims) == 0: dims = ['index']