Skip to content

Commit

Permalink
fix: Raise on invalid arithmetic shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jun 16, 2024
1 parent 12fb43a commit 24fb425
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
18 changes: 18 additions & 0 deletions crates/polars-core/src/series/arithmetic/borrowed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -524,8 +524,22 @@ impl Rem for &Series {
}
}

fn check_lengths(a: &Series, b: &Series) -> PolarsResult<()> {
match (a.len(), b.len()) {
// broadcasting
(1, _) | (_, 1) => Ok(()),
// equal
(a, b) if a == b => Ok(()),
// unequal
(a, b) => {
polars_bail!(InvalidOperation: "cannot do arithmetic operation on series of different lengths: got {} and {}", a, b)
},
}
}

impl Series {
pub fn try_sub(&self, rhs: &Self) -> PolarsResult<Self> {
check_lengths(self, rhs)?;
match (self.dtype(), rhs.dtype()) {
#[cfg(feature = "dtype-struct")]
(DataType::Struct(_), DataType::Struct(_)) => {
Expand All @@ -539,6 +553,7 @@ impl Series {
}

pub fn try_add(&self, rhs: &Self) -> PolarsResult<Self> {
check_lengths(self, rhs)?;
match (self.dtype(), rhs.dtype()) {
#[cfg(feature = "dtype-struct")]
(DataType::Struct(_), DataType::Struct(_)) => {
Expand All @@ -552,6 +567,7 @@ impl Series {
}

pub fn try_mul(&self, rhs: &Self) -> PolarsResult<Self> {
check_lengths(self, rhs)?;
use DataType::*;
match (self.dtype(), rhs.dtype()) {
#[cfg(feature = "dtype-struct")]
Expand All @@ -575,6 +591,7 @@ impl Series {
}

pub fn try_div(&self, rhs: &Self) -> PolarsResult<Self> {
check_lengths(self, rhs)?;
use DataType::*;
match (self.dtype(), rhs.dtype()) {
#[cfg(feature = "dtype-struct")]
Expand All @@ -599,6 +616,7 @@ impl Series {
}

pub fn try_rem(&self, rhs: &Self) -> PolarsResult<Self> {
check_lengths(self, rhs)?;
match (self.dtype(), rhs.dtype()) {
#[cfg(feature = "dtype-struct")]
(DataType::Struct(_), DataType::Struct(_)) => {
Expand Down
8 changes: 8 additions & 0 deletions py-polars/tests/unit/operations/arithmetic/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,3 +739,11 @@ def test_arithmetic_duration_div_multiply() -> None:
timedelta(microseconds=7500),
],
}


def test_invalid_shapes_err() -> None:
with pytest.raises(
pl.InvalidOperationError,
match=r"cannot do arithmetic operation on series of different lengths: got 2 and 3",
):
pl.Series([1, 2]) + pl.Series([1, 2, 3])

0 comments on commit 24fb425

Please sign in to comment.