Skip to content

Commit

Permalink
refactor: Improve safety of amortized_iter (#16820)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jun 8, 2024
1 parent 38149d6 commit fa12a8e
Show file tree
Hide file tree
Showing 18 changed files with 241 additions and 309 deletions.
84 changes: 40 additions & 44 deletions crates/polars-core/src/chunked_array/array/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::ptr::NonNull;

use super::*;
use crate::chunked_array::list::iterator::AmortizedListIter;
use crate::series::unstable::{unstable_series_container_and_ptr, ArrayBox, UnstableSeries};
use crate::series::amortized_iter::{unstable_series_container_and_ptr, AmortSeries, ArrayBox};

impl ArrayChunked {
/// This is an iterator over a [`ArrayChunked`] that save allocations.
Expand All @@ -23,11 +23,9 @@ impl ArrayChunked {
/// that Series.
///
/// # Safety
/// The lifetime of [UnstableSeries] is bound to the iterator. Keeping it alive
/// The lifetime of [AmortSeries] is bound to the iterator. Keeping it alive
/// longer than the iterator is UB.
pub unsafe fn amortized_iter(
&self,
) -> AmortizedListIter<impl Iterator<Item = Option<ArrayBox>> + '_> {
pub fn amortized_iter(&self) -> AmortizedListIter<impl Iterator<Item = Option<ArrayBox>> + '_> {
self.amortized_iter_with_name("")
}

Expand All @@ -37,21 +35,14 @@ impl ArrayChunked {
/// ChunkedArray is:
/// 2. Vec< 3. ArrayRef>
///
/// The [`ArrayRef`] we indicated with 3. will be updated during iteration.
/// The ArrayRef we indicated with 3. will be updated during iteration.
/// The Series will be pinned in memory, saving an allocation for
/// 1. Arc<..>
/// 2. Vec<...>
///
/// # Warning
/// Though memory safe in the sense that it will not read unowned memory, UB, or memory leaks
/// this function still needs precautions. The returned should never be cloned or taken longer
/// than a single iteration, as every call on `next` of the iterator will change the contents of
/// that Series.
///
/// # Safety
/// The lifetime of [UnstableSeries] is bound to the iterator. Keeping it alive
/// longer than the iterator is UB.
pub unsafe fn amortized_iter_with_name(
/// If the returned `AmortSeries` is cloned, the local copy will be replaced and a new container
/// will be set.
pub fn amortized_iter_with_name(
&self,
name: &str,
) -> AmortizedListIter<impl Iterator<Item = Option<ArrayBox>> + '_> {
Expand All @@ -75,18 +66,21 @@ impl ArrayChunked {
let (s, ptr) =
unsafe { unstable_series_container_and_ptr(name, inner_values.clone(), &iter_dtype) };

AmortizedListIter::new(
self.len(),
s,
NonNull::new(ptr).unwrap(),
self.downcast_iter().flat_map(|arr| arr.iter()),
inner_dtype.clone(),
)
// SAFETY: `ptr` belongs to the `Series`.
unsafe {
AmortizedListIter::new(
self.len(),
s,
NonNull::new(ptr).unwrap(),
self.downcast_iter().flat_map(|arr| arr.iter()),
inner_dtype.clone(),
)
}
}

pub fn try_apply_amortized_to_list<'a, F>(&'a self, mut f: F) -> PolarsResult<ListChunked>
pub fn try_apply_amortized_to_list<F>(&self, mut f: F) -> PolarsResult<ListChunked>
where
F: FnMut(UnstableSeries<'a>) -> PolarsResult<Series>,
F: FnMut(AmortSeries) -> PolarsResult<Series>,
{
if self.is_empty() {
return Ok(Series::new_empty(
Expand All @@ -98,8 +92,7 @@ impl ArrayChunked {
.clone());
}
let mut fast_explode = self.null_count() == 0;
// SAFETY: lifetime of iterator is bound to this functions scope
let mut ca: ListChunked = unsafe {
let mut ca: ListChunked = {
self.amortized_iter()
.map(|opt_v| {
opt_v
Expand Down Expand Up @@ -128,9 +121,9 @@ impl ArrayChunked {
/// # Safety
/// Return series of `F` must has the same dtype and number of elements as input.
#[must_use]
pub unsafe fn apply_amortized_same_type<'a, F>(&'a self, mut f: F) -> Self
pub unsafe fn apply_amortized_same_type<F>(&self, mut f: F) -> Self
where
F: FnMut(UnstableSeries<'a>) -> Series,
F: FnMut(AmortSeries) -> Series,
{
if self.is_empty() {
return self.clone();
Expand All @@ -149,9 +142,9 @@ impl ArrayChunked {
///
/// # Safety
/// Return series of `F` must has the same dtype and number of elements as input if it is Ok.
pub unsafe fn try_apply_amortized_same_type<'a, F>(&'a self, mut f: F) -> PolarsResult<Self>
pub unsafe fn try_apply_amortized_same_type<F>(&self, mut f: F) -> PolarsResult<Self>
where
F: FnMut(UnstableSeries<'a>) -> PolarsResult<Series>,
F: FnMut(AmortSeries) -> PolarsResult<Series>,
{
if self.is_empty() {
return Ok(self.clone());
Expand Down Expand Up @@ -180,7 +173,7 @@ impl ArrayChunked {
) -> Self
where
T: PolarsDataType,
F: FnMut(Option<UnstableSeries<'a>>, Option<T::Physical<'a>>) -> Option<Series>,
F: FnMut(Option<AmortSeries>, Option<T::Physical<'a>>) -> Option<Series>,
{
if self.is_empty() {
return self.clone();
Expand All @@ -196,33 +189,36 @@ impl ArrayChunked {

/// Apply a closure `F` elementwise.
#[must_use]
pub fn apply_amortized_generic<'a, F, K, V>(&'a self, f: F) -> ChunkedArray<V>
pub fn apply_amortized_generic<F, K, V>(&self, f: F) -> ChunkedArray<V>
where
V: PolarsDataType,
F: FnMut(Option<UnstableSeries<'a>>) -> Option<K> + Copy,
F: FnMut(Option<AmortSeries>) -> Option<K> + Copy,
V::Array: ArrayFromIter<Option<K>>,
{
// SAFETY: lifetime of iterator is bound to this functions scope
unsafe { self.amortized_iter().map(f).collect_ca(self.name()) }
{
self.amortized_iter().map(f).collect_ca(self.name())
}
}

/// Try apply a closure `F` elementwise.
pub fn try_apply_amortized_generic<'a, F, K, V>(&'a self, f: F) -> PolarsResult<ChunkedArray<V>>
pub fn try_apply_amortized_generic<F, K, V>(&self, f: F) -> PolarsResult<ChunkedArray<V>>
where
V: PolarsDataType,
F: FnMut(Option<UnstableSeries<'a>>) -> PolarsResult<Option<K>> + Copy,
F: FnMut(Option<AmortSeries>) -> PolarsResult<Option<K>> + Copy,
V::Array: ArrayFromIter<Option<K>>,
{
// SAFETY: lifetime of iterator is bound to this functions scope
unsafe { self.amortized_iter().map(f).try_collect_ca(self.name()) }
{
self.amortized_iter().map(f).try_collect_ca(self.name())
}
}

pub fn for_each_amortized<'a, F>(&'a self, f: F)
pub fn for_each_amortized<F>(&self, f: F)
where
F: FnMut(Option<UnstableSeries<'a>>),
F: FnMut(Option<AmortSeries>),
{
// SAFETY: lifetime of iterator is bound to this functions scope
unsafe { self.amortized_iter().for_each(f) }
{
self.amortized_iter().for_each(f)
}
}
}

Expand Down
38 changes: 15 additions & 23 deletions crates/polars-core/src/chunked_array/comparison/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -560,34 +560,26 @@ where
match (lhs.len(), rhs.len()) {
(_, 1) => {
let right = rhs.get_as_series(0).map(|s| s.with_name(""));
// SAFETY: values within iterator do not outlive the iterator itself
unsafe {
lhs.amortized_iter()
.map(|left| op(left.as_ref().map(|us| us.as_ref()), right.as_ref()))
.collect_trusted()
}
lhs.amortized_iter()
.map(|left| op(left.as_ref().map(|us| us.as_ref()), right.as_ref()))
.collect_trusted()
},
(1, _) => {
let left = lhs.get_as_series(0).map(|s| s.with_name(""));
// SAFETY: values within iterator do not outlive the iterator itself
unsafe {
rhs.amortized_iter()
.map(|right| op(left.as_ref(), right.as_ref().map(|us| us.as_ref())))
.collect_trusted()
}
},
// SAFETY: values within iterator do not outlive the iterator itself
_ => unsafe {
lhs.amortized_iter()
.zip(rhs.amortized_iter())
.map(|(left, right)| {
op(
left.as_ref().map(|us| us.as_ref()),
right.as_ref().map(|us| us.as_ref()),
)
})
rhs.amortized_iter()
.map(|right| op(left.as_ref(), right.as_ref().map(|us| us.as_ref())))
.collect_trusted()
},
_ => lhs
.amortized_iter()
.zip(rhs.amortized_iter())
.map(|(left, right)| {
op(
left.as_ref().map(|us| us.as_ref()),
right.as_ref().map(|us| us.as_ref()),
)
})
.collect_trusted(),
}
}

Expand Down
Loading

0 comments on commit fa12a8e

Please sign in to comment.