Skip to content

Commit

Permalink
feat(python): Allow use of Python types in cs.by_dtype and col
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed Dec 29, 2024
1 parent f43a7d4 commit e0df706
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 19 deletions.
77 changes: 64 additions & 13 deletions py-polars/polars/functions/col.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,40 @@

import contextlib
from collections.abc import Iterable
from datetime import datetime, timedelta
from typing import TYPE_CHECKING

from polars._utils.wrap import wrap_expr
from polars.datatypes import is_polars_dtype
from polars.datatypes import Datetime, Duration, is_polars_dtype, parse_into_dtype
from polars.datatypes.group import (
DATETIME_DTYPES,
DURATION_DTYPES,
FLOAT_DTYPES,
INTEGER_DTYPES,
)

with contextlib.suppress(ImportError): # Module not available when building docs
import polars.polars as plr

if TYPE_CHECKING:
from polars._typing import PolarsDataType
from polars._typing import PolarsDataType, PythonDataType
from polars.expr.expr import Expr

__all__ = ["col"]


def _create_col(
name: str | PolarsDataType | Iterable[str] | Iterable[PolarsDataType],
*more_names: str | PolarsDataType,
name: (
str
| PolarsDataType
| PythonDataType
| Iterable[str]
| Iterable[PolarsDataType | PythonDataType]
),
*more_names: str | PolarsDataType | PythonDataType,
) -> Expr:
"""Create one or more column expressions representing column(s) in a DataFrame."""
dtypes: list[PolarsDataType]
if more_names:
if isinstance(name, str):
names_str = [name]
Expand All @@ -41,7 +55,11 @@ def _create_col(
if isinstance(name, str):
return wrap_expr(plr.col(name))
elif is_polars_dtype(name):
return wrap_expr(plr.dtype_cols([name]))
dtypes = _polars_dtype_match(name)
return wrap_expr(plr.dtype_cols(dtypes))
elif isinstance(name, type):
dtypes = _python_dtype_match(name)
return wrap_expr(plr.dtype_cols(dtypes))
elif isinstance(name, Iterable):
names = list(name)
if not names:
Expand All @@ -51,7 +69,15 @@ def _create_col(
if isinstance(item, str):
return wrap_expr(plr.cols(names))
elif is_polars_dtype(item):
return wrap_expr(plr.dtype_cols(names))
dtypes = []
for nm in names:
dtypes.extend(_polars_dtype_match(nm)) # type: ignore[arg-type]
return wrap_expr(plr.dtype_cols(dtypes))
elif isinstance(item, type):
dtypes = []
for nm in names:
dtypes.extend(_python_dtype_match(nm)) # type: ignore[arg-type]
return wrap_expr(plr.dtype_cols(dtypes))
else:
msg = (
"invalid input for `col`"
Expand All @@ -67,6 +93,26 @@ def _create_col(
raise TypeError(msg)


def _python_dtype_match(tp: PythonDataType) -> list[PolarsDataType]:
if tp is int:
return list(INTEGER_DTYPES)
elif tp is float:
return list(FLOAT_DTYPES)
elif tp is datetime:
return list(DATETIME_DTYPES)
elif tp is timedelta:
return list(DURATION_DTYPES)
return [parse_into_dtype(tp)]


def _polars_dtype_match(tp: PolarsDataType) -> list[PolarsDataType]:
if Datetime.is_(tp):
return list(DATETIME_DTYPES)
elif Duration.is_(tp):
return list(DURATION_DTYPES)
return [tp]


class Col:
"""
Create Polars column expressions.
Expand All @@ -79,8 +125,7 @@ class Col:
This helper class enables an alternative syntax for creating a column expression
through attribute lookup. For example `col.foo` creates an expression equal to
`col("foo")`.
See the :func:`__getattr__` method for further documentation.
`col("foo")`. See the :func:`__getattr__` method for further documentation.
The function call syntax is considered the idiomatic way of constructing a column
expression. The alternative attribute syntax can be useful for quick prototyping as
Expand Down Expand Up @@ -126,18 +171,24 @@ class Col:

def __call__(
self,
name: str | PolarsDataType | Iterable[str] | Iterable[PolarsDataType],
*more_names: str | PolarsDataType,
name: (
str
| PolarsDataType
| PythonDataType
| Iterable[str]
| Iterable[PolarsDataType | PythonDataType]
),
*more_names: str | PolarsDataType | PythonDataType,
) -> Expr:
"""
Create one or more column expressions representing column(s) in a DataFrame.
Create one or more expressions representing columns in a DataFrame.
Parameters
----------
name
The name or datatype of the column(s) to represent.
Accepts regular expression input.
Regular expressions should start with `^` and end with `$`.
Accepts regular expression input; regular expressions
should start with `^` and end with `$`.
*more_names
Additional names or datatypes of columns to represent,
specified as positional arguments.
Expand Down
16 changes: 11 additions & 5 deletions py-polars/polars/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@

if TYPE_CHECKING:
import sys
from collections.abc import Iterable

from polars import DataFrame, LazyFrame
from polars._typing import PolarsDataType, SelectorType, TimeUnit
from polars._typing import PolarsDataType, PythonDataType, SelectorType, TimeUnit

if sys.version_info >= (3, 11):
from typing import Self
Expand Down Expand Up @@ -868,7 +869,12 @@ def boolean() -> SelectorType:


def by_dtype(
*dtypes: PolarsDataType | Collection[PolarsDataType],
*dtypes: (
PolarsDataType
| PythonDataType
| Iterable[PolarsDataType]
| Iterable[PythonDataType]
),
) -> SelectorType:
"""
Select all columns matching the given dtypes.
Expand Down Expand Up @@ -931,13 +937,13 @@ def by_dtype(
│ foo ┆ -3265500 │
└───────┴──────────┘
"""
all_dtypes: list[PolarsDataType] = []
all_dtypes: list[PolarsDataType | PythonDataType] = []
for tp in dtypes:
if is_polars_dtype(tp):
if is_polars_dtype(tp) or isinstance(tp, type):
all_dtypes.append(tp)
elif isinstance(tp, Collection):
for t in tp:
if not is_polars_dtype(t):
if not (is_polars_dtype(t) or isinstance(t, type)):
msg = f"invalid dtype: {t!r}"
raise TypeError(msg)
all_dtypes.append(t)
Expand Down
55 changes: 54 additions & 1 deletion py-polars/tests/unit/test_selectors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import OrderedDict
from datetime import datetime
from datetime import datetime, timedelta
from decimal import Decimal as PyDecimal
from typing import Any
from zoneinfo import ZoneInfo

Expand Down Expand Up @@ -110,6 +111,58 @@ def test_selector_by_dtype(df: pl.DataFrame) -> None:
"qqR": pl.String(),
}
)
assert df.select(
cs.by_dtype(pl.Datetime("ns"), pl.Float32, pl.UInt32, pl.Date)
).schema == pl.Schema(
{
"bbb": pl.UInt32,
"def": pl.Float32,
"JJK": pl.Date,
}
)

# select using python types
assert df.select(cs.by_dtype(int, float)).schema == pl.Schema(
{
"abc": pl.UInt16,
"bbb": pl.UInt32,
"cde": pl.Float64,
"def": pl.Float32,
}
)
assert df.select(cs.by_dtype(bool, datetime, timedelta)).schema == pl.Schema(
{
"eee": pl.Boolean(),
"fgg": pl.Boolean(),
"Lmn": pl.Duration("us"),
"opp": pl.Datetime("ms"),
}
)

# cover timezones and decimal
dfx = pl.DataFrame(
{"idx": [], "dt1": [], "dt2": []},
schema_overrides={
"idx": pl.Decimal(24),
"dt1": pl.Datetime("ms"),
"dt2": pl.Datetime(time_zone="Asia/Tokyo"),
},
)
assert dfx.select(cs.by_dtype(PyDecimal)).schema == pl.Schema(
{"idx": pl.Decimal(24)},
)
assert dfx.select(cs.by_dtype(pl.Datetime(time_zone="*"))).schema == pl.Schema(
{"dt2": pl.Datetime(time_zone="Asia/Tokyo")}
)
assert dfx.select(cs.by_dtype(pl.Datetime("ms", None))).schema == pl.Schema(
{"dt1": pl.Datetime("ms")},
)
for dt in (datetime, pl.Datetime):
assert dfx.select(cs.by_dtype(dt)).schema == pl.Schema(
{"dt1": pl.Datetime("ms"), "dt2": pl.Datetime(time_zone="Asia/Tokyo")},
)

# empty selection selects nothing
assert df.select(cs.by_dtype()).schema == {}
assert df.select(cs.by_dtype([])).schema == {}

Expand Down

0 comments on commit e0df706

Please sign in to comment.