Skip to content

Commit

Permalink
apacheGH-43487: [Python] Sanitize Python reference handling in UDF im…
Browse files Browse the repository at this point in the history
…plementation

1. Remove spurious increfs (the function object is already incref'ed at an upper level)
2. Add unit test with an ephemeral Python function object
3. Streamline and improve Python reference handling
  • Loading branch information
pitrou committed Aug 5, 2024
1 parent 66cb749 commit c22df05
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 95 deletions.
149 changes: 54 additions & 95 deletions python/pyarrow/src/arrow/python/udf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,35 +43,18 @@ namespace py {
namespace {

struct PythonUdfKernelState : public compute::KernelState {
// NOTE: this KernelState constructor doesn't require the GIL.
// If it did, the corresponding KernelInit::operator() should be wrapped
// within SafeCallIntoPython (GH-43487).
explicit PythonUdfKernelState(std::shared_ptr<OwnedRefNoGIL> function)
: function(function) {
Py_INCREF(function->obj());
}

// function needs to be destroyed at process exit
// and Python may no longer be initialized.
~PythonUdfKernelState() {
if (Py_IsFinalizing()) {
function->detach();
}
}
: function(std::move(function)) {}

std::shared_ptr<OwnedRefNoGIL> function;
};

struct PythonUdfKernelInit {
explicit PythonUdfKernelInit(std::shared_ptr<OwnedRefNoGIL> function)
: function(function) {
Py_INCREF(function->obj());
}

// function needs to be destroyed at process exit
// and Python may no longer be initialized.
~PythonUdfKernelInit() {
if (Py_IsFinalizing()) {
function->detach();
}
}
: function(std::move(function)) {}

Result<std::unique_ptr<compute::KernelState>> operator()(
compute::KernelContext*, const compute::KernelInitArgs&) {
Expand All @@ -94,68 +77,56 @@ struct HashUdfAggregator : public compute::KernelState {
virtual Status Finalize(KernelContext* ctx, Datum* out) = 0;
};

arrow::Status AggregateUdfConsume(compute::KernelContext* ctx,
const compute::ExecSpan& batch) {
Status AggregateUdfConsume(compute::KernelContext* ctx, const compute::ExecSpan& batch) {
return checked_cast<ScalarUdfAggregator*>(ctx->state())->Consume(ctx, batch);
}

arrow::Status AggregateUdfMerge(compute::KernelContext* ctx, compute::KernelState&& src,
compute::KernelState* dst) {
Status AggregateUdfMerge(compute::KernelContext* ctx, compute::KernelState&& src,
compute::KernelState* dst) {
return checked_cast<ScalarUdfAggregator*>(dst)->MergeFrom(ctx, std::move(src));
}

arrow::Status AggregateUdfFinalize(compute::KernelContext* ctx, arrow::Datum* out) {
Status AggregateUdfFinalize(compute::KernelContext* ctx, arrow::Datum* out) {
return checked_cast<ScalarUdfAggregator*>(ctx->state())->Finalize(ctx, out);
}

arrow::Status HashAggregateUdfResize(KernelContext* ctx, int64_t size) {
Status HashAggregateUdfResize(KernelContext* ctx, int64_t size) {
return checked_cast<HashUdfAggregator*>(ctx->state())->Resize(ctx, size);
}

arrow::Status HashAggregateUdfConsume(KernelContext* ctx, const ExecSpan& batch) {
Status HashAggregateUdfConsume(KernelContext* ctx, const ExecSpan& batch) {
return checked_cast<HashUdfAggregator*>(ctx->state())->Consume(ctx, batch);
}

arrow::Status HashAggregateUdfMerge(KernelContext* ctx, KernelState&& src,
const ArrayData& group_id_mapping) {
Status HashAggregateUdfMerge(KernelContext* ctx, KernelState&& src,
const ArrayData& group_id_mapping) {
return checked_cast<HashUdfAggregator*>(ctx->state())
->Merge(ctx, std::move(src), group_id_mapping);
}

arrow::Status HashAggregateUdfFinalize(KernelContext* ctx, Datum* out) {
Status HashAggregateUdfFinalize(KernelContext* ctx, Datum* out) {
return checked_cast<HashUdfAggregator*>(ctx->state())->Finalize(ctx, out);
}

struct PythonTableUdfKernelInit {
PythonTableUdfKernelInit(std::shared_ptr<OwnedRefNoGIL> function_maker,
UdfWrapperCallback cb)
: function_maker(function_maker), cb(cb) {
Py_INCREF(function_maker->obj());
}

// function needs to be destroyed at process exit
// and Python may no longer be initialized.
~PythonTableUdfKernelInit() {
if (Py_IsFinalizing()) {
function_maker->detach();
}
}
: function_maker(std::move(function_maker)), cb(std::move(cb)) {}

Result<std::unique_ptr<compute::KernelState>> operator()(
compute::KernelContext* ctx, const compute::KernelInitArgs&) {
UdfContext udf_context{ctx->memory_pool(), /*batch_length=*/0};
std::unique_ptr<OwnedRefNoGIL> function;
RETURN_NOT_OK(SafeCallIntoPython([this, &udf_context, &function] {
OwnedRef empty_tuple(PyTuple_New(0));
function = std::make_unique<OwnedRefNoGIL>(
cb(function_maker->obj(), udf_context, empty_tuple.obj()));
RETURN_NOT_OK(CheckPyError());
return Status::OK();
}));
if (!PyCallable_Check(function->obj())) {
return Status::TypeError("Expected a callable Python object.");
}
return std::make_unique<PythonUdfKernelState>(std::move(function));
return SafeCallIntoPython(
[this, ctx]() -> Result<std::unique_ptr<compute::KernelState>> {
UdfContext udf_context{ctx->memory_pool(), /*batch_length=*/0};
OwnedRef empty_tuple(PyTuple_New(0));
auto function = std::make_shared<OwnedRefNoGIL>(
cb(function_maker->obj(), udf_context, empty_tuple.obj()));
RETURN_NOT_OK(CheckPyError());
if (!PyCallable_Check(function->obj())) {
return Status::TypeError("Expected a callable Python object.");
}
return std::make_unique<PythonUdfKernelState>(std::move(function));
});
}

std::shared_ptr<OwnedRefNoGIL> function_maker;
Expand All @@ -167,21 +138,16 @@ struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator {
UdfWrapperCallback cb,
std::vector<std::shared_ptr<DataType>> input_types,
std::shared_ptr<DataType> output_type)
: function(function), cb(std::move(cb)), output_type(std::move(output_type)) {
Py_INCREF(function->obj());
: function(std::move(function)),
cb(std::move(cb)),
output_type(std::move(output_type)) {
std::vector<std::shared_ptr<Field>> fields;
for (size_t i = 0; i < input_types.size(); i++) {
fields.push_back(field("", input_types[i]));
}
input_schema = schema(std::move(fields));
};

~PythonUdfScalarAggregatorImpl() override {
if (Py_IsFinalizing()) {
function->detach();
}
}

Status Consume(compute::KernelContext* ctx, const compute::ExecSpan& batch) override {
ARROW_ASSIGN_OR_RAISE(
auto rb, batch.ToExecBatch().ToRecordBatch(input_schema, ctx->memory_pool()));
Expand Down Expand Up @@ -263,8 +229,9 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator {
UdfWrapperCallback cb,
std::vector<std::shared_ptr<DataType>> input_types,
std::shared_ptr<DataType> output_type)
: function(function), cb(std::move(cb)), output_type(std::move(output_type)) {
Py_INCREF(function->obj());
: function(std::move(function)),
cb(std::move(cb)),
output_type(std::move(output_type)) {
std::vector<std::shared_ptr<Field>> fields;
fields.reserve(input_types.size());
for (size_t i = 0; i < input_types.size(); i++) {
Expand All @@ -273,12 +240,6 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator {
input_schema = schema(std::move(fields));
};

~PythonUdfHashAggregatorImpl() override {
if (Py_IsFinalizing()) {
function->detach();
}
}

// same as ApplyGrouping in partition.cc
// replicated the code here to avoid complicating the dependencies
static Result<RecordBatchVector> ApplyGroupings(
Expand Down Expand Up @@ -416,10 +377,10 @@ struct PythonUdfHashAggregatorImpl : public HashUdfAggregator {
struct PythonUdf : public PythonUdfKernelState {
PythonUdf(std::shared_ptr<OwnedRefNoGIL> function, UdfWrapperCallback cb,
std::vector<TypeHolder> input_types, compute::OutputType output_type)
: PythonUdfKernelState(function),
cb(cb),
input_types(input_types),
output_type(output_type) {}
: PythonUdfKernelState(std::move(function)),
cb(std::move(cb)),
input_types(std::move(input_types)),
output_type(std::move(output_type)) {}

UdfWrapperCallback cb;
std::vector<TypeHolder> input_types;
Expand All @@ -440,7 +401,7 @@ struct PythonUdf : public PythonUdfKernelState {
Status Exec(compute::KernelContext* ctx, const compute::ExecSpan& batch,
compute::ExecResult* out) {
auto state = arrow::internal::checked_cast<PythonUdfKernelState*>(ctx->state());
std::shared_ptr<OwnedRefNoGIL>& function = state->function;
PyObject* function = state->function->obj();
const int num_args = batch.num_values();
UdfContext udf_context{ctx->memory_pool(), batch.length};

Expand All @@ -458,7 +419,7 @@ struct PythonUdf : public PythonUdfKernelState {
}
}

OwnedRef result(cb(function->obj(), udf_context, arg_tuple.obj()));
OwnedRef result(cb(function, udf_context, arg_tuple.obj()));
RETURN_NOT_OK(CheckPyError());
// unwrapping the output for expected output type
if (is_array(result.obj())) {
Expand Down Expand Up @@ -497,12 +458,13 @@ Status RegisterUdf(PyObject* function, compute::KernelInit kernel_init,
}
auto scalar_func =
std::make_shared<Function>(options.func_name, options.arity, options.func_doc);
Py_INCREF(function);
std::vector<compute::InputType> input_types;
for (const auto& in_dtype : options.input_types) {
input_types.emplace_back(in_dtype);
}
compute::OutputType output_type(options.output_type);
// Take reference before wrapping with OwnedRefNoGIL
Py_INCREF(function);
auto udf_data = std::make_shared<PythonUdf>(
std::make_shared<OwnedRefNoGIL>(function), cb,
TypeHolder::FromTypes(options.input_types), options.output_type);
Expand Down Expand Up @@ -565,11 +527,6 @@ Status RegisterScalarAggregateFunction(PyObject* function, UdfWrapperCallback cb
registry = compute::GetFunctionRegistry();
}

// Py_INCREF here so that once a function is registered
// its refcount gets increased by 1 and doesn't get gced
// if all existing refs are gone
Py_INCREF(function);

static auto default_scalar_aggregate_options =
compute::ScalarAggregateOptions::Defaults();
auto aggregate_func = std::make_shared<compute::ScalarAggregateFunction>(
Expand All @@ -582,12 +539,16 @@ Status RegisterScalarAggregateFunction(PyObject* function, UdfWrapperCallback cb
}
compute::OutputType output_type(options.output_type);

compute::KernelInit init = [cb, function, options](compute::KernelContext* ctx,
const compute::KernelInitArgs& args)
// Take reference before wrapping with OwnedRefNoGIL
Py_INCREF(function);
auto function_ref = std::make_shared<OwnedRefNoGIL>(function);

compute::KernelInit init = [cb, function_ref, options](
compute::KernelContext* ctx,
const compute::KernelInitArgs& args)
-> Result<std::unique_ptr<compute::KernelState>> {
return std::make_unique<PythonUdfScalarAggregatorImpl>(
std::make_shared<OwnedRefNoGIL>(function), cb, options.input_types,
options.output_type);
function_ref, cb, options.input_types, options.output_type);
};

auto sig = compute::KernelSignature::Make(
Expand Down Expand Up @@ -638,10 +599,6 @@ Status RegisterHashAggregateFunction(PyObject* function, UdfWrapperCallback cb,
registry = compute::GetFunctionRegistry();
}

// Py_INCREF here so that once a function is registered
// its refcount gets increased by 1 and doesn't get gced
// if all existing refs are gone
Py_INCREF(function);
UdfOptions hash_options = AdjustForHashAggregate(options);

std::vector<compute::InputType> input_types;
Expand All @@ -656,13 +613,15 @@ Status RegisterHashAggregateFunction(PyObject* function, UdfWrapperCallback cb,
hash_options.func_name, hash_options.arity, hash_options.func_doc,
&default_hash_aggregate_options);

compute::KernelInit init = [function, cb, hash_options](
// Take reference before wrapping with OwnedRefNoGIL
Py_INCREF(function);
auto function_ref = std::make_shared<OwnedRefNoGIL>(function);
compute::KernelInit init = [function_ref, cb, hash_options](
compute::KernelContext* ctx,
const compute::KernelInitArgs& args)
-> Result<std::unique_ptr<compute::KernelState>> {
return std::make_unique<PythonUdfHashAggregatorImpl>(
std::make_shared<OwnedRefNoGIL>(function), cb, hash_options.input_types,
hash_options.output_type);
function_ref, cb, hash_options.input_types, hash_options.output_type);
};

auto sig = compute::KernelSignature::Make(
Expand Down
31 changes: 31 additions & 0 deletions python/pyarrow/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,31 @@ def nullary_func(context):
return nullary_func, func_name


@pytest.fixture(scope="session")
def ephemeral_nullary_func_fixture():
"""
Register a nullary scalar function with an ephemeral Python function.
This stresses that the Python function object is properly kept alive by the
registered function.
"""
def nullary_func(context):
return pa.array([42] * context.batch_length, type=pa.int64(),
memory_pool=context.memory_pool)

func_doc = {
"summary": "random function",
"description": "generates a random value"
}
func_name = "test_ephemeral_nullary_func"
pc.register_scalar_function(nullary_func,
func_name,
func_doc,
{},
pa.int64())

return func_name


@pytest.fixture(scope="session")
def wrong_output_type_func_fixture():
"""
Expand Down Expand Up @@ -505,6 +530,12 @@ def test_nullary_function(nullary_func_fixture):
batch_length=1)


def test_ephemeral_function(ephemeral_nullary_func_fixture):
name = ephemeral_nullary_func_fixture
result = pc.call_function(name, [], length=1)
assert result.to_pylist() == [42]


def test_wrong_output_type(wrong_output_type_func_fixture):
_, func_name = wrong_output_type_func_fixture

Expand Down

0 comments on commit c22df05

Please sign in to comment.