Skip to content

Commit

Permalink
apacheGH-40066: [Python] Support requested_schema in `__arrow_c_str…
Browse files Browse the repository at this point in the history
…eam__()` (apache#40070)

### Rationale for this change

The `requested_schema` portion of the `__arrow_c_stream__()` protocol methods errored in all cases if passed an unequal schema. There was a note about figuring out how to check the cast before doing it and a comment in apache#40066 about how it should be done lazily. This PR (hopefully) solves both!

### What changes are included in this PR?

- Added `arrow::py::CastingRecordBatchReader`, which wraps a `arrow::RecordBatchReader`, casting each batch as it is pulled.

### Are these changes tested?

Yes.

### Are there any user-facing changes?

Yes: the current approach adds `RecordBatchReader.cast()` as the way to access the casting reader.

* Closes: apache#40066
* GitHub Issue: apache#40066

Lead-authored-by: Dewey Dunnington <dewey@fishandwhistle.net>
Co-authored-by: Dewey Dunnington <dewey@voltrondata.com>
Co-authored-by: Antoine Pitrou <pitrou@free.fr>
Signed-off-by: Joris Van den Bossche <jorisvandenbossche@gmail.com>
  • Loading branch information
3 people authored and thisisnic committed Mar 8, 2024
1 parent c35c32a commit 1eb21c2
Show file tree
Hide file tree
Showing 8 changed files with 261 additions and 13 deletions.
8 changes: 8 additions & 0 deletions python/pyarrow/includes/libarrow_python.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,14 @@ cdef extern from "arrow/python/ipc.h" namespace "arrow::py":
object)


cdef extern from "arrow/python/ipc.h" namespace "arrow::py" nogil:
cdef cppclass CCastingRecordBatchReader" arrow::py::CastingRecordBatchReader" \
(CRecordBatchReader):
@staticmethod
CResult[shared_ptr[CRecordBatchReader]] Make(shared_ptr[CRecordBatchReader],
shared_ptr[CSchema])


cdef extern from "arrow/python/extension_type.h" namespace "arrow::py":
cdef cppclass CPyExtensionType \
" arrow::py::PyExtensionType"(CExtensionType):
Expand Down
39 changes: 33 additions & 6 deletions python/pyarrow/ipc.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,38 @@ cdef class RecordBatchReader(_Weakrefable):
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()

def cast(self, target_schema):
"""
Wrap this reader with one that casts each batch lazily as it is pulled.
Currently only a safe cast to target_schema is implemented.
Parameters
----------
target_schema : Schema
Schema to cast to, the names and order of fields must match.
Returns
-------
RecordBatchReader
"""
cdef:
shared_ptr[CSchema] c_schema
shared_ptr[CRecordBatchReader] c_reader
RecordBatchReader out

if self.schema.names != target_schema.names:
raise ValueError("Target schema's field names are not matching "
f"the table's field names: {self.schema.names}, "
f"{target_schema.names}")

c_schema = pyarrow_unwrap_schema(target_schema)
c_reader = GetResultValue(CCastingRecordBatchReader.Make(
self.reader, c_schema))

out = RecordBatchReader.__new__(RecordBatchReader)
out.reader = c_reader
return out

def _export_to_c(self, out_ptr):
"""
Export to a C ArrowArrayStream struct, given its pointer.
Expand Down Expand Up @@ -827,8 +859,6 @@ cdef class RecordBatchReader(_Weakrefable):
The schema to which the stream should be casted, passed as a
PyCapsule containing a C ArrowSchema representation of the
requested schema.
Currently, this is not supported and will raise a
NotImplementedError if the schema doesn't match the current schema.
Returns
-------
Expand All @@ -840,11 +870,8 @@ cdef class RecordBatchReader(_Weakrefable):

if requested_schema is not None:
out_schema = Schema._import_from_c_capsule(requested_schema)
# TODO: figure out a way to check if one schema is castable to
# another. Once we have that, we can perform validation here and
# if successful creating a wrapping reader that casts each batch.
if self.schema != out_schema:
raise NotImplementedError("Casting to requested_schema")
return self.cast(out_schema).__arrow_c_stream__()

stream_capsule = alloc_c_stream(&c_stream)

Expand Down
66 changes: 66 additions & 0 deletions python/pyarrow/src/arrow/python/ipc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#include <memory>

#include "arrow/compute/cast.h"
#include "arrow/python/pyarrow.h"

namespace arrow {
Expand Down Expand Up @@ -63,5 +64,70 @@ Result<std::shared_ptr<RecordBatchReader>> PyRecordBatchReader::Make(
return reader;
}

CastingRecordBatchReader::CastingRecordBatchReader() = default;

Status CastingRecordBatchReader::Init(std::shared_ptr<RecordBatchReader> parent,
std::shared_ptr<Schema> schema) {
std::shared_ptr<Schema> src = parent->schema();

// The check for names has already been done in Python where it's easier to
// generate a nice error message.
int num_fields = schema->num_fields();
if (src->num_fields() != num_fields) {
return Status::Invalid("Number of fields not equal");
}

// Ensure all columns can be cast before succeeding
for (int i = 0; i < num_fields; i++) {
if (!compute::CanCast(*src->field(i)->type(), *schema->field(i)->type())) {
return Status::TypeError("Field ", i, " cannot be cast from ",
src->field(i)->type()->ToString(), " to ",
schema->field(i)->type()->ToString());
}
}

parent_ = std::move(parent);
schema_ = std::move(schema);

return Status::OK();
}

std::shared_ptr<Schema> CastingRecordBatchReader::schema() const { return schema_; }

Status CastingRecordBatchReader::ReadNext(std::shared_ptr<RecordBatch>* batch) {
std::shared_ptr<RecordBatch> out;
ARROW_RETURN_NOT_OK(parent_->ReadNext(&out));
if (!out) {
batch->reset();
return Status::OK();
}

auto num_columns = out->num_columns();
auto options = compute::CastOptions::Safe();
ArrayVector columns(num_columns);
for (int i = 0; i < num_columns; i++) {
const Array& src = *out->column(i);
if (!schema_->field(i)->nullable() && src.null_count() > 0) {
return Status::Invalid(
"Can't cast array that contains nulls to non-nullable field at index ", i);
}

ARROW_ASSIGN_OR_RAISE(columns[i],
compute::Cast(src, schema_->field(i)->type(), options));
}

*batch = RecordBatch::Make(schema_, out->num_rows(), std::move(columns));
return Status::OK();
}

Result<std::shared_ptr<RecordBatchReader>> CastingRecordBatchReader::Make(
std::shared_ptr<RecordBatchReader> parent, std::shared_ptr<Schema> schema) {
auto reader = std::shared_ptr<CastingRecordBatchReader>(new CastingRecordBatchReader());
ARROW_RETURN_NOT_OK(reader->Init(parent, schema));
return reader;
}

Status CastingRecordBatchReader::Close() { return parent_->Close(); }

} // namespace py
} // namespace arrow
20 changes: 20 additions & 0 deletions python/pyarrow/src/arrow/python/ipc.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,25 @@ class ARROW_PYTHON_EXPORT PyRecordBatchReader : public RecordBatchReader {
OwnedRefNoGIL iterator_;
};

class ARROW_PYTHON_EXPORT CastingRecordBatchReader : public RecordBatchReader {
public:
std::shared_ptr<Schema> schema() const override;

Status ReadNext(std::shared_ptr<RecordBatch>* batch) override;

static Result<std::shared_ptr<RecordBatchReader>> Make(
std::shared_ptr<RecordBatchReader> parent, std::shared_ptr<Schema> schema);

Status Close() override;

protected:
CastingRecordBatchReader();

Status Init(std::shared_ptr<RecordBatchReader> parent, std::shared_ptr<Schema> schema);

std::shared_ptr<RecordBatchReader> parent_;
std::shared_ptr<Schema> schema_;
};

} // namespace py
} // namespace arrow
23 changes: 23 additions & 0 deletions python/pyarrow/table.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -2742,6 +2742,29 @@ cdef class RecordBatch(_Tabular):

return pyarrow_wrap_batch(c_batch)

def cast(self, Schema target_schema, safe=None, options=None):
"""
Cast batch values to another schema.
Parameters
----------
target_schema : Schema
Schema to cast to, the names and order of fields must match.
safe : bool, default True
Check for overflows or other unsafe conversions.
options : CastOptions, default None
Additional checks pass by CastOptions
Returns
-------
RecordBatch
"""
# Wrap the more general Table cast implementation
tbl = Table.from_batches([self])
casted_tbl = tbl.cast(target_schema, safe=safe, options=options)
casted_batch, = casted_tbl.to_batches()
return casted_batch

def _to_pandas(self, options, **kwargs):
return Table.from_batches([self])._to_pandas(options, **kwargs)

Expand Down
18 changes: 16 additions & 2 deletions python/pyarrow/tests/test_cffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,9 +633,8 @@ def test_roundtrip_reader_capsule(constructor):

obj = constructor(schema, batches)

# TODO: turn this to ValueError once we implement validation.
bad_schema = pa.schema({'ints': pa.int32()})
with pytest.raises(NotImplementedError):
with pytest.raises(pa.lib.ArrowTypeError, match="Field 0 cannot be cast"):
obj.__arrow_c_stream__(bad_schema.__arrow_c_schema__())

# Can work with matching schema
Expand All @@ -647,6 +646,21 @@ def test_roundtrip_reader_capsule(constructor):
assert batch.equals(expected)


def test_roundtrip_batch_reader_capsule_requested_schema():
batch = make_batch()
requested_schema = pa.schema([('ints', pa.list_(pa.int64()))])
requested_capsule = requested_schema.__arrow_c_schema__()
batch_as_requested = batch.cast(requested_schema)

capsule = batch.__arrow_c_stream__(requested_capsule)
assert PyCapsule_IsValid(capsule, b"arrow_array_stream") == 1
imported_reader = pa.RecordBatchReader._import_from_c_capsule(capsule)
assert imported_reader.schema == requested_schema
assert imported_reader.read_next_batch().equals(batch_as_requested)
with pytest.raises(StopIteration):
imported_reader.read_next_batch()


def test_roundtrip_batch_reader_capsule():
batch = make_batch()

Expand Down
68 changes: 65 additions & 3 deletions python/pyarrow/tests/test_ipc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,10 +1226,15 @@ def __arrow_c_stream__(self, requested_schema=None):
reader = pa.RecordBatchReader.from_stream(wrapper, schema=data[0].schema)
assert reader.read_all() == expected

# If schema doesn't match, raises NotImplementedError
with pytest.raises(NotImplementedError):
# Passing a different but castable schema works
good_schema = pa.schema([pa.field("a", pa.int32())])
reader = pa.RecordBatchReader.from_stream(wrapper, schema=good_schema)
assert reader.read_all() == expected.cast(good_schema)

# If schema doesn't match, raises TypeError
with pytest.raises(pa.lib.ArrowTypeError, match='Field 0 cannot be cast'):
pa.RecordBatchReader.from_stream(
wrapper, schema=pa.schema([pa.field('a', pa.int32())])
wrapper, schema=pa.schema([pa.field('a', pa.list_(pa.int32()))])
)

# Proper type errors for wrong input
Expand All @@ -1238,3 +1243,60 @@ def __arrow_c_stream__(self, requested_schema=None):

with pytest.raises(TypeError):
pa.RecordBatchReader.from_stream(expected, schema=data[0])


def test_record_batch_reader_cast():
schema_src = pa.schema([pa.field('a', pa.int64())])
data = [
pa.record_batch([pa.array([1, 2, 3], type=pa.int64())], names=['a']),
pa.record_batch([pa.array([4, 5, 6], type=pa.int64())], names=['a']),
]
table_src = pa.Table.from_batches(data)

# Cast to same type should always work
reader = pa.RecordBatchReader.from_batches(schema_src, data)
assert reader.cast(schema_src).read_all() == table_src

# Check non-trivial cast
schema_dst = pa.schema([pa.field('a', pa.int32())])
reader = pa.RecordBatchReader.from_batches(schema_src, data)
assert reader.cast(schema_dst).read_all() == table_src.cast(schema_dst)

# Check error for field name/length mismatch
reader = pa.RecordBatchReader.from_batches(schema_src, data)
with pytest.raises(ValueError, match="Target schema's field names"):
reader.cast(pa.schema([]))

# Check error for impossible cast in call to .cast()
reader = pa.RecordBatchReader.from_batches(schema_src, data)
with pytest.raises(pa.lib.ArrowTypeError, match='Field 0 cannot be cast'):
reader.cast(pa.schema([pa.field('a', pa.list_(pa.int32()))]))


def test_record_batch_reader_cast_nulls():
schema_src = pa.schema([pa.field('a', pa.int64())])
data_with_nulls = [
pa.record_batch([pa.array([1, 2, None], type=pa.int64())], names=['a']),
]
data_without_nulls = [
pa.record_batch([pa.array([1, 2, 3], type=pa.int64())], names=['a']),
]
table_with_nulls = pa.Table.from_batches(data_with_nulls)
table_without_nulls = pa.Table.from_batches(data_without_nulls)

# Cast to nullable destination should work
reader = pa.RecordBatchReader.from_batches(schema_src, data_with_nulls)
schema_dst = pa.schema([pa.field('a', pa.int32())])
assert reader.cast(schema_dst).read_all() == table_with_nulls.cast(schema_dst)

# Cast to non-nullable destination should work if there are no nulls
reader = pa.RecordBatchReader.from_batches(schema_src, data_without_nulls)
schema_dst = pa.schema([pa.field('a', pa.int32(), nullable=False)])
assert reader.cast(schema_dst).read_all() == table_without_nulls.cast(schema_dst)

# Cast to non-nullable destination should error if there are nulls
# when the batch is pulled
reader = pa.RecordBatchReader.from_batches(schema_src, data_with_nulls)
casted_reader = reader.cast(schema_dst)
with pytest.raises(pa.lib.ArrowInvalid, match="Can't cast array"):
casted_reader.read_all()
32 changes: 30 additions & 2 deletions python/pyarrow/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,9 +635,18 @@ def __arrow_c_stream__(self, requested_schema=None):
result = pa.table(wrapper, schema=data[0].schema)
assert result == expected

# Passing a different schema will cast
good_schema = pa.schema([pa.field('a', pa.int32())])
result = pa.table(wrapper, schema=good_schema)
assert result == expected.cast(good_schema)

# If schema doesn't match, raises NotImplementedError
with pytest.raises(NotImplementedError):
pa.table(wrapper, schema=pa.schema([pa.field('a', pa.int32())]))
with pytest.raises(
pa.lib.ArrowTypeError, match="Field 0 cannot be cast"
):
pa.table(
wrapper, schema=pa.schema([pa.field('a', pa.list_(pa.int32()))])
)


def test_recordbatch_itercolumns():
Expand Down Expand Up @@ -2620,6 +2629,25 @@ def test_record_batch_sort():
assert sorted_rb_dict["c"] == ["foobar", "bar", "foo", "car"]


def test_record_batch_cast():
rb = pa.RecordBatch.from_arrays([
pa.array([None, 1]),
pa.array([False, True])
], names=["a", "b"])
new_schema = pa.schema([pa.field("a", "int64", nullable=True),
pa.field("b", "bool", nullable=False)])

assert rb.cast(new_schema).schema == new_schema

# Casting a nullable field to non-nullable is invalid
rb = pa.RecordBatch.from_arrays([
pa.array([None, 1]),
pa.array([None, True])
], names=["a", "b"])
with pytest.raises(ValueError):
rb.cast(new_schema)


@pytest.mark.parametrize("constructor", [pa.table, pa.record_batch])
def test_numpy_asarray(constructor):
table = constructor([[1, 2, 3], [4.0, 5.0, 6.0]], names=["a", "b"])
Expand Down

0 comments on commit 1eb21c2

Please sign in to comment.