Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add astype(type, errors='null') to cast safely #1122

Merged
merged 3 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions bigframes/core/compile/ibis_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@


def cast_ibis_value(
value: ibis_types.Value, to_type: ibis_dtypes.DataType
value: ibis_types.Value, to_type: ibis_dtypes.DataType, safe: bool = False
) -> ibis_types.Value:
"""Perform compatible type casts of ibis values

Expand Down Expand Up @@ -176,7 +176,7 @@ def cast_ibis_value(
value = ibis_value_to_canonical_type(value)
if value.type() in good_casts:
if to_type in good_casts[value.type()]:
return value.cast(to_type)
return value.try_cast(to_type) if safe else value.cast(to_type)
else:
# this should never happen
raise TypeError(
Expand All @@ -188,10 +188,16 @@ def cast_ibis_value(
# BigQuery casts bools to lower case strings. Capitalize the result to match Pandas
# TODO(bmil): remove this workaround after fixing Ibis
if value.type() == ibis_dtypes.bool and to_type == ibis_dtypes.string:
return cast(ibis_types.StringValue, value.cast(to_type)).capitalize()
if safe:
return cast(ibis_types.StringValue, value.try_cast(to_type)).capitalize()
else:
return cast(ibis_types.StringValue, value.cast(to_type)).capitalize()

if value.type() == ibis_dtypes.bool and to_type == ibis_dtypes.float64:
return value.cast(ibis_dtypes.int64).cast(ibis_dtypes.float64)
if safe:
return value.try_cast(ibis_dtypes.int64).try_cast(ibis_dtypes.float64)
else:
return value.cast(ibis_dtypes.int64).cast(ibis_dtypes.float64)

if value.type() == ibis_dtypes.float64 and to_type == ibis_dtypes.bool:
return value != ibis_types.literal(0)
Expand Down
28 changes: 21 additions & 7 deletions bigframes/core/compile/scalar_op_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,7 +947,9 @@ def struct_field_op_impl(x: ibis_types.Value, op: ops.StructFieldOp):
return result.cast(result.type()(nullable=True)).name(name)


def numeric_to_datetime(x: ibis_types.Value, unit: str) -> ibis_types.TimestampValue:
def numeric_to_datetime(
x: ibis_types.Value, unit: str, safe: bool = False
) -> ibis_types.TimestampValue:
if not isinstance(x, ibis_types.IntegerValue) and not isinstance(
x, ibis_types.FloatingValue
):
Expand All @@ -956,7 +958,11 @@ def numeric_to_datetime(x: ibis_types.Value, unit: str) -> ibis_types.TimestampV
if unit not in UNIT_TO_US_CONVERSION_FACTORS:
raise ValueError(f"Cannot convert input with unit '{unit}'.")
x_converted = x * UNIT_TO_US_CONVERSION_FACTORS[unit]
x_converted = x_converted.cast(ibis_dtypes.int64)
x_converted = (
x_converted.try_cast(ibis_dtypes.int64)
if safe
else x_converted.cast(ibis_dtypes.int64)
)

# Note: Due to an issue where casting directly to a timestamp
# without a timezone does not work, we first cast to UTC. This
Expand All @@ -978,8 +984,11 @@ def astype_op_impl(x: ibis_types.Value, op: ops.AsTypeOp):

# When casting DATETIME column into INT column, we need to convert the column into TIMESTAMP first.
if to_type == ibis_dtypes.int64 and x.type() == ibis_dtypes.timestamp:
x_converted = x.cast(ibis_dtypes.Timestamp(timezone="UTC"))
return bigframes.core.compile.ibis_types.cast_ibis_value(x_converted, to_type)
utc_time_type = ibis_dtypes.Timestamp(timezone="UTC")
x_converted = x.try_cast(utc_time_type) if op.safe else x.cast(utc_time_type)
return bigframes.core.compile.ibis_types.cast_ibis_value(
x_converted, to_type, safe=op.safe
)

if to_type == ibis_dtypes.int64 and x.type() == ibis_dtypes.time:
# The conversion unit is set to "us" (microseconds) for consistency
Expand All @@ -991,15 +1000,20 @@ def astype_op_impl(x: ibis_types.Value, op: ops.AsTypeOp):
# with pandas converting int64[pyarrow] to timestamp[us][pyarrow],
# timestamp[us, tz=UTC][pyarrow], and time64[us][pyarrow].
unit = "us"
x_converted = numeric_to_datetime(x, unit)
x_converted = numeric_to_datetime(x, unit, safe=op.safe)
if to_type == ibis_dtypes.timestamp:
return x_converted.cast(ibis_dtypes.Timestamp())
return (
x_converted.try_cast(ibis_dtypes.Timestamp())
if op.safe
else x_converted.cast(ibis_dtypes.Timestamp())
)
elif to_type == ibis_dtypes.Timestamp(timezone="UTC"):
return x_converted
elif to_type == ibis_dtypes.time:
return x_converted.time()

return bigframes.core.compile.ibis_types.cast_ibis_value(x, to_type)
# TODO: either inline this function, or push rest of this op into the function
return bigframes.core.compile.ibis_types.cast_ibis_value(x, to_type, safe=op.safe)


@scalar_op_compiler.register_unary_op(ops.IsInOp, pass_op=True)
Expand Down
10 changes: 8 additions & 2 deletions bigframes/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from __future__ import annotations

import typing
from typing import Hashable, Optional, Sequence, Union
from typing import Hashable, Literal, Optional, Sequence, Union

import bigframes_vendored.constants as constants
import bigframes_vendored.pandas.core.indexes.base as vendored_pandas_index
Expand Down Expand Up @@ -324,11 +324,17 @@ def sort_values(self, *, ascending: bool = True, na_position: str = "last"):
def astype(
self,
dtype: Union[bigframes.dtypes.DtypeString, bigframes.dtypes.Dtype],
*,
errors: Literal["raise", "null"] = "raise",
) -> Index:
if errors not in ["raise", "null"]:
raise ValueError("Arg 'error' must be one of 'raise' or 'null'")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Argument

if self.nlevels > 1:
raise TypeError("Multiindex does not support 'astype'")
return self._apply_unary_expr(
ops.AsTypeOp(to_type=dtype).as_expr(ex.free_var("arg"))
ops.AsTypeOp(to_type=dtype, safe=(errors == "null")).as_expr(
ex.free_var("arg")
)
)

def all(self) -> bool:
Expand Down
8 changes: 7 additions & 1 deletion bigframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,14 @@ def __iter__(self):
def astype(
self,
dtype: Union[bigframes.dtypes.DtypeString, bigframes.dtypes.Dtype],
*,
errors: Literal["raise", "null"] = "raise",
) -> DataFrame:
return self._apply_unary_op(ops.AsTypeOp(to_type=dtype))
if errors not in ["raise", "null"]:
raise ValueError("Arg 'error' must be one of 'raise' or 'null'")
return self._apply_unary_op(
ops.AsTypeOp(to_type=dtype, safe=(errors == "null"))
)

def _to_sql_query(
self, include_index: bool, enable_cache: bool = True
Expand Down
1 change: 1 addition & 0 deletions bigframes/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,7 @@ class AsTypeOp(UnaryOp):
name: typing.ClassVar[str] = "astype"
# TODO: Convert strings to dtype earlier
to_type: dtypes.DtypeString | dtypes.Dtype
safe: bool = False

def output_type(self, *input_types):
# TODO: We should do this conversion earlier
Expand Down
8 changes: 7 additions & 1 deletion bigframes/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,8 +352,14 @@ def __repr__(self) -> str:
def astype(
self,
dtype: Union[bigframes.dtypes.DtypeString, bigframes.dtypes.Dtype],
*,
errors: Literal["raise", "null"] = "raise",
) -> Series:
return self._apply_unary_op(bigframes.operations.AsTypeOp(to_type=dtype))
if errors not in ["raise", "null"]:
raise ValueError("Arg 'error' must be one of 'raise' or 'null'")
return self._apply_unary_op(
bigframes.operations.AsTypeOp(to_type=dtype, safe=(errors == "null"))
)

def to_pandas(
self,
Expand Down
6 changes: 6 additions & 0 deletions tests/system/small/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3687,6 +3687,12 @@ def test_df_add_suffix(scalars_df_index, scalars_pandas_df_index, axis):
)


def test_df_astype_error_error(session):
input = pd.DataFrame(["hello", "world", "3.11", "4000"])
with pytest.raises(ValueError):
session.read_pandas(input).astype("Float64", errors="bad_value")


def test_df_columns_filter_items(scalars_df_index, scalars_pandas_df_index):
if pd.__version__.startswith("2.0") or pd.__version__.startswith("1."):
pytest.skip("pandas filter items behavior different pre-2.1")
Expand Down
6 changes: 6 additions & 0 deletions tests/system/small/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ def test_index_astype(scalars_df_index, scalars_pandas_df_index):
pd.testing.assert_index_equal(bf_result, pd_result)


def test_index_astype_error_error(session):
input = pd.Index(["hello", "world", "3.11", "4000"])
with pytest.raises(ValueError):
session.read_pandas(input).astype("Float64", errors="bad_value")


def test_index_any(scalars_df_index, scalars_pandas_df_index):
bf_result = scalars_df_index.set_index("int64_col").index.any()
pd_result = scalars_pandas_df_index.set_index("int64_col").index.any()
Expand Down
23 changes: 21 additions & 2 deletions tests/system/small/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3087,6 +3087,7 @@ def foo(x):
assert_series_equal(bf_result, pd_result, check_dtype=False)


@pytest.mark.parametrize("errors", ["raise", "null"])
@pytest.mark.parametrize(
("column", "to_type"),
[
Expand All @@ -3102,6 +3103,7 @@ def foo(x):
("int64_col", "time64[us][pyarrow]"),
("bool_col", "Int64"),
("bool_col", "string[pyarrow]"),
("bool_col", "Float64"),
("string_col", "binary[pyarrow]"),
("bytes_col", "string[pyarrow]"),
# pandas actually doesn't let folks convert to/from naive timestamp and
Expand Down Expand Up @@ -3137,12 +3139,29 @@ def foo(x):
],
)
@skip_legacy_pandas
def test_astype(scalars_df_index, scalars_pandas_df_index, column, to_type):
bf_result = scalars_df_index[column].astype(to_type).to_pandas()
def test_astype(scalars_df_index, scalars_pandas_df_index, column, to_type, errors):
bf_result = scalars_df_index[column].astype(to_type, errors=errors).to_pandas()
pd_result = scalars_pandas_df_index[column].astype(to_type)
pd.testing.assert_series_equal(bf_result, pd_result)


def test_astype_safe(session):
input = pd.Series(["hello", "world", "3.11", "4000"])
exepcted = pd.Series(
[None, None, 3.11, 4000],
dtype="Float64",
index=pd.Index([0, 1, 2, 3], dtype="Int64"),
)
result = session.read_pandas(input).astype("Float64", errors="null").to_pandas()
pd.testing.assert_series_equal(result, exepcted)


def test_series_astype_error_error(session):
input = pd.Series(["hello", "world", "3.11", "4000"])
with pytest.raises(ValueError):
session.read_pandas(input).astype("Float64", errors="bad_value")


@skip_legacy_pandas
def test_astype_numeric_to_int(scalars_df_index, scalars_pandas_df_index):
column = "numeric_col"
Expand Down
4 changes: 4 additions & 0 deletions third_party/bigframes_vendored/pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ def astype(self, dtype):
``pd.ArrowDtype(pa.time64("us"))``,
``pd.ArrowDtype(pa.timestamp("us"))``,
``pd.ArrowDtype(pa.timestamp("us", tz="UTC"))``.
errors ({'raise', 'null'}, default 'raise'):
Control raising of exceptions on invalid data for provided dtype.
If 'raise', allow exceptions to be raised if any value fails cast
If 'null', will assign null value if value fails cast

Returns:
bigframes.pandas.DataFrame:
Expand Down
5 changes: 5 additions & 0 deletions third_party/bigframes_vendored/pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,11 @@ def astype(self, dtype):

Args:
dtype (numpy dtype or pandas type):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it str or pandas.ExtensionDtype as the list in the generic.py?

A dtype supported by BigQuery DataFrames
errors ({'raise', 'null'}, default 'raise'):
Control raising of exceptions on invalid data for provided dtype.
If 'raise', allow exceptions to be raised if any value fails cast
If 'null', will assign null value if value fails cast

Returns:
Index: Index with values cast to specified dtype.
Expand Down