-
-
Notifications
You must be signed in to change notification settings - Fork 17.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: Add leftsemi merge #57979
ENH: Add leftsemi merge #57979
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1385,6 +1385,33 @@ cdef class PyObjectHashTable(HashTable): | |
k = kh_put_pymap(self.table, <PyObject*>val, &ret) | ||
self.table.vals[k] = i | ||
|
||
@cython.wraparound(False) | ||
@cython.boundscheck(False) | ||
def hash_inner_join(self, ndarray[object] values, object mask = None) -> tuple[ndarray, ndarray]: | ||
cdef: | ||
Py_ssize_t i, n = len(values) | ||
object val | ||
khiter_t k | ||
Int64Vector locs = Int64Vector() | ||
Int64Vector self_locs = Int64Vector() | ||
Int64VectorData *l | ||
Int64VectorData *sl | ||
# mask not implemented | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we raise NotImplementedError here for now? |
||
|
||
l = &locs.data | ||
sl = &self_locs.data | ||
|
||
for i in range(n): | ||
val = values[i] | ||
hash(val) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the point of this just to raise an error if an object is not hashable? |
||
|
||
k = kh_get_pymap(self.table, <PyObject*>val) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the cast required here? I assume it would know we are already working with PyObjects given the declaration of values |
||
if k != self.table.n_buckets: | ||
append_data_int64(l, i) | ||
append_data_int64(sl, self.table.vals[k]) | ||
|
||
return self_locs.to_array(), locs.to_array() | ||
|
||
def lookup(self, ndarray[object] values, object mask = None) -> ndarray: | ||
# -> np.ndarray[np.intp] | ||
# mask not yet implemented | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -447,7 +447,7 @@ def closed(self) -> bool: | |
AnyAll = Literal["any", "all"] | ||
|
||
# merge | ||
MergeHow = Literal["left", "right", "inner", "outer", "cross"] | ||
MergeHow = Literal["left", "right", "inner", "outer", "cross", "leftsemi"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
So if If |
||
MergeValidate = Literal[ | ||
"one_to_one", | ||
"1:1", | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -315,7 +315,7 @@ | |||||
----------%s | ||||||
right : DataFrame or named Series | ||||||
Object to merge with. | ||||||
how : {'left', 'right', 'outer', 'inner', 'cross'}, default 'inner' | ||||||
how : {'left', 'right', 'outer', 'inner', 'leftsemi', 'cross'}, default 'inner' | ||||||
Type of merge to be performed. | ||||||
|
||||||
* left: use only keys from left frame, similar to a SQL left outer join; | ||||||
|
@@ -326,6 +326,11 @@ | |||||
join; sort keys lexicographically. | ||||||
* inner: use intersection of keys from both frames, similar to a SQL inner | ||||||
join; preserve the order of the left keys. | ||||||
* leftsemi: Filter for rows in the left that have a match on the right; | ||||||
preserve the order of the left keys. Doesn't support `left_index`, `right_index`, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
`indicator` or `validate`. | ||||||
|
||||||
.. versionadded:: 3.0 | ||||||
* cross: creates the cartesian product from both frames, preserves the order | ||||||
of the left keys. | ||||||
on : label or list | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -166,7 +166,8 @@ def merge( | |
validate=validate, | ||
) | ||
else: | ||
op = _MergeOperation( | ||
klass = _MergeOperation if how != "leftsemi" else _SemiMergeOperation | ||
op = klass( | ||
left_df, | ||
right_df, | ||
how=how, | ||
|
@@ -817,7 +818,6 @@ def _validate_tolerance(self, left_join_keys: list[ArrayLike]) -> None: | |
# Overridden by AsOfMerge | ||
pass | ||
|
||
@final | ||
def _reindex_and_concat( | ||
self, | ||
join_index: Index, | ||
|
@@ -945,7 +945,6 @@ def _indicator_post_merge(self, result: DataFrame) -> DataFrame: | |
result = result.drop(labels=["_left_indicator", "_right_indicator"], axis=1) | ||
return result | ||
|
||
@final | ||
def _maybe_restore_index_levels(self, result: DataFrame) -> None: | ||
""" | ||
Restore index levels specified as `on` parameters | ||
|
@@ -989,7 +988,6 @@ def _maybe_restore_index_levels(self, result: DataFrame) -> None: | |
if names_to_restore: | ||
result.set_index(names_to_restore, inplace=True) | ||
|
||
@final | ||
def _maybe_add_join_keys( | ||
self, | ||
result: DataFrame, | ||
|
@@ -1683,7 +1681,7 @@ def get_join_indexers( | |
left_keys: list[ArrayLike], | ||
right_keys: list[ArrayLike], | ||
sort: bool = False, | ||
how: JoinHow = "inner", | ||
how: JoinHow + Literal["leftsemi"] = "inner", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any reason why we don't just add |
||
) -> tuple[npt.NDArray[np.intp] | None, npt.NDArray[np.intp] | None]: | ||
""" | ||
|
||
|
@@ -1740,7 +1738,8 @@ def get_join_indexers( | |
right = Index(rkey) | ||
|
||
if ( | ||
left.is_monotonic_increasing | ||
how != "leftsemi" | ||
and left.is_monotonic_increasing | ||
and right.is_monotonic_increasing | ||
and (left.is_unique or right.is_unique) | ||
): | ||
|
@@ -1883,6 +1882,48 @@ def _convert_to_multiindex(index: Index) -> MultiIndex: | |
return tuple(join_levels), tuple(join_codes), tuple(join_names) | ||
|
||
|
||
class _SemiMergeOperation(_MergeOperation): | ||
def __init__(self, *args, **kwargs): | ||
if kwargs.get("validate", None): | ||
raise NotImplementedError("validate is not supported for semi-join.") | ||
|
||
super().__init__(*args, **kwargs) | ||
if self.left_index or self.right_index: | ||
raise NotImplementedError( | ||
"left_index or right_index are not supported for semi-join." | ||
) | ||
elif self.indicator: | ||
raise NotImplementedError("indicator is not supported for semi-join.") | ||
elif self.sort: | ||
raise NotImplementedError( | ||
"sort is not supported for semi-join. Sort your DataFrame afterwards." | ||
) | ||
|
||
def _maybe_add_join_keys( | ||
self, | ||
result: DataFrame, | ||
left_indexer: npt.NDArray[np.intp] | None, | ||
right_indexer: npt.NDArray[np.intp] | None, | ||
) -> None: | ||
return | ||
|
||
def _maybe_restore_index_levels(self, result: DataFrame) -> None: | ||
return | ||
|
||
def _reindex_and_concat( | ||
self, | ||
join_index: Index, | ||
left_indexer: npt.NDArray[np.intp] | None, | ||
right_indexer: npt.NDArray[np.intp] | None, | ||
) -> DataFrame: | ||
left = self.left[:] | ||
|
||
if left_indexer is not None and not is_range_indexer(left_indexer, len(left)): | ||
lmgr = left._mgr.take(left_indexer, axis=1, verify=False) | ||
left = left._constructor_from_mgr(lmgr, axes=lmgr.axes) | ||
return left | ||
|
||
|
||
class _OrderedMerge(_MergeOperation): | ||
_merge_type = "ordered_merge" | ||
|
||
|
@@ -2470,7 +2511,7 @@ def _factorize_keys( | |
lk = ensure_int64(lk.codes) | ||
rk = ensure_int64(rk.codes) | ||
|
||
elif isinstance(lk, ExtensionArray) and lk.dtype == rk.dtype: | ||
elif how != "leftsemi" and isinstance(lk, ExtensionArray) and lk.dtype == rk.dtype: | ||
if (isinstance(lk.dtype, ArrowDtype) and is_string_dtype(lk.dtype)) or ( | ||
isinstance(lk.dtype, StringDtype) | ||
and lk.dtype.storage in ["pyarrow", "pyarrow_numpy"] | ||
|
@@ -2560,14 +2601,18 @@ def _factorize_keys( | |
lk_data, rk_data = lk, rk # type: ignore[assignment] | ||
lk_mask, rk_mask = None, None | ||
|
||
hash_join_available = how == "inner" and not sort and lk.dtype.kind in "iufb" | ||
hash_join_available = how == "inner" and not sort | ||
if hash_join_available: | ||
rlab = rizer.factorize(rk_data, mask=rk_mask) | ||
if rizer.get_count() == len(rlab): | ||
ridx, lidx = rizer.hash_inner_join(lk_data, lk_mask) | ||
return lidx, ridx, -1 | ||
else: | ||
llab = rizer.factorize(lk_data, mask=lk_mask) | ||
elif how == "leftsemi": | ||
# populate hashtable for right and then do a hash join | ||
rizer.factorize(rk_data, mask=rk_mask) | ||
return rizer.hash_inner_join(lk_data, lk_mask)[1], None, -1 # type: ignore[return-value] | ||
else: | ||
llab = rizer.factorize(lk_data, mask=lk_mask) | ||
rlab = rizer.factorize(rk_data, mask=rk_mask) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import pytest | ||
|
||
import pandas.util._test_decorators as td | ||
|
||
import pandas as pd | ||
import pandas._testing as tm | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"vals_left, vals_right, dtype", | ||
[ | ||
([1, 2, 3], [1, 2], "int64"), | ||
(["a", "b", "c"], ["a", "b"], "object"), | ||
pytest.param( | ||
["a", "b", "c"], | ||
["a", "b"], | ||
"string[pyarrow]", | ||
marks=td.skip_if_no("pyarrow"), | ||
), | ||
], | ||
) | ||
def test_leftsemi(vals_left, vals_right, dtype): | ||
vals_left = pd.Series(vals_left, dtype=dtype) | ||
vals_right = pd.Series(vals_right, dtype=dtype) | ||
left = pd.DataFrame({"a": vals_left, "b": [1, 2, 3]}) | ||
right = pd.DataFrame({"a": vals_right, "c": 1}) | ||
expected = pd.DataFrame({"a": vals_right, "b": [1, 2]}) | ||
result = left.merge(right, how="leftsemi") | ||
tm.assert_frame_equal(result, expected) | ||
|
||
right = pd.DataFrame({"d": vals_right, "c": 1}) | ||
result = left.merge(right, how="leftsemi", left_on="a", right_on="d") | ||
tm.assert_frame_equal(result, expected) | ||
|
||
right = pd.DataFrame({"d": vals_right, "c": 1}) | ||
result = left.merge(right, how="leftsemi", left_on=["a", "b"], right_on=["d", "c"]) | ||
tm.assert_frame_equal(result, expected.head(1)) | ||
|
||
|
||
def test_leftsemi_invalid(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add a test for raising NotImplementederror when passing a mask? |
||
left = pd.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]}) | ||
right = pd.DataFrame({"a": [1, 2], "c": 1}) | ||
|
||
msg = "left_index or right_index are not supported for semi-join." | ||
with pytest.raises(NotImplementedError, match=msg): | ||
left.merge(right, how="leftsemi", left_index=True, right_on="a") | ||
with pytest.raises(NotImplementedError, match=msg): | ||
left.merge(right, how="leftsemi", right_index=True, left_on="a") | ||
|
||
msg = "validate is not supported for semi-join." | ||
with pytest.raises(NotImplementedError, match=msg): | ||
left.merge(right, how="leftsemi", validate="one_to_one") | ||
|
||
msg = "indicator is not supported for semi-join." | ||
with pytest.raises(NotImplementedError, match=msg): | ||
left.merge(right, how="leftsemi", indicator=True) | ||
|
||
msg = "sort is not supported for semi-join. Sort your DataFrame afterwards." | ||
with pytest.raises(NotImplementedError, match=msg): | ||
left.merge(right, how="leftsemi", sort=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stylistic nit but would be better to not use single character variable names, especially those that conflict with debugger aliases. The cython debugger is hard enough to use as is :-)