Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support per-column nulls_last on sort operations #16639

Merged
merged 1 commit into from
Jun 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions crates/polars-core/src/chunked_array/ops/sort/arg_bottom_k.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,14 @@ pub fn _arg_bottom_k(
sort_options: &mut SortMultipleOptions,
) -> PolarsResult<NoNull<IdxCa>> {
let from_n_rows = by_column[0].len();
_broadcast_descending(by_column.len(), &mut sort_options.descending);
let encoded = _get_rows_encoded(by_column, &sort_options.descending, sort_options.nulls_last)?;
_broadcast_bools(by_column.len(), &mut sort_options.descending);
_broadcast_bools(by_column.len(), &mut sort_options.nulls_last);

let encoded = _get_rows_encoded(
by_column,
&sort_options.descending,
&sort_options.nulls_last,
)?;
let arr = encoded.into_array();
let mut rows = arr
.values_iter()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,26 @@ pub(crate) fn arg_sort_multiple_impl<T: NullOrderCmp + Send + Copy>(
by: &[Series],
options: &SortMultipleOptions,
) -> PolarsResult<IdxCa> {
let nulls_last = &options.nulls_last;
let descending = &options.descending;

debug_assert_eq!(descending.len() - 1, by.len());
debug_assert_eq!(nulls_last.len() - 1, by.len());

let compare_inner: Vec<_> = by
.iter()
.map(|s| s.into_total_ord_inner())
.collect_trusted();

let first_descending = descending[0];
let first_nulls_last = nulls_last[0];

let compare = |tpl_a: &(_, T), tpl_b: &(_, T)| -> Ordering {
match (
first_descending,
tpl_a
.1
.null_order_cmp(&tpl_b.1, options.nulls_last ^ first_descending),
.null_order_cmp(&tpl_b.1, first_nulls_last ^ first_descending),
) {
// if ordering is equal, we check the other arrays until we find a non-equal ordering
// if we have exhausted all arrays, we keep the equal ordering.
Expand All @@ -52,7 +57,7 @@ pub(crate) fn arg_sort_multiple_impl<T: NullOrderCmp + Send + Copy>(
ordering_other_columns(
&compare_inner,
descending.get_unchecked(1..),
options.nulls_last,
nulls_last.get_unchecked(1..),
idx_a,
idx_b,
)
Expand Down Expand Up @@ -184,17 +189,19 @@ pub fn _get_rows_encoded_unordered(by: &[Series]) -> PolarsResult<RowsEncoded> {
pub fn _get_rows_encoded(
by: &[Series],
descending: &[bool],
nulls_last: bool,
nulls_last: &[bool],
) -> PolarsResult<RowsEncoded> {
debug_assert_eq!(by.len(), descending.len());
debug_assert_eq!(by.len(), nulls_last.len());

let mut cols = Vec::with_capacity(by.len());
let mut fields = Vec::with_capacity(by.len());
for (by, descending) in by.iter().zip(descending) {
let arr = _get_rows_encoded_compat_array(by)?;

for ((by, desc), null_last) in by.iter().zip(descending).zip(nulls_last) {
let arr = _get_rows_encoded_compat_array(by)?;
let sort_field = EncodingField {
descending: *descending,
nulls_last,
descending: *desc,
nulls_last: *null_last,
no_order: false,
};
match arr.data_type() {
Expand All @@ -203,12 +210,12 @@ pub fn _get_rows_encoded(
let arr = arr.as_any().downcast_ref::<StructArray>().unwrap();
for arr in arr.values() {
cols.push(arr.clone() as ArrayRef);
fields.push(sort_field)
fields.push(sort_field);
}
},
_ => {
cols.push(arr);
fields.push(sort_field)
fields.push(sort_field);
},
}
}
Expand All @@ -219,7 +226,7 @@ pub fn _get_rows_encoded_ca(
name: &str,
by: &[Series],
descending: &[bool],
nulls_last: bool,
nulls_last: &[bool],
) -> PolarsResult<BinaryOffsetChunked> {
_get_rows_encoded(by, descending, nulls_last)
.map(|rows| BinaryOffsetChunked::with_chunk(name, rows.into_array()))
Expand All @@ -236,12 +243,13 @@ pub fn _get_rows_encoded_ca_unordered(
pub(crate) fn argsort_multiple_row_fmt(
by: &[Series],
mut descending: Vec<bool>,
nulls_last: bool,
mut nulls_last: Vec<bool>,
parallel: bool,
) -> PolarsResult<IdxCa> {
_broadcast_descending(by.len(), &mut descending);
_broadcast_bools(by.len(), &mut descending);
_broadcast_bools(by.len(), &mut nulls_last);

let rows_encoded = _get_rows_encoded(by, &descending, nulls_last)?;
let rows_encoded = _get_rows_encoded(by, &descending, &nulls_last)?;
let mut items: Vec<_> = rows_encoded.iter().enumerate_idx().collect();

if parallel {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,15 @@ mod test {

let out = df.sort(
["cat", "vals"],
SortMultipleOptions::default().with_order_descendings([false, false]),
SortMultipleOptions::default().with_order_descending_multi([false, false]),
)?;
let out = out.column("cat")?;
let cat = out.categorical()?;
assert_order(cat, &["a", "a", "b", "c"]);

let out = df.sort(
["vals", "cat"],
SortMultipleOptions::default().with_order_descendings([false, false]),
SortMultipleOptions::default().with_order_descending_multi([false, false]),
)?;
let out = out.column("cat")?;
let cat = out.categorical()?;
Expand Down
29 changes: 14 additions & 15 deletions crates/polars-core/src/chunked_array/ops/sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,14 +277,13 @@ where
fn ordering_other_columns<'a>(
compare_inner: &'a [Box<dyn TotalOrdInner + 'a>],
descending: &[bool],
nulls_last: bool,
nulls_last: &[bool],
idx_a: usize,
idx_b: usize,
) -> Ordering {
for (cmp, descending) in compare_inner.iter().zip(descending) {
// SAFETY:
// indices are in bounds
let ordering = unsafe { cmp.cmp_element_unchecked(idx_a, idx_b, nulls_last ^ descending) };
for ((cmp, descending), null_last) in compare_inner.iter().zip(descending).zip(nulls_last) {
// SAFETY: indices are in bounds
let ordering = unsafe { cmp.cmp_element_unchecked(idx_a, idx_b, null_last ^ descending) };
match (ordering, descending) {
(Ordering::Equal, _) => continue,
(_, true) => return ordering.reverse(),
Expand Down Expand Up @@ -557,7 +556,7 @@ impl StructChunked {
self.name(),
&[self.clone().into_series()],
&[options.descending],
options.nulls_last,
&[options.nulls_last],
)
.unwrap();
bin.arg_sort(Default::default())
Expand Down Expand Up @@ -670,10 +669,10 @@ pub(crate) fn convert_sort_column_multi_sort(s: &Series) -> PolarsResult<Series>
Ok(out)
}

pub fn _broadcast_descending(n_cols: usize, descending: &mut Vec<bool>) {
if n_cols > descending.len() && descending.len() == 1 {
while n_cols != descending.len() {
descending.push(descending[0]);
pub fn _broadcast_bools(n_cols: usize, values: &mut Vec<bool>) {
if n_cols > values.len() && values.len() == 1 {
while n_cols != values.len() {
values.push(values[0]);
}
}
}
Expand All @@ -689,10 +688,10 @@ pub(crate) fn prepare_arg_sort(
.map(convert_sort_column_multi_sort)
.collect::<PolarsResult<Vec<_>>>()?;

let first = columns.remove(0);
_broadcast_bools(n_cols, &mut sort_options.descending);
_broadcast_bools(n_cols, &mut sort_options.nulls_last);

// broadcast ordering
_broadcast_descending(n_cols, &mut sort_options.descending);
let first = columns.remove(0);
Ok((first, columns))
}

Expand Down Expand Up @@ -831,7 +830,7 @@ mod test {

let out = df.sort(
["groups", "values"],
SortMultipleOptions::default().with_order_descendings([true, false]),
SortMultipleOptions::default().with_order_descending_multi([true, false]),
)?;
let expected = df!(
"groups" => [3, 2, 1],
Expand All @@ -841,7 +840,7 @@ mod test {

let out = df.sort(
["values", "groups"],
SortMultipleOptions::default().with_order_descendings([false, true]),
SortMultipleOptions::default().with_order_descending_multi([false, true]),
)?;
let expected = df!(
"groups" => [2, 1, 3],
Expand Down
41 changes: 27 additions & 14 deletions crates/polars-core/src/chunked_array/ops/sort/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ pub struct SortOptions {
/// SortMultipleOptions::default()
/// .with_maintain_order(true)
/// .with_multithreaded(false)
/// .with_order_descendings([false, true])
/// .with_order_descending_multi([false, true])
/// .with_nulls_last(true),
/// )?;
///
Expand All @@ -83,15 +83,15 @@ pub struct SortMultipleOptions {
///
/// If only one value is given, it will broadcast to all columns.
///
/// Use [`SortMultipleOptions::with_order_descendings`]
/// Use [`SortMultipleOptions::with_order_descending_multi`]
/// or [`SortMultipleOptions::with_order_descending`] to modify.
///
/// # Safety
///
/// Len must matches the number of columns or equal to 1.
/// Len must match the number of columns, or equal 1.
pub descending: Vec<bool>,
/// Whether place null values last. Default `false`.
pub nulls_last: bool,
pub nulls_last: Vec<bool>,
/// Whether sort in multiple threads. Default `true`.
pub multithreaded: bool,
/// Whether maintain the order of equal elements. Default `false`.
Expand All @@ -113,7 +113,7 @@ impl Default for SortMultipleOptions {
fn default() -> Self {
Self {
descending: vec![false],
nulls_last: false,
nulls_last: vec![false],
multithreaded: true,
maintain_order: false,
}
Expand All @@ -126,12 +126,15 @@ impl SortMultipleOptions {
Self::default()
}

/// Specify order for each columns. Default all `false`.
/// Specify order for each column. Defaults all `false`.
///
/// # Safety
///
/// Len must matches the number of columns or equal to 1.
pub fn with_order_descendings(mut self, descending: impl IntoIterator<Item = bool>) -> Self {
/// Len must match the number of columns, or be equal to 1.
pub fn with_order_descending_multi(
mut self,
descending: impl IntoIterator<Item = bool>,
) -> Self {
self.descending = descending.into_iter().collect();
self
}
Expand All @@ -142,19 +145,29 @@ impl SortMultipleOptions {
self
}

/// Whether place null values last. Default `false`.
/// Specify whether to place nulls last, per-column. Defaults all `false`.
///
/// # Safety
///
/// Len must match the number of columns, or be equal to 1.
pub fn with_nulls_last_multi(mut self, nulls_last: impl IntoIterator<Item = bool>) -> Self {
self.nulls_last = nulls_last.into_iter().collect();
self
}

/// Whether to place null values last. Default `false`.
pub fn with_nulls_last(mut self, enabled: bool) -> Self {
self.nulls_last = enabled;
self.nulls_last = vec![enabled];
self
}

/// Whether sort in multiple threads. Default `true`.
/// Whether to sort in multiple threads. Default `true`.
pub fn with_multithreaded(mut self, enabled: bool) -> Self {
self.multithreaded = enabled;
self
}

/// Whether maintain the order of equal elements. Default `false`.
/// Whether to maintain the order of equal elements. Default `false`.
pub fn with_maintain_order(mut self, enabled: bool) -> Self {
self.maintain_order = enabled;
self
Expand Down Expand Up @@ -208,7 +221,7 @@ impl From<&SortOptions> for SortMultipleOptions {
fn from(value: &SortOptions) -> Self {
SortMultipleOptions {
descending: vec![value.descending],
nulls_last: value.nulls_last,
nulls_last: vec![value.nulls_last],
multithreaded: value.multithreaded,
maintain_order: value.maintain_order,
}
Expand All @@ -219,7 +232,7 @@ impl From<&SortMultipleOptions> for SortOptions {
fn from(value: &SortMultipleOptions) -> Self {
SortOptions {
descending: value.descending.first().copied().unwrap_or(false),
nulls_last: value.nulls_last,
nulls_last: value.nulls_last.first().copied().unwrap_or(false),
multithreaded: value.multithreaded,
maintain_order: value.maintain_order,
}
Expand Down
19 changes: 7 additions & 12 deletions crates/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1773,18 +1773,17 @@ impl DataFrame {
mut sort_options: SortMultipleOptions,
slice: Option<(i64, usize)>,
) -> PolarsResult<Self> {
// note that the by_column argument also contains evaluated expression from polars-lazy
// that may not even be present in this dataframe.
// note that the by_column argument also contains evaluated expression from
// polars-lazy that may not even be present in this dataframe.

// therefore when we try to set the first columns as sorted, we ignore the error
// as expressions are not present (they are renamed to _POLARS_SORT_COLUMN_i.
let first_descending = sort_options.descending[0];
let first_by_column = by_column[0].name().to_string();

let set_sorted = |df: &mut DataFrame| {
// Mark the first sort column as sorted
// if the column did not exists it is ok, because we sorted by an expression
// not present in the dataframe
// Mark the first sort column as sorted; if the column does not exist it
// is ok, because we sorted by an expression not present in the dataframe
let _ = df.apply(&first_by_column, |s| {
let mut s = s.clone();
if first_descending {
Expand All @@ -1795,14 +1794,11 @@ impl DataFrame {
s
});
};

if self.is_empty() {
let mut out = self.clone();
set_sorted(&mut out);

return Ok(out);
}

if let Some((0, k)) = slice {
return self.bottom_k_impl(k, by_column, sort_options);
}
Expand All @@ -1824,7 +1820,7 @@ impl DataFrame {
let s = &by_column[0];
let options = SortOptions {
descending: sort_options.descending[0],
nulls_last: sort_options.nulls_last,
nulls_last: sort_options.nulls_last[0],
multithreaded: sort_options.multithreaded,
maintain_order: sort_options.maintain_order,
};
Expand All @@ -1836,13 +1832,12 @@ impl DataFrame {
if let Some((offset, len)) = slice {
out = out.slice(offset, len);
}

return Ok(out.into_frame());
}
s.arg_sort(options)
},
_ => {
if sort_options.nulls_last
if sort_options.nulls_last.iter().all(|&x| x)
|| has_struct
|| std::env::var("POLARS_ROW_FMT_SORT").is_ok()
{
Expand Down Expand Up @@ -1899,7 +1894,7 @@ impl DataFrame {
/// df.sort(
/// &["sepal_width", "sepal_length"],
/// SortMultipleOptions::new()
/// .with_order_descendings([false, true])
/// .with_order_descending_multi([false, true])
/// )
/// }
/// ```
Expand Down
Loading