From 50f161eafb4062ffa13e3399c49d8f98e8dbfb6d Mon Sep 17 00:00:00 2001 From: Will Jones Date: Mon, 7 Aug 2023 15:16:23 -0700 Subject: [PATCH] fix ownership of c stream error (#4660) * fix ownership of c stream error * add pyarrow integration test --- arrow-pyarrow-integration-testing/src/lib.rs | 13 ++++ .../tests/test_sql.py | 12 +++ arrow/src/ffi_stream.rs | 75 ++++++++++++++----- 3 files changed, 81 insertions(+), 19 deletions(-) diff --git a/arrow-pyarrow-integration-testing/src/lib.rs b/arrow-pyarrow-integration-testing/src/lib.rs index 89395bd2ed08..adcec769f247 100644 --- a/arrow-pyarrow-integration-testing/src/lib.rs +++ b/arrow-pyarrow-integration-testing/src/lib.rs @@ -21,6 +21,7 @@ use std::sync::Arc; use arrow::array::new_empty_array; +use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::wrap_pyfunction; @@ -140,6 +141,17 @@ fn round_trip_record_batch_reader( Ok(obj) } +#[pyfunction] +fn reader_return_errors(obj: PyArrowType) -> PyResult<()> { + // This makes sure we can correctly consume a RBR and return the error, + // ensuring the error can live beyond the lifetime of the RBR. + let batches = obj.0.collect::, ArrowError>>(); + match batches { + Ok(_) => Ok(()), + Err(err) => Err(PyValueError::new_err(err.to_string())), + } +} + #[pymodule] fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(double))?; @@ -153,5 +165,6 @@ fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) -> PyResult<()> m.add_wrapped(wrap_pyfunction!(round_trip_array))?; m.add_wrapped(wrap_pyfunction!(round_trip_record_batch))?; m.add_wrapped(wrap_pyfunction!(round_trip_record_batch_reader))?; + m.add_wrapped(wrap_pyfunction!(reader_return_errors))?; Ok(()) } diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py b/arrow-pyarrow-integration-testing/tests/test_sql.py index a7c6b34a4474..92782b9ed473 100644 --- a/arrow-pyarrow-integration-testing/tests/test_sql.py +++ b/arrow-pyarrow-integration-testing/tests/test_sql.py @@ -409,6 +409,18 @@ def test_record_batch_reader(): got_batches = list(b) assert got_batches == batches +def test_record_batch_reader_error(): + schema = pa.schema([('ints', pa.list_(pa.int32()))]) + + def iter_batches(): + yield pa.record_batch([[[1], [2, 42]]], schema) + raise ValueError("test error") + + reader = pa.RecordBatchReader.from_batches(schema, iter_batches()) + + with pytest.raises(ValueError, match="test error"): + rust.reader_return_errors(reader) + def test_reject_other_classes(): # Arbitrary type that is not a PyArrow type not_pyarrow = ["hello"] diff --git a/arrow/src/ffi_stream.rs b/arrow/src/ffi_stream.rs index 7d6689a89058..a9d2e8ab6bf2 100644 --- a/arrow/src/ffi_stream.rs +++ b/arrow/src/ffi_stream.rs @@ -54,6 +54,7 @@ //! } //! ``` +use std::ffi::CStr; use std::ptr::addr_of; use std::{ convert::TryFrom, @@ -120,7 +121,7 @@ unsafe extern "C" fn release_stream(stream: *mut FFI_ArrowArrayStream) { struct StreamPrivateData { batch_reader: Box, - last_error: String, + last_error: Option, } // The callback used to get array schema @@ -142,8 +143,12 @@ unsafe extern "C" fn get_next( // The callback used to get the error from last operation on the `FFI_ArrowArrayStream` unsafe extern "C" fn get_last_error(stream: *mut FFI_ArrowArrayStream) -> *const c_char { let mut ffi_stream = ExportedArrayStream { stream }; - let last_error = ffi_stream.get_last_error(); - CString::new(last_error.as_str()).unwrap().into_raw() + // The consumer should not take ownership of this string, we should return + // a const pointer to it. + match ffi_stream.get_last_error() { + Some(err_string) => err_string.as_ptr(), + None => std::ptr::null(), + } } impl Drop for FFI_ArrowArrayStream { @@ -160,7 +165,7 @@ impl FFI_ArrowArrayStream { pub fn new(batch_reader: Box) -> Self { let private_data = Box::new(StreamPrivateData { batch_reader, - last_error: String::new(), + last_error: None, }); Self { @@ -206,7 +211,10 @@ impl ExportedArrayStream { 0 } Err(ref err) => { - private_data.last_error = err.to_string(); + private_data.last_error = Some( + CString::new(err.to_string()) + .expect("Error string has a null byte in it."), + ); get_error_code(err) } } @@ -231,15 +239,18 @@ impl ExportedArrayStream { 0 } else { let err = &next_batch.unwrap_err(); - private_data.last_error = err.to_string(); + private_data.last_error = Some( + CString::new(err.to_string()) + .expect("Error string has a null byte in it."), + ); get_error_code(err) } } } } - pub fn get_last_error(&mut self) -> &String { - &self.get_private_data().last_error + pub fn get_last_error(&mut self) -> Option<&CString> { + self.get_private_data().last_error.as_ref() } } @@ -312,19 +323,15 @@ impl ArrowArrayStreamReader { /// Get the last error from `ArrowArrayStreamReader` fn get_stream_last_error(&mut self) -> Option { - self.stream.get_last_error?; - - let error_str = unsafe { - let c_str = - self.stream.get_last_error.unwrap()(&mut self.stream) as *mut c_char; - CString::from_raw(c_str).into_string() - }; + let get_last_error = self.stream.get_last_error?; - if let Err(err) = error_str { - Some(err.to_string()) - } else { - Some(error_str.unwrap()) + let error_str = unsafe { get_last_error(&mut self.stream) }; + if error_str.is_null() { + return None; } + + let error_str = unsafe { CStr::from_ptr(error_str) }; + Some(error_str.to_string_lossy().to_string()) } } @@ -381,6 +388,8 @@ pub unsafe fn export_reader_into_raw( #[cfg(test)] mod tests { + use arrow_schema::DataType; + use super::*; use crate::array::Int32Array; @@ -503,4 +512,32 @@ mod tests { _test_round_trip_import(vec![array.clone(), array.clone(), array]) } + + #[test] + fn test_error_import() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + + let iter = + Box::new(vec![Err(ArrowError::MemoryError("".to_string()))].into_iter()); + + let reader = TestRecordBatchReader::new(schema.clone(), iter); + + // Import through `FFI_ArrowArrayStream` as `ArrowArrayStreamReader` + let stream = FFI_ArrowArrayStream::new(reader); + let stream_reader = ArrowArrayStreamReader::try_new(stream).unwrap(); + + let imported_schema = stream_reader.schema(); + assert_eq!(imported_schema, schema); + + let mut produced_batches = vec![]; + for batch in stream_reader { + produced_batches.push(batch); + } + + // The results should outlive the lifetime of the stream itself. + assert_eq!(produced_batches.len(), 1); + assert!(produced_batches[0].is_err()); + + Ok(()) + } }