Skip to content

Commit

Permalink
fix(16905): Fix a number of edge cases where assignment corrupted a S…
Browse files Browse the repository at this point in the history
…eries (#16930)

Co-authored-by: Itamar Turner-Trauring <itamar@pythonspeed.com>
  • Loading branch information
itamarst and pythonspeed authored Jun 15, 2024
1 parent 6b36e9b commit 9a3e032
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 25 deletions.
2 changes: 1 addition & 1 deletion crates/polars-core/src/series/comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ macro_rules! impl_compare {
lhs.0.$method(&rhs.0)
},

dt => polars_bail!(InvalidOperation: "could apply comparison on series of dtype '{}; operand names: '{}', '{}'", dt, lhs.name(), rhs.name()),
dt => polars_bail!(InvalidOperation: "could not apply comparison on series of dtype '{}; operand names: '{}', '{}'", dt, lhs.name(), rhs.name()),
};
out.rename(lhs.name());
PolarsResult::Ok(out)
Expand Down
7 changes: 4 additions & 3 deletions crates/polars-ops/src/chunked_array/scatter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use polars_core::utils::arrow::types::NativeType;
use polars_utils::index::check_bounds;

pub trait ChunkedSet<T: Copy> {
/// Invariant for implementations: if the scatter() fails, typically because
/// of bad indexes, then self should remain unmodified.
fn scatter<V>(self, idx: &[IdxSize], values: V) -> PolarsResult<Series>
where
V: IntoIterator<Item = Option<T>>;
Expand Down Expand Up @@ -88,7 +90,7 @@ unsafe fn scatter_impl<V, T: NativeType>(
}
}

impl<T: PolarsOpsNumericType> ChunkedSet<T::Native> for ChunkedArray<T>
impl<T: PolarsOpsNumericType> ChunkedSet<T::Native> for &mut ChunkedArray<T>
where
ChunkedArray<T>: IntoSeries,
{
Expand All @@ -97,8 +99,7 @@ where
V: IntoIterator<Item = Option<T::Native>>,
{
check_bounds(idx, self.len() as IdxSize)?;
let mut ca = self.rechunk();
drop(self);
let mut ca = std::mem::take(self).rechunk();

// SAFETY:
// we will not modify the length
Expand Down
63 changes: 44 additions & 19 deletions py-polars/src/series/scatter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,46 @@ use crate::PySeries;
#[pymethods]
impl PySeries {
fn scatter(&mut self, idx: PySeries, values: PySeries) -> PyResult<()> {
// we take the value because we want a ref count
// of 1 so that we can have mutable access
// we take the value because we want a ref count of 1 so that we can
// have mutable access cheaply via _get_inner_mut().
let s = std::mem::take(&mut self.series);
match scatter(s, &idx.series, &values.series) {
Ok(out) => {
self.series = out;
Ok(())
},
Err(e) => Err(PyErr::from(PyPolarsErr::from(e))),
Err((s, e)) => {
// Restore original series:
self.series = s;
Err(PyErr::from(PyPolarsErr::from(e)))
},
}
}
}

fn scatter(mut s: Series, idx: &Series, values: &Series) -> PolarsResult<Series> {
fn scatter(mut s: Series, idx: &Series, values: &Series) -> Result<Series, (Series, PolarsError)> {
let logical_dtype = s.dtype().clone();

let idx = polars_ops::prelude::convert_to_unsigned_index(idx, s.len())?;
let idx = match polars_ops::prelude::convert_to_unsigned_index(idx, s.len()) {
Ok(idx) => idx,
Err(err) => return Err((s, err)),
};
let idx = idx.rechunk();
let idx = idx.downcast_iter().next().unwrap();

if idx.null_count() > 0 {
return Err(PolarsError::ComputeError(
"index values should not be null".into(),
return Err((
s,
PolarsError::ComputeError("index values should not be null".into()),
));
}

let idx = idx.values().as_slice();

let mut values = values.to_physical_repr().cast(&s.dtype().to_physical())?;
let mut values = match values.to_physical_repr().cast(&s.dtype().to_physical()) {
Ok(values) => values,
Err(err) => return Err((s, err)),
};

// Broadcast values input
if values.len() == 1 && idx.len() > 1 {
Expand All @@ -46,58 +57,68 @@ fn scatter(mut s: Series, idx: &Series, values: &Series) -> PolarsResult<Series>
// do not shadow, otherwise s is not dropped immediately
// and we want to have mutable access
s = s.to_physical_repr().into_owned();
let s_mut_ref = &mut s;
scatter_impl(s_mut_ref, logical_dtype, idx, &values).map_err(|err| (s, err))
}

fn scatter_impl(
s: &mut Series,
logical_dtype: DataType,
idx: &[IdxSize],
values: &Series,
) -> PolarsResult<Series> {
let mutable_s = s._get_inner_mut();

let s = match logical_dtype.to_physical() {
DataType::Int8 => {
let ca: &mut ChunkedArray<Int8Type> = mutable_s.as_mut();
let values = values.i8()?;
std::mem::take(ca).scatter(idx, values)
ca.scatter(idx, values)
},
DataType::Int16 => {
let ca: &mut ChunkedArray<Int16Type> = mutable_s.as_mut();
let values = values.i16()?;
std::mem::take(ca).scatter(idx, values)
ca.scatter(idx, values)
},
DataType::Int32 => {
let ca: &mut ChunkedArray<Int32Type> = mutable_s.as_mut();
let values = values.i32()?;
std::mem::take(ca).scatter(idx, values)
ca.scatter(idx, values)
},
DataType::Int64 => {
let ca: &mut ChunkedArray<Int64Type> = mutable_s.as_mut();
let values = values.i64()?;
std::mem::take(ca).scatter(idx, values)
ca.scatter(idx, values)
},
DataType::UInt8 => {
let ca: &mut ChunkedArray<UInt8Type> = mutable_s.as_mut();
let values = values.u8()?;
std::mem::take(ca).scatter(idx, values)
ca.scatter(idx, values)
},
DataType::UInt16 => {
let ca: &mut ChunkedArray<UInt16Type> = mutable_s.as_mut();
let values = values.u16()?;
std::mem::take(ca).scatter(idx, values)
ca.scatter(idx, values)
},
DataType::UInt32 => {
let ca: &mut ChunkedArray<UInt32Type> = mutable_s.as_mut();
let values = values.u32()?;
std::mem::take(ca).scatter(idx, values)
ca.scatter(idx, values)
},
DataType::UInt64 => {
let ca: &mut ChunkedArray<UInt64Type> = mutable_s.as_mut();
let values = values.u64()?;
std::mem::take(ca).scatter(idx, values)
ca.scatter(idx, values)
},
DataType::Float32 => {
let ca: &mut ChunkedArray<Float32Type> = mutable_s.as_mut();
let values = values.f32()?;
std::mem::take(ca).scatter(idx, values)
ca.scatter(idx, values)
},
DataType::Float64 => {
let ca: &mut ChunkedArray<Float64Type> = mutable_s.as_mut();
let values = values.f64()?;
std::mem::take(ca).scatter(idx, values)
ca.scatter(idx, values)
},
DataType::Boolean => {
let ca = s.bool()?;
Expand All @@ -109,7 +130,11 @@ fn scatter(mut s: Series, idx: &Series, values: &Series) -> PolarsResult<Series>
let values = values.str()?;
ca.scatter(idx, values)
},
_ => panic!("not yet implemented for dtype: {logical_dtype}"),
_ => {
return Err(PolarsError::ComputeError(
format!("not yet implemented for dtype: {logical_dtype}").into(),
));
},
};

s.and_then(|s| s.cast(&logical_dtype))
Expand Down
25 changes: 24 additions & 1 deletion py-polars/tests/unit/series/test_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,36 @@ def test_scatter() -> None:
assert s.scatter([0, 1], [False, True]).to_list() == [False, True, True]

# set negative indices
a = pl.Series(range(5))
a = pl.Series("r", range(5))
a[-2] = None
a[-5] = None
assert a.to_list() == [None, 1, 2, None, 4]

a = pl.Series("x", [1, 2])
with pytest.raises(pl.OutOfBoundsError):
a[-100] = None
assert_series_equal(a, pl.Series("x", [1, 2]))


def test_index_with_None_errors_16905() -> None:
s = pl.Series("s", [1, 2, 3])
with pytest.raises(pl.ComputeError, match="index values should not be null"):
s[[1, None]] = 5
# The error doesn't trash the series, as it used to:
assert_series_equal(s, pl.Series("s", [1, 2, 3]))


def test_object_dtype_16905() -> None:
obj = object()
s = pl.Series("s", [obj, 27], dtype=pl.Object)
# This operation is not semantically wrong, it might be supported in the
# future, but for now it isn't.
with pytest.raises(pl.InvalidOperationError):
s[0] = 5
# The error doesn't trash the series, as it used to:
assert s.dtype == pl.Object
assert s.name == "s"
assert s.to_list() == [obj, 27]


def test_scatter_datetime() -> None:
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ def test_err_invalid_comparison() -> None:

with pytest.raises(
pl.InvalidOperationError,
match="could apply comparison on series of dtype 'object; operand names: 'a', 'b'",
match="could not apply comparison on series of dtype 'object; operand names: 'a', 'b'",
):
_ = pl.Series("a", [object()]) == pl.Series("b", [object])

Expand Down

0 comments on commit 9a3e032

Please sign in to comment.