From 87e2ad5b2f2cdaf1e469b1cda1a2899a747464b6 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 21 Sep 2021 13:08:09 +0200 Subject: [PATCH] ARROW-13573: [C++] Support dictionaries natively in case_when This supports dictionaries 'natively', that is, dictionaries are no longer always unpacked. (If mixed dictionary and non-dictionary arguments are given, then they will be unpacked.) For scalar conditions, the output will have the dictionary of whichever input is selected (or no dictionary if the output is null). For array conditions, we unify the dictionaries as we select elements. Closes #11022 from lidavidm/arrow-13573 Authored-by: David Li Signed-off-by: Antoine Pitrou --- .github/workflows/cpp.yml | 2 +- .github/workflows/r.yml | 2 +- ci/scripts/PKGBUILD | 6 +- cpp/src/arrow/array/array_test.cc | 12 +- cpp/src/arrow/array/builder_base.cc | 10 +- cpp/src/arrow/array/builder_base.h | 13 +- cpp/src/arrow/array/builder_dict.cc | 39 ++- cpp/src/arrow/array/builder_dict.h | 110 +++++++ cpp/src/arrow/builder.cc | 269 ++++++++++-------- .../arrow/compute/kernels/scalar_if_else.cc | 37 ++- .../compute/kernels/scalar_if_else_test.cc | 193 +++++++++++++ cpp/src/arrow/compute/kernels/test_util.cc | 92 +++++- cpp/src/arrow/compute/kernels/test_util.h | 15 + cpp/src/arrow/ipc/json_simple.cc | 19 ++ cpp/src/arrow/ipc/json_simple.h | 5 + cpp/src/arrow/ipc/json_simple_test.cc | 24 ++ cpp/src/arrow/scalar.cc | 3 +- cpp/src/arrow/testing/gtest_util.cc | 9 + cpp/src/arrow/testing/gtest_util.h | 5 + 19 files changed, 692 insertions(+), 173 deletions(-) diff --git a/.github/workflows/cpp.yml b/.github/workflows/cpp.yml index 086f45d6fee70..0f19f7351c325 100644 --- a/.github/workflows/cpp.yml +++ b/.github/workflows/cpp.yml @@ -238,7 +238,7 @@ jobs: name: AMD64 Windows MinGW ${{ matrix.mingw-n-bits }} C++ runs-on: windows-latest if: ${{ !contains(github.event.pull_request.title, 'WIP') }} - timeout-minutes: 45 + timeout-minutes: 60 strategy: fail-fast: false matrix: diff --git a/.github/workflows/r.yml b/.github/workflows/r.yml index e160ba8128a3a..3886eafee949a 100644 --- a/.github/workflows/r.yml +++ b/.github/workflows/r.yml @@ -53,7 +53,7 @@ jobs: name: AMD64 Ubuntu ${{ matrix.ubuntu }} R ${{ matrix.r }} runs-on: ubuntu-latest if: ${{ !contains(github.event.pull_request.title, 'WIP') }} - timeout-minutes: 60 + timeout-minutes: 75 strategy: fail-fast: false matrix: diff --git a/ci/scripts/PKGBUILD b/ci/scripts/PKGBUILD index 56d70d83daf16..246b679129a38 100644 --- a/ci/scripts/PKGBUILD +++ b/ci/scripts/PKGBUILD @@ -80,9 +80,13 @@ build() { export LIBS="-L${MINGW_PREFIX}/libs" export ARROW_S3=OFF export ARROW_WITH_RE2=OFF + # Without this, some dataset functionality segfaults + export CMAKE_UNITY_BUILD=ON else export ARROW_S3=ON export ARROW_WITH_RE2=ON + # Without this, some compute functionality segfaults in tests + export CMAKE_UNITY_BUILD=OFF fi MSYS2_ARG_CONV_EXCL="-DCMAKE_INSTALL_PREFIX=" \ @@ -115,7 +119,7 @@ build() { -DARROW_CXXFLAGS="${CPPFLAGS}" \ -DCMAKE_BUILD_TYPE="release" \ -DCMAKE_INSTALL_PREFIX=${MINGW_PREFIX} \ - -DCMAKE_UNITY_BUILD=ON \ + -DCMAKE_UNITY_BUILD=${CMAKE_UNITY_BUILD} \ -DCMAKE_VERBOSE_MAKEFILE=ON make -j3 diff --git a/cpp/src/arrow/array/array_test.cc b/cpp/src/arrow/array/array_test.cc index d9617c4e60321..2e3d40570946c 100644 --- a/cpp/src/arrow/array/array_test.cc +++ b/cpp/src/arrow/array/array_test.cc @@ -456,7 +456,7 @@ TEST_F(TestArray, TestValidateNullCount) { void AssertAppendScalar(MemoryPool* pool, const std::shared_ptr& scalar) { std::unique_ptr builder; auto null_scalar = MakeNullScalar(scalar->type); - ASSERT_OK(MakeBuilder(pool, scalar->type, &builder)); + ASSERT_OK(MakeBuilderExactIndex(pool, scalar->type, &builder)); ASSERT_OK(builder->AppendScalar(*scalar)); ASSERT_OK(builder->AppendScalar(*scalar)); ASSERT_OK(builder->AppendScalar(*null_scalar)); @@ -471,15 +471,18 @@ void AssertAppendScalar(MemoryPool* pool, const std::shared_ptr& scalar) ASSERT_EQ(out->length(), 9); const bool can_check_nulls = internal::HasValidityBitmap(out->type()->id()); + // For a dictionary builder, the output dictionary won't necessarily be the same + const bool can_check_values = !is_dictionary(out->type()->id()); if (can_check_nulls) { ASSERT_EQ(out->null_count(), 4); } + for (const auto index : {0, 1, 3, 5, 6}) { ASSERT_FALSE(out->IsNull(index)); ASSERT_OK_AND_ASSIGN(auto scalar_i, out->GetScalar(index)); ASSERT_OK(scalar_i->ValidateFull()); - AssertScalarsEqual(*scalar, *scalar_i, /*verbose=*/true); + if (can_check_values) AssertScalarsEqual(*scalar, *scalar_i, /*verbose=*/true); } for (const auto index : {2, 4, 7, 8}) { ASSERT_EQ(out->IsNull(index), can_check_nulls); @@ -575,8 +578,6 @@ TEST_F(TestArray, TestMakeArrayFromScalar) { } for (auto scalar : scalars) { - // TODO(ARROW-13197): appending dictionary scalars not implemented - if (is_dictionary(scalar->type->id())) continue; AssertAppendScalar(pool_, scalar); } } @@ -634,9 +635,6 @@ TEST_F(TestArray, TestMakeArrayFromMapScalar) { TEST_F(TestArray, TestAppendArraySlice) { auto scalars = GetScalars(); for (const auto& scalar : scalars) { - // TODO(ARROW-13573): appending dictionary arrays not implemented - if (is_dictionary(scalar->type->id())) continue; - ARROW_SCOPED_TRACE(*scalar->type); ASSERT_OK_AND_ASSIGN(auto array, MakeArrayFromScalar(*scalar, 16)); ASSERT_OK_AND_ASSIGN(auto nulls, MakeArrayOfNull(scalar->type, 16)); diff --git a/cpp/src/arrow/array/builder_base.cc b/cpp/src/arrow/array/builder_base.cc index 2f4e63b546db4..117b9d3763204 100644 --- a/cpp/src/arrow/array/builder_base.cc +++ b/cpp/src/arrow/array/builder_base.cc @@ -22,6 +22,7 @@ #include #include "arrow/array/array_base.h" +#include "arrow/array/builder_dict.h" #include "arrow/array/data.h" #include "arrow/array/util.h" #include "arrow/buffer.h" @@ -268,15 +269,6 @@ struct AppendScalarImpl { } // namespace -Status ArrayBuilder::AppendScalar(const Scalar& scalar) { - if (!scalar.type->Equals(type())) { - return Status::Invalid("Cannot append scalar of type ", scalar.type->ToString(), - " to builder for type ", type()->ToString()); - } - std::shared_ptr shared{const_cast(&scalar), [](Scalar*) {}}; - return AppendScalarImpl{&shared, &shared + 1, /*n_repeats=*/1, this}.Convert(); -} - Status ArrayBuilder::AppendScalar(const Scalar& scalar, int64_t n_repeats) { if (!scalar.type->Equals(type())) { return Status::Invalid("Cannot append scalar of type ", scalar.type->ToString(), diff --git a/cpp/src/arrow/array/builder_base.h b/cpp/src/arrow/array/builder_base.h index 67203e790714b..87e39c3fe9fc7 100644 --- a/cpp/src/arrow/array/builder_base.h +++ b/cpp/src/arrow/array/builder_base.h @@ -119,9 +119,9 @@ class ARROW_EXPORT ArrayBuilder { virtual Status AppendEmptyValues(int64_t length) = 0; /// \brief Append a value from a scalar - Status AppendScalar(const Scalar& scalar); - Status AppendScalar(const Scalar& scalar, int64_t n_repeats); - Status AppendScalars(const ScalarVector& scalars); + Status AppendScalar(const Scalar& scalar) { return AppendScalar(scalar, 1); } + virtual Status AppendScalar(const Scalar& scalar, int64_t n_repeats); + virtual Status AppendScalars(const ScalarVector& scalars); /// \brief Append a range of values from an array. /// @@ -282,6 +282,13 @@ ARROW_EXPORT Status MakeBuilder(MemoryPool* pool, const std::shared_ptr& type, std::unique_ptr* out); +/// \brief Construct an empty ArrayBuilder corresponding to the data +/// type, where any top-level or nested dictionary builders return the +/// exact index type specified by the type. +ARROW_EXPORT +Status MakeBuilderExactIndex(MemoryPool* pool, const std::shared_ptr& type, + std::unique_ptr* out); + /// \brief Construct an empty DictionaryBuilder initialized optionally /// with a pre-existing dictionary /// \param[in] pool the MemoryPool to use for allocations diff --git a/cpp/src/arrow/array/builder_dict.cc b/cpp/src/arrow/array/builder_dict.cc index b13f6a2db34ed..d247316999dc6 100644 --- a/cpp/src/arrow/array/builder_dict.cc +++ b/cpp/src/arrow/array/builder_dict.cc @@ -159,23 +159,32 @@ DictionaryMemoTable::DictionaryMemoTable(MemoryPool* pool, DictionaryMemoTable::~DictionaryMemoTable() = default; -#define GET_OR_INSERT(C_TYPE) \ - Status DictionaryMemoTable::GetOrInsert( \ - const typename CTypeTraits::ArrowType*, C_TYPE value, int32_t* out) { \ - return impl_->GetOrInsert::ArrowType>(value, out); \ +#define GET_OR_INSERT(ARROW_TYPE) \ + Status DictionaryMemoTable::GetOrInsert( \ + const ARROW_TYPE*, typename ARROW_TYPE::c_type value, int32_t* out) { \ + return impl_->GetOrInsert(value, out); \ } -GET_OR_INSERT(bool) -GET_OR_INSERT(int8_t) -GET_OR_INSERT(int16_t) -GET_OR_INSERT(int32_t) -GET_OR_INSERT(int64_t) -GET_OR_INSERT(uint8_t) -GET_OR_INSERT(uint16_t) -GET_OR_INSERT(uint32_t) -GET_OR_INSERT(uint64_t) -GET_OR_INSERT(float) -GET_OR_INSERT(double) +GET_OR_INSERT(BooleanType) +GET_OR_INSERT(Int8Type) +GET_OR_INSERT(Int16Type) +GET_OR_INSERT(Int32Type) +GET_OR_INSERT(Int64Type) +GET_OR_INSERT(UInt8Type) +GET_OR_INSERT(UInt16Type) +GET_OR_INSERT(UInt32Type) +GET_OR_INSERT(UInt64Type) +GET_OR_INSERT(FloatType) +GET_OR_INSERT(DoubleType) +GET_OR_INSERT(DurationType); +GET_OR_INSERT(TimestampType); +GET_OR_INSERT(Date32Type); +GET_OR_INSERT(Date64Type); +GET_OR_INSERT(Time32Type); +GET_OR_INSERT(Time64Type); +GET_OR_INSERT(MonthDayNanoIntervalType); +GET_OR_INSERT(DayTimeIntervalType); +GET_OR_INSERT(MonthIntervalType); #undef GET_OR_INSERT diff --git a/cpp/src/arrow/array/builder_dict.h b/cpp/src/arrow/array/builder_dict.h index 455cb3df7b17c..0637c9722a8ec 100644 --- a/cpp/src/arrow/array/builder_dict.h +++ b/cpp/src/arrow/array/builder_dict.h @@ -37,6 +37,7 @@ #include "arrow/util/decimal.h" #include "arrow/util/macros.h" #include "arrow/util/visibility.h" +#include "arrow/visitor_inline.h" namespace arrow { @@ -97,6 +98,17 @@ class ARROW_EXPORT DictionaryMemoTable { Status GetOrInsert(const UInt16Type*, uint16_t value, int32_t* out); Status GetOrInsert(const UInt32Type*, uint32_t value, int32_t* out); Status GetOrInsert(const UInt64Type*, uint64_t value, int32_t* out); + Status GetOrInsert(const DurationType*, int64_t value, int32_t* out); + Status GetOrInsert(const TimestampType*, int64_t value, int32_t* out); + Status GetOrInsert(const Date32Type*, int32_t value, int32_t* out); + Status GetOrInsert(const Date64Type*, int64_t value, int32_t* out); + Status GetOrInsert(const Time32Type*, int32_t value, int32_t* out); + Status GetOrInsert(const Time64Type*, int64_t value, int32_t* out); + Status GetOrInsert(const MonthDayNanoIntervalType*, + MonthDayNanoIntervalType::MonthDayNanos value, int32_t* out); + Status GetOrInsert(const DayTimeIntervalType*, + DayTimeIntervalType::DayMilliseconds value, int32_t* out); + Status GetOrInsert(const MonthIntervalType*, int32_t value, int32_t* out); Status GetOrInsert(const FloatType*, float value, int32_t* out); Status GetOrInsert(const DoubleType*, double value, int32_t* out); @@ -282,6 +294,73 @@ class DictionaryBuilderBase : public ArrayBuilder { return indices_builder_.AppendEmptyValues(length); } + Status AppendScalar(const Scalar& scalar, int64_t n_repeats) override { + if (!scalar.is_valid) return AppendNulls(n_repeats); + + const auto& dict_ty = internal::checked_cast(*scalar.type); + const DictionaryScalar& dict_scalar = + internal::checked_cast(scalar); + const auto& dict = internal::checked_cast::ArrayType&>( + *dict_scalar.value.dictionary); + ARROW_RETURN_NOT_OK(Reserve(n_repeats)); + switch (dict_ty.index_type()->id()) { + case Type::UINT8: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); + case Type::INT8: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); + case Type::UINT16: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); + case Type::INT16: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); + case Type::UINT32: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); + case Type::INT32: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); + case Type::UINT64: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); + case Type::INT64: + return AppendScalarImpl(dict, *dict_scalar.value.index, n_repeats); + default: + return Status::TypeError("Invalid index type: ", dict_ty); + } + return Status::OK(); + } + + Status AppendScalars(const ScalarVector& scalars) override { + for (const auto& scalar : scalars) { + ARROW_RETURN_NOT_OK(AppendScalar(*scalar, /*n_repeats=*/1)); + } + return Status::OK(); + } + + Status AppendArraySlice(const ArrayData& array, int64_t offset, int64_t length) final { + // Visit the indices and insert the unpacked values. + const auto& dict_ty = internal::checked_cast(*array.type); + const typename TypeTraits::ArrayType dict(array.dictionary); + ARROW_RETURN_NOT_OK(Reserve(length)); + switch (dict_ty.index_type()->id()) { + case Type::UINT8: + return AppendArraySliceImpl(dict, array, offset, length); + case Type::INT8: + return AppendArraySliceImpl(dict, array, offset, length); + case Type::UINT16: + return AppendArraySliceImpl(dict, array, offset, length); + case Type::INT16: + return AppendArraySliceImpl(dict, array, offset, length); + case Type::UINT32: + return AppendArraySliceImpl(dict, array, offset, length); + case Type::INT32: + return AppendArraySliceImpl(dict, array, offset, length); + case Type::UINT64: + return AppendArraySliceImpl(dict, array, offset, length); + case Type::INT64: + return AppendArraySliceImpl(dict, array, offset, length); + default: + return Status::TypeError("Invalid index type: ", dict_ty); + } + return Status::OK(); + } + /// \brief Insert values into the dictionary's memo, but do not append any /// indices. Can be used to initialize a new builder with known dictionary /// values @@ -376,6 +455,37 @@ class DictionaryBuilderBase : public ArrayBuilder { } protected: + template + Status AppendArraySliceImpl(const typename TypeTraits::ArrayType& dict, + const ArrayData& array, int64_t offset, int64_t length) { + const c_type* values = array.GetValues(1) + offset; + return VisitBitBlocks( + array.buffers[0], array.offset + offset, length, + [&](const int64_t position) { + const int64_t index = static_cast(values[position]); + if (dict.IsValid(index)) { + return Append(dict.GetView(index)); + } + return AppendNull(); + }, + [&]() { return AppendNull(); }); + } + + template + Status AppendScalarImpl(const typename TypeTraits::ArrayType& dict, + const Scalar& index_scalar, int64_t n_repeats) { + using ScalarType = typename TypeTraits::ScalarType; + const auto index = internal::checked_cast(index_scalar).value; + if (index_scalar.is_valid && dict.IsValid(index)) { + const auto& value = dict.GetView(index); + for (int64_t i = 0; i < n_repeats; i++) { + ARROW_RETURN_NOT_OK(Append(value)); + } + return Status::OK(); + } + return AppendNulls(n_repeats); + } + Status FinishInternal(std::shared_ptr* out) override { std::shared_ptr dictionary; ARROW_RETURN_NOT_OK(FinishWithDictOffset(/*offset=*/0, out, &dictionary)); diff --git a/cpp/src/arrow/builder.cc b/cpp/src/arrow/builder.cc index 37cc9e07ad4b6..115a97e93898f 100644 --- a/cpp/src/arrow/builder.cc +++ b/cpp/src/arrow/builder.cc @@ -41,14 +41,10 @@ struct DictionaryBuilderCase { } Status Visit(const NullType&) { return CreateFor(); } - Status Visit(const BinaryType&) { return Create(); } - Status Visit(const StringType&) { return Create(); } - Status Visit(const LargeBinaryType&) { - return Create>(); - } - Status Visit(const LargeStringType&) { - return Create>(); - } + Status Visit(const BinaryType&) { return CreateFor(); } + Status Visit(const StringType&) { return CreateFor(); } + Status Visit(const LargeBinaryType&) { return CreateFor(); } + Status Visit(const LargeStringType&) { return CreateFor(); } Status Visit(const FixedSizeBinaryType&) { return CreateFor(); } Status Visit(const Decimal128Type&) { return CreateFor(); } Status Visit(const Decimal256Type&) { return CreateFor(); } @@ -63,19 +59,50 @@ struct DictionaryBuilderCase { template Status CreateFor() { - return Create>(); - } - - template - Status Create() { - BuilderType* builder; + using AdaptiveBuilderType = DictionaryBuilder; if (dictionary != nullptr) { - builder = new BuilderType(dictionary, pool); + out->reset(new AdaptiveBuilderType(dictionary, pool)); + } else if (exact_index_type) { + switch (index_type->id()) { + case Type::UINT8: + out->reset(new internal::DictionaryBuilderBase( + value_type, pool)); + break; + case Type::INT8: + out->reset(new internal::DictionaryBuilderBase( + value_type, pool)); + break; + case Type::UINT16: + out->reset(new internal::DictionaryBuilderBase( + value_type, pool)); + break; + case Type::INT16: + out->reset(new internal::DictionaryBuilderBase( + value_type, pool)); + break; + case Type::UINT32: + out->reset(new internal::DictionaryBuilderBase( + value_type, pool)); + break; + case Type::INT32: + out->reset(new internal::DictionaryBuilderBase( + value_type, pool)); + break; + case Type::UINT64: + out->reset(new internal::DictionaryBuilderBase( + value_type, pool)); + break; + case Type::INT64: + out->reset(new internal::DictionaryBuilderBase( + value_type, pool)); + break; + default: + return Status::TypeError("MakeBuilder: invalid index type ", *index_type); + } } else { auto start_int_size = internal::GetByteWidth(*index_type); - builder = new BuilderType(start_int_size, value_type, pool); + out->reset(new AdaptiveBuilderType(start_int_size, value_type, pool)); } - out->reset(builder); return Status::OK(); } @@ -85,138 +112,130 @@ struct DictionaryBuilderCase { const std::shared_ptr& index_type; const std::shared_ptr& value_type; const std::shared_ptr& dictionary; + bool exact_index_type; std::unique_ptr* out; }; -#define BUILDER_CASE(TYPE_CLASS) \ - case TYPE_CLASS##Type::type_id: \ - out->reset(new TYPE_CLASS##Builder(type, pool)); \ +struct MakeBuilderImpl { + template + enable_if_not_nested Visit(const T&) { + out.reset(new typename TypeTraits::BuilderType(type, pool)); return Status::OK(); + } -Result>> FieldBuilders(const DataType& type, - MemoryPool* pool) { - std::vector> field_builders; + Status Visit(const DictionaryType& dict_type) { + DictionaryBuilderCase visitor = {pool, + dict_type.index_type(), + dict_type.value_type(), + /*dictionary=*/nullptr, + exact_index_type, + &out}; + return visitor.Make(); + } - for (const auto& field : type.fields()) { - std::unique_ptr builder; - RETURN_NOT_OK(MakeBuilder(pool, field->type(), &builder)); - field_builders.emplace_back(std::move(builder)); + Status Visit(const ListType& list_type) { + std::shared_ptr value_type = list_type.value_type(); + ARROW_ASSIGN_OR_RAISE(auto value_builder, ChildBuilder(value_type)); + out.reset(new ListBuilder(pool, std::move(value_builder), type)); + return Status::OK(); } - return field_builders; -} + Status Visit(const LargeListType& list_type) { + std::shared_ptr value_type = list_type.value_type(); + ARROW_ASSIGN_OR_RAISE(auto value_builder, ChildBuilder(value_type)); + out.reset(new LargeListBuilder(pool, std::move(value_builder), type)); + return Status::OK(); + } -Status MakeBuilder(MemoryPool* pool, const std::shared_ptr& type, - std::unique_ptr* out) { - switch (type->id()) { - case Type::NA: { - out->reset(new NullBuilder(pool)); - return Status::OK(); - } - BUILDER_CASE(UInt8); - BUILDER_CASE(Int8); - BUILDER_CASE(UInt16); - BUILDER_CASE(Int16); - BUILDER_CASE(UInt32); - BUILDER_CASE(Int32); - BUILDER_CASE(UInt64); - BUILDER_CASE(Int64); - BUILDER_CASE(Date32); - BUILDER_CASE(Date64); - BUILDER_CASE(Duration); - BUILDER_CASE(Time32); - BUILDER_CASE(Time64); - BUILDER_CASE(Timestamp); - BUILDER_CASE(MonthInterval); - BUILDER_CASE(DayTimeInterval); - BUILDER_CASE(MonthDayNanoInterval); - BUILDER_CASE(Boolean); - BUILDER_CASE(HalfFloat); - BUILDER_CASE(Float); - BUILDER_CASE(Double); - BUILDER_CASE(String); - BUILDER_CASE(Binary); - BUILDER_CASE(LargeString); - BUILDER_CASE(LargeBinary); - BUILDER_CASE(FixedSizeBinary); - BUILDER_CASE(Decimal128); - BUILDER_CASE(Decimal256); - - case Type::DICTIONARY: { - const auto& dict_type = static_cast(*type); - DictionaryBuilderCase visitor = {pool, dict_type.index_type(), - dict_type.value_type(), nullptr, out}; - return visitor.Make(); - } + Status Visit(const MapType& map_type) { + ARROW_ASSIGN_OR_RAISE(auto key_builder, ChildBuilder(map_type.key_type())); + ARROW_ASSIGN_OR_RAISE(auto item_builder, ChildBuilder(map_type.item_type())); + out.reset( + new MapBuilder(pool, std::move(key_builder), std::move(item_builder), type)); + return Status::OK(); + } - case Type::LIST: { - std::unique_ptr value_builder; - std::shared_ptr value_type = - internal::checked_cast(*type).value_type(); - RETURN_NOT_OK(MakeBuilder(pool, value_type, &value_builder)); - out->reset(new ListBuilder(pool, std::move(value_builder), type)); - return Status::OK(); - } + Status Visit(const FixedSizeListType& list_type) { + auto value_type = list_type.value_type(); + ARROW_ASSIGN_OR_RAISE(auto value_builder, ChildBuilder(value_type)); + out.reset(new FixedSizeListBuilder(pool, std::move(value_builder), type)); + return Status::OK(); + } - case Type::LARGE_LIST: { - std::unique_ptr value_builder; - std::shared_ptr value_type = - internal::checked_cast(*type).value_type(); - RETURN_NOT_OK(MakeBuilder(pool, value_type, &value_builder)); - out->reset(new LargeListBuilder(pool, std::move(value_builder), type)); - return Status::OK(); - } + Status Visit(const StructType& struct_type) { + ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool)); + out.reset(new StructBuilder(type, pool, std::move(field_builders))); + return Status::OK(); + } - case Type::MAP: { - const auto& map_type = internal::checked_cast(*type); - std::unique_ptr key_builder, item_builder; - RETURN_NOT_OK(MakeBuilder(pool, map_type.key_type(), &key_builder)); - RETURN_NOT_OK(MakeBuilder(pool, map_type.item_type(), &item_builder)); - out->reset( - new MapBuilder(pool, std::move(key_builder), std::move(item_builder), type)); - return Status::OK(); - } + Status Visit(const SparseUnionType&) { + ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool)); + out.reset(new SparseUnionBuilder(pool, std::move(field_builders), type)); + return Status::OK(); + } - case Type::FIXED_SIZE_LIST: { - const auto& list_type = internal::checked_cast(*type); - std::unique_ptr value_builder; - auto value_type = list_type.value_type(); - RETURN_NOT_OK(MakeBuilder(pool, value_type, &value_builder)); - out->reset(new FixedSizeListBuilder(pool, std::move(value_builder), type)); - return Status::OK(); - } + Status Visit(const DenseUnionType&) { + ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool)); + out.reset(new DenseUnionBuilder(pool, std::move(field_builders), type)); + return Status::OK(); + } - case Type::STRUCT: { - ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool)); - out->reset(new StructBuilder(type, pool, std::move(field_builders))); - return Status::OK(); - } + Status Visit(const ExtensionType&) { return NotImplemented(); } + Status Visit(const DataType&) { return NotImplemented(); } - case Type::SPARSE_UNION: { - ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool)); - out->reset(new SparseUnionBuilder(pool, std::move(field_builders), type)); - return Status::OK(); - } + Status NotImplemented() { + return Status::NotImplemented("MakeBuilder: cannot construct builder for type ", + type->ToString()); + } - case Type::DENSE_UNION: { - ARROW_ASSIGN_OR_RAISE(auto field_builders, FieldBuilders(*type, pool)); - out->reset(new DenseUnionBuilder(pool, std::move(field_builders), type)); - return Status::OK(); - } + Result> ChildBuilder( + const std::shared_ptr& type) { + MakeBuilderImpl impl{pool, type, exact_index_type, /*out=*/nullptr}; + RETURN_NOT_OK(VisitTypeInline(*type, &impl)); + return std::move(impl.out); + } - default: - break; + Result>> FieldBuilders(const DataType& type, + MemoryPool* pool) { + std::vector> field_builders; + for (const auto& field : type.fields()) { + std::unique_ptr builder; + MakeBuilderImpl impl{pool, field->type(), exact_index_type, /*out=*/nullptr}; + RETURN_NOT_OK(VisitTypeInline(*field->type(), &impl)); + field_builders.emplace_back(std::move(impl.out)); + } + return field_builders; } - return Status::NotImplemented("MakeBuilder: cannot construct builder for type ", - type->ToString()); + + MemoryPool* pool; + const std::shared_ptr& type; + bool exact_index_type; + std::unique_ptr out; +}; + +Status MakeBuilder(MemoryPool* pool, const std::shared_ptr& type, + std::unique_ptr* out) { + MakeBuilderImpl impl{pool, type, /*exact_index_type=*/false, /*out=*/nullptr}; + RETURN_NOT_OK(VisitTypeInline(*type, &impl)); + *out = std::move(impl.out); + return Status::OK(); +} + +Status MakeBuilderExactIndex(MemoryPool* pool, const std::shared_ptr& type, + std::unique_ptr* out) { + MakeBuilderImpl impl{pool, type, /*exact_index_type=*/true, /*out=*/nullptr}; + RETURN_NOT_OK(VisitTypeInline(*type, &impl)); + *out = std::move(impl.out); + return Status::OK(); } Status MakeDictionaryBuilder(MemoryPool* pool, const std::shared_ptr& type, const std::shared_ptr& dictionary, std::unique_ptr* out) { const auto& dict_type = static_cast(*type); - DictionaryBuilderCase visitor = {pool, dict_type.index_type(), dict_type.value_type(), - dictionary, out}; + DictionaryBuilderCase visitor = { + pool, dict_type.index_type(), dict_type.value_type(), + dictionary, /*exact_index_type=*/false, out}; return visitor.Make(); } diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 4de04da7a813d..35bb6248f2342 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -1222,7 +1222,6 @@ struct CaseWhenFunction : ScalarFunction { // The first function is a struct of booleans, where the number of fields in the // struct is either equal to the number of other arguments or is one less. RETURN_NOT_OK(CheckArity(*values)); - EnsureDictionaryDecoded(values); auto first_type = (*values)[0].type; if (first_type->id() != Type::STRUCT) { return Status::TypeError("case_when: first argument must be STRUCT, not ", @@ -1243,6 +1242,9 @@ struct CaseWhenFunction : ScalarFunction { } } + if (auto kernel = DispatchExactImpl(this, *values)) return kernel; + + EnsureDictionaryDecoded(values); if (auto type = CommonNumeric(values->data() + 1, values->size() - 1)) { for (auto it = values->begin() + 1; it != values->end(); it++) { it->type = type; @@ -1279,6 +1281,15 @@ Status ExecScalarCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out return Status::OK(); } ArrayData* output = out->mutable_array(); + if (is_dictionary_type::value) { + const Datum& dict_from = result.is_value() ? result : batch[1]; + if (dict_from.is_scalar()) { + output->dictionary = checked_cast(*dict_from.scalar()) + .value.dictionary->data(); + } else { + output->dictionary = dict_from.array()->dictionary; + } + } if (!result.is_value()) { // All conditions false, no 'else' argument result = MakeNullScalar(out->type()); @@ -1304,6 +1315,7 @@ Status ExecArrayCaseWhen(KernelContext* ctx, const ExecBatch& batch, Datum* out) static_cast(conds_array.type->num_fields()) < num_value_args; uint8_t* out_valid = output->buffers[0]->mutable_data(); uint8_t* out_values = output->buffers[1]->mutable_data(); + if (have_else_arg) { // Copy 'else' value into output CopyValues(batch.values.back(), /*in_offset=*/0, batch.length, out_valid, @@ -1472,7 +1484,7 @@ static Status ExecVarWidthArrayCaseWhenImpl( const bool have_else_arg = static_cast(conds_array.type->num_fields()) < (batch.values.size() - 1); std::unique_ptr raw_builder; - RETURN_NOT_OK(MakeBuilder(ctx->memory_pool(), out->type(), &raw_builder)); + RETURN_NOT_OK(MakeBuilderExactIndex(ctx->memory_pool(), out->type(), &raw_builder)); RETURN_NOT_OK(raw_builder->Reserve(batch.length)); RETURN_NOT_OK(reserve_data(raw_builder.get())); @@ -1701,6 +1713,24 @@ struct CaseWhenFunctor> { } }; +template <> +struct CaseWhenFunctor { + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + if (batch[0].null_count() > 0) { + return Status::Invalid("cond struct must not have outer nulls"); + } + if (batch[0].is_scalar()) { + return ExecVarWidthScalarCaseWhen(ctx, batch, out); + } + return ExecArray(ctx, batch, out); + } + + static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + std::function reserve_data = ReserveNoData; + return ExecVarWidthArrayCaseWhen(ctx, batch, out, std::move(reserve_data)); + } +}; + struct CoalesceFunction : ScalarFunction { using ScalarFunction::ScalarFunction; @@ -2446,7 +2476,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { } { auto func = std::make_shared( - "case_when", Arity::VarArgs(/*min_args=*/1), &case_when_doc); + "case_when", Arity::VarArgs(/*min_args=*/2), &case_when_doc); AddPrimitiveCaseWhenKernels(func, NumericTypes()); AddPrimitiveCaseWhenKernels(func, TemporalTypes()); AddPrimitiveCaseWhenKernels(func, IntervalTypes()); @@ -2464,6 +2494,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { AddCaseWhenKernel(func, Type::STRUCT, CaseWhenFunctor::Exec); AddCaseWhenKernel(func, Type::DENSE_UNION, CaseWhenFunctor::Exec); AddCaseWhenKernel(func, Type::SPARSE_UNION, CaseWhenFunctor::Exec); + AddCaseWhenKernel(func, Type::DICTIONARY, CaseWhenFunctor::Exec); DCHECK_OK(registry->AddFunction(std::move(func))); } { diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index b3b0f26cead12..8793cac761956 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -624,6 +624,187 @@ TYPED_TEST(TestCaseWhenNumeric, ListOfType) { ArrayFromJSON(type, R"([null, null, null, [6, null]])")); } +template +class TestCaseWhenDict : public ::testing::Test {}; + +struct JsonDict { + std::shared_ptr type; + std::string value; +}; + +TYPED_TEST_SUITE(TestCaseWhenDict, IntegralArrowTypes); + +TYPED_TEST(TestCaseWhenDict, Simple) { + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + for (const auto& dict : + {JsonDict{utf8(), R"(["a", null, "bc", "def"])"}, + JsonDict{int64(), "[1, null, 2, 3]"}, + JsonDict{decimal256(3, 2), R"(["1.23", null, "3.45", "6.78"])"}}) { + auto type = dictionary(default_type_instance(), dict.type); + auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict.value); + auto values1 = DictArrayFromJSON(type, "[0, null, 3, 1]", dict.value); + auto values2 = DictArrayFromJSON(type, "[2, 1, null, 0]", dict.value); + + // Easy case: all arguments have the same dictionary + CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, values2}); + CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}); + CheckDictionary("case_when", + {MakeStruct({cond1, cond2}), values_null, values2, values1}); + } +} + +TYPED_TEST(TestCaseWhenDict, Mixed) { + auto type = dictionary(default_type_instance(), utf8()); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto dict = R"(["a", null, "bc", "def"])"; + auto values_null = DictArrayFromJSON(type, "[null, null, null, null]", dict); + auto values1_dict = DictArrayFromJSON(type, "[0, null, 3, 1]", dict); + auto values1_decoded = ArrayFromJSON(utf8(), R"(["a", null, "def", null])"); + auto values2_dict = DictArrayFromJSON(type, "[2, 1, null, 0]", dict); + auto values2_decoded = ArrayFromJSON(utf8(), R"(["bc", null, null, "a"])"); + + // If we have mixed dictionary/non-dictionary arguments, we decode dictionaries + CheckDictionary("case_when", + {MakeStruct({cond1, cond2}), values1_dict, values2_decoded}, + /*result_is_encoded=*/false); + CheckDictionary("case_when", + {MakeStruct({cond1, cond2}), values1_decoded, values2_dict}, + /*result_is_encoded=*/false); + CheckDictionary( + "case_when", + {MakeStruct({cond1, cond2}), values1_dict, values2_dict, values1_decoded}, + /*result_is_encoded=*/false); + CheckDictionary( + "case_when", + {MakeStruct({cond1, cond2}), values_null, values2_dict, values1_decoded}, + /*result_is_encoded=*/false); +} + +TYPED_TEST(TestCaseWhenDict, NestedSimple) { + auto make_list = [](const std::shared_ptr& indices, + const std::shared_ptr& backing_array) { + EXPECT_OK_AND_ASSIGN(auto result, ListArray::FromArrays(*indices, *backing_array)); + return result; + }; + auto index_type = default_type_instance(); + auto inner_type = dictionary(index_type, utf8()); + auto type = list(inner_type); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, true, null]"); + auto dict = R"(["a", null, "bc", "def"])"; + auto values_null = make_list(ArrayFromJSON(int32(), "[null, null, null, null, 0]"), + DictArrayFromJSON(inner_type, "[]", dict)); + auto values1_backing = DictArrayFromJSON(inner_type, "[0, null, 3, 1]", dict); + auto values2_backing = DictArrayFromJSON(inner_type, "[2, 1, null, 0]", dict); + auto values1 = make_list(ArrayFromJSON(int32(), "[0, 2, 2, 3, 4]"), values1_backing); + auto values2 = make_list(ArrayFromJSON(int32(), "[0, 1, 2, 2, 4]"), values2_backing); + + CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, values2}, + /*result_is_encoded=*/false); + CheckDictionary( + "case_when", + {MakeStruct({cond1, cond2}), values1, + make_list(ArrayFromJSON(int32(), "[0, 1, null, 2, 4]"), values2_backing)}, + /*result_is_encoded=*/false); + CheckDictionary( + "case_when", + {MakeStruct({cond1, cond2}), values1, + make_list(ArrayFromJSON(int32(), "[0, 1, null, 2, 4]"), values2_backing), values1}, + /*result_is_encoded=*/false); + + CheckDictionary("case_when", + { + Datum(MakeStruct({cond1, cond2})), + Datum(std::make_shared( + DictArrayFromJSON(inner_type, "[0, 1]", dict))), + Datum(std::make_shared( + DictArrayFromJSON(inner_type, "[2, 3]", dict))), + }, + /*result_is_encoded=*/false); + + CheckDictionary("case_when", + {MakeStruct({Datum(true), Datum(false)}), values1, values2}, + /*result_is_encoded=*/false); + CheckDictionary("case_when", + {MakeStruct({Datum(false), Datum(true)}), values1, values2}, + /*result_is_encoded=*/false); + CheckDictionary("case_when", {MakeStruct({Datum(false)}), values1, values2}, + /*result_is_encoded=*/false); + CheckDictionary("case_when", + {MakeStruct({Datum(false), Datum(false)}), values1, values2}, + /*result_is_encoded=*/false); +} + +TYPED_TEST(TestCaseWhenDict, DifferentDictionaries) { + auto type = dictionary(default_type_instance(), utf8()); + auto cond1 = ArrayFromJSON(boolean(), "[true, true, null, null]"); + auto cond2 = ArrayFromJSON(boolean(), "[true, false, null, true]"); + auto dict1 = R"(["a", null, "bc", "def"])"; + auto dict2 = R"(["bc", "foo", null, "a"])"; + auto dict3 = R"(["def", null, "a", "bc"])"; + auto values1_null = DictArrayFromJSON(type, "[null, null, null, null]", dict1); + auto values2_null = DictArrayFromJSON(type, "[null, null, null, null]", dict2); + auto values1 = DictArrayFromJSON(type, "[null, 0, 3, 1]", dict1); + auto values2 = DictArrayFromJSON(type, "[2, 1, 0, null]", dict2); + auto values3 = DictArrayFromJSON(type, "[0, 1, 2, 3]", dict3); + + CheckDictionary("case_when", + {MakeStruct({Datum(true), Datum(false)}), values1, values2}); + CheckDictionary("case_when", + {MakeStruct({Datum(false), Datum(true)}), values1, values2}); + CheckDictionary("case_when", + {MakeStruct({Datum(false), Datum(false)}), values1, values2}); + CheckDictionary("case_when", + {MakeStruct({Datum(false), Datum(false)}), values2, values1}); + + CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, values2}); + CheckDictionary("case_when", {MakeStruct({cond1, cond2}), values1, values2, values1}); + + CheckDictionary("case_when", + {MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]")}), + values1, values2}); + CheckDictionary("case_when", + {MakeStruct({ArrayFromJSON(boolean(), "[true, false, false, true]")}), + values1, values2}); + CheckDictionary("case_when", + {MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]"), + ArrayFromJSON(boolean(), "[true, false, true, false]")}), + values1, values2}); + CheckDictionary("case_when", + {MakeStruct({ArrayFromJSON(boolean(), "[false, false, false, false]"), + ArrayFromJSON(boolean(), "[true, true, true, true]")}), + values1, values3}); + CheckDictionary("case_when", + {MakeStruct({ArrayFromJSON(boolean(), "[null, null, null, true]"), + ArrayFromJSON(boolean(), "[true, true, true, true]")}), + values1, values3}); + CheckDictionary( + "case_when", + { + MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]")}), + DictScalarFromJSON(type, "0", dict1), + DictScalarFromJSON(type, "0", dict2), + }); + CheckDictionary( + "case_when", + { + MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]"), + ArrayFromJSON(boolean(), "[false, false, true, true]")}), + DictScalarFromJSON(type, "0", dict1), + DictScalarFromJSON(type, "0", dict2), + }); + CheckDictionary( + "case_when", + { + MakeStruct({ArrayFromJSON(boolean(), "[true, true, false, false]"), + ArrayFromJSON(boolean(), "[false, false, true, true]")}), + DictScalarFromJSON(type, "null", dict1), + DictScalarFromJSON(type, "0", dict2), + }); +} + TEST(TestCaseWhen, Null) { auto cond_true = ScalarFromJSON(boolean(), "true"); auto cond_false = ScalarFromJSON(boolean(), "false"); @@ -1489,6 +1670,18 @@ TEST(TestCaseWhen, DispatchBest) { CallFunction("case_when", {MakeStruct({ArrayFromJSON(boolean(), "[]")}), ArrayFromJSON(int64(), "[]"), ArrayFromJSON(utf8(), "[]")})); + + // Do not dictionary-decode when we have only dictionary values + CheckDispatchBest("case_when", + {struct_({field("", boolean())}), dictionary(int64(), utf8()), + dictionary(int64(), utf8())}, + {struct_({field("", boolean())}), dictionary(int64(), utf8()), + dictionary(int64(), utf8())}); + + // Dictionary-decode if we have a mix + CheckDispatchBest( + "case_when", {struct_({field("", boolean())}), dictionary(int64(), utf8()), utf8()}, + {struct_({field("", boolean())}), utf8(), utf8()}); } template diff --git a/cpp/src/arrow/compute/kernels/test_util.cc b/cpp/src/arrow/compute/kernels/test_util.cc index 4a9215101b11d..cedc03698a10f 100644 --- a/cpp/src/arrow/compute/kernels/test_util.cc +++ b/cpp/src/arrow/compute/kernels/test_util.cc @@ -24,6 +24,7 @@ #include "arrow/array.h" #include "arrow/array/validate.h" #include "arrow/chunked_array.h" +#include "arrow/compute/cast.h" #include "arrow/compute/exec.h" #include "arrow/compute/function.h" #include "arrow/compute/registry.h" @@ -46,13 +47,6 @@ DatumVector GetDatums(const std::vector& inputs) { return datums; } -void CheckScalarNonRecursive(const std::string& func_name, const DatumVector& inputs, - const Datum& expected, const FunctionOptions* options) { - ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, inputs, options)); - ValidateOutput(out); - AssertDatumsEqual(expected, out, /*verbose=*/true); -} - template DatumVector SliceArrays(const DatumVector& inputs, SliceArgs... slice_args) { DatumVector sliced; @@ -80,6 +74,13 @@ ScalarVector GetScalars(const DatumVector& inputs, int64_t index) { } // namespace +void CheckScalarNonRecursive(const std::string& func_name, const DatumVector& inputs, + const Datum& expected, const FunctionOptions* options) { + ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, inputs, options)); + ValidateOutput(out); + AssertDatumsEqual(expected, out, /*verbose=*/true); +} + void CheckScalar(std::string func_name, const ScalarVector& inputs, std::shared_ptr expected, const FunctionOptions* options) { ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, GetDatums(inputs), options)); @@ -170,6 +171,83 @@ void CheckScalar(std::string func_name, const DatumVector& inputs, Datum expecte } } +Datum CheckDictionaryNonRecursive(const std::string& func_name, const DatumVector& args, + bool result_is_encoded) { + EXPECT_OK_AND_ASSIGN(Datum actual, CallFunction(func_name, args)); + ValidateOutput(actual); + + DatumVector decoded_args; + decoded_args.reserve(args.size()); + for (const auto& arg : args) { + if (arg.type()->id() == Type::DICTIONARY) { + const auto& to_type = checked_cast(*arg.type()).value_type(); + EXPECT_OK_AND_ASSIGN(auto decoded, Cast(arg, to_type)); + decoded_args.push_back(decoded); + } else { + decoded_args.push_back(arg); + } + } + EXPECT_OK_AND_ASSIGN(Datum expected, CallFunction(func_name, decoded_args)); + + if (result_is_encoded) { + EXPECT_EQ(Type::DICTIONARY, actual.type()->id()) + << "Result should have been dictionary-encoded"; + // Decode before comparison - we care about equivalent not identical results + const auto& to_type = + checked_cast(*actual.type()).value_type(); + EXPECT_OK_AND_ASSIGN(auto decoded, Cast(actual, to_type)); + AssertDatumsApproxEqual(expected, decoded, /*verbose=*/true); + } else { + AssertDatumsApproxEqual(expected, actual, /*verbose=*/true); + } + return actual; +} + +void CheckDictionary(const std::string& func_name, const DatumVector& args, + bool result_is_encoded) { + auto actual = CheckDictionaryNonRecursive(func_name, args, result_is_encoded); + + if (actual.is_scalar()) return; + ASSERT_TRUE(actual.is_array()); + ASSERT_GE(actual.length(), 0); + + // Check all scalars + for (int64_t i = 0; i < actual.length(); i++) { + CheckDictionaryNonRecursive(func_name, GetDatums(GetScalars(args, i)), + result_is_encoded); + } + + // Check slices of the input + const auto slice_length = actual.length() / 3; + if (slice_length > 0) { + CheckDictionaryNonRecursive(func_name, SliceArrays(args, 0, slice_length), + result_is_encoded); + CheckDictionaryNonRecursive(func_name, SliceArrays(args, slice_length, slice_length), + result_is_encoded); + CheckDictionaryNonRecursive(func_name, SliceArrays(args, 2 * slice_length), + result_is_encoded); + } + + // Check empty slice + CheckDictionaryNonRecursive(func_name, SliceArrays(args, 0, 0), result_is_encoded); + + // Check chunked arrays + if (slice_length > 0) { + DatumVector chunked_args; + chunked_args.reserve(args.size()); + for (const auto& arg : args) { + if (arg.is_array()) { + auto arr = arg.make_array(); + ArrayVector chunks{arr->Slice(0, slice_length), arr->Slice(slice_length)}; + chunked_args.push_back(std::make_shared(std::move(chunks))); + } else { + chunked_args.push_back(arg); + } + } + CheckDictionaryNonRecursive(func_name, chunked_args, result_is_encoded); + } +} + void CheckScalarUnary(std::string func_name, Datum input, Datum expected, const FunctionOptions* options) { std::vector input_vector = {std::move(input)}; diff --git a/cpp/src/arrow/compute/kernels/test_util.h b/cpp/src/arrow/compute/kernels/test_util.h index 79745b0555259..25ea577a42331 100644 --- a/cpp/src/arrow/compute/kernels/test_util.h +++ b/cpp/src/arrow/compute/kernels/test_util.h @@ -67,6 +67,8 @@ inline std::string CompareOperatorToFunctionName(CompareOperator op) { return function_names[op]; } +// Call the function with the given arguments, as well as slices of +// the arguments and scalars extracted from the arguments. void CheckScalar(std::string func_name, const ScalarVector& inputs, std::shared_ptr expected, const FunctionOptions* options = nullptr); @@ -74,6 +76,19 @@ void CheckScalar(std::string func_name, const ScalarVector& inputs, void CheckScalar(std::string func_name, const DatumVector& inputs, Datum expected, const FunctionOptions* options = nullptr); +// Like CheckScalar, but gets the expected result by +// dictionary-decoding arguments and calling the function again. +// +// result_is_encoded controls whether the result is expected to be a +// dictionary or not. +void CheckDictionary(const std::string& func_name, const DatumVector& args, + bool result_is_encoded = true); + +// Just call the function with the given arguments. +void CheckScalarNonRecursive(const std::string& func_name, const DatumVector& inputs, + const Datum& expected, + const FunctionOptions* options = nullptr); + void CheckScalarUnary(std::string func_name, std::shared_ptr in_ty, std::string json_input, std::shared_ptr out_ty, std::string json_expected, diff --git a/cpp/src/arrow/ipc/json_simple.cc b/cpp/src/arrow/ipc/json_simple.cc index 34b0f3fba598d..8347b871b1f90 100644 --- a/cpp/src/arrow/ipc/json_simple.cc +++ b/cpp/src/arrow/ipc/json_simple.cc @@ -969,6 +969,25 @@ Status ScalarFromJSON(const std::shared_ptr& type, return Status::OK(); } +Status DictScalarFromJSON(const std::shared_ptr& type, + util::string_view index_json, util::string_view dictionary_json, + std::shared_ptr* out) { + if (type->id() != Type::DICTIONARY) { + return Status::TypeError("DictScalarFromJSON requires dictionary type, got ", *type); + } + + const auto& dictionary_type = checked_cast(*type); + + std::shared_ptr index; + std::shared_ptr dictionary; + RETURN_NOT_OK(ScalarFromJSON(dictionary_type.index_type(), index_json, &index)); + RETURN_NOT_OK( + ArrayFromJSON(dictionary_type.value_type(), dictionary_json, &dictionary)); + + *out = DictionaryScalar::Make(std::move(index), std::move(dictionary)); + return Status::OK(); +} + } // namespace json } // namespace internal } // namespace ipc diff --git a/cpp/src/arrow/ipc/json_simple.h b/cpp/src/arrow/ipc/json_simple.h index 4dd3a664aa677..8269bd6532665 100644 --- a/cpp/src/arrow/ipc/json_simple.h +++ b/cpp/src/arrow/ipc/json_simple.h @@ -55,6 +55,11 @@ ARROW_EXPORT Status ScalarFromJSON(const std::shared_ptr&, util::string_view json, std::shared_ptr* out); +ARROW_EXPORT +Status DictScalarFromJSON(const std::shared_ptr&, util::string_view index_json, + util::string_view dictionary_json, + std::shared_ptr* out); + } // namespace json } // namespace internal } // namespace ipc diff --git a/cpp/src/arrow/ipc/json_simple_test.cc b/cpp/src/arrow/ipc/json_simple_test.cc index ce2c37b79574f..34c300faa95b0 100644 --- a/cpp/src/arrow/ipc/json_simple_test.cc +++ b/cpp/src/arrow/ipc/json_simple_test.cc @@ -1385,6 +1385,30 @@ TEST(TestScalarFromJSON, Errors) { ASSERT_RAISES(Invalid, ScalarFromJSON(boolean(), "\"true\"", &scalar)); } +TEST(TestDictScalarFromJSON, Basics) { + auto type = dictionary(int32(), utf8()); + auto dict = R"(["whiskey", "tango", "foxtrot"])"; + auto expected_dictionary = ArrayFromJSON(utf8(), dict); + + for (auto index : {"null", "2", "1", "0"}) { + auto scalar = DictScalarFromJSON(type, index, dict); + auto expected_index = ScalarFromJSON(int32(), index); + AssertScalarsEqual(*DictionaryScalar::Make(expected_index, expected_dictionary), + *scalar, /*verbose=*/true); + ASSERT_OK(scalar->ValidateFull()); + } +} + +TEST(TestDictScalarFromJSON, Errors) { + auto type = dictionary(int32(), utf8()); + std::shared_ptr scalar; + + ASSERT_RAISES(Invalid, + DictScalarFromJSON(type, "\"not a valid index\"", "[\"\"]", &scalar)); + ASSERT_RAISES(Invalid, DictScalarFromJSON(type, "0", "[1]", + &scalar)); // dict value isn't string +} + } // namespace json } // namespace internal } // namespace ipc diff --git a/cpp/src/arrow/scalar.cc b/cpp/src/arrow/scalar.cc index 60ba54f82cc11..adfc50182cb79 100644 --- a/cpp/src/arrow/scalar.cc +++ b/cpp/src/arrow/scalar.cc @@ -599,8 +599,9 @@ Result> DictionaryScalar::GetEncodedValue() const { std::shared_ptr DictionaryScalar::Make(std::shared_ptr index, std::shared_ptr dict) { auto type = dictionary(index->type, dict->type()); + auto is_valid = index->is_valid; return std::make_shared(ValueType{std::move(index), std::move(dict)}, - std::move(type)); + std::move(type), is_valid); } namespace { diff --git a/cpp/src/arrow/testing/gtest_util.cc b/cpp/src/arrow/testing/gtest_util.cc index 587154c1f3048..24f5edcc6cb8f 100644 --- a/cpp/src/arrow/testing/gtest_util.cc +++ b/cpp/src/arrow/testing/gtest_util.cc @@ -446,6 +446,15 @@ std::shared_ptr ScalarFromJSON(const std::shared_ptr& type, return out; } +std::shared_ptr DictScalarFromJSON(const std::shared_ptr& type, + util::string_view index_json, + util::string_view dictionary_json) { + std::shared_ptr out; + ABORT_NOT_OK( + ipc::internal::json::DictScalarFromJSON(type, index_json, dictionary_json, &out)); + return out; +} + std::shared_ptr TableFromJSON(const std::shared_ptr& schema, const std::vector& json) { std::vector> batches; diff --git a/cpp/src/arrow/testing/gtest_util.h b/cpp/src/arrow/testing/gtest_util.h index f0021e056035e..65ab33c5d1fb7 100644 --- a/cpp/src/arrow/testing/gtest_util.h +++ b/cpp/src/arrow/testing/gtest_util.h @@ -338,6 +338,11 @@ ARROW_TESTING_EXPORT std::shared_ptr ScalarFromJSON(const std::shared_ptr&, util::string_view json); +ARROW_TESTING_EXPORT +std::shared_ptr DictScalarFromJSON(const std::shared_ptr&, + util::string_view index_json, + util::string_view dictionary_json); + ARROW_TESTING_EXPORT std::shared_ptr
TableFromJSON(const std::shared_ptr&, const std::vector& json);