diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index f9cbf6fb3..6c0ad0164 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -295,7 +295,7 @@ def join( self, other: Self, *, - how: Literal["left", "inner", "outer", "cross", "anti"] = "inner", + how: Literal["left", "inner", "outer", "cross", "anti", "semi"] = "inner", left_on: str | list[str] | None = None, right_on: str | list[str] | None = None, ) -> Self: @@ -339,7 +339,7 @@ def join( n_bytes=8, columns=[*self.columns, *other.columns] ) - other = ( + other_native = ( other._native_dataframe.loc[:, right_on] .rename( # rename to avoid creating extra columns in join columns=dict(zip(right_on, left_on)) # type: ignore[arg-type] @@ -348,7 +348,7 @@ def join( ) return self._from_native_dataframe( self._native_dataframe.merge( - other, + other_native, how="outer", indicator=indicator_token, left_on=left_on, @@ -359,6 +359,23 @@ def join( .reset_index(drop=True) ) + if how == "semi": + other_native = ( + other._native_dataframe.loc[:, right_on] + .rename( # rename to avoid creating extra columns in join + columns=dict(zip(right_on, left_on)) # type: ignore[arg-type] + ) + .drop_duplicates() # avoids potential rows duplication from inner join + ) + return self._from_native_dataframe( + self._native_dataframe.merge( + other_native, + how="inner", + left_on=left_on, + right_on=left_on, + ).reset_index(drop=True) + ) + return self._from_native_dataframe( self._native_dataframe.merge( other._native_dataframe, diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 1a3f47032..4d07d4b4c 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -174,13 +174,14 @@ def join( self, other: Self, *, - how: Literal["inner", "cross", "anti"] = "inner", + how: Literal["inner", "cross", "semi", "anti"] = "inner", left_on: str | list[str] | None = None, right_on: str | list[str] | None = None, ) -> Self: - _supported_joins = {"inner", "cross", "anti"} + _supported_joins = ("inner", "cross", "anti", "semi") + if how not in _supported_joins: - msg = f"Only the following join stragies are supported: {_supported_joins}" + msg = f"Only the following join stragies are supported: {_supported_joins}; found '{how}'." raise NotImplementedError(msg) if how == "cross" and (left_on or right_on): @@ -1475,7 +1476,7 @@ def join( self, other: Self, *, - how: Literal["inner", "cross", "anti"] = "inner", + how: Literal["inner", "cross", "semi", "anti"] = "inner", left_on: str | list[str] | None = None, right_on: str | list[str] | None = None, ) -> Self: @@ -1487,8 +1488,9 @@ def join( how: Join strategy. - * *inner*: Returns rows that have matching values in both tables - * *cross*: Returns the Cartesian product of rows from both tables + * *inner*: Returns rows that have matching values in both tables. + * *cross*: Returns the Cartesian product of rows from both tables. + * *semi*: Filter rows that have a match in the right table. * *anti*: Filter rows that do not have a match in the right table. left_on: Name(s) of the left join column(s). @@ -2900,7 +2902,7 @@ def join( self, other: Self, *, - how: Literal["inner", "cross", "anti"] = "inner", + how: Literal["inner", "cross", "semi", "anti"] = "inner", left_on: str | list[str] | None = None, right_on: str | list[str] | None = None, ) -> Self: @@ -2912,8 +2914,9 @@ def join( how: Join strategy. - * *inner*: Returns rows that have matching values in both tables - * *cross*: Returns the Cartesian product of rows from both tables + * *inner*: Returns rows that have matching values in both tables. + * *cross*: Returns the Cartesian product of rows from both tables. + * *semi*: Filter rows that have a match in the right table. * *anti*: Filter rows that do not have a match in the right table. left_on: Join column of the left DataFrame. diff --git a/tests/frame/join_test.py b/tests/frame/join_test.py new file mode 100644 index 000000000..61ded0de1 --- /dev/null +++ b/tests/frame/join_test.py @@ -0,0 +1,110 @@ +from __future__ import annotations + +import re +from typing import Any + +import pandas as pd +import pytest + +import narwhals.stable.v1 as nw +from narwhals._pandas_like.utils import Implementation +from tests.utils import compare_dicts + + +def test_inner_join(constructor_with_lazy: Any) -> None: + data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} + df = nw.from_native(constructor_with_lazy(data)).lazy() + df_right = df + result = df.join(df_right, left_on=["a", "b"], right_on=["a", "b"], how="inner") + result_native = nw.to_native(result) + expected = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9], "z_right": [7.0, 8, 9]} + compare_dicts(result_native, expected) + + result = df.collect().join(df_right.collect(), left_on="a", right_on="a", how="inner") # type: ignore[assignment] + result_native = nw.to_native(result) + expected = { + "a": [1, 3, 2], + "b": [4, 4, 6], + "b_right": [4, 4, 6], + "z": [7.0, 8, 9], + "z_right": [7.0, 8, 9], + } + compare_dicts(result_native, expected) + + +def test_cross_join(constructor_with_lazy: Any) -> None: + data = {"a": [1, 3, 2]} + df = nw.from_native(constructor_with_lazy(data)) + result = df.join(df, how="cross") # type: ignore[arg-type] + + expected = {"a": [1, 1, 1, 3, 3, 3, 2, 2, 2], "a_right": [1, 3, 2, 1, 3, 2, 1, 3, 2]} + compare_dicts(result, expected) + + with pytest.raises(ValueError, match="Can not pass left_on, right_on for cross join"): + df.join(df, how="cross", left_on="a") # type: ignore[arg-type] + + +def test_cross_join_non_pandas() -> None: + data = {"a": [1, 3, 2]} + df = nw.from_native(pd.DataFrame(data)) + # HACK to force testing for a non-pandas codepath + df._dataframe._implementation = Implementation.MODIN + result = df.join(df, how="cross") # type: ignore[arg-type] + expected = {"a": [1, 1, 1, 3, 3, 3, 2, 2, 2], "a_right": [1, 3, 2, 1, 3, 2, 1, 3, 2]} + compare_dicts(result, expected) + + +@pytest.mark.parametrize( + ("join_key", "filter_expr", "expected"), + [ + (["a", "b"], (nw.col("b") < 5), {"a": [2], "b": [6], "z": [9]}), + (["b"], (nw.col("b") < 5), {"a": [2], "b": [6], "z": [9]}), + (["b"], (nw.col("b") > 5), {"a": [1, 3], "b": [4, 4], "z": [7.0, 8.0]}), + ], +) +def test_anti_join( + constructor_with_lazy: Any, + join_key: list[str], + filter_expr: nw.Expr, + expected: dict[str, list[Any]], +) -> None: + data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} + df = nw.from_native(constructor_with_lazy(data)) + other = df.filter(filter_expr) + result = df.join(other, how="anti", left_on=join_key, right_on=join_key) # type: ignore[arg-type] + compare_dicts(result, expected) + + +@pytest.mark.parametrize( + ("join_key", "filter_expr", "expected"), + [ + (["a"], (nw.col("b") > 5), {"a": [2], "b": [6], "z": [9]}), + (["b"], (nw.col("b") < 5), {"a": [1, 3], "b": [4, 4], "z": [7, 8]}), + (["a", "b"], (nw.col("b") < 5), {"a": [1, 3], "b": [4, 4], "z": [7, 8]}), + ], +) +def test_semi_join( + constructor: Any, + join_key: list[str], + filter_expr: nw.Expr, + expected: dict[str, list[Any]], +) -> None: + data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} + df = nw.from_native(constructor(data)) + other = df.filter(filter_expr) + result = df.join(other, how="semi", left_on=join_key, right_on=join_key) # type: ignore[arg-type] + compare_dicts(result, expected) + + +@pytest.mark.parametrize("how", ["left", "right", "full"]) +def test_join_not_implemented(constructor_with_lazy: Any, how: str) -> None: + data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9]} + df = nw.from_native(constructor_with_lazy(data)) + + with pytest.raises( + NotImplementedError, + match=re.escape( + f"Only the following join stragies are supported: ('inner', 'cross', 'anti', 'semi'); found '{how}'." + ), + ): + df.join(df, left_on="a", right_on="a", how=how) # type: ignore[arg-type] diff --git a/tests/frame/test_common.py b/tests/frame/test_common.py index 626eebdac..3234a64a1 100644 --- a/tests/frame/test_common.py +++ b/tests/frame/test_common.py @@ -13,7 +13,6 @@ import pytest import narwhals.stable.v1 as nw -from narwhals._pandas_like.utils import Implementation from narwhals.functions import _get_deps_info from narwhals.functions import _get_sys_info from narwhals.functions import show_versions @@ -132,79 +131,6 @@ def test_lit_error(df_raw: Any) -> None: _ = df.with_columns(nw.lit([1, 2]).alias("lit")) -@pytest.mark.parametrize( - "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] -) -def test_join(df_raw: Any) -> None: - df = nw.from_native(df_raw).lazy() - df_right = df - result = df.join(df_right, left_on=["a", "b"], right_on=["a", "b"], how="inner") - result_native = nw.to_native(result) - expected = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.0, 8, 9], "z_right": [7.0, 8, 9]} - compare_dicts(result_native, expected) - - with pytest.raises(NotImplementedError): - result = df.join(df_right, left_on="a", right_on="a", how="left") # type: ignore[arg-type] - - result = df.collect().join(df_right.collect(), left_on="a", right_on="a", how="inner") # type: ignore[assignment] - result_native = nw.to_native(result) - expected = { - "a": [1, 3, 2], - "b": [4, 4, 6], - "b_right": [4, 4, 6], - "z": [7.0, 8, 9], - "z_right": [7.0, 8, 9], - } - compare_dicts(result_native, expected) - - -@pytest.mark.parametrize("df_raw", [df_polars, df_lazy, df_pandas, df_mpd]) -def test_cross_join(df_raw: Any) -> None: - df = nw.from_native(df_raw).select("a") - result = df.join(df, how="cross") # type: ignore[arg-type] - - expected = {"a": [1, 1, 1, 3, 3, 3, 2, 2, 2], "a_right": [1, 3, 2, 1, 3, 2, 1, 3, 2]} - compare_dicts(result, expected) - - with pytest.raises(ValueError, match="Can not pass left_on, right_on for cross join"): - df.join(df, how="cross", left_on="a") # type: ignore[arg-type] - - -def test_cross_join_non_pandas() -> None: - df = nw.from_native(df_pandas).select("a") - # HACK to force testing for a non-pandas codepath - df._dataframe._implementation = Implementation.MODIN - result = df.join(df, how="cross") # type: ignore[arg-type] - expected = {"a": [1, 1, 1, 3, 3, 3, 2, 2, 2], "a_right": [1, 3, 2, 1, 3, 2, 1, 3, 2]} - compare_dicts(result, expected) - - -@pytest.mark.parametrize( - "df_raw", - [ - df_polars, - df_lazy, - df_pandas, - # df_mpd, (TODO(Unassigned): understand the difference between ipython/jupyter and pytest runs) - ], -) -@pytest.mark.parametrize( - ("join_key", "filter_expr", "expected"), - [ - (["a", "b"], (nw.col("b") < 5), {"a": [2], "b": [6], "z": [9]}), - (["b"], (nw.col("b") < 5), {"a": [2], "b": [6], "z": [9]}), - (["b"], (nw.col("b") > 5), {"a": [1, 3], "b": [4, 4], "z": [7.0, 8.0]}), - ], -) -def test_anti_join( - df_raw: Any, join_key: list[str], filter_expr: nw.Expr, expected: dict[str, list[Any]] -) -> None: - df = nw.from_native(df_raw) - other = df.filter(filter_expr) - result = df.join(other, how="anti", left_on=join_key, right_on=join_key) # type: ignore[arg-type] - compare_dicts(result, expected) - - @pytest.mark.parametrize( "df_raw", [df_pandas, df_lazy, df_pandas_nullable, df_pandas_pyarrow] )