Skip to content

Commit

Permalink
WIP: Vector function UDF
Browse files Browse the repository at this point in the history
  • Loading branch information
icexelloss committed Jul 10, 2023
1 parent 12f45ba commit 09301a5
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 0 deletions.
9 changes: 9 additions & 0 deletions python/pyarrow/_compute.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2705,6 +2705,10 @@ cdef get_register_aggregate_function():
reg.register_func = RegisterAggregateFunction
return reg

cdef get_register_vector_function():
cdef RegisterUdf reg = RegisterUdf.__new__(RegisterUdf)
reg.register_func = RegisterVectorFunction
return reg

def register_scalar_function(func, function_name, function_doc, in_types, out_type,
func_registry=None):
Expand Down Expand Up @@ -2786,6 +2790,11 @@ def register_scalar_function(func, function_name, function_doc, in_types, out_ty
func, function_name, function_doc, in_types,
out_type, func_registry)

def register_vector_function(func, function_name, function_doc, in_types, out_type,
func_registry=None):
return _register_user_defined_function(get_register_vector_function(),
func, function_name, function_doc, in_types,
out_type, func_registry)

def register_aggregate_function(func, function_name, function_doc, in_types, out_type,
func_registry=None):
Expand Down
1 change: 1 addition & 0 deletions python/pyarrow/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
register_scalar_function,
register_tabular_function,
register_aggregate_function,
register_vector_function,
UdfContext,
# Expressions
Expression,
Expand Down
4 changes: 4 additions & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -2815,5 +2815,9 @@ cdef extern from "arrow/python/udf.h" namespace "arrow::py" nogil:
function[CallbackUdf] wrapper, const CUdfOptions& options,
CFunctionRegistry* registry)

CStatus RegisterVectorFunction(PyObject* function,
function[CallbackUdf] wrapper, const CUdfOptions& options,
CFunctionRegistry* registry)

CResult[shared_ptr[CRecordBatchReader]] CallTabularFunction(
const c_string& func_name, const vector[CDatum]& args, CFunctionRegistry* registry)
43 changes: 43 additions & 0 deletions python/pyarrow/src/arrow/python/udf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,40 @@ Status RegisterUdf(PyObject* user_function, compute::KernelInit kernel_init,
return Status::OK();
}

Status RegisterVectorUdf(PyObject* user_function, compute::KernelInit kernel_init,
UdfWrapperCallback wrapper, const UdfOptions& options,
compute::FunctionRegistry* registry) {
if (!PyCallable_Check(user_function)) {
return Status::TypeError("Expected a callable Python object.");
}
auto vector_func = std::make_shared<compute::VectorFunction>(
options.func_name, options.arity, options.func_doc);
Py_INCREF(user_function);
std::vector<compute::InputType> input_types;
for (const auto& in_dtype : options.input_types) {
input_types.emplace_back(in_dtype);
}
compute::OutputType output_type(options.output_type);

auto udf_data = std::make_shared<PythonUdf>(
std::make_shared<OwnedRefNoGIL>(user_function), wrapper,
TypeHolder::FromTypes(options.input_types), options.output_type);
compute::VectorKernel kernel(
compute::KernelSignature::Make(std::move(input_types), std::move(output_type),
options.arity.is_varargs),
PythonUdfExec, kernel_init, NULL);
kernel.data = std::move(udf_data);

kernel.mem_allocation = compute::MemAllocation::NO_PREALLOCATE;
kernel.null_handling = compute::NullHandling::COMPUTED_NO_PREALLOCATE;
RETURN_NOT_OK(vector_func->AddKernel(std::move(kernel)));
if (registry == NULLPTR) {
registry = compute::GetFunctionRegistry();
}
RETURN_NOT_OK(registry->AddFunction(std::move(vector_func)));
return Status::OK();
}

} // namespace

Status RegisterScalarFunction(PyObject* function, UdfWrapperCallback cb,
Expand All @@ -527,6 +561,14 @@ Status RegisterScalarFunction(PyObject* function, UdfWrapperCallback cb,
options, registry);
}

Status RegisterVectorFunction(PyObject* user_function, UdfWrapperCallback wrapper,
const UdfOptions& options,
compute::FunctionRegistry* registry) {
return RegisterVectorUdf(user_function,
PythonUdfKernelInit{std::make_shared<OwnedRefNoGIL>(user_function)},
wrapper, options, registry);
}

Status RegisterTabularFunction(PyObject* function, UdfWrapperCallback cb,
const UdfOptions& options,
compute::FunctionRegistry* registry) {
Expand All @@ -553,6 +595,7 @@ Status RegisterScalarAggregateFunction(PyObject* function, UdfWrapperCallback cb
}

// Py_INCREF here so that once a function is registered

// its refcount gets increased by 1 and doesn't get gced
// if all existing refs are gone
Py_INCREF(function);
Expand Down
5 changes: 5 additions & 0 deletions python/pyarrow/src/arrow/python/udf.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ Status ARROW_PYTHON_EXPORT RegisterAggregateFunction(
PyObject* user_function, UdfWrapperCallback wrapper, const UdfOptions& options,
compute::FunctionRegistry* registry = NULLPTR);

/// \brief register a Vector user-defined-function from Python
Status ARROW_PYTHON_EXPORT RegisterVectorFunction(
PyObject* user_function, UdfWrapperCallback wrapper, const UdfOptions& options,
compute::FunctionRegistry* registry = NULLPTR);

Result<std::shared_ptr<RecordBatchReader>> ARROW_PYTHON_EXPORT
CallTabularFunction(const std::string& func_name, const std::vector<Datum>& args,
compute::FunctionRegistry* registry = NULLPTR);
Expand Down
59 changes: 59 additions & 0 deletions python/pyarrow/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,43 @@ def raising_func(ctx):
return raising_func, func_name


@pytest.fixture(scope="session")
def unary_vector_func_fixture():
"""
Reigster a vector function
"""
import numpy as np
def pct_rank(ctx, x):
return pa.array(x.to_pandas().rank(pct=True))

func_name = "y=pct_rank(x)"
doc = empty_udf_doc
pc.register_vector_function(pct_rank, func_name, doc, {'x': pa.float64()}, pa.float64())

return pct_rank, func_name


@pytest.fixture(scope="session")
def struct_vector_func_fixture():
"""
Reigster a vector function that returns a struct array
"""
def pivot(ctx, k, v, c):
df = pa.RecordBatch.from_arrays([k, v, c], names=['k', 'v', 'c']).to_pandas()
df_pivot = df.pivot(columns='c', values='v', index='k').reset_index()
return pa.RecordBatch.from_pandas(df_pivot).to_struct_array()

func_name = "y=pivot(x)"
doc = empty_udf_doc
pc.register_vector_function(
pivot, func_name, doc,
{'k': pa.int64(), 'v': pa.float64(), 'c': pa.utf8()},
pa.struct([('k', pa.int64()), ('v1', pa.float64()), ('v2', pa.float64())])
)

return pivot, func_name


def check_scalar_function(func_fixture,
inputs, *,
run_in_dataset=True,
Expand Down Expand Up @@ -797,3 +834,25 @@ def test_hash_agg_random(sum_agg_func_fixture):
[("value", "sum")]).rename_columns(['id', 'value_sum_udf'])

assert result.sort_by('id') == expected.sort_by('id')


def test_vector_basic(unary_vector_func_fixture):
arr = pa.array([10.0, 20.0, 30.0, 40.0, 50.0], pa.float64())
result = pc.call_function("y=pct_rank(x)", [arr])
expected = pa.array(arr.to_pandas().rank(pct=True))
assert result == expected


def test_vector_struct(struct_vector_func_fixture):
k = pa.array(
[1, 1, 2, 2], pa.int64()
)
v = pa.array(
[1.0, 2.0, 3.0, 4.0], pa.float64()
)
c = pa.array(
['v1', 'v2', 'v1', 'v2']
)
result = pc.call_function("y=pivot(x)", [k, v, c])
expected = struct_vector_func_fixture[0](None, k, v, c)
assert result == expected

0 comments on commit 09301a5

Please sign in to comment.