From cc18207fb7e7eaa74bb2986c8ea7515da10eb738 Mon Sep 17 00:00:00 2001 From: Matthew Murray <41342305+Matt711@users.noreply.github.com> Date: Fri, 16 Aug 2024 17:27:52 -0400 Subject: [PATCH] [BUG] Fix `cudf.pandas` integration issues with `cuxfilter` (#619) THis PR makes sure that the proxy objects returned/used by cudf.pandas work with cuxfilter. closes #605 Authors: - Matthew Murray (https://github.com/Matt711) - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - Ajay Thorve (https://github.com/AjayThorve) URL: https://github.com/rapidsai/cuxfilter/pull/619 --- python/cuxfilter/charts/bokeh/plots/bar.py | 5 ++++- python/cuxfilter/charts/core/core_chart.py | 9 +++++++-- python/cuxfilter/charts/core/non_aggregate/core_graph.py | 9 +++++++-- .../charts/core/non_aggregate/core_stacked_line.py | 5 ++++- .../datashader/custom_extensions/holoviews_datashader.py | 7 ++++--- python/cuxfilter/charts/datashader/plots.py | 3 ++- python/cuxfilter/charts/panel_widgets/plots.py | 5 ++++- 7 files changed, 32 insertions(+), 11 deletions(-) diff --git a/python/cuxfilter/charts/bokeh/plots/bar.py b/python/cuxfilter/charts/bokeh/plots/bar.py index 3c5ea43c..f6408632 100644 --- a/python/cuxfilter/charts/bokeh/plots/bar.py +++ b/python/cuxfilter/charts/bokeh/plots/bar.py @@ -1,5 +1,6 @@ import holoviews as hv import param +import pandas as pd from cuxfilter.charts.core.aggregate import BaseAggregateChart from cuxfilter.assets.numba_kernels import calc_groupby import panel as pn @@ -10,7 +11,9 @@ class InteractiveBar(param.Parameterized): x = param.String("x", doc="x axis column name") y = param.List(["y"], doc="y axis column names as a list") source_df = param.ClassSelector( - class_=cudf.DataFrame, default=cudf.DataFrame(), doc="source dataframe" + class_=(cudf.DataFrame, pd.DataFrame), + default=cudf.DataFrame(), + doc="source dataframe", ) box_stream = param.ClassSelector( class_=hv.streams.SelectionXY, default=hv.streams.SelectionXY() diff --git a/python/cuxfilter/charts/core/core_chart.py b/python/cuxfilter/charts/core/core_chart.py index f89d267c..fa4555de 100644 --- a/python/cuxfilter/charts/core/core_chart.py +++ b/python/cuxfilter/charts/core/core_chart.py @@ -1,4 +1,5 @@ import cudf +import pandas as pd import dask_cudf import logging import panel as pn @@ -61,7 +62,9 @@ def library_specific_params(self): def x_dtype(self): if isinstance(self.source, ColumnDataSource): return self.source.data[self.data_x_axis].dtype - elif isinstance(self.source, (cudf.DataFrame, dask_cudf.DataFrame)): + elif isinstance( + self.source, (cudf.DataFrame, dask_cudf.DataFrame, pd.DataFrame) + ): return self.source[self.x].dtype return None @@ -69,7 +72,9 @@ def x_dtype(self): def y_dtype(self): if isinstance(self.source, ColumnDataSource): return self.source.data[self.data_x_axis].dtype - elif isinstance(self.source, (cudf.DataFrame, dask_cudf.DataFrame)): + elif isinstance( + self.source, (cudf.DataFrame, dask_cudf.DataFrame, pd.DataFrame) + ): return self.source[self.y].dtype return None diff --git a/python/cuxfilter/charts/core/non_aggregate/core_graph.py b/python/cuxfilter/charts/core/non_aggregate/core_graph.py index ab9b4101..d9fc9a51 100644 --- a/python/cuxfilter/charts/core/non_aggregate/core_graph.py +++ b/python/cuxfilter/charts/core/non_aggregate/core_graph.py @@ -1,5 +1,6 @@ from typing import Tuple import cudf +import pandas as pd import dask.dataframe as dd import dask_cudf import panel as pn @@ -169,13 +170,17 @@ def __init__( @property def x_dtype(self): - if isinstance(self.nodes, (cudf.DataFrame, dask_cudf.DataFrame)): + if isinstance( + self.nodes, (cudf.DataFrame, dask_cudf.DataFrame, pd.DataFrame) + ): return self.nodes[self.node_x].dtype return None @property def y_dtype(self): - if isinstance(self.nodes, (cudf.DataFrame, dask_cudf.DataFrame)): + if isinstance( + self.nodes, (cudf.DataFrame, dask_cudf.DataFrame, pd.DataFrame) + ): return self.nodes[self.node_y].dtype return None diff --git a/python/cuxfilter/charts/core/non_aggregate/core_stacked_line.py b/python/cuxfilter/charts/core/non_aggregate/core_stacked_line.py index 7ca73fbd..b4982781 100644 --- a/python/cuxfilter/charts/core/non_aggregate/core_stacked_line.py +++ b/python/cuxfilter/charts/core/non_aggregate/core_stacked_line.py @@ -1,4 +1,5 @@ import cudf +import pandas as pd import dask_cudf from typing import Tuple import panel as pn @@ -30,7 +31,9 @@ def y_dtype(self): overwriting the y_dtype property from BaseChart for stackedLines where self.y is a list of columns """ - if isinstance(self.source, (cudf.DataFrame, dask_cudf.DataFrame)): + if isinstance( + self.source, (cudf.DataFrame, dask_cudf.DataFrame, pd.DataFrame) + ): return self.source[self.y[0]].dtype return None diff --git a/python/cuxfilter/charts/datashader/custom_extensions/holoviews_datashader.py b/python/cuxfilter/charts/datashader/custom_extensions/holoviews_datashader.py index 548bf498..e66d0c6d 100644 --- a/python/cuxfilter/charts/datashader/custom_extensions/holoviews_datashader.py +++ b/python/cuxfilter/charts/datashader/custom_extensions/holoviews_datashader.py @@ -18,6 +18,7 @@ import requests from PIL import Image from io import BytesIO +import pandas as pd def load_image(url): @@ -180,7 +181,7 @@ def add_reset_event(self, callback_fn): class InteractiveDatashader(InteractiveDatashaderBase): source_df = param.ClassSelector( - class_=(cudf.DataFrame, dask_cudf.DataFrame), + class_=(cudf.DataFrame, dask_cudf.DataFrame, pd.DataFrame), doc="source cuDF/dask_cuDF dataframe", ) x = param.String("x") @@ -498,11 +499,11 @@ def view(self): class InteractiveDatashaderGraph(InteractiveDatashaderBase): nodes_df = param.ClassSelector( - class_=(cudf.DataFrame, dask_cudf.DataFrame), + class_=(cudf.DataFrame, dask_cudf.DataFrame, pd.DataFrame), doc="nodes cuDF/dask_cuDF dataframe", ) edges_df = param.ClassSelector( - class_=(cudf.DataFrame, dask_cudf.DataFrame), + class_=(cudf.DataFrame, dask_cudf.DataFrame, pd.DataFrame), doc="edges cuDF/dask_cuDF dataframe", ) node_x = param.String("x") diff --git a/python/cuxfilter/charts/datashader/plots.py b/python/cuxfilter/charts/datashader/plots.py index b53fd364..dfdfd61f 100644 --- a/python/cuxfilter/charts/datashader/plots.py +++ b/python/cuxfilter/charts/datashader/plots.py @@ -19,6 +19,7 @@ import dask.dataframe as dd import cupy as cp import cudf +import pandas as pd import holoviews as hv from bokeh import events from PIL import Image @@ -145,7 +146,7 @@ def format_source_data(self, dataframe): Ouput: """ - if isinstance(dataframe, cudf.DataFrame): + if isinstance(dataframe, (cudf.DataFrame, pd.DataFrame)): self.nodes = dataframe else: self.nodes = dataframe.data diff --git a/python/cuxfilter/charts/panel_widgets/plots.py b/python/cuxfilter/charts/panel_widgets/plots.py index f165bb4f..8bca3c65 100644 --- a/python/cuxfilter/charts/panel_widgets/plots.py +++ b/python/cuxfilter/charts/panel_widgets/plots.py @@ -7,6 +7,7 @@ from ...assets.cudf_utils import get_min_max from bokeh.models import ColumnDataSource import cudf +import pandas as pd import dask_cudf import panel as pn import uuid @@ -88,7 +89,9 @@ class DateRangeSlider(BaseWidget): def x_dtype(self): if isinstance(self.source, ColumnDataSource): return self.source.data[self.data_x_axis].dtype - elif isinstance(self.source, (cudf.DataFrame, dask_cudf.DataFrame)): + elif isinstance( + self.source, (cudf.DataFrame, dask_cudf.DataFrame, pd.DataFrame) + ): return self.source[self.x].dtype return None