Skip to content

Commit

Permalink
docs(python): Minor doctest fixes (#17002)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Jun 16, 2024
1 parent ce62b33 commit 308df5d
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 84 deletions.
8 changes: 5 additions & 3 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,10 @@ class DataFrame:
│ 2 ┆ 4 │
└─────┴─────┘
Constructing a DataFrame from a list of lists, row orientation inferred:
Constructing a DataFrame from a list of lists, row orientation specified:
>>> data = [[1, 2, 3], [4, 5, 6]]
>>> df6 = pl.DataFrame(data, schema=["a", "b", "c"])
>>> df6 = pl.DataFrame(data, schema=["a", "b", "c"], orient="row")
>>> df6
shape: (2, 3)
┌─────┬─────┬─────┐
Expand Down Expand Up @@ -9250,7 +9250,9 @@ def n_unique(self, subset: str | Expr | Sequence[str | Expr] | None = None) -> i
If instead you want to count the number of unique values per-column, you can
also use expression-level syntax to return a new frame containing that result:
>>> df = pl.DataFrame([[1, 2, 3], [1, 2, 4]], schema=["a", "b", "c"])
>>> df = pl.DataFrame(
... [[1, 2, 3], [1, 2, 4]], schema=["a", "b", "c"], orient="row"
... )
>>> df_nunique = df.select(pl.all().n_unique())
In aggregate context there is also an equivalent method for returning the
Expand Down
12 changes: 6 additions & 6 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,10 +267,10 @@ class LazyFrame:
│ 2 ┆ 4 │
└─────┴─────┘
Constructing a LazyFrame from a list of lists, row orientation inferred:
Constructing a LazyFrame from a list of lists, row orientation specified:
>>> data = [[1, 2, 3], [4, 5, 6]]
>>> lf6 = pl.LazyFrame(data, schema=["a", "b", "c"])
>>> lf6 = pl.LazyFrame(data, schema=["a", "b", "c"], orient="row")
>>> lf6.collect()
shape: (2, 3)
┌─────┬─────┬─────┐
Expand Down Expand Up @@ -420,7 +420,7 @@ def columns(self) -> list[str]:
... "ham": ["a", "b", "c"],
... }
... ).select("foo", "bar")
>>> lf.columns
>>> lf.columns # doctest: +SKIP
['foo', 'bar']
"""
issue_warning(
Expand Down Expand Up @@ -462,7 +462,7 @@ def dtypes(self) -> list[DataType]:
... "ham": ["a", "b", "c"],
... }
... )
>>> lf.dtypes
>>> lf.dtypes # doctest: +SKIP
[Int64, Float64, String]
"""
issue_warning(
Expand Down Expand Up @@ -498,7 +498,7 @@ def schema(self) -> Schema:
... "ham": ["a", "b", "c"],
... }
... )
>>> lf.schema
>>> lf.schema # doctest: +SKIP
Schema({'foo': Int64, 'bar': Float64, 'ham': String})
"""
issue_warning(
Expand Down Expand Up @@ -537,7 +537,7 @@ def width(self) -> int:
... "bar": [4, 5, 6],
... }
... )
>>> lf.width
>>> lf.width # doctest: +SKIP
2
"""
issue_warning(
Expand Down
148 changes: 74 additions & 74 deletions py-polars/polars/ml/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,80 @@


class PolarsDataset(TensorDataset): # type: ignore[misc]
"""Specialized TensorDataset for Polars DataFrames."""
"""
TensorDataset class specialized for use with Polars DataFrames.
Parameters
----------
frame
Polars DataFrame containing the data that will be retrieved as Tensors.
label
One or more column names or expressions that label the feature data; results
in `(features,label)` tuples, where all non-label columns are considered
to be features. If no label is designated then each returned item is a
simple `(features,)` tuple containing all row elements.
features
One or more column names or expressions that represent the feature data.
If not provided, all columns not designated as labels are considered to be
features.
Notes
-----
* Integer, slice, range, integer list/Tensor Dataset indexing is all supported.
* Designating multi-element labels is also supported.
Examples
--------
>>> from torch.utils.data import DataLoader
>>> df = pl.DataFrame(
... data=[
... (0, 1, 1.5),
... (1, 0, -0.5),
... (2, 0, 0.0),
... (3, 1, -2.25),
... ],
... schema=["lbl", "feat1", "feat2"],
... orient="row",
... )
Create a Dataset from a Polars DataFrame, standardising the dtype and
distinguishing the label/feature columns.
>>> ds = df.to_torch("dataset", label="lbl", dtype=pl.Float32)
>>> ds # doctest: +IGNORE_RESULT
<PolarsDataset [len:4, features:2, labels:1] at 0x156B033B0>
>>> ds.features
tensor([[ 1.0000, 1.5000],
[ 0.0000, -0.5000],
[ 0.0000, 0.0000],
[ 1.0000, -2.2500]])
>>> ds[0]
(tensor([1.0000, 1.5000]), tensor(0.))
The Dataset can be used standalone, or in conjunction with a DataLoader.
>>> dl = DataLoader(ds, batch_size=2)
>>> list(dl)
[[tensor([[ 1.0000, 1.5000],
[ 0.0000, -0.5000]]),
tensor([0., 1.])],
[tensor([[ 0.0000, 0.0000],
[ 1.0000, -2.2500]]),
tensor([2., 3.])]]
Note that the label can be given as an expression as well as a column name,
allowing for independent transform and dtype adjustment from the feature
columns.
>>> ds = df.to_torch(
... "dataset",
... dtype=pl.Float32,
... label=(pl.col("lbl") * 8).cast(pl.Int16),
... )
>>> ds[:2]
(tensor([[ 1.0000, 1.5000],
[ 0.0000, -0.5000]]), tensor([0, 8], dtype=torch.int16))
"""

tensors: tuple[Tensor, ...]
labels: Tensor | None
Expand All @@ -41,79 +114,6 @@ def __init__(
label: str | Expr | Sequence[str | Expr] | None = None,
features: str | Expr | Sequence[str | Expr] | None = None,
):
"""
TensorDataset class specialized for use with Polars DataFrames.
Parameters
----------
frame
Polars DataFrame containing the data that will be retrieved as Tensors.
label
One or more column names or expressions that label the feature data; results
in `(features,label)` tuples, where all non-label columns are considered
to be features. If no label is designated then each returned item is a
simple `(features,)` tuple containing all row elements.
features
One or more column names or expressions that represent the feature data.
If not provided, all columns not designated as labels are considered to be
features.
Notes
-----
* Integer, slice, range, integer list/Tensor Dataset indexing is all supported.
* Designating multi-element labels is also supported.
Examples
--------
>>> from torch.utils.data import DataLoader
>>> df = pl.DataFrame(
... data=[
... (0, 1, 1.5),
... (1, 0, -0.5),
... (2, 0, 0.0),
... (3, 1, -2.25),
... ],
... schema=["lbl", "feat1", "feat2"],
... )
Create a Dataset from a Polars DataFrame, standardising the dtype and
distinguishing the label/feature columns.
>>> ds = df.to_torch("dataset", label="lbl", dtype=pl.Float32)
>>> ds # doctest: +IGNORE_RESULT
<PolarsDataset [len:4, features:2, labels:1] at 0x156B033B0>
>>> ds.features
tensor([[ 1.0000, 1.5000],
[ 0.0000, -0.5000],
[ 0.0000, 0.0000],
[ 1.0000, -2.2500]])
>>> ds[0]
(tensor([1.0000, 1.5000]), tensor(0.))
The Dataset can be used standalone, or in conjunction with a DataLoader.
>>> dl = DataLoader(ds, batch_size=2)
>>> list(dl)
[[tensor([[ 1.0000, 1.5000],
[ 0.0000, -0.5000]]),
tensor([0., 1.])],
[tensor([[ 0.0000, 0.0000],
[ 1.0000, -2.2500]]),
tensor([2., 3.])]]
Note that the label can be given as an expression as well as a column name,
allowing for independent transform and dtype adjustment from the feature
columns.
>>> ds = df.to_torch(
... "dataset",
... dtype=pl.Float32,
... label=(pl.col("lbl") * 8).cast(pl.Int16),
... )
>>> ds[:2]
(tensor([[ 1.0000, 1.5000],
[ 0.0000, -0.5000]]), tensor([0, 8], dtype=torch.int16))
"""
if isinstance(label, (str, Expr)):
label = [label]

Expand Down
1 change: 1 addition & 0 deletions py-polars/polars/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ def execute(
... ("The Shawshank Redemption", 1994, 25_000_000, 28_341_469, 9.3),
... ],
... schema=["title", "release_year", "budget", "gross", "imdb_score"],
... orient="row",
... )
>>> ctx = pl.SQLContext(films=df)
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/docs/run_doctest.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def __iter__(self) -> Iterator[Any]:
IGNORE_RESULT = doctest.register_optionflag("IGNORE_RESULT")

# Set doctests to fail on warnings
warnings.simplefilter("error", DeprecationWarning)
warnings.simplefilter("error", Warning)
warnings.filterwarnings(
"ignore",
message="datetime.datetime.utcfromtimestamp\\(\\) is deprecated.*",
Expand Down

0 comments on commit 308df5d

Please sign in to comment.