From 72b539f54233f6610b01ec7381755a84c652d151 Mon Sep 17 00:00:00 2001 From: Vibhatha Lakmal Abeykoon Date: Thu, 15 Sep 2022 01:15:58 +0530 Subject: [PATCH] ARROW-17521: [Python] Add python bindings for NamedTableProvider for Substrait consumer (#14024) This PR includes a basic version to use NamedTable feature in Substrait. The idea is to provide the flexibility to write Python tests with in-memory PyArrow tables. Authored-by: Vibhatha Abeykoon Signed-off-by: Weston Pace --- cpp/src/arrow/compute/exec/exec_plan.cc | 4 + cpp/src/arrow/compute/exec/exec_plan.h | 5 + .../arrow/engine/substrait/function_test.cc | 4 +- .../engine/substrait/relation_internal.cc | 6 + cpp/src/arrow/engine/substrait/serde_test.cc | 13 +- cpp/src/arrow/engine/substrait/util.cc | 23 ++-- cpp/src/arrow/engine/substrait/util.h | 8 +- python/pyarrow/_exec_plan.pyx | 2 +- python/pyarrow/_substrait.pyx | 97 +++++++++++++- python/pyarrow/includes/libarrow.pxd | 1 + .../pyarrow/includes/libarrow_substrait.pxd | 28 +++- python/pyarrow/tests/test_substrait.py | 126 ++++++++++++++++++ 12 files changed, 289 insertions(+), 28 deletions(-) diff --git a/cpp/src/arrow/compute/exec/exec_plan.cc b/cpp/src/arrow/compute/exec/exec_plan.cc index b6a3916de1f63..00415495aa8fd 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.cc +++ b/cpp/src/arrow/compute/exec/exec_plan.cc @@ -643,6 +643,10 @@ Declaration Declaration::Sequence(std::vector decls) { return out; } +bool Declaration::IsValid(ExecFactoryRegistry* registry) const { + return !this->factory_name.empty() && this->options != nullptr; +} + namespace internal { void RegisterSourceNode(ExecFactoryRegistry*); diff --git a/cpp/src/arrow/compute/exec/exec_plan.h b/cpp/src/arrow/compute/exec/exec_plan.h index 263f3634a5aa1..a9481e21a6e6f 100644 --- a/cpp/src/arrow/compute/exec/exec_plan.h +++ b/cpp/src/arrow/compute/exec/exec_plan.h @@ -451,6 +451,8 @@ inline Result MakeExecNode( struct ARROW_EXPORT Declaration { using Input = util::Variant; + Declaration() {} + Declaration(std::string factory_name, std::vector inputs, std::shared_ptr options, std::string label) : factory_name{std::move(factory_name)}, @@ -514,6 +516,9 @@ struct ARROW_EXPORT Declaration { Result AddToPlan(ExecPlan* plan, ExecFactoryRegistry* registry = default_exec_factory_registry()) const; + // Validate a declaration + bool IsValid(ExecFactoryRegistry* registry = default_exec_factory_registry()) const; + std::string factory_name; std::vector inputs; std::shared_ptr options; diff --git a/cpp/src/arrow/engine/substrait/function_test.cc b/cpp/src/arrow/engine/substrait/function_test.cc index 0bcb475d310b8..3465f00e13206 100644 --- a/cpp/src/arrow/engine/substrait/function_test.cc +++ b/cpp/src/arrow/engine/substrait/function_test.cc @@ -132,8 +132,8 @@ void CheckValidTestCases(const std::vector& valid_cases) { ASSERT_FINISHES_OK(plan->finished()); // Could also modify the Substrait plan with an emit to drop the leading columns - ASSERT_OK_AND_ASSIGN(output_table, - output_table->SelectColumns({output_table->num_columns() - 1})); + int result_column = output_table->num_columns() - 1; // last column holds result + ASSERT_OK_AND_ASSIGN(output_table, output_table->SelectColumns({result_column})); ASSERT_OK_AND_ASSIGN( std::shared_ptr expected_output, diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index 4213895b61681..3911373b7b74b 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -135,8 +135,14 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& const substrait::ReadRel::NamedTable& named_table = read.named_table(); std::vector table_names(named_table.names().begin(), named_table.names().end()); + if (table_names.empty()) { + return Status::Invalid("names for NamedTable not provided"); + } ARROW_ASSIGN_OR_RAISE(compute::Declaration source_decl, named_table_provider(table_names)); + if (!source_decl.IsValid()) { + return Status::Invalid("Invalid NamedTable Source"); + } return ProcessEmit(std::move(read), DeclarationInfo{std::move(source_decl), base_schema}, std::move(base_schema)); diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 251c2bfe35202..b50e1c6084cf8 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -1924,7 +1924,6 @@ TEST(Substrait, BasicPlanRoundTripping) { ASSERT_OK_AND_ASSIGN(auto tempdir, arrow::internal::TemporaryDir::Make("substrait-tempdir-")); - std::cout << "file_path_str " << tempdir->path().ToString() << std::endl; ASSERT_OK_AND_ASSIGN(auto file_path, tempdir->path().Join(file_name)); std::string file_path_str = file_path.ToString(); @@ -2189,7 +2188,7 @@ TEST(Substrait, ProjectRel) { } }, "namedTable": { - "names": [] + "names": ["A"] } } } @@ -2313,7 +2312,7 @@ TEST(Substrait, ProjectRelOnFunctionWithEmit) { } }, "namedTable": { - "names": [] + "names": ["A"] } } } @@ -2396,7 +2395,7 @@ TEST(Substrait, ReadRelWithEmit) { } }, "namedTable": { - "names" : [] + "names" : ["A"] } } } @@ -2501,7 +2500,7 @@ TEST(Substrait, FilterRelWithEmit) { } }, "namedTable": { - "names" : [] + "names" : ["A"] } } } @@ -2885,7 +2884,7 @@ TEST(Substrait, AggregateRel) { } }, "namedTable" : { - "names": [] + "names": ["A"] } } }, @@ -3004,7 +3003,7 @@ TEST(Substrait, AggregateRelEmit) { } }, "namedTable" : { - "names" : [] + "names" : ["A"] } } }, diff --git a/cpp/src/arrow/engine/substrait/util.cc b/cpp/src/arrow/engine/substrait/util.cc index 936bde5c652e5..f51666ef8585b 100644 --- a/cpp/src/arrow/engine/substrait/util.cc +++ b/cpp/src/arrow/engine/substrait/util.cc @@ -63,8 +63,12 @@ class SubstraitSinkConsumer : public compute::SinkNodeConsumer { class SubstraitExecutor { public: explicit SubstraitExecutor(std::shared_ptr plan, - compute::ExecContext exec_context) - : plan_(std::move(plan)), plan_started_(false), exec_context_(exec_context) {} + compute::ExecContext exec_context, + const ConversionOptions& conversion_options = {}) + : plan_(std::move(plan)), + plan_started_(false), + exec_context_(exec_context), + conversion_options_(conversion_options) {} ~SubstraitExecutor() { ARROW_UNUSED(this->Close()); } @@ -95,8 +99,8 @@ class SubstraitExecutor { return sink_consumer_; }; ARROW_ASSIGN_OR_RAISE( - declarations_, - engine::DeserializePlans(substrait_buffer, consumer_factory, registry)); + declarations_, engine::DeserializePlans(substrait_buffer, consumer_factory, + registry, nullptr, conversion_options_)); return Status::OK(); } @@ -107,19 +111,20 @@ class SubstraitExecutor { bool plan_started_; compute::ExecContext exec_context_; std::shared_ptr sink_consumer_; + const ConversionOptions& conversion_options_; }; } // namespace Result> ExecuteSerializedPlan( - const Buffer& substrait_buffer, const ExtensionIdRegistry* extid_registry, - compute::FunctionRegistry* func_registry) { - // TODO(ARROW-15732) + const Buffer& substrait_buffer, const ExtensionIdRegistry* registry, + compute::FunctionRegistry* func_registry, + const ConversionOptions& conversion_options) { compute::ExecContext exec_context(arrow::default_memory_pool(), ::arrow::internal::GetCpuThreadPool(), func_registry); ARROW_ASSIGN_OR_RAISE(auto plan, compute::ExecPlan::Make(&exec_context)); - SubstraitExecutor executor(std::move(plan), exec_context); - RETURN_NOT_OK(executor.Init(substrait_buffer, extid_registry)); + SubstraitExecutor executor(std::move(plan), exec_context, conversion_options); + RETURN_NOT_OK(executor.Init(substrait_buffer, registry)); ARROW_ASSIGN_OR_RAISE(auto sink_reader, executor.Execute()); // check closing here, not in destructor, to expose error to caller RETURN_NOT_OK(executor.Close()); diff --git a/cpp/src/arrow/engine/substrait/util.h b/cpp/src/arrow/engine/substrait/util.h index 3ac9320e1da76..90cb4e3dd2a37 100644 --- a/cpp/src/arrow/engine/substrait/util.h +++ b/cpp/src/arrow/engine/substrait/util.h @@ -20,6 +20,7 @@ #include #include "arrow/compute/registry.h" #include "arrow/engine/substrait/api.h" +#include "arrow/engine/substrait/options.h" #include "arrow/util/iterator.h" #include "arrow/util/optional.h" @@ -27,10 +28,13 @@ namespace arrow { namespace engine { -/// \brief Retrieve a RecordBatchReader from a Substrait plan. +using PythonTableProvider = + std::function>(const std::vector&)>; + ARROW_ENGINE_EXPORT Result> ExecuteSerializedPlan( const Buffer& substrait_buffer, const ExtensionIdRegistry* registry = NULLPTR, - compute::FunctionRegistry* func_registry = NULLPTR); + compute::FunctionRegistry* func_registry = NULLPTR, + const ConversionOptions& conversion_options = {}); /// \brief Get a Serialized Plan from a Substrait JSON plan. /// This is a helper method for Python tests. diff --git a/python/pyarrow/_exec_plan.pyx b/python/pyarrow/_exec_plan.pyx index 89e474f43906c..9506caf7d287f 100644 --- a/python/pyarrow/_exec_plan.pyx +++ b/python/pyarrow/_exec_plan.pyx @@ -92,7 +92,7 @@ cdef execplan(inputs, output_type, vector[CDeclaration] plan, c_bool use_threads node_factory = "table_source" c_in_table = pyarrow_unwrap_table(ipt) c_tablesourceopts = make_shared[CTableSourceNodeOptions]( - c_in_table, 1 << 20) + c_in_table) c_input_node_opts = static_pointer_cast[CExecNodeOptions, CTableSourceNodeOptions]( c_tablesourceopts) elif isinstance(ipt, Dataset): diff --git a/python/pyarrow/_substrait.pyx b/python/pyarrow/_substrait.pyx index 05794a95a20ee..47a519cf16b53 100644 --- a/python/pyarrow/_substrait.pyx +++ b/python/pyarrow/_substrait.pyx @@ -17,15 +17,38 @@ # cython: language_level = 3 from cython.operator cimport dereference as deref +from libcpp.vector cimport vector as std_vector from pyarrow import Buffer -from pyarrow.lib import frombytes +from pyarrow.lib import frombytes, tobytes from pyarrow.lib cimport * from pyarrow.includes.libarrow cimport * from pyarrow.includes.libarrow_substrait cimport * -def run_query(plan): +cdef CDeclaration _create_named_table_provider(dict named_args, const std_vector[c_string]& names): + cdef: + c_string c_name + shared_ptr[CTable] c_in_table + shared_ptr[CTableSourceNodeOptions] c_tablesourceopts + shared_ptr[CExecNodeOptions] c_input_node_opts + vector[CDeclaration.Input] no_c_inputs + + py_names = [] + for i in range(names.size()): + c_name = names[i] + py_names.append(frombytes(c_name)) + + py_table = named_args["provider"](py_names) + c_in_table = pyarrow_unwrap_table(py_table) + c_tablesourceopts = make_shared[CTableSourceNodeOptions](c_in_table) + c_input_node_opts = static_pointer_cast[CExecNodeOptions, CTableSourceNodeOptions]( + c_tablesourceopts) + return CDeclaration(tobytes("table_source"), + no_c_inputs, c_input_node_opts) + + +def run_query(plan, table_provider=None): """ Execute a Substrait plan and read the results as a RecordBatchReader. @@ -33,6 +56,63 @@ def run_query(plan): ---------- plan : Buffer The serialized Substrait plan to execute. + table_provider : object (optional) + A function to resolve any NamedTable relation to a table. + The function will receive a single argument which will be a list + of strings representing the table name and should return a pyarrow.Table. + + Returns + ------- + RecordBatchReader + A reader containing the result of the executed query + + Examples + -------- + >>> import pyarrow as pa + >>> from pyarrow.lib import tobytes + >>> import pyarrow.substrait as substrait + >>> test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]}) + >>> test_table_2 = pa.Table.from_pydict({"x": [4, 5, 6]}) + >>> def table_provider(names): + ... if not names: + ... raise Exception("No names provided") + ... elif names[0] == "t1": + ... return test_table_1 + ... elif names[1] == "t2": + ... return test_table_2 + ... else: + ... raise Exception("Unrecognized table name") + ... + >>> substrait_query = ''' + ... { + ... "relations": [ + ... {"rel": { + ... "read": { + ... "base_schema": { + ... "struct": { + ... "types": [ + ... {"i64": {}} + ... ] + ... }, + ... "names": [ + ... "x" + ... ] + ... }, + ... "namedTable": { + ... "names": ["t1"] + ... } + ... } + ... }} + ... ] + ... } + ... ''' + >>> buf = pa._substrait._parse_json_plan(tobytes(substrait_query)) + >>> reader = pa.substrait.run_query(buf, table_provider) + >>> reader.read_all() + pyarrow.Table + x: int64 + ---- + x: [[1,2,3]] """ cdef: @@ -41,10 +121,21 @@ def run_query(plan): RecordBatchReader reader c_string c_str_plan shared_ptr[CBuffer] c_buf_plan + function[CNamedTableProvider] c_named_table_provider + CConversionOptions c_conversion_options c_buf_plan = pyarrow_unwrap_buffer(plan) + + if table_provider is not None: + named_table_args = { + "provider": table_provider + } + c_conversion_options.named_table_provider = BindFunction[CNamedTableProvider]( + &_create_named_table_provider, named_table_args) + with nogil: - c_res_reader = ExecuteSerializedPlan(deref(c_buf_plan)) + c_res_reader = ExecuteSerializedPlan( + deref(c_buf_plan), default_extension_id_registry(), GetFunctionRegistry(), c_conversion_options) c_reader = GetResultValue(c_res_reader) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index be273975f94bf..489d73bf27e6f 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2574,6 +2574,7 @@ cdef extern from "arrow/compute/exec/exec_plan.h" namespace "arrow::compute" nog c_string label vector[Input] inputs + CDeclaration() CDeclaration(c_string factory_name, CExecNodeOptions options) CDeclaration(c_string factory_name, vector[Input] inputs, shared_ptr[CExecNodeOptions] options) diff --git a/python/pyarrow/includes/libarrow_substrait.pxd b/python/pyarrow/includes/libarrow_substrait.pxd index 0b3ace75d92b0..04990380d97a6 100644 --- a/python/pyarrow/includes/libarrow_substrait.pxd +++ b/python/pyarrow/includes/libarrow_substrait.pxd @@ -22,10 +22,22 @@ from libcpp.vector cimport vector as std_vector from pyarrow.includes.common cimport * from pyarrow.includes.libarrow cimport * - -cdef extern from "arrow/engine/substrait/util.h" namespace "arrow::engine" nogil: - CResult[shared_ptr[CRecordBatchReader]] ExecuteSerializedPlan(const CBuffer& substrait_buffer) - CResult[shared_ptr[CBuffer]] SerializeJsonPlan(const c_string& substrait_json) +ctypedef CResult[CDeclaration] CNamedTableProvider(const std_vector[c_string]&) + +cdef extern from "arrow/engine/substrait/options.h" namespace "arrow::engine" nogil: + cdef enum ConversionStrictness \ + "arrow::engine::ConversionStrictness": + EXACT_ROUNDTRIP \ + "arrow::engine::ConversionStrictness::EXACT_ROUNDTRIP" + PRESERVE_STRUCTURE \ + "arrow::engine::ConversionStrictness::PRESERVE_STRUCTURE" + BEST_EFFORT \ + "arrow::engine::ConversionStrictness::BEST_EFFORT" + + cdef cppclass CConversionOptions \ + "arrow::engine::ConversionOptions": + ConversionStrictness conversion_strictness + function[CNamedTableProvider] named_table_provider cdef extern from "arrow/engine/substrait/extension_set.h" \ namespace "arrow::engine" nogil: @@ -34,3 +46,11 @@ cdef extern from "arrow/engine/substrait/extension_set.h" \ std_vector[c_string] GetSupportedSubstraitFunctions() ExtensionIdRegistry* default_extension_id_registry() + + +cdef extern from "arrow/engine/substrait/util.h" namespace "arrow::engine" nogil: + CResult[shared_ptr[CRecordBatchReader]] ExecuteSerializedPlan( + const CBuffer& substrait_buffer, const ExtensionIdRegistry* registry, + CFunctionRegistry* func_registry, const CConversionOptions& conversion_options) + + CResult[shared_ptr[CBuffer]] SerializeJsonPlan(const c_string& substrait_json) diff --git a/python/pyarrow/tests/test_substrait.py b/python/pyarrow/tests/test_substrait.py index c8fa6afcb9ffa..c8fd8048aa492 100644 --- a/python/pyarrow/tests/test_substrait.py +++ b/python/pyarrow/tests/test_substrait.py @@ -165,3 +165,129 @@ def test_get_supported_functions(): 'functions_arithmetic.yaml', 'add') assert has_function(supported_functions, 'functions_arithmetic.yaml', 'sum') + + +def test_named_table(): + test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]}) + test_table_2 = pa.Table.from_pydict({"x": [4, 5, 6]}) + + def table_provider(names): + if not names: + raise Exception("No names provided") + elif names[0] == "t1": + return test_table_1 + elif names[1] == "t2": + return test_table_2 + else: + raise Exception("Unrecognized table name") + + substrait_query = """ + { + "relations": [ + {"rel": { + "read": { + "base_schema": { + "struct": { + "types": [ + {"i64": {}} + ] + }, + "names": [ + "x" + ] + }, + "namedTable": { + "names": ["t1"] + } + } + }} + ] + } + """ + + buf = pa._substrait._parse_json_plan(tobytes(substrait_query)) + reader = pa.substrait.run_query(buf, table_provider) + res_tb = reader.read_all() + assert res_tb == test_table_1 + + +def test_named_table_invalid_table_name(): + test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]}) + + def table_provider(names): + if not names: + raise Exception("No names provided") + elif names[0] == "t1": + return test_table_1 + else: + raise Exception("Unrecognized table name") + + substrait_query = """ + { + "relations": [ + {"rel": { + "read": { + "base_schema": { + "struct": { + "types": [ + {"i64": {}} + ] + }, + "names": [ + "x" + ] + }, + "namedTable": { + "names": ["t3"] + } + } + }} + ] + } + """ + + buf = pa._substrait._parse_json_plan(tobytes(substrait_query)) + exec_message = "Invalid NamedTable Source" + with pytest.raises(ArrowInvalid, match=exec_message): + substrait.run_query(buf, table_provider) + + +def test_named_table_empty_names(): + test_table_1 = pa.Table.from_pydict({"x": [1, 2, 3]}) + + def table_provider(names): + if not names: + raise Exception("No names provided") + elif names[0] == "t1": + return test_table_1 + else: + raise Exception("Unrecognized table name") + + substrait_query = """ + { + "relations": [ + {"rel": { + "read": { + "base_schema": { + "struct": { + "types": [ + {"i64": {}} + ] + }, + "names": [ + "x" + ] + }, + "namedTable": { + "names": [] + } + } + }} + ] + } + """ + query = tobytes(substrait_query) + buf = pa._substrait._parse_json_plan(tobytes(query)) + exec_message = "names for NamedTable not provided" + with pytest.raises(ArrowInvalid, match=exec_message): + substrait.run_query(buf, table_provider)