Skip to content
This repository has been archived by the owner on Sep 18, 2023. It is now read-only.

Commit

Permalink
Fix issue when spark only count numRows with no input
Browse files Browse the repository at this point in the history
Signed-off-by: Chendi Xue <chendi.xue@intel.com>
  • Loading branch information
xuechendi committed Feb 4, 2021
1 parent 3ab2ed5 commit 8222067
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ import org.apache.spark.sql.execution.vectorized.MutableColumnarRow
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
import org.apache.spark.sql.util.ArrowUtils
import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.KVIterator
import scala.collection.JavaConverters._

import scala.collection.Iterator

Expand Down Expand Up @@ -136,46 +137,80 @@ case class ColumnarHashAggregateExec(
nativeIterator.close
}

var numRowsInput = 0
// now we can return this wholestagecodegen iter
val res = new Iterator[ColumnarBatch] {
var processed = false
var skip_native = false
var count_num_row = 0
def process: Unit = {
while (iter.hasNext) {
val cb = iter.next()
numInputBatches += 1
if (cb.numRows != 0) {
numRowsInput += cb.numRows
val beforeEval = System.nanoTime()
val input_rb =
ConverterUtils.createArrowRecordBatch(cb)
nativeIterator.processAndCacheOne(hash_aggr_input_schema, input_rb)
ConverterUtils.releaseArrowRecordBatch(input_rb)
if (hash_aggr_input_schema.getFields.size == 0) {
// This is a special case used by only do count literal
count_num_row += cb.numRows
skip_native = true
} else {
val input_rb =
ConverterUtils.createArrowRecordBatch(cb)
nativeIterator.processAndCacheOne(hash_aggr_input_schema, input_rb)
ConverterUtils.releaseArrowRecordBatch(input_rb)
}
eval_elapse += System.nanoTime() - beforeEval
}
}
processed = true
}
override def hasNext: Boolean = {
if (!processed) process
nativeIterator.hasNext
if (skip_native) {
count_num_row > 0
} else {
nativeIterator.hasNext
}
}

override def next(): ColumnarBatch = {
if (!processed) process
val beforeEval = System.nanoTime()
val output_rb = nativeIterator.next
if (output_rb == null) {
eval_elapse += System.nanoTime() - beforeEval
if (skip_native) {
// special handling for only count literal in this operator
val out_res = count_num_row
count_num_row = 0
val resultColumnVectors =
ArrowWritableColumnVector.allocateColumns(0, resultStructType).toArray
return new ColumnarBatch(resultColumnVectors.map(_.asInstanceOf[ColumnVector]), 0)
resultColumnVectors.foreach { v =>
{
val numRows = v.dataType match {
case t: IntegerType =>
out_res.asInstanceOf[Number].intValue
case t: LongType =>
out_res.asInstanceOf[Number].longValue
}
v.put(0, numRows)
}
}
return new ColumnarBatch(resultColumnVectors.map(_.asInstanceOf[ColumnVector]), 1)
} else {
val output_rb = nativeIterator.next
if (output_rb == null) {
eval_elapse += System.nanoTime() - beforeEval
val resultColumnVectors =
ArrowWritableColumnVector.allocateColumns(0, resultStructType).toArray
return new ColumnarBatch(resultColumnVectors.map(_.asInstanceOf[ColumnVector]), 0)
}
val outputNumRows = output_rb.getLength
val output = ConverterUtils.fromArrowRecordBatch(hash_aggr_out_schema, output_rb)
ConverterUtils.releaseArrowRecordBatch(output_rb)
eval_elapse += System.nanoTime() - beforeEval
numOutputRows += outputNumRows
numOutputBatches += 1
new ColumnarBatch(output.map(v => v.asInstanceOf[ColumnVector]), outputNumRows)
}
val outputNumRows = output_rb.getLength
val output = ConverterUtils.fromArrowRecordBatch(hash_aggr_out_schema, output_rb)
ConverterUtils.releaseArrowRecordBatch(output_rb)
eval_elapse += System.nanoTime() - beforeEval
numOutputRows += outputNumRows
numOutputBatches += 1
new ColumnarBatch(output.map(v => v.asInstanceOf[ColumnVector]), outputNumRows)
}
}
SparkMemoryUtils.addLeakSafeTaskCompletionListener[Unit](_ => {
Expand All @@ -201,7 +236,7 @@ case class ColumnarHashAggregateExec(
try {
ConverterUtils.checkIfTypeSupported(expr.dataType)
} catch {
case e : UnsupportedOperationException =>
case e: UnsupportedOperationException =>
throw new UnsupportedOperationException(
s"${expr.dataType} is not supported in ColumnarAggregation")
}
Expand Down
20 changes: 14 additions & 6 deletions cpp/src/codegen/arrow_compute/ext/actions_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ arrow::Status ActionBase::Submit(const std::shared_ptr<arrow::Array>& in,
return arrow::Status::NotImplemented("ActionBase Submit is abstract.");
}

arrow::Status ActionBase::EvaluateCountLiteral(const int& len) {
return arrow::Status::NotImplemented("ActionBase EvaluateCountLiteral is abstract.");
}

arrow::Status ActionBase::Evaluate(int dest_group_id) {
return arrow::Status::NotImplemented("ActionBase Evaluate is abstract.");
}
Expand Down Expand Up @@ -287,8 +291,7 @@ class CountAction : public ActionBase {
std::cout << "Construct CountAction" << std::endl;
#endif
std::unique_ptr<arrow::ArrayBuilder> array_builder;
arrow::MakeBuilder(ctx_->memory_pool(), arrow::TypeTraits<DataType>::type_singleton(),
&array_builder);
arrow::MakeBuilder(ctx_->memory_pool(), arrow::int64(), &array_builder);
builder_.reset(
arrow::internal::checked_cast<ResBuilderType*>(array_builder.release()));
}
Expand Down Expand Up @@ -470,15 +473,20 @@ class CountLiteralAction : public ActionBase {
return arrow::Status::OK();
}

arrow::Status Evaluate(const arrow::ArrayVector& in) {
arrow::Status EvaluateCountLiteral(const int& len) {
if (cache_.empty()) {
cache_.resize(1, 0);
length_ = 1;
}
cache_[0] += (in[0]->length() - in[0]->null_count()) * arg_;
cache_[0] += len;
return arrow::Status::OK();
}

arrow::Status Evaluate(const arrow::ArrayVector& in) {
return arrow::Status::NotImplemented(
"CountLiteralAction Non-Groupby Evaluate is unsupported.");
}

arrow::Status Evaluate(int dest_group_id) {
auto target_group_size = dest_group_id + 1;
if (cache_.size() <= target_group_size) GrowByFactor(target_group_size);
Expand Down Expand Up @@ -2379,14 +2387,14 @@ arrow::Status MakeUniqueAction(arrow::compute::FunctionContext* ctx,

arrow::Status MakeCountAction(arrow::compute::FunctionContext* ctx,
std::shared_ptr<ActionBase>* out) {
auto action_ptr = std::make_shared<CountAction<arrow::UInt64Type>>(ctx);
auto action_ptr = std::make_shared<CountAction<arrow::Int64Type>>(ctx);
*out = std::dynamic_pointer_cast<ActionBase>(action_ptr);
return arrow::Status::OK();
}

arrow::Status MakeCountLiteralAction(arrow::compute::FunctionContext* ctx, int arg,
std::shared_ptr<ActionBase>* out) {
auto action_ptr = std::make_shared<CountLiteralAction<arrow::UInt64Type>>(ctx, arg);
auto action_ptr = std::make_shared<CountLiteralAction<arrow::Int64Type>>(ctx, arg);
*out = std::dynamic_pointer_cast<ActionBase>(action_ptr);
return arrow::Status::OK();
}
Expand Down
1 change: 1 addition & 0 deletions cpp/src/codegen/arrow_compute/ext/actions_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class ActionBase {
virtual arrow::Status Submit(const std::shared_ptr<arrow::Array>& in,
std::function<arrow::Status(uint32_t)>* on_valid,
std::function<arrow::Status()>* on_null);
virtual arrow::Status EvaluateCountLiteral(const int& len);
virtual arrow::Status Evaluate(const arrow::ArrayVector& in);
virtual arrow::Status Evaluate(int dest_group_id);
virtual arrow::Status Evaluate(int dest_group_id, void* data);
Expand Down
21 changes: 17 additions & 4 deletions cpp/src/codegen/arrow_compute/ext/hash_aggregate_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,12 @@ class HashAggregateKernel::Impl {
std::shared_ptr<arrow::Schema> schema,
std::shared_ptr<ResultIterator<arrow::RecordBatch>>* out) {
// 1. create pre project
auto pre_process_expr_list = GetGandivaKernel(prepare_function_list_);
auto pre_process_projector = std::make_shared<GandivaProjector>(
ctx_, arrow::schema(input_field_list_), pre_process_expr_list);
std::shared_ptr<GandivaProjector> pre_process_projector;
if (!prepare_function_list_.empty()) {
auto pre_process_expr_list = GetGandivaKernel(prepare_function_list_);
pre_process_projector = std::make_shared<GandivaProjector>(
ctx_, arrow::schema(input_field_list_), pre_process_expr_list);
}

// 2. action_impl_list
std::vector<std::shared_ptr<ActionBase>> action_impl_list;
Expand Down Expand Up @@ -573,6 +576,8 @@ class HashAggregateKernel::Impl {
arrow::ArrayVector in;
if (pre_process_projector_) {
in = pre_process_projector_->Evaluate(orig_in);
} else {
in = orig_in;
}

// 2.1 handle no groupby scenario
Expand All @@ -584,7 +589,13 @@ class HashAggregateKernel::Impl {
for (auto idx : action_prepare_index_list_[i]) {
cols.push_back(in[idx]);
}
RETURN_NOT_OK(action->Evaluate(cols));
if (cols.empty()) {
// There is a special case, when we need to do no groupby count literal
RETURN_NOT_OK(action->EvaluateCountLiteral(in[0]->length()));

} else {
RETURN_NOT_OK(action->Evaluate(cols));
}
}
total_out_length_ = 1;
return arrow::Status::OK();
Expand Down Expand Up @@ -716,6 +727,8 @@ class HashAggregateKernel::Impl {
arrow::ArrayVector in;
if (pre_process_projector_) {
in = pre_process_projector_->Evaluate(orig_in);
} else {
in = orig_in;
}

// 2. handle multiple keys
Expand Down
12 changes: 8 additions & 4 deletions cpp/src/tests/arrow_compute_test_aggregate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ TEST(TestArrowCompute, AggregateTest) {
auto n_max = TreeExprBuilder::MakeFunction("action_max", {arg_0}, int64());
auto n_stddev = TreeExprBuilder::MakeFunction("action_stddev_samp_final",
{arg_2, arg_3, arg_4}, float64());
auto n_count_literal =
TreeExprBuilder::MakeFunction("action_countLiteral_1", {}, int64());

auto f_sum = field("sum", int64());
auto f_count = field("count", int64());
Expand All @@ -58,12 +60,14 @@ TEST(TestArrowCompute, AggregateTest) {
auto f_min = field("min", int64());
auto f_max = field("max", int64());
auto f_stddev = field("stddev", float64());
auto f_count_literal = field("count_all", int64());
auto f_res = field("res", int32());

auto n_proj = TreeExprBuilder::MakeFunction(
"aggregateExpressions", {arg_0, arg_1, arg_2, arg_3, arg_4}, uint32());
auto n_action = TreeExprBuilder::MakeFunction(
"aggregateActions", {n_sum, n_count, n_sum_count, n_avg, n_min, n_max, n_stddev},
"aggregateActions",
{n_sum, n_count, n_sum_count, n_avg, n_min, n_max, n_stddev, n_count_literal},
uint32());

auto n_aggr =
Expand All @@ -72,8 +76,8 @@ TEST(TestArrowCompute, AggregateTest) {
auto aggr_expr = TreeExprBuilder::MakeExpression(n_child, f_res);

auto sch = arrow::schema({f0, f1, f2, f3, f4});
std::vector<std::shared_ptr<Field>> ret_types = {f_sum, f_count, f_sum, f_count,
f_avg, f_min, f_max, f_stddev};
std::vector<std::shared_ptr<Field>> ret_types = {
f_sum, f_count, f_sum, f_count, f_avg, f_min, f_max, f_stddev, f_count_literal};
///////////////////// Calculation //////////////////
std::shared_ptr<CodeGenerator> expr;
arrow::compute::FunctionContext ctx;
Expand Down Expand Up @@ -107,7 +111,7 @@ TEST(TestArrowCompute, AggregateTest) {
std::shared_ptr<arrow::RecordBatch> expected_result;
std::shared_ptr<arrow::RecordBatch> result_batch;
std::vector<std::string> expected_result_string = {
"[221]", "[39]", "[221]", "[39]", "[4.40724]", "[1]", "[10]", "[17.2996]"};
"[221]", "[39]", "[221]", "[39]", "[4.40724]", "[1]", "[10]", "[17.2996]", "[40]"};
auto res_sch = arrow::schema(ret_types);
MakeInputBatch(expected_result_string, res_sch, &expected_result);
if (aggr_result_iterator->HasNext()) {
Expand Down

0 comments on commit 8222067

Please sign in to comment.