Skip to content

Commit

Permalink
fix: Revert categorical unique code (#20540)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jan 3, 2025
1 parent 58c1745 commit ca36b66
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 27 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
use polars_compute::unique::{
DictionaryRangedUniqueState, PrimitiveRangedUniqueState, RangedUniqueKernel,
};
use polars_compute::unique::{DictionaryRangedUniqueState, RangedUniqueKernel};

use super::*;

Expand Down Expand Up @@ -43,32 +41,11 @@ impl CategoricalChunked {
Ok(out)
}
} else {
let has_nulls = (self.null_count() > 0) as u32;
let mut state = match cat_map.as_ref() {
RevMapping::Global(map, values, _) => {
if self.is_enum() {
PrimitiveRangedUniqueState::new(0, values.len() as u32 + has_nulls)
} else {
let mut min = u32::MAX;
let mut max = 0u32;

for &v in map.keys() {
min = min.min(v);
max = max.max(v);
}

PrimitiveRangedUniqueState::new(min, max + has_nulls)
}
},
RevMapping::Local(values, _) => {
PrimitiveRangedUniqueState::new(0, values.len() as u32 + has_nulls)
},
};

let mut state = DictionaryRangedUniqueState::new(cat_map.get_categories().to_boxed());
for chunk in self.physical().downcast_iter() {
state.append(chunk);
state.key_state().append(chunk);
}
let unique = state.finalize_unique();
let (_, unique, _) = state.finalize_unique().take();
let ca = unsafe {
UInt32Chunked::from_chunks_and_dtype_unchecked(
self.physical().name().clone(),
Expand Down
1 change: 1 addition & 0 deletions py-polars/polars/_utils/various.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,7 @@ def re_escape(s: str) -> str:
return re.sub(f"([{re_rust_metachars}])", r"\\\1", s)


# Don't rename or move. This is used by polars cloud
def display_dot_graph(
*,
dot: str,
Expand Down
20 changes: 20 additions & 0 deletions py-polars/tests/unit/datatypes/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,3 +905,23 @@ def test_categorical_unique() -> None:
s = pl.Series(["a", "b", None], dtype=pl.Categorical)
assert s.n_unique() == 3
assert s.unique().to_list() == ["a", "b", None]


@StringCache()
def test_categorical_unique_20539() -> None:
df = pl.DataFrame({"number": [1, 1, 2, 2, 3], "letter": ["a", "b", "b", "c", "c"]})

result = (
df.cast({"letter": pl.Categorical})
.group_by("number")
.agg(
unique=pl.col("letter").unique(),
unique_with_order=pl.col("letter").unique(maintain_order=True),
)
)

assert result.sort("number").to_dict(as_series=False) == {
"number": [1, 2, 3],
"unique": [["a", "b"], ["b", "c"], ["c"]],
"unique_with_order": [["a", "b"], ["b", "c"], ["c"]],
}

0 comments on commit ca36b66

Please sign in to comment.