From 3fadc78dc368a05a47423f7d1d25e9c19a532c72 Mon Sep 17 00:00:00 2001 From: alexander-beedie Date: Fri, 31 May 2024 15:08:37 +0400 Subject: [PATCH] feat: Support per-column `nulls_last` on sort operations --- .../chunked_array/ops/sort/arg_bottom_k.rs | 10 +- .../ops/sort/arg_sort_multiple.rs | 34 +++--- .../src/chunked_array/ops/sort/categorical.rs | 4 +- .../src/chunked_array/ops/sort/mod.rs | 29 +++-- .../src/chunked_array/ops/sort/options.rs | 41 ++++--- crates/polars-core/src/frame/mod.rs | 19 ++- .../src/series/implementations/struct_.rs | 12 +- crates/polars-expr/src/expressions/sortby.rs | 26 +++-- crates/polars-lazy/src/frame/mod.rs | 2 +- crates/polars-lazy/src/tests/queries.rs | 4 +- crates/polars-lazy/src/tests/streaming.rs | 2 +- crates/polars-lazy/src/tests/tpch.rs | 2 +- crates/polars-ops/src/series/ops/various.rs | 8 +- .../src/executors/sinks/sort/sink.rs | 2 +- .../src/executors/sinks/sort/sink_multiple.rs | 11 +- .../src/logical_plan/alp/tree_format.rs | 9 +- crates/polars-sql/src/context.rs | 2 +- crates/polars-sql/src/functions.rs | 2 +- crates/polars-sql/src/sql_expr.rs | 2 +- crates/polars-sql/tests/ops_distinct_on.rs | 20 ++-- crates/polars-sql/tests/simple_exprs.rs | 2 +- py-polars/polars/_utils/various.py | 17 +++ py-polars/polars/dataframe/frame.py | 9 +- py-polars/polars/expr/expr.py | 31 ++--- py-polars/polars/functions/lazy.py | 11 +- py-polars/polars/lazyframe/frame.py | 29 ++--- py-polars/src/expr/general.rs | 6 +- py-polars/src/functions/lazy.rs | 2 +- py-polars/src/lazyframe/mod.rs | 8 +- py-polars/src/lazyframe/visitor/expr_nodes.rs | 4 +- py-polars/src/lazyframe/visitor/nodes.rs | 4 +- py-polars/tests/unit/datatypes/test_struct.py | 20 ++-- py-polars/tests/unit/operations/test_sort.py | 108 +++++++++++------- 33 files changed, 279 insertions(+), 213 deletions(-) diff --git a/crates/polars-core/src/chunked_array/ops/sort/arg_bottom_k.rs b/crates/polars-core/src/chunked_array/ops/sort/arg_bottom_k.rs index dff7b1fed733..300e9f85611e 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/arg_bottom_k.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/arg_bottom_k.rs @@ -35,8 +35,14 @@ pub fn _arg_bottom_k( sort_options: &mut SortMultipleOptions, ) -> PolarsResult> { let from_n_rows = by_column[0].len(); - _broadcast_descending(by_column.len(), &mut sort_options.descending); - let encoded = _get_rows_encoded(by_column, &sort_options.descending, sort_options.nulls_last)?; + _broadcast_bools(by_column.len(), &mut sort_options.descending); + _broadcast_bools(by_column.len(), &mut sort_options.nulls_last); + + let encoded = _get_rows_encoded( + by_column, + &sort_options.descending, + &sort_options.nulls_last, + )?; let arr = encoded.into_array(); let mut rows = arr .values_iter() diff --git a/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs b/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs index ca6651282259..8b909f09930a 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs @@ -27,21 +27,26 @@ pub(crate) fn arg_sort_multiple_impl( by: &[Series], options: &SortMultipleOptions, ) -> PolarsResult { + let nulls_last = &options.nulls_last; let descending = &options.descending; + debug_assert_eq!(descending.len() - 1, by.len()); + debug_assert_eq!(nulls_last.len() - 1, by.len()); + let compare_inner: Vec<_> = by .iter() .map(|s| s.into_total_ord_inner()) .collect_trusted(); let first_descending = descending[0]; + let first_nulls_last = nulls_last[0]; let compare = |tpl_a: &(_, T), tpl_b: &(_, T)| -> Ordering { match ( first_descending, tpl_a .1 - .null_order_cmp(&tpl_b.1, options.nulls_last ^ first_descending), + .null_order_cmp(&tpl_b.1, first_nulls_last ^ first_descending), ) { // if ordering is equal, we check the other arrays until we find a non-equal ordering // if we have exhausted all arrays, we keep the equal ordering. @@ -52,7 +57,7 @@ pub(crate) fn arg_sort_multiple_impl( ordering_other_columns( &compare_inner, descending.get_unchecked(1..), - options.nulls_last, + nulls_last.get_unchecked(1..), idx_a, idx_b, ) @@ -184,17 +189,19 @@ pub fn _get_rows_encoded_unordered(by: &[Series]) -> PolarsResult { pub fn _get_rows_encoded( by: &[Series], descending: &[bool], - nulls_last: bool, + nulls_last: &[bool], ) -> PolarsResult { debug_assert_eq!(by.len(), descending.len()); + debug_assert_eq!(by.len(), nulls_last.len()); + let mut cols = Vec::with_capacity(by.len()); let mut fields = Vec::with_capacity(by.len()); - for (by, descending) in by.iter().zip(descending) { - let arr = _get_rows_encoded_compat_array(by)?; + for ((by, desc), null_last) in by.iter().zip(descending).zip(nulls_last) { + let arr = _get_rows_encoded_compat_array(by)?; let sort_field = EncodingField { - descending: *descending, - nulls_last, + descending: *desc, + nulls_last: *null_last, no_order: false, }; match arr.data_type() { @@ -203,12 +210,12 @@ pub fn _get_rows_encoded( let arr = arr.as_any().downcast_ref::().unwrap(); for arr in arr.values() { cols.push(arr.clone() as ArrayRef); - fields.push(sort_field) + fields.push(sort_field); } }, _ => { cols.push(arr); - fields.push(sort_field) + fields.push(sort_field); }, } } @@ -219,7 +226,7 @@ pub fn _get_rows_encoded_ca( name: &str, by: &[Series], descending: &[bool], - nulls_last: bool, + nulls_last: &[bool], ) -> PolarsResult { _get_rows_encoded(by, descending, nulls_last) .map(|rows| BinaryOffsetChunked::with_chunk(name, rows.into_array())) @@ -236,12 +243,13 @@ pub fn _get_rows_encoded_ca_unordered( pub(crate) fn argsort_multiple_row_fmt( by: &[Series], mut descending: Vec, - nulls_last: bool, + mut nulls_last: Vec, parallel: bool, ) -> PolarsResult { - _broadcast_descending(by.len(), &mut descending); + _broadcast_bools(by.len(), &mut descending); + _broadcast_bools(by.len(), &mut nulls_last); - let rows_encoded = _get_rows_encoded(by, &descending, nulls_last)?; + let rows_encoded = _get_rows_encoded(by, &descending, &nulls_last)?; let mut items: Vec<_> = rows_encoded.iter().enumerate_idx().collect(); if parallel { diff --git a/crates/polars-core/src/chunked_array/ops/sort/categorical.rs b/crates/polars-core/src/chunked_array/ops/sort/categorical.rs index 0594421f62d3..3b2c67db8eb2 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/categorical.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/categorical.rs @@ -177,7 +177,7 @@ mod test { let out = df.sort( ["cat", "vals"], - SortMultipleOptions::default().with_order_descendings([false, false]), + SortMultipleOptions::default().with_order_descending_multi([false, false]), )?; let out = out.column("cat")?; let cat = out.categorical()?; @@ -185,7 +185,7 @@ mod test { let out = df.sort( ["vals", "cat"], - SortMultipleOptions::default().with_order_descendings([false, false]), + SortMultipleOptions::default().with_order_descending_multi([false, false]), )?; let out = out.column("cat")?; let cat = out.categorical()?; diff --git a/crates/polars-core/src/chunked_array/ops/sort/mod.rs b/crates/polars-core/src/chunked_array/ops/sort/mod.rs index 0751e01e3cb7..5650e93505d1 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/mod.rs @@ -277,14 +277,13 @@ where fn ordering_other_columns<'a>( compare_inner: &'a [Box], descending: &[bool], - nulls_last: bool, + nulls_last: &[bool], idx_a: usize, idx_b: usize, ) -> Ordering { - for (cmp, descending) in compare_inner.iter().zip(descending) { - // SAFETY: - // indices are in bounds - let ordering = unsafe { cmp.cmp_element_unchecked(idx_a, idx_b, nulls_last ^ descending) }; + for ((cmp, descending), null_last) in compare_inner.iter().zip(descending).zip(nulls_last) { + // SAFETY: indices are in bounds + let ordering = unsafe { cmp.cmp_element_unchecked(idx_a, idx_b, null_last ^ descending) }; match (ordering, descending) { (Ordering::Equal, _) => continue, (_, true) => return ordering.reverse(), @@ -557,7 +556,7 @@ impl StructChunked { self.name(), &[self.clone().into_series()], &[options.descending], - options.nulls_last, + &[options.nulls_last], ) .unwrap(); bin.arg_sort(Default::default()) @@ -670,10 +669,10 @@ pub(crate) fn convert_sort_column_multi_sort(s: &Series) -> PolarsResult Ok(out) } -pub fn _broadcast_descending(n_cols: usize, descending: &mut Vec) { - if n_cols > descending.len() && descending.len() == 1 { - while n_cols != descending.len() { - descending.push(descending[0]); +pub fn _broadcast_bools(n_cols: usize, values: &mut Vec) { + if n_cols > values.len() && values.len() == 1 { + while n_cols != values.len() { + values.push(values[0]); } } } @@ -689,10 +688,10 @@ pub(crate) fn prepare_arg_sort( .map(convert_sort_column_multi_sort) .collect::>>()?; - let first = columns.remove(0); + _broadcast_bools(n_cols, &mut sort_options.descending); + _broadcast_bools(n_cols, &mut sort_options.nulls_last); - // broadcast ordering - _broadcast_descending(n_cols, &mut sort_options.descending); + let first = columns.remove(0); Ok((first, columns)) } @@ -831,7 +830,7 @@ mod test { let out = df.sort( ["groups", "values"], - SortMultipleOptions::default().with_order_descendings([true, false]), + SortMultipleOptions::default().with_order_descending_multi([true, false]), )?; let expected = df!( "groups" => [3, 2, 1], @@ -841,7 +840,7 @@ mod test { let out = df.sort( ["values", "groups"], - SortMultipleOptions::default().with_order_descendings([false, true]), + SortMultipleOptions::default().with_order_descending_multi([false, true]), )?; let expected = df!( "groups" => [2, 1, 3], diff --git a/crates/polars-core/src/chunked_array/ops/sort/options.rs b/crates/polars-core/src/chunked_array/ops/sort/options.rs index 49ff2ca52286..8726da26774a 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/options.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/options.rs @@ -63,7 +63,7 @@ pub struct SortOptions { /// SortMultipleOptions::default() /// .with_maintain_order(true) /// .with_multithreaded(false) -/// .with_order_descendings([false, true]) +/// .with_order_descending_multi([false, true]) /// .with_nulls_last(true), /// )?; /// @@ -83,15 +83,15 @@ pub struct SortMultipleOptions { /// /// If only one value is given, it will broadcast to all columns. /// - /// Use [`SortMultipleOptions::with_order_descendings`] + /// Use [`SortMultipleOptions::with_order_descending_multi`] /// or [`SortMultipleOptions::with_order_descending`] to modify. /// /// # Safety /// - /// Len must matches the number of columns or equal to 1. + /// Len must match the number of columns, or equal 1. pub descending: Vec, /// Whether place null values last. Default `false`. - pub nulls_last: bool, + pub nulls_last: Vec, /// Whether sort in multiple threads. Default `true`. pub multithreaded: bool, /// Whether maintain the order of equal elements. Default `false`. @@ -113,7 +113,7 @@ impl Default for SortMultipleOptions { fn default() -> Self { Self { descending: vec![false], - nulls_last: false, + nulls_last: vec![false], multithreaded: true, maintain_order: false, } @@ -126,12 +126,15 @@ impl SortMultipleOptions { Self::default() } - /// Specify order for each columns. Default all `false`. + /// Specify order for each column. Defaults all `false`. /// /// # Safety /// - /// Len must matches the number of columns or equal to 1. - pub fn with_order_descendings(mut self, descending: impl IntoIterator) -> Self { + /// Len must match the number of columns, or be equal to 1. + pub fn with_order_descending_multi( + mut self, + descending: impl IntoIterator, + ) -> Self { self.descending = descending.into_iter().collect(); self } @@ -142,19 +145,29 @@ impl SortMultipleOptions { self } - /// Whether place null values last. Default `false`. + /// Specify whether to place nulls last, per-column. Defaults all `false`. + /// + /// # Safety + /// + /// Len must match the number of columns, or be equal to 1. + pub fn with_nulls_last_multi(mut self, nulls_last: impl IntoIterator) -> Self { + self.nulls_last = nulls_last.into_iter().collect(); + self + } + + /// Whether to place null values last. Default `false`. pub fn with_nulls_last(mut self, enabled: bool) -> Self { - self.nulls_last = enabled; + self.nulls_last = vec![enabled]; self } - /// Whether sort in multiple threads. Default `true`. + /// Whether to sort in multiple threads. Default `true`. pub fn with_multithreaded(mut self, enabled: bool) -> Self { self.multithreaded = enabled; self } - /// Whether maintain the order of equal elements. Default `false`. + /// Whether to maintain the order of equal elements. Default `false`. pub fn with_maintain_order(mut self, enabled: bool) -> Self { self.maintain_order = enabled; self @@ -208,7 +221,7 @@ impl From<&SortOptions> for SortMultipleOptions { fn from(value: &SortOptions) -> Self { SortMultipleOptions { descending: vec![value.descending], - nulls_last: value.nulls_last, + nulls_last: vec![value.nulls_last], multithreaded: value.multithreaded, maintain_order: value.maintain_order, } @@ -219,7 +232,7 @@ impl From<&SortMultipleOptions> for SortOptions { fn from(value: &SortMultipleOptions) -> Self { SortOptions { descending: value.descending.first().copied().unwrap_or(false), - nulls_last: value.nulls_last, + nulls_last: value.nulls_last.first().copied().unwrap_or(false), multithreaded: value.multithreaded, maintain_order: value.maintain_order, } diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index a17f6f1ebde9..b3530be45c11 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -1773,8 +1773,8 @@ impl DataFrame { mut sort_options: SortMultipleOptions, slice: Option<(i64, usize)>, ) -> PolarsResult { - // note that the by_column argument also contains evaluated expression from polars-lazy - // that may not even be present in this dataframe. + // note that the by_column argument also contains evaluated expression from + // polars-lazy that may not even be present in this dataframe. // therefore when we try to set the first columns as sorted, we ignore the error // as expressions are not present (they are renamed to _POLARS_SORT_COLUMN_i. @@ -1782,9 +1782,8 @@ impl DataFrame { let first_by_column = by_column[0].name().to_string(); let set_sorted = |df: &mut DataFrame| { - // Mark the first sort column as sorted - // if the column did not exists it is ok, because we sorted by an expression - // not present in the dataframe + // Mark the first sort column as sorted; if the column does not exist it + // is ok, because we sorted by an expression not present in the dataframe let _ = df.apply(&first_by_column, |s| { let mut s = s.clone(); if first_descending { @@ -1795,14 +1794,11 @@ impl DataFrame { s }); }; - if self.is_empty() { let mut out = self.clone(); set_sorted(&mut out); - return Ok(out); } - if let Some((0, k)) = slice { return self.bottom_k_impl(k, by_column, sort_options); } @@ -1824,7 +1820,7 @@ impl DataFrame { let s = &by_column[0]; let options = SortOptions { descending: sort_options.descending[0], - nulls_last: sort_options.nulls_last, + nulls_last: sort_options.nulls_last[0], multithreaded: sort_options.multithreaded, maintain_order: sort_options.maintain_order, }; @@ -1836,13 +1832,12 @@ impl DataFrame { if let Some((offset, len)) = slice { out = out.slice(offset, len); } - return Ok(out.into_frame()); } s.arg_sort(options) }, _ => { - if sort_options.nulls_last + if sort_options.nulls_last.iter().all(|&x| x) || has_struct || std::env::var("POLARS_ROW_FMT_SORT").is_ok() { @@ -1899,7 +1894,7 @@ impl DataFrame { /// df.sort( /// &["sepal_width", "sepal_length"], /// SortMultipleOptions::new() - /// .with_order_descendings([false, true]) + /// .with_order_descending_multi([false, true]) /// ) /// } /// ``` diff --git a/crates/polars-core/src/series/implementations/struct_.rs b/crates/polars-core/src/series/implementations/struct_.rs index 6c267729f7da..971ef90904d4 100644 --- a/crates/polars-core/src/series/implementations/struct_.rs +++ b/crates/polars-core/src/series/implementations/struct_.rs @@ -315,13 +315,13 @@ impl SeriesTrait for SeriesWrap { fn sort_with(&self, options: SortOptions) -> PolarsResult { let df = self.0.clone().unnest(); - let desc = if options.descending { - vec![true; df.width()] - } else { - vec![false; df.width()] - }; + let n_cols = df.width(); + let desc = vec![options.descending; n_cols]; + let last = vec![options.nulls_last; n_cols]; - let multi_options = SortMultipleOptions::from(&options).with_order_descendings(desc); + let multi_options = SortMultipleOptions::from(&options) + .with_order_descending_multi(desc) + .with_nulls_last_multi(last); let out = df.sort_impl(df.columns.clone(), multi_options, None)?; Ok(StructChunked::new_unchecked(self.name(), &out.columns).into_series()) diff --git a/crates/polars-expr/src/expressions/sortby.rs b/crates/polars-expr/src/expressions/sortby.rs index 06c3ef65d976..e51e4722dac3 100644 --- a/crates/polars-expr/src/expressions/sortby.rs +++ b/crates/polars-expr/src/expressions/sortby.rs @@ -33,14 +33,14 @@ impl SortByExpr { } } -fn prepare_descending(descending: &[bool], by_len: usize) -> Vec { - match (descending.len(), by_len) { +fn prepare_bool_vec(values: &[bool], by_len: usize) -> Vec { + match (values.len(), by_len) { // Equal length. - (n_rdescending, n) if n_rdescending == n => descending.to_vec(), + (n_rvalues, n) if n_rvalues == n => values.to_vec(), // None given all false. (0, n) => vec![false; n], // Broadcast first. - (_, n) => vec![descending[0]; n], + (_, n) => vec![values[0]; n], } } @@ -141,7 +141,7 @@ fn sort_by_groups_multiple_by( let options = SortMultipleOptions { descending: descending.to_owned(), - nulls_last: false, + nulls_last: vec![false; descending.len()], multithreaded, maintain_order, }; @@ -157,7 +157,7 @@ fn sort_by_groups_multiple_by( let options = SortMultipleOptions { descending: descending.to_owned(), - nulls_last: false, + nulls_last: vec![false; descending.len()], multithreaded, maintain_order, }; @@ -178,8 +178,6 @@ impl PhysicalExpr for SortByExpr { } fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult { let series_f = || self.input.evaluate(df, state); - let descending = prepare_descending(&self.sort_options.descending, self.by.len()); - let (series, sorted_idx) = if self.by.len() == 1 { let sorted_idx_f = || { let s_sort_by = self.by[0].evaluate(df, state)?; @@ -187,6 +185,9 @@ impl PhysicalExpr for SortByExpr { }; POOL.install(|| rayon::join(series_f, sorted_idx_f)) } else { + let descending = prepare_bool_vec(&self.sort_options.descending, self.by.len()); + let nulls_last = prepare_bool_vec(&self.sort_options.nulls_last, self.by.len()); + let sorted_idx_f = || { let s_sort_by = self .by @@ -200,7 +201,12 @@ impl PhysicalExpr for SortByExpr { }) .collect::>>()?; - let options = self.sort_options.clone().with_order_descendings(descending); + let options = self + .sort_options + .clone() + .with_order_descending_multi(descending) + .with_nulls_last_multi(nulls_last); + s_sort_by[0].arg_sort_multiple(&s_sort_by[1..], &options) }; POOL.install(|| rayon::join(series_f, sorted_idx_f)) @@ -225,7 +231,7 @@ impl PhysicalExpr for SortByExpr { state: &ExecutionState, ) -> PolarsResult> { let mut ac_in = self.input.evaluate_on_groups(df, groups, state)?; - let descending = prepare_descending(&self.sort_options.descending, self.by.len()); + let descending = prepare_bool_vec(&self.sort_options.descending, self.by.len()); let mut ac_sort_by = self .by diff --git a/crates/polars-lazy/src/frame/mod.rs b/crates/polars-lazy/src/frame/mod.rs index 44155a040995..333aa505b21d 100644 --- a/crates/polars-lazy/src/frame/mod.rs +++ b/crates/polars-lazy/src/frame/mod.rs @@ -292,7 +292,7 @@ impl LazyFrame { /// df.lazy().sort( /// &["sepal_width", "sepal_length"], /// SortMultipleOptions::new() - /// .with_order_descendings([false, true]) + /// .with_order_descending_multi([false, true]) /// ) /// } /// ``` diff --git a/crates/polars-lazy/src/tests/queries.rs b/crates/polars-lazy/src/tests/queries.rs index 7b3c76487080..87e54ef25802 100644 --- a/crates/polars-lazy/src/tests/queries.rs +++ b/crates/polars-lazy/src/tests/queries.rs @@ -1036,7 +1036,7 @@ fn test_arg_sort_multiple() -> PolarsResult<()> { .lazy() .select([arg_sort_by( [col("int"), col("flt")], - SortMultipleOptions::default().with_order_descendings([true, false]), + SortMultipleOptions::default().with_order_descending_multi([true, false]), )]) .collect()?; @@ -1054,7 +1054,7 @@ fn test_arg_sort_multiple() -> PolarsResult<()> { .lazy() .select([arg_sort_by( [col("str"), col("flt")], - SortMultipleOptions::default().with_order_descendings([true, false]), + SortMultipleOptions::default().with_order_descending_multi([true, false]), )]) .collect()?; Ok(()) diff --git a/crates/polars-lazy/src/tests/streaming.rs b/crates/polars-lazy/src/tests/streaming.rs index 1ca82f18b832..d8d76384ed0c 100644 --- a/crates/polars-lazy/src/tests/streaming.rs +++ b/crates/polars-lazy/src/tests/streaming.rs @@ -104,7 +104,7 @@ fn test_streaming_multiple_keys_aggregate() -> PolarsResult<()> { ]) .sort_by_exprs( [col("sugars_g"), col("calories")], - SortMultipleOptions::default().with_order_descendings([false, false]), + SortMultipleOptions::default().with_order_descending_multi([false, false]), ); assert_streaming_with_default(q, true, false); diff --git a/crates/polars-lazy/src/tests/tpch.rs b/crates/polars-lazy/src/tests/tpch.rs index 0a647615d0ea..49eed184f72a 100644 --- a/crates/polars-lazy/src/tests/tpch.rs +++ b/crates/polars-lazy/src/tests/tpch.rs @@ -79,7 +79,7 @@ fn test_q2() -> PolarsResult<()> { .sort_by_exprs( [cols(["s_acctbal", "n_name", "s_name", "p_partkey"])], SortMultipleOptions::default() - .with_order_descendings([true, false, false, false]) + .with_order_descending_multi([true, false, false, false]) .with_maintain_order(true), ) .limit(100) diff --git a/crates/polars-ops/src/series/ops/various.rs b/crates/polars-ops/src/series/ops/various.rs index 327d2b193bb6..af16dad993da 100644 --- a/crates/polars-ops/src/series/ops/various.rs +++ b/crates/polars-ops/src/series/ops/various.rs @@ -75,8 +75,12 @@ pub trait SeriesMethods: SeriesSealed { // for struct types we row-encode and recurse #[cfg(feature = "dtype-struct")] if matches!(s.dtype(), DataType::Struct(_)) { - let encoded = - _get_rows_encoded_ca("", &[s.clone()], &[options.descending], options.nulls_last)?; + let encoded = _get_rows_encoded_ca( + "", + &[s.clone()], + &[options.descending], + &[options.nulls_last], + )?; return encoded.into_series().is_sorted(options); } diff --git a/crates/polars-pipe/src/executors/sinks/sort/sink.rs b/crates/polars-pipe/src/executors/sinks/sort/sink.rs index 2e19f27753b4..5bd51deba54a 100644 --- a/crates/polars-pipe/src/executors/sinks/sort/sink.rs +++ b/crates/polars-pipe/src/executors/sinks/sort/sink.rs @@ -207,7 +207,7 @@ impl Sink for SortSink { dist, self.sort_idx, self.sort_options.descending[0], - self.sort_options.nulls_last, + self.sort_options.nulls_last[0], self.slice, context.verbose, self.mem_track.clone(), diff --git a/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs b/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs index b24439330b80..1e7976afc431 100644 --- a/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs +++ b/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs @@ -1,7 +1,7 @@ use std::any::Any; use arrow::array::BinaryArray; -use polars_core::prelude::sort::_broadcast_descending; +use polars_core::prelude::sort::_broadcast_bools; use polars_core::prelude::sort::arg_sort_multiple::_get_rows_encoded_compat_array; use polars_core::prelude::*; use polars_core::series::IsSorted; @@ -16,10 +16,15 @@ const POLARS_SORT_COLUMN: &str = "__POLARS_SORT_COLUMN"; fn get_sort_fields(sort_idx: &[usize], sort_options: &SortMultipleOptions) -> Vec { let mut descending = sort_options.descending.clone(); - _broadcast_descending(sort_idx.len(), &mut descending); + let mut nulls_last = sort_options.nulls_last.clone(); + + _broadcast_bools(sort_idx.len(), &mut descending); + _broadcast_bools(sort_idx.len(), &mut nulls_last); + descending .into_iter() - .map(|descending| EncodingField::new_sorted(descending, sort_options.nulls_last)) + .zip(nulls_last) + .map(|(descending, nulls_last)| EncodingField::new_sorted(descending, nulls_last)) .collect() } diff --git a/crates/polars-plan/src/logical_plan/alp/tree_format.rs b/crates/polars-plan/src/logical_plan/alp/tree_format.rs index 7337a6c33201..0c1761fc95c0 100644 --- a/crates/polars-plan/src/logical_plan/alp/tree_format.rs +++ b/crates/polars-plan/src/logical_plan/alp/tree_format.rs @@ -52,11 +52,10 @@ impl fmt::Display for TreeFmtAExpr<'_> { for i in &sort_options.descending { write!(f, "{}", *i as u8)?; } - write!( - f, - "{}{}", - sort_options.nulls_last as u8, sort_options.multithreaded as u8 - )?; + for i in &sort_options.nulls_last { + write!(f, "{}", *i as u8)?; + } + write!(f, "{}", sort_options.multithreaded as u8)?; return Ok(()); }, AExpr::Filter { .. } => "filter", diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index a44788ae0ace..201b95d49a56 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -865,7 +865,7 @@ impl SQLContext { Ok(lf.sort_by_exprs( &by, SortMultipleOptions::default() - .with_order_descendings(descending) + .with_order_descending_multi(descending) .with_maintain_order(true), )) } diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index 037e9ef9fd21..2533f84906b2 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -1143,7 +1143,7 @@ impl SQLFunctionVisitor<'_> { cumulative_f( e.sort_by( &order_by, - SortMultipleOptions::default().with_order_descendings(desc.clone()), + SortMultipleOptions::default().with_order_descending_multi(desc.clone()), ), false, ) diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index 321f55a3c548..2acd69f4a6f0 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -851,7 +851,7 @@ impl SQLExprVisitor<'_> { let (order_by, descending) = self.visit_order_by(order_by)?; base = base.sort_by( order_by, - SortMultipleOptions::default().with_order_descendings(descending), + SortMultipleOptions::default().with_order_descending_multi(descending), ); } if let Some(limit) = &expr.limit { diff --git a/crates/polars-sql/tests/ops_distinct_on.rs b/crates/polars-sql/tests/ops_distinct_on.rs index 77bc4652a9f8..6d9c81a7be41 100644 --- a/crates/polars-sql/tests/ops_distinct_on.rs +++ b/crates/polars-sql/tests/ops_distinct_on.rs @@ -16,22 +16,22 @@ fn test_distinct_on() { ctx.register("df", df.clone()); let sql = r#" - SELECT DISTINCT ON ("Name") - "Name", - "Record Date", - "Score" - FROM - df - ORDER BY - "Name", - "Record Date" DESC;"#; + SELECT DISTINCT ON ("Name") + "Name", + "Record Date", + "Score" + FROM + df + ORDER BY + "Name", + "Record Date" DESC;"#; let lf = ctx.execute(sql).unwrap(); let actual = lf.collect().unwrap(); let expected = df .sort_by_exprs( vec![col("Name"), col("Record Date")], SortMultipleOptions::default() - .with_order_descendings([false, true]) + .with_order_descending_multi([false, true]) .with_maintain_order(true), ) .group_by_stable(vec![col("Name")]) diff --git a/crates/polars-sql/tests/simple_exprs.rs b/crates/polars-sql/tests/simple_exprs.rs index 60d0277e1abc..7f9aae7aeeaa 100644 --- a/crates/polars-sql/tests/simple_exprs.rs +++ b/crates/polars-sql/tests/simple_exprs.rs @@ -560,7 +560,7 @@ fn test_group_by_2() -> PolarsResult<()> { ]) .sort_by_exprs( vec![col("count"), col("category")], - SortMultipleOptions::default().with_order_descendings([false, true]), + SortMultipleOptions::default().with_order_descending_multi([false, true]), ) .limit(2); let expected = expected.collect()?; diff --git a/py-polars/polars/_utils/various.py b/py-polars/polars/_utils/various.py index 5644b1cd2a86..3839686f8840 100644 --- a/py-polars/polars/_utils/various.py +++ b/py-polars/polars/_utils/various.py @@ -485,6 +485,23 @@ def _polars_warn(msg: str, category: type[Warning] = UserWarning) -> None: ) +def extend_bool( + value: bool | Sequence[bool], + n_match: int, + value_name: str, + match_name: str, +) -> Sequence[bool]: + """Ensure the given bool or sequence of bools is the correct length.""" + values = [value] * n_match if isinstance(value, bool) else value + if n_match != len(values): + msg = ( + f"the length of `{value_name}` ({len(values)}) " + f"does not match the length of `{match_name}` ({n_match})" + ) + raise ValueError(msg) + return values + + def in_terminal_that_supports_colour() -> bool: """ Determine (within reason) if we are in an interactive terminal that supports color. diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 0d40723d2e74..a5295cf61920 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -4546,7 +4546,7 @@ def sort( by: IntoExpr | Iterable[IntoExpr], *more_by: IntoExpr, descending: bool | Sequence[bool] = False, - nulls_last: bool = False, + nulls_last: bool | Sequence[bool] = False, multithreaded: bool = True, maintain_order: bool = False, ) -> DataFrame: @@ -4564,7 +4564,8 @@ def sort( Sort in descending order. When sorting by multiple columns, can be specified per column by passing a sequence of booleans. nulls_last - Place null values last. + Place null values last; can specify a single boolean applying to all columns + or a sequence of booleans for per-column control. multithreaded Sort using multiple threads. maintain_order @@ -4750,7 +4751,7 @@ def top_k( *, by: IntoExpr | Iterable[IntoExpr], descending: bool | Sequence[bool] = False, - nulls_last: bool | None = None, + nulls_last: bool | Sequence[bool] | None = None, maintain_order: bool | None = None, ) -> DataFrame: """ @@ -4851,7 +4852,7 @@ def bottom_k( *, by: IntoExpr | Iterable[IntoExpr], descending: bool | Sequence[bool] = False, - nulls_last: bool | None = None, + nulls_last: bool | Sequence[bool] | None = None, maintain_order: bool | None = None, ) -> DataFrame: """ diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 2a089b43e127..aaf57bf9210f 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -44,6 +44,7 @@ from polars._utils.unstable import issue_unstable_warning, unstable from polars._utils.various import ( BUILDING_SPHINX_DOCS, + extend_bool, find_stacklevel, no_default, normalize_filepath, @@ -2155,7 +2156,7 @@ def top_k_by( k: int | IntoExprColumn = 5, *, descending: bool | Sequence[bool] = False, - nulls_last: bool | None = None, + nulls_last: bool | Sequence[bool] | None = None, maintain_order: bool | None = None, multithreaded: bool | None = None, ) -> Self: @@ -2319,11 +2320,8 @@ def top_k_by( k = parse_as_expression(k) by = parse_as_list_of_expressions(by) - if isinstance(descending, bool): - descending = [descending] - elif len(by) != len(descending): - msg = f"the length of `descending` ({len(descending)}) does not match the length of `by` ({len(by)})" - raise ValueError(msg) + descending = extend_bool(descending, len(by), "descending", "by") + nulls_last = extend_bool(nulls_last, len(by), "nulls_last", "by") return self._from_pyexpr( self._pyexpr.top_k_by( k, @@ -2454,7 +2452,7 @@ def bottom_k_by( k: int | IntoExprColumn = 5, *, descending: bool | Sequence[bool] = False, - nulls_last: bool | None = None, + nulls_last: bool | Sequence[bool] | None = None, maintain_order: bool | None = None, multithreaded: bool | None = None, ) -> Self: @@ -2613,11 +2611,8 @@ def bottom_k_by( k = parse_as_expression(k) by = parse_as_list_of_expressions(by) - if isinstance(descending, bool): - descending = [descending] - elif len(by) != len(descending): - msg = f"the length of `descending` ({len(descending)}) does not match the length of `by` ({len(by)})" - raise ValueError(msg) + descending = extend_bool(descending, len(by), "descending", "by") + nulls_last = extend_bool(nulls_last, len(by), "nulls_last", "by") return self._from_pyexpr( self._pyexpr.bottom_k_by( k, @@ -2780,7 +2775,7 @@ def sort_by( by: IntoExpr | Iterable[IntoExpr], *more_by: IntoExpr, descending: bool | Sequence[bool] = False, - nulls_last: bool = False, + nulls_last: bool | Sequence[bool] = False, multithreaded: bool = True, maintain_order: bool = False, ) -> Self: @@ -2801,7 +2796,8 @@ def sort_by( Sort in descending order. When sorting by multiple columns, can be specified per column by passing a sequence of booleans. nulls_last - Place null values last. + Place null values last; can specify a single boolean applying to all columns + or a sequence of booleans for per-column control. multithreaded Sort using multiple threads. maintain_order @@ -2908,11 +2904,8 @@ def sort_by( └───────┴────────┴────────┘ """ by = parse_as_list_of_expressions(by, *more_by) - if isinstance(descending, bool): - descending = [descending] - elif len(by) != len(descending): - msg = f"the length of `descending` ({len(descending)}) does not match the length of `by` ({len(by)})" - raise ValueError(msg) + descending = extend_bool(descending, len(by), "descending", "by") + nulls_last = extend_bool(nulls_last, len(by), "nulls_last", "by") return self._from_pyexpr( self._pyexpr.sort_by( by, descending, nulls_last, multithreaded, maintain_order diff --git a/py-polars/polars/functions/lazy.py b/py-polars/polars/functions/lazy.py index 302614a7ac19..5c9fe2917e9a 100644 --- a/py-polars/polars/functions/lazy.py +++ b/py-polars/polars/functions/lazy.py @@ -16,6 +16,7 @@ parse_as_list_of_expressions, ) from polars._utils.unstable import issue_unstable_warning, unstable +from polars._utils.various import extend_bool from polars._utils.wrap import wrap_df, wrap_expr from polars.datatypes import DTYPE_TEMPORAL_UNITS, Date, Datetime, Int64, UInt32 @@ -1640,7 +1641,7 @@ def arg_sort_by( exprs: IntoExpr | Iterable[IntoExpr], *more_exprs: IntoExpr, descending: bool | Sequence[bool] = False, - nulls_last: bool = False, + nulls_last: bool | Sequence[bool] = False, multithreaded: bool = True, maintain_order: bool = False, ) -> Expr: @@ -1725,12 +1726,8 @@ def arg_sort_by( └─────┘ """ exprs = parse_as_list_of_expressions(exprs, *more_exprs) - - if isinstance(descending, bool): - descending = [descending] * len(exprs) - elif len(exprs) != len(descending): - msg = f"the length of `descending` ({len(descending)}) does not match the length of `exprs` ({len(exprs)})" - raise ValueError(msg) + descending = extend_bool(descending, len(exprs), "descending", "exprs") + nulls_last = extend_bool(nulls_last, len(exprs), "nulls_last", "exprs") return wrap_expr( plr.arg_sort_by(exprs, descending, nulls_last, multithreaded, maintain_order) ) diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index 71670591185e..68bf98379a68 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -42,6 +42,7 @@ from polars._utils.various import ( _in_notebook, _is_generator, + extend_bool, is_bool_sequence, is_sequence, normalize_filepath, @@ -1178,7 +1179,7 @@ def sort( by: IntoExpr | Iterable[IntoExpr], *more_by: IntoExpr, descending: bool | Sequence[bool] = False, - nulls_last: bool = False, + nulls_last: bool | Sequence[bool] = False, maintain_order: bool = False, multithreaded: bool = True, ) -> Self: @@ -1196,7 +1197,8 @@ def sort( Sort in descending order. When sorting by multiple columns, can be specified per column by passing a sequence of booleans. nulls_last - Place null values last. + Place null values last; can specify a single boolean applying to all columns + or a sequence of booleans for per-column control. maintain_order Whether the order should be maintained if elements are equal. Note that if `true` streaming is not possible and performance might be @@ -1278,12 +1280,8 @@ def sort( ) by = parse_as_list_of_expressions(by, *more_by) - - if isinstance(descending, bool): - descending = [descending] - elif len(by) != len(descending): - msg = f"the length of `descending` ({len(descending)}) does not match the length of `by` ({len(by)})" - raise ValueError(msg) + descending = extend_bool(descending, len(by), "descending", "by") + nulls_last = extend_bool(nulls_last, len(by), "nulls_last", "by") return self._from_pyldf( self._ldf.sort_by_exprs( by, descending, nulls_last, maintain_order, multithreaded @@ -1384,7 +1382,7 @@ def top_k( *, by: IntoExpr | Iterable[IntoExpr], descending: bool | Sequence[bool] = False, - nulls_last: bool | None = None, + nulls_last: bool | Sequence[bool] | None = None, maintain_order: bool | None = None, multithreaded: bool | None = None, ) -> Self: @@ -1502,11 +1500,8 @@ def top_k( multithreaded = True by = parse_as_list_of_expressions(by) - if isinstance(descending, bool): - descending = [descending] - elif len(by) != len(descending): - msg = f"the length of `descending` ({len(descending)}) does not match the length of `by` ({len(by)})" - raise ValueError(msg) + descending = extend_bool(descending, len(by), "descending", "by") + nulls_last = extend_bool(nulls_last, len(by), "nulls_last", "by") return self._from_pyldf( self._ldf.top_k( k, by, descending, nulls_last, maintain_order, multithreaded @@ -1519,7 +1514,7 @@ def bottom_k( *, by: IntoExpr | Iterable[IntoExpr], descending: bool | Sequence[bool] = False, - nulls_last: bool | None = None, + nulls_last: bool | Sequence[bool] | None = None, maintain_order: bool | None = None, multithreaded: bool | None = None, ) -> Self: @@ -1637,8 +1632,8 @@ def bottom_k( multithreaded = True by = parse_as_list_of_expressions(by) - if isinstance(descending, bool): - descending = [descending] + descending = extend_bool(descending, len(by), "descending", "by") + nulls_last = extend_bool(nulls_last, len(by), "nulls_last", "by") return self._from_pyldf( self._ldf.bottom_k( k, by, descending, nulls_last, maintain_order, multithreaded diff --git a/py-polars/src/expr/general.rs b/py-polars/src/expr/general.rs index 9e2475d99658..b80d687a6d5a 100644 --- a/py-polars/src/expr/general.rs +++ b/py-polars/src/expr/general.rs @@ -312,7 +312,7 @@ impl PyExpr { k: Self, by: Vec, descending: Vec, - nulls_last: bool, + nulls_last: Vec, maintain_order: bool, multithreaded: bool, ) -> Self { @@ -358,7 +358,7 @@ impl PyExpr { k: Self, by: Vec, descending: Vec, - nulls_last: bool, + nulls_last: Vec, maintain_order: bool, multithreaded: bool, ) -> Self { @@ -415,7 +415,7 @@ impl PyExpr { &self, by: Vec, descending: Vec, - nulls_last: bool, + nulls_last: Vec, multithreaded: bool, maintain_order: bool, ) -> Self { diff --git a/py-polars/src/functions/lazy.rs b/py-polars/src/functions/lazy.rs index 0eb7d293d8a4..d2d294beee9b 100644 --- a/py-polars/src/functions/lazy.rs +++ b/py-polars/src/functions/lazy.rs @@ -61,7 +61,7 @@ pub fn rolling_cov( pub fn arg_sort_by( by: Vec, descending: Vec, - nulls_last: bool, + nulls_last: Vec, multithreaded: bool, maintain_order: bool, ) -> PyExpr { diff --git a/py-polars/src/lazyframe/mod.rs b/py-polars/src/lazyframe/mod.rs index 7aa96c9640a9..3434a1992f88 100644 --- a/py-polars/src/lazyframe/mod.rs +++ b/py-polars/src/lazyframe/mod.rs @@ -494,7 +494,7 @@ impl PyLazyFrame { [by_column], SortMultipleOptions { descending: vec![descending], - nulls_last, + nulls_last: vec![nulls_last], multithreaded, maintain_order, }, @@ -506,7 +506,7 @@ impl PyLazyFrame { &self, by: Vec, descending: Vec, - nulls_last: bool, + nulls_last: Vec, maintain_order: bool, multithreaded: bool, ) -> Self { @@ -529,7 +529,7 @@ impl PyLazyFrame { k: IdxSize, by: Vec, descending: Vec, - nulls_last: bool, + nulls_last: Vec, maintain_order: bool, multithreaded: bool, ) -> Self { @@ -553,7 +553,7 @@ impl PyLazyFrame { k: IdxSize, by: Vec, descending: Vec, - nulls_last: bool, + nulls_last: Vec, maintain_order: bool, multithreaded: bool, ) -> Self { diff --git a/py-polars/src/lazyframe/visitor/expr_nodes.rs b/py-polars/src/lazyframe/visitor/expr_nodes.rs index 04102812d3e2..4c9d6493515f 100644 --- a/py-polars/src/lazyframe/visitor/expr_nodes.rs +++ b/py-polars/src/lazyframe/visitor/expr_nodes.rs @@ -308,7 +308,7 @@ pub struct SortBy { by: Vec, #[pyo3(get)] /// maintain_order, nulls_last, descending - sort_options: (bool, bool, Vec), + sort_options: (bool, Vec, Vec), } #[pyclass] @@ -594,7 +594,7 @@ pub(crate) fn into_py(py: Python<'_>, expr: &AExpr) -> PyResult { by: by.iter().map(|n| n.0).collect(), sort_options: ( sort_options.maintain_order, - sort_options.nulls_last, + sort_options.nulls_last.clone(), sort_options.descending.clone(), ), } diff --git a/py-polars/src/lazyframe/visitor/nodes.rs b/py-polars/src/lazyframe/visitor/nodes.rs index 15e6e32dce2e..9fd83245805a 100644 --- a/py-polars/src/lazyframe/visitor/nodes.rs +++ b/py-polars/src/lazyframe/visitor/nodes.rs @@ -141,7 +141,7 @@ pub struct Sort { #[pyo3(get)] by_column: Vec, #[pyo3(get)] - sort_options: (bool, bool, Vec), + sort_options: (bool, Vec, Vec), #[pyo3(get)] slice: Option<(i64, usize)>, } @@ -360,7 +360,7 @@ pub(crate) fn into_py(py: Python<'_>, plan: &IR) -> PyResult { by_column: by_column.iter().map(|e| e.into()).collect(), sort_options: ( sort_options.maintain_order, - sort_options.nulls_last, + sort_options.nulls_last.clone(), sort_options.descending.clone(), ), slice: *slice, diff --git a/py-polars/tests/unit/datatypes/test_struct.py b/py-polars/tests/unit/datatypes/test_struct.py index d53b22dd7c7e..13fd6dab253c 100644 --- a/py-polars/tests/unit/datatypes/test_struct.py +++ b/py-polars/tests/unit/datatypes/test_struct.py @@ -735,13 +735,19 @@ def test_struct_concat_self_no_rechunk() -> None: def test_sort_structs() -> None: - assert pl.DataFrame( - {"sex": ["male", "female", "female"], "age": [22, 38, 26]} - ).select(pl.struct(["sex", "age"]).sort()).unnest("sex").to_dict( - as_series=False - ) == { - "sex": ["female", "female", "male"], - "age": [26, 38, 22], + df = pl.DataFrame( + { + "sex": ["m", "f", "f", "f", "m", "m", "f"], + "age": [22, 38, 26, 24, 21, 46, 22], + }, + ) + df_sorted_as_struct = df.select(pl.struct(["sex", "age"]).sort()).unnest("sex") + df_expected = df.sort(by=["sex", "age"]) + + assert_frame_equal(df_expected, df_sorted_as_struct) + assert df_sorted_as_struct.to_dict(as_series=False) == { + "sex": ["f", "f", "f", "f", "m", "m", "m"], + "age": [22, 24, 26, 38, 21, 22, 46], } diff --git a/py-polars/tests/unit/operations/test_sort.py b/py-polars/tests/unit/operations/test_sort.py index 68dc0b7fc7f1..f1d823be8cf1 100644 --- a/py-polars/tests/unit/operations/test_sort.py +++ b/py-polars/tests/unit/operations/test_sort.py @@ -74,34 +74,54 @@ def test_sort_by() -> None: def test_expr_sort_by_nulls_last() -> None: + df = pl.DataFrame({"a": [1, 2, None, None, 5], "b": [None, 1, 1, 2, None]}) + # nulls last - df = pl.DataFrame( - {"a": [1, 2, None, None, 5], "b": [None, 1, 1, 2, None], "c": [2, 3, 1, 2, 1]} - ) + expected = pl.DataFrame({"a": [1, 2, 5, None, None], "b": [None, 1, None, 1, 2]}) + out = df.select(pl.all().sort_by("a", nulls_last=True)) + assert_frame_equal(out, expected) + + # nulls first (default) + expected = pl.DataFrame({"a": [None, None, 1, 2, 5], "b": [1, 2, None, 1, None]}) + for out in ( + df.select(pl.all().sort_by("a", nulls_last=False)), + df.select(pl.all().sort_by("a")), + ): + assert_frame_equal(out, expected) - out = df.select(pl.all().sort_by("a", nulls_last=True, maintain_order=True)) - excepted = pl.DataFrame( - { - "a": [1, 2, 5, None, None], - "b": [None, 1, None, 1, 2], - "c": [2, 3, 1, 1, 2], - } - ) +def test_expr_sort_by_multi_nulls_last() -> None: + df = pl.DataFrame({"x": [None, 1, None, 3], "y": [3, 2, None, 1]}) - assert_frame_equal(out, excepted) + res = df.sort("x", "y", nulls_last=[False, True]) + assert res.to_dict(as_series=False) == { + "x": [None, None, 1, 3], + "y": [3, None, 2, 1], + } - # nulls first + res = df.sort("x", "y", nulls_last=[True, False]) + assert res.to_dict(as_series=False) == { + "x": [1, 3, None, None], + "y": [2, 1, None, 3], + } - out = df.select(pl.all().sort_by("a", nulls_last=False, maintain_order=True)) + res = df.sort("x", "y", nulls_last=[True, False], descending=True) + assert res.to_dict(as_series=False) == { + "x": [3, 1, None, None], + "y": [1, 2, None, 3], + } - excepted = pl.DataFrame( - { - "a": [None, None, 1, 2, 5], - "b": [1, 2, None, 1, None], - "c": [1, 2, 2, 3, 1], - } - ) + res = df.sort("x", "y", nulls_last=[False, True], descending=True) + assert res.to_dict(as_series=False) == { + "x": [None, None, 3, 1], + "y": [3, None, 1, 2], + } + + res = df.sort("x", "y", nulls_last=[False, True], descending=[True, False]) + assert res.to_dict(as_series=False) == { + "x": [None, None, 3, 1], + "y": [3, None, 1, 2], + } def test_sort_by_exprs() -> None: @@ -114,38 +134,41 @@ def test_sort_by_exprs() -> None: def test_arg_sort_nulls() -> None: a = pl.Series("a", [1.0, 2.0, 3.0, None, None]) + assert a.arg_sort(nulls_last=True).to_list() == [0, 1, 2, 3, 4] assert a.arg_sort(nulls_last=False).to_list() == [3, 4, 0, 1, 2] - assert a.to_frame().sort(by="a", nulls_last=False).to_series().to_list() == [ - None, - None, - 1.0, - 2.0, - 3.0, - ] - assert a.to_frame().sort(by="a", nulls_last=True).to_series().to_list() == [ - 1.0, - 2.0, - 3.0, - None, - None, - ] + res = a.to_frame().sort(by="a", nulls_last=False).to_series().to_list() + assert res == [None, None, 1.0, 2.0, 3.0] + res = a.to_frame().sort(by="a", nulls_last=True).to_series().to_list() + assert res == [1.0, 2.0, 3.0, None, None] -def test_expr_arg_sort_nulls_last() -> None: + +@pytest.mark.parametrize( + ("nulls_last", "expected"), + [ + (True, [0, 1, 4, 3, 2]), + (False, [2, 3, 0, 1, 4]), + ([True, False], [0, 1, 4, 2, 3]), + ([False, True], [3, 2, 0, 1, 4]), + ], +) +def test_expr_arg_sort_nulls_last( + nulls_last: bool | list[bool], expected: list[int] +) -> None: df = pl.DataFrame( - {"a": [1, 2, None, None, 5], "b": [None, 1, 2, 1, None], "c": [2, 3, 1, 2, 1]} + { + "a": [1, 2, None, None, 5], + "b": [1, 2, None, 1, None], + "c": [2, 3, 1, 2, 1], + }, ) - out = ( - df.select(pl.arg_sort_by("a", "b", nulls_last=True, maintain_order=True)) + df.select(pl.arg_sort_by("a", "b", nulls_last=nulls_last, maintain_order=True)) .to_series() .to_list() ) - - expected = [0, 1, 4, 3, 2] - assert out == expected @@ -157,7 +180,6 @@ def test_arg_sort_window_functions() -> None: pl.arg_sort_by("Age").over("Id").alias("arg_sort_by"), ] ) - assert ( out["arg_sort"].to_list() == out["arg_sort_by"].to_list() == [0, 1, 0, 1, 0, 1] )