Skip to content

Commit

Permalink
feat(python): Implement Arrow PyCapsule Interface for Series/DataFram…
Browse files Browse the repository at this point in the history
…e export (#17676)
  • Loading branch information
kylebarron authored Jul 25, 2024
1 parent 7137895 commit 9978d88
Show file tree
Hide file tree
Showing 9 changed files with 169 additions and 2 deletions.
2 changes: 2 additions & 0 deletions crates/polars-arrow/src/ffi/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ impl Drop for ArrowArrayStream {
}
}

unsafe impl Send for ArrowArrayStream {}

impl ArrowArrayStream {
/// Creates an empty [`ArrowArrayStream`] used to import from a producer.
pub fn empty() -> Self {
Expand Down
8 changes: 8 additions & 0 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1264,6 +1264,14 @@ def __deepcopy__(self, memo: None = None) -> DataFrame:
def _ipython_key_completions_(self) -> list[str]:
return self.columns

def __arrow_c_stream__(self, requested_schema: object) -> object:
"""
Export a DataFrame via the Arrow PyCapsule Interface.
https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html
"""
return self._df.__arrow_c_stream__(requested_schema)

def _repr_html_(self, *, _from_series: bool = False) -> str:
"""
Format output data in HTML for display in Jupyter Notebooks.
Expand Down
8 changes: 8 additions & 0 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,6 +1447,14 @@ def __array_ufunc__(
)
raise NotImplementedError(msg)

def __arrow_c_stream__(self, requested_schema: object) -> object:
"""
Export a Series via the Arrow PyCapsule Interface.
https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html
"""
return self._s.__arrow_c_stream__(requested_schema)

def _repr_html_(self) -> str:
"""Format output data in HTML for display in Jupyter Notebooks."""
return self.to_frame()._repr_html_(_from_series=True)
Expand Down
14 changes: 13 additions & 1 deletion py-polars/src/dataframe/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ use polars::export::arrow::record_batch::RecordBatch;
use polars_core::export::arrow::datatypes::IntegerType;
use polars_core::utils::arrow::compute::cast::CastOptionsImpl;
use pyo3::prelude::*;
use pyo3::types::{PyList, PyTuple};
use pyo3::types::{PyCapsule, PyList, PyTuple};

use super::*;
use crate::conversion::{ObjectValue, Wrap};
use crate::interop;
use crate::interop::arrow::to_py::dataframe_to_stream;
use crate::prelude::PyCompatLevel;

#[pymethods]
Expand Down Expand Up @@ -130,4 +131,15 @@ impl PyDataFrame {
Ok(rbs)
})
}

#[allow(unused_variables)]
#[pyo3(signature = (requested_schema=None))]
fn __arrow_c_stream__<'py>(
&'py mut self,
py: Python<'py>,
requested_schema: Option<PyObject>,
) -> PyResult<Bound<'py, PyCapsule>> {
self.df.align_chunks();
dataframe_to_stream(&self.df, py)
}
}
78 changes: 78 additions & 0 deletions py-polars/src/interop/arrow/to_py.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
use std::ffi::CString;

use arrow::datatypes::ArrowDataType;
use arrow::ffi;
use arrow::record_batch::RecordBatch;
use polars::datatypes::CompatLevel;
use polars::frame::DataFrame;
use polars::prelude::{ArrayRef, ArrowField};
use polars::series::Series;
use polars_core::utils::arrow;
use polars_error::PolarsResult;
use pyo3::ffi::Py_uintptr_t;
use pyo3::prelude::*;
use pyo3::types::PyCapsule;

/// Arrow array to Python.
pub(crate) fn to_py_array(
Expand Down Expand Up @@ -49,3 +57,73 @@ pub(crate) fn to_py_rb(

Ok(record.to_object(py))
}

/// Export a series to a C stream via a PyCapsule according to the Arrow PyCapsule Interface
/// https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html
pub(crate) fn series_to_stream<'py>(
series: &'py Series,
py: Python<'py>,
) -> PyResult<Bound<'py, PyCapsule>> {
let field = series.field().to_arrow(CompatLevel::newest());
let iter = Box::new(series.chunks().clone().into_iter().map(Ok)) as _;
let stream = ffi::export_iterator(iter, field);
let stream_capsule_name = CString::new("arrow_array_stream").unwrap();
PyCapsule::new_bound(py, stream, Some(stream_capsule_name))
}

pub(crate) fn dataframe_to_stream<'py>(
df: &'py DataFrame,
py: Python<'py>,
) -> PyResult<Bound<'py, PyCapsule>> {
let iter = Box::new(DataFrameStreamIterator::new(df));
let field = iter.field();
let stream = ffi::export_iterator(iter, field);
let stream_capsule_name = CString::new("arrow_array_stream").unwrap();
PyCapsule::new_bound(py, stream, Some(stream_capsule_name))
}

pub struct DataFrameStreamIterator {
columns: Vec<Series>,
data_type: ArrowDataType,
idx: usize,
n_chunks: usize,
}

impl DataFrameStreamIterator {
fn new(df: &DataFrame) -> Self {
let schema = df.schema().to_arrow(CompatLevel::newest());
let data_type = ArrowDataType::Struct(schema.fields);

Self {
columns: df.get_columns().to_vec(),
data_type,
idx: 0,
n_chunks: df.n_chunks(),
}
}

fn field(&self) -> ArrowField {
ArrowField::new("", self.data_type.clone(), false)
}
}

impl Iterator for DataFrameStreamIterator {
type Item = PolarsResult<ArrayRef>;

fn next(&mut self) -> Option<Self::Item> {
if self.idx >= self.n_chunks {
None
} else {
// create a batch of the columns with the same chunk no.
let batch_cols = self
.columns
.iter()
.map(|s| s.to_arrow(self.idx, CompatLevel::newest()))
.collect();
self.idx += 1;

let array = arrow::array::StructArray::new(self.data_type.clone(), batch_cols, None);
Some(Ok(Box::new(array)))
}
}
}
13 changes: 12 additions & 1 deletion py-polars/src/series/export.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use polars_core::prelude::*;
use pyo3::prelude::*;
use pyo3::types::PyList;
use pyo3::types::{PyCapsule, PyList};

use crate::interop::arrow::to_py::series_to_stream;
use crate::prelude::*;
use crate::{interop, PySeries};

Expand Down Expand Up @@ -157,4 +158,14 @@ impl PySeries {
)
})
}

#[allow(unused_variables)]
#[pyo3(signature = (requested_schema=None))]
fn __arrow_c_stream__<'py>(
&'py self,
py: Python<'py>,
requested_schema: Option<PyObject>,
) -> PyResult<Bound<'py, PyCapsule>> {
series_to_stream(&self.series, py)
}
}
18 changes: 18 additions & 0 deletions py-polars/tests/unit/interop/test_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from polars.exceptions import ComputeError, UnstableWarning
from polars.interchange.protocol import CompatLevel
from polars.testing import assert_frame_equal, assert_series_equal
from tests.unit.utils.pycapsule_utils import PyCapsuleStreamHolder


def test_arrow_list_roundtrip() -> None:
Expand Down Expand Up @@ -749,3 +750,20 @@ def test_compat_level(monkeypatch: pytest.MonkeyPatch) -> None:
assert len(df.write_ipc_stream(None).getbuffer()) == 544
assert len(df.write_ipc_stream(None, compat_level=oldest).getbuffer()) == 672
assert len(df.write_ipc_stream(None, compat_level=newest).getbuffer()) == 544


def test_df_pycapsule_interface() -> None:
df = pl.DataFrame(
{
"a": [1, 2, 3],
"b": ["a", "b", "c"],
"c": ["fooooooooooooooooooooo", "bar", "looooooooooooooooong string"],
}
)
out = pa.table(PyCapsuleStreamHolder(df))
assert df.shape == out.shape
assert df.schema.names() == out.schema.names

df2 = pl.from_arrow(out)
assert isinstance(df2, pl.DataFrame)
assert df.equals(df2)
8 changes: 8 additions & 0 deletions py-polars/tests/unit/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
ShapeError,
)
from polars.testing import assert_frame_equal, assert_series_equal
from tests.unit.utils.pycapsule_utils import PyCapsuleStreamHolder

if TYPE_CHECKING:
from zoneinfo import ZoneInfo
Expand Down Expand Up @@ -628,6 +629,13 @@ def test_arrow() -> None:
)


def test_pycapsule_interface() -> None:
a = pl.Series("a", [1, 2, 3, None])
out = pa.chunked_array(PyCapsuleStreamHolder(a))
out_arr = out.combine_chunks()
assert out_arr == pa.array([1, 2, 3, None])


def test_get() -> None:
a = pl.Series("a", [1, 2, 3])
pos_idxs = pl.Series("idxs", [2, 0, 1, 0], dtype=pl.Int8)
Expand Down
22 changes: 22 additions & 0 deletions py-polars/tests/unit/utils/pycapsule_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Any


class PyCapsuleStreamHolder:
"""
Hold the Arrow C Stream pycapsule.
A class that exposes _only_ the Arrow C Stream interface via Arrow PyCapsules.
This ensures that pyarrow is seeing _only_ the `__arrow_c_stream__` dunder, and
that nothing else (e.g. the dataframe or array interface) is actually being
used.
This is used by tests across multiple files.
"""

arrow_obj: Any

def __init__(self, arrow_obj: object) -> None:
self.arrow_obj = arrow_obj

def __arrow_c_stream__(self, requested_schema: object = None) -> object:
return self.arrow_obj.__arrow_c_stream__(requested_schema)

0 comments on commit 9978d88

Please sign in to comment.