diff --git a/py-polars/polars/testing/asserts/frame.py b/py-polars/polars/testing/asserts/frame.py index 5d6112b6cb08..800289d2952e 100644 --- a/py-polars/polars/testing/asserts/frame.py +++ b/py-polars/polars/testing/asserts/frame.py @@ -257,6 +257,7 @@ def assert_frame_not_equal( """ __tracebackhide__ = True + _assert_correct_input_type(left, right) try: assert_frame_equal( left=left, @@ -272,5 +273,5 @@ def assert_frame_not_equal( except AssertionError: return else: - msg = "frames are equal" + msg = "frames are equal (but are expected not to be)" raise AssertionError(msg) diff --git a/py-polars/polars/testing/asserts/series.py b/py-polars/polars/testing/asserts/series.py index ad316f565aad..65e5169cab74 100644 --- a/py-polars/polars/testing/asserts/series.py +++ b/py-polars/polars/testing/asserts/series.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from polars._utils.deprecation import deprecate_renamed_parameter from polars.datatypes import ( @@ -20,6 +20,19 @@ from polars import DataType +def _assert_correct_input_type(left: Any, right: Any) -> bool: + __tracebackhide__ = True + + if not (isinstance(left, Series) and isinstance(right, Series)): + raise_assertion_error( + "inputs", + "unexpected input types", + type(left).__name__, + type(right).__name__, + ) + return True + + @deprecate_renamed_parameter("check_dtype", "check_dtypes", version="0.20.31") def assert_series_equal( left: Series, @@ -90,13 +103,7 @@ def assert_series_equal( """ __tracebackhide__ = True - if not (isinstance(left, Series) and isinstance(right, Series)): # type: ignore[redundant-expr] - raise_assertion_error( - "inputs", - "unexpected input types", - type(left).__name__, - type(right).__name__, - ) + _assert_correct_input_type(left, right) if left.len() != right.len(): raise_assertion_error("Series", "length mismatch", left.len(), right.len()) @@ -404,6 +411,7 @@ def assert_series_not_equal( """ __tracebackhide__ = True + _assert_correct_input_type(left, right) try: assert_series_equal( left=left, @@ -419,5 +427,5 @@ def assert_series_not_equal( except AssertionError: return else: - msg = "Series are equal" + msg = "Series are equal (but are expected not to be)" raise AssertionError(msg) diff --git a/py-polars/tests/unit/testing/test_assert_frame_equal.py b/py-polars/tests/unit/testing/test_assert_frame_equal.py index 4cb8b3f5106f..f6a1c9f192cd 100644 --- a/py-polars/tests/unit/testing/test_assert_frame_equal.py +++ b/py-polars/tests/unit/testing/test_assert_frame_equal.py @@ -278,13 +278,17 @@ def test_assert_frame_equal_pass() -> None: assert_frame_equal(df1, df2) -def test_assert_frame_equal_types() -> None: +@pytest.mark.parametrize( + "assert_function", + [assert_frame_equal, assert_frame_not_equal], +) +def test_assert_frame_equal_types(assert_function: Any) -> None: df1 = pl.DataFrame({"a": [1, 2]}) srs1 = pl.Series(values=[1, 2], name="a") with pytest.raises( AssertionError, match=r"inputs are different \(unexpected input types\)" ): - assert_frame_equal(df1, srs1) # type: ignore[arg-type] + assert_function(df1, srs1) def test_assert_frame_equal_length_mismatch() -> None: @@ -295,6 +299,7 @@ def test_assert_frame_equal_length_mismatch() -> None: match=r"DataFrames are different \(number of rows does not match\)", ): assert_frame_equal(df1, df2) + assert_frame_not_equal(df1, df2) def test_assert_frame_equal_column_mismatch() -> None: @@ -304,6 +309,7 @@ def test_assert_frame_equal_column_mismatch() -> None: AssertionError, match="columns \\['a'\\] in left DataFrame, but not in right" ): assert_frame_equal(df1, df2) + assert_frame_not_equal(df1, df2) def test_assert_frame_equal_column_mismatch2() -> None: @@ -314,6 +320,7 @@ def test_assert_frame_equal_column_mismatch2() -> None: match="columns \\['b', 'c'\\] in right LazyFrame, but not in left", ): assert_frame_equal(df1, df2) + assert_frame_not_equal(df1, df2) def test_assert_frame_equal_column_mismatch_order() -> None: @@ -323,6 +330,7 @@ def test_assert_frame_equal_column_mismatch_order() -> None: assert_frame_equal(df1, df2) assert_frame_equal(df1, df2, check_column_order=False) + assert_frame_not_equal(df1, df2) def test_assert_frame_equal_check_row_order() -> None: @@ -331,25 +339,33 @@ def test_assert_frame_equal_check_row_order() -> None: with pytest.raises(AssertionError, match="value mismatch for column 'a'"): assert_frame_equal(df1, df2) + assert_frame_equal(df1, df2, check_row_order=False) + assert_frame_not_equal(df1, df2) def test_assert_frame_equal_check_row_col_order() -> None: df1 = pl.DataFrame({"a": [1, 2], "b": [4, 3]}) - df3 = pl.DataFrame({"b": [3, 4], "a": [2, 1]}) + df2 = pl.DataFrame({"b": [3, 4], "a": [2, 1]}) with pytest.raises(AssertionError, match="columns are not in the same order"): - assert_frame_equal(df1, df3, check_row_order=False) - assert_frame_equal(df1, df3, check_row_order=False, check_column_order=False) + assert_frame_equal(df1, df2, check_row_order=False) + + assert_frame_equal(df1, df2, check_row_order=False, check_column_order=False) + assert_frame_not_equal(df1, df2) -def test_assert_frame_equal_check_row_order_unsortable() -> None: +@pytest.mark.parametrize( + "assert_function", + [assert_frame_equal, assert_frame_not_equal], +) +def test_assert_frame_equal_check_row_order_unsortable(assert_function: Any) -> None: df1 = pl.DataFrame({"a": [object(), object()], "b": [3, 4]}) df2 = pl.DataFrame({"a": [object(), object()], "b": [4, 3]}) with pytest.raises( TypeError, match="cannot set `check_row_order=False`.*unsortable columns" ): - assert_frame_equal(df1, df2, check_row_order=False) + assert_function(df1, df2, check_row_order=False) def test_assert_frame_equal_dtypes_mismatch() -> None: @@ -360,6 +376,9 @@ def test_assert_frame_equal_dtypes_mismatch() -> None: with pytest.raises(AssertionError, match="dtypes do not match"): assert_frame_equal(df1, df2, check_column_order=False) + assert_frame_not_equal(df1, df2, check_column_order=False) + assert_frame_not_equal(df1, df2) + def test_assert_frame_not_equal() -> None: df = pl.DataFrame({"a": [1, 2]}) diff --git a/py-polars/tests/unit/testing/test_assert_series_equal.py b/py-polars/tests/unit/testing/test_assert_series_equal.py index 92ebe13a0104..c523fe193a30 100644 --- a/py-polars/tests/unit/testing/test_assert_series_equal.py +++ b/py-polars/tests/unit/testing/test_assert_series_equal.py @@ -35,10 +35,11 @@ def test_assert_series_equal_parametric_array(data: st.DataObject) -> None: def test_compare_series_value_mismatch() -> None: srs1 = pl.Series([1, 2, 3]) srs2 = pl.Series([2, 3, 4]) - assert_series_not_equal(srs1, srs2) + with pytest.raises( - AssertionError, match=r"Series are different \(exact value mismatch\)" + AssertionError, + match=r"Series are different \(exact value mismatch\)", ): assert_series_equal(srs1, srs2) @@ -46,25 +47,33 @@ def test_compare_series_value_mismatch() -> None: def test_compare_series_empty_equal() -> None: srs1 = pl.Series([]) srs2 = pl.Series(()) - assert_series_equal(srs1, srs2) - with pytest.raises(AssertionError): + + with pytest.raises( + AssertionError, + match=r"Series are equal \(but are expected not to be\)", + ): assert_series_not_equal(srs1, srs2) def test_assert_series_equal_check_order() -> None: srs1 = pl.Series([1, 2, 3, None]) srs2 = pl.Series([2, None, 3, 1]) - assert_series_equal(srs1, srs2, check_order=False) - with pytest.raises(AssertionError): + + with pytest.raises( + AssertionError, + match=r"Series are equal \(but are expected not to be\)", + ): assert_series_not_equal(srs1, srs2, check_order=False) def test_assert_series_equal_check_order_unsortable_type() -> None: s = pl.Series([object(), object()]) - - with pytest.raises(TypeError): + with pytest.raises( + TypeError, + match="cannot set `check_order=False` on Series with unsortable data type", + ): assert_series_equal(s, s, check_order=False) @@ -123,32 +132,45 @@ def test_compare_series_value_mismatch_string() -> None: assert_series_not_equal(srs1, srs2) with pytest.raises( - AssertionError, match=r"Series are different \(exact value mismatch\)" + AssertionError, + match=r"Series are different \(exact value mismatch\)", ): assert_series_equal(srs1, srs2) -def test_compare_series_type_mismatch() -> None: +def test_compare_series_dtype_mismatch() -> None: srs1 = pl.Series([1, 2, 3]) - srs2 = pl.DataFrame({"col1": [2, 3, 4]}) + srs2 = pl.Series([1.0, 2.0, 3.0]) + assert_series_not_equal(srs1, srs2) with pytest.raises( - AssertionError, match=r"inputs are different \(unexpected input types\)" + AssertionError, + match=r"Series are different \(dtype mismatch\)", ): - assert_series_equal(srs1, srs2) # type: ignore[arg-type] + assert_series_equal(srs1, srs2) + + +@pytest.mark.parametrize( + "assert_function", [assert_series_equal, assert_series_not_equal] +) +def test_compare_series_input_type_mismatch(assert_function: Any) -> None: + srs1 = pl.Series([1, 2, 3]) + srs2 = pl.DataFrame({"col1": [2, 3, 4]}) - srs3 = pl.Series([1.0, 2.0, 3.0]) - assert_series_not_equal(srs1, srs3) with pytest.raises( - AssertionError, match=r"Series are different \(dtype mismatch\)" + AssertionError, + match=r"inputs are different \(unexpected input types\)", ): - assert_series_equal(srs1, srs3) + assert_function(srs1, srs2) def test_compare_series_name_mismatch() -> None: srs1 = pl.Series(values=[1, 2, 3], name="srs1") srs2 = pl.Series(values=[1, 2, 3], name="srs2") - with pytest.raises(AssertionError, match=r"Series are different \(name mismatch\)"): + with pytest.raises( + AssertionError, + match=r"Series are different \(name mismatch\)", + ): assert_series_equal(srs1, srs2) @@ -158,7 +180,8 @@ def test_compare_series_length_mismatch() -> None: assert_series_not_equal(srs1, srs2) with pytest.raises( - AssertionError, match=r"Series are different \(length mismatch\)" + AssertionError, + match=r"Series are different \(length mismatch\)", ): assert_series_equal(srs1, srs2) @@ -167,7 +190,8 @@ def test_compare_series_value_exact_mismatch() -> None: srs1 = pl.Series([1.0, 2.0, 3.0]) srs2 = pl.Series([1.0, 2.0 + 1e-7, 3.0]) with pytest.raises( - AssertionError, match=r"Series are different \(exact value mismatch\)" + AssertionError, + match=r"Series are different \(exact value mismatch\)", ): assert_series_equal(srs1, srs2, check_exact=True) @@ -537,7 +561,10 @@ def test_assert_series_equal_full_series() -> None: def test_assert_series_not_equal() -> None: s = pl.Series("a", [1, 2]) - with pytest.raises(AssertionError, match="Series are equal"): + with pytest.raises( + AssertionError, + match=r"Series are equal \(but are expected not to be\)", + ): assert_series_not_equal(s, s) @@ -546,7 +573,10 @@ def test_assert_series_equal_nested_list_float() -> None: s1 = pl.Series([[1.0, 2.0], [3.0, 4.0]], dtype=pl.List(pl.Float64)) s2 = pl.Series([[1.0, 2.0], [3.0, 4.9]], dtype=pl.List(pl.Float64)) - with pytest.raises(AssertionError): + with pytest.raises( + AssertionError, + match=r"Series are different \(nested value mismatch\)", + ): assert_series_equal(s1, s2) @@ -560,7 +590,10 @@ def test_assert_series_equal_nested_struct_float() -> None: dtype=pl.Struct({"a": pl.Float64, "b": pl.Float64}), ) - with pytest.raises(AssertionError): + with pytest.raises( + AssertionError, + match=r"Series are different \(nested value mismatch\)", + ): assert_series_equal(s1, s2) @@ -570,7 +603,10 @@ def test_assert_series_equal_full_null_incompatible_dtypes_raises() -> None: # You could argue this should pass, but it's rare enough not to warrant the # additional check - with pytest.raises(AssertionError, match="incompatible data types"): + with pytest.raises( + AssertionError, + match="incompatible data types", + ): assert_series_equal(s1, s2, check_dtypes=False) @@ -595,9 +631,16 @@ def test_assert_series_equal_uint_overflow() -> None: s1 = pl.Series([1, 2, 3], dtype=pl.UInt8) s2 = pl.Series([2, 3, 4], dtype=pl.UInt8) - with pytest.raises(AssertionError): + with pytest.raises( + AssertionError, + match=r"Series are different \(exact value mismatch\)", + ): assert_series_equal(s1, s2, atol=0) - with pytest.raises(AssertionError): + + with pytest.raises( + AssertionError, + match=r"Series are different \(exact value mismatch\)", + ): assert_series_equal(s1, s2, atol=1) left = pl.Series( @@ -616,7 +659,10 @@ def test_assert_series_equal_uint_always_checked_exactly() -> None: s1 = pl.Series([1, 3], dtype=pl.UInt8) s2 = pl.Series([2, 4], dtype=pl.Int64) - with pytest.raises(AssertionError): + with pytest.raises( + AssertionError, + match=r"Series are different \(exact value mismatch\)", + ): assert_series_equal(s1, s2, atol=1, check_dtypes=False) @@ -624,9 +670,15 @@ def test_assert_series_equal_nested_int_always_checked_exactly() -> None: s1 = pl.Series([[1, 2], [3, 4]]) s2 = pl.Series([[1, 2], [3, 5]]) - with pytest.raises(AssertionError): + with pytest.raises( + AssertionError, + match=r"Series are different \(exact value mismatch\)", + ): assert_series_equal(s1, s2, atol=1) - with pytest.raises(AssertionError): + with pytest.raises( + AssertionError, + match=r"Series are different \(exact value mismatch\)", + ): assert_series_equal(s1, s2, check_exact=True) @@ -635,7 +687,9 @@ def test_assert_series_equal_array_equal(check_exact: bool) -> None: s1 = pl.Series([[1.0, 2.0], [3.0, 4.0]], dtype=pl.Array(pl.Float64, 2)) s2 = pl.Series([[1.0, 2.0], [3.0, 4.2]], dtype=pl.Array(pl.Float64, 2)) - with pytest.raises(AssertionError): + with pytest.raises( + AssertionError, match=r"Series are different \(nested value mismatch\)" + ): assert_series_equal(s1, s2, check_exact=check_exact)