Skip to content

Commit

Permalink
refactor assert_equal_data
Browse files Browse the repository at this point in the history
  • Loading branch information
EdAbati committed Oct 18, 2024
1 parent a9db656 commit 85c3d45
Showing 1 changed file with 29 additions and 18 deletions.
47 changes: 29 additions & 18 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,36 +30,47 @@ def zip_strict(left: Sequence[Any], right: Sequence[Any]) -> Iterator[Any]:
return zip(left, right)


def _to_py_object(value: Any) -> Any:
# PyArrow: return scalars as Python objects
if hasattr(value, "as_py"): # pragma: no cover
return value.as_py()
# cuDF: returns cupy scalars as Python objects
if hasattr(value, "item"): # pragma: no cover
return value.item()
return value


def _to_comparable_list(column_values: Any) -> Any:
if (
hasattr(column_values, "_compliant_series")
and column_values._compliant_series._implementation is Implementation.CUDF
): # pragma: no cover
column_values = column_values.to_pandas()
if hasattr(column_values, "to_list"):
return column_values.to_list()
return [_to_py_object(v) for v in column_values]


def assert_equal_data(result: Any, expected: dict[str, Any]) -> None:
if hasattr(result, "collect"):
result = result.collect()
if hasattr(result, "columns"):
for key in result.columns:
assert key in expected
result = {key: _to_comparable_list(result[key]) for key in expected}
for key in expected:
result_key = result[key]
if (
hasattr(result_key, "_compliant_series")
and result_key._compliant_series._implementation is Implementation.CUDF
): # pragma: no cover
result_key = result_key.to_pandas()
for lhs, rhs in zip_strict(result_key, expected[key]):
if hasattr(lhs, "as_py"):
lhs = lhs.as_py() # noqa: PLW2901
if hasattr(rhs, "as_py"): # pragma: no cover
rhs = rhs.as_py() # noqa: PLW2901
if hasattr(lhs, "item"): # pragma: no cover
lhs = lhs.item() # noqa: PLW2901
if hasattr(rhs, "item"): # pragma: no cover
rhs = rhs.item() # noqa: PLW2901
expected_key = expected[key]
for i, (lhs, rhs) in enumerate(zip_strict(result_key, expected_key)):
if isinstance(lhs, float) and not math.isnan(lhs):
assert math.isclose(lhs, rhs, rel_tol=0, abs_tol=1e-6), (lhs, rhs)
are_valid_values = math.isclose(lhs, rhs, rel_tol=0, abs_tol=1e-6)
elif isinstance(lhs, float) and math.isnan(lhs) and rhs is not None:
assert math.isnan(rhs), (lhs, rhs) # pragma: no cover
are_valid_values = math.isnan(rhs) # pragma: no cover
elif pd.isna(lhs):
assert pd.isna(rhs), (lhs, rhs)
are_valid_values = pd.isna(rhs)
else:
assert lhs == rhs, (lhs, rhs)
are_valid_values = lhs == rhs
assert are_valid_values, f"Mismatch at index {i}: {lhs} != {rhs}\nExpected: {expected}\nGot: {result}"


def maybe_get_modin_df(df_pandas: pd.DataFrame) -> Any:
Expand Down

0 comments on commit 85c3d45

Please sign in to comment.