Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
zanmato1984 committed Oct 23, 2024
1 parent 6d78ea5 commit 6bab9b5
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 23 deletions.
4 changes: 4 additions & 0 deletions cpp/src/arrow/compute/exec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ Status ExecSpanIterator::Init(const ExecBatch& batch, int64_t max_chunksize,
value_offsets_.clear();
value_offsets_.resize(args_->size(), 0);
max_chunksize_ = std::min(length_, max_chunksize);
selection_vector_ = batch.selection_vector.get();
return Status::OK();
}

Expand Down Expand Up @@ -437,13 +438,16 @@ bool ExecSpanIterator::Next(ExecSpan* span) {
span->values[i].scalar = nullptr;
}
have_chunked_arrays_ = true;
DCHECK_EQ(selection_vector_, nullptr);
}
}

if (have_all_scalars_ && promote_if_all_scalars_) {
PromoteExecSpanScalars(span);
}

span->selection_vector = selection_vector_;

initialized_ = true;
} else if (position_ == length_) {
// We've emitted at least one span and we're at the end so we are done
Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/compute/exec.h
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ struct ARROW_EXPORT ExecSpan {

int64_t length = 0;
std::vector<ExecValue> values;
SelectionVector* selection_vector = NULLPTR;
};

/// \defgroup compute-call-function One-shot calls to compute functions
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/compute/exec_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ namespace detail {
/// \brief Break std::vector<Datum> into a sequence of non-owning
/// ExecSpan for kernel execution. The lifetime of the Datum vector
/// must be longer than the lifetime of this object
// TODO: Can this struct sense the presence of selection vector and not split ExecBatch if
// so?
class ARROW_EXPORT ExecSpanIterator {
public:
ExecSpanIterator() = default;
Expand Down Expand Up @@ -94,6 +96,7 @@ class ARROW_EXPORT ExecSpanIterator {
int64_t position_ = 0;
int64_t length_ = 0;
int64_t max_chunksize_;
SelectionVector* selection_vector_ = NULLPTR;
};

// "Push" / listener API like IPC reader so that consumers can receive
Expand Down
45 changes: 31 additions & 14 deletions cpp/src/arrow/compute/special_form.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Result<TypeHolder> IfElseSpecialForm::Resolve(std::vector<Expression>* arguments

namespace {

// TODO: Take scalar may not work.
Result<ExecBatch> TakeBySelectionVector(const ExecBatch& input,
const Datum& selection_vector,
ExecContext* exec_context) {
Expand All @@ -62,12 +63,14 @@ std::shared_ptr<ChunkedArray> ChunkedArrayFromDatums(const std::vector<Datum>& d
std::vector<std::shared_ptr<Array>> chunks;
for (const auto& datum : datums) {
DCHECK(datum.is_arraylike());
if (datum.is_array()) {
if (datum.is_array() && datum.length() > 0) {
chunks.push_back(datum.make_array());
} else {
DCHECK(datum.is_chunked_array());
for (const auto& chunk : datum.chunked_array()->chunks()) {
chunks.push_back(chunk);
if (chunk->length() > 0) {
chunks.push_back(chunk);
}
}
}
}
Expand Down Expand Up @@ -129,6 +132,7 @@ Result<Datum> IfElseSpecialForm::Execute(const std::vector<Expression>& argument

Datum if_true = MakeNullScalar(if_true_expr.type()->GetSharedPtr());
Datum if_false = MakeNullScalar(if_false_expr.type()->GetSharedPtr());
bool all_null = true;

if (IsSelectionVectorAwarePathAvailable({if_true_expr, if_false_expr}, input,
exec_context)) {
Expand All @@ -144,19 +148,28 @@ Result<Datum> IfElseSpecialForm::Execute(const std::vector<Expression>& argument
if (sel_true->length() == input.length) {
return if_true;
}
all_null = false;
}

ARROW_ASSIGN_OR_RAISE(auto cond_inverted,
CallFunction("invert", {cond}, exec_context));
DCHECK(cond_inverted.is_array());
ARROW_ASSIGN_OR_RAISE(auto sel_false, SelectionVector::FromMask(*boolean_cond));
ExecBatch input_false = input;
input_false.selection_vector = sel_false;
ARROW_ASSIGN_OR_RAISE(
if_false, ExecuteScalarExpression(if_false_expr, input_false, exec_context));
if (sel_false->length() == input.length) {
return if_false;
if (sel_false->length() > 0) {
ExecBatch input_false = input;
input_false.selection_vector = sel_false;
ARROW_ASSIGN_OR_RAISE(
if_false, ExecuteScalarExpression(if_false_expr, input_false, exec_context));
if (sel_false->length() == input.length) {
return if_false;
}
all_null = false;
}

if (all_null) {
return MakeNullScalar(if_true_expr.type()->GetSharedPtr());
}

return CallFunction("if_else", {cond, if_true, if_false}, exec_context);
}

Expand All @@ -170,17 +183,21 @@ Result<Datum> IfElseSpecialForm::Execute(const std::vector<Expression>& argument
if (sel_true.length() == input.length) {
return if_true;
}
all_null = false;
}

ARROW_ASSIGN_OR_RAISE(auto cond_inverted, CallFunction("invert", {cond}, exec_context));
ARROW_ASSIGN_OR_RAISE(auto sel_false,
CallFunction("indices_nonzero", {cond_inverted}, exec_context));
ARROW_ASSIGN_OR_RAISE(auto input_false,
TakeBySelectionVector(input, sel_false, exec_context));
ARROW_ASSIGN_OR_RAISE(
if_false, ExecuteScalarExpression(if_false_expr, input_false, exec_context));
if (sel_false.length() == input.length) {
return if_false;
if (sel_false.length() > 0) {
ARROW_ASSIGN_OR_RAISE(auto input_false,
TakeBySelectionVector(input, sel_false, exec_context));
ARROW_ASSIGN_OR_RAISE(
if_false, ExecuteScalarExpression(if_false_expr, input_false, exec_context));
if (sel_false.length() == input.length) {
return if_false;
}
all_null = false;
}

auto if_true_false = ChunkedArrayFromDatums({if_true, if_false});
Expand Down
180 changes: 171 additions & 9 deletions cpp/src/arrow/compute/special_form_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,70 @@

#include "arrow/compute/exec.h"
#include "arrow/compute/expression.h"
#include "arrow/compute/function.h"
#include "arrow/compute/kernels/codegen_internal.h"
#include "arrow/compute/registry.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/util/logging.h"

namespace arrow::compute {

namespace {

template <bool selection_vector_aware>
Status TestKernelExec(KernelContext*, const ExecSpan& span, ExecResult* out) {
DCHECK_EQ(span.num_values(), 1);
if constexpr (!selection_vector_aware) {
if (span.selection_vector != nullptr) {
return Status::Invalid("There is a selection vector");
}
}
const auto& arg = span[0];
DCHECK(arg.is_array());
*out->array_data_mutable() = *arg.array.ToArrayData();
return Status::OK();
}

static Status RegisterTestFunctions() {
auto registry = GetFunctionRegistry();

auto register_test_func = [&](const std::string& name,
bool selection_vector_aware) -> Status {
auto panic_on_selection =
std::make_shared<ScalarFunction>(name, Arity::Unary(), FunctionDoc::Empty());

ArrayKernelExec exec;
if (selection_vector_aware) {
exec = TestKernelExec<true>;
} else {
exec = TestKernelExec<false>;
}
ScalarKernel kernel({InputType::Any()}, internal::FirstType, std::move(exec));
kernel.selection_vector_aware = selection_vector_aware;
kernel.can_write_into_slices = false;
kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE;
kernel.mem_allocation = MemAllocation::NO_PREALLOCATE;
RETURN_NOT_OK(panic_on_selection->AddKernel(kernel));
RETURN_NOT_OK(registry->AddFunction(std::move(panic_on_selection)));
return Status::OK();
};

RETURN_NOT_OK(register_test_func("panic_on_selection", false));
RETURN_NOT_OK(register_test_func("calm_on_selection", true));

return Status::OK();
}

Expression panic_on_selection(Expression arg) {
return call("panic_on_selection", {std::move(arg)});
}

Expression calm_on_selection(Expression arg) {
return call("calm_on_selection", {std::move(arg)});
}

} // namespace

TEST(IfElseSpecialForm, Basic) {
{
ARROW_SCOPED_TRACE("if (b != 0) then a / b else b");
Expand Down Expand Up @@ -97,28 +157,35 @@ namespace {
void AssertIfElseEqual(const Datum& expected, Expression cond, Expression if_true,
Expression if_false, const std::shared_ptr<Schema>& schema,
const ExecBatch& input) {
auto if_else_sp = if_else_special(cond, if_true, if_false);
ASSERT_OK_AND_ASSIGN(auto bound, if_else_sp.Bind(*schema));
ASSERT_OK_AND_ASSIGN(auto result, ExecuteScalarExpression(bound, input));
AssertDatumsEqual(expected, result);
// Test using original/panic_on_selection(original)/calm_on_selection(original).
for (auto if_else_sp : {if_else_special(cond, if_true, if_false),
panic_on_selection(if_else_special(
panic_on_selection(cond), panic_on_selection(if_true),
panic_on_selection(if_false))),
calm_on_selection(if_else_special(
calm_on_selection(cond), calm_on_selection(if_true),
calm_on_selection(if_false)))}) {
ARROW_SCOPED_TRACE(if_else_sp.ToString());
ASSERT_OK_AND_ASSIGN(auto bound, if_else_sp.Bind(*schema));
ASSERT_OK_AND_ASSIGN(auto result, ExecuteScalarExpression(bound, input));
AssertDatumsEqual(expected, result);
}
}

void AssertIfElseEqualWithExpr(Expression cond, Expression if_true, Expression if_false,
const std::shared_ptr<Schema>& schema,
const ExecBatch& input) {
auto if_else = call("if_else", {cond, if_true, if_false});
auto if_else_sp = if_else_special(cond, if_true, if_false);
ASSERT_OK_AND_ASSIGN(auto bound, if_else.Bind(*schema));
ASSERT_OK_AND_ASSIGN(auto result, ExecuteScalarExpression(bound, input));
ASSERT_OK_AND_ASSIGN(auto bound_sp, if_else_sp.Bind(*schema));
ASSERT_OK_AND_ASSIGN(auto result_sp, ExecuteScalarExpression(bound_sp, input));
AssertDatumsEqual(result, result_sp);
AssertIfElseEqual(result, cond, if_true, if_false, schema, input);
}

} // namespace

// TODO: A function to break the selection vector awareness of the expressions.
TEST(IfElseSpecialForm, Shortcuts) {
ASSERT_OK(RegisterTestFunctions());
{
ARROW_SCOPED_TRACE("if (null) then 1 else 0");
AssertIfElseEqual(MakeNullScalar(int32()), literal(MakeNullScalar(boolean())),
Expand All @@ -140,12 +207,17 @@ TEST(IfElseSpecialForm, Shortcuts) {
{
auto schema = arrow::schema({field("a", int32()), field("b", int32())});
std::vector<ExecBatch> batches = {
ExecBatch(*RecordBatchFromJSON(schema, R"([])")),
ExecBatch(*RecordBatchFromJSON(schema, R"([
[1, 0],
[1, 0],
[1, 0]
])")),
ExecBatch(*RecordBatchFromJSON(schema, R"([])")),
ExecBatch(*RecordBatchFromJSON(schema, R"([
[1, 0],
[null, 0],
[1, null]
])")),
};
for (const auto& input : batches) {
{
Expand Down Expand Up @@ -184,16 +256,31 @@ TEST(IfElseSpecialForm, Shortcuts) {
[null, 1, 0],
[null, 1, 0]
])")),
ExecBatch(*RecordBatchFromJSON(schema, R"([
[null, 1, 0],
[null, null, 0],
[null, 1, null]
])")),
ExecBatch(*RecordBatchFromJSON(schema, R"([
[true, 1, 0],
[true, 1, 0],
[true, 1, 0]
])")),
ExecBatch(*RecordBatchFromJSON(schema, R"([
[true, 1, 0],
[true, null, 0],
[true, 1, null]
])")),
ExecBatch(*RecordBatchFromJSON(schema, R"([
[false, 1, 0],
[false, 1, 0],
[false, 1, 0]
])")),
ExecBatch(*RecordBatchFromJSON(schema, R"([
[false, 1, 0],
[false, null, 0],
[false, 1, null]
])")),
};
for (const auto& input : batches) {
{
Expand All @@ -209,4 +296,79 @@ TEST(IfElseSpecialForm, Shortcuts) {
}
}

namespace {

template <bool selection_vector_aware>
Status ConstantKernelExec(KernelContext*, const ExecSpan& span, ExecResult* out) {
DCHECK_EQ(span.num_values(), 1);
DCHECK_EQ(span.length, 1);
DCHECK(out->is_array_span());
DCHECK_EQ(out->length(), 1);
if constexpr (!selection_vector_aware) {
if (span.selection_vector != nullptr) {
return Status::Invalid("There is a selection vector");
}
}
int32_t* out_values = out->array_span_mutable()->GetValues<int32_t>(1);
*out_values = 0;
return Status::OK();
}

static Status RegisterConstantFunctions() {
auto registry = GetFunctionRegistry();

auto register_test_func = [&](const std::string& name,
bool selection_vector_aware) -> Status {
auto zero =
std::make_shared<ScalarFunction>(name, Arity::Unary(), FunctionDoc::Empty());

ArrayKernelExec exec;
if (selection_vector_aware) {
exec = ConstantKernelExec<true>;
} else {
exec = ConstantKernelExec<false>;
}
ScalarKernel kernel({InputType::Any()}, OutputType{int32()}, std::move(exec));
kernel.selection_vector_aware = selection_vector_aware;
kernel.can_write_into_slices = true;
kernel.null_handling = NullHandling::OUTPUT_NOT_NULL;
kernel.mem_allocation = MemAllocation::PREALLOCATE;
RETURN_NOT_OK(zero->AddKernel(kernel));
RETURN_NOT_OK(registry->AddFunction(std::move(zero)));
return Status::OK();
};

RETURN_NOT_OK(register_test_func("zero_panic", false));
RETURN_NOT_OK(register_test_func("zero_calm", true));

return Status::OK();
}

} // namespace

TEST(IfElseSpecialForm, Reference) {
ASSERT_OK(RegisterConstantFunctions());

auto schema = arrow::schema({field("a", int32()), field("b", int32())});
std::vector<ExecBatch> batches = {
ExecBatch(*RecordBatchFromJSON(schema, R"([])")),
ExecBatch(*RecordBatchFromJSON(schema, R"([
[1, 0],
[1, 0],
[1, 0]
])")),
ExecBatch(*RecordBatchFromJSON(schema, R"([
[1, 0],
[null, 0],
[1, null]
])")),
};
for (const auto& input : batches) {
auto expr = call("zero_panic", {literal(42)});
ASSERT_OK_AND_ASSIGN(auto bound, expr.Bind(*schema));
ASSERT_OK_AND_ASSIGN(auto result, ExecuteScalarExpression(bound, input));
std::cout << result.ToString() << std::endl;
}
}

} // namespace arrow::compute

0 comments on commit 6bab9b5

Please sign in to comment.