Skip to content

Commit

Permalink
feat!: Do not parse hive partitions from user provided directory/glob…
Browse files Browse the repository at this point in the history
… path (#17055)
  • Loading branch information
nameexhaustion authored Jun 19, 2024
1 parent 14dd2ca commit 4d15fd4
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 13 deletions.
6 changes: 2 additions & 4 deletions crates/polars-lazy/src/scan/file_list_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ fn expand_paths(
} else if !path.ends_with("/")
&& is_file_cloud(path.to_str().unwrap(), cloud_options)?
{
expand_start_idx = 0;
out_paths.push(path.clone());
continue;
} else if !glob {
Expand Down Expand Up @@ -120,15 +121,12 @@ fn expand_paths(
out_paths.push(path.map_err(to_compute_err)?);
}
} else {
expand_start_idx = 0;
out_paths.push(path.clone());
}
}
}

// Todo:
// This maintains existing behavior - will remove very soon.
expand_start_idx = 0;

Ok((
out_paths.into_iter().collect::<Arc<[_]>>(),
expand_start_idx,
Expand Down
19 changes: 12 additions & 7 deletions crates/polars-plan/src/plans/hive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,18 @@ impl HivePartitions {
}

let schema = match schema {
Some(s) => {
polars_ensure!(
s.len() == partitions.len(),
SchemaMismatch: "path does not match the provided Hive schema"
);
s
},
Some(schema) => Arc::new(
partitions
.iter()
.map(|s| {
let mut field = s.field().into_owned();
if let Some(dtype) = schema.get(field.name()) {
field.dtype = dtype.clone();
};
field
})
.collect::<Schema>(),
),
None => Arc::new(partitions.as_slice().into()),
};

Expand Down
83 changes: 81 additions & 2 deletions py-polars/tests/unit/io/test_hive.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import warnings
from collections import OrderedDict
from functools import partial
from pathlib import Path
from typing import Any
from typing import Any, Callable

import pyarrow.parquet as pq
import pytest
Expand Down Expand Up @@ -187,7 +188,7 @@ def test_hive_partitioned_err(io_files_path: Path, tmp_path: Path) -> None:
df.write_parquet(root / "file.parquet")

with pytest.raises(DuplicateError, match="invalid Hive partition schema"):
pl.scan_parquet(root / "**/*.parquet", hive_partitioning=True).collect()
pl.scan_parquet(tmp_path, hive_partitioning=True).collect()


@pytest.mark.write_disk()
Expand Down Expand Up @@ -273,3 +274,81 @@ def test_read_parquet_hive_schema_with_pyarrow() -> None:
match="cannot use `hive_partitions` with `use_pyarrow=True`",
):
pl.read_parquet("test.parquet", hive_schema={"c": pl.Int32}, use_pyarrow=True)


@pytest.mark.parametrize(
("scan_func", "write_func"),
[
(pl.scan_parquet, pl.DataFrame.write_parquet),
],
)
@pytest.mark.parametrize(
("append_glob", "glob"),
[
("**/*.bin", True),
("", True),
("", False),
],
)
def test_hive_partition_directory_scan(
tmp_path: Path,
append_glob: str,
scan_func: Callable[[Any], pl.LazyFrame],
write_func: Callable[[pl.DataFrame, Path], None],
glob: bool,
) -> None:
tmp_path.mkdir(exist_ok=True)

dfs = [
pl.DataFrame({'x': 5 * [1], 'a': 1, 'b': 1}),
pl.DataFrame({'x': 5 * [2], 'a': 1, 'b': 2}),
pl.DataFrame({'x': 5 * [3], 'a': 2, 'b': 1}),
pl.DataFrame({'x': 5 * [4], 'a': 2, 'b': 2}),
] # fmt: skip

for df in dfs:
a = df.item(0, "a")
b = df.item(0, "b")
path = tmp_path / f"a={a}/b={b}/data.bin"
path.parent.mkdir(exist_ok=True, parents=True)
write_func(df.drop("a", "b"), path)

df = pl.concat(dfs)
hive_schema = df.lazy().select("a", "b").collect_schema()

scan = scan_func
scan = partial(scan_func, hive_schema=hive_schema, glob=glob)

out = scan(
tmp_path / append_glob,
hive_partitioning=True,
hive_schema=hive_schema,
).collect()
assert_frame_equal(out, df)

out = scan(tmp_path / append_glob, hive_partitioning=False).collect()
assert_frame_equal(out, df.drop("a", "b"))

out = scan(
tmp_path / "a=1" / append_glob,
hive_partitioning=True,
).collect()
assert_frame_equal(out, df.filter(a=1).drop("a"))

out = scan(
tmp_path / "a=1" / append_glob,
hive_partitioning=False,
).collect()
assert_frame_equal(out, df.filter(a=1).drop("a", "b"))

path = tmp_path / "a=1/b=1/data.bin"

df = dfs[0]
out = scan(path, hive_partitioning=True).collect()

assert_frame_equal(out, df)

df = dfs[0].drop("a", "b")
out = scan(path, hive_partitioning=False).collect()

assert_frame_equal(out, df)

0 comments on commit 4d15fd4

Please sign in to comment.