diff --git a/python/pyarrow/_csv.pyx b/python/pyarrow/_csv.pyx index e532d8d8ab22a..508488c0c3b3c 100644 --- a/python/pyarrow/_csv.pyx +++ b/python/pyarrow/_csv.pyx @@ -26,8 +26,7 @@ from collections.abc import Mapping from pyarrow.includes.common cimport * from pyarrow.includes.libarrow cimport * -from pyarrow.includes.libarrow_python cimport (MakeInvalidRowHandler, - PyInvalidRowCallback) +from pyarrow.includes.libarrow_python cimport * from pyarrow.lib cimport (check_status, Field, MemoryPool, Schema, RecordBatchReader, ensure_type, maybe_unbox_memory_pool, get_input_stream, @@ -1251,7 +1250,7 @@ def read_csv(input_file, read_options=None, parse_options=None, CCSVParseOptions c_parse_options CCSVConvertOptions c_convert_options CIOContext io_context - shared_ptr[CCSVReader] reader + SharedPtrNoGIL[CCSVReader] reader shared_ptr[CTable] table _get_reader(input_file, read_options, &stream) diff --git a/python/pyarrow/_dataset.pxd b/python/pyarrow/_dataset.pxd index 210e5558009ec..bee9fc1f0987a 100644 --- a/python/pyarrow/_dataset.pxd +++ b/python/pyarrow/_dataset.pxd @@ -31,7 +31,7 @@ cdef CFileSource _make_file_source(object file, FileSystem filesystem=*) cdef class DatasetFactory(_Weakrefable): cdef: - shared_ptr[CDatasetFactory] wrapped + SharedPtrNoGIL[CDatasetFactory] wrapped CDatasetFactory* factory cdef init(self, const shared_ptr[CDatasetFactory]& sp) @@ -45,7 +45,7 @@ cdef class DatasetFactory(_Weakrefable): cdef class Dataset(_Weakrefable): cdef: - shared_ptr[CDataset] wrapped + SharedPtrNoGIL[CDataset] wrapped CDataset* dataset public dict _scan_options @@ -59,7 +59,7 @@ cdef class Dataset(_Weakrefable): cdef class Scanner(_Weakrefable): cdef: - shared_ptr[CScanner] wrapped + SharedPtrNoGIL[CScanner] wrapped CScanner* scanner cdef void init(self, const shared_ptr[CScanner]& sp) @@ -122,7 +122,7 @@ cdef class FileWriteOptions(_Weakrefable): cdef class Fragment(_Weakrefable): cdef: - shared_ptr[CFragment] wrapped + SharedPtrNoGIL[CFragment] wrapped CFragment* fragment cdef void init(self, const shared_ptr[CFragment]& sp) diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index 48ee676915311..d7d69965d000a 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -3227,7 +3227,7 @@ cdef class RecordBatchIterator(_Weakrefable): object iterator_owner # Iterator is a non-POD type and Cython uses offsetof, leading # to a compiler warning unless wrapped like so - shared_ptr[CRecordBatchIterator] iterator + SharedPtrNoGIL[CRecordBatchIterator] iterator def __init__(self): _forbid_instantiation(self.__class__, subclasses_instead=False) @@ -3273,7 +3273,7 @@ cdef class TaggedRecordBatchIterator(_Weakrefable): """An iterator over a sequence of record batches with fragments.""" cdef: object iterator_owner - shared_ptr[CTaggedRecordBatchIterator] iterator + SharedPtrNoGIL[CTaggedRecordBatchIterator] iterator def __init__(self): _forbid_instantiation(self.__class__, subclasses_instead=False) diff --git a/python/pyarrow/_parquet.pyx b/python/pyarrow/_parquet.pyx index 48091367b2ff8..089ed7c75ce58 100644 --- a/python/pyarrow/_parquet.pyx +++ b/python/pyarrow/_parquet.pyx @@ -24,6 +24,7 @@ import warnings from cython.operator cimport dereference as deref from pyarrow.includes.common cimport * from pyarrow.includes.libarrow cimport * +from pyarrow.includes.libarrow_python cimport * from pyarrow.lib cimport (_Weakrefable, Buffer, Schema, check_status, MemoryPool, maybe_unbox_memory_pool, @@ -1165,7 +1166,7 @@ cdef class ParquetReader(_Weakrefable): cdef: object source CMemoryPool* pool - unique_ptr[FileReader] reader + UniquePtrNoGIL[FileReader] reader FileMetaData _metadata shared_ptr[CRandomAccessFile] rd_handle @@ -1334,7 +1335,7 @@ cdef class ParquetReader(_Weakrefable): vector[int] c_row_groups vector[int] c_column_indices shared_ptr[CRecordBatch] record_batch - unique_ptr[CRecordBatchReader] recordbatchreader + UniquePtrNoGIL[CRecordBatchReader] recordbatchreader self.set_batch_size(batch_size) @@ -1366,7 +1367,6 @@ cdef class ParquetReader(_Weakrefable): check_status( recordbatchreader.get().ReadNext(&record_batch) ) - if record_batch.get() == NULL: break diff --git a/python/pyarrow/includes/libarrow_python.pxd b/python/pyarrow/includes/libarrow_python.pxd index 4d109fc660e08..e1df2fe61d8d2 100644 --- a/python/pyarrow/includes/libarrow_python.pxd +++ b/python/pyarrow/includes/libarrow_python.pxd @@ -261,6 +261,13 @@ cdef extern from "arrow/python/common.h" namespace "arrow::py": void RestorePyError(const CStatus& status) except * +cdef extern from "arrow/python/common.h" namespace "arrow::py" nogil: + cdef cppclass SharedPtrNoGIL[T](shared_ptr[T]): + pass + cdef cppclass UniquePtrNoGIL[T,DELETER=*](unique_ptr[T,DELETER]): + pass + + cdef extern from "arrow/python/inference.h" namespace "arrow::py": c_bool IsPyBool(object o) c_bool IsPyInt(object o) diff --git a/python/pyarrow/ipc.pxi b/python/pyarrow/ipc.pxi index fcb9eb729ef04..5d20a4f8b72cb 100644 --- a/python/pyarrow/ipc.pxi +++ b/python/pyarrow/ipc.pxi @@ -977,7 +977,7 @@ cdef _wrap_record_batch_with_metadata(CRecordBatchWithMetadata c): cdef class _RecordBatchFileReader(_Weakrefable): cdef: - shared_ptr[CRecordBatchFileReader] reader + SharedPtrNoGIL[CRecordBatchFileReader] reader shared_ptr[CRandomAccessFile] file CIpcReadOptions options diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd index 63ebe6aea8233..ae197eca1ca6b 100644 --- a/python/pyarrow/lib.pxd +++ b/python/pyarrow/lib.pxd @@ -552,12 +552,12 @@ cdef class CompressedOutputStream(NativeFile): cdef class _CRecordBatchWriter(_Weakrefable): cdef: - shared_ptr[CRecordBatchWriter] writer + SharedPtrNoGIL[CRecordBatchWriter] writer cdef class RecordBatchReader(_Weakrefable): cdef: - shared_ptr[CRecordBatchReader] reader + SharedPtrNoGIL[CRecordBatchReader] reader cdef class Codec(_Weakrefable): diff --git a/python/pyarrow/src/arrow/python/common.h b/python/pyarrow/src/arrow/python/common.h index e36c0834fd424..0d9c46a997ba3 100644 --- a/python/pyarrow/src/arrow/python/common.h +++ b/python/pyarrow/src/arrow/python/common.h @@ -235,6 +235,50 @@ class ARROW_PYTHON_EXPORT OwnedRefNoGIL : public OwnedRef { } }; +template +class SharedPtrNoGIL : public std::shared_ptr { + using Base = std::shared_ptr; + public: + template + SharedPtrNoGIL(Args&&... args) + : Base(std::forward(args)...) {} + + ~SharedPtrNoGIL() { + if (Py_IsInitialized() && PyGILState_Check()) { + PyReleaseGIL release; + Base::reset(); + } + } + + template + SharedPtrNoGIL& operator=(V&& v) { + Base::operator=(std::forward(v)); + return *this; + } +}; + +template +class UniquePtrNoGIL : public std::unique_ptr { + using Base = std::unique_ptr; + public: + template + UniquePtrNoGIL(Args&&... args) + : Base(std::forward(args)...) {} + + ~UniquePtrNoGIL() { + if (Py_IsInitialized() && PyGILState_Check()) { + PyReleaseGIL release; + Base::reset(); + } + } + + template + UniquePtrNoGIL& operator=(V&& v) { + Base::operator=(std::forward(v)); + return *this; + } +}; + template struct BoundFunction;