Skip to content

Commit

Permalink
ARROW-13573: [C++] Support dictionaries natively in case_when
Browse files Browse the repository at this point in the history
This supports dictionaries 'natively', that is, dictionaries are no longer always unpacked. (If mixed dictionary and non-dictionary arguments are given, then they will be unpacked.)

For scalar conditions, the output will have the dictionary of whichever input is selected (or no dictionary if the output is null). For array conditions, we unify the dictionaries as we select elements.

Closes #11022 from lidavidm/arrow-13573

Authored-by: David Li <li.davidm96@gmail.com>
Signed-off-by: Antoine Pitrou <antoine@python.org>
  • Loading branch information
lidavidm authored and pitrou committed Sep 21, 2021
1 parent 6a6b464 commit 87e2ad5
Show file tree
Hide file tree
Showing 19 changed files with 692 additions and 173 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cpp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ jobs:
name: AMD64 Windows MinGW ${{ matrix.mingw-n-bits }} C++
runs-on: windows-latest
if: ${{ !contains(github.event.pull_request.title, 'WIP') }}
timeout-minutes: 45
timeout-minutes: 60
strategy:
fail-fast: false
matrix:
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/r.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ jobs:
name: AMD64 Ubuntu ${{ matrix.ubuntu }} R ${{ matrix.r }}
runs-on: ubuntu-latest
if: ${{ !contains(github.event.pull_request.title, 'WIP') }}
timeout-minutes: 60
timeout-minutes: 75
strategy:
fail-fast: false
matrix:
Expand Down
6 changes: 5 additions & 1 deletion ci/scripts/PKGBUILD
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,13 @@ build() {
export LIBS="-L${MINGW_PREFIX}/libs"
export ARROW_S3=OFF
export ARROW_WITH_RE2=OFF
# Without this, some dataset functionality segfaults
export CMAKE_UNITY_BUILD=ON
else
export ARROW_S3=ON
export ARROW_WITH_RE2=ON
# Without this, some compute functionality segfaults in tests
export CMAKE_UNITY_BUILD=OFF
fi

MSYS2_ARG_CONV_EXCL="-DCMAKE_INSTALL_PREFIX=" \
Expand Down Expand Up @@ -115,7 +119,7 @@ build() {
-DARROW_CXXFLAGS="${CPPFLAGS}" \
-DCMAKE_BUILD_TYPE="release" \
-DCMAKE_INSTALL_PREFIX=${MINGW_PREFIX} \
-DCMAKE_UNITY_BUILD=ON \
-DCMAKE_UNITY_BUILD=${CMAKE_UNITY_BUILD} \
-DCMAKE_VERBOSE_MAKEFILE=ON

make -j3
Expand Down
12 changes: 5 additions & 7 deletions cpp/src/arrow/array/array_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,7 @@ TEST_F(TestArray, TestValidateNullCount) {
void AssertAppendScalar(MemoryPool* pool, const std::shared_ptr<Scalar>& scalar) {
std::unique_ptr<arrow::ArrayBuilder> builder;
auto null_scalar = MakeNullScalar(scalar->type);
ASSERT_OK(MakeBuilder(pool, scalar->type, &builder));
ASSERT_OK(MakeBuilderExactIndex(pool, scalar->type, &builder));
ASSERT_OK(builder->AppendScalar(*scalar));
ASSERT_OK(builder->AppendScalar(*scalar));
ASSERT_OK(builder->AppendScalar(*null_scalar));
Expand All @@ -471,15 +471,18 @@ void AssertAppendScalar(MemoryPool* pool, const std::shared_ptr<Scalar>& scalar)
ASSERT_EQ(out->length(), 9);

const bool can_check_nulls = internal::HasValidityBitmap(out->type()->id());
// For a dictionary builder, the output dictionary won't necessarily be the same
const bool can_check_values = !is_dictionary(out->type()->id());

if (can_check_nulls) {
ASSERT_EQ(out->null_count(), 4);
}

for (const auto index : {0, 1, 3, 5, 6}) {
ASSERT_FALSE(out->IsNull(index));
ASSERT_OK_AND_ASSIGN(auto scalar_i, out->GetScalar(index));
ASSERT_OK(scalar_i->ValidateFull());
AssertScalarsEqual(*scalar, *scalar_i, /*verbose=*/true);
if (can_check_values) AssertScalarsEqual(*scalar, *scalar_i, /*verbose=*/true);
}
for (const auto index : {2, 4, 7, 8}) {
ASSERT_EQ(out->IsNull(index), can_check_nulls);
Expand Down Expand Up @@ -575,8 +578,6 @@ TEST_F(TestArray, TestMakeArrayFromScalar) {
}

for (auto scalar : scalars) {
// TODO(ARROW-13197): appending dictionary scalars not implemented
if (is_dictionary(scalar->type->id())) continue;
AssertAppendScalar(pool_, scalar);
}
}
Expand Down Expand Up @@ -634,9 +635,6 @@ TEST_F(TestArray, TestMakeArrayFromMapScalar) {
TEST_F(TestArray, TestAppendArraySlice) {
auto scalars = GetScalars();
for (const auto& scalar : scalars) {
// TODO(ARROW-13573): appending dictionary arrays not implemented
if (is_dictionary(scalar->type->id())) continue;

ARROW_SCOPED_TRACE(*scalar->type);
ASSERT_OK_AND_ASSIGN(auto array, MakeArrayFromScalar(*scalar, 16));
ASSERT_OK_AND_ASSIGN(auto nulls, MakeArrayOfNull(scalar->type, 16));
Expand Down
10 changes: 1 addition & 9 deletions cpp/src/arrow/array/builder_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <vector>

#include "arrow/array/array_base.h"
#include "arrow/array/builder_dict.h"
#include "arrow/array/data.h"
#include "arrow/array/util.h"
#include "arrow/buffer.h"
Expand Down Expand Up @@ -268,15 +269,6 @@ struct AppendScalarImpl {

} // namespace

Status ArrayBuilder::AppendScalar(const Scalar& scalar) {
if (!scalar.type->Equals(type())) {
return Status::Invalid("Cannot append scalar of type ", scalar.type->ToString(),
" to builder for type ", type()->ToString());
}
std::shared_ptr<Scalar> shared{const_cast<Scalar*>(&scalar), [](Scalar*) {}};
return AppendScalarImpl{&shared, &shared + 1, /*n_repeats=*/1, this}.Convert();
}

Status ArrayBuilder::AppendScalar(const Scalar& scalar, int64_t n_repeats) {
if (!scalar.type->Equals(type())) {
return Status::Invalid("Cannot append scalar of type ", scalar.type->ToString(),
Expand Down
13 changes: 10 additions & 3 deletions cpp/src/arrow/array/builder_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ class ARROW_EXPORT ArrayBuilder {
virtual Status AppendEmptyValues(int64_t length) = 0;

/// \brief Append a value from a scalar
Status AppendScalar(const Scalar& scalar);
Status AppendScalar(const Scalar& scalar, int64_t n_repeats);
Status AppendScalars(const ScalarVector& scalars);
Status AppendScalar(const Scalar& scalar) { return AppendScalar(scalar, 1); }
virtual Status AppendScalar(const Scalar& scalar, int64_t n_repeats);
virtual Status AppendScalars(const ScalarVector& scalars);

/// \brief Append a range of values from an array.
///
Expand Down Expand Up @@ -282,6 +282,13 @@ ARROW_EXPORT
Status MakeBuilder(MemoryPool* pool, const std::shared_ptr<DataType>& type,
std::unique_ptr<ArrayBuilder>* out);

/// \brief Construct an empty ArrayBuilder corresponding to the data
/// type, where any top-level or nested dictionary builders return the
/// exact index type specified by the type.
ARROW_EXPORT
Status MakeBuilderExactIndex(MemoryPool* pool, const std::shared_ptr<DataType>& type,
std::unique_ptr<ArrayBuilder>* out);

/// \brief Construct an empty DictionaryBuilder initialized optionally
/// with a pre-existing dictionary
/// \param[in] pool the MemoryPool to use for allocations
Expand Down
39 changes: 24 additions & 15 deletions cpp/src/arrow/array/builder_dict.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,23 +159,32 @@ DictionaryMemoTable::DictionaryMemoTable(MemoryPool* pool,

DictionaryMemoTable::~DictionaryMemoTable() = default;

#define GET_OR_INSERT(C_TYPE) \
Status DictionaryMemoTable::GetOrInsert( \
const typename CTypeTraits<C_TYPE>::ArrowType*, C_TYPE value, int32_t* out) { \
return impl_->GetOrInsert<typename CTypeTraits<C_TYPE>::ArrowType>(value, out); \
#define GET_OR_INSERT(ARROW_TYPE) \
Status DictionaryMemoTable::GetOrInsert( \
const ARROW_TYPE*, typename ARROW_TYPE::c_type value, int32_t* out) { \
return impl_->GetOrInsert<ARROW_TYPE>(value, out); \
}

GET_OR_INSERT(bool)
GET_OR_INSERT(int8_t)
GET_OR_INSERT(int16_t)
GET_OR_INSERT(int32_t)
GET_OR_INSERT(int64_t)
GET_OR_INSERT(uint8_t)
GET_OR_INSERT(uint16_t)
GET_OR_INSERT(uint32_t)
GET_OR_INSERT(uint64_t)
GET_OR_INSERT(float)
GET_OR_INSERT(double)
GET_OR_INSERT(BooleanType)
GET_OR_INSERT(Int8Type)
GET_OR_INSERT(Int16Type)
GET_OR_INSERT(Int32Type)
GET_OR_INSERT(Int64Type)
GET_OR_INSERT(UInt8Type)
GET_OR_INSERT(UInt16Type)
GET_OR_INSERT(UInt32Type)
GET_OR_INSERT(UInt64Type)
GET_OR_INSERT(FloatType)
GET_OR_INSERT(DoubleType)
GET_OR_INSERT(DurationType);
GET_OR_INSERT(TimestampType);
GET_OR_INSERT(Date32Type);
GET_OR_INSERT(Date64Type);
GET_OR_INSERT(Time32Type);
GET_OR_INSERT(Time64Type);
GET_OR_INSERT(MonthDayNanoIntervalType);
GET_OR_INSERT(DayTimeIntervalType);
GET_OR_INSERT(MonthIntervalType);

#undef GET_OR_INSERT

Expand Down
110 changes: 110 additions & 0 deletions cpp/src/arrow/array/builder_dict.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "arrow/util/decimal.h"
#include "arrow/util/macros.h"
#include "arrow/util/visibility.h"
#include "arrow/visitor_inline.h"

namespace arrow {

Expand Down Expand Up @@ -97,6 +98,17 @@ class ARROW_EXPORT DictionaryMemoTable {
Status GetOrInsert(const UInt16Type*, uint16_t value, int32_t* out);
Status GetOrInsert(const UInt32Type*, uint32_t value, int32_t* out);
Status GetOrInsert(const UInt64Type*, uint64_t value, int32_t* out);
Status GetOrInsert(const DurationType*, int64_t value, int32_t* out);
Status GetOrInsert(const TimestampType*, int64_t value, int32_t* out);
Status GetOrInsert(const Date32Type*, int32_t value, int32_t* out);
Status GetOrInsert(const Date64Type*, int64_t value, int32_t* out);
Status GetOrInsert(const Time32Type*, int32_t value, int32_t* out);
Status GetOrInsert(const Time64Type*, int64_t value, int32_t* out);
Status GetOrInsert(const MonthDayNanoIntervalType*,
MonthDayNanoIntervalType::MonthDayNanos value, int32_t* out);
Status GetOrInsert(const DayTimeIntervalType*,
DayTimeIntervalType::DayMilliseconds value, int32_t* out);
Status GetOrInsert(const MonthIntervalType*, int32_t value, int32_t* out);
Status GetOrInsert(const FloatType*, float value, int32_t* out);
Status GetOrInsert(const DoubleType*, double value, int32_t* out);

Expand Down Expand Up @@ -282,6 +294,73 @@ class DictionaryBuilderBase : public ArrayBuilder {
return indices_builder_.AppendEmptyValues(length);
}

Status AppendScalar(const Scalar& scalar, int64_t n_repeats) override {
if (!scalar.is_valid) return AppendNulls(n_repeats);

const auto& dict_ty = internal::checked_cast<const DictionaryType&>(*scalar.type);
const DictionaryScalar& dict_scalar =
internal::checked_cast<const DictionaryScalar&>(scalar);
const auto& dict = internal::checked_cast<const typename TypeTraits<T>::ArrayType&>(
*dict_scalar.value.dictionary);
ARROW_RETURN_NOT_OK(Reserve(n_repeats));
switch (dict_ty.index_type()->id()) {
case Type::UINT8:
return AppendScalarImpl<UInt8Type>(dict, *dict_scalar.value.index, n_repeats);
case Type::INT8:
return AppendScalarImpl<Int8Type>(dict, *dict_scalar.value.index, n_repeats);
case Type::UINT16:
return AppendScalarImpl<UInt16Type>(dict, *dict_scalar.value.index, n_repeats);
case Type::INT16:
return AppendScalarImpl<Int16Type>(dict, *dict_scalar.value.index, n_repeats);
case Type::UINT32:
return AppendScalarImpl<UInt32Type>(dict, *dict_scalar.value.index, n_repeats);
case Type::INT32:
return AppendScalarImpl<Int32Type>(dict, *dict_scalar.value.index, n_repeats);
case Type::UINT64:
return AppendScalarImpl<UInt64Type>(dict, *dict_scalar.value.index, n_repeats);
case Type::INT64:
return AppendScalarImpl<Int64Type>(dict, *dict_scalar.value.index, n_repeats);
default:
return Status::TypeError("Invalid index type: ", dict_ty);
}
return Status::OK();
}

Status AppendScalars(const ScalarVector& scalars) override {
for (const auto& scalar : scalars) {
ARROW_RETURN_NOT_OK(AppendScalar(*scalar, /*n_repeats=*/1));
}
return Status::OK();
}

Status AppendArraySlice(const ArrayData& array, int64_t offset, int64_t length) final {
// Visit the indices and insert the unpacked values.
const auto& dict_ty = internal::checked_cast<const DictionaryType&>(*array.type);
const typename TypeTraits<T>::ArrayType dict(array.dictionary);
ARROW_RETURN_NOT_OK(Reserve(length));
switch (dict_ty.index_type()->id()) {
case Type::UINT8:
return AppendArraySliceImpl<uint8_t>(dict, array, offset, length);
case Type::INT8:
return AppendArraySliceImpl<int8_t>(dict, array, offset, length);
case Type::UINT16:
return AppendArraySliceImpl<uint16_t>(dict, array, offset, length);
case Type::INT16:
return AppendArraySliceImpl<int16_t>(dict, array, offset, length);
case Type::UINT32:
return AppendArraySliceImpl<uint32_t>(dict, array, offset, length);
case Type::INT32:
return AppendArraySliceImpl<int32_t>(dict, array, offset, length);
case Type::UINT64:
return AppendArraySliceImpl<uint64_t>(dict, array, offset, length);
case Type::INT64:
return AppendArraySliceImpl<int64_t>(dict, array, offset, length);
default:
return Status::TypeError("Invalid index type: ", dict_ty);
}
return Status::OK();
}

/// \brief Insert values into the dictionary's memo, but do not append any
/// indices. Can be used to initialize a new builder with known dictionary
/// values
Expand Down Expand Up @@ -376,6 +455,37 @@ class DictionaryBuilderBase : public ArrayBuilder {
}

protected:
template <typename c_type>
Status AppendArraySliceImpl(const typename TypeTraits<T>::ArrayType& dict,
const ArrayData& array, int64_t offset, int64_t length) {
const c_type* values = array.GetValues<c_type>(1) + offset;
return VisitBitBlocks(
array.buffers[0], array.offset + offset, length,
[&](const int64_t position) {
const int64_t index = static_cast<int64_t>(values[position]);
if (dict.IsValid(index)) {
return Append(dict.GetView(index));
}
return AppendNull();
},
[&]() { return AppendNull(); });
}

template <typename IndexType>
Status AppendScalarImpl(const typename TypeTraits<T>::ArrayType& dict,
const Scalar& index_scalar, int64_t n_repeats) {
using ScalarType = typename TypeTraits<IndexType>::ScalarType;
const auto index = internal::checked_cast<const ScalarType&>(index_scalar).value;
if (index_scalar.is_valid && dict.IsValid(index)) {
const auto& value = dict.GetView(index);
for (int64_t i = 0; i < n_repeats; i++) {
ARROW_RETURN_NOT_OK(Append(value));
}
return Status::OK();
}
return AppendNulls(n_repeats);
}

Status FinishInternal(std::shared_ptr<ArrayData>* out) override {
std::shared_ptr<ArrayData> dictionary;
ARROW_RETURN_NOT_OK(FinishWithDictOffset(/*offset=*/0, out, &dictionary));
Expand Down
Loading

0 comments on commit 87e2ad5

Please sign in to comment.