Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allow sorting of lists and arrays #20169

Merged
merged 4 commits into from
Dec 6, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: Allow sorting of lists and arrays
coastalwhite committed Dec 5, 2024
commit b6dc92ba344d94bd5624028a663cfef3a476b6c8
62 changes: 45 additions & 17 deletions crates/polars-core/src/chunked_array/ops/row_encode.rs
Original file line number Diff line number Diff line change
@@ -7,15 +7,29 @@ use crate::utils::_split_offsets;
use crate::POOL;

pub(crate) fn convert_series_for_row_encoding(s: &Series) -> PolarsResult<Series> {
use DataType::*;
use DataType as D;
let out = match s.dtype() {
D::Null
| D::Boolean
| D::UInt8
| D::UInt16
| D::UInt32
| D::UInt64
| D::Int8
| D::Int16
| D::Int32
| D::Int64
| D::Float32
| D::Float64
| D::String
| D::Binary
| D::BinaryOffset => s.clone(),

#[cfg(feature = "dtype-categorical")]
Categorical(_, _) | Enum(_, _) => s.rechunk(),
Binary | Boolean => s.clone(),
BinaryOffset => s.clone(),
String => s.clone(),
D::Categorical(_, _) | D::Enum(_, _) => s.rechunk(),

#[cfg(feature = "dtype-struct")]
Struct(_) => {
D::Struct(_) => {
let ca = s.struct_().unwrap();
let new_fields = ca
.fields_as_series()
@@ -29,17 +43,31 @@ pub(crate) fn convert_series_for_row_encoding(s: &Series) -> PolarsResult<Series
},
// we could fallback to default branch, but decimal is not numeric dtype for now, so explicit here
#[cfg(feature = "dtype-decimal")]
Decimal(_, _) => s.clone(),
List(inner) if !inner.is_nested() => s.clone(),
Null => s.clone(),
_ => {
let phys = s.to_physical_repr().into_owned();
polars_ensure!(
phys.dtype().is_numeric(),
InvalidOperation: "cannot sort column of dtype `{}`", s.dtype()
);
phys
},
D::Decimal(_, _) => s.clone(),
D::Array(_, _) => s
.array()
.unwrap()
.apply_to_inner(&|s| convert_series_for_row_encoding(&s))
.unwrap()
.into_series(),
D::List(_) => s
.list()
.unwrap()
.apply_to_inner(&|s| convert_series_for_row_encoding(&s))
.unwrap()
.into_series(),
#[cfg(feature = "dtype-date")]
D::Date => s.to_physical_repr().into_owned(),
#[cfg(feature = "dtype-datetime")]
D::Datetime(_, _) => s.to_physical_repr().into_owned(),
#[cfg(feature = "dtype-duration")]
D::Duration(_) => s.to_physical_repr().into_owned(),
#[cfg(feature = "dtype-time")]
D::Time => s.to_physical_repr().into_owned(),

D::Object(_, _) | D::Unknown(_) => polars_bail!(
InvalidOperation: "cannot sort column of dtype `{}`", s.dtype()
),
};
Ok(out)
}
22 changes: 22 additions & 0 deletions crates/polars-core/src/chunked_array/ops/sort/arg_sort.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use polars_utils::itertools::Itertools;

use self::row_encode::_get_rows_encoded;
use super::*;

// Reduce monomorphisation.
@@ -149,3 +152,22 @@ where

ChunkedArray::with_chunk(name, IdxArr::from_data_default(Buffer::from(idx), None))
}

pub(crate) fn arg_sort_row_fmt(
by: &[Column],
descending: bool,
nulls_last: bool,
parallel: bool,
) -> PolarsResult<IdxCa> {
let rows_encoded = _get_rows_encoded(by, &[descending], &[nulls_last])?;
let mut items: Vec<_> = rows_encoded.iter().enumerate_idx().collect();

if parallel {
POOL.install(|| items.par_sort_by(|a, b| a.1.cmp(b.1)));
} else {
items.sort_by(|a, b| a.1.cmp(b.1));
}

let ca: NoNull<IdxCa> = items.into_iter().map(|tpl| tpl.0).collect();
Ok(ca.into_inner())
}
1 change: 1 addition & 0 deletions crates/polars-core/src/chunked_array/ops/sort/mod.rs
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@ mod categorical;

use std::cmp::Ordering;

pub(crate) use arg_sort::arg_sort_row_fmt;
pub(crate) use arg_sort_multiple::argsort_multiple_row_fmt;
use arrow::bitmap::{Bitmap, MutableBitmap};
use arrow::buffer::Buffer;
18 changes: 18 additions & 0 deletions crates/polars-core/src/series/implementations/array.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::any::Any;
use std::borrow::Cow;

use self::sort::arg_sort_row_fmt;
use super::{private, MetadataFlags};
use crate::chunked_array::cast::CastOptions;
use crate::chunked_array::comparison::*;
@@ -89,6 +90,23 @@ impl SeriesTrait for SeriesWrap<ArrayChunked> {
self.0.shrink_to_fit()
}

fn arg_sort(&self, options: SortOptions) -> IdxCa {
let slf = (*self).clone();
let slf = slf.into_column();
arg_sort_row_fmt(
&[slf],
options.descending,
options.nulls_last,
options.multithreaded,
)
.unwrap()
}

fn sort_with(&self, options: SortOptions) -> PolarsResult<Series> {
let idxs = self.arg_sort(options);
Ok(unsafe { self.take_unchecked(&idxs) })
}

fn slice(&self, offset: i64, length: usize) -> Series {
self.0.slice(offset, length).into_series()
}
18 changes: 18 additions & 0 deletions crates/polars-core/src/series/implementations/list.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use self::sort::arg_sort_row_fmt;
use super::*;
use crate::chunked_array::comparison::*;
#[cfg(feature = "algorithm_group_by")]
@@ -93,6 +94,23 @@ impl SeriesTrait for SeriesWrap<ListChunked> {
);
}

fn arg_sort(&self, options: SortOptions) -> IdxCa {
let slf = (*self).clone();
let slf = slf.into_column();
arg_sort_row_fmt(
&[slf],
options.descending,
options.nulls_last,
options.multithreaded,
)
.unwrap()
}

fn sort_with(&self, options: SortOptions) -> PolarsResult<Series> {
let idxs = self.arg_sort(options);
Ok(unsafe { self.take_unchecked(&idxs) })
}

fn slice(&self, offset: i64, length: usize) -> Series {
self.0.slice(offset, length).into_series()
}
14 changes: 14 additions & 0 deletions py-polars/tests/unit/datatypes/test_array.py
Original file line number Diff line number Diff line change
@@ -383,3 +383,17 @@ def test_zero_width_array(fn: str) -> None:

df = pl.concat([a.to_frame(), b.to_frame()], how="horizontal")
df.select(c=expr_f(pl.col.a, pl.col.b))


def test_sort() -> None:
def tc(a: list[Any], b: list[Any], w: int) -> None:
a_s = pl.Series("l", a, pl.Array(pl.Int64, w))
b_s = pl.Series("l", b, pl.Array(pl.Int64, w))

assert_series_equal(a_s.sort(), b_s)

tc([], [], 1)
tc([[1]], [[1]], 1)
tc([[2], [1]], [[1], [2]], 1)
tc([[2, 1]], [[2, 1]], 2)
tc([[2, 1], [1, 2]], [[1, 2], [2, 1]], 2)
14 changes: 14 additions & 0 deletions py-polars/tests/unit/datatypes/test_list.py
Original file line number Diff line number Diff line change
@@ -839,3 +839,17 @@ def test_null_list_categorical_16405() -> None:

expected = pl.DataFrame([None], schema={"result": pl.List(pl.Categorical)})
assert_frame_equal(df, expected)


def test_sort() -> None:
def tc(a: list[Any], b: list[Any]) -> None:
a_s = pl.Series("l", a, pl.List(pl.Int64))
b_s = pl.Series("l", b, pl.List(pl.Int64))

assert_series_equal(a_s.sort(), b_s)

tc([], [])
tc([[1]], [[1]])
tc([[1], []], [[], [1]])
tc([[2, 1]], [[2, 1]])
tc([[2, 1], [1, 2]], [[1, 2], [2, 1]])