Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle pyarrow-backed columns in pandas 2 DataFrames #3128

Merged
merged 6 commits into from
Jul 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions altair/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,13 @@ def sanitize_geo_interface(geo: MutableMapping) -> dict:
return geo_dct


def numpy_is_subtype(dtype: Any, subtype: Any) -> bool:
try:
return np.issubdtype(dtype, subtype)
except (NotImplementedError, TypeError):
return False


def sanitize_dataframe(df: pd.DataFrame) -> pd.DataFrame: # noqa: C901
"""Sanitize a DataFrame to prepare it for serialization.

Expand Down Expand Up @@ -339,26 +346,27 @@ def to_list_if_array(val):
return val

for col_name, dtype in df.dtypes.items():
if str(dtype) == "category":
dtype_name = str(dtype)
if dtype_name == "category":
# Work around bug in to_json for categorical types in older versions of pandas
# https://github.com/pydata/pandas/issues/10778
# https://github.com/altair-viz/altair/pull/2170
col = df[col_name].astype(object)
df[col_name] = col.where(col.notnull(), None)
elif str(dtype) == "string":
elif dtype_name == "string":
# dedicated string datatype (since 1.0)
# https://pandas.pydata.org/pandas-docs/version/1.0.0/whatsnew/v1.0.0.html#dedicated-string-data-type
col = df[col_name].astype(object)
df[col_name] = col.where(col.notnull(), None)
elif str(dtype) == "bool":
elif dtype_name == "bool":
# convert numpy bools to objects; np.bool is not JSON serializable
df[col_name] = df[col_name].astype(object)
elif str(dtype) == "boolean":
elif dtype_name == "boolean":
# dedicated boolean datatype (since 1.0)
# https://pandas.io/docs/user_guide/boolean.html
col = df[col_name].astype(object)
df[col_name] = col.where(col.notnull(), None)
elif str(dtype).startswith("datetime"):
elif dtype_name.startswith("datetime") or dtype_name.startswith("timestamp"):
# Convert datetimes to strings. This needs to be a full ISO string
# with time, which is why we cannot use ``col.astype(str)``.
# This is because Javascript parses date-only times in UTC, but
Expand All @@ -368,18 +376,18 @@ def to_list_if_array(val):
df[col_name] = (
df[col_name].apply(lambda x: x.isoformat()).replace("NaT", "")
)
elif str(dtype).startswith("timedelta"):
elif dtype_name.startswith("timedelta"):
raise ValueError(
'Field "{col_name}" has type "{dtype}" which is '
"not supported by Altair. Please convert to "
"either a timestamp or a numerical value."
"".format(col_name=col_name, dtype=dtype)
)
elif str(dtype).startswith("geometry"):
elif dtype_name.startswith("geometry"):
# geopandas >=0.6.1 uses the dtype geometry. Continue here
# otherwise it will give an error on np.issubdtype(dtype, np.integer)
continue
elif str(dtype) in {
elif dtype_name in {
"Int8",
"Int16",
"Int32",
Expand All @@ -394,10 +402,10 @@ def to_list_if_array(val):
# https://pandas.pydata.org/pandas-docs/version/0.25/whatsnew/v0.24.0.html#optional-integer-na-support
col = df[col_name].astype(object)
df[col_name] = col.where(col.notnull(), None)
elif np.issubdtype(dtype, np.integer):
elif numpy_is_subtype(dtype, np.integer):
# convert integers to objects; np.int is not JSON serializable
df[col_name] = df[col_name].astype(object)
elif np.issubdtype(dtype, np.floating):
elif numpy_is_subtype(dtype, np.floating):
# For floats, convert to Python float: np.float is not JSON serializable
# Also convert NaN/inf values to null, as they are not JSON serializable
col = df[col_name]
Expand Down Expand Up @@ -635,7 +643,7 @@ def infer_vegalite_type_for_dfi_column(
# error message for the presence of datetime64.
#
# See https://github.com/pandas-dev/pandas/issues/54239
if "datetime64" in e.args[0]:
if "datetime64" in e.args[0] or "timestamp" in e.args[0]:
return "temporal"
raise e

Expand Down
20 changes: 20 additions & 0 deletions tests/utils/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@
from altair.utils.core import parse_shorthand, update_nested, infer_encoding_types
from altair.utils.core import infer_dtype

try:
import pyarrow as pa
except ImportError:
pa = None


FAKE_CHANNELS_MODULE = '''
"""Fake channels module for utility tests."""

Expand Down Expand Up @@ -148,6 +154,20 @@ def check(s, data, **kwargs):
check("month(t)", data, timeUnit="month", field="t", type="temporal")


@pytest.mark.skipif(pa is None, reason="pyarrow not installed")
def test_parse_shorthand_for_arrow_timestamp():
data = pd.DataFrame(
{
"z": pd.date_range("2018-01-01", periods=5, freq="D"),
"t": pd.date_range("2018-01-01", periods=5, freq="D").tz_localize("UTC"),
}
)
# Convert to arrow-packed dtypes
data = pa.Table.from_pandas(data).to_pandas(types_mapper=pd.ArrowDtype)
assert parse_shorthand("z", data) == {"field": "z", "type": "temporal"}
assert parse_shorthand("z", data) == {"field": "z", "type": "temporal"}


def test_parse_shorthand_all_aggregates():
aggregates = alt.Root._schema["definitions"]["AggregateOp"]["enum"]
for aggregate in aggregates:
Expand Down
36 changes: 36 additions & 0 deletions tests/utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@

from altair.utils import infer_vegalite_type, sanitize_dataframe

try:
import pyarrow as pa
except ImportError:
pa = None


def test_infer_vegalite_type():
def _check(arr, typ):
Expand Down Expand Up @@ -83,6 +88,37 @@ def test_sanitize_dataframe():
assert df.equals(df2)


@pytest.mark.skipif(pa is None, reason="pyarrow not installed")
def test_sanitize_dataframe_arrow_columns():
# create a dataframe with various types
df = pd.DataFrame(
{
"s": list("abcde"),
"f": np.arange(5, dtype=float),
"i": np.arange(5, dtype=int),
"b": np.array([True, False, True, True, False]),
"d": pd.date_range("2012-01-01", periods=5, freq="H"),
"c": pd.Series(list("ababc"), dtype="category"),
"p": pd.date_range("2012-01-01", periods=5, freq="H").tz_localize("UTC"),
}
)
df_arrow = pa.Table.from_pandas(df).to_pandas(types_mapper=pd.ArrowDtype)
df_clean = sanitize_dataframe(df_arrow)
records = df_clean.to_dict(orient="records")
assert records[0] == {
"s": "a",
"f": 0.0,
"i": 0,
"b": True,
"d": "2012-01-01T00:00:00",
"c": "a",
"p": "2012-01-01T00:00:00+00:00",
}

# Make sure we can serialize to JSON without error
json.dumps(records)


def test_sanitize_dataframe_colnames():
df = pd.DataFrame(np.arange(12).reshape(4, 3))

Expand Down