Skip to content

Commit

Permalink
fix(python): Fix edge case in DataFrame constructor data orientation …
Browse files Browse the repository at this point in the history
…inference (#16975)
  • Loading branch information
stinodego authored Jun 16, 2024
1 parent 9e3e9b0 commit 9f097ea
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 22 deletions.
60 changes: 38 additions & 22 deletions py-polars/polars/_utils/construction/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,10 @@ def _parse_schema_overrides(
col = col[0]
column_names.append(col)

if n_expected is not None and len(column_names) != n_expected:
msg = "data does not match the number of columns"
raise ShapeError(msg)

# determine column dtypes from schema and lookup_names
lookup: dict[str, str] | None = (
{
Expand Down Expand Up @@ -533,18 +537,11 @@ def _sequence_of_sequence_to_pydf(
infer_schema_length: int | None,
) -> PyDataFrame:
if orient is None:
# note: limit type-checking to smaller data; larger values are much more
# likely to indicate col orientation anyway, so minimise extra checks.
if len(first_element) > 1000:
orient = "col" if schema and len(schema) == len(data) else "row"
elif (schema is not None and len(schema) == len(data)) or not schema:
# check if element types in the first 'row' resolve to a single dtype.
row_types = {type(value) for value in first_element if value is not None}
if int in row_types and float in row_types:
row_types.discard(int)
orient = "col" if len(row_types) == 1 else "row"
else:
orient = "row"
orient = _infer_data_orientation(
first_element,
len_data=len(data),
len_schema=len(schema) if schema is not None else None,
)

if orient == "row":
column_names, schema_overrides = _unpack_schema(
Expand All @@ -555,13 +552,6 @@ def _sequence_of_sequence_to_pydf(
if schema_overrides
else {}
)
if (
column_names
and len(first_element) > 0
and len(first_element) != len(column_names)
):
msg = "the row data does not match the number of columns"
raise ShapeError(msg)

unpack_nested = False
for col, tp in local_schema_override.items():
Expand Down Expand Up @@ -589,7 +579,7 @@ def _sequence_of_sequence_to_pydf(
)
return pydf

if orient == "col" or orient is None:
elif orient == "col":
column_names, schema_overrides = _unpack_schema(
schema, schema_overrides=schema_overrides, n_expected=len(data)
)
Expand All @@ -604,8 +594,34 @@ def _sequence_of_sequence_to_pydf(
]
return PyDataFrame(data_series)

msg = f"`orient` must be one of {{'col', 'row', None}}, got {orient!r}"
raise ValueError(msg)
else:
msg = f"`orient` must be one of {{'col', 'row', None}}, got {orient!r}"
raise ValueError(msg)


def _infer_data_orientation(
first_element: Sequence[Any] | np.ndarray[Any, Any],
len_data: int,
len_schema: int | None = None,
) -> Orientation:
# Check if element types in the first 'row' resolve to a single dtype.
# Note: limit type-checking to smaller data; larger values are much more
# likely to indicate col orientation anyway, so minimize extra checks.
if len(first_element) <= 1000 and (len_schema is None or len_schema == len_data):
row_types = {type(value) for value in first_element if value is not None}
if int in row_types and float in row_types:
row_types.discard(int)
return "row" if len(row_types) > 1 else "col"

elif (
len_schema is not None
and len_schema == len(first_element)
and len_schema != len_data
):
return "row"

else:
return "col"


def _sequence_of_series_to_pydf(
Expand Down
11 changes: 11 additions & 0 deletions py-polars/tests/unit/constructors/test_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pydantic import BaseModel, Field, TypeAdapter

import polars as pl
from polars._utils.construction.dataframe import _infer_data_orientation
from polars._utils.construction.utils import try_get_type_hints
from polars.datatypes import PolarsDataType, numpy_char_code_to_dtype
from polars.dependencies import dataclasses, pydantic
Expand Down Expand Up @@ -1622,3 +1623,13 @@ def test_array_construction() -> None:
df = pl.from_dicts(rows, schema=schema)
assert df.schema == schema
assert df.rows() == [("a", [1, 2, 3]), ("b", [2, 3, 4])]


def test_infer_data_orientation() -> None:
assert _infer_data_orientation([1], 1) == "col"
assert _infer_data_orientation([1, 2], 2) == "col"
assert _infer_data_orientation([1, 2], 2, 2) == "col"
assert _infer_data_orientation([1, 2, 3], 2) == "col"
assert _infer_data_orientation([1, 2, 3], 2, 2) == "col"
assert _infer_data_orientation([1, 2, 3], 2, 3) == "row"
assert _infer_data_orientation([1, "x"], 2) == "row"

0 comments on commit 9f097ea

Please sign in to comment.