From cbb4aa13b6ee7473c6ef607aeb953dba46a5dbb0 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Mon, 7 Aug 2023 13:01:22 -0700 Subject: [PATCH] fix ownership of c stream error --- arrow/src/ffi_stream.rs | 75 ++++++++++++++++++++++++++++++----------- 1 file changed, 56 insertions(+), 19 deletions(-) diff --git a/arrow/src/ffi_stream.rs b/arrow/src/ffi_stream.rs index 7d6689a89058..a9d2e8ab6bf2 100644 --- a/arrow/src/ffi_stream.rs +++ b/arrow/src/ffi_stream.rs @@ -54,6 +54,7 @@ //! } //! ``` +use std::ffi::CStr; use std::ptr::addr_of; use std::{ convert::TryFrom, @@ -120,7 +121,7 @@ unsafe extern "C" fn release_stream(stream: *mut FFI_ArrowArrayStream) { struct StreamPrivateData { batch_reader: Box, - last_error: String, + last_error: Option, } // The callback used to get array schema @@ -142,8 +143,12 @@ unsafe extern "C" fn get_next( // The callback used to get the error from last operation on the `FFI_ArrowArrayStream` unsafe extern "C" fn get_last_error(stream: *mut FFI_ArrowArrayStream) -> *const c_char { let mut ffi_stream = ExportedArrayStream { stream }; - let last_error = ffi_stream.get_last_error(); - CString::new(last_error.as_str()).unwrap().into_raw() + // The consumer should not take ownership of this string, we should return + // a const pointer to it. + match ffi_stream.get_last_error() { + Some(err_string) => err_string.as_ptr(), + None => std::ptr::null(), + } } impl Drop for FFI_ArrowArrayStream { @@ -160,7 +165,7 @@ impl FFI_ArrowArrayStream { pub fn new(batch_reader: Box) -> Self { let private_data = Box::new(StreamPrivateData { batch_reader, - last_error: String::new(), + last_error: None, }); Self { @@ -206,7 +211,10 @@ impl ExportedArrayStream { 0 } Err(ref err) => { - private_data.last_error = err.to_string(); + private_data.last_error = Some( + CString::new(err.to_string()) + .expect("Error string has a null byte in it."), + ); get_error_code(err) } } @@ -231,15 +239,18 @@ impl ExportedArrayStream { 0 } else { let err = &next_batch.unwrap_err(); - private_data.last_error = err.to_string(); + private_data.last_error = Some( + CString::new(err.to_string()) + .expect("Error string has a null byte in it."), + ); get_error_code(err) } } } } - pub fn get_last_error(&mut self) -> &String { - &self.get_private_data().last_error + pub fn get_last_error(&mut self) -> Option<&CString> { + self.get_private_data().last_error.as_ref() } } @@ -312,19 +323,15 @@ impl ArrowArrayStreamReader { /// Get the last error from `ArrowArrayStreamReader` fn get_stream_last_error(&mut self) -> Option { - self.stream.get_last_error?; - - let error_str = unsafe { - let c_str = - self.stream.get_last_error.unwrap()(&mut self.stream) as *mut c_char; - CString::from_raw(c_str).into_string() - }; + let get_last_error = self.stream.get_last_error?; - if let Err(err) = error_str { - Some(err.to_string()) - } else { - Some(error_str.unwrap()) + let error_str = unsafe { get_last_error(&mut self.stream) }; + if error_str.is_null() { + return None; } + + let error_str = unsafe { CStr::from_ptr(error_str) }; + Some(error_str.to_string_lossy().to_string()) } } @@ -381,6 +388,8 @@ pub unsafe fn export_reader_into_raw( #[cfg(test)] mod tests { + use arrow_schema::DataType; + use super::*; use crate::array::Int32Array; @@ -503,4 +512,32 @@ mod tests { _test_round_trip_import(vec![array.clone(), array.clone(), array]) } + + #[test] + fn test_error_import() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + + let iter = + Box::new(vec![Err(ArrowError::MemoryError("".to_string()))].into_iter()); + + let reader = TestRecordBatchReader::new(schema.clone(), iter); + + // Import through `FFI_ArrowArrayStream` as `ArrowArrayStreamReader` + let stream = FFI_ArrowArrayStream::new(reader); + let stream_reader = ArrowArrayStreamReader::try_new(stream).unwrap(); + + let imported_schema = stream_reader.schema(); + assert_eq!(imported_schema, schema); + + let mut produced_batches = vec![]; + for batch in stream_reader { + produced_batches.push(batch); + } + + // The results should outlive the lifetime of the stream itself. + assert_eq!(produced_batches.len(), 1); + assert!(produced_batches[0].is_err()); + + Ok(()) + } }