Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Add infer_schema parameter to read_csv / scan_csv #17617

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions py-polars/polars/io/csv/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def read_csv(
ignore_errors: bool = False,
try_parse_dates: bool = False,
n_threads: int | None = None,
infer_schema: bool = True,
infer_schema_length: int | None = N_INFER_DEFAULT,
batch_size: int = 8192,
n_rows: int | None = None,
Expand Down Expand Up @@ -126,7 +127,7 @@ def read_csv(
Before using this option, try to increase the number of lines used for schema
inference with e.g `infer_schema_length=10000` or override automatic dtype
inference for specific columns with the `schema_overrides` option or use
`infer_schema_length=0` to read all columns as `pl.String` to check which
`infer_schema=False` to read all columns as `pl.String` to check which
values might cause an issue.
try_parse_dates
Try to automatically parse dates. Most ISO8601-like formats can
Expand All @@ -136,10 +137,15 @@ def read_csv(
n_threads
Number of threads to use in csv parsing.
Defaults to the number of physical cpu's of your system.
infer_schema
When `True`, the schema is inferred from the data using the first
`infer_schema_length` rows.
When `False`, the schema is not inferred and will be `pl.String` if not
specified in `schema` or `schema_overrides`.
infer_schema_length
The maximum number of rows to scan for schema inference.
If set to `0`, all columns will be read as `pl.String`.
If set to `None`, the full data may be scanned *(this is slow)*.
Set `infer_schema=False` to read all columns as `pl.String`.
batch_size
Number of lines to read into the buffer at once.
Modify this to change performance.
Expand Down Expand Up @@ -184,7 +190,7 @@ def read_csv(
with windows line endings (`\r\n`), one can go with the default `\n`. The extra
`\r` will be removed when processed.
raise_if_empty
When there is no data in the source,`NoDataError` is raised. If this parameter
When there is no data in the source, `NoDataError` is raised. If this parameter
is set to False, an empty DataFrame (with no columns) is returned instead.
truncate_ragged_lines
Truncate lines that are longer than the schema.
Expand Down Expand Up @@ -410,6 +416,9 @@ def read_csv(
for column_name, column_dtype in schema_overrides.items()
}

if not infer_schema:
infer_schema_length = 0

with prepare_file_arg(
source,
encoding=encoding,
Expand Down Expand Up @@ -922,6 +931,7 @@ def scan_csv(
ignore_errors: bool = False,
cache: bool = True,
with_column_names: Callable[[list[str]], list[str]] | None = None,
infer_schema: bool = True,
infer_schema_length: int | None = N_INFER_DEFAULT,
n_rows: int | None = None,
encoding: CsvEncoding = "utf8",
Expand Down Expand Up @@ -989,17 +999,22 @@ def scan_csv(
utf8 values to be treated as the empty string you can set this param True.
ignore_errors
Try to keep reading lines if some lines yield errors.
First try `infer_schema_length=0` to read all columns as
First try `infer_schema=False` to read all columns as
`pl.String` to check which values might cause an issue.
cache
Cache the result after reading.
with_column_names
Apply a function over the column names just in time (when they are determined);
this function will receive (and should return) a list of column names.
infer_schema
When `True`, the schema is inferred from the data using the first
`infer_schema_length` rows.
When `False`, the schema is not inferred and will be `pl.String` if not
specified in `schema` or `schema_overrides`.
infer_schema_length
The maximum number of rows to scan for schema inference.
If set to `0`, all columns will be read as `pl.String`.
If set to `None`, the full data may be scanned *(this is slow)*.
Set `infer_schema=False` to read all columns as `pl.String`.
n_rows
Stop reading from CSV file after reading `n_rows`.
encoding : {'utf8', 'utf8-lossy'}
Expand Down Expand Up @@ -1029,7 +1044,7 @@ def scan_csv(
scanning a headerless CSV file). If the given list is shorter than the width of
the DataFrame the remaining columns will have their original name.
raise_if_empty
When there is no data in the source,`NoDataError` is raised. If this parameter
When there is no data in the source, `NoDataError` is raised. If this parameter
is set to False, an empty LazyFrame (with no columns) is returned instead.
truncate_ragged_lines
Truncate lines that are longer than the schema.
Expand Down Expand Up @@ -1153,6 +1168,9 @@ def with_column_names(cols: list[str]) -> list[str]:
normalize_filepath(source, check_not_directory=False) for source in source
]

if not infer_schema:
infer_schema_length = 0

return _scan_csv_impl(
source,
has_header=has_header,
Expand Down
13 changes: 13 additions & 0 deletions py-polars/tests/unit/io/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,19 @@ def test_normalize_filepath(io_files_path: Path) -> None:
)


def test_infer_schema_false() -> None:
csv = textwrap.dedent(
"""\
a,b,c
1,2,3
1,2,3
"""
)
f = io.StringIO(csv)
df = pl.read_csv(f, infer_schema=False)
assert df.dtypes == [pl.String, pl.String, pl.String]


def test_csv_null_values() -> None:
csv = textwrap.dedent(
"""\
Expand Down
Loading