Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-40066: [Python] Support requested_schema in __arrow_c_stream__() #40070

Merged
merged 26 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/pyarrow/includes/libarrow_python.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,14 @@ cdef extern from "arrow/python/ipc.h" namespace "arrow::py":
object)


cdef extern from "arrow/python/ipc.h" namespace "arrow::py" nogil:
cdef cppclass CCastingRecordBatchReader" arrow::py::CastingRecordBatchReader" \
(CRecordBatchReader):
@staticmethod
CResult[shared_ptr[CRecordBatchReader]] Make(shared_ptr[CRecordBatchReader],
shared_ptr[CSchema])


cdef extern from "arrow/python/extension_type.h" namespace "arrow::py":
cdef cppclass CPyExtensionType \
" arrow::py::PyExtensionType"(CExtensionType):
Expand Down
39 changes: 33 additions & 6 deletions python/pyarrow/ipc.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,38 @@ cdef class RecordBatchReader(_Weakrefable):
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

def cast(self, target_schema):
"""
Wrap this reader with one that casts each batch lazily as it is pulled.
Currently only a safe cast to target_schema is implemented.

Parameters
----------
target_schema : Schema
Schema to cast to, the names and order of fields must match.

Returns
-------
RecordBatchReader
"""
cdef:
shared_ptr[CSchema] c_schema
shared_ptr[CRecordBatchReader] c_reader
RecordBatchReader out

if self.schema.names != target_schema.names:
raise ValueError("Target schema's field names are not matching "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, but you can use f-strings now rather than explicit format calls.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

f"the table's field names: {self.schema.names}, "
f"{target_schema.names}")

c_schema = pyarrow_unwrap_schema(target_schema)
c_reader = GetResultValue(CCastingRecordBatchReader.Make(
self.reader, c_schema))

out = RecordBatchReader.__new__(RecordBatchReader)
out.reader = c_reader
return out

def _export_to_c(self, out_ptr):
"""
Export to a C ArrowArrayStream struct, given its pointer.
Expand Down Expand Up @@ -827,8 +859,6 @@ cdef class RecordBatchReader(_Weakrefable):
The schema to which the stream should be casted, passed as a
PyCapsule containing a C ArrowSchema representation of the
requested schema.
Currently, this is not supported and will raise a
NotImplementedError if the schema doesn't match the current schema.

Returns
-------
Expand All @@ -840,11 +870,8 @@ cdef class RecordBatchReader(_Weakrefable):

if requested_schema is not None:
out_schema = Schema._import_from_c_capsule(requested_schema)
# TODO: figure out a way to check if one schema is castable to
# another. Once we have that, we can perform validation here and
# if successful creating a wrapping reader that casts each batch.
if self.schema != out_schema:
raise NotImplementedError("Casting to requested_schema")
return self.cast(out_schema).__arrow_c_stream__()

stream_capsule = alloc_c_stream(&c_stream)

Expand Down
66 changes: 66 additions & 0 deletions python/pyarrow/src/arrow/python/ipc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <memory>

#include "arrow/compute/cast.h"
#include "arrow/python/pyarrow.h"

namespace arrow {
Expand Down Expand Up @@ -63,5 +64,70 @@ Result<std::shared_ptr<RecordBatchReader>> PyRecordBatchReader::Make(
return reader;
}

CastingRecordBatchReader::CastingRecordBatchReader() = default;

Status CastingRecordBatchReader::Init(std::shared_ptr<RecordBatchReader> parent,
std::shared_ptr<Schema> schema) {
std::shared_ptr<Schema> src = parent->schema();

// The check for names has already been done in Python where it's easier to
// generate a nice error message.
int num_fields = schema->num_fields();
if (src->num_fields() != num_fields) {
return Status::Invalid("Number of fields not equal");
}

// Ensure all columns can be cast before succeeding
for (int i = 0; i < num_fields; i++) {
if (!compute::CanCast(*src->field(i)->type(), *schema->field(i)->type())) {
return Status::TypeError("Field ", i, " cannot be cast from ",
src->field(i)->type()->ToString(), " to ",
schema->field(i)->type()->ToString());
}
}

parent_ = std::move(parent);
schema_ = std::move(schema);

return Status::OK();
}

std::shared_ptr<Schema> CastingRecordBatchReader::schema() const { return schema_; }

Status CastingRecordBatchReader::ReadNext(std::shared_ptr<RecordBatch>* batch) {
std::shared_ptr<RecordBatch> out;
ARROW_RETURN_NOT_OK(parent_->ReadNext(&out));
if (!out) {
batch->reset();
return Status::OK();
}

auto num_columns = out->num_columns();
auto options = compute::CastOptions::Safe();
ArrayVector columns(num_columns);
for (int i = 0; i < num_columns; i++) {
const Array& src = *out->column(i);
if (!schema_->field(i)->nullable() && src.null_count() > 0) {
return Status::Invalid(
"Can't cast array that contains nulls to non-nullable field at index ", i);
}

ARROW_ASSIGN_OR_RAISE(columns[i],
compute::Cast(src, schema_->field(i)->type(), options));
}

*batch = RecordBatch::Make(schema_, out->num_rows(), std::move(columns));
return Status::OK();
}

Result<std::shared_ptr<RecordBatchReader>> CastingRecordBatchReader::Make(
std::shared_ptr<RecordBatchReader> parent, std::shared_ptr<Schema> schema) {
auto reader = std::shared_ptr<CastingRecordBatchReader>(new CastingRecordBatchReader());
ARROW_RETURN_NOT_OK(reader->Init(parent, schema));
return reader;
}

Status CastingRecordBatchReader::Close() { return parent_->Close(); }

} // namespace py
} // namespace arrow
20 changes: 20 additions & 0 deletions python/pyarrow/src/arrow/python/ipc.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,25 @@ class ARROW_PYTHON_EXPORT PyRecordBatchReader : public RecordBatchReader {
OwnedRefNoGIL iterator_;
};

class ARROW_PYTHON_EXPORT CastingRecordBatchReader : public RecordBatchReader {
public:
std::shared_ptr<Schema> schema() const override;

Status ReadNext(std::shared_ptr<RecordBatch>* batch) override;

static Result<std::shared_ptr<RecordBatchReader>> Make(
std::shared_ptr<RecordBatchReader> parent, std::shared_ptr<Schema> schema);

Status Close() override;

protected:
CastingRecordBatchReader();

Status Init(std::shared_ptr<RecordBatchReader> parent, std::shared_ptr<Schema> schema);

std::shared_ptr<RecordBatchReader> parent_;
std::shared_ptr<Schema> schema_;
};

} // namespace py
} // namespace arrow
23 changes: 23 additions & 0 deletions python/pyarrow/table.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -2681,6 +2681,29 @@ cdef class RecordBatch(_Tabular):

return pyarrow_wrap_batch(c_batch)

def cast(self, Schema target_schema, safe=None, options=None):
"""
Cast batch values to another schema.

Parameters
----------
target_schema : Schema
Schema to cast to, the names and order of fields must match.
safe : bool, default True
Check for overflows or other unsafe conversions.
options : CastOptions, default None
Additional checks pass by CastOptions

Returns
-------
RecordBatch
"""
# Wrap the more general Table cast implementation
tbl = Table.from_batches([self])
casted_tbl = tbl.cast(target_schema, safe=safe, options=options)
casted_batch, = casted_tbl.to_batches()
return casted_batch

def _to_pandas(self, options, **kwargs):
return Table.from_batches([self])._to_pandas(options, **kwargs)

Expand Down
18 changes: 16 additions & 2 deletions python/pyarrow/tests/test_cffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,9 +577,8 @@ def test_roundtrip_reader_capsule(constructor):

obj = constructor(schema, batches)

# TODO: turn this to ValueError once we implement validation.
bad_schema = pa.schema({'ints': pa.int32()})
with pytest.raises(NotImplementedError):
with pytest.raises(pa.lib.ArrowTypeError, match="Field 0 cannot be cast"):
obj.__arrow_c_stream__(bad_schema.__arrow_c_schema__())

# Can work with matching schema
Expand All @@ -591,6 +590,21 @@ def test_roundtrip_reader_capsule(constructor):
assert batch.equals(expected)


def test_roundtrip_batch_reader_capsule_requested_schema():
batch = make_batch()
requested_schema = pa.schema([('ints', pa.list_(pa.int64()))])
requested_capsule = requested_schema.__arrow_c_schema__()
batch_as_requested = batch.cast(requested_schema)

capsule = batch.__arrow_c_stream__(requested_capsule)
assert PyCapsule_IsValid(capsule, b"arrow_array_stream") == 1
imported_reader = pa.RecordBatchReader._import_from_c_capsule(capsule)
assert imported_reader.schema == requested_schema
assert imported_reader.read_next_batch().equals(batch_as_requested)
with pytest.raises(StopIteration):
imported_reader.read_next_batch()


def test_roundtrip_batch_reader_capsule():
batch = make_batch()

Expand Down
68 changes: 65 additions & 3 deletions python/pyarrow/tests/test_ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,10 +1226,15 @@ def __arrow_c_stream__(self, requested_schema=None):
reader = pa.RecordBatchReader.from_stream(wrapper, schema=data[0].schema)
assert reader.read_all() == expected

# If schema doesn't match, raises NotImplementedError
with pytest.raises(NotImplementedError):
# Passing a different but castable schema works
good_schema = pa.schema([pa.field("a", pa.int32())])
reader = pa.RecordBatchReader.from_stream(wrapper, schema=good_schema)
assert reader.read_all() == expected.cast(good_schema)

# If schema doesn't match, raises TypeError
with pytest.raises(pa.lib.ArrowTypeError, match='Field 0 cannot be cast'):
pa.RecordBatchReader.from_stream(
wrapper, schema=pa.schema([pa.field('a', pa.int32())])
wrapper, schema=pa.schema([pa.field('a', pa.list_(pa.int32()))])
)

# Proper type errors for wrong input
Expand All @@ -1238,3 +1243,60 @@ def __arrow_c_stream__(self, requested_schema=None):

with pytest.raises(TypeError):
pa.RecordBatchReader.from_stream(expected, schema=data[0])


def test_record_batch_reader_cast():
schema_src = pa.schema([pa.field('a', pa.int64())])
data = [
pa.record_batch([pa.array([1, 2, 3], type=pa.int64())], names=['a']),
pa.record_batch([pa.array([4, 5, 6], type=pa.int64())], names=['a']),
]
table_src = pa.Table.from_batches(data)

# Cast to same type should always work
reader = pa.RecordBatchReader.from_batches(schema_src, data)
assert reader.cast(schema_src).read_all() == table_src

# Check non-trivial cast
schema_dst = pa.schema([pa.field('a', pa.int32())])
reader = pa.RecordBatchReader.from_batches(schema_src, data)
assert reader.cast(schema_dst).read_all() == table_src.cast(schema_dst)

# Check error for field name/length mismatch
reader = pa.RecordBatchReader.from_batches(schema_src, data)
with pytest.raises(ValueError, match="Target schema's field names"):
reader.cast(pa.schema([]))

# Check error for impossible cast in call to .cast()
reader = pa.RecordBatchReader.from_batches(schema_src, data)
with pytest.raises(pa.lib.ArrowTypeError, match='Field 0 cannot be cast'):
reader.cast(pa.schema([pa.field('a', pa.list_(pa.int32()))]))


def test_record_batch_reader_cast_nulls():
schema_src = pa.schema([pa.field('a', pa.int64())])
data_with_nulls = [
pa.record_batch([pa.array([1, 2, None], type=pa.int64())], names=['a']),
]
data_without_nulls = [
pa.record_batch([pa.array([1, 2, 3], type=pa.int64())], names=['a']),
]
table_with_nulls = pa.Table.from_batches(data_with_nulls)
table_without_nulls = pa.Table.from_batches(data_without_nulls)

# Cast to nullable destination should work
reader = pa.RecordBatchReader.from_batches(schema_src, data_with_nulls)
schema_dst = pa.schema([pa.field('a', pa.int32())])
assert reader.cast(schema_dst).read_all() == table_with_nulls.cast(schema_dst)

# Cast to non-nullable destination should work if there are no nulls
reader = pa.RecordBatchReader.from_batches(schema_src, data_without_nulls)
schema_dst = pa.schema([pa.field('a', pa.int32(), nullable=False)])
assert reader.cast(schema_dst).read_all() == table_without_nulls.cast(schema_dst)

# Cast to non-nullable destination should error if there are nulls
# when the batch is pulled
reader = pa.RecordBatchReader.from_batches(schema_src, data_with_nulls)
casted_reader = reader.cast(schema_dst)
with pytest.raises(pa.lib.ArrowInvalid, match="Can't cast array"):
casted_reader.read_all()
32 changes: 30 additions & 2 deletions python/pyarrow/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,9 +635,18 @@ def __arrow_c_stream__(self, requested_schema=None):
result = pa.table(wrapper, schema=data[0].schema)
assert result == expected

# Passing a different schema will cast
good_schema = pa.schema([pa.field('a', pa.int32())])
result = pa.table(wrapper, schema=good_schema)
assert result == expected.cast(good_schema)

# If schema doesn't match, raises NotImplementedError
with pytest.raises(NotImplementedError):
pa.table(wrapper, schema=pa.schema([pa.field('a', pa.int32())]))
with pytest.raises(
pa.lib.ArrowTypeError, match="Field 0 cannot be cast"
):
pa.table(
wrapper, schema=pa.schema([pa.field('a', pa.list_(pa.int32()))])
)


def test_recordbatch_itercolumns():
Expand Down Expand Up @@ -2620,6 +2629,25 @@ def test_record_batch_sort():
assert sorted_rb_dict["c"] == ["foobar", "bar", "foo", "car"]


def test_record_batch_cast():
rb = pa.RecordBatch.from_arrays([
pa.array([None, 1]),
pa.array([False, True])
], names=["a", "b"])
new_schema = pa.schema([pa.field("a", "int64", nullable=True),
pa.field("b", "bool", nullable=False)])

assert rb.cast(new_schema).schema == new_schema

# Casting a nullable field to non-nullable is invalid
rb = pa.RecordBatch.from_arrays([
pa.array([None, 1]),
pa.array([None, True])
], names=["a", "b"])
with pytest.raises(ValueError):
rb.cast(new_schema)


@pytest.mark.parametrize("constructor", [pa.table, pa.record_batch])
def test_numpy_asarray(constructor):
table = constructor([[1, 2, 3], [4.0, 5.0, 6.0]], names=["a", "b"])
Expand Down
Loading