Skip to content

Commit

Permalink
fix: improve replace on categoricals (#13223)
Browse files Browse the repository at this point in the history
  • Loading branch information
c-peters authored Dec 27, 2023
1 parent 3f00e78 commit 8d610b1
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 28 deletions.
18 changes: 17 additions & 1 deletion crates/polars-ops/src/series/ops/replace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,19 @@ pub fn replace(
return Ok(default);
}

let old = old.strict_cast(s.dtype())?;
let old = match (s.dtype(), old.dtype()) {
#[cfg(feature = "dtype-categorical")]
(DataType::Categorical(opt_rev_map, ord), DataType::String) => {
let dt = opt_rev_map
.as_ref()
.filter(|rev_map| rev_map.is_enum())
.map(|rev_map| DataType::Categorical(Some(rev_map.clone()), *ord))
.unwrap_or(DataType::Categorical(None, *ord));

old.strict_cast(&dt)?
},
_ => old.strict_cast(s.dtype())?,
};
let new = new.cast(&return_dtype)?;

if new.len() == 1 {
Expand Down Expand Up @@ -83,6 +95,10 @@ fn replace_by_multiple(

let replaced = joined.column("__POLARS_REPLACE_NEW").unwrap();

if replaced.null_count() == 0 {
return Ok(replaced.clone());
}

match joined.column("__POLARS_REPLACE_MASK") {
Ok(col) => {
let mask = col.bool()?;
Expand Down
86 changes: 59 additions & 27 deletions py-polars/tests/unit/operations/test_replace.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

import contextlib
from typing import Any

import pytest

import polars as pl
from polars.exceptions import CategoricalRemappingWarning
from polars.testing import assert_frame_equal, assert_series_equal


Expand Down Expand Up @@ -55,33 +57,6 @@ def test_replace_str_to_str_default_other(str_mapping: dict[str | None, str]) ->
assert_frame_equal(result, expected)


@pl.StringCache()
def test_replace_cat_to_str_err(str_mapping: dict[str | None, str]) -> None:
df = pl.DataFrame(
{"country_code": ["FR", None, "ES", "DE"]},
schema={"country_code": pl.Categorical},
)

with pytest.raises(
pl.InvalidOperationError,
match="casting to a non-enum variant with rev map is not supported for the user",
):
df.select(pl.col("country_code").replace(str_mapping))


# https://github.com/pola-rs/polars/issues/13164
@pl.StringCache()
def test_replace_cat_to_str_fast_path_err() -> None:
s = pl.Series(["a", "b"], dtype=pl.Categorical)
mapping = {"a": "c"}

with pytest.raises(
pl.InvalidOperationError,
match="casting to a non-enum variant with rev map is not supported for the user",
):
s.replace(mapping)


def test_replace_str_to_cat() -> None:
s = pl.Series(["a", "b", "c"])
mapping = {"a": "c", "b": "d"}
Expand Down Expand Up @@ -503,3 +478,60 @@ def test_map_dict_deprecated() -> None:
with pytest.deprecated_call():
result = s.to_frame().select(pl.col("a").map_dict({2: 100})).to_series()
assert_series_equal(result, expected)


@pytest.mark.parametrize(
("context", "dtype"),
[
(pl.StringCache(), pl.Categorical),
(pytest.warns(CategoricalRemappingWarning), pl.Categorical),
(contextlib.nullcontext(), pl.Enum(["a", "b", "OTHER"])),
],
)
def test_replace_cat_str(
context: contextlib.AbstractContextManager, # type: ignore[type-arg]
dtype: pl.DataType,
) -> None:
with context:
for old, new, expected in [
("a", "c", pl.Series("s", ["c", None], dtype=pl.Utf8)),
(["a", "b"], ["c", "d"], pl.Series("s", ["c", "d"], dtype=pl.Utf8)),
(pl.lit("a", dtype=dtype), "c", pl.Series("s", ["c", None], dtype=pl.Utf8)),
(
pl.Series(["a", "b"], dtype=dtype),
["c", "d"],
pl.Series("s", ["c", "d"], dtype=pl.Utf8),
),
]:
s = pl.Series("s", ["a", "b"], dtype=dtype)
s_replaced = s.replace(old, new, default=None) # type: ignore[arg-type]
assert_series_equal(s_replaced, expected)

s = pl.Series("s", ["a", "b"], dtype=dtype)
s_replaced = s.replace(old, new, default="OTHER") # type: ignore[arg-type]
assert_series_equal(s_replaced, expected.fill_null("OTHER"))


@pytest.mark.parametrize(
"context", [pl.StringCache(), pytest.warns(CategoricalRemappingWarning)]
)
def test_replace_cat_cat(
context: contextlib.AbstractContextManager, # type: ignore[type-arg]
) -> None:
with context:
dt = pl.Categorical
for old, new, expected in [
("a", pl.lit("c", dtype=dt), pl.Series("s", ["c", None], dtype=dt)),
(
["a", "b"],
pl.Series(["c", "d"], dtype=dt),
pl.Series("s", ["c", "d"], dtype=dt),
),
]:
s = pl.Series("s", ["a", "b"], dtype=dt)
s_replaced = s.replace(old, new, default=None) # type: ignore[arg-type]
assert_series_equal(s_replaced, expected)

s = pl.Series("s", ["a", "b"], dtype=dt)
s_replaced = s.replace(old, new, default=pl.lit("OTHER", dtype=dt)) # type: ignore[arg-type]
assert_series_equal(s_replaced, expected.fill_null("OTHER"))

0 comments on commit 8d610b1

Please sign in to comment.