Skip to content

Commit

Permalink
fix: dataframe fillna with scalar. (#1132)
Browse files Browse the repository at this point in the history
* fix: dataframe fillna with string scalar.

* update type supports

* remove case that pandas has issue

* update annotation
  • Loading branch information
Genesis929 authored Nov 13, 2024
1 parent 8d4da15 commit 37f8c32
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 7 deletions.
7 changes: 5 additions & 2 deletions bigframes/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ def _apply_binop(
how: str = "outer",
reverse: bool = False,
):
if isinstance(other, (float, int, bool)):
if isinstance(other, bigframes.dtypes.LOCAL_SCALAR_TYPES):
return self._apply_scalar_binop(other, op, reverse=reverse)
elif isinstance(other, DataFrame):
return self._apply_dataframe_binop(other, op, how=how, reverse=reverse)
Expand All @@ -752,7 +752,10 @@ def _apply_binop(
)

def _apply_scalar_binop(
self, other: float | int, op: ops.BinaryOp, reverse: bool = False
self,
other: bigframes.dtypes.LOCAL_SCALAR_TYPE,
op: ops.BinaryOp,
reverse: bool = False,
) -> DataFrame:
if reverse:
expr = op.as_expr(
Expand Down
19 changes: 19 additions & 0 deletions bigframes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,25 @@
# Used when storing Null expressions
DEFAULT_DTYPE = FLOAT_DTYPE

LOCAL_SCALAR_TYPE = Union[
bool,
np.bool_,
int,
np.integer,
float,
np.floating,
decimal.Decimal,
str,
np.str_,
bytes,
np.bytes_,
datetime.datetime,
pd.Timestamp,
datetime.date,
datetime.time,
]
LOCAL_SCALAR_TYPES = typing.get_args(LOCAL_SCALAR_TYPE)


# Will have a few dtype variants: simple(eg. int, string, bool), complex (eg. list, struct), and virtual (eg. micro intervals, categorical)
@dataclass(frozen=True)
Expand Down
17 changes: 12 additions & 5 deletions tests/system/small/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,13 +1020,20 @@ def test_df_interpolate(scalars_dfs):
)


def test_df_fillna(scalars_dfs):
@pytest.mark.parametrize(
"col, fill_value",
[
(["int64_col", "float64_col"], 3),
(["string_col"], "A"),
(["datetime_col"], pd.Timestamp("2023-01-01")),
],
)
def test_df_fillna(scalars_dfs, col, fill_value):
scalars_df, scalars_pandas_df = scalars_dfs
df = scalars_df[["int64_col", "float64_col"]].fillna(3)
bf_result = df.to_pandas()
pd_result = scalars_pandas_df[["int64_col", "float64_col"]].fillna(3)
bf_result = scalars_df[col].fillna(fill_value).to_pandas()
pd_result = scalars_pandas_df[col].fillna(fill_value)

pandas.testing.assert_frame_equal(bf_result, pd_result)
pd.testing.assert_frame_equal(bf_result, pd_result, check_dtype=False)


def test_df_replace_scalar_scalar(scalars_dfs):
Expand Down

0 comments on commit 37f8c32

Please sign in to comment.