Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
zanmato1984 committed Oct 15, 2024
1 parent 945505a commit 13aa602
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 73 deletions.
88 changes: 16 additions & 72 deletions cpp/src/arrow/compute/expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,8 @@ size_t Expression::hash() const {
return ref->hash();
}

if (auto call = call()) {
return call->hash;
if (auto c = call()) {
return c->hash;
}

return SpecialNotNull(*this)->hash;
Expand Down Expand Up @@ -633,76 +633,15 @@ Result<Expression> BindNonRecursive(Expression::Call call, bool insert_implicit_
return Expression(std::move(call));
}

Result<Expression> BindNonRecursive(Expression::Special special, bool insert_implicit_casts,
Result<Expression> BindNonRecursive(Expression::Special special,
bool insert_implicit_casts,
compute::ExecContext* exec_context) {
DCHECK(std::all_of(special.arguments.begin(), special.arguments.end(),
[](const Expression& argument) { return argument.IsBound(); }));

std::vector<TypeHolder> types = GetTypes(special.arguments);
ARROW_ASSIGN_OR_RAISE(call.function, GetFunction(call, exec_context));

auto FinishBind = [&] {
compute::KernelContext kernel_context(exec_context, call.kernel);
if (call.kernel->init) {
const FunctionOptions* options =
call.options ? call.options.get() : call.function->default_options();
ARROW_ASSIGN_OR_RAISE(
call.kernel_state,
call.kernel->init(&kernel_context, {call.kernel, types, options}));

kernel_context.SetState(call.kernel_state.get());
}

ARROW_ASSIGN_OR_RAISE(
call.type, call.kernel->signature->out_type().Resolve(&kernel_context, types));
return Status::OK();
};

// First try and bind exactly
Result<const Kernel*> maybe_exact_match = call.function->DispatchExact(types);
if (maybe_exact_match.ok()) {
call.kernel = *maybe_exact_match;
if (FinishBind().ok()) {
return Expression(std::move(call));
}
}

if (!insert_implicit_casts) {
return maybe_exact_match.status();
}

// If exact binding fails, and we are allowed to cast, then prefer casting literals
// first. Since DispatchBest generally prefers up-casting the best way to do this is
// first down-cast the literals as much as possible
types = GetTypesWithSmallestLiteralRepresentation(call.arguments);
ARROW_ASSIGN_OR_RAISE(call.kernel, call.function->DispatchBest(&types));

for (size_t i = 0; i < types.size(); ++i) {
if (types[i] == call.arguments[i].type()) continue;

if (const Datum* lit = call.arguments[i].literal()) {
ARROW_ASSIGN_OR_RAISE(Datum new_lit, compute::Cast(*lit, types[i].GetSharedPtr()));
call.arguments[i] = literal(std::move(new_lit));
continue;
}

// construct an implicit cast Expression with which to replace this argument
Expression::Call implicit_cast;
implicit_cast.function_name = "cast";
implicit_cast.arguments = {std::move(call.arguments[i])};

// TODO(wesm): Use TypeHolder in options
implicit_cast.options = std::make_shared<compute::CastOptions>(
compute::CastOptions::Safe(types[i].GetSharedPtr()));

ARROW_ASSIGN_OR_RAISE(
call.arguments[i],
BindNonRecursive(std::move(implicit_cast),
/*insert_implicit_casts=*/false, exec_context));
}

RETURN_NOT_OK(FinishBind());
return Expression(std::move(call));
ARROW_ASSIGN_OR_RAISE(special.type,
special.special_form->Resolve(&special.arguments, exec_context));
return Expression(std::move(special));
}

template <typename TypeOrSchema>
Expand All @@ -726,18 +665,19 @@ Result<Expression> BindImpl(Expression expr, const TypeOrSchema& in,
return Expression{std::move(param)};
}

if (const Call* call = expr.call()) {
if (expr.call()) {
auto call = *expr.call();
for (auto& argument : call.arguments) {
ARROW_ASSIGN_OR_RAISE(argument, BindImpl(std::move(argument), in, exec_context));
}
return BindNonRecursive(std::move(call),
/*insert_implicit_casts=*/true, exec_context);
return BindNonRecursive(call, /*insert_implicit_casts=*/true, exec_context);
}

auto special = *SpecialNotNull(expr);
for (auto& argument : call.arguments) {
for (auto& argument : special.arguments) {
ARROW_ASSIGN_OR_RAISE(argument, BindImpl(std::move(argument), in, exec_context));
}
return BindNonRecursive(special, /*insert_implicit_casts=*/true, exec_context);
}

} // namespace
Expand Down Expand Up @@ -866,6 +806,10 @@ Result<Datum> ExecuteScalarExpression(const Expression& expr, const ExecBatch& i
return field;
}

if (auto special = expr.special()) {
return special->special_form->Execute(special->arguments, input, exec_context);
}

auto call = CallNotNull(expr);

std::vector<Datum> arguments(call->arguments.size());
Expand Down
32 changes: 31 additions & 1 deletion cpp/src/arrow/compute/special_form.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,15 @@ namespace arrow::compute {

class SpecialForm {
public:
SpecialForm(const std::string& name) {}
SpecialForm(std::string name) : name(std::move(name)) {}
virtual ~SpecialForm() = default;

virtual Result<TypeHolder> Resolve(std::vector<Expression>* arguments,
compute::ExecContext* exec_context) const = 0;

virtual Result<Datum> Execute(const std::vector<Expression>& arguments,
const ExecBatch& input,
compute::ExecContext* exec_context) const = 0;

public:
const std::string name;
Expand All @@ -13,6 +21,28 @@ class SpecialForm {
class IfElseSpecialForm : public SpecialForm {
public:
IfElseSpecialForm() : SpecialForm("if_else") {}

Result<TypeHolder> Resolve(std::vector<Expression>* arguments,
compute::ExecContext* exec_context) const override {
ARROW_ASSIGN_OR_RAISE(auto function,
exec_context->func_registry()->GetFunction("if_else"));
std::vector<TypeHolder> types = GetTypes(*arguments);

// TODO: DispatchBest and implicit cast.
ARROW_ASSIGN_OR_RAISE(auto maybe_exact_match, function->DispatchExact(types));
compute::KernelContext kernel_context(exec_context, maybe_exact_match);
if (maybe_exact_match->init) {
const FunctionOptions* options = function->default_options();
ARROW_ASSIGN_OR_RAISE(
auto kernel_state,
maybe_exact_match->init(&kernel_context, {maybe_exact_match, types, options}));
kernel_context.SetState(kernel_state.get());
}
return maybe_exact_match->signature->out_type().Resolve(&kernel_context, types);
}

Result<Datum> Execute(const std::vector<Expression>& arguments, const ExecBatch& input,
compute::ExecContext* exec_context) const override {}
};

} // namespace arrow::compute

0 comments on commit 13aa602

Please sign in to comment.