Skip to content

Commit

Permalink
fix ownership of c stream error (#4660)
Browse files Browse the repository at this point in the history
* fix ownership of c stream error

* add pyarrow integration test
  • Loading branch information
wjones127 authored Aug 7, 2023
1 parent f16ceed commit 50f161e
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 19 deletions.
13 changes: 13 additions & 0 deletions arrow-pyarrow-integration-testing/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
use std::sync::Arc;

use arrow::array::new_empty_array;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;

Expand Down Expand Up @@ -140,6 +141,17 @@ fn round_trip_record_batch_reader(
Ok(obj)
}

#[pyfunction]
fn reader_return_errors(obj: PyArrowType<ArrowArrayStreamReader>) -> PyResult<()> {
// This makes sure we can correctly consume a RBR and return the error,
// ensuring the error can live beyond the lifetime of the RBR.
let batches = obj.0.collect::<Result<Vec<RecordBatch>, ArrowError>>();
match batches {
Ok(_) => Ok(()),
Err(err) => Err(PyValueError::new_err(err.to_string())),
}
}

#[pymodule]
fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(double))?;
Expand All @@ -153,5 +165,6 @@ fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) -> PyResult<()>
m.add_wrapped(wrap_pyfunction!(round_trip_array))?;
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch))?;
m.add_wrapped(wrap_pyfunction!(round_trip_record_batch_reader))?;
m.add_wrapped(wrap_pyfunction!(reader_return_errors))?;
Ok(())
}
12 changes: 12 additions & 0 deletions arrow-pyarrow-integration-testing/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,18 @@ def test_record_batch_reader():
got_batches = list(b)
assert got_batches == batches

def test_record_batch_reader_error():
schema = pa.schema([('ints', pa.list_(pa.int32()))])

def iter_batches():
yield pa.record_batch([[[1], [2, 42]]], schema)
raise ValueError("test error")

reader = pa.RecordBatchReader.from_batches(schema, iter_batches())

with pytest.raises(ValueError, match="test error"):
rust.reader_return_errors(reader)

def test_reject_other_classes():
# Arbitrary type that is not a PyArrow type
not_pyarrow = ["hello"]
Expand Down
75 changes: 56 additions & 19 deletions arrow/src/ffi_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
//! }
//! ```

use std::ffi::CStr;
use std::ptr::addr_of;
use std::{
convert::TryFrom,
Expand Down Expand Up @@ -120,7 +121,7 @@ unsafe extern "C" fn release_stream(stream: *mut FFI_ArrowArrayStream) {

struct StreamPrivateData {
batch_reader: Box<dyn RecordBatchReader + Send>,
last_error: String,
last_error: Option<CString>,
}

// The callback used to get array schema
Expand All @@ -142,8 +143,12 @@ unsafe extern "C" fn get_next(
// The callback used to get the error from last operation on the `FFI_ArrowArrayStream`
unsafe extern "C" fn get_last_error(stream: *mut FFI_ArrowArrayStream) -> *const c_char {
let mut ffi_stream = ExportedArrayStream { stream };
let last_error = ffi_stream.get_last_error();
CString::new(last_error.as_str()).unwrap().into_raw()
// The consumer should not take ownership of this string, we should return
// a const pointer to it.
match ffi_stream.get_last_error() {
Some(err_string) => err_string.as_ptr(),
None => std::ptr::null(),
}
}

impl Drop for FFI_ArrowArrayStream {
Expand All @@ -160,7 +165,7 @@ impl FFI_ArrowArrayStream {
pub fn new(batch_reader: Box<dyn RecordBatchReader + Send>) -> Self {
let private_data = Box::new(StreamPrivateData {
batch_reader,
last_error: String::new(),
last_error: None,
});

Self {
Expand Down Expand Up @@ -206,7 +211,10 @@ impl ExportedArrayStream {
0
}
Err(ref err) => {
private_data.last_error = err.to_string();
private_data.last_error = Some(
CString::new(err.to_string())
.expect("Error string has a null byte in it."),
);
get_error_code(err)
}
}
Expand All @@ -231,15 +239,18 @@ impl ExportedArrayStream {
0
} else {
let err = &next_batch.unwrap_err();
private_data.last_error = err.to_string();
private_data.last_error = Some(
CString::new(err.to_string())
.expect("Error string has a null byte in it."),
);
get_error_code(err)
}
}
}
}

pub fn get_last_error(&mut self) -> &String {
&self.get_private_data().last_error
pub fn get_last_error(&mut self) -> Option<&CString> {
self.get_private_data().last_error.as_ref()
}
}

Expand Down Expand Up @@ -312,19 +323,15 @@ impl ArrowArrayStreamReader {

/// Get the last error from `ArrowArrayStreamReader`
fn get_stream_last_error(&mut self) -> Option<String> {
self.stream.get_last_error?;

let error_str = unsafe {
let c_str =
self.stream.get_last_error.unwrap()(&mut self.stream) as *mut c_char;
CString::from_raw(c_str).into_string()
};
let get_last_error = self.stream.get_last_error?;

if let Err(err) = error_str {
Some(err.to_string())
} else {
Some(error_str.unwrap())
let error_str = unsafe { get_last_error(&mut self.stream) };
if error_str.is_null() {
return None;
}

let error_str = unsafe { CStr::from_ptr(error_str) };
Some(error_str.to_string_lossy().to_string())
}
}

Expand Down Expand Up @@ -381,6 +388,8 @@ pub unsafe fn export_reader_into_raw(

#[cfg(test)]
mod tests {
use arrow_schema::DataType;

use super::*;

use crate::array::Int32Array;
Expand Down Expand Up @@ -503,4 +512,32 @@ mod tests {

_test_round_trip_import(vec![array.clone(), array.clone(), array])
}

#[test]
fn test_error_import() -> Result<()> {
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)]));

let iter =
Box::new(vec![Err(ArrowError::MemoryError("".to_string()))].into_iter());

let reader = TestRecordBatchReader::new(schema.clone(), iter);

// Import through `FFI_ArrowArrayStream` as `ArrowArrayStreamReader`
let stream = FFI_ArrowArrayStream::new(reader);
let stream_reader = ArrowArrayStreamReader::try_new(stream).unwrap();

let imported_schema = stream_reader.schema();
assert_eq!(imported_schema, schema);

let mut produced_batches = vec![];
for batch in stream_reader {
produced_batches.push(batch);
}

// The results should outlive the lifetime of the stream itself.
assert_eq!(produced_batches.len(), 1);
assert!(produced_batches[0].is_err());

Ok(())
}
}

0 comments on commit 50f161e

Please sign in to comment.