diff --git a/python/pyarrow/includes/libarrow_python.pxd b/python/pyarrow/includes/libarrow_python.pxd index 906f0b7d28e59..136d6bc8b14cd 100644 --- a/python/pyarrow/includes/libarrow_python.pxd +++ b/python/pyarrow/includes/libarrow_python.pxd @@ -283,6 +283,14 @@ cdef extern from "arrow/python/ipc.h" namespace "arrow::py": object) +cdef extern from "arrow/python/ipc.h" namespace "arrow::py" nogil: + cdef cppclass CCastingRecordBatchReader" arrow::py::CastingRecordBatchReader" \ + (CRecordBatchReader): + @staticmethod + CResult[shared_ptr[CRecordBatchReader]] Make(shared_ptr[CRecordBatchReader], + shared_ptr[CSchema]) + + cdef extern from "arrow/python/extension_type.h" namespace "arrow::py": cdef cppclass CPyExtensionType \ " arrow::py::PyExtensionType"(CExtensionType): diff --git a/python/pyarrow/ipc.pxi b/python/pyarrow/ipc.pxi index 0bb0fe073cc59..617e25a14235d 100644 --- a/python/pyarrow/ipc.pxi +++ b/python/pyarrow/ipc.pxi @@ -772,6 +772,38 @@ cdef class RecordBatchReader(_Weakrefable): def __exit__(self, exc_type, exc_val, exc_tb): self.close() + def cast(self, target_schema): + """ + Wrap this reader with one that casts each batch lazily as it is pulled. + Currently only a safe cast to target_schema is implemented. + + Parameters + ---------- + target_schema : Schema + Schema to cast to, the names and order of fields must match. + + Returns + ------- + RecordBatchReader + """ + cdef: + shared_ptr[CSchema] c_schema + shared_ptr[CRecordBatchReader] c_reader + RecordBatchReader out + + if self.schema.names != target_schema.names: + raise ValueError("Target schema's field names are not matching " + f"the table's field names: {self.schema.names}, " + f"{target_schema.names}") + + c_schema = pyarrow_unwrap_schema(target_schema) + c_reader = GetResultValue(CCastingRecordBatchReader.Make( + self.reader, c_schema)) + + out = RecordBatchReader.__new__(RecordBatchReader) + out.reader = c_reader + return out + def _export_to_c(self, out_ptr): """ Export to a C ArrowArrayStream struct, given its pointer. @@ -827,8 +859,6 @@ cdef class RecordBatchReader(_Weakrefable): The schema to which the stream should be casted, passed as a PyCapsule containing a C ArrowSchema representation of the requested schema. - Currently, this is not supported and will raise a - NotImplementedError if the schema doesn't match the current schema. Returns ------- @@ -840,11 +870,8 @@ cdef class RecordBatchReader(_Weakrefable): if requested_schema is not None: out_schema = Schema._import_from_c_capsule(requested_schema) - # TODO: figure out a way to check if one schema is castable to - # another. Once we have that, we can perform validation here and - # if successful creating a wrapping reader that casts each batch. if self.schema != out_schema: - raise NotImplementedError("Casting to requested_schema") + return self.cast(out_schema).__arrow_c_stream__() stream_capsule = alloc_c_stream(&c_stream) diff --git a/python/pyarrow/src/arrow/python/ipc.cc b/python/pyarrow/src/arrow/python/ipc.cc index 93481822475db..0ed152242425d 100644 --- a/python/pyarrow/src/arrow/python/ipc.cc +++ b/python/pyarrow/src/arrow/python/ipc.cc @@ -19,6 +19,7 @@ #include +#include "arrow/compute/cast.h" #include "arrow/python/pyarrow.h" namespace arrow { @@ -63,5 +64,70 @@ Result> PyRecordBatchReader::Make( return reader; } +CastingRecordBatchReader::CastingRecordBatchReader() = default; + +Status CastingRecordBatchReader::Init(std::shared_ptr parent, + std::shared_ptr schema) { + std::shared_ptr src = parent->schema(); + + // The check for names has already been done in Python where it's easier to + // generate a nice error message. + int num_fields = schema->num_fields(); + if (src->num_fields() != num_fields) { + return Status::Invalid("Number of fields not equal"); + } + + // Ensure all columns can be cast before succeeding + for (int i = 0; i < num_fields; i++) { + if (!compute::CanCast(*src->field(i)->type(), *schema->field(i)->type())) { + return Status::TypeError("Field ", i, " cannot be cast from ", + src->field(i)->type()->ToString(), " to ", + schema->field(i)->type()->ToString()); + } + } + + parent_ = std::move(parent); + schema_ = std::move(schema); + + return Status::OK(); +} + +std::shared_ptr CastingRecordBatchReader::schema() const { return schema_; } + +Status CastingRecordBatchReader::ReadNext(std::shared_ptr* batch) { + std::shared_ptr out; + ARROW_RETURN_NOT_OK(parent_->ReadNext(&out)); + if (!out) { + batch->reset(); + return Status::OK(); + } + + auto num_columns = out->num_columns(); + auto options = compute::CastOptions::Safe(); + ArrayVector columns(num_columns); + for (int i = 0; i < num_columns; i++) { + const Array& src = *out->column(i); + if (!schema_->field(i)->nullable() && src.null_count() > 0) { + return Status::Invalid( + "Can't cast array that contains nulls to non-nullable field at index ", i); + } + + ARROW_ASSIGN_OR_RAISE(columns[i], + compute::Cast(src, schema_->field(i)->type(), options)); + } + + *batch = RecordBatch::Make(schema_, out->num_rows(), std::move(columns)); + return Status::OK(); +} + +Result> CastingRecordBatchReader::Make( + std::shared_ptr parent, std::shared_ptr schema) { + auto reader = std::shared_ptr(new CastingRecordBatchReader()); + ARROW_RETURN_NOT_OK(reader->Init(parent, schema)); + return reader; +} + +Status CastingRecordBatchReader::Close() { return parent_->Close(); } + } // namespace py } // namespace arrow diff --git a/python/pyarrow/src/arrow/python/ipc.h b/python/pyarrow/src/arrow/python/ipc.h index 92232ed830093..2c16d8c967ff0 100644 --- a/python/pyarrow/src/arrow/python/ipc.h +++ b/python/pyarrow/src/arrow/python/ipc.h @@ -48,5 +48,25 @@ class ARROW_PYTHON_EXPORT PyRecordBatchReader : public RecordBatchReader { OwnedRefNoGIL iterator_; }; +class ARROW_PYTHON_EXPORT CastingRecordBatchReader : public RecordBatchReader { + public: + std::shared_ptr schema() const override; + + Status ReadNext(std::shared_ptr* batch) override; + + static Result> Make( + std::shared_ptr parent, std::shared_ptr schema); + + Status Close() override; + + protected: + CastingRecordBatchReader(); + + Status Init(std::shared_ptr parent, std::shared_ptr schema); + + std::shared_ptr parent_; + std::shared_ptr schema_; +}; + } // namespace py } // namespace arrow diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 40d22494e6ffb..d7f7895b538e8 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -2742,6 +2742,29 @@ cdef class RecordBatch(_Tabular): return pyarrow_wrap_batch(c_batch) + def cast(self, Schema target_schema, safe=None, options=None): + """ + Cast batch values to another schema. + + Parameters + ---------- + target_schema : Schema + Schema to cast to, the names and order of fields must match. + safe : bool, default True + Check for overflows or other unsafe conversions. + options : CastOptions, default None + Additional checks pass by CastOptions + + Returns + ------- + RecordBatch + """ + # Wrap the more general Table cast implementation + tbl = Table.from_batches([self]) + casted_tbl = tbl.cast(target_schema, safe=safe, options=options) + casted_batch, = casted_tbl.to_batches() + return casted_batch + def _to_pandas(self, options, **kwargs): return Table.from_batches([self])._to_pandas(options, **kwargs) diff --git a/python/pyarrow/tests/test_cffi.py b/python/pyarrow/tests/test_cffi.py index ce50fe6a6f81d..f8b2ea15d31ad 100644 --- a/python/pyarrow/tests/test_cffi.py +++ b/python/pyarrow/tests/test_cffi.py @@ -633,9 +633,8 @@ def test_roundtrip_reader_capsule(constructor): obj = constructor(schema, batches) - # TODO: turn this to ValueError once we implement validation. bad_schema = pa.schema({'ints': pa.int32()}) - with pytest.raises(NotImplementedError): + with pytest.raises(pa.lib.ArrowTypeError, match="Field 0 cannot be cast"): obj.__arrow_c_stream__(bad_schema.__arrow_c_schema__()) # Can work with matching schema @@ -647,6 +646,21 @@ def test_roundtrip_reader_capsule(constructor): assert batch.equals(expected) +def test_roundtrip_batch_reader_capsule_requested_schema(): + batch = make_batch() + requested_schema = pa.schema([('ints', pa.list_(pa.int64()))]) + requested_capsule = requested_schema.__arrow_c_schema__() + batch_as_requested = batch.cast(requested_schema) + + capsule = batch.__arrow_c_stream__(requested_capsule) + assert PyCapsule_IsValid(capsule, b"arrow_array_stream") == 1 + imported_reader = pa.RecordBatchReader._import_from_c_capsule(capsule) + assert imported_reader.schema == requested_schema + assert imported_reader.read_next_batch().equals(batch_as_requested) + with pytest.raises(StopIteration): + imported_reader.read_next_batch() + + def test_roundtrip_batch_reader_capsule(): batch = make_batch() diff --git a/python/pyarrow/tests/test_ipc.py b/python/pyarrow/tests/test_ipc.py index 407011d90b734..d38f45b5feff4 100644 --- a/python/pyarrow/tests/test_ipc.py +++ b/python/pyarrow/tests/test_ipc.py @@ -1226,10 +1226,15 @@ def __arrow_c_stream__(self, requested_schema=None): reader = pa.RecordBatchReader.from_stream(wrapper, schema=data[0].schema) assert reader.read_all() == expected - # If schema doesn't match, raises NotImplementedError - with pytest.raises(NotImplementedError): + # Passing a different but castable schema works + good_schema = pa.schema([pa.field("a", pa.int32())]) + reader = pa.RecordBatchReader.from_stream(wrapper, schema=good_schema) + assert reader.read_all() == expected.cast(good_schema) + + # If schema doesn't match, raises TypeError + with pytest.raises(pa.lib.ArrowTypeError, match='Field 0 cannot be cast'): pa.RecordBatchReader.from_stream( - wrapper, schema=pa.schema([pa.field('a', pa.int32())]) + wrapper, schema=pa.schema([pa.field('a', pa.list_(pa.int32()))]) ) # Proper type errors for wrong input @@ -1238,3 +1243,60 @@ def __arrow_c_stream__(self, requested_schema=None): with pytest.raises(TypeError): pa.RecordBatchReader.from_stream(expected, schema=data[0]) + + +def test_record_batch_reader_cast(): + schema_src = pa.schema([pa.field('a', pa.int64())]) + data = [ + pa.record_batch([pa.array([1, 2, 3], type=pa.int64())], names=['a']), + pa.record_batch([pa.array([4, 5, 6], type=pa.int64())], names=['a']), + ] + table_src = pa.Table.from_batches(data) + + # Cast to same type should always work + reader = pa.RecordBatchReader.from_batches(schema_src, data) + assert reader.cast(schema_src).read_all() == table_src + + # Check non-trivial cast + schema_dst = pa.schema([pa.field('a', pa.int32())]) + reader = pa.RecordBatchReader.from_batches(schema_src, data) + assert reader.cast(schema_dst).read_all() == table_src.cast(schema_dst) + + # Check error for field name/length mismatch + reader = pa.RecordBatchReader.from_batches(schema_src, data) + with pytest.raises(ValueError, match="Target schema's field names"): + reader.cast(pa.schema([])) + + # Check error for impossible cast in call to .cast() + reader = pa.RecordBatchReader.from_batches(schema_src, data) + with pytest.raises(pa.lib.ArrowTypeError, match='Field 0 cannot be cast'): + reader.cast(pa.schema([pa.field('a', pa.list_(pa.int32()))])) + + +def test_record_batch_reader_cast_nulls(): + schema_src = pa.schema([pa.field('a', pa.int64())]) + data_with_nulls = [ + pa.record_batch([pa.array([1, 2, None], type=pa.int64())], names=['a']), + ] + data_without_nulls = [ + pa.record_batch([pa.array([1, 2, 3], type=pa.int64())], names=['a']), + ] + table_with_nulls = pa.Table.from_batches(data_with_nulls) + table_without_nulls = pa.Table.from_batches(data_without_nulls) + + # Cast to nullable destination should work + reader = pa.RecordBatchReader.from_batches(schema_src, data_with_nulls) + schema_dst = pa.schema([pa.field('a', pa.int32())]) + assert reader.cast(schema_dst).read_all() == table_with_nulls.cast(schema_dst) + + # Cast to non-nullable destination should work if there are no nulls + reader = pa.RecordBatchReader.from_batches(schema_src, data_without_nulls) + schema_dst = pa.schema([pa.field('a', pa.int32(), nullable=False)]) + assert reader.cast(schema_dst).read_all() == table_without_nulls.cast(schema_dst) + + # Cast to non-nullable destination should error if there are nulls + # when the batch is pulled + reader = pa.RecordBatchReader.from_batches(schema_src, data_with_nulls) + casted_reader = reader.cast(schema_dst) + with pytest.raises(pa.lib.ArrowInvalid, match="Can't cast array"): + casted_reader.read_all() diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py index d6def54570581..f0fd5518de067 100644 --- a/python/pyarrow/tests/test_table.py +++ b/python/pyarrow/tests/test_table.py @@ -635,9 +635,18 @@ def __arrow_c_stream__(self, requested_schema=None): result = pa.table(wrapper, schema=data[0].schema) assert result == expected + # Passing a different schema will cast + good_schema = pa.schema([pa.field('a', pa.int32())]) + result = pa.table(wrapper, schema=good_schema) + assert result == expected.cast(good_schema) + # If schema doesn't match, raises NotImplementedError - with pytest.raises(NotImplementedError): - pa.table(wrapper, schema=pa.schema([pa.field('a', pa.int32())])) + with pytest.raises( + pa.lib.ArrowTypeError, match="Field 0 cannot be cast" + ): + pa.table( + wrapper, schema=pa.schema([pa.field('a', pa.list_(pa.int32()))]) + ) def test_recordbatch_itercolumns(): @@ -2620,6 +2629,25 @@ def test_record_batch_sort(): assert sorted_rb_dict["c"] == ["foobar", "bar", "foo", "car"] +def test_record_batch_cast(): + rb = pa.RecordBatch.from_arrays([ + pa.array([None, 1]), + pa.array([False, True]) + ], names=["a", "b"]) + new_schema = pa.schema([pa.field("a", "int64", nullable=True), + pa.field("b", "bool", nullable=False)]) + + assert rb.cast(new_schema).schema == new_schema + + # Casting a nullable field to non-nullable is invalid + rb = pa.RecordBatch.from_arrays([ + pa.array([None, 1]), + pa.array([None, True]) + ], names=["a", "b"]) + with pytest.raises(ValueError): + rb.cast(new_schema) + + @pytest.mark.parametrize("constructor", [pa.table, pa.record_batch]) def test_numpy_asarray(constructor): table = constructor([[1, 2, 3], [4.0, 5.0, 6.0]], names=["a", "b"])