From 805183bd9ae32b628e641d9552f1d0dded4a7a85 Mon Sep 17 00:00:00 2001 From: Daniel Mesejo Date: Thu, 22 Aug 2024 03:53:30 +0200 Subject: [PATCH] feat: enable list of paths for read_csv (#824) --- python/datafusion/context.py | 7 +++++-- python/datafusion/tests/test_context.py | 16 ++++++++++++++++ src/context.rs | 15 +++++++-------- 3 files changed, 28 insertions(+), 10 deletions(-) diff --git a/python/datafusion/context.py b/python/datafusion/context.py index 922cc87a..47f2b9cf 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -883,7 +883,7 @@ def read_json( def read_csv( self, - path: str | pathlib.Path, + path: str | pathlib.Path | list[str] | list[pathlib.Path], schema: pyarrow.Schema | None = None, has_header: bool = True, delimiter: str = ",", @@ -914,9 +914,12 @@ def read_csv( """ if table_partition_cols is None: table_partition_cols = [] + + path = [str(p) for p in path] if isinstance(path, list) else str(path) + return DataFrame( self.ctx.read_csv( - str(path), + path, schema, has_header, delimiter, diff --git a/python/datafusion/tests/test_context.py b/python/datafusion/tests/test_context.py index 8373659b..1b424db8 100644 --- a/python/datafusion/tests/test_context.py +++ b/python/datafusion/tests/test_context.py @@ -484,6 +484,22 @@ def test_read_csv(ctx): csv_df.select(column("c1")).show() +def test_read_csv_list(ctx): + csv_df = ctx.read_csv(path=["testing/data/csv/aggregate_test_100.csv"]) + expected = csv_df.count() * 2 + + double_csv_df = ctx.read_csv( + path=[ + "testing/data/csv/aggregate_test_100.csv", + "testing/data/csv/aggregate_test_100.csv", + ] + ) + actual = double_csv_df.count() + + double_csv_df.select(column("c1")).show() + assert actual == expected + + def test_read_csv_compressed(ctx, tmp_path): test_data_path = "testing/data/csv/aggregate_test_100.csv" diff --git a/src/context.rs b/src/context.rs index 50c4a199..d7890e3f 100644 --- a/src/context.rs +++ b/src/context.rs @@ -805,7 +805,7 @@ impl PySessionContext { file_compression_type=None))] pub fn read_csv( &self, - path: PathBuf, + path: &Bound<'_, PyAny>, schema: Option>, has_header: bool, delimiter: &str, @@ -815,10 +815,6 @@ impl PySessionContext { file_compression_type: Option, py: Python, ) -> PyResult { - let path = path - .to_str() - .ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?; - let delimiter = delimiter.as_bytes(); if delimiter.len() != 1 { return Err(PyValueError::new_err( @@ -833,13 +829,16 @@ impl PySessionContext { .file_extension(file_extension) .table_partition_cols(convert_table_partition_cols(table_partition_cols)?) .file_compression_type(parse_file_compression_type(file_compression_type)?); + options.schema = schema.as_ref().map(|x| &x.0); - if let Some(py_schema) = schema { - options.schema = Some(&py_schema.0); - let result = self.ctx.read_csv(path, options); + if path.is_instance_of::() { + let paths = path.extract::>()?; + let paths = paths.iter().map(|p| p as &str).collect::>(); + let result = self.ctx.read_csv(paths, options); let df = PyDataFrame::new(wait_for_future(py, result).map_err(DataFusionError::from)?); Ok(df) } else { + let path = path.extract::()?; let result = self.ctx.read_csv(path, options); let df = PyDataFrame::new(wait_for_future(py, result).map_err(DataFusionError::from)?); Ok(df)