diff --git a/crates/polars-ops/src/frame/join/merge_sorted.rs b/crates/polars-ops/src/frame/join/merge_sorted.rs index 8312c19af2cc..ffb0913a6a84 100644 --- a/crates/polars-ops/src/frame/join/merge_sorted.rs +++ b/crates/polars-ops/src/frame/join/merge_sorted.rs @@ -36,7 +36,7 @@ pub fn _merge_sorted_dfs( return Ok(right.clone()); } - let merge_indicator = series_to_merge_indicator(left_s, right_s); + let merge_indicator = series_to_merge_indicator(left_s, right_s)?; let new_columns = left .get_columns() .iter() @@ -90,7 +90,10 @@ fn merge_series(lhs: &Series, rhs: &Series, merge_indicator: &[bool]) -> PolarsR .fields_as_series() .iter() .zip(rhs.fields_as_series()) - .map(|(lhs, rhs)| merge_series(lhs, &rhs, merge_indicator)) + .map(|(lhs, rhs)| { + merge_series(lhs, &rhs, merge_indicator) + .map(|merged| merged.with_name(lhs.name().clone())) + }) .collect::>>()?; StructChunked::from_series(PlSmallStr::EMPTY, new_fields[0].len(), new_fields.iter()) .unwrap() @@ -139,11 +142,11 @@ where unsafe { iter.trust_my_length(total_len).collect_trusted() } } -fn series_to_merge_indicator(lhs: &Series, rhs: &Series) -> Vec { +fn series_to_merge_indicator(lhs: &Series, rhs: &Series) -> PolarsResult> { let lhs_s = lhs.to_physical_repr().into_owned(); let rhs_s = rhs.to_physical_repr().into_owned(); - match lhs_s.dtype() { + let out = match lhs_s.dtype() { DataType::Boolean => { let lhs = lhs_s.bool().unwrap(); let rhs = rhs_s.bool().unwrap(); @@ -159,6 +162,13 @@ fn series_to_merge_indicator(lhs: &Series, rhs: &Series) -> Vec { let rhs = rhs_s.binary().unwrap(); get_merge_indicator(lhs.into_iter(), rhs.into_iter()) }, + #[cfg(feature = "dtype-struct")] + DataType::Struct(_) => { + let options = SortOptions::default(); + let lhs = lhs_s.struct_().unwrap().get_row_encoded(options)?; + let rhs = rhs_s.struct_().unwrap().get_row_encoded(options)?; + get_merge_indicator(lhs.into_iter(), rhs.into_iter()) + }, _ => { with_match_physical_numeric_polars_type!(lhs_s.dtype(), |$T| { let lhs: &ChunkedArray<$T> = lhs_s.as_ref().as_ref().as_ref(); @@ -168,7 +178,8 @@ fn series_to_merge_indicator(lhs: &Series, rhs: &Series) -> Vec { }) }, - } + }; + Ok(out) } // get a boolean values, left: true, right: false diff --git a/py-polars/tests/unit/operations/test_merge_sorted.py b/py-polars/tests/unit/operations/test_merge_sorted.py index d851368b6159..d8d9f65dc1fd 100644 --- a/py-polars/tests/unit/operations/test_merge_sorted.py +++ b/py-polars/tests/unit/operations/test_merge_sorted.py @@ -172,11 +172,36 @@ def test_merge_sorted_parametric_string(lhs: pl.Series, rhs: pl.Series) -> None: assert_series_equal(merge_sorted, append_sorted) +@given( + lhs=series( + name="a", + allowed_dtypes=[ + pl.Struct({"x": pl.Int32, "y": pl.Struct({"x": pl.Int8, "y": pl.Int8})}) + ], + allow_null=False, + ), # Nulls see: https://github.com/pola-rs/polars/issues/20991 + rhs=series( + name="a", + allowed_dtypes=[ + pl.Struct({"x": pl.Int32, "y": pl.Struct({"x": pl.Int8, "y": pl.Int8})}) + ], + allow_null=False, + ), # Nulls see: https://github.com/pola-rs/polars/issues/20991 +) +def test_merge_sorted_parametric_struct(lhs: pl.Series, rhs: pl.Series) -> None: + l_df = pl.DataFrame([lhs.sort()]) + r_df = pl.DataFrame([rhs.sort()]) + + merge_sorted = l_df.lazy().merge_sorted(r_df.lazy(), "a").collect().get_column("a") + append_sorted = lhs.append(rhs).sort() + + assert_series_equal(merge_sorted, append_sorted) + + @given( s=series( name="a", excluded_dtypes=[ - pl.Struct, # Bug. See https://github.com/pola-rs/polars/issues/20986 pl.Categorical( ordering="lexical" ), # Bug. See https://github.com/pola-rs/polars/issues/21025