Skip to content

Commit

Permalink
ARROW-17521: [Python] Add python bindings for NamedTableProvider for …
Browse files Browse the repository at this point in the history
…Substrait consumer (apache#14024)

This PR includes a basic version to use NamedTable feature in Substrait. The idea is to provide the flexibility to write Python tests with in-memory PyArrow tables. 

Authored-by: Vibhatha Abeykoon <vibhatha@gmail.com>
Signed-off-by: Weston Pace <weston.pace@gmail.com>
  • Loading branch information
vibhatha authored Sep 14, 2022
1 parent eb00620 commit 72b539f
Show file tree
Hide file tree
Showing 12 changed files with 289 additions and 28 deletions.
4 changes: 4 additions & 0 deletions cpp/src/arrow/compute/exec/exec_plan.cc
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,10 @@ Declaration Declaration::Sequence(std::vector<Declaration> decls) {
return out;
}

bool Declaration::IsValid(ExecFactoryRegistry* registry) const {
return !this->factory_name.empty() && this->options != nullptr;
}

namespace internal {

void RegisterSourceNode(ExecFactoryRegistry*);
Expand Down
5 changes: 5 additions & 0 deletions cpp/src/arrow/compute/exec/exec_plan.h
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,8 @@ inline Result<ExecNode*> MakeExecNode(
struct ARROW_EXPORT Declaration {
using Input = util::Variant<ExecNode*, Declaration>;

Declaration() {}

Declaration(std::string factory_name, std::vector<Input> inputs,
std::shared_ptr<ExecNodeOptions> options, std::string label)
: factory_name{std::move(factory_name)},
Expand Down Expand Up @@ -514,6 +516,9 @@ struct ARROW_EXPORT Declaration {
Result<ExecNode*> AddToPlan(ExecPlan* plan, ExecFactoryRegistry* registry =
default_exec_factory_registry()) const;

// Validate a declaration
bool IsValid(ExecFactoryRegistry* registry = default_exec_factory_registry()) const;

std::string factory_name;
std::vector<Input> inputs;
std::shared_ptr<ExecNodeOptions> options;
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/arrow/engine/substrait/function_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ void CheckValidTestCases(const std::vector<FunctionTestCase>& valid_cases) {
ASSERT_FINISHES_OK(plan->finished());

// Could also modify the Substrait plan with an emit to drop the leading columns
ASSERT_OK_AND_ASSIGN(output_table,
output_table->SelectColumns({output_table->num_columns() - 1}));
int result_column = output_table->num_columns() - 1; // last column holds result
ASSERT_OK_AND_ASSIGN(output_table, output_table->SelectColumns({result_column}));

ASSERT_OK_AND_ASSIGN(
std::shared_ptr<Table> expected_output,
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/arrow/engine/substrait/relation_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,14 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet&
const substrait::ReadRel::NamedTable& named_table = read.named_table();
std::vector<std::string> table_names(named_table.names().begin(),
named_table.names().end());
if (table_names.empty()) {
return Status::Invalid("names for NamedTable not provided");
}
ARROW_ASSIGN_OR_RAISE(compute::Declaration source_decl,
named_table_provider(table_names));
if (!source_decl.IsValid()) {
return Status::Invalid("Invalid NamedTable Source");
}
return ProcessEmit(std::move(read),
DeclarationInfo{std::move(source_decl), base_schema},
std::move(base_schema));
Expand Down
13 changes: 6 additions & 7 deletions cpp/src/arrow/engine/substrait/serde_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1924,7 +1924,6 @@ TEST(Substrait, BasicPlanRoundTripping) {

ASSERT_OK_AND_ASSIGN(auto tempdir,
arrow::internal::TemporaryDir::Make("substrait-tempdir-"));
std::cout << "file_path_str " << tempdir->path().ToString() << std::endl;
ASSERT_OK_AND_ASSIGN(auto file_path, tempdir->path().Join(file_name));
std::string file_path_str = file_path.ToString();

Expand Down Expand Up @@ -2189,7 +2188,7 @@ TEST(Substrait, ProjectRel) {
}
},
"namedTable": {
"names": []
"names": ["A"]
}
}
}
Expand Down Expand Up @@ -2313,7 +2312,7 @@ TEST(Substrait, ProjectRelOnFunctionWithEmit) {
}
},
"namedTable": {
"names": []
"names": ["A"]
}
}
}
Expand Down Expand Up @@ -2396,7 +2395,7 @@ TEST(Substrait, ReadRelWithEmit) {
}
},
"namedTable": {
"names" : []
"names" : ["A"]
}
}
}
Expand Down Expand Up @@ -2501,7 +2500,7 @@ TEST(Substrait, FilterRelWithEmit) {
}
},
"namedTable": {
"names" : []
"names" : ["A"]
}
}
}
Expand Down Expand Up @@ -2885,7 +2884,7 @@ TEST(Substrait, AggregateRel) {
}
},
"namedTable" : {
"names": []
"names": ["A"]
}
}
},
Expand Down Expand Up @@ -3004,7 +3003,7 @@ TEST(Substrait, AggregateRelEmit) {
}
},
"namedTable" : {
"names" : []
"names" : ["A"]
}
}
},
Expand Down
23 changes: 14 additions & 9 deletions cpp/src/arrow/engine/substrait/util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,12 @@ class SubstraitSinkConsumer : public compute::SinkNodeConsumer {
class SubstraitExecutor {
public:
explicit SubstraitExecutor(std::shared_ptr<compute::ExecPlan> plan,
compute::ExecContext exec_context)
: plan_(std::move(plan)), plan_started_(false), exec_context_(exec_context) {}
compute::ExecContext exec_context,
const ConversionOptions& conversion_options = {})
: plan_(std::move(plan)),
plan_started_(false),
exec_context_(exec_context),
conversion_options_(conversion_options) {}

~SubstraitExecutor() { ARROW_UNUSED(this->Close()); }

Expand Down Expand Up @@ -95,8 +99,8 @@ class SubstraitExecutor {
return sink_consumer_;
};
ARROW_ASSIGN_OR_RAISE(
declarations_,
engine::DeserializePlans(substrait_buffer, consumer_factory, registry));
declarations_, engine::DeserializePlans(substrait_buffer, consumer_factory,
registry, nullptr, conversion_options_));
return Status::OK();
}

Expand All @@ -107,19 +111,20 @@ class SubstraitExecutor {
bool plan_started_;
compute::ExecContext exec_context_;
std::shared_ptr<SubstraitSinkConsumer> sink_consumer_;
const ConversionOptions& conversion_options_;
};

} // namespace

Result<std::shared_ptr<RecordBatchReader>> ExecuteSerializedPlan(
const Buffer& substrait_buffer, const ExtensionIdRegistry* extid_registry,
compute::FunctionRegistry* func_registry) {
// TODO(ARROW-15732)
const Buffer& substrait_buffer, const ExtensionIdRegistry* registry,
compute::FunctionRegistry* func_registry,
const ConversionOptions& conversion_options) {
compute::ExecContext exec_context(arrow::default_memory_pool(),
::arrow::internal::GetCpuThreadPool(), func_registry);
ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make(&exec_context));
SubstraitExecutor executor(std::move(plan), exec_context);
RETURN_NOT_OK(executor.Init(substrait_buffer, extid_registry));
SubstraitExecutor executor(std::move(plan), exec_context, conversion_options);
RETURN_NOT_OK(executor.Init(substrait_buffer, registry));
ARROW_ASSIGN_OR_RAISE(auto sink_reader, executor.Execute());
// check closing here, not in destructor, to expose error to caller
RETURN_NOT_OK(executor.Close());
Expand Down
8 changes: 6 additions & 2 deletions cpp/src/arrow/engine/substrait/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,21 @@
#include <memory>
#include "arrow/compute/registry.h"
#include "arrow/engine/substrait/api.h"
#include "arrow/engine/substrait/options.h"
#include "arrow/util/iterator.h"
#include "arrow/util/optional.h"

namespace arrow {

namespace engine {

/// \brief Retrieve a RecordBatchReader from a Substrait plan.
using PythonTableProvider =
std::function<Result<std::shared_ptr<Table>>(const std::vector<std::string>&)>;

ARROW_ENGINE_EXPORT Result<std::shared_ptr<RecordBatchReader>> ExecuteSerializedPlan(
const Buffer& substrait_buffer, const ExtensionIdRegistry* registry = NULLPTR,
compute::FunctionRegistry* func_registry = NULLPTR);
compute::FunctionRegistry* func_registry = NULLPTR,
const ConversionOptions& conversion_options = {});

/// \brief Get a Serialized Plan from a Substrait JSON plan.
/// This is a helper method for Python tests.
Expand Down
2 changes: 1 addition & 1 deletion python/pyarrow/_exec_plan.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ cdef execplan(inputs, output_type, vector[CDeclaration] plan, c_bool use_threads
node_factory = "table_source"
c_in_table = pyarrow_unwrap_table(ipt)
c_tablesourceopts = make_shared[CTableSourceNodeOptions](
c_in_table, 1 << 20)
c_in_table)
c_input_node_opts = static_pointer_cast[CExecNodeOptions, CTableSourceNodeOptions](
c_tablesourceopts)
elif isinstance(ipt, Dataset):
Expand Down
97 changes: 94 additions & 3 deletions python/pyarrow/_substrait.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,102 @@

# cython: language_level = 3
from cython.operator cimport dereference as deref
from libcpp.vector cimport vector as std_vector

from pyarrow import Buffer
from pyarrow.lib import frombytes
from pyarrow.lib import frombytes, tobytes
from pyarrow.lib cimport *
from pyarrow.includes.libarrow cimport *
from pyarrow.includes.libarrow_substrait cimport *


def run_query(plan):
cdef CDeclaration _create_named_table_provider(dict named_args, const std_vector[c_string]& names):
cdef:
c_string c_name
shared_ptr[CTable] c_in_table
shared_ptr[CTableSourceNodeOptions] c_tablesourceopts
shared_ptr[CExecNodeOptions] c_input_node_opts
vector[CDeclaration.Input] no_c_inputs

py_names = []
for i in range(names.size()):
c_name = names[i]
py_names.append(frombytes(c_name))

py_table = named_args["provider"](py_names)
c_in_table = pyarrow_unwrap_table(py_table)
c_tablesourceopts = make_shared[CTableSourceNodeOptions](c_in_table)
c_input_node_opts = static_pointer_cast[CExecNodeOptions, CTableSourceNodeOptions](
c_tablesourceopts)
return CDeclaration(tobytes("table_source"),
no_c_inputs, c_input_node_opts)


def run_query(plan, table_provider=None):
"""
Execute a Substrait plan and read the results as a RecordBatchReader.
Parameters
----------
plan : Buffer
The serialized Substrait plan to execute.
table_provider : object (optional)
A function to resolve any NamedTable relation to a table.
The function will receive a single argument which will be a list
of strings representing the table name and should return a pyarrow.Table.
Returns
-------
RecordBatchReader
A reader containing the result of the executed query
Examples
--------
>>> import pyarrow as pa
>>> from pyarrow.lib import tobytes
>>> import pyarrow.substrait as substrait
>>> test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]})
>>> test_table_2 = pa.Table.from_pydict({"x": [4, 5, 6]})
>>> def table_provider(names):
... if not names:
... raise Exception("No names provided")
... elif names[0] == "t1":
... return test_table_1
... elif names[1] == "t2":
... return test_table_2
... else:
... raise Exception("Unrecognized table name")
...
>>> substrait_query = '''
... {
... "relations": [
... {"rel": {
... "read": {
... "base_schema": {
... "struct": {
... "types": [
... {"i64": {}}
... ]
... },
... "names": [
... "x"
... ]
... },
... "namedTable": {
... "names": ["t1"]
... }
... }
... }}
... ]
... }
... '''
>>> buf = pa._substrait._parse_json_plan(tobytes(substrait_query))
>>> reader = pa.substrait.run_query(buf, table_provider)
>>> reader.read_all()
pyarrow.Table
x: int64
----
x: [[1,2,3]]
"""

cdef:
Expand All @@ -41,10 +121,21 @@ def run_query(plan):
RecordBatchReader reader
c_string c_str_plan
shared_ptr[CBuffer] c_buf_plan
function[CNamedTableProvider] c_named_table_provider
CConversionOptions c_conversion_options

c_buf_plan = pyarrow_unwrap_buffer(plan)

if table_provider is not None:
named_table_args = {
"provider": table_provider
}
c_conversion_options.named_table_provider = BindFunction[CNamedTableProvider](
&_create_named_table_provider, named_table_args)

with nogil:
c_res_reader = ExecuteSerializedPlan(deref(c_buf_plan))
c_res_reader = ExecuteSerializedPlan(
deref(c_buf_plan), default_extension_id_registry(), GetFunctionRegistry(), c_conversion_options)

c_reader = GetResultValue(c_res_reader)

Expand Down
1 change: 1 addition & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -2574,6 +2574,7 @@ cdef extern from "arrow/compute/exec/exec_plan.h" namespace "arrow::compute" nog
c_string label
vector[Input] inputs

CDeclaration()
CDeclaration(c_string factory_name, CExecNodeOptions options)
CDeclaration(c_string factory_name, vector[Input] inputs, shared_ptr[CExecNodeOptions] options)

Expand Down
28 changes: 24 additions & 4 deletions python/pyarrow/includes/libarrow_substrait.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,22 @@ from libcpp.vector cimport vector as std_vector
from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport *


cdef extern from "arrow/engine/substrait/util.h" namespace "arrow::engine" nogil:
CResult[shared_ptr[CRecordBatchReader]] ExecuteSerializedPlan(const CBuffer& substrait_buffer)
CResult[shared_ptr[CBuffer]] SerializeJsonPlan(const c_string& substrait_json)
ctypedef CResult[CDeclaration] CNamedTableProvider(const std_vector[c_string]&)

cdef extern from "arrow/engine/substrait/options.h" namespace "arrow::engine" nogil:
cdef enum ConversionStrictness \
"arrow::engine::ConversionStrictness":
EXACT_ROUNDTRIP \
"arrow::engine::ConversionStrictness::EXACT_ROUNDTRIP"
PRESERVE_STRUCTURE \
"arrow::engine::ConversionStrictness::PRESERVE_STRUCTURE"
BEST_EFFORT \
"arrow::engine::ConversionStrictness::BEST_EFFORT"

cdef cppclass CConversionOptions \
"arrow::engine::ConversionOptions":
ConversionStrictness conversion_strictness
function[CNamedTableProvider] named_table_provider

cdef extern from "arrow/engine/substrait/extension_set.h" \
namespace "arrow::engine" nogil:
Expand All @@ -34,3 +46,11 @@ cdef extern from "arrow/engine/substrait/extension_set.h" \
std_vector[c_string] GetSupportedSubstraitFunctions()

ExtensionIdRegistry* default_extension_id_registry()


cdef extern from "arrow/engine/substrait/util.h" namespace "arrow::engine" nogil:
CResult[shared_ptr[CRecordBatchReader]] ExecuteSerializedPlan(
const CBuffer& substrait_buffer, const ExtensionIdRegistry* registry,
CFunctionRegistry* func_registry, const CConversionOptions& conversion_options)

CResult[shared_ptr[CBuffer]] SerializeJsonPlan(const c_string& substrait_json)
Loading

0 comments on commit 72b539f

Please sign in to comment.