diff --git a/py-polars/polars/selectors.py b/py-polars/polars/selectors.py index 4cb11506b3f6..9d3cedb47e85 100644 --- a/py-polars/polars/selectors.py +++ b/py-polars/polars/selectors.py @@ -385,10 +385,6 @@ def __and__(self, other: Any) -> Expr: ... def __and__(self, other: Any) -> SelectorType | Expr: if is_column(other): colname = other.meta.output_name() - if self._attrs["name"] == "by_name" and ( - params := self._attrs["params"] - ).get("require_all", True): - return by_name(*params["*names"], colname) other = by_name(colname) if is_selector(other): return _selector_proxy_( @@ -399,6 +395,12 @@ def __and__(self, other: Any) -> SelectorType | Expr: else: return self.as_expr().__and__(other) + def __rand__(self, other: Any) -> Expr: + if is_column(other): + colname = other.meta.output_name() + return by_name(colname) & self + return self.as_expr().__rand__(other) + @overload def __or__(self, other: SelectorType) -> SelectorType: ... @@ -417,6 +419,11 @@ def __or__(self, other: Any) -> SelectorType | Expr: else: return self.as_expr().__or__(other) + def __ror__(self, other: Any) -> Expr: + if is_column(other): + other = by_name(other.meta.output_name()) + return self.as_expr().__ror__(other) + @overload def __xor__(self, other: SelectorType) -> SelectorType: ... @@ -435,21 +442,6 @@ def __xor__(self, other: Any) -> SelectorType | Expr: else: return self.as_expr().__or__(other) - def __rand__(self, other: Any) -> Expr: - if is_column(other): - colname = other.meta.output_name() - if self._attrs["name"] == "by_name" and ( - params := self._attrs["params"] - ).get("require_all", True): - return by_name(colname, *params["*names"]) - other = by_name(colname) - return self.as_expr().__rand__(other) - - def __ror__(self, other: Any) -> Expr: - if is_column(other): - other = by_name(other.meta.output_name()) - return self.as_expr().__ror__(other) - def __rxor__(self, other: Any) -> Expr: if is_column(other): other = by_name(other.meta.output_name()) diff --git a/py-polars/tests/unit/test_selectors.py b/py-polars/tests/unit/test_selectors.py index dd2c415c9a13..f4e29e9194c6 100644 --- a/py-polars/tests/unit/test_selectors.py +++ b/py-polars/tests/unit/test_selectors.py @@ -182,11 +182,17 @@ def test_selector_by_name(df: pl.DataFrame) -> None: # check "by_name & col" for selector_expr, expected in ( - (cs.by_name("abc", "cde") & pl.col("ghi"), ["abc", "cde", "ghi"]), - (pl.col("ghi") & cs.by_name("cde", "abc"), ["ghi", "cde", "abc"]), + (cs.by_name("abc", "cde") & pl.col("ghi"), []), + (cs.by_name("abc", "cde") & pl.col("cde"), ["cde"]), + (pl.col("cde") & cs.by_name("cde", "abc"), ["cde"]), ): assert df.select(selector_expr).columns == expected + # check "by_name & by_name" + assert df.select( + cs.by_name("abc", "cde", "def", "eee") & cs.by_name("cde", "eee", "fgg") + ).columns == ["cde", "eee"] + # expected errors with pytest.raises(ColumnNotFoundError, match="xxx"): df.select(cs.by_name("xxx", "fgg", "!!!"))