diff --git a/crates/polars-io/src/partition.rs b/crates/polars-io/src/partition.rs index cdc1768226ea..de821efb0d60 100644 --- a/crates/polars-io/src/partition.rs +++ b/crates/polars-io/src/partition.rs @@ -129,6 +129,7 @@ where path } +/// Write a partitioned parquet dataset. This functionality is unstable. pub fn write_partitioned_dataset( df: &DataFrame, path: &Path, @@ -139,100 +140,150 @@ pub fn write_partitioned_dataset( where S: AsRef, { - let base_path = path; + // Note: When adding support for formats other than Parquet, avoid writing the partitioned + // columns into the file. We write them for parquet because they are encoded efficiently with + // RLE and also gives us a way to get the hive schema from the parquet file for free. + let get_hive_path_part = { + let schema = &df.schema(); - for (path_part, part_df) in get_hive_partitions_iter(df, partition_by)? { - let dir = base_path.join(path_part); - std::fs::create_dir_all(&dir)?; + let partition_by_col_idx = partition_by + .iter() + .map(|x| { + let Some(i) = schema.index_of(x.as_ref()) else { + polars_bail!(ColumnNotFound: "{}", x.as_ref()) + }; + Ok(i) + }) + .collect::>>()?; - let n_files = (part_df.estimated_size() / chunk_size).clamp(1, 0xf_ffff_ffff_ffff); - let rows_per_file = (df.height() / n_files).saturating_add(1); + const CHAR_SET: &percent_encoding::AsciiSet = &percent_encoding::CONTROLS + .add(b'/') + .add(b'=') + .add(b':') + .add(b' '); - fn get_path_for_index(i: usize) -> String { - // Use a fixed-width file name so that it sorts properly. - format!("{:013x}.parquet", i) + move |df: &DataFrame| { + let cols = df.get_columns(); + + partition_by_col_idx + .iter() + .map(|&i| { + let s = &cols[i].slice(0, 1).cast(&DataType::String).unwrap(); + + format!( + "{}={}", + s.name(), + percent_encoding::percent_encode( + s.str() + .unwrap() + .get(0) + .unwrap_or("__HIVE_DEFAULT_PARTITION__") + .as_bytes(), + CHAR_SET + ) + ) + }) + .collect::>() + .join("/") } + }; - for (i, slice_start) in (0..part_df.height()).step_by(rows_per_file).enumerate() { - let f = std::fs::File::create(dir.join(get_path_for_index(i)))?; + let base_path = path; + let groups = df.group_by(partition_by)?.take_groups(); - file_write_options - .to_writer(f) - .finish(&mut part_df.slice(slice_start as i64, rows_per_file))?; - } - } + let init_part_base_dir = |part_df: &DataFrame| { + let path_part = get_hive_path_part(part_df); + let dir = base_path.join(path_part); + std::fs::create_dir_all(&dir)?; - Ok(()) -} + PolarsResult::Ok(dir) + }; -/// Creates an iterator of (hive partition path, DataFrame) pairs, e.g.: -/// ("a=1/b=1", DataFrame) -fn get_hive_partitions_iter<'a, S>( - df: &'a DataFrame, - partition_by: &'a [S], -) -> PolarsResult + 'a>> -where - S: AsRef, -{ - let schema = df.schema(); - - let partition_by_col_idx = partition_by - .iter() - .map(|x| { - let Some(i) = schema.index_of(x.as_ref()) else { - polars_bail!(ColumnNotFound: "{}", x.as_ref()) - }; - Ok(i) - }) - .collect::>>()?; - - let get_hive_path_part = move |df: &DataFrame| { - const CHAR_SET: &percent_encoding::AsciiSet = &percent_encoding::CONTROLS - .add(b'/') - .add(b'=') - .add(b':') - .add(b' '); + fn get_path_for_index(i: usize) -> String { + // Use a fixed-width file name so that it sorts properly. + format!("{:08x}.parquet", i) + } - let cols = df.get_columns(); + let get_n_files_and_rows_per_file = |part_df: &DataFrame| { + let n_files = (part_df.estimated_size() / chunk_size).clamp(1, 0xffff_ffff); + let rows_per_file = (df.height() / n_files).saturating_add(1); + (n_files, rows_per_file) + }; - partition_by_col_idx - .iter() - .map(|&i| { - let s = &cols[i].slice(0, 1).cast(&DataType::String).unwrap(); - - format!( - "{}={}", - s.name(), - percent_encoding::percent_encode( - s.str() - .unwrap() - .get(0) - .unwrap_or("__HIVE_DEFAULT_PARTITION__") - .as_bytes(), - CHAR_SET - ) - ) - }) - .collect::>() - .join("/") + let write_part = |mut df: DataFrame, path: &Path| { + let f = std::fs::File::create(path)?; + file_write_options.to_writer(f).finish(&mut df)?; + PolarsResult::Ok(()) }; - let groups = df.group_by(partition_by)?; - let groups = groups.take_groups(); - - let out: Box> = match groups { - GroupsProxy::Idx(idx) => Box::new(idx.into_iter().map(move |(_, group)| { - let part_df = - unsafe { df._take_unchecked_slice_sorted(&group, false, IsSorted::Ascending) }; - (get_hive_path_part(&part_df), part_df) - })), - GroupsProxy::Slice { groups, .. } => { - Box::new(groups.into_iter().map(move |[offset, len]| { - let part_df = df.slice(offset as i64, len as usize); - (get_hive_path_part(&part_df), part_df) - })) - }, + // This is sqrt(N) of the actual limit - we chunk the input both at the groups + // proxy level and within every group. + const MAX_OPEN_FILES: usize = 8; + + let finish_part_df = |df: DataFrame| { + let dir_path = init_part_base_dir(&df)?; + let (n_files, rows_per_file) = get_n_files_and_rows_per_file(&df); + + if n_files == 1 { + write_part(df.clone(), &dir_path.join(get_path_for_index(0))) + } else { + (0..df.height()) + .step_by(rows_per_file) + .enumerate() + .collect::>() + .chunks(MAX_OPEN_FILES) + .map(|chunk| { + chunk + .into_par_iter() + .map(|&(idx, slice_start)| { + let df = df.slice(slice_start as i64, rows_per_file); + write_part(df.clone(), &dir_path.join(get_path_for_index(idx))) + }) + .reduce( + || PolarsResult::Ok(()), + |a, b| if a.is_err() { a } else { b }, + ) + }) + .collect::>>()?; + Ok(()) + } }; - Ok(out) + POOL.install(|| match groups { + GroupsProxy::Idx(idx) => idx + .all() + .chunks(MAX_OPEN_FILES) + .map(|chunk| { + chunk + .par_iter() + .map(|group| { + let df = unsafe { + df._take_unchecked_slice_sorted(group, false, IsSorted::Ascending) + }; + finish_part_df(df) + }) + .reduce( + || PolarsResult::Ok(()), + |a, b| if a.is_err() { a } else { b }, + ) + }) + .collect::>>(), + GroupsProxy::Slice { groups, .. } => groups + .chunks(MAX_OPEN_FILES) + .map(|chunk| { + chunk + .into_par_iter() + .map(|&[offset, len]| { + let df = df.slice(offset as i64, len as usize); + finish_part_df(df) + }) + .reduce( + || PolarsResult::Ok(()), + |a, b| if a.is_err() { a } else { b }, + ) + }) + .collect::>>(), + })?; + + Ok(()) } diff --git a/crates/polars-io/src/prelude.rs b/crates/polars-io/src/prelude.rs index 143bc912163d..2da80949cb62 100644 --- a/crates/polars-io/src/prelude.rs +++ b/crates/polars-io/src/prelude.rs @@ -9,5 +9,7 @@ pub use crate::json::*; pub use crate::ndjson::core::*; #[cfg(feature = "parquet")] pub use crate::parquet::{metadata::*, read::*, write::*}; +#[cfg(feature = "parquet")] +pub use crate::partition::write_partitioned_dataset; pub use crate::shared::{SerReader, SerWriter}; pub use crate::utils::*; diff --git a/py-polars/tests/unit/io/test_hive.py b/py-polars/tests/unit/io/test_hive.py index 3a7d3a477592..f3bb97a7fad7 100644 --- a/py-polars/tests/unit/io/test_hive.py +++ b/py-polars/tests/unit/io/test_hive.py @@ -681,7 +681,7 @@ def test_hive_write(tmp_path: Path, df: pl.DataFrame) -> None: @pytest.mark.slow() @pytest.mark.write_disk() -def test_hive_write_multiple_files(tmp_path: Path, monkeypatch: Any) -> None: +def test_hive_write_multiple_files(tmp_path: Path) -> None: chunk_size = 262_144 n_rows = 100_000 df = pl.select(a=pl.repeat(0, n_rows), b=pl.int_range(0, n_rows))