From fa12a8e1ba2ac61b230f46f1ec81b2d33882bdf4 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Sat, 8 Jun 2024 11:43:16 +0200 Subject: [PATCH] refactor: Improve safety of amortized_iter (#16820) --- .../src/chunked_array/array/iterator.rs | 84 +++++----- .../src/chunked_array/comparison/mod.rs | 38 ++--- .../src/chunked_array/list/iterator.rs | 149 ++++++++---------- .../series/{unstable.rs => amortized_iter.rs} | 57 +++---- crates/polars-core/src/series/mod.rs | 2 +- crates/polars-core/src/utils/series.rs | 12 +- crates/polars-expr/src/expressions/filter.rs | 3 +- crates/polars-expr/src/expressions/gather.rs | 2 +- .../polars-expr/src/expressions/group_iter.rs | 55 +++---- crates/polars-expr/src/expressions/mod.rs | 3 +- .../src/chunked_array/array/dispersion.rs | 3 +- .../src/chunked_array/array/join.rs | 3 +- .../src/chunked_array/array/sum_mean.rs | 3 +- .../src/chunked_array/list/namespace.rs | 15 +- crates/polars-ops/src/series/ops/is_in.rs | 33 ++-- crates/polars-ops/src/series/ops/rolling.rs | 15 +- .../polars-plan/src/dsl/function_expr/list.rs | 67 ++++---- py-polars/src/series/export.rs | 6 +- 18 files changed, 241 insertions(+), 309 deletions(-) rename crates/polars-core/src/series/{unstable.rs => amortized_iter.rs} (62%) diff --git a/crates/polars-core/src/chunked_array/array/iterator.rs b/crates/polars-core/src/chunked_array/array/iterator.rs index 271e59aee490..52785f1208db 100644 --- a/crates/polars-core/src/chunked_array/array/iterator.rs +++ b/crates/polars-core/src/chunked_array/array/iterator.rs @@ -2,7 +2,7 @@ use std::ptr::NonNull; use super::*; use crate::chunked_array::list::iterator::AmortizedListIter; -use crate::series::unstable::{unstable_series_container_and_ptr, ArrayBox, UnstableSeries}; +use crate::series::amortized_iter::{unstable_series_container_and_ptr, AmortSeries, ArrayBox}; impl ArrayChunked { /// This is an iterator over a [`ArrayChunked`] that save allocations. @@ -23,11 +23,9 @@ impl ArrayChunked { /// that Series. /// /// # Safety - /// The lifetime of [UnstableSeries] is bound to the iterator. Keeping it alive + /// The lifetime of [AmortSeries] is bound to the iterator. Keeping it alive /// longer than the iterator is UB. - pub unsafe fn amortized_iter( - &self, - ) -> AmortizedListIter> + '_> { + pub fn amortized_iter(&self) -> AmortizedListIter> + '_> { self.amortized_iter_with_name("") } @@ -37,21 +35,14 @@ impl ArrayChunked { /// ChunkedArray is: /// 2. Vec< 3. ArrayRef> /// - /// The [`ArrayRef`] we indicated with 3. will be updated during iteration. + /// The ArrayRef we indicated with 3. will be updated during iteration. /// The Series will be pinned in memory, saving an allocation for /// 1. Arc<..> /// 2. Vec<...> /// - /// # Warning - /// Though memory safe in the sense that it will not read unowned memory, UB, or memory leaks - /// this function still needs precautions. The returned should never be cloned or taken longer - /// than a single iteration, as every call on `next` of the iterator will change the contents of - /// that Series. - /// - /// # Safety - /// The lifetime of [UnstableSeries] is bound to the iterator. Keeping it alive - /// longer than the iterator is UB. - pub unsafe fn amortized_iter_with_name( + /// If the returned `AmortSeries` is cloned, the local copy will be replaced and a new container + /// will be set. + pub fn amortized_iter_with_name( &self, name: &str, ) -> AmortizedListIter> + '_> { @@ -75,18 +66,21 @@ impl ArrayChunked { let (s, ptr) = unsafe { unstable_series_container_and_ptr(name, inner_values.clone(), &iter_dtype) }; - AmortizedListIter::new( - self.len(), - s, - NonNull::new(ptr).unwrap(), - self.downcast_iter().flat_map(|arr| arr.iter()), - inner_dtype.clone(), - ) + // SAFETY: `ptr` belongs to the `Series`. + unsafe { + AmortizedListIter::new( + self.len(), + s, + NonNull::new(ptr).unwrap(), + self.downcast_iter().flat_map(|arr| arr.iter()), + inner_dtype.clone(), + ) + } } - pub fn try_apply_amortized_to_list<'a, F>(&'a self, mut f: F) -> PolarsResult + pub fn try_apply_amortized_to_list(&self, mut f: F) -> PolarsResult where - F: FnMut(UnstableSeries<'a>) -> PolarsResult, + F: FnMut(AmortSeries) -> PolarsResult, { if self.is_empty() { return Ok(Series::new_empty( @@ -98,8 +92,7 @@ impl ArrayChunked { .clone()); } let mut fast_explode = self.null_count() == 0; - // SAFETY: lifetime of iterator is bound to this functions scope - let mut ca: ListChunked = unsafe { + let mut ca: ListChunked = { self.amortized_iter() .map(|opt_v| { opt_v @@ -128,9 +121,9 @@ impl ArrayChunked { /// # Safety /// Return series of `F` must has the same dtype and number of elements as input. #[must_use] - pub unsafe fn apply_amortized_same_type<'a, F>(&'a self, mut f: F) -> Self + pub unsafe fn apply_amortized_same_type(&self, mut f: F) -> Self where - F: FnMut(UnstableSeries<'a>) -> Series, + F: FnMut(AmortSeries) -> Series, { if self.is_empty() { return self.clone(); @@ -149,9 +142,9 @@ impl ArrayChunked { /// /// # Safety /// Return series of `F` must has the same dtype and number of elements as input if it is Ok. - pub unsafe fn try_apply_amortized_same_type<'a, F>(&'a self, mut f: F) -> PolarsResult + pub unsafe fn try_apply_amortized_same_type(&self, mut f: F) -> PolarsResult where - F: FnMut(UnstableSeries<'a>) -> PolarsResult, + F: FnMut(AmortSeries) -> PolarsResult, { if self.is_empty() { return Ok(self.clone()); @@ -180,7 +173,7 @@ impl ArrayChunked { ) -> Self where T: PolarsDataType, - F: FnMut(Option>, Option>) -> Option, + F: FnMut(Option, Option>) -> Option, { if self.is_empty() { return self.clone(); @@ -196,33 +189,36 @@ impl ArrayChunked { /// Apply a closure `F` elementwise. #[must_use] - pub fn apply_amortized_generic<'a, F, K, V>(&'a self, f: F) -> ChunkedArray + pub fn apply_amortized_generic(&self, f: F) -> ChunkedArray where V: PolarsDataType, - F: FnMut(Option>) -> Option + Copy, + F: FnMut(Option) -> Option + Copy, V::Array: ArrayFromIter>, { - // SAFETY: lifetime of iterator is bound to this functions scope - unsafe { self.amortized_iter().map(f).collect_ca(self.name()) } + { + self.amortized_iter().map(f).collect_ca(self.name()) + } } /// Try apply a closure `F` elementwise. - pub fn try_apply_amortized_generic<'a, F, K, V>(&'a self, f: F) -> PolarsResult> + pub fn try_apply_amortized_generic(&self, f: F) -> PolarsResult> where V: PolarsDataType, - F: FnMut(Option>) -> PolarsResult> + Copy, + F: FnMut(Option) -> PolarsResult> + Copy, V::Array: ArrayFromIter>, { - // SAFETY: lifetime of iterator is bound to this functions scope - unsafe { self.amortized_iter().map(f).try_collect_ca(self.name()) } + { + self.amortized_iter().map(f).try_collect_ca(self.name()) + } } - pub fn for_each_amortized<'a, F>(&'a self, f: F) + pub fn for_each_amortized(&self, f: F) where - F: FnMut(Option>), + F: FnMut(Option), { - // SAFETY: lifetime of iterator is bound to this functions scope - unsafe { self.amortized_iter().for_each(f) } + { + self.amortized_iter().for_each(f) + } } } diff --git a/crates/polars-core/src/chunked_array/comparison/mod.rs b/crates/polars-core/src/chunked_array/comparison/mod.rs index 57fdfc8c0591..9fae970632f5 100644 --- a/crates/polars-core/src/chunked_array/comparison/mod.rs +++ b/crates/polars-core/src/chunked_array/comparison/mod.rs @@ -560,34 +560,26 @@ where match (lhs.len(), rhs.len()) { (_, 1) => { let right = rhs.get_as_series(0).map(|s| s.with_name("")); - // SAFETY: values within iterator do not outlive the iterator itself - unsafe { - lhs.amortized_iter() - .map(|left| op(left.as_ref().map(|us| us.as_ref()), right.as_ref())) - .collect_trusted() - } + lhs.amortized_iter() + .map(|left| op(left.as_ref().map(|us| us.as_ref()), right.as_ref())) + .collect_trusted() }, (1, _) => { let left = lhs.get_as_series(0).map(|s| s.with_name("")); - // SAFETY: values within iterator do not outlive the iterator itself - unsafe { - rhs.amortized_iter() - .map(|right| op(left.as_ref(), right.as_ref().map(|us| us.as_ref()))) - .collect_trusted() - } - }, - // SAFETY: values within iterator do not outlive the iterator itself - _ => unsafe { - lhs.amortized_iter() - .zip(rhs.amortized_iter()) - .map(|(left, right)| { - op( - left.as_ref().map(|us| us.as_ref()), - right.as_ref().map(|us| us.as_ref()), - ) - }) + rhs.amortized_iter() + .map(|right| op(left.as_ref(), right.as_ref().map(|us| us.as_ref()))) .collect_trusted() }, + _ => lhs + .amortized_iter() + .zip(rhs.amortized_iter()) + .map(|(left, right)| { + op( + left.as_ref().map(|us| us.as_ref()), + right.as_ref().map(|us| us.as_ref()), + ) + }) + .collect_trusted(), } } diff --git a/crates/polars-core/src/chunked_array/list/iterator.rs b/crates/polars-core/src/chunked_array/list/iterator.rs index a86353a52152..19d6fc90c952 100644 --- a/crates/polars-core/src/chunked_array/list/iterator.rs +++ b/crates/polars-core/src/chunked_array/list/iterator.rs @@ -1,12 +1,15 @@ use std::marker::PhantomData; use std::ptr::NonNull; +use std::rc::Rc; + +use polars_utils::unwrap::UnwrapUncheckedRelease; use crate::prelude::*; -use crate::series::unstable::{unstable_series_container_and_ptr, ArrayBox, UnstableSeries}; +use crate::series::amortized_iter::{unstable_series_container_and_ptr, AmortSeries, ArrayBox}; pub struct AmortizedListIter<'a, I: Iterator>> { len: usize, - series_container: Box, + series_container: Rc, inner: NonNull, lifetime: PhantomData<&'a ArrayRef>, iter: I, @@ -16,7 +19,7 @@ pub struct AmortizedListIter<'a, I: Iterator>> { } impl<'a, I: Iterator>> AmortizedListIter<'a, I> { - pub(crate) fn new( + pub(crate) unsafe fn new( len: usize, series_container: Series, inner: NonNull, @@ -25,7 +28,7 @@ impl<'a, I: Iterator>> AmortizedListIter<'a, I> { ) -> Self { Self { len, - series_container: Box::new(series_container), + series_container: Rc::new(series_container), inner, lifetime: PhantomData, iter, @@ -35,7 +38,7 @@ impl<'a, I: Iterator>> AmortizedListIter<'a, I> { } impl<'a, I: Iterator>> Iterator for AmortizedListIter<'a, I> { - type Item = Option>; + type Item = Option; fn next(&mut self) -> Option { self.iter.next().map(|opt_val| { @@ -47,24 +50,23 @@ impl<'a, I: Iterator>> Iterator for AmortizedListIter<'a // SAFETY: // dtype is known unsafe { - let mut s = Series::from_chunks_and_dtype_unchecked( + let s = Series::from_chunks_and_dtype_unchecked( "", vec![array_ref], &self.inner_dtype.to_physical(), ) .cast_unchecked(&self.inner_dtype) .unwrap(); - // swap the new series with the container - std::mem::swap(&mut *self.series_container, &mut s); - // return a reference to the container - // this lifetime is now bound to 'a - return UnstableSeries::new( - &mut *(&mut *self.series_container as *mut Series), - ); + let inner = Rc::make_mut(&mut self.series_container); + *inner = s; + + return AmortSeries::new(self.series_container.clone()); } } // The series is cloned, we make a new container. - if Arc::strong_count(&self.series_container.0) > 1 { + if Arc::strong_count(&self.series_container.0) > 1 + || Rc::strong_count(&self.series_container) > 1 + { let (s, ptr) = unsafe { unstable_series_container_and_ptr( self.series_container.name(), @@ -72,25 +74,26 @@ impl<'a, I: Iterator>> Iterator for AmortizedListIter<'a self.series_container.dtype(), ) }; - *self.series_container.as_mut() = s; + self.series_container = Rc::new(s); self.inner = NonNull::new(ptr).unwrap(); } else { + // SAFETY: we checked the RC above; + let series_mut = unsafe { + Rc::get_mut(&mut self.series_container).unwrap_unchecked_release() + }; // update the inner state unsafe { *self.inner.as_mut() = array_ref }; // last iteration could have set the sorted flag (e.g. in compute_len) - self.series_container.clear_flags(); + series_mut.clear_flags(); // make sure that the length is correct - self.series_container._get_inner_mut().compute_len(); + series_mut._get_inner_mut().compute_len(); } // SAFETY: - // we cannot control the lifetime of an iterators `next` method. - // but as long as self is alive the reference to the series container is valid - let refer = &mut *self.series_container; + // inner belongs to Series. unsafe { - let s = std::mem::transmute::<&mut Series, &'a mut Series>(refer); - UnstableSeries::new_with_chunk(s, self.inner.as_ref()) + AmortSeries::new_with_chunk(self.series_container.clone(), self.inner.as_ref()) } }) }) @@ -106,7 +109,7 @@ impl<'a, I: Iterator>> Iterator for AmortizedListIter<'a unsafe impl<'a, I: Iterator>> TrustedLen for AmortizedListIter<'a, I> {} impl ListChunked { - /// This is an iterator over a [`ListChunked`] that save allocations. + /// This is an iterator over a [`ListChunked`] that saves allocations. /// A Series is: /// 1. [`Arc`] /// ChunkedArray is: @@ -117,25 +120,14 @@ impl ListChunked { /// 1. Arc<..> /// 2. Vec<...> /// - /// # Warning - /// Though memory safe in the sense that it will not read unowned memory, UB, or memory leaks - /// this function still needs precautions. The returned should never be cloned or taken longer - /// than a single iteration, as every call on `next` of the iterator will change the contents of - /// that Series. - /// - /// # Safety - /// The lifetime of [UnstableSeries] is bound to the iterator. Keeping it alive - /// longer than the iterator is UB. - pub unsafe fn amortized_iter( - &self, - ) -> AmortizedListIter> + '_> { + /// If the returned `AmortSeries` is cloned, the local copy will be replaced and a new container + /// will be set. + pub fn amortized_iter(&self) -> AmortizedListIter> + '_> { self.amortized_iter_with_name("") } - /// # Safety - /// The lifetime of [UnstableSeries] is bound to the iterator. Keeping it alive - /// longer than the iterator is UB. - pub unsafe fn amortized_iter_with_name( + /// See `amortized_iter`. + pub fn amortized_iter_with_name( &self, name: &str, ) -> AmortizedListIter> + '_> { @@ -159,45 +151,45 @@ impl ListChunked { let (s, ptr) = unsafe { unstable_series_container_and_ptr(name, inner_values.clone(), &iter_dtype) }; - AmortizedListIter::new( - self.len(), - s, - NonNull::new(ptr).unwrap(), - self.downcast_iter().flat_map(|arr| arr.iter()), - inner_dtype.clone(), - ) + // SAFETY: ptr belongs the the Series.. + unsafe { + AmortizedListIter::new( + self.len(), + s, + NonNull::new(ptr).unwrap(), + self.downcast_iter().flat_map(|arr| arr.iter()), + inner_dtype.clone(), + ) + } } /// Apply a closure `F` elementwise. #[must_use] - pub fn apply_amortized_generic<'a, F, K, V>(&'a self, f: F) -> ChunkedArray + pub fn apply_amortized_generic(&self, f: F) -> ChunkedArray where V: PolarsDataType, - F: FnMut(Option>) -> Option + Copy, + F: FnMut(Option) -> Option + Copy, V::Array: ArrayFromIter>, { // TODO! make an amortized iter that does not flatten - // SAFETY: unstable series never lives longer than the iterator. - unsafe { self.amortized_iter().map(f).collect_ca(self.name()) } + self.amortized_iter().map(f).collect_ca(self.name()) } - pub fn try_apply_amortized_generic<'a, F, K, V>(&'a self, f: F) -> PolarsResult> + pub fn try_apply_amortized_generic(&self, f: F) -> PolarsResult> where V: PolarsDataType, - F: FnMut(Option>) -> PolarsResult> + Copy, + F: FnMut(Option) -> PolarsResult> + Copy, V::Array: ArrayFromIter>, { // TODO! make an amortized iter that does not flatten - // SAFETY: unstable series never lives longer than the iterator. - unsafe { self.amortized_iter().map(f).try_collect_ca(self.name()) } + self.amortized_iter().map(f).try_collect_ca(self.name()) } - pub fn for_each_amortized<'a, F>(&'a self, f: F) + pub fn for_each_amortized(&self, f: F) where - F: FnMut(Option>), + F: FnMut(Option), { - // SAFETY: unstable series never lives longer than the iterator. - unsafe { self.amortized_iter().for_each(f) } + self.amortized_iter().for_each(f) } /// Zip with a `ChunkedArray` then apply a binary function `F` elementwise. @@ -207,14 +199,13 @@ impl ListChunked { T: PolarsDataType, &'a ChunkedArray: IntoIterator, I: TrustedLen>>, - F: FnMut(Option>, Option>) -> Option, + F: FnMut(Option, Option>) -> Option, { if self.is_empty() { return self.clone(); } let mut fast_explode = self.null_count() == 0; - // SAFETY: unstable series never lives longer than the iterator. - let mut out: ListChunked = unsafe { + let mut out: ListChunked = { self.amortized_iter() .zip(ca) .map(|(opt_s, opt_v)| { @@ -251,7 +242,7 @@ impl ListChunked { T: PolarsDataType, U: PolarsDataType, F: FnMut( - Option>, + Option, Option>, Option>, ) -> Option, @@ -260,8 +251,7 @@ impl ListChunked { return self.clone(); } let mut fast_explode = self.null_count() == 0; - // SAFETY: unstable series never lives longer than the iterator. - let mut out: ListChunked = unsafe { + let mut out: ListChunked = { self.amortized_iter() .zip(ca1.iter()) .zip(ca2.iter()) @@ -297,17 +287,13 @@ impl ListChunked { T: PolarsDataType, &'a ChunkedArray: IntoIterator, I: TrustedLen>>, - F: FnMut( - Option>, - Option>, - ) -> PolarsResult>, + F: FnMut(Option, Option>) -> PolarsResult>, { if self.is_empty() { return Ok(self.clone()); } let mut fast_explode = self.null_count() == 0; - // SAFETY: unstable series never lives longer than the iterator. - let mut out: ListChunked = unsafe { + let mut out: ListChunked = { self.amortized_iter() .zip(ca) .map(|(opt_s, opt_v)| { @@ -335,16 +321,15 @@ impl ListChunked { /// Apply a closure `F` elementwise. #[must_use] - pub fn apply_amortized<'a, F>(&'a self, mut f: F) -> Self + pub fn apply_amortized(&self, mut f: F) -> Self where - F: FnMut(UnstableSeries<'a>) -> Series, + F: FnMut(AmortSeries) -> Series, { if self.is_empty() { return self.clone(); } let mut fast_explode = self.null_count() == 0; - // SAFETY: unstable series never lives longer than the iterator. - let mut ca: ListChunked = unsafe { + let mut ca: ListChunked = { self.amortized_iter() .map(|opt_v| { opt_v.map(|v| { @@ -365,16 +350,15 @@ impl ListChunked { ca } - pub fn try_apply_amortized<'a, F>(&'a self, mut f: F) -> PolarsResult + pub fn try_apply_amortized(&self, mut f: F) -> PolarsResult where - F: FnMut(UnstableSeries<'a>) -> PolarsResult, + F: FnMut(AmortSeries) -> PolarsResult, { if self.is_empty() { return Ok(self.clone()); } let mut fast_explode = self.null_count() == 0; - // SAFETY: unstable series never lives longer than the iterator. - let mut ca: ListChunked = unsafe { + let mut ca: ListChunked = { self.amortized_iter() .map(|opt_v| { opt_v @@ -412,11 +396,8 @@ mod test { builder.append_series(&Series::new("", &[1, 1])).unwrap(); let ca = builder.finish(); - // SAFETY: unstable series never lives longer than the iterator. - unsafe { - ca.amortized_iter().zip(&ca).for_each(|(s1, s2)| { - assert!(s1.unwrap().as_ref().equals(&s2.unwrap())); - }) - }; + ca.amortized_iter().zip(&ca).for_each(|(s1, s2)| { + assert!(s1.unwrap().as_ref().equals(&s2.unwrap())); + }) } } diff --git a/crates/polars-core/src/series/unstable.rs b/crates/polars-core/src/series/amortized_iter.rs similarity index 62% rename from crates/polars-core/src/series/unstable.rs rename to crates/polars-core/src/series/amortized_iter.rs index 6dbcd603e4c5..7cdf8507c29f 100644 --- a/crates/polars-core/src/series/unstable.rs +++ b/crates/polars-core/src/series/amortized_iter.rs @@ -1,46 +1,35 @@ -use std::marker::PhantomData; use std::ptr::NonNull; +use std::rc::Rc; use polars_utils::unwrap::UnwrapUncheckedRelease; use crate::prelude::*; -/// A wrapper type that should make it a bit more clear that we should not clone Series -#[derive(Copy, Clone)] -pub struct UnstableSeries<'a> { - lifetime: PhantomData<&'a Series>, - // A series containing a single chunk ArrayRef - // the ArrayRef will be replaced by amortized_iter - // use with caution! - container: *mut Series, +/// A `[Series]` that amortizes a few allocations during iteration. +#[derive(Clone)] +pub struct AmortSeries { + container: Rc, // the ptr to the inner chunk, this saves some ptr chasing inner: NonNull, } /// We don't implement Deref so that the caller is aware of converting to Series -impl AsRef for UnstableSeries<'_> { +impl AsRef for AmortSeries { fn as_ref(&self) -> &Series { - unsafe { &*self.container } - } -} - -impl AsMut for UnstableSeries<'_> { - fn as_mut(&mut self) -> &mut Series { - unsafe { &mut *self.container } + self.container.as_ref() } } pub type ArrayBox = Box; -impl<'a> UnstableSeries<'a> { - pub fn new(series: &'a mut Series) -> Self { +impl AmortSeries { + pub fn new(series: Rc) -> Self { debug_assert_eq!(series.chunks().len(), 1); - let container = series as *mut Series; - let inner_chunk = series.array_ref(0); - UnstableSeries { - lifetime: PhantomData, + let inner_chunk = series.array_ref(0) as *const ArrayRef as *mut arrow::array::ArrayRef; + let container = series; + AmortSeries { container, - inner: NonNull::new(inner_chunk as *const ArrayRef as *mut ArrayRef).unwrap(), + inner: NonNull::new(inner_chunk).unwrap(), } } @@ -49,9 +38,8 @@ impl<'a> UnstableSeries<'a> { /// # Safety /// Inner chunks must be from `Series` otherwise the dtype may be incorrect and lead to UB. #[inline] - pub(crate) unsafe fn new_with_chunk(series: &'a mut Series, inner_chunk: &ArrayRef) -> Self { - UnstableSeries { - lifetime: PhantomData, + pub(crate) unsafe fn new_with_chunk(series: Rc, inner_chunk: &ArrayRef) -> Self { + AmortSeries { container: series, inner: NonNull::new(inner_chunk as *const ArrayRef as *mut ArrayRef) .unwrap_unchecked_release(), @@ -69,22 +57,29 @@ impl<'a> UnstableSeries<'a> { } #[inline] - /// Swaps inner state with the `array`. Prefer `UnstableSeries::with_array` as this + /// Swaps inner state with the `array`. Prefer `AmortSeries::with_array` as this /// restores the state. /// # Safety /// This swaps an underlying pointer that might be hold by other cloned series. pub unsafe fn swap(&mut self, array: &mut ArrayRef) { std::mem::swap(self.inner.as_mut(), array); + // ensure lengths are correct. - self.as_mut()._get_inner_mut().compute_len(); + unsafe { + let ptr = Rc::as_ptr(&self.container) as *mut Series; + (*ptr)._get_inner_mut().compute_len() + } } /// Temporary swaps out the array, and restores the original state /// when application of the function `f` is done. + /// + /// # Safety + /// Array must be from `Series` physical dtype. #[inline] - pub fn with_array(&mut self, array: &mut ArrayRef, f: F) -> T + pub unsafe fn with_array(&mut self, array: &mut ArrayRef, f: F) -> T where - F: Fn(&UnstableSeries) -> T, + F: Fn(&AmortSeries) -> T, { unsafe { self.swap(array); diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index fd35ec2e509e..c02862cc7458 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -2,6 +2,7 @@ pub use crate::prelude::ChunkCompare; use crate::prelude::*; +pub mod amortized_iter; mod any_value; pub mod arithmetic; mod comparison; @@ -11,7 +12,6 @@ mod into; pub(crate) mod iterator; pub mod ops; mod series_trait; -pub mod unstable; use std::borrow::Cow; use std::hash::{Hash, Hasher}; diff --git a/crates/polars-core/src/utils/series.rs b/crates/polars-core/src/utils/series.rs index 3312e6a9e109..1595101b2a43 100644 --- a/crates/polars-core/src/utils/series.rs +++ b/crates/polars-core/src/utils/series.rs @@ -1,14 +1,16 @@ +use std::rc::Rc; + use crate::prelude::*; -use crate::series::unstable::UnstableSeries; +use crate::series::amortized_iter::AmortSeries; -/// A utility that allocates an [`UnstableSeries`]. The applied function can then use that +/// A utility that allocates an [`AmortSeries`]. The applied function can then use that /// series container to save heap allocations and swap arrow arrays. pub fn with_unstable_series(dtype: &DataType, f: F) -> T where - F: Fn(&mut UnstableSeries) -> T, + F: Fn(&mut AmortSeries) -> T, { - let mut container = Series::full_null("", 0, dtype); - let mut us = UnstableSeries::new(&mut container); + let container = Series::full_null("", 0, dtype); + let mut us = AmortSeries::new(Rc::new(container)); f(&mut us) } diff --git a/crates/polars-expr/src/expressions/filter.rs b/crates/polars-expr/src/expressions/filter.rs index 09f28f4d9b9d..d9df88419ae7 100644 --- a/crates/polars-expr/src/expressions/filter.rs +++ b/crates/polars-expr/src/expressions/filter.rs @@ -55,8 +55,7 @@ impl PhysicalExpr for FilterExpr { // return an empty list if ca is empty. ListChunked::full_null_with_dtype(ca.name(), 0, ca.inner_dtype()) } else { - // SAFETY: unstable series never lives longer than the iterator. - unsafe { + { ca.amortized_iter() .zip(preds) .map(|(opt_s, opt_pred)| match (opt_s, opt_pred) { diff --git a/crates/polars-expr/src/expressions/gather.rs b/crates/polars-expr/src/expressions/gather.rs index fc1104ffad24..c54f8b9e8262 100644 --- a/crates/polars-expr/src/expressions/gather.rs +++ b/crates/polars-expr/src/expressions/gather.rs @@ -72,7 +72,7 @@ impl PhysicalExpr for GatherExpr { let s = idx.cast(&DataType::List(Box::new(IDX_DTYPE)))?; let idx = s.list().unwrap(); - let taken = unsafe { + let taken = { ac.aggregated() .list() .unwrap() diff --git a/crates/polars-expr/src/expressions/group_iter.rs b/crates/polars-expr/src/expressions/group_iter.rs index 6db8911a99f6..8c921a519bd1 100644 --- a/crates/polars-expr/src/expressions/group_iter.rs +++ b/crates/polars-expr/src/expressions/group_iter.rs @@ -1,17 +1,17 @@ -use std::pin::Pin; +use std::rc::Rc; -use polars_core::series::unstable::UnstableSeries; +use polars_core::series::amortized_iter::AmortSeries; use super::*; impl<'a> AggregationContext<'a> { /// # Safety - /// The lifetime of [UnstableSeries] is bound to the iterator. Keeping it alive + /// The lifetime of [AmortSeries] is bound to the iterator. Keeping it alive /// longer than the iterator is UB. pub(super) unsafe fn iter_groups( &mut self, keep_names: bool, - ) -> Box>> + '_> { + ) -> Box> + '_> { match self.agg_state() { AggState::Literal(_) => { self.groups(); @@ -59,45 +59,44 @@ impl<'a> AggregationContext<'a> { } } -struct LitIter<'a> { +struct LitIter { len: usize, offset: usize, - // UnstableSeries referenced that series + // AmortSeries referenced that series #[allow(dead_code)] - series_container: Pin>, - item: UnstableSeries<'a>, + series_container: Rc, + item: AmortSeries, } -impl<'a> LitIter<'a> { +impl LitIter { /// # Safety /// Caller must ensure the given `logical` dtype belongs to `array`. unsafe fn new(array: ArrayRef, len: usize, logical: &DataType, name: &str) -> Self { - let mut series_container = Box::pin(Series::from_chunks_and_dtype_unchecked( + let series_container = Rc::new(Series::from_chunks_and_dtype_unchecked( name, vec![array], logical, )); - let ref_s = &mut *series_container as *mut Series; Self { offset: 0, len, - series_container, + series_container: series_container.clone(), // SAFETY: we pinned the series so the location is still valid - item: UnstableSeries::new(unsafe { &mut *ref_s }), + item: AmortSeries::new(series_container), } } } -impl<'a> Iterator for LitIter<'a> { - type Item = Option>; +impl Iterator for LitIter { + type Item = Option; fn next(&mut self) -> Option { if self.len == self.offset { None } else { self.offset += 1; - Some(Some(self.item)) + Some(Some(self.item.clone())) } } @@ -106,19 +105,19 @@ impl<'a> Iterator for LitIter<'a> { } } -struct FlatIter<'a> { +struct FlatIter { current_array: ArrayRef, chunks: Vec, offset: usize, chunk_offset: usize, len: usize, - // UnstableSeries referenced that series + // AmortSeries referenced that series #[allow(dead_code)] - series_container: Pin>, - item: UnstableSeries<'a>, + series_container: Rc, + item: AmortSeries, } -impl<'a> FlatIter<'a> { +impl FlatIter { /// # Safety /// Caller must ensure the given `logical` dtype belongs to `array`. unsafe fn new(chunks: &[ArrayRef], len: usize, logical: &DataType, name: &str) -> Self { @@ -127,27 +126,25 @@ impl<'a> FlatIter<'a> { stack.push(chunk.clone()) } let current_array = stack.pop().unwrap(); - let mut series_container = Box::pin(Series::from_chunks_and_dtype_unchecked( + let series_container = Rc::new(Series::from_chunks_and_dtype_unchecked( name, vec![current_array.clone()], logical, )); - let ref_s = &mut *series_container as *mut Series; Self { current_array, chunks: stack, offset: 0, chunk_offset: 0, len, - series_container, - // SAFETY: we pinned the series so the location is still valid - item: UnstableSeries::new(unsafe { &mut *ref_s }), + series_container: series_container.clone(), + item: AmortSeries::new(series_container), } } } -impl<'a> Iterator for FlatIter<'a> { - type Item = Option>; +impl Iterator for FlatIter { + type Item = Option; fn next(&mut self) -> Option { if self.len == self.offset { @@ -168,7 +165,7 @@ impl<'a> Iterator for FlatIter<'a> { } self.offset += 1; self.chunk_offset += 1; - Some(Some(self.item)) + Some(Some(self.item.clone())) } } fn size_hint(&self) -> (usize, Option) { diff --git a/crates/polars-expr/src/expressions/mod.rs b/crates/polars-expr/src/expressions/mod.rs index 6fc8021b9886..5b4b5407a614 100644 --- a/crates/polars-expr/src/expressions/mod.rs +++ b/crates/polars-expr/src/expressions/mod.rs @@ -289,8 +289,7 @@ impl<'a> AggregationContext<'a> { }); }, _ => { - // SAFETY: unstable series never lives longer than the iterator. - let groups = unsafe { + let groups = { self.series() .list() .expect("impl error, should be a list at this point") diff --git a/crates/polars-ops/src/chunked_array/array/dispersion.rs b/crates/polars-ops/src/chunked_array/array/dispersion.rs index 48a4f4277a01..056b1b87d09a 100644 --- a/crates/polars-ops/src/chunked_array/array/dispersion.rs +++ b/crates/polars-ops/src/chunked_array/array/dispersion.rs @@ -42,8 +42,7 @@ pub(super) fn std_with_nulls(ca: &ArrayChunked, ddof: u8) -> PolarsResult { - // SAFETY: lifetime of iterator bound to scope of function - let out: Float64Chunked = unsafe { + let out: Float64Chunked = { ca.amortized_iter() .map(|s| s.and_then(|s| s.as_ref().std(ddof))) .collect() diff --git a/crates/polars-ops/src/chunked_array/array/join.rs b/crates/polars-ops/src/chunked_array/array/join.rs index b6711dabbf5e..0ba4a517ca0f 100644 --- a/crates/polars-ops/src/chunked_array/array/join.rs +++ b/crates/polars-ops/src/chunked_array/array/join.rs @@ -47,8 +47,7 @@ fn join_many( let mut buf = String::new(); let mut builder = StringChunkedBuilder::new(ca.name(), ca.len()); - // SAFETY: lifetime of iterator bound to scope of function - unsafe { ca.amortized_iter() } + { ca.amortized_iter() } .zip(separator) .for_each(|(opt_s, opt_sep)| match opt_sep { Some(separator) => { diff --git a/crates/polars-ops/src/chunked_array/array/sum_mean.rs b/crates/polars-ops/src/chunked_array/array/sum_mean.rs index 4bcab7da2741..60bd144317bc 100644 --- a/crates/polars-ops/src/chunked_array/array/sum_mean.rs +++ b/crates/polars-ops/src/chunked_array/array/sum_mean.rs @@ -59,8 +59,7 @@ pub(super) fn sum_array_numerical(ca: &ArrayChunked, inner_type: &DataType) -> S pub(super) fn sum_with_nulls(ca: &ArrayChunked, inner_dtype: &DataType) -> PolarsResult { use DataType::*; // TODO: add fast path for smaller ints? - // SAFETY: lifetime of iterator is bound to this functions scope - let mut out = unsafe { + let mut out = { match inner_dtype { Boolean => { let out: IdxCa = ca diff --git a/crates/polars-ops/src/chunked_array/list/namespace.rs b/crates/polars-ops/src/chunked_array/list/namespace.rs index 15b32c55a7ed..0306375af35f 100644 --- a/crates/polars-ops/src/chunked_array/list/namespace.rs +++ b/crates/polars-ops/src/chunked_array/list/namespace.rs @@ -136,8 +136,7 @@ pub trait ListNameSpaceImpl: AsList { // used to amortize heap allocs let mut buf = String::with_capacity(128); let mut builder = StringChunkedBuilder::new(ca.name(), ca.len()); - // SAFETY: unstable series never lives longer than the iterator. - unsafe { + { ca.amortized_iter() .zip(separator) .for_each(|(opt_s, opt_sep)| match opt_sep { @@ -428,8 +427,7 @@ pub trait ListNameSpaceImpl: AsList { let index_typed_index = |idx: &Series| { let idx = idx.cast(&IDX_DTYPE).unwrap(); - // SAFETY: unstable series never lives longer than the iterator. - unsafe { + { list_ca .amortized_iter() .map(|s| { @@ -451,8 +449,7 @@ pub trait ListNameSpaceImpl: AsList { match idx.dtype() { List(_) => { let idx_ca = idx.list().unwrap(); - // SAFETY: unstable series never lives longer than the iterator. - let mut out = unsafe { + let mut out = { list_ca .amortized_iter() .zip(idx_ca) @@ -479,8 +476,7 @@ pub trait ListNameSpaceImpl: AsList { if min >= 0 { index_typed_index(idx) } else { - // SAFETY: unstable series never lives longer than the iterator. - let mut out = unsafe { + let mut out = { list_ca .amortized_iter() .map(|opt_s| { @@ -684,8 +680,7 @@ pub trait ListNameSpaceImpl: AsList { let mut iters = Vec::with_capacity(other_len + 1); for s in other.iter_mut() { - // SAFETY: unstable series never lives longer than the iterator. - iters.push(unsafe { s.list()?.amortized_iter() }) + iters.push(s.list()?.amortized_iter()) } let mut first_iter: Box>> = ca.into_iter(); let mut builder = get_list_builder( diff --git a/crates/polars-ops/src/series/ops/is_in.rs b/crates/polars-ops/src/series/ops/is_in.rs index ec7f0cc944c0..afb89d725fdc 100644 --- a/crates/polars-ops/src/series/ops/is_in.rs +++ b/crates/polars-ops/src/series/ops/is_in.rs @@ -57,8 +57,7 @@ where }) } else { polars_ensure!(ca_in.len() == other.len(), ComputeError: "shapes don't match: expected {} elements in 'is_in' comparison, got {}", ca_in.len(), other.len()); - // SAFETY: unstable series never lives longer than the iterator. - unsafe { + { ca_in .iter() .zip(other.list()?.amortized_iter()) @@ -97,8 +96,7 @@ where polars_ensure!(ca_in.len() == other.len(), ComputeError: "shapes don't match: expected {} elements in 'is_in' comparison, got {}", ca_in.len(), other.len()); ca_in .iter() - // SAFETY: lifetime of iterator bound to scope of function - .zip(unsafe { other.array()?.amortized_iter() }) + .zip(other.array()?.amortized_iter()) .map(|(value, series)| match (value, series) { (val, Some(series)) => { let ca = series.as_ref().unpack::().unwrap(); @@ -183,8 +181,7 @@ fn is_in_string_list_categorical( } } else { polars_ensure!(ca_in.len() == other.len(), ComputeError: "shapes don't match: expected {} elements in 'is_in' comparison, got {}", ca_in.len(), other.len()); - // SAFETY: unstable series never lives longer than the iterator. - unsafe { + { ca_in .iter() .zip(other.list()?.amortized_iter()) @@ -257,8 +254,7 @@ fn is_in_binary_list(ca_in: &BinaryChunked, other: &Series) -> PolarsResult PolarsResult { let ca = series.as_ref().unpack::().unwrap(); @@ -322,7 +317,6 @@ fn is_in_boolean_list(ca_in: &BooleanChunked, other: &Series) -> PolarsResult PolarsResult PolarsResult { let ca = series.as_ref().unpack::().unwrap(); @@ -436,8 +428,7 @@ fn is_in_struct_list(ca_in: &StructChunked, other: &Series) -> PolarsResult PolarsResult { let ca = series.as_ref().struct_().unwrap(); @@ -531,9 +522,6 @@ fn is_in_struct(ca_in: &StructChunked, other: &Series) -> PolarsResult PolarsResult PolarsResult> { let length_ca = length_s.cast(&DataType::Int64)?; let length_ca = length_ca.i64().unwrap(); - // SAFETY: unstable series never lives longer than the iterator. - unsafe { - list_ca - .amortized_iter() - .zip(length_ca) - .map(|(opt_s, opt_length)| match (opt_s, opt_length) { - (Some(s), Some(length)) => Some(s.as_ref().slice(offset, length as usize)), - _ => None, - }) - .collect_trusted() - } + list_ca + .amortized_iter() + .zip(length_ca) + .map(|(opt_s, opt_length)| match (opt_s, opt_length) { + (Some(s), Some(length)) => Some(s.as_ref().slice(offset, length as usize)), + _ => None, + }) + .collect_trusted() }, (offset_len, 1) => { check_slice_arg_shape(offset_len, list_ca.len(), "offset")?; @@ -347,17 +344,14 @@ pub(super) fn slice(args: &mut [Series]) -> PolarsResult> { .unwrap_or(usize::MAX); let offset_ca = offset_s.cast(&DataType::Int64)?; let offset_ca = offset_ca.i64().unwrap(); - // SAFETY: unstable series never lives longer than the iterator. - unsafe { - list_ca - .amortized_iter() - .zip(offset_ca) - .map(|(opt_s, opt_offset)| match (opt_s, opt_offset) { - (Some(s), Some(offset)) => Some(s.as_ref().slice(offset, length_slice)), - _ => None, - }) - .collect_trusted() - } + list_ca + .amortized_iter() + .zip(offset_ca) + .map(|(opt_s, opt_offset)| match (opt_s, opt_offset) { + (Some(s), Some(offset)) => Some(s.as_ref().slice(offset, length_slice)), + _ => None, + }) + .collect_trusted() }, _ => { check_slice_arg_shape(offset_s.len(), list_ca.len(), "offset")?; @@ -369,22 +363,19 @@ pub(super) fn slice(args: &mut [Series]) -> PolarsResult> { let length_ca = length_s.cast(&DataType::Int64)?; let length_ca = length_ca.i64().unwrap(); - // SAFETY: unstable series never lives longer than the iterator. - unsafe { - list_ca - .amortized_iter() - .zip(offset_ca) - .zip(length_ca) - .map(|((opt_s, opt_offset), opt_length)| { - match (opt_s, opt_offset, opt_length) { - (Some(s), Some(offset), Some(length)) => { - Some(s.as_ref().slice(offset, length as usize)) - }, - _ => None, - } - }) - .collect_trusted() - } + list_ca + .amortized_iter() + .zip(offset_ca) + .zip(length_ca) + .map( + |((opt_s, opt_offset), opt_length)| match (opt_s, opt_offset, opt_length) { + (Some(s), Some(offset), Some(length)) => { + Some(s.as_ref().slice(offset, length as usize)) + }, + _ => None, + }, + ) + .collect_trusted() }, }; out.rename(s.name()); diff --git a/py-polars/src/series/export.rs b/py-polars/src/series/export.rs index bf9e0b210810..3478a1022272 100644 --- a/py-polars/src/series/export.rs +++ b/py-polars/src/series/export.rs @@ -44,8 +44,7 @@ impl PySeries { DataType::List(_) => { let v = PyList::empty_bound(py); let ca = series.list().unwrap(); - // SAFETY: unstable series never lives longer than the iterator. - for opt_s in unsafe { ca.amortized_iter() } { + for opt_s in ca.amortized_iter() { match opt_s { None => { v.append(py.None()).unwrap(); @@ -61,8 +60,7 @@ impl PySeries { DataType::Array(_, _) => { let v = PyList::empty_bound(py); let ca = series.array().unwrap(); - // SAFETY: lifetime of iterator bound to this scope - for opt_s in unsafe { ca.amortized_iter() } { + for opt_s in ca.amortized_iter() { match opt_s { None => { v.append(py.None()).unwrap();