From c22df0503f39d348481e43c0a1274dca0b16209f Mon Sep 17 00:00:00 2001 From: Antoine Pitrou Date: Mon, 5 Aug 2024 14:27:52 +0200 Subject: [PATCH] GH-43487: [Python] Sanitize Python reference handling in UDF implementation 1. Remove spurious increfs (the function object is already incref'ed at an upper level) 2. Add unit test with an ephemeral Python function object 3. Streamline and improve Python reference handling --- python/pyarrow/src/arrow/python/udf.cc | 149 +++++++++---------------- python/pyarrow/tests/test_udf.py | 31 +++++ 2 files changed, 85 insertions(+), 95 deletions(-) diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index b6a862af8ca07..2c1e97c3ea03d 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -43,35 +43,18 @@ namespace py { namespace { struct PythonUdfKernelState : public compute::KernelState { + // NOTE: this KernelState constructor doesn't require the GIL. + // If it did, the corresponding KernelInit::operator() should be wrapped + // within SafeCallIntoPython (GH-43487). explicit PythonUdfKernelState(std::shared_ptr function) - : function(function) { - Py_INCREF(function->obj()); - } - - // function needs to be destroyed at process exit - // and Python may no longer be initialized. - ~PythonUdfKernelState() { - if (Py_IsFinalizing()) { - function->detach(); - } - } + : function(std::move(function)) {} std::shared_ptr function; }; struct PythonUdfKernelInit { explicit PythonUdfKernelInit(std::shared_ptr function) - : function(function) { - Py_INCREF(function->obj()); - } - - // function needs to be destroyed at process exit - // and Python may no longer be initialized. - ~PythonUdfKernelInit() { - if (Py_IsFinalizing()) { - function->detach(); - } - } + : function(std::move(function)) {} Result> operator()( compute::KernelContext*, const compute::KernelInitArgs&) { @@ -94,68 +77,56 @@ struct HashUdfAggregator : public compute::KernelState { virtual Status Finalize(KernelContext* ctx, Datum* out) = 0; }; -arrow::Status AggregateUdfConsume(compute::KernelContext* ctx, - const compute::ExecSpan& batch) { +Status AggregateUdfConsume(compute::KernelContext* ctx, const compute::ExecSpan& batch) { return checked_cast(ctx->state())->Consume(ctx, batch); } -arrow::Status AggregateUdfMerge(compute::KernelContext* ctx, compute::KernelState&& src, - compute::KernelState* dst) { +Status AggregateUdfMerge(compute::KernelContext* ctx, compute::KernelState&& src, + compute::KernelState* dst) { return checked_cast(dst)->MergeFrom(ctx, std::move(src)); } -arrow::Status AggregateUdfFinalize(compute::KernelContext* ctx, arrow::Datum* out) { +Status AggregateUdfFinalize(compute::KernelContext* ctx, arrow::Datum* out) { return checked_cast(ctx->state())->Finalize(ctx, out); } -arrow::Status HashAggregateUdfResize(KernelContext* ctx, int64_t size) { +Status HashAggregateUdfResize(KernelContext* ctx, int64_t size) { return checked_cast(ctx->state())->Resize(ctx, size); } -arrow::Status HashAggregateUdfConsume(KernelContext* ctx, const ExecSpan& batch) { +Status HashAggregateUdfConsume(KernelContext* ctx, const ExecSpan& batch) { return checked_cast(ctx->state())->Consume(ctx, batch); } -arrow::Status HashAggregateUdfMerge(KernelContext* ctx, KernelState&& src, - const ArrayData& group_id_mapping) { +Status HashAggregateUdfMerge(KernelContext* ctx, KernelState&& src, + const ArrayData& group_id_mapping) { return checked_cast(ctx->state()) ->Merge(ctx, std::move(src), group_id_mapping); } -arrow::Status HashAggregateUdfFinalize(KernelContext* ctx, Datum* out) { +Status HashAggregateUdfFinalize(KernelContext* ctx, Datum* out) { return checked_cast(ctx->state())->Finalize(ctx, out); } struct PythonTableUdfKernelInit { PythonTableUdfKernelInit(std::shared_ptr function_maker, UdfWrapperCallback cb) - : function_maker(function_maker), cb(cb) { - Py_INCREF(function_maker->obj()); - } - - // function needs to be destroyed at process exit - // and Python may no longer be initialized. - ~PythonTableUdfKernelInit() { - if (Py_IsFinalizing()) { - function_maker->detach(); - } - } + : function_maker(std::move(function_maker)), cb(std::move(cb)) {} Result> operator()( compute::KernelContext* ctx, const compute::KernelInitArgs&) { - UdfContext udf_context{ctx->memory_pool(), /*batch_length=*/0}; - std::unique_ptr function; - RETURN_NOT_OK(SafeCallIntoPython([this, &udf_context, &function] { - OwnedRef empty_tuple(PyTuple_New(0)); - function = std::make_unique( - cb(function_maker->obj(), udf_context, empty_tuple.obj())); - RETURN_NOT_OK(CheckPyError()); - return Status::OK(); - })); - if (!PyCallable_Check(function->obj())) { - return Status::TypeError("Expected a callable Python object."); - } - return std::make_unique(std::move(function)); + return SafeCallIntoPython( + [this, ctx]() -> Result> { + UdfContext udf_context{ctx->memory_pool(), /*batch_length=*/0}; + OwnedRef empty_tuple(PyTuple_New(0)); + auto function = std::make_shared( + cb(function_maker->obj(), udf_context, empty_tuple.obj())); + RETURN_NOT_OK(CheckPyError()); + if (!PyCallable_Check(function->obj())) { + return Status::TypeError("Expected a callable Python object."); + } + return std::make_unique(std::move(function)); + }); } std::shared_ptr function_maker; @@ -167,8 +138,9 @@ struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator { UdfWrapperCallback cb, std::vector> input_types, std::shared_ptr output_type) - : function(function), cb(std::move(cb)), output_type(std::move(output_type)) { - Py_INCREF(function->obj()); + : function(std::move(function)), + cb(std::move(cb)), + output_type(std::move(output_type)) { std::vector> fields; for (size_t i = 0; i < input_types.size(); i++) { fields.push_back(field("", input_types[i])); @@ -176,12 +148,6 @@ struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator { input_schema = schema(std::move(fields)); }; - ~PythonUdfScalarAggregatorImpl() override { - if (Py_IsFinalizing()) { - function->detach(); - } - } - Status Consume(compute::KernelContext* ctx, const compute::ExecSpan& batch) override { ARROW_ASSIGN_OR_RAISE( auto rb, batch.ToExecBatch().ToRecordBatch(input_schema, ctx->memory_pool())); @@ -263,8 +229,9 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator { UdfWrapperCallback cb, std::vector> input_types, std::shared_ptr output_type) - : function(function), cb(std::move(cb)), output_type(std::move(output_type)) { - Py_INCREF(function->obj()); + : function(std::move(function)), + cb(std::move(cb)), + output_type(std::move(output_type)) { std::vector> fields; fields.reserve(input_types.size()); for (size_t i = 0; i < input_types.size(); i++) { @@ -273,12 +240,6 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator { input_schema = schema(std::move(fields)); }; - ~PythonUdfHashAggregatorImpl() override { - if (Py_IsFinalizing()) { - function->detach(); - } - } - // same as ApplyGrouping in partition.cc // replicated the code here to avoid complicating the dependencies static Result ApplyGroupings( @@ -416,10 +377,10 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator { struct PythonUdf : public PythonUdfKernelState { PythonUdf(std::shared_ptr function, UdfWrapperCallback cb, std::vector input_types, compute::OutputType output_type) - : PythonUdfKernelState(function), - cb(cb), - input_types(input_types), - output_type(output_type) {} + : PythonUdfKernelState(std::move(function)), + cb(std::move(cb)), + input_types(std::move(input_types)), + output_type(std::move(output_type)) {} UdfWrapperCallback cb; std::vector input_types; @@ -440,7 +401,7 @@ struct PythonUdf : public PythonUdfKernelState { Status Exec(compute::KernelContext* ctx, const compute::ExecSpan& batch, compute::ExecResult* out) { auto state = arrow::internal::checked_cast(ctx->state()); - std::shared_ptr& function = state->function; + PyObject* function = state->function->obj(); const int num_args = batch.num_values(); UdfContext udf_context{ctx->memory_pool(), batch.length}; @@ -458,7 +419,7 @@ struct PythonUdf : public PythonUdfKernelState { } } - OwnedRef result(cb(function->obj(), udf_context, arg_tuple.obj())); + OwnedRef result(cb(function, udf_context, arg_tuple.obj())); RETURN_NOT_OK(CheckPyError()); // unwrapping the output for expected output type if (is_array(result.obj())) { @@ -497,12 +458,13 @@ Status RegisterUdf(PyObject* function, compute::KernelInit kernel_init, } auto scalar_func = std::make_shared(options.func_name, options.arity, options.func_doc); - Py_INCREF(function); std::vector input_types; for (const auto& in_dtype : options.input_types) { input_types.emplace_back(in_dtype); } compute::OutputType output_type(options.output_type); + // Take reference before wrapping with OwnedRefNoGIL + Py_INCREF(function); auto udf_data = std::make_shared( std::make_shared(function), cb, TypeHolder::FromTypes(options.input_types), options.output_type); @@ -565,11 +527,6 @@ Status RegisterScalarAggregateFunction(PyObject* function, UdfWrapperCallback cb registry = compute::GetFunctionRegistry(); } - // Py_INCREF here so that once a function is registered - // its refcount gets increased by 1 and doesn't get gced - // if all existing refs are gone - Py_INCREF(function); - static auto default_scalar_aggregate_options = compute::ScalarAggregateOptions::Defaults(); auto aggregate_func = std::make_shared( @@ -582,12 +539,16 @@ Status RegisterScalarAggregateFunction(PyObject* function, UdfWrapperCallback cb } compute::OutputType output_type(options.output_type); - compute::KernelInit init = [cb, function, options](compute::KernelContext* ctx, - const compute::KernelInitArgs& args) + // Take reference before wrapping with OwnedRefNoGIL + Py_INCREF(function); + auto function_ref = std::make_shared(function); + + compute::KernelInit init = [cb, function_ref, options]( + compute::KernelContext* ctx, + const compute::KernelInitArgs& args) -> Result> { return std::make_unique( - std::make_shared(function), cb, options.input_types, - options.output_type); + function_ref, cb, options.input_types, options.output_type); }; auto sig = compute::KernelSignature::Make( @@ -638,10 +599,6 @@ Status RegisterHashAggregateFunction(PyObject* function, UdfWrapperCallback cb, registry = compute::GetFunctionRegistry(); } - // Py_INCREF here so that once a function is registered - // its refcount gets increased by 1 and doesn't get gced - // if all existing refs are gone - Py_INCREF(function); UdfOptions hash_options = AdjustForHashAggregate(options); std::vector input_types; @@ -656,13 +613,15 @@ Status RegisterHashAggregateFunction(PyObject* function, UdfWrapperCallback cb, hash_options.func_name, hash_options.arity, hash_options.func_doc, &default_hash_aggregate_options); - compute::KernelInit init = [function, cb, hash_options]( + // Take reference before wrapping with OwnedRefNoGIL + Py_INCREF(function); + auto function_ref = std::make_shared(function); + compute::KernelInit init = [function_ref, cb, hash_options]( compute::KernelContext* ctx, const compute::KernelInitArgs& args) -> Result> { return std::make_unique( - std::make_shared(function), cb, hash_options.input_types, - hash_options.output_type); + function_ref, cb, hash_options.input_types, hash_options.output_type); }; auto sig = compute::KernelSignature::Make( diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index c8e376fefb3b8..22fefbbb58ba9 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -219,6 +219,31 @@ def nullary_func(context): return nullary_func, func_name +@pytest.fixture(scope="session") +def ephemeral_nullary_func_fixture(): + """ + Register a nullary scalar function with an ephemeral Python function. + This stresses that the Python function object is properly kept alive by the + registered function. + """ + def nullary_func(context): + return pa.array([42] * context.batch_length, type=pa.int64(), + memory_pool=context.memory_pool) + + func_doc = { + "summary": "random function", + "description": "generates a random value" + } + func_name = "test_ephemeral_nullary_func" + pc.register_scalar_function(nullary_func, + func_name, + func_doc, + {}, + pa.int64()) + + return func_name + + @pytest.fixture(scope="session") def wrong_output_type_func_fixture(): """ @@ -505,6 +530,12 @@ def test_nullary_function(nullary_func_fixture): batch_length=1) +def test_ephemeral_function(ephemeral_nullary_func_fixture): + name = ephemeral_nullary_func_fixture + result = pc.call_function(name, [], length=1) + assert result.to_pylist() == [42] + + def test_wrong_output_type(wrong_output_type_func_fixture): _, func_name = wrong_output_type_func_fixture