From 6f43d6b94614134792671d2c1a90aff588682c94 Mon Sep 17 00:00:00 2001 From: Jon Mease Date: Sat, 10 Jun 2023 15:37:00 -0400 Subject: [PATCH] Use DataFrameLike protocol for the transformed_data signature The returned DataFrames can be arrow or Polars as well if those are used as the input to the Chart. --- altair/__init__.py | 1 + altair/utils/_transformed_data.py | 13 ++++++------- altair/utils/core.py | 9 +++++++-- altair/vegalite/v5/api.py | 13 +++++++------ 4 files changed, 21 insertions(+), 15 deletions(-) diff --git a/altair/__init__.py b/altair/__init__.py index 1a66239a5..3c0a19f13 100644 --- a/altair/__init__.py +++ b/altair/__init__.py @@ -124,6 +124,7 @@ "Cyclical", "Data", "DataFormat", + "DataFrameLike", "DataSource", "Datasets", "DateTime", diff --git a/altair/utils/_transformed_data.py b/altair/utils/_transformed_data.py index a343269e3..fa89f8727 100644 --- a/altair/utils/_transformed_data.py +++ b/altair/utils/_transformed_data.py @@ -1,7 +1,5 @@ from typing import List, Optional, Tuple, Dict, Iterable, overload, Union -import pandas as pd - from altair import ( Chart, FacetChart, @@ -11,6 +9,7 @@ ConcatChart, data_transformers, ) +from altair.utils.core import DataFrameLike from altair.utils.schemapi import Undefined Scope = Tuple[int, ...] @@ -24,7 +23,7 @@ def transformed_data( chart: Union[Chart, FacetChart], row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, -) -> Optional[pd.DataFrame]: +) -> Optional[DataFrameLike]: ... @@ -33,7 +32,7 @@ def transformed_data( chart: Union[LayerChart, HConcatChart, VConcatChart, ConcatChart], row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, -) -> List[pd.DataFrame]: +) -> List[DataFrameLike]: ... @@ -54,9 +53,9 @@ def transformed_data(chart, row_limit=None, exclude=None): Returns ------- - pandas DataFrame or list of pandas DataFrames or None - If input chart is a Chart or Facet Chart, returns a pandas DataFrame of the - transformed data. Otherwise, returns a list of pandas DataFrames of the + DataFrame or list of DataFrames or None + If input chart is a Chart or Facet Chart, returns a DataFrame of the + transformed data. Otherwise, returns a list of DataFrames of the transformed data """ try: diff --git a/altair/utils/core.py b/altair/utils/core.py index 41e886001..4821de3c6 100644 --- a/altair/utils/core.py +++ b/altair/utils/core.py @@ -18,9 +18,9 @@ from altair.utils.schemapi import SchemaBase if sys.version_info >= (3, 10): - from typing import ParamSpec + from typing import ParamSpec, Protocol else: - from typing_extensions import ParamSpec + from typing_extensions import ParamSpec, Protocol try: from pandas.api.types import infer_dtype as _infer_dtype @@ -32,6 +32,11 @@ _P = ParamSpec("_P") +class DataFrameLike(Protocol): + def __dataframe__(self, *args, **kwargs): + ... + + def infer_dtype(value): """Infer the dtype of the value. diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 683e5b155..ba7801143 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -21,6 +21,7 @@ from .display import renderers, VEGALITE_VERSION, VEGAEMBED_VERSION, VEGA_VERSION from .theme import themes from .compiler import vegalite_compilers +from ...utils.core import DataFrameLike if sys.version_info >= (3, 11): from typing import Self @@ -2661,7 +2662,7 @@ def _transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, - ) -> Optional[pd.DataFrame]: + ) -> Optional[DataFrameLike]: """Evaluate a Chart's transforms Evaluate the data transforms associated with a Chart and return the @@ -2947,7 +2948,7 @@ def _transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, - ) -> List[pd.DataFrame]: + ) -> List[DataFrameLike]: """Evaluate a ConcatChart's transforms Evaluate the data transforms associated with a ConcatChart and return the @@ -3044,7 +3045,7 @@ def _transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, - ) -> List[pd.DataFrame]: + ) -> List[DataFrameLike]: """Evaluate a HConcatChart's transforms Evaluate the data transforms associated with a HConcatChart and return the @@ -3141,7 +3142,7 @@ def _transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, - ) -> List[pd.DataFrame]: + ) -> List[DataFrameLike]: """Evaluate a VConcatChart's transforms Evaluate the data transforms associated with a VConcatChart and return the @@ -3237,7 +3238,7 @@ def _transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, - ) -> List[pd.DataFrame]: + ) -> List[DataFrameLike]: """Evaluate a LayerChart's transforms Evaluate the data transforms associated with a LayerChart and return the @@ -3352,7 +3353,7 @@ def _transformed_data( self, row_limit: Optional[int] = None, exclude: Optional[Iterable[str]] = None, - ) -> Optional[pd.DataFrame]: + ) -> Optional[DataFrameLike]: """Evaluate a FacetChart's transforms Evaluate the data transforms associated with a FacetChart and return the