Skip to content

Commit

Permalink
fix(rust): Validate column names in unique() for empty DataFrames (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Biswas-N authored Dec 29, 2024
1 parent ef32c9a commit b430f64
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
14 changes: 13 additions & 1 deletion crates/polars-plan/src/plans/conversion/dsl_to_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -681,8 +681,20 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult

let subset = options
.subset
.map(|s| expand_selectors(s, input_schema.as_ref(), &[]))
.map(|s| {
let cols = expand_selectors(s, input_schema.as_ref(), &[])?;

// Checking if subset columns exist in the dataframe
for col in cols.iter() {
let _ = input_schema
.try_get(col)
.map_err(|_| polars_err!(col_not_found = col))?;
}

Ok::<_, PolarsError>(cols)
})
.transpose()?;

let options = DistinctOptionsIR {
subset,
maintain_order: options.maintain_order,
Expand Down
33 changes: 33 additions & 0 deletions py-polars/tests/unit/operations/unique/test_unique.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@

import re
from datetime import date
from typing import TYPE_CHECKING, Any

import pytest

import polars as pl
from polars.exceptions import ColumnNotFoundError
from polars.testing import assert_frame_equal, assert_series_equal

if TYPE_CHECKING:
from polars._typing import PolarsDataType


def test_unique_predicate_pd() -> None:
lf = pl.LazyFrame(
Expand Down Expand Up @@ -163,6 +168,34 @@ def test_unique_with_null() -> None:
assert_frame_equal(df.unique(maintain_order=True), expected_df)


@pytest.mark.parametrize(
("input_json_data", "input_schema", "subset"),
[
({"ID": [], "Name": []}, {"ID": pl.Int64, "Name": pl.String}, "id"),
({"ID": [], "Name": []}, {"ID": pl.Int64, "Name": pl.String}, ["age", "place"]),
(
{"ID": [1, 2, 1, 2], "Name": ["foo", "bar", "baz", "baa"]},
{"ID": pl.Int64, "Name": pl.String},
"id",
),
(
{"ID": [1, 2, 1, 2], "Name": ["foo", "bar", "baz", "baa"]},
{"ID": pl.Int64, "Name": pl.String},
["age", "place"],
),
],
)
def test_unique_with_bad_subset(
input_json_data: dict[str, list[Any]],
input_schema: dict[str, PolarsDataType],
subset: str | list[str],
) -> None:
df = pl.DataFrame(input_json_data, schema=input_schema)

with pytest.raises(ColumnNotFoundError, match="not found"):
df.unique(subset=subset)


def test_categorical_unique_19409() -> None:
df = pl.DataFrame({"x": [str(n % 50) for n in range(127)]}).cast(pl.Categorical)
uniq = df.unique()
Expand Down

0 comments on commit b430f64

Please sign in to comment.