diff --git a/docs/source/user-guide/io/arrow.rst b/docs/source/user-guide/io/arrow.rst new file mode 100644 index 00000000..d571aa99 --- /dev/null +++ b/docs/source/user-guide/io/arrow.rst @@ -0,0 +1,73 @@ +.. Licensed to the Apache Software Foundation (ASF) under one +.. or more contributor license agreements. See the NOTICE file +.. distributed with this work for additional information +.. regarding copyright ownership. The ASF licenses this file +.. to you under the Apache License, Version 2.0 (the +.. "License"); you may not use this file except in compliance +.. with the License. You may obtain a copy of the License at + +.. http://www.apache.org/licenses/LICENSE-2.0 + +.. Unless required by applicable law or agreed to in writing, +.. software distributed under the License is distributed on an +.. "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +.. KIND, either express or implied. See the License for the +.. specific language governing permissions and limitations +.. under the License. + +Arrow +===== + +DataFusion implements the +`Apache Arrow PyCapsule interface `_ +for importing and exporting DataFrames with zero copy. With this feature, any Python +project that implements this interface can share data back and forth with DataFusion +with zero copy. + +We can demonstrate using `pyarrow `_. + +Importing to DataFusion +----------------------- + +Here we will create an Arrow table and import it to DataFusion. + +To import an Arrow table, use :py:func:`datafusion.context.SessionContext.from_arrow`. +This will accept any Python object that implements +`__arrow_c_stream__ `_ +or `__arrow_c_array__ `_ +and returns a ``StructArray``. Common pyarrow sources you can use are: + +- `Array `_ (but it must return a Struct Array) +- `Record Batch `_ +- `Record Batch Reader `_ +- `Table `_ + +.. ipython:: python + + from datafusion import SessionContext + import pyarrow as pa + + data = {"a": [1, 2, 3], "b": [4, 5, 6]} + table = pa.Table.from_pydict(data) + + ctx = SessionContext() + df = ctx.from_arrow(table) + df + +Exporting from DataFusion +------------------------- + +DataFusion DataFrames implement ``__arrow_c_stream__`` PyCapsule interface, so any +Python library that accepts these can import a DataFusion DataFrame directly. + +.. warning:: + It is important to note that this will cause the DataFrame execution to happen, which may be + a time consuming task. That is, you will cause a + :py:func:`datafusion.dataframe.DataFrame.collect` operation call to occur. + + +.. ipython:: python + + df = df.select((col("a") * lit(1.5)).alias("c"), lit("df").alias("d")) + pa.table(df) + diff --git a/docs/source/user-guide/io/index.rst b/docs/source/user-guide/io/index.rst index af08240f..05411327 100644 --- a/docs/source/user-guide/io/index.rst +++ b/docs/source/user-guide/io/index.rst @@ -21,8 +21,8 @@ IO .. toctree:: :maxdepth: 2 + arrow + avro csv - parquet json - avro - + parquet diff --git a/examples/import.py b/examples/import.py index a249a1c4..cd965cb4 100644 --- a/examples/import.py +++ b/examples/import.py @@ -54,5 +54,5 @@ # Convert Arrow Table to datafusion DataFrame arrow_table = pa.Table.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}) -df = ctx.from_arrow_table(arrow_table) +df = ctx.from_arrow(arrow_table) assert type(df) == datafusion.DataFrame diff --git a/python/datafusion/context.py b/python/datafusion/context.py index d4e50cfe..283f71e1 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -586,19 +586,31 @@ def from_pydict( """ return DataFrame(self.ctx.from_pydict(data, name)) - def from_arrow_table( - self, data: pyarrow.Table, name: str | None = None - ) -> DataFrame: - """Create a :py:class:`~datafusion.dataframe.DataFrame` from an Arrow table. + def from_arrow(self, data: Any, name: str | None = None) -> DataFrame: + """Create a :py:class:`~datafusion.dataframe.DataFrame` from an Arrow source. + + The Arrow data source can be any object that implements either + ``__arrow_c_stream__`` or ``__arrow_c_array__``. For the latter, it must return + a struct array. Common examples of sources from pyarrow include Args: - data: Arrow table. + data: Arrow data source. name: Name of the DataFrame. Returns: DataFrame representation of the Arrow table. """ - return DataFrame(self.ctx.from_arrow_table(data, name)) + return DataFrame(self.ctx.from_arrow(data, name)) + + @deprecated("Use ``from_arrow`` instead.") + def from_arrow_table( + self, data: pyarrow.Table, name: str | None = None + ) -> DataFrame: + """Create a :py:class:`~datafusion.dataframe.DataFrame` from an Arrow table. + + This is an alias for :py:func:`from_arrow`. + """ + return self.from_arrow(data, name) def from_pandas(self, data: pandas.DataFrame, name: str | None = None) -> DataFrame: """Create a :py:class:`~datafusion.dataframe.DataFrame` from a Pandas DataFrame. diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index fa739844..4f176013 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -524,3 +524,19 @@ def unnest_columns(self, *columns: str, preserve_nulls: bool = True) -> DataFram """ columns = [c for c in columns] return DataFrame(self.df.unnest_columns(columns, preserve_nulls=preserve_nulls)) + + def __arrow_c_stream__(self, requested_schema: pa.Schema) -> Any: + """Export an Arrow PyCapsule Stream. + + This will execute and collect the DataFrame. We will attempt to respect the + requested schema, but only trivial transformations will be applied such as only + returning the fields listed in the requested schema if their data types match + those in the DataFrame. + + Args: + requested_schema: Attempt to provide the DataFrame using this schema. + + Returns: + Arrow PyCapsule object. + """ + return self.df.__arrow_c_stream__(requested_schema) diff --git a/python/datafusion/tests/test_context.py b/python/datafusion/tests/test_context.py index 66d7e013..0184280c 100644 --- a/python/datafusion/tests/test_context.py +++ b/python/datafusion/tests/test_context.py @@ -156,7 +156,7 @@ def test_from_arrow_table(ctx): table = pa.Table.from_pydict(data) # convert to DataFrame - df = ctx.from_arrow_table(table) + df = ctx.from_arrow(table) tables = list(ctx.catalog().database().names()) assert df @@ -166,13 +166,42 @@ def test_from_arrow_table(ctx): assert df.collect()[0].num_rows == 3 +def record_batch_generator(num_batches: int): + schema = pa.schema([("a", pa.int64()), ("b", pa.int64())]) + for i in range(num_batches): + yield pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], schema=schema + ) + + +@pytest.mark.parametrize( + "source", + [ + # __arrow_c_array__ sources + pa.array([{"a": 1, "b": 4}, {"a": 2, "b": 5}, {"a": 3, "b": 6}]), + # __arrow_c_stream__ sources + pa.RecordBatch.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}), + pa.RecordBatchReader.from_batches( + pa.schema([("a", pa.int64()), ("b", pa.int64())]), record_batch_generator(1) + ), + pa.Table.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]}), + ], +) +def test_from_arrow_sources(ctx, source) -> None: + df = ctx.from_arrow(source) + assert df + assert isinstance(df, DataFrame) + assert df.schema().names == ["a", "b"] + assert df.count() == 3 + + def test_from_arrow_table_with_name(ctx): # create a PyArrow table data = {"a": [1, 2, 3], "b": [4, 5, 6]} table = pa.Table.from_pydict(data) # convert to DataFrame with optional name - df = ctx.from_arrow_table(table, name="tbl") + df = ctx.from_arrow(table, name="tbl") tables = list(ctx.catalog().database().names()) assert df @@ -185,7 +214,7 @@ def test_from_arrow_table_empty(ctx): table = pa.Table.from_pydict(data, schema=schema) # convert to DataFrame - df = ctx.from_arrow_table(table) + df = ctx.from_arrow(table) tables = list(ctx.catalog().database().names()) assert df @@ -200,7 +229,7 @@ def test_from_arrow_table_empty_no_schema(ctx): table = pa.Table.from_pydict(data) # convert to DataFrame - df = ctx.from_arrow_table(table) + df = ctx.from_arrow(table) tables = list(ctx.catalog().database().names()) assert df diff --git a/python/datafusion/tests/test_dataframe.py b/python/datafusion/tests/test_dataframe.py index e5e0c9c8..477bc0fc 100644 --- a/python/datafusion/tests/test_dataframe.py +++ b/python/datafusion/tests/test_dataframe.py @@ -47,7 +47,7 @@ def df(): names=["a", "b", "c"], ) - return ctx.create_dataframe([[batch]]) + return ctx.from_arrow(batch) @pytest.fixture @@ -835,13 +835,42 @@ def test_write_compressed_parquet_missing_compression_level(df, tmp_path, compre df.write_parquet(str(path), compression=compression) -# ctx = SessionContext() - -# # create a RecordBatch and a new DataFrame from it -# batch = pa.RecordBatch.from_arrays( -# [pa.array([1, 2, 3]), pa.array([4, 5, 6]), pa.array([8, 5, 8])], -# names=["a", "b", "c"], -# ) - -# df = ctx.create_dataframe([[batch]]) -# test_execute_stream(df) +def test_dataframe_export(df) -> None: + # Guarantees that we have the canonical implementation + # reading our dataframe export + table = pa.table(df) + assert table.num_columns == 3 + assert table.num_rows == 3 + + desired_schema = pa.schema([("a", pa.int64())]) + + # Verify we can request a schema + table = pa.table(df, schema=desired_schema) + assert table.num_columns == 1 + assert table.num_rows == 3 + + # Expect a table of nulls if the schema don't overlap + desired_schema = pa.schema([("g", pa.string())]) + table = pa.table(df, schema=desired_schema) + assert table.num_columns == 1 + assert table.num_rows == 3 + for i in range(0, 3): + assert table[0][i].as_py() is None + + # Expect an error when we cannot convert schema + desired_schema = pa.schema([("a", pa.float32())]) + failed_convert = False + try: + table = pa.table(df, schema=desired_schema) + except Exception: + failed_convert = True + assert failed_convert + + # Expect an error when we have a not set non-nullable + desired_schema = pa.schema([("g", pa.string(), False)]) + failed_convert = False + try: + table = pa.table(df, schema=desired_schema) + except Exception: + failed_convert = True + assert failed_convert diff --git a/src/context.rs b/src/context.rs index a43599cf..4433d94c 100644 --- a/src/context.rs +++ b/src/context.rs @@ -20,12 +20,15 @@ use std::path::PathBuf; use std::str::FromStr; use std::sync::Arc; +use arrow::array::RecordBatchReader; +use arrow::ffi_stream::ArrowArrayStreamReader; +use arrow::pyarrow::FromPyArrow; use datafusion::execution::session_state::SessionStateBuilder; use object_store::ObjectStore; use url::Url; use uuid::Uuid; -use pyo3::exceptions::{PyKeyError, PyValueError}; +use pyo3::exceptions::{PyKeyError, PyTypeError, PyValueError}; use pyo3::prelude::*; use crate::catalog::{PyCatalog, PyTable}; @@ -444,7 +447,7 @@ impl PySessionContext { let table = table_class.call_method1("from_pylist", args)?; // Convert Arrow Table to datafusion DataFrame - let df = self.from_arrow_table(table, name, py)?; + let df = self.from_arrow(table, name, py)?; Ok(df) } @@ -463,29 +466,42 @@ impl PySessionContext { let table = table_class.call_method1("from_pydict", args)?; // Convert Arrow Table to datafusion DataFrame - let df = self.from_arrow_table(table, name, py)?; + let df = self.from_arrow(table, name, py)?; Ok(df) } /// Construct datafusion dataframe from Arrow Table - pub fn from_arrow_table( + pub fn from_arrow( &mut self, data: Bound<'_, PyAny>, name: Option<&str>, py: Python, ) -> PyResult { - // Instantiate pyarrow Table object & convert to batches - let table = data.call_method0("to_batches")?; + let (schema, batches) = + if let Ok(stream_reader) = ArrowArrayStreamReader::from_pyarrow_bound(&data) { + // Works for any object that implements __arrow_c_stream__ in pycapsule. + + let schema = stream_reader.schema().as_ref().to_owned(); + let batches = stream_reader + .collect::, arrow::error::ArrowError>>() + .map_err(DataFusionError::from)?; + + (schema, batches) + } else if let Ok(array) = RecordBatch::from_pyarrow_bound(&data) { + // While this says RecordBatch, it will work for any object that implements + // __arrow_c_array__ and returns a StructArray. + + (array.schema().as_ref().to_owned(), vec![array]) + } else { + return Err(PyTypeError::new_err( + "Expected either a Arrow Array or Arrow Stream in from_arrow().", + )); + }; - let schema = data.getattr("schema")?; - let schema = schema.extract::>()?; - - // Cast PyAny to RecordBatch type // Because create_dataframe() expects a vector of vectors of record batches // here we need to wrap the vector of record batches in an additional vector - let batches = table.extract::>>()?; - let list_of_batches = PyArrowType::from(vec![batches.0]); - self.create_dataframe(list_of_batches, name, Some(schema), py) + let list_of_batches = PyArrowType::from(vec![batches]); + self.create_dataframe(list_of_batches, name, Some(schema.into()), py) } /// Construct datafusion dataframe from pandas @@ -504,7 +520,7 @@ impl PySessionContext { let table = table_class.call_method1("from_pandas", args)?; // Convert Arrow Table to datafusion DataFrame - let df = self.from_arrow_table(table, name, py)?; + let df = self.from_arrow(table, name, py)?; Ok(df) } @@ -518,7 +534,7 @@ impl PySessionContext { let table = data.call_method0("to_arrow")?; // Convert Arrow Table to datafusion DataFrame - let df = self.from_arrow_table(table, name, data.py())?; + let df = self.from_arrow(table, name, data.py())?; Ok(df) } diff --git a/src/dataframe.rs b/src/dataframe.rs index 4db59d4f..22b05226 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -15,8 +15,14 @@ // specific language governing permissions and limitations // under the License. +use std::ffi::CString; use std::sync::Arc; +use arrow::array::{new_null_array, RecordBatch, RecordBatchIterator, RecordBatchReader}; +use arrow::compute::can_cast_types; +use arrow::error::ArrowError; +use arrow::ffi::FFI_ArrowSchema; +use arrow::ffi_stream::FFI_ArrowArrayStream; use datafusion::arrow::datatypes::Schema; use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow}; use datafusion::arrow::util::pretty; @@ -29,7 +35,7 @@ use datafusion_common::UnnestOptions; use pyo3::exceptions::{PyTypeError, PyValueError}; use pyo3::prelude::*; use pyo3::pybacked::PyBackedStr; -use pyo3::types::PyTuple; +use pyo3::types::{PyCapsule, PyTuple}; use tokio::task::JoinHandle; use crate::errors::py_datafusion_err; @@ -451,6 +457,39 @@ impl PyDataFrame { Ok(table) } + fn __arrow_c_stream__<'py>( + &'py mut self, + py: Python<'py>, + requested_schema: Option>, + ) -> PyResult> { + let mut batches = wait_for_future(py, self.df.as_ref().clone().collect())?; + let mut schema: Schema = self.df.schema().to_owned().into(); + + if let Some(schema_capsule) = requested_schema { + validate_pycapsule(&schema_capsule, "arrow_schema")?; + + let schema_ptr = unsafe { schema_capsule.reference::() }; + let desired_schema = Schema::try_from(schema_ptr).map_err(DataFusionError::from)?; + + schema = project_schema(schema, desired_schema).map_err(DataFusionError::ArrowError)?; + + batches = batches + .into_iter() + .map(|record_batch| record_batch_into_schema(record_batch, &schema)) + .collect::, ArrowError>>() + .map_err(DataFusionError::ArrowError)?; + } + + let batches_wrapped = batches.into_iter().map(Ok); + + let reader = RecordBatchIterator::new(batches_wrapped, Arc::new(schema)); + let reader: Box = Box::new(reader); + + let ffi_stream = FFI_ArrowArrayStream::new(reader); + let stream_capsule_name = CString::new("arrow_array_stream").unwrap(); + PyCapsule::new_bound(py, ffi_stream, Some(stream_capsule_name)) + } + fn execute_stream(&self, py: Python) -> PyResult { // create a Tokio runtime to run the async code let rt = &get_tokio_runtime(py).0; @@ -539,3 +578,78 @@ fn print_dataframe(py: Python, df: DataFrame) -> PyResult<()> { print.call1((result,))?; Ok(()) } + +fn project_schema(from_schema: Schema, to_schema: Schema) -> Result { + let merged_schema = Schema::try_merge(vec![from_schema, to_schema.clone()])?; + + let project_indices: Vec = to_schema + .fields + .iter() + .map(|field| field.name()) + .filter_map(|field_name| merged_schema.index_of(field_name).ok()) + .collect(); + + merged_schema.project(&project_indices) +} + +fn record_batch_into_schema( + record_batch: RecordBatch, + schema: &Schema, +) -> Result { + let schema = Arc::new(schema.clone()); + let base_schema = record_batch.schema(); + if base_schema.fields().len() == 0 { + // Nothing to project + return Ok(RecordBatch::new_empty(schema)); + } + + let array_size = record_batch.column(0).len(); + let mut data_arrays = Vec::with_capacity(schema.fields().len()); + + for field in schema.fields() { + let desired_data_type = field.data_type(); + if let Some(original_data) = record_batch.column_by_name(field.name()) { + let original_data_type = original_data.data_type(); + + if can_cast_types(original_data_type, desired_data_type) { + data_arrays.push(arrow::compute::kernels::cast( + original_data, + desired_data_type, + )?); + } else if field.is_nullable() { + data_arrays.push(new_null_array(desired_data_type, array_size)); + } else { + return Err(ArrowError::CastError(format!("Attempting to cast to non-nullable and non-castable field {} during schema projection.", field.name()))); + } + } else { + if !field.is_nullable() { + return Err(ArrowError::CastError(format!( + "Attempting to set null to non-nullable field {} during schema projection.", + field.name() + ))); + } + data_arrays.push(new_null_array(desired_data_type, array_size)); + } + } + + RecordBatch::try_new(schema, data_arrays) +} + +fn validate_pycapsule(capsule: &Bound, name: &str) -> PyResult<()> { + let capsule_name = capsule.name()?; + if capsule_name.is_none() { + return Err(PyValueError::new_err( + "Expected schema PyCapsule to have name set.", + )); + } + + let capsule_name = capsule_name.unwrap().to_str()?; + if capsule_name != name { + return Err(PyValueError::new_err(format!( + "Expected name '{}' in PyCapsule, instead got '{}'", + name, capsule_name + ))); + } + + Ok(()) +}