diff --git a/cpp/src/arrow/compute/expression.cc b/cpp/src/arrow/compute/expression.cc index f7ccd27fd645d..f3d9d1388d911 100644 --- a/cpp/src/arrow/compute/expression.cc +++ b/cpp/src/arrow/compute/expression.cc @@ -296,8 +296,8 @@ size_t Expression::hash() const { return ref->hash(); } - if (auto call = call()) { - return call->hash; + if (auto c = call()) { + return c->hash; } return SpecialNotNull(*this)->hash; @@ -633,76 +633,15 @@ Result BindNonRecursive(Expression::Call call, bool insert_implicit_ return Expression(std::move(call)); } -Result BindNonRecursive(Expression::Special special, bool insert_implicit_casts, +Result BindNonRecursive(Expression::Special special, + bool insert_implicit_casts, compute::ExecContext* exec_context) { DCHECK(std::all_of(special.arguments.begin(), special.arguments.end(), [](const Expression& argument) { return argument.IsBound(); })); - std::vector types = GetTypes(special.arguments); - ARROW_ASSIGN_OR_RAISE(call.function, GetFunction(call, exec_context)); - - auto FinishBind = [&] { - compute::KernelContext kernel_context(exec_context, call.kernel); - if (call.kernel->init) { - const FunctionOptions* options = - call.options ? call.options.get() : call.function->default_options(); - ARROW_ASSIGN_OR_RAISE( - call.kernel_state, - call.kernel->init(&kernel_context, {call.kernel, types, options})); - - kernel_context.SetState(call.kernel_state.get()); - } - - ARROW_ASSIGN_OR_RAISE( - call.type, call.kernel->signature->out_type().Resolve(&kernel_context, types)); - return Status::OK(); - }; - - // First try and bind exactly - Result maybe_exact_match = call.function->DispatchExact(types); - if (maybe_exact_match.ok()) { - call.kernel = *maybe_exact_match; - if (FinishBind().ok()) { - return Expression(std::move(call)); - } - } - - if (!insert_implicit_casts) { - return maybe_exact_match.status(); - } - - // If exact binding fails, and we are allowed to cast, then prefer casting literals - // first. Since DispatchBest generally prefers up-casting the best way to do this is - // first down-cast the literals as much as possible - types = GetTypesWithSmallestLiteralRepresentation(call.arguments); - ARROW_ASSIGN_OR_RAISE(call.kernel, call.function->DispatchBest(&types)); - - for (size_t i = 0; i < types.size(); ++i) { - if (types[i] == call.arguments[i].type()) continue; - - if (const Datum* lit = call.arguments[i].literal()) { - ARROW_ASSIGN_OR_RAISE(Datum new_lit, compute::Cast(*lit, types[i].GetSharedPtr())); - call.arguments[i] = literal(std::move(new_lit)); - continue; - } - - // construct an implicit cast Expression with which to replace this argument - Expression::Call implicit_cast; - implicit_cast.function_name = "cast"; - implicit_cast.arguments = {std::move(call.arguments[i])}; - - // TODO(wesm): Use TypeHolder in options - implicit_cast.options = std::make_shared( - compute::CastOptions::Safe(types[i].GetSharedPtr())); - - ARROW_ASSIGN_OR_RAISE( - call.arguments[i], - BindNonRecursive(std::move(implicit_cast), - /*insert_implicit_casts=*/false, exec_context)); - } - - RETURN_NOT_OK(FinishBind()); - return Expression(std::move(call)); + ARROW_ASSIGN_OR_RAISE(special.type, + special.special_form->Resolve(&special.arguments, exec_context)); + return Expression(std::move(special)); } template @@ -726,18 +665,19 @@ Result BindImpl(Expression expr, const TypeOrSchema& in, return Expression{std::move(param)}; } - if (const Call* call = expr.call()) { + if (expr.call()) { + auto call = *expr.call(); for (auto& argument : call.arguments) { ARROW_ASSIGN_OR_RAISE(argument, BindImpl(std::move(argument), in, exec_context)); } - return BindNonRecursive(std::move(call), - /*insert_implicit_casts=*/true, exec_context); + return BindNonRecursive(call, /*insert_implicit_casts=*/true, exec_context); } auto special = *SpecialNotNull(expr); - for (auto& argument : call.arguments) { + for (auto& argument : special.arguments) { ARROW_ASSIGN_OR_RAISE(argument, BindImpl(std::move(argument), in, exec_context)); } + return BindNonRecursive(special, /*insert_implicit_casts=*/true, exec_context); } } // namespace @@ -866,6 +806,10 @@ Result ExecuteScalarExpression(const Expression& expr, const ExecBatch& i return field; } + if (auto special = expr.special()) { + return special->special_form->Execute(special->arguments, input, exec_context); + } + auto call = CallNotNull(expr); std::vector arguments(call->arguments.size()); diff --git a/cpp/src/arrow/compute/special_form.h b/cpp/src/arrow/compute/special_form.h index 47fcf667eb616..f96ec01be3690 100644 --- a/cpp/src/arrow/compute/special_form.h +++ b/cpp/src/arrow/compute/special_form.h @@ -4,7 +4,15 @@ namespace arrow::compute { class SpecialForm { public: - SpecialForm(const std::string& name) {} + SpecialForm(std::string name) : name(std::move(name)) {} + virtual ~SpecialForm() = default; + + virtual Result Resolve(std::vector* arguments, + compute::ExecContext* exec_context) const = 0; + + virtual Result Execute(const std::vector& arguments, + const ExecBatch& input, + compute::ExecContext* exec_context) const = 0; public: const std::string name; @@ -13,6 +21,28 @@ class SpecialForm { class IfElseSpecialForm : public SpecialForm { public: IfElseSpecialForm() : SpecialForm("if_else") {} + + Result Resolve(std::vector* arguments, + compute::ExecContext* exec_context) const override { + ARROW_ASSIGN_OR_RAISE(auto function, + exec_context->func_registry()->GetFunction("if_else")); + std::vector types = GetTypes(*arguments); + + // TODO: DispatchBest and implicit cast. + ARROW_ASSIGN_OR_RAISE(auto maybe_exact_match, function->DispatchExact(types)); + compute::KernelContext kernel_context(exec_context, maybe_exact_match); + if (maybe_exact_match->init) { + const FunctionOptions* options = function->default_options(); + ARROW_ASSIGN_OR_RAISE( + auto kernel_state, + maybe_exact_match->init(&kernel_context, {maybe_exact_match, types, options})); + kernel_context.SetState(kernel_state.get()); + } + return maybe_exact_match->signature->out_type().Resolve(&kernel_context, types); + } + + Result Execute(const std::vector& arguments, const ExecBatch& input, + compute::ExecContext* exec_context) const override {} }; } // namespace arrow::compute