Skip to content

Commit

Permalink
Merge pull request #1 from zanmato1984/fix-39332-2
Browse files Browse the repository at this point in the history
Give explicit error in ExecBatchBuilder when appending data exceeds o…
  • Loading branch information
zanmato1984 authored Jan 18, 2024
2 parents b03c71c + c7c6295 commit a0204d6
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 56 deletions.
54 changes: 26 additions & 28 deletions cpp/src/arrow/compute/light_array.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

#include <type_traits>

#include "arrow/array/builder_binary.h"
#include "arrow/util/bitmap_ops.h"
#include "arrow/util/int_util_overflow.h"
#include "arrow/util/macros.h"

namespace arrow {
Expand Down Expand Up @@ -470,7 +470,7 @@ void ExecBatchBuilder::Visit(const std::shared_ptr<ArrayData>& column, int num_r
for (int i = 0; i < num_rows; ++i) {
uint16_t row_id = row_ids[i];
const uint8_t* field_ptr = ptr_base + offsets[row_id];
uint32_t field_length = offsets[row_id + 1] - offsets[row_id];
int32_t field_length = offsets[row_id + 1] - offsets[row_id];
process_value_fn(i, field_ptr, field_length);
}
} else {
Expand All @@ -480,7 +480,7 @@ void ExecBatchBuilder::Visit(const std::shared_ptr<ArrayData>& column, int num_r
const uint8_t* field_ptr =
column->buffers[1]->data() +
(column->offset + row_id) * static_cast<int64_t>(metadata.fixed_length);
process_value_fn(i, field_ptr, metadata.fixed_length);
process_value_fn(i, field_ptr, static_cast<int32_t>(metadata.fixed_length));
}
}
}
Expand Down Expand Up @@ -511,30 +511,30 @@ Status ExecBatchBuilder::AppendSelected(const std::shared_ptr<ArrayData>& source
break;
case 1:
Visit(source, num_rows_to_append, row_ids,
[&](int i, const uint8_t* ptr, uint32_t num_bytes) {
[&](int i, const uint8_t* ptr, int32_t num_bytes) {
target->mutable_data(1)[num_rows_before + i] = *ptr;
});
break;
case 2:
Visit(
source, num_rows_to_append, row_ids,
[&](int i, const uint8_t* ptr, uint32_t num_bytes) {
[&](int i, const uint8_t* ptr, int32_t num_bytes) {
reinterpret_cast<uint16_t*>(target->mutable_data(1))[num_rows_before + i] =
*reinterpret_cast<const uint16_t*>(ptr);
});
break;
case 4:
Visit(
source, num_rows_to_append, row_ids,
[&](int i, const uint8_t* ptr, uint32_t num_bytes) {
[&](int i, const uint8_t* ptr, int32_t num_bytes) {
reinterpret_cast<uint32_t*>(target->mutable_data(1))[num_rows_before + i] =
*reinterpret_cast<const uint32_t*>(ptr);
});
break;
case 8:
Visit(
source, num_rows_to_append, row_ids,
[&](int i, const uint8_t* ptr, uint32_t num_bytes) {
[&](int i, const uint8_t* ptr, int32_t num_bytes) {
reinterpret_cast<uint64_t*>(target->mutable_data(1))[num_rows_before + i] =
*reinterpret_cast<const uint64_t*>(ptr);
});
Expand All @@ -544,7 +544,7 @@ Status ExecBatchBuilder::AppendSelected(const std::shared_ptr<ArrayData>& source
num_rows_to_append -
NumRowsToSkip(source, num_rows_to_append, row_ids, sizeof(uint64_t));
Visit(source, num_rows_to_process, row_ids,
[&](int i, const uint8_t* ptr, uint32_t num_bytes) {
[&](int i, const uint8_t* ptr, int32_t num_bytes) {
uint64_t* dst = reinterpret_cast<uint64_t*>(
target->mutable_data(1) +
static_cast<int64_t>(num_bytes) * (num_rows_before + i));
Expand All @@ -558,7 +558,7 @@ Status ExecBatchBuilder::AppendSelected(const std::shared_ptr<ArrayData>& source
if (num_rows_to_append > num_rows_to_process) {
Visit(source, num_rows_to_append - num_rows_to_process,
row_ids + num_rows_to_process,
[&](int i, const uint8_t* ptr, uint32_t num_bytes) {
[&](int i, const uint8_t* ptr, int32_t num_bytes) {
uint64_t* dst = reinterpret_cast<uint64_t*>(
target->mutable_data(1) +
static_cast<int64_t>(num_bytes) *
Expand All @@ -576,25 +576,23 @@ Status ExecBatchBuilder::AppendSelected(const std::shared_ptr<ArrayData>& source
// Step 1: calculate target offsets
//
int32_t* offsets = reinterpret_cast<int32_t*>(target->mutable_data(1));
{
int64_t sum = num_rows_before == 0 ? 0 : offsets[num_rows_before];
Visit(source, num_rows_to_append, row_ids,
[&](int i, const uint8_t* ptr, uint32_t num_bytes) {
offsets[num_rows_before + i] = num_bytes;
});
for (int i = 0; i < num_rows_to_append; ++i) {
int32_t length = offsets[num_rows_before + i];
offsets[num_rows_before + i] = static_cast<int32_t>(sum);
if (ARROW_PREDICT_FALSE(sum + length > BinaryBuilder::memory_limit())) {
return Status::Invalid("ExecBatchBuilder cannot contain more than ",
BinaryBuilder::memory_limit(), " bytes, current ", sum,
", appending ", num_rows_before + i + 1,
"-th element of length ", length);
}
sum += length;
int32_t sum = num_rows_before == 0 ? 0 : offsets[num_rows_before];
Visit(source, num_rows_to_append, row_ids,
[&](int i, const uint8_t* ptr, int32_t num_bytes) {
offsets[num_rows_before + i] = num_bytes;
});
for (int i = 0; i < num_rows_to_append; ++i) {
int32_t length = offsets[num_rows_before + i];
offsets[num_rows_before + i] = sum;
int32_t new_sum_maybe_overflow = 0;
if (ARROW_PREDICT_FALSE(internal::AddWithOverflow(sum, length, &new_sum_maybe_overflow))) {
return Status::Invalid("Overflow detected in ExecBatchBuilder when appending ",
num_rows_before + i + 1, "-th element of length ", length,
" bytes to current length ", sum, " bytes");
}
offsets[num_rows_before + num_rows_to_append] = static_cast<int32_t>(sum);
sum = new_sum_maybe_overflow;
}
offsets[num_rows_before + num_rows_to_append] = sum;

// Step 2: resize output buffers
//
Expand All @@ -606,7 +604,7 @@ Status ExecBatchBuilder::AppendSelected(const std::shared_ptr<ArrayData>& source
num_rows_to_append -
NumRowsToSkip(source, num_rows_to_append, row_ids, sizeof(uint64_t));
Visit(source, num_rows_to_process, row_ids,
[&](int i, const uint8_t* ptr, uint32_t num_bytes) {
[&](int i, const uint8_t* ptr, int32_t num_bytes) {
uint64_t* dst = reinterpret_cast<uint64_t*>(target->mutable_data(2) +
offsets[num_rows_before + i]);
const uint64_t* src = reinterpret_cast<const uint64_t*>(ptr);
Expand All @@ -616,7 +614,7 @@ Status ExecBatchBuilder::AppendSelected(const std::shared_ptr<ArrayData>& source
}
});
Visit(source, num_rows_to_append - num_rows_to_process, row_ids + num_rows_to_process,
[&](int i, const uint8_t* ptr, uint32_t num_bytes) {
[&](int i, const uint8_t* ptr, int32_t num_bytes) {
uint64_t* dst = reinterpret_cast<uint64_t*>(
target->mutable_data(2) +
offsets[num_rows_before + num_rows_to_process + i]);
Expand Down
58 changes: 30 additions & 28 deletions cpp/src/arrow/compute/light_array_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include <gtest/gtest.h>
#include <numeric>

#include "arrow/array/builder_binary.h"
#include "arrow/testing/generator.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/type.h"
Expand Down Expand Up @@ -409,56 +408,59 @@ TEST(ExecBatchBuilder, AppendValuesBeyondLimit) {
}

TEST(ExecBatchBuilder, AppendVarLengthBeyondLimit) {
// Corresponds to GH-39332.
std::unique_ptr<MemoryPool> owned_pool = MemoryPool::CreateDefault();
MemoryPool* pool = owned_pool.get();
constexpr auto eight_mb = 8 * 1024 * 1024;
constexpr auto eight_mb_minus_one = eight_mb - 1;
// String of size 8mb to repetitively fill the heading multiple of 8mbs of an array
// of int32_max bytes.
std::string str_8mb(eight_mb, 'a');
std::string str_8mb_minus_1(eight_mb - 1, 'b');
std::string str_8mb_minus_2(eight_mb - 2,
'b'); // BaseBinaryBuilder::memory_limit()
// String of size (8mb - 1) to be the last element of an array of int32_max bytes.
std::string str_8mb_minus_1(eight_mb_minus_one, 'b');
std::shared_ptr<Array> values_8mb = ConstantArrayGenerator::String(1, str_8mb);
std::shared_ptr<Array> values_8mb_minus_1 =
ConstantArrayGenerator::String(1, str_8mb_minus_1);
std::shared_ptr<Array> values_8mb_minus_2 =
ConstantArrayGenerator::String(1, str_8mb_minus_2);

ExecBatch batch_8mb({values_8mb}, 1);
ExecBatch batch_8mb_minus_1({values_8mb_minus_1}, 1);
ExecBatch batch_8mb_minus_2({values_8mb_minus_2}, 1);

auto num_rows = std::numeric_limits<int32_t>::max() / eight_mb;
StringBuilder values_ref_builder;
ASSERT_OK(values_ref_builder.Reserve(num_rows + 1));
for (int i = 0; i < num_rows; i++) {
ASSERT_OK(values_ref_builder.Append(str_8mb));
}
ASSERT_OK(values_ref_builder.Append(str_8mb_minus_2));
ASSERT_OK_AND_ASSIGN(auto values_ref, values_ref_builder.Finish());
auto values_ref_1 = values_ref->Slice(0, num_rows);
ExecBatch batch_ref_1({values_ref_1}, num_rows);
ExecBatch batch_ref_2({values_ref}, num_rows + 1);

std::vector<uint16_t> first_set_row_ids(num_rows, 0);
std::vector<uint16_t> second_set_row_ids(1, 0);
std::vector<uint16_t> body_row_ids(num_rows, 0);
std::vector<uint16_t> tail_row_id(1, 0);

{
// Building an array of (int32_max + 1) = (8mb * num_rows + 8mb) bytes should raise an
// error of overflow.
ExecBatchBuilder builder;
ASSERT_OK(builder.AppendSelected(pool, batch_8mb, num_rows, first_set_row_ids.data(),
ASSERT_OK(builder.AppendSelected(pool, batch_8mb, num_rows, body_row_ids.data(),
/*num_cols=*/1));
ASSERT_RAISES(Invalid, builder.AppendSelected(pool, batch_8mb_minus_1, 1,
second_set_row_ids.data(),
/*num_cols=*/1));
std::stringstream ss;
ss << "Invalid: Overflow detected in ExecBatchBuilder when appending " << num_rows + 1
<< "-th element of length " << eight_mb << " bytes to current length "
<< eight_mb * num_rows << " bytes";
ASSERT_RAISES_WITH_MESSAGE(
Invalid, ss.str(),
builder.AppendSelected(pool, batch_8mb, 1, tail_row_id.data(),
/*num_cols=*/1));
}

{
// Building an array of int32_max = (8mb * num_rows + 8mb - 1) bytes should succeed.
ExecBatchBuilder builder;
ASSERT_OK(builder.AppendSelected(pool, batch_8mb, num_rows, first_set_row_ids.data(),
ASSERT_OK(builder.AppendSelected(pool, batch_8mb, num_rows, body_row_ids.data(),
/*num_cols=*/1));
ASSERT_OK(builder.AppendSelected(pool, batch_8mb_minus_2, 1,
second_set_row_ids.data(),
ASSERT_OK(builder.AppendSelected(pool, batch_8mb_minus_1, 1, tail_row_id.data(),
/*num_cols=*/1));
ExecBatch built = builder.Flush();
ASSERT_EQ(batch_ref_2, built);
auto datum = built[0];
ASSERT_TRUE(datum.is_array());
auto array = datum.array_as<StringArray>();
ASSERT_EQ(array->length(), num_rows + 1);
for (int i = 0; i < num_rows; ++i) {
ASSERT_EQ(array->GetString(i), str_8mb);
}
ASSERT_EQ(array->GetString(num_rows), str_8mb_minus_1);
ASSERT_NE(0, pool->bytes_allocated());
}

Expand Down

0 comments on commit a0204d6

Please sign in to comment.