Skip to content

Commit

Permalink
Avoid importing from dask_cudf.core (#593)
Browse files Browse the repository at this point in the history
This PR is intended to resolve test failures related to the recent dask-expr migration in `dask.dataframe`. I noticed that cuxfilter was importing from `dask_cudf.core` (no longer allowed when query-planning is enabled). There may be other issues as well.

Authors:
  - Richard (Rick) Zamora (https://github.com/rjzamora)

Approvers:
  - Ajay Thorve (https://github.com/AjayThorve)

URL: #593
  • Loading branch information
rjzamora authored Apr 26, 2024
1 parent e357bea commit 41819ce
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 25 deletions.
6 changes: 3 additions & 3 deletions python/cuxfilter/assets/numba_kernels/gpu_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def calc_value_counts(a_gpu, stride, min_value, custom_binning=False):
"""
custom_binning = custom_binning and stride

if isinstance(a_gpu, dask_cudf.core.Series):
if isinstance(a_gpu, dask_cudf.Series):
if not custom_binning:
val_count = a_gpu.value_counts()
else:
Expand Down Expand Up @@ -64,7 +64,7 @@ def calc_groupby(chart: Type[BaseChart], data, agg=None):

if agg is None:
temp_df[chart.y] = data.dropna(subset=[chart.x])[chart.y]
if isinstance(temp_df, dask_cudf.core.DataFrame):
if isinstance(temp_df, dask_cudf.DataFrame):
groupby_res = getattr(
temp_df.groupby(by=[chart.x], sort=True), chart.aggregate_fn
)()
Expand All @@ -76,7 +76,7 @@ def calc_groupby(chart: Type[BaseChart], data, agg=None):
else:
for key, agg_fn in agg.items():
temp_df[key] = data[key]
if isinstance(data, dask_cudf.core.DataFrame):
if isinstance(data, dask_cudf.DataFrame):
groupby_res = None
for key, agg_fn in agg.items():
groupby_res_temp = getattr(
Expand Down
4 changes: 2 additions & 2 deletions python/cuxfilter/charts/core/core_view_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def name(self):

def initiate_chart(self, dashboard_cls):
data = dashboard_cls._cuxfilter_df.data
if isinstance(data, dask_cudf.core.DataFrame):
if isinstance(data, dask_cudf.DataFrame):
if self.force_computation:
self.generate_chart(data.compute())
else:
Expand Down Expand Up @@ -113,7 +113,7 @@ def get_dashboard_view(self):
return pn.panel(self.chart, sizing_mode="stretch_both")

def reload_chart(self, data):
if isinstance(data, dask_cudf.core.DataFrame):
if isinstance(data, dask_cudf.DataFrame):
if self.force_computation:
self.chart.data = self._format_data(
data[self.columns].compute()
Expand Down
4 changes: 1 addition & 3 deletions python/cuxfilter/charts/core/non_aggregate/core_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,7 @@ def initiate_chart(self, dashboard_cls):
dashboard_cls._cuxfilter_df.data[self.node_y].min(),
dashboard_cls._cuxfilter_df.data[self.node_y].max(),
)
if isinstance(
dashboard_cls._cuxfilter_df.data, dask_cudf.core.DataFrame
):
if isinstance(dashboard_cls._cuxfilter_df.data, dask_cudf.DataFrame):
self.x_range = dd.compute(*self.x_range)
self.y_range = dd.compute(*self.y_range)

Expand Down
4 changes: 2 additions & 2 deletions python/cuxfilter/charts/datashader/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def format_source_data(self, data):
self.x_range = (self.source[self.x].min(), self.source[self.x].max())
self.y_range = (self.source[self.y].min(), self.source[self.y].max())

if isinstance(data, dask_cudf.core.DataFrame):
if isinstance(data, dask_cudf.DataFrame):
self.x_range = dd.compute(*self.x_range)
self.y_range = dd.compute(*self.y_range)

Expand Down Expand Up @@ -449,7 +449,7 @@ def format_source_data(self, data):
self.source[self.y].min().min(),
self.source[self.y].max().max(),
)
if isinstance(data, dask_cudf.core.DataFrame):
if isinstance(data, dask_cudf.DataFrame):
self.x_range = dd.compute(*self.x_range)
self.y_range = dd.compute(*self.y_range)

Expand Down
4 changes: 2 additions & 2 deletions python/cuxfilter/charts/panel_widgets/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def initiate_chart(self, dashboard_cls):
_series = dashboard_cls._cuxfilter_df.data[self.x].value_counts()
self.data_points = (
_series.compute().shape[0]
if isinstance(_series, dask_cudf.core.Series)
if isinstance(_series, dask_cudf.Series)
else _series.shape[0]
)
del _series
Expand Down Expand Up @@ -343,7 +343,7 @@ def calc_list_of_values(self, data):
"""
if self.label_map is None:
self.list_of_values = data[self.x].unique()
if isinstance(data, dask_cudf.core.DataFrame):
if isinstance(data, dask_cudf.DataFrame):
self.list_of_values = self.list_of_values.compute()

self.list_of_values = self.list_of_values.to_pandas().tolist()
Expand Down
15 changes: 2 additions & 13 deletions python/cuxfilter/tests/charts/datashader/test_graph_assets.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import dask_cudf
import pytest
import cudf
import cupy as cp
from cuxfilter.charts.datashader.custom_extensions import graph_assets
from dask.dataframe import assert_eq

from ..utils import initialize_df, df_types

Expand Down Expand Up @@ -50,15 +50,4 @@ def test_calc_connected_edges(
node_y_dtype=cp.float32,
).reset_index(drop=True)

res = (
res.compute().reset_index(drop=True)
if isinstance(res, dask_cudf.DataFrame)
else res
)
result = (
result.compute()
if isinstance(result, dask_cudf.DataFrame)
else result
)

assert res.to_pandas().equals(result.to_pandas())
assert_eq(res, result, check_divisions=False, check_index=False)

0 comments on commit 41819ce

Please sign in to comment.