Skip to content

Commit

Permalink
simplify, mark unsafety
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed May 22, 2024
1 parent 57a0f6b commit b23f7be
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use arrow::bitmap::MutableBitmap;
use polars_core::downcast_as_macro_arg_physical;
use polars_core::export::num::{NumCast, Zero};
use polars_core::prelude::*;
use polars_utils::vec::PushUnchecked;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -61,7 +60,7 @@ where
let mut out = Vec::with_capacity(chunked_arr.len());
let mut iter = chunked_arr.iter().skip(first);
for _ in 0..first {
unsafe { out.push_unchecked(Zero::zero()) }
out.push(Zero::zero());
}

// The next element of `iter` is definitely `Some(Some(v))`, because we skipped the first
Expand Down
32 changes: 20 additions & 12 deletions crates/polars-ops/src/series/ops/interpolation/interpolate_by.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use std::ops::{Add, Div, Mul, Sub};

use arrow::array::PrimitiveArray;
use arrow::bitmap::MutableBitmap;
use arrow::legacy::trusted_len::TrustedLenPush;
use bytemuck::allocation::zeroed_vec;
use polars_core::export::num::{NumCast, Zero};
use polars_core::prelude::*;
Expand Down Expand Up @@ -34,15 +33,15 @@ where
}

/// # Safety
/// `x` must be non-empty.
/// - `x` must be non-empty.
/// - `sorting_indices` must be the same size as `x`
#[inline]
unsafe fn signed_interp_by<T, F>(
y_start: T,
y_end: T,
x: &[F],
out: &mut [T],
sorting_indices: &[IdxSize],
low_idx: usize,
) where
T: Sub<Output = T>
+ Mul<Output = T>
Expand All @@ -63,11 +62,11 @@ unsafe fn signed_interp_by<T, F>(
iter = x.slice_unchecked(1..x.len() - 1).iter();
}
let slope = range_y / range_x;
for (offset, x_i) in iter.enumerate() {
for (idx, x_i) in iter.enumerate() {
let x_delta = NumCast::from(*x_i - *x_start).unwrap();
let v = linear_itp(y_start, x_delta, slope);
unsafe {
let out_idx = sorting_indices.get_unchecked(low_idx + offset + 1);
let out_idx = sorting_indices.get_unchecked(idx + 1);
*out.get_unchecked_mut(*out_idx as usize) = v;
}
}
Expand Down Expand Up @@ -101,7 +100,7 @@ where
let mut out = Vec::with_capacity(chunked_arr.len());
let mut iter = chunked_arr.iter().enumerate().skip(first);
for _ in 0..first {
unsafe { out.push_unchecked(Zero::zero()) }
out.push(Zero::zero());
}

// The next element of `iter` is definitely `Some(idx, Some(v))`, because we skipped the first
Expand Down Expand Up @@ -136,8 +135,8 @@ where
}

for i in last..chunked_arr.len() {
unsafe { validity.set_unchecked(i, false) };
out.push(Zero::zero())
unsafe { validity.set_unchecked(i, false) }
out.push(Zero::zero());
}

let array = PrimitiveArray::new(
Expand All @@ -160,7 +159,7 @@ fn interpolate_impl_by<T, F, I>(
where
T: PolarsNumericType,
F: PolarsIntegerType,
I: Fn(T::Native, T::Native, &[F::Native], &mut [T::Native], &[IdxSize], usize),
I: Fn(T::Native, T::Native, &[F::Native], &mut [T::Native], &[IdxSize]),
{
// This implementation differs from pandas as that boundary None's are not removed.
// This prevents a lot of errors due to expressions leading to different lengths.
Expand Down Expand Up @@ -205,9 +204,15 @@ where
} else {
for (high_idx, next) in iter.by_ref() {
if let Some(high) = next {
let x = unsafe { &by_sorted_values.slice_unchecked(low_idx..high_idx + 1) };
interpolation_branch(low, high, x, &mut out, sorting_indices, low_idx);
// SAFETY: we are in bounds, and the slices are the same length (and non-empty)
unsafe {
interpolation_branch(
low,
high,
by_sorted_values.slice_unchecked(low_idx..high_idx + 1),
&mut out,
sorting_indices.slice_unchecked(low_idx..high_idx + 1),
);
let out_idx = sorting_indices.get_unchecked(high_idx);
*out.get_unchecked_mut(*out_idx as usize) = high;
}
Expand Down Expand Up @@ -263,7 +268,10 @@ pub fn interpolate_by(s: &Series, by: &Series, by_is_sorted: bool) -> PolarsResu
if is_sorted {
interpolate_impl_by_sorted(ca, by, signed_interp_by_sorted).map(|x| x.into_series())
} else {
interpolate_impl_by(ca, by, safe{signed_interp_by}).map(|x| x.into_series())
interpolate_impl_by(ca, by, |y_start, y_end, x, out, sorting_indices| unsafe {
signed_interp_by(y_start, y_end, x, out, sorting_indices)
})
.map(|x| x.into_series())
}
}

Expand Down

0 comments on commit b23f7be

Please sign in to comment.