Skip to content

Commit

Permalink
Some framework change
Browse files Browse the repository at this point in the history
  • Loading branch information
zanmato1984 committed May 27, 2024
1 parent 49950d6 commit fb7b936
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 48 deletions.
10 changes: 10 additions & 0 deletions cpp/src/arrow/compute/exec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,16 @@ class KernelExecutorImpl : public KernelExecutor {
class ScalarExecutor : public KernelExecutorImpl<ScalarKernel> {
public:
Status Execute(const ExecBatch& batch, ExecListener* listener) override {
if (batch.selection_vector && !kernel_->selection_vector_aware) {
// Slow path for selection vector.
ExecBatch selected_batch;
// Gather selected rows into new batch.
DatumAccumulator new_listener;
RETURN_NOT_OK(Execute(selected_batch, &new_listener));
// Scatter result according to the original selection vector.
return EmitResult(new_listener.values()[0].array(), listener);
}

RETURN_NOT_OK(span_iterator_.Init(batch, exec_context()->exec_chunksize()));

if (batch.length == 0) {
Expand Down
97 changes: 69 additions & 28 deletions cpp/src/arrow/compute/expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,12 @@ Expression field_ref(FieldRef ref) {
}

Expression call(std::string function, std::vector<Expression> arguments,
std::shared_ptr<compute::FunctionOptions> options) {
std::shared_ptr<compute::FunctionOptions> options, bool special_form) {
Expression::Call call;
call.function_name = std::move(function);
call.arguments = std::move(arguments);
call.options = std::move(options);
call.special_form = special_form;
return Expression(std::move(call));
}

Expand Down Expand Up @@ -759,42 +760,82 @@ Result<Datum> ExecuteScalarExpression(const Expression& expr, const ExecBatch& i

std::vector<Datum> arguments(call->arguments.size());

bool all_scalar = true;
for (size_t i = 0; i < arguments.size(); ++i) {
ARROW_ASSIGN_OR_RAISE(
arguments[i], ExecuteScalarExpression(call->arguments[i], input, exec_context));
if (arguments[i].is_array()) {
all_scalar = false;
if (!call->special_form) {
bool all_scalar = true;
for (size_t i = 0; i < arguments.size(); ++i) {
ARROW_ASSIGN_OR_RAISE(
arguments[i], ExecuteScalarExpression(call->arguments[i], input, exec_context));
if (arguments[i].is_array()) {
all_scalar = false;
}
}
}

int64_t input_length;
if (!arguments.empty() && all_scalar) {
// all inputs are scalar, so use a 1-long batch to avoid
// computing input.length equivalent outputs
input_length = 1;
int64_t input_length;
if (!arguments.empty() && all_scalar) {
// all inputs are scalar, so use a 1-long batch to avoid
// computing input.length equivalent outputs
input_length = 1;
} else {
input_length = input.length;
}

auto executor = compute::detail::KernelExecutor::MakeScalar();

compute::KernelContext kernel_context(exec_context, call->kernel);
kernel_context.SetState(call->kernel_state.get());

const Kernel* kernel = call->kernel;
std::vector<TypeHolder> types = GetTypes(arguments);
auto options = call->options.get();
RETURN_NOT_OK(executor->Init(&kernel_context, {kernel, types, options}));

compute::detail::DatumAccumulator listener;
RETURN_NOT_OK(
executor->Execute(ExecBatch(std::move(arguments), input_length), &listener));
const auto out = executor->WrapResults(arguments, listener.values());
#ifndef NDEBUG
DCHECK_OK(executor->CheckResultType(out, call->function_name.c_str()));
#endif
return out;
} else {
input_length = input.length;
}
ARROW_ASSIGN_OR_RAISE(
arguments[0], ExecuteScalarExpression(call->arguments[0], input, exec_context));
// Obtain the selection vector from cond.
ExecBatch if_input = input;
ARROW_ASSIGN_OR_RAISE(
if_input.selection_vector,
SelectionVector::FromMask(*arguments[0].array_as<BooleanArray>()));
ARROW_ASSIGN_OR_RAISE(arguments[1], ExecuteScalarExpression(call->arguments[1],
if_input, exec_context));
ExecBatch else_input = input;
// Else input must consider the original selection vector, instead of merely taking
// false rows from cond.
ARROW_ASSIGN_OR_RAISE(
else_input.selection_vector,
SelectionVector::FromMask(*arguments[0].array_as<BooleanArray>()));
ARROW_ASSIGN_OR_RAISE(
arguments[2],
ExecuteScalarExpression(call->arguments[2], else_input, exec_context));

auto executor = compute::detail::KernelExecutor::MakeScalar();
auto executor = compute::detail::KernelExecutor::MakeScalar();

compute::KernelContext kernel_context(exec_context, call->kernel);
kernel_context.SetState(call->kernel_state.get());
compute::KernelContext kernel_context(exec_context, call->kernel);
kernel_context.SetState(call->kernel_state.get());

const Kernel* kernel = call->kernel;
std::vector<TypeHolder> types = GetTypes(arguments);
auto options = call->options.get();
RETURN_NOT_OK(executor->Init(&kernel_context, {kernel, types, options}));
const Kernel* kernel = call->kernel;
std::vector<TypeHolder> types = GetTypes(arguments);
auto options = call->options.get();
RETURN_NOT_OK(executor->Init(&kernel_context, {kernel, types, options}));

compute::detail::DatumAccumulator listener;
RETURN_NOT_OK(
executor->Execute(ExecBatch(std::move(arguments), input_length), &listener));
const auto out = executor->WrapResults(arguments, listener.values());
compute::detail::DatumAccumulator listener;
RETURN_NOT_OK(
executor->Execute(ExecBatch(std::move(arguments), input.length), &listener));
const auto out = executor->WrapResults(arguments, listener.values());
#ifndef NDEBUG
DCHECK_OK(executor->CheckResultType(out, call->function_name.c_str()));
DCHECK_OK(executor->CheckResultType(out, call->function_name.c_str()));
#endif
return out;
return out;
}
}

namespace {
Expand Down
26 changes: 6 additions & 20 deletions cpp/src/arrow/compute/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,22 +48,7 @@ class ARROW_EXPORT Expression {
std::string function_name;
std::vector<Expression> arguments;
std::shared_ptr<FunctionOptions> options;
// Cached hash value
size_t hash;

// post-Bind properties:
std::shared_ptr<Function> function;
const Kernel* kernel = NULLPTR;
std::shared_ptr<KernelState> kernel_state;
TypeHolder type;

void ComputeHash();
};

struct SpecialCall {
std::string function_name;
std::vector<Expression> arguments;
std::shared_ptr<FunctionOptions> options;
bool special_form = false;
// Cached hash value
size_t hash;

Expand Down Expand Up @@ -175,14 +160,15 @@ Expression field_ref(FieldRef ref);

ARROW_EXPORT
Expression call(std::string function, std::vector<Expression> arguments,
std::shared_ptr<FunctionOptions> options = NULLPTR);
std::shared_ptr<FunctionOptions> options = NULLPTR,
bool special_form = false);

template <typename Options, typename = typename std::enable_if<
std::is_base_of<FunctionOptions, Options>::value>::type>
Expression call(std::string function, std::vector<Expression> arguments,
Options options) {
Expression call(std::string function, std::vector<Expression> arguments, Options options,
bool special_form = false) {
return call(std::move(function), std::move(arguments),
std::make_shared<Options>(std::move(options)));
std::make_shared<Options>(std::move(options)), special_form);
}

/// Assemble a list of all fields referenced by an Expression at any depth.
Expand Down

0 comments on commit fb7b936

Please sign in to comment.