diff --git a/crates/polars-core/src/testing.rs b/crates/polars-core/src/testing.rs index 6c630b02af96..cb9f6e5389ab 100644 --- a/crates/polars-core/src/testing.rs +++ b/crates/polars-core/src/testing.rs @@ -6,7 +6,7 @@ use crate::prelude::*; impl Series { /// Check if series are equal. Note that `None == None` evaluates to `false` pub fn equals(&self, other: &Series) -> bool { - if self.null_count() > 0 || other.null_count() > 0 || self.dtype() != other.dtype() { + if self.null_count() > 0 || other.null_count() > 0 { false } else { self.equals_missing(other) @@ -14,10 +14,10 @@ impl Series { } /// Check if all values in series are equal where `None == None` evaluates to `true`. - /// Two [`Datetime`](DataType::Datetime) series are *not* equal if their timezones are different, regardless - /// if they represent the same UTC time or not. pub fn equals_missing(&self, other: &Series) -> bool { match (self.dtype(), other.dtype()) { + // Two [`Datetime`](DataType::Datetime) series are *not* equal if their timezones + // are different, regardless if they represent the same UTC time or not. #[cfg(feature = "timezones")] (DataType::Datetime(_, tz_lhs), DataType::Datetime(_, tz_rhs)) => { if tz_lhs != tz_rhs { @@ -27,17 +27,14 @@ impl Series { _ => {}, } - // differences from Partial::eq in that numerical dtype may be different - self.len() == other.len() - && self.name() == other.name() - && self.null_count() == other.null_count() - && { - let eq = self.equal_missing(other); - match eq { - Ok(b) => b.all(), - Err(_) => false, - } + // Differs from Partial::eq in that numerical dtype may be different + self.len() == other.len() && self.null_count() == other.null_count() && { + let eq = self.equal_missing(other); + match eq { + Ok(b) => b.all(), + Err(_) => false, } + } } /// Get a pointer to the underlying data of this [`Series`]. @@ -99,7 +96,7 @@ impl DataFrame { return false; } for (left, right) in self.get_columns().iter().zip(other.get_columns()) { - if !left.equals(right) { + if left.name() != right.name() || !left.equals(right) { return false; } } @@ -125,7 +122,7 @@ impl DataFrame { return false; } for (left, right) in self.get_columns().iter().zip(other.get_columns()) { - if !left.equals_missing(right) { + if left.name() != right.name() || !left.equals_missing(right) { return false; } } @@ -191,10 +188,11 @@ mod test { } #[test] - fn test_series_dtype_noteq() { + fn test_series_dtype_not_equal() { let s_i32 = Series::new("a", &[1_i32, 2_i32]); let s_i64 = Series::new("a", &[1_i64, 2_i64]); - assert!(!s_i32.equals(&s_i64)); + assert!(s_i32.dtype() != s_i64.dtype()); + assert!(s_i32.equals(&s_i64)); } #[test] diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 129e30fa2921..3ab381fd06c6 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -4120,6 +4120,7 @@ def equals( other: Series, *, check_dtypes: bool = False, + check_names: bool = False, null_equal: bool = True, ) -> bool: """ @@ -4131,6 +4132,8 @@ def equals( Series to compare with. check_dtypes Require data types to match. + check_names + Require names to match. null_equal Consider null values as equal. @@ -4148,7 +4151,10 @@ def equals( False """ return self._s.equals( - other._s, check_dtypes=check_dtypes, null_equal=null_equal + other._s, + check_dtypes=check_dtypes, + check_names=check_names, + null_equal=null_equal, ) def cast( diff --git a/py-polars/src/series/mod.rs b/py-polars/src/series/mod.rs index bda26f533eb5..872dbaf2252a 100644 --- a/py-polars/src/series/mod.rs +++ b/py-polars/src/series/mod.rs @@ -324,10 +324,19 @@ impl PySeries { self.series.has_validity() } - fn equals(&self, other: &PySeries, check_dtypes: bool, null_equal: bool) -> bool { + fn equals( + &self, + other: &PySeries, + check_dtypes: bool, + check_names: bool, + null_equal: bool, + ) -> bool { if check_dtypes && (self.series.dtype() != other.series.dtype()) { return false; } + if check_names && (self.series.name() != other.series.name()) { + return false; + } if null_equal { self.series.equals_missing(&other.series) } else { diff --git a/py-polars/tests/unit/series/test_equals.py b/py-polars/tests/unit/series/test_equals.py index da607b934936..989554656253 100644 --- a/py-polars/tests/unit/series/test_equals.py +++ b/py-polars/tests/unit/series/test_equals.py @@ -30,6 +30,13 @@ def test_equals() -> None: assert s3.dt.convert_time_zone("Asia/Tokyo").equals(s4) is True +def test_series_equals_check_names() -> None: + s1 = pl.Series("foo", [1, 2, 3]) + s2 = pl.Series("bar", [1, 2, 3]) + assert s1.equals(s2) is True + assert s1.equals(s2, check_names=True) is False + + def test_eq_list_cmp_list() -> None: s = pl.Series([[1], [1, 2]]) result = s == [1, 2]