Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Implement Arrow PyCapsule Interface for Series/DataFrame export #17676

Merged
merged 14 commits into from
Jul 25, 2024
2 changes: 2 additions & 0 deletions crates/polars-arrow/src/ffi/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ impl Drop for ArrowArrayStream {
}
}

unsafe impl Send for ArrowArrayStream {}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


impl ArrowArrayStream {
/// Creates an empty [`ArrowArrayStream`] used to import from a producer.
pub fn empty() -> Self {
Expand Down
8 changes: 8 additions & 0 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1264,6 +1264,14 @@ def __deepcopy__(self, memo: None = None) -> DataFrame:
def _ipython_key_completions_(self) -> list[str]:
return self.columns

def __arrow_c_stream__(self, requested_schema: object) -> object:
"""
Export a DataFrame via the Arrow PyCapsule Interface.

https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html
"""
return self._df.__arrow_c_stream__(requested_schema)

def _repr_html_(self, *, _from_series: bool = False) -> str:
"""
Format output data in HTML for display in Jupyter Notebooks.
Expand Down
8 changes: 8 additions & 0 deletions py-polars/polars/series/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,6 +1447,14 @@ def __array_ufunc__(
)
raise NotImplementedError(msg)

def __arrow_c_stream__(self, requested_schema: object) -> object:
"""
Export a Series via the Arrow PyCapsule Interface.

https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html
"""
return self._s.__arrow_c_stream__(requested_schema)

def _repr_html_(self) -> str:
"""Format output data in HTML for display in Jupyter Notebooks."""
return self.to_frame()._repr_html_(_from_series=True)
Expand Down
14 changes: 13 additions & 1 deletion py-polars/src/dataframe/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ use polars::export::arrow::record_batch::RecordBatch;
use polars_core::export::arrow::datatypes::IntegerType;
use polars_core::utils::arrow::compute::cast::CastOptionsImpl;
use pyo3::prelude::*;
use pyo3::types::{PyList, PyTuple};
use pyo3::types::{PyCapsule, PyList, PyTuple};

use super::*;
use crate::conversion::{ObjectValue, Wrap};
use crate::interop;
use crate::interop::arrow::to_py::dataframe_to_stream;
use crate::prelude::PyCompatLevel;

#[pymethods]
Expand Down Expand Up @@ -130,4 +131,15 @@ impl PyDataFrame {
Ok(rbs)
})
}

#[allow(unused_variables)]
#[pyo3(signature = (requested_schema=None))]
fn __arrow_c_stream__<'py>(
&'py mut self,
py: Python<'py>,
requested_schema: Option<PyObject>,
) -> PyResult<Bound<'py, PyCapsule>> {
self.df.align_chunks();
dataframe_to_stream(&self.df, py)
}
}
78 changes: 78 additions & 0 deletions py-polars/src/interop/arrow/to_py.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
use std::ffi::CString;
kylebarron marked this conversation as resolved.
Show resolved Hide resolved

use arrow::datatypes::ArrowDataType;
use arrow::ffi;
use arrow::record_batch::RecordBatch;
use polars::datatypes::CompatLevel;
use polars::frame::DataFrame;
use polars::prelude::{ArrayRef, ArrowField};
use polars::series::Series;
use polars_core::utils::arrow;
use polars_error::PolarsResult;
use pyo3::ffi::Py_uintptr_t;
use pyo3::prelude::*;
use pyo3::types::PyCapsule;

/// Arrow array to Python.
pub(crate) fn to_py_array(
Expand Down Expand Up @@ -49,3 +57,73 @@ pub(crate) fn to_py_rb(

Ok(record.to_object(py))
}

/// Export a series to a C stream via a PyCapsule according to the Arrow PyCapsule Interface
/// https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html
pub(crate) fn series_to_stream<'py>(
series: &'py Series,
py: Python<'py>,
) -> PyResult<Bound<'py, PyCapsule>> {
let field = series.field().to_arrow(CompatLevel::newest());
let iter = Box::new(series.chunks().clone().into_iter().map(Ok)) as _;
let stream = ffi::export_iterator(iter, field);
let stream_capsule_name = CString::new("arrow_array_stream").unwrap();
PyCapsule::new_bound(py, stream, Some(stream_capsule_name))
}

pub(crate) fn dataframe_to_stream<'py>(
df: &'py DataFrame,
py: Python<'py>,
) -> PyResult<Bound<'py, PyCapsule>> {
let iter = Box::new(DataFrameStreamIterator::new(df));
let field = iter.field();
let stream = ffi::export_iterator(iter, field);
let stream_capsule_name = CString::new("arrow_array_stream").unwrap();
PyCapsule::new_bound(py, stream, Some(stream_capsule_name))
}

pub struct DataFrameStreamIterator {
columns: Vec<Series>,
data_type: ArrowDataType,
idx: usize,
n_chunks: usize,
}

impl DataFrameStreamIterator {
fn new(df: &DataFrame) -> Self {
let schema = df.schema().to_arrow(CompatLevel::newest());
let data_type = ArrowDataType::Struct(schema.fields);

Self {
columns: df.get_columns().to_vec(),
data_type,
idx: 0,
n_chunks: df.n_chunks(),
}
}

fn field(&self) -> ArrowField {
ArrowField::new("", self.data_type.clone(), false)
}
}

impl Iterator for DataFrameStreamIterator {
type Item = PolarsResult<ArrayRef>;

fn next(&mut self) -> Option<Self::Item> {
if self.idx >= self.n_chunks {
None
} else {
// create a batch of the columns with the same chunk no.
let batch_cols = self
.columns
.iter()
.map(|s| s.to_arrow(self.idx, CompatLevel::newest()))
.collect();
self.idx += 1;

let array = arrow::array::StructArray::new(self.data_type.clone(), batch_cols, None);
Some(Ok(Box::new(array)))
}
}
}
13 changes: 12 additions & 1 deletion py-polars/src/series/export.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use polars_core::prelude::*;
use pyo3::prelude::*;
use pyo3::types::PyList;
use pyo3::types::{PyCapsule, PyList};

use crate::interop::arrow::to_py::series_to_stream;
use crate::prelude::*;
use crate::{interop, PySeries};

Expand Down Expand Up @@ -157,4 +158,14 @@ impl PySeries {
)
})
}

#[allow(unused_variables)]
#[pyo3(signature = (requested_schema=None))]
fn __arrow_c_stream__<'py>(
&'py self,
py: Python<'py>,
requested_schema: Option<PyObject>,
) -> PyResult<Bound<'py, PyCapsule>> {
series_to_stream(&self.series, py)
}
}
18 changes: 18 additions & 0 deletions py-polars/tests/unit/interop/test_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from polars.exceptions import ComputeError, UnstableWarning
from polars.interchange.protocol import CompatLevel
from polars.testing import assert_frame_equal, assert_series_equal
from tests.unit.utils.pycapsule_utils import PyCapsuleStreamHolder


def test_arrow_list_roundtrip() -> None:
Expand Down Expand Up @@ -749,3 +750,20 @@ def test_compat_level(monkeypatch: pytest.MonkeyPatch) -> None:
assert len(df.write_ipc_stream(None).getbuffer()) == 544
assert len(df.write_ipc_stream(None, compat_level=oldest).getbuffer()) == 672
assert len(df.write_ipc_stream(None, compat_level=newest).getbuffer()) == 544


def test_df_pycapsule_interface() -> None:
df = pl.DataFrame(
{
"a": [1, 2, 3],
"b": ["a", "b", "c"],
"c": ["fooooooooooooooooooooo", "bar", "looooooooooooooooong string"],
}
)
out = pa.table(PyCapsuleStreamHolder(df))
assert df.shape == out.shape
assert df.schema.names() == out.schema.names
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could drop df just now and make sure that the recreated df2 below still gets the expected contents (instead of crashing or whatever else).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the test to not hold a bare capsule, but rather call the underlying object's __arrow_c_stream__ method. I'm not sure what you're suggesting this test, since I need to check below that df and df2 are equal. Are you suggesting after that I should drop df again? That isn't possible when this utility class doesn't hold bare capsules


df2 = pl.from_arrow(out)
assert isinstance(df2, pl.DataFrame)
assert df.equals(df2)
8 changes: 8 additions & 0 deletions py-polars/tests/unit/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
ShapeError,
)
from polars.testing import assert_frame_equal, assert_series_equal
from tests.unit.utils.pycapsule_utils import PyCapsuleStreamHolder

if TYPE_CHECKING:
from zoneinfo import ZoneInfo
Expand Down Expand Up @@ -628,6 +629,13 @@ def test_arrow() -> None:
)


def test_pycapsule_interface() -> None:
a = pl.Series("a", [1, 2, 3, None])
out = pa.chunked_array(PyCapsuleStreamHolder(a))
out_arr = out.combine_chunks()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same idea here (drop a before doing things with out)

assert out_arr == pa.array([1, 2, 3, None])


def test_get() -> None:
a = pl.Series("a", [1, 2, 3])
pos_idxs = pl.Series("idxs", [2, 0, 1, 0], dtype=pl.Int8)
Expand Down
22 changes: 22 additions & 0 deletions py-polars/tests/unit/utils/pycapsule_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Any


class PyCapsuleStreamHolder:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is put in a helper file because it's used by tests both in this PR and in https://github.com/pola-rs/polars/pull/17693/files. Let me know if there's a better place to put this test helper.

"""
Hold the Arrow C Stream pycapsule.

A class that exposes _only_ the Arrow C Stream interface via Arrow PyCapsules.
This ensures that pyarrow is seeing _only_ the `__arrow_c_stream__` dunder, and
that nothing else (e.g. the dataframe or array interface) is actually being
used.

This is used by tests across multiple files.
"""

arrow_obj: Any

def __init__(self, arrow_obj: object) -> None:
self.arrow_obj = arrow_obj

def __arrow_c_stream__(self, requested_schema: object = None) -> object:
return self.arrow_obj.__arrow_c_stream__(requested_schema)