From c6d6d7378d433e63c77badf460ae1710051f0093 Mon Sep 17 00:00:00 2001 From: nameexhaustion Date: Mon, 15 Jul 2024 15:52:53 +1000 Subject: [PATCH] fix: Expand brackets in async glob expansion (#17630) --- crates/polars-io/src/cloud/glob.rs | 5 ++--- crates/polars-io/src/utils/path.rs | 2 +- py-polars/tests/unit/io/test_scan.py | 10 ++++++++++ 3 files changed, 13 insertions(+), 4 deletions(-) diff --git a/crates/polars-io/src/cloud/glob.rs b/crates/polars-io/src/cloud/glob.rs index dc202c5459d3..27efcaf01b67 100644 --- a/crates/polars-io/src/cloud/glob.rs +++ b/crates/polars-io/src/cloud/glob.rs @@ -19,8 +19,7 @@ fn extract_prefix_expansion(url: &str) -> PolarsResult<(String, Option)> let mut expansion = String::new(); let mut last_split_was_wildcard = false; for split in splits { - let has_star = split.contains('*'); - if expansion.is_empty() && !has_star { + if expansion.is_empty() && memchr::memchr2(b'*', b'[', split.as_bytes()).is_none() { // We are still gathering splits in the prefix. if !prefix.is_empty() { prefix.push(DELIMITER); @@ -44,7 +43,7 @@ fn extract_prefix_expansion(url: &str) -> PolarsResult<(String, Option)> expansion.push(DELIMITER); } // Handle '.' inside a split. - if split.contains('.') || split.contains('*') { + if memchr::memchr2(b'.', b'*', split.as_bytes()).is_some() { let processed = split.replace('.', "\\."); expansion.push_str(&processed.replace('*', "([^/]*)")); continue; diff --git a/crates/polars-io/src/utils/path.rs b/crates/polars-io/src/utils/path.rs index 27fa46d6b4d1..5c2dee8bd271 100644 --- a/crates/polars-io/src/utils/path.rs +++ b/crates/polars-io/src/utils/path.rs @@ -315,7 +315,7 @@ pub fn expand_paths_hive( if path.extension() != ext { polars_bail!( InvalidOperation: r#"directory contained paths with different file extensions: \ - first path: {}, second path: {}. Please use a glob pattern to explicitly specify + first path: {}, second path: {}. Please use a glob pattern to explicitly specify \ which files to read (e.g. "dir/**/*", "dir/**/*.parquet")"#, out_paths[i - 1].to_str().unwrap(), path.to_str().unwrap() ); diff --git a/py-polars/tests/unit/io/test_scan.py b/py-polars/tests/unit/io/test_scan.py index 3a1e05d604c6..2cc9cbb025b1 100644 --- a/py-polars/tests/unit/io/test_scan.py +++ b/py-polars/tests/unit/io/test_scan.py @@ -655,3 +655,13 @@ def test_scan_include_file_name( # Test codepaths that materialize empty DataFrames assert_frame_equal(lf.head(0).collect(streaming=streaming), df.head(0)) + + +@pytest.mark.write_disk() +def test_async_path_expansion_bracket_17629(tmp_path: Path) -> None: + path = tmp_path / "data.parquet" + + df = pl.DataFrame({"x": 1}) + df.write_parquet(path) + + assert_frame_equal(pl.scan_parquet(tmp_path / "[d]ata.parquet").collect(), df)