diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index f9e02db248ce..04a0311674bf 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -317,10 +317,10 @@ class DataFrame: │ 2 ┆ 4 │ └─────┴─────┘ - Constructing a DataFrame from a list of lists, row orientation inferred: + Constructing a DataFrame from a list of lists, row orientation specified: >>> data = [[1, 2, 3], [4, 5, 6]] - >>> df6 = pl.DataFrame(data, schema=["a", "b", "c"]) + >>> df6 = pl.DataFrame(data, schema=["a", "b", "c"], orient="row") >>> df6 shape: (2, 3) ┌─────┬─────┬─────┐ @@ -9250,7 +9250,9 @@ def n_unique(self, subset: str | Expr | Sequence[str | Expr] | None = None) -> i If instead you want to count the number of unique values per-column, you can also use expression-level syntax to return a new frame containing that result: - >>> df = pl.DataFrame([[1, 2, 3], [1, 2, 4]], schema=["a", "b", "c"]) + >>> df = pl.DataFrame( + ... [[1, 2, 3], [1, 2, 4]], schema=["a", "b", "c"], orient="row" + ... ) >>> df_nunique = df.select(pl.all().n_unique()) In aggregate context there is also an equivalent method for returning the diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index 6fb3fb41e199..989909a17800 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -267,10 +267,10 @@ class LazyFrame: │ 2 ┆ 4 │ └─────┴─────┘ - Constructing a LazyFrame from a list of lists, row orientation inferred: + Constructing a LazyFrame from a list of lists, row orientation specified: >>> data = [[1, 2, 3], [4, 5, 6]] - >>> lf6 = pl.LazyFrame(data, schema=["a", "b", "c"]) + >>> lf6 = pl.LazyFrame(data, schema=["a", "b", "c"], orient="row") >>> lf6.collect() shape: (2, 3) ┌─────┬─────┬─────┐ @@ -420,7 +420,7 @@ def columns(self) -> list[str]: ... "ham": ["a", "b", "c"], ... } ... ).select("foo", "bar") - >>> lf.columns + >>> lf.columns # doctest: +SKIP ['foo', 'bar'] """ issue_warning( @@ -462,7 +462,7 @@ def dtypes(self) -> list[DataType]: ... "ham": ["a", "b", "c"], ... } ... ) - >>> lf.dtypes + >>> lf.dtypes # doctest: +SKIP [Int64, Float64, String] """ issue_warning( @@ -498,7 +498,7 @@ def schema(self) -> Schema: ... "ham": ["a", "b", "c"], ... } ... ) - >>> lf.schema + >>> lf.schema # doctest: +SKIP Schema({'foo': Int64, 'bar': Float64, 'ham': String}) """ issue_warning( @@ -537,7 +537,7 @@ def width(self) -> int: ... "bar": [4, 5, 6], ... } ... ) - >>> lf.width + >>> lf.width # doctest: +SKIP 2 """ issue_warning( diff --git a/py-polars/polars/ml/torch.py b/py-polars/polars/ml/torch.py index d38340f8d130..034d95b25631 100644 --- a/py-polars/polars/ml/torch.py +++ b/py-polars/polars/ml/torch.py @@ -28,7 +28,80 @@ class PolarsDataset(TensorDataset): # type: ignore[misc] - """Specialized TensorDataset for Polars DataFrames.""" + """ + TensorDataset class specialized for use with Polars DataFrames. + + Parameters + ---------- + frame + Polars DataFrame containing the data that will be retrieved as Tensors. + label + One or more column names or expressions that label the feature data; results + in `(features,label)` tuples, where all non-label columns are considered + to be features. If no label is designated then each returned item is a + simple `(features,)` tuple containing all row elements. + features + One or more column names or expressions that represent the feature data. + If not provided, all columns not designated as labels are considered to be + features. + + Notes + ----- + * Integer, slice, range, integer list/Tensor Dataset indexing is all supported. + * Designating multi-element labels is also supported. + + Examples + -------- + >>> from torch.utils.data import DataLoader + >>> df = pl.DataFrame( + ... data=[ + ... (0, 1, 1.5), + ... (1, 0, -0.5), + ... (2, 0, 0.0), + ... (3, 1, -2.25), + ... ], + ... schema=["lbl", "feat1", "feat2"], + ... orient="row", + ... ) + + Create a Dataset from a Polars DataFrame, standardising the dtype and + distinguishing the label/feature columns. + + >>> ds = df.to_torch("dataset", label="lbl", dtype=pl.Float32) + >>> ds # doctest: +IGNORE_RESULT + + >>> ds.features + tensor([[ 1.0000, 1.5000], + [ 0.0000, -0.5000], + [ 0.0000, 0.0000], + [ 1.0000, -2.2500]]) + >>> ds[0] + (tensor([1.0000, 1.5000]), tensor(0.)) + + The Dataset can be used standalone, or in conjunction with a DataLoader. + + >>> dl = DataLoader(ds, batch_size=2) + >>> list(dl) + [[tensor([[ 1.0000, 1.5000], + [ 0.0000, -0.5000]]), + tensor([0., 1.])], + [tensor([[ 0.0000, 0.0000], + [ 1.0000, -2.2500]]), + tensor([2., 3.])]] + + Note that the label can be given as an expression as well as a column name, + allowing for independent transform and dtype adjustment from the feature + columns. + + >>> ds = df.to_torch( + ... "dataset", + ... dtype=pl.Float32, + ... label=(pl.col("lbl") * 8).cast(pl.Int16), + ... ) + >>> ds[:2] + (tensor([[ 1.0000, 1.5000], + [ 0.0000, -0.5000]]), tensor([0, 8], dtype=torch.int16)) + """ tensors: tuple[Tensor, ...] labels: Tensor | None @@ -41,79 +114,6 @@ def __init__( label: str | Expr | Sequence[str | Expr] | None = None, features: str | Expr | Sequence[str | Expr] | None = None, ): - """ - TensorDataset class specialized for use with Polars DataFrames. - - Parameters - ---------- - frame - Polars DataFrame containing the data that will be retrieved as Tensors. - label - One or more column names or expressions that label the feature data; results - in `(features,label)` tuples, where all non-label columns are considered - to be features. If no label is designated then each returned item is a - simple `(features,)` tuple containing all row elements. - features - One or more column names or expressions that represent the feature data. - If not provided, all columns not designated as labels are considered to be - features. - - Notes - ----- - * Integer, slice, range, integer list/Tensor Dataset indexing is all supported. - * Designating multi-element labels is also supported. - - Examples - -------- - >>> from torch.utils.data import DataLoader - >>> df = pl.DataFrame( - ... data=[ - ... (0, 1, 1.5), - ... (1, 0, -0.5), - ... (2, 0, 0.0), - ... (3, 1, -2.25), - ... ], - ... schema=["lbl", "feat1", "feat2"], - ... ) - - Create a Dataset from a Polars DataFrame, standardising the dtype and - distinguishing the label/feature columns. - - >>> ds = df.to_torch("dataset", label="lbl", dtype=pl.Float32) - >>> ds # doctest: +IGNORE_RESULT - - >>> ds.features - tensor([[ 1.0000, 1.5000], - [ 0.0000, -0.5000], - [ 0.0000, 0.0000], - [ 1.0000, -2.2500]]) - >>> ds[0] - (tensor([1.0000, 1.5000]), tensor(0.)) - - The Dataset can be used standalone, or in conjunction with a DataLoader. - - >>> dl = DataLoader(ds, batch_size=2) - >>> list(dl) - [[tensor([[ 1.0000, 1.5000], - [ 0.0000, -0.5000]]), - tensor([0., 1.])], - [tensor([[ 0.0000, 0.0000], - [ 1.0000, -2.2500]]), - tensor([2., 3.])]] - - Note that the label can be given as an expression as well as a column name, - allowing for independent transform and dtype adjustment from the feature - columns. - - >>> ds = df.to_torch( - ... "dataset", - ... dtype=pl.Float32, - ... label=(pl.col("lbl") * 8).cast(pl.Int16), - ... ) - >>> ds[:2] - (tensor([[ 1.0000, 1.5000], - [ 0.0000, -0.5000]]), tensor([0, 8], dtype=torch.int16)) - """ if isinstance(label, (str, Expr)): label = [label] diff --git a/py-polars/polars/sql/context.py b/py-polars/polars/sql/context.py index 99dbd0a58cca..7b45a5899429 100644 --- a/py-polars/polars/sql/context.py +++ b/py-polars/polars/sql/context.py @@ -381,6 +381,7 @@ def execute( ... ("The Shawshank Redemption", 1994, 25_000_000, 28_341_469, 9.3), ... ], ... schema=["title", "release_year", "budget", "gross", "imdb_score"], + ... orient="row", ... ) >>> ctx = pl.SQLContext(films=df) diff --git a/py-polars/tests/docs/run_doctest.py b/py-polars/tests/docs/run_doctest.py index 39b95a548ddc..09ad6b41c14d 100644 --- a/py-polars/tests/docs/run_doctest.py +++ b/py-polars/tests/docs/run_doctest.py @@ -114,7 +114,7 @@ def __iter__(self) -> Iterator[Any]: IGNORE_RESULT = doctest.register_optionflag("IGNORE_RESULT") # Set doctests to fail on warnings - warnings.simplefilter("error", DeprecationWarning) + warnings.simplefilter("error", Warning) warnings.filterwarnings( "ignore", message="datetime.datetime.utcfromtimestamp\\(\\) is deprecated.*",