From 2ad7793dd2e72df75e0cd3780b0bb5e5a8d882e0 Mon Sep 17 00:00:00 2001 From: Ruoxi Sun Date: Thu, 13 Jun 2024 02:11:03 +0800 Subject: [PATCH] Add files --- cpp/src/arrow/CMakeLists.txt | 3 + cpp/src/arrow/compute/kernels/CMakeLists.txt | 7 + .../arrow/compute/kernels/vector_scatter.cc | 92 + .../kernels/vector_scatter_benchmark.cc | 577 ++++ .../vector_scatter_by_mask_internal.cc | 1132 +++++++ .../kernels/vector_scatter_by_mask_internal.h | 37 + .../kernels/vector_scatter_internal.cc | 888 ++++++ .../compute/kernels/vector_scatter_internal.h | 71 + .../compute/kernels/vector_scatter_test.cc | 2723 +++++++++++++++++ cpp/src/arrow/compute/registry_internal.h | 1 + 10 files changed, 5531 insertions(+) create mode 100644 cpp/src/arrow/compute/kernels/vector_scatter.cc create mode 100644 cpp/src/arrow/compute/kernels/vector_scatter_benchmark.cc create mode 100644 cpp/src/arrow/compute/kernels/vector_scatter_by_mask_internal.cc create mode 100644 cpp/src/arrow/compute/kernels/vector_scatter_by_mask_internal.h create mode 100644 cpp/src/arrow/compute/kernels/vector_scatter_internal.cc create mode 100644 cpp/src/arrow/compute/kernels/vector_scatter_internal.h create mode 100644 cpp/src/arrow/compute/kernels/vector_scatter_test.cc diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 6dc8358f502f5..61617c9fdf001 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -734,6 +734,9 @@ set(ARROW_COMPUTE_SRCS compute/kernels/scalar_cast_temporal.cc compute/kernels/util_internal.cc compute/kernels/vector_hash.cc + compute/kernels/vector_scatter.cc + compute/kernels/vector_scatter_by_mask_internal.cc + compute/kernels/vector_scatter_internal.cc compute/kernels/vector_selection.cc compute/kernels/vector_selection_filter_internal.cc compute/kernels/vector_selection_internal.cc diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index afb30996eac15..01f19d7e7d30d 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -114,12 +114,19 @@ add_arrow_compute_test(vector_selection_test EXTRA_LINK_LIBS arrow_compute_kernels_testing) +add_arrow_compute_test(vector_scatter_test + SOURCES + vector_scatter_test.cc + EXTRA_LINK_LIBS + arrow_compute_kernels_testing) + add_arrow_benchmark(vector_hash_benchmark PREFIX "arrow-compute") add_arrow_benchmark(vector_sort_benchmark PREFIX "arrow-compute") add_arrow_benchmark(vector_partition_benchmark PREFIX "arrow-compute") add_arrow_benchmark(vector_topk_benchmark PREFIX "arrow-compute") add_arrow_benchmark(vector_replace_benchmark PREFIX "arrow-compute") add_arrow_benchmark(vector_selection_benchmark PREFIX "arrow-compute") +add_arrow_benchmark(vector_scatter_benchmark PREFIX "arrow-compute") # ---------------------------------------------------------------------- # Aggregate kernels diff --git a/cpp/src/arrow/compute/kernels/vector_scatter.cc b/cpp/src/arrow/compute/kernels/vector_scatter.cc new file mode 100644 index 0000000000000..46a72d83ed15e --- /dev/null +++ b/cpp/src/arrow/compute/kernels/vector_scatter.cc @@ -0,0 +1,92 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include +#include + +#include "arrow/array/array_binary.h" +#include "arrow/array/array_dict.h" +#include "arrow/array/array_nested.h" +#include "arrow/array/builder_primitive.h" +#include "arrow/array/concatenate.h" +#include "arrow/buffer_builder.h" +#include "arrow/chunked_array.h" +#include "arrow/compute/api_vector.h" +#include "arrow/compute/kernels/common_internal.h" +#include "arrow/compute/kernels/util_internal.h" +#include "arrow/compute/kernels/vector_scatter_by_mask_internal.h" +#include "arrow/extension_type.h" +#include "arrow/record_batch.h" +#include "arrow/result.h" +#include "arrow/table.h" +#include "arrow/type.h" +#include "arrow/util/bit_block_counter.h" +#include "arrow/util/bit_run_reader.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/bitmap_reader.h" +#include "arrow/util/int_util.h" + +namespace arrow { + +using internal::BinaryBitBlockCounter; +using internal::BitBlockCount; +using internal::BitBlockCounter; +using internal::CheckIndexBounds; +using internal::CopyBitmap; +using internal::CountSetBits; +using internal::OptionalBitBlockCounter; +using internal::OptionalBitIndexer; + +namespace compute { +namespace internal { + +namespace { + +// ---------------------------------------------------------------------- + +const FunctionDoc array_scatter_by_mask_doc( + "Scatter with a boolean positional mask", + ("The values of the input `array` will be placed into the output at positions where " + "the `positional_mask` is non-zero. The rest positions of the output will be " + "populated by `null`s.\n"), + {"array", "positional_mask"}); + +} // namespace + +void RegisterVectorScatter(FunctionRegistry* registry) { + // Scatter by mask kernels + std::vector scatter_by_mask_kernels; + PopulateScatterByMaskKernels(&scatter_by_mask_kernels); + + VectorKernel scatter_by_mask_base; + scatter_by_mask_base.can_execute_chunkwise = false; + scatter_by_mask_base.output_chunked = false; + RegisterScatterFunction("array_scatter_by_mask", array_scatter_by_mask_doc, + scatter_by_mask_base, std::move(scatter_by_mask_kernels), + NULLPTR, registry); + + DCHECK_OK(registry->AddFunction(MakeScatterByMaskMetaFunction())); +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/vector_scatter_benchmark.cc b/cpp/src/arrow/compute/kernels/vector_scatter_benchmark.cc new file mode 100644 index 0000000000000..c2a27dfe43488 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/vector_scatter_benchmark.cc @@ -0,0 +1,577 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "benchmark/benchmark.h" + +#include +#include + +#include "arrow/compute/api_vector.h" +#include "arrow/compute/kernels/test_util.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" +#include "arrow/util/benchmark_util.h" + +namespace arrow { +namespace compute { + +constexpr auto kSeed = 0x0ff1ce; + +struct FilterParams { + // proportion of nulls in the values array + const double values_null_proportion; + + // proportion of true in filter + const double selected_proportion; + + // proportion of nulls in the filter + const double filter_null_proportion; +}; + +std::vector g_data_sizes = {kL2Size}; + +// The benchmark state parameter references this vector of cases. Test high and +// low selectivity filters. + +// clang-format off +std::vector g_filter_params = { + {0., 0.999, 0.05}, + {0., 0.50, 0.05}, + {0., 0.01, 0.05}, + {0.001, 0.999, 0.05}, + {0.001, 0.50, 0.05}, + {0.001, 0.01, 0.05}, + {0.01, 0.999, 0.05}, + {0.01, 0.50, 0.05}, + {0.01, 0.01, 0.05}, + {0.1, 0.999, 0.05}, + {0.1, 0.50, 0.05}, + {0.1, 0.01, 0.05}, + {0.9, 0.999, 0.05}, + {0.9, 0.50, 0.05}, + {0.9, 0.01, 0.05} +}; +// clang-format on + +// RAII struct to handle some of the boilerplate in filter +struct FilterArgs { + // size of memory tested (per iteration) in bytes + int64_t size; + + // What to call the "size" that's reported in the console output, for result + // interpretability. + std::string size_name = "size"; + + double values_null_proportion = 0.; + double selected_proportion = 0.; + double filter_null_proportion = 0.; + + FilterArgs(benchmark::State& state, bool filter_has_nulls) + : size(state.range(0)), state_(state) { + auto params = g_filter_params[state.range(1)]; + values_null_proportion = params.values_null_proportion; + selected_proportion = params.selected_proportion; + filter_null_proportion = filter_has_nulls ? params.filter_null_proportion : 0; + } + + ~FilterArgs() { + state_.counters[size_name] = static_cast(size); + state_.counters["select%"] = selected_proportion * 100; + state_.counters["data null%"] = values_null_proportion * 100; + state_.counters["mask null%"] = filter_null_proportion * 100; + state_.SetBytesProcessed(state_.iterations() * size); + } + + private: + benchmark::State& state_; +}; + +struct TakeBenchmark { + benchmark::State& state; + RegressionArgs args; + random::RandomArrayGenerator rand; + bool indices_have_nulls; + bool monotonic_indices = false; + + TakeBenchmark(benchmark::State& state, bool indices_have_nulls, + bool monotonic_indices = false) + : state(state), + args(state, /*size_is_bytes=*/false), + rand(kSeed), + indices_have_nulls(indices_have_nulls), + monotonic_indices(monotonic_indices) {} + + static constexpr int kStringMinLength = 0; + static constexpr int kStringMaxLength = 32; + static constexpr int kByteWidthRange = 2; + + template + std::shared_ptr GenChunkedArray(int64_t num_chunks, + GenChunk&& gen_chunk) { + const int64_t chunk_length = + std::llround(args.size / static_cast(num_chunks)); + ArrayVector chunks; + for (int64_t i = 0; i < num_chunks; ++i) { + const int64_t fitting_chunk_length = + std::min(chunk_length, args.size - i * chunk_length); + chunks.push_back(gen_chunk(fitting_chunk_length)); + } + return std::make_shared(std::move(chunks)); + } + + void Int64() { + auto values = rand.Int64(args.size, -100, 100, args.null_proportion); + Bench(values); + } + + void FSLInt64() { + auto int_array = rand.Int64(args.size, -100, 100, args.null_proportion); + auto values = std::make_shared( + fixed_size_list(int64(), 1), args.size, int_array, int_array->null_bitmap(), + int_array->null_count()); + Bench(values); + } + + void FixedSizeBinary() { + const auto byte_width = static_cast(state.range(kByteWidthRange)); + auto values = rand.FixedSizeBinary(args.size, byte_width, args.null_proportion); + Bench(values); + state.counters["byte_width"] = byte_width; + } + + void String() { + auto values = std::static_pointer_cast( + rand.String(args.size, kStringMinLength, kStringMaxLength, args.null_proportion)); + Bench(values); + } + + void ChunkedInt64(int64_t num_chunks, bool chunk_indices_too) { + auto chunked_array = GenChunkedArray(num_chunks, [this](int64_t chunk_length) { + return rand.Int64(chunk_length, -100, 100, args.null_proportion); + }); + BenchChunked(chunked_array, chunk_indices_too); + } + + void ChunkedFSB(int64_t num_chunks, bool chunk_indices_too) { + const auto byte_width = static_cast(state.range(kByteWidthRange)); + auto chunked_array = + GenChunkedArray(num_chunks, [this, byte_width](int64_t chunk_length) { + return rand.FixedSizeBinary(chunk_length, byte_width, args.null_proportion); + }); + BenchChunked(chunked_array, chunk_indices_too); + state.counters["byte_width"] = byte_width; + } + + void ChunkedString(int64_t num_chunks, bool chunk_indices_too) { + auto chunked_array = GenChunkedArray(num_chunks, [this](int64_t chunk_length) { + return std::static_pointer_cast(rand.String( + chunk_length, kStringMinLength, kStringMaxLength, args.null_proportion)); + }); + BenchChunked(chunked_array, chunk_indices_too); + } + + void Bench(const std::shared_ptr& values) { + double indices_null_proportion = indices_have_nulls ? args.null_proportion : 0; + auto indices = + rand.Int32(values->length(), 0, static_cast(values->length() - 1), + indices_null_proportion); + + if (monotonic_indices) { + auto arg_sorter = *SortIndices(*indices); + indices = *Take(*indices, *arg_sorter); + } + + for (auto _ : state) { + ABORT_NOT_OK(Take(values, indices).status()); + } + state.SetItemsProcessed(state.iterations() * values->length()); + } + + void BenchChunked(const std::shared_ptr& values, bool chunk_indices_too) { + double indices_null_proportion = indices_have_nulls ? args.null_proportion : 0; + auto indices = + rand.Int32(values->length(), 0, static_cast(values->length() - 1), + indices_null_proportion); + + if (monotonic_indices) { + auto arg_sorter = *SortIndices(*indices); + indices = *Take(*indices, *arg_sorter); + } + std::shared_ptr chunked_indices; + if (chunk_indices_too) { + std::vector> indices_chunks; + int64_t offset = 0; + for (int i = 0; i < values->num_chunks(); ++i) { + auto chunk = indices->Slice(offset, values->chunk(i)->length()); + indices_chunks.push_back(std::move(chunk)); + offset += values->chunk(i)->length(); + } + chunked_indices = std::make_shared(std::move(indices_chunks)); + } + + if (chunk_indices_too) { + for (auto _ : state) { + ABORT_NOT_OK(Take(values, chunked_indices).status()); + } + } else { + for (auto _ : state) { + ABORT_NOT_OK(Take(values, indices).status()); + } + } + state.SetItemsProcessed(state.iterations() * values->length()); + } +}; + +struct FilterBenchmark { + benchmark::State& state; + FilterArgs args; + random::RandomArrayGenerator rand; + bool filter_has_nulls; + + FilterBenchmark(benchmark::State& state, bool filter_has_nulls) + : state(state), + args(state, filter_has_nulls), + rand(kSeed), + filter_has_nulls(filter_has_nulls) {} + + void Int64() { + const int64_t array_size = args.size / sizeof(int64_t); + auto values = rand.Int64(array_size, -100, 100, args.values_null_proportion); + Bench(values); + } + + void FSLInt64() { + const int64_t array_size = args.size / sizeof(int64_t); + auto int_array = std::static_pointer_cast>( + rand.Int64(array_size, -100, 100, args.values_null_proportion)); + auto values = std::make_shared( + fixed_size_list(int64(), 1), array_size, int_array, int_array->null_bitmap(), + int_array->null_count()); + Bench(values); + } + + void FixedSizeBinary() { + const int32_t byte_width = static_cast(state.range(2)); + const int64_t array_size = args.size / byte_width; + auto values = + rand.FixedSizeBinary(array_size, byte_width, args.values_null_proportion); + Bench(values); + state.counters["byte_width"] = byte_width; + } + + void String() { + int32_t string_min_length = 0, string_max_length = 32; + int32_t string_mean_length = (string_max_length + string_min_length) / 2; + // for an array of 50% null strings, we need to generate twice as many strings + // to ensure that they have an average of args.size total characters + int64_t array_size = args.size; + if (args.values_null_proportion < 1) { + array_size = static_cast(args.size / string_mean_length / + (1 - args.values_null_proportion)); + } + auto values = std::static_pointer_cast(rand.String( + array_size, string_min_length, string_max_length, args.values_null_proportion)); + Bench(values); + } + + void Bench(const std::shared_ptr& values) { + auto filter = rand.Boolean(values->length(), args.selected_proportion, + args.filter_null_proportion); + for (auto _ : state) { + ABORT_NOT_OK(Filter(values, filter).status()); + } + state.SetItemsProcessed(state.iterations() * values->length()); + } + + void BenchRecordBatch() { + const int64_t total_data_cells = 10000000; + const int64_t num_columns = state.range(0); + const int64_t num_rows = total_data_cells / num_columns; + + auto col_data = rand.Float64(num_rows, 0, 1); + + auto filter = + rand.Boolean(num_rows, args.selected_proportion, args.filter_null_proportion); + + int64_t output_length = + internal::GetFilterOutputSize(*filter->data(), FilterOptions::DROP); + + // HACK: set FilterArgs.size to the number of selected data cells * + // sizeof(double) for accurate memory processing performance + args.size = output_length * num_columns * sizeof(double); + args.size_name = "extracted_size"; + state.counters["num_cols"] = static_cast(num_columns); + + std::vector> columns; + std::vector> fields; + for (int64_t i = 0; i < num_columns; ++i) { + std::stringstream ss; + ss << "f" << i; + fields.push_back(::arrow::field(ss.str(), float64())); + columns.push_back(col_data); + } + + auto batch = RecordBatch::Make(schema(fields), num_rows, columns); + for (auto _ : state) { + ABORT_NOT_OK(Filter(batch, filter).status()); + } + state.SetItemsProcessed(state.iterations() * num_rows); + } +}; + +static void FilterInt64FilterNoNulls(benchmark::State& state) { + FilterBenchmark(state, false).Int64(); +} + +static void FilterInt64FilterWithNulls(benchmark::State& state) { + FilterBenchmark(state, true).Int64(); +} + +static void FilterFSLInt64FilterNoNulls(benchmark::State& state) { + FilterBenchmark(state, false).FSLInt64(); +} + +static void FilterFSLInt64FilterWithNulls(benchmark::State& state) { + FilterBenchmark(state, true).FSLInt64(); +} + +static void FilterFixedSizeBinaryFilterNoNulls(benchmark::State& state) { + FilterBenchmark(state, false).FixedSizeBinary(); +} + +static void FilterFixedSizeBinaryFilterWithNulls(benchmark::State& state) { + FilterBenchmark(state, true).FixedSizeBinary(); +} + +static void FilterStringFilterNoNulls(benchmark::State& state) { + FilterBenchmark(state, false).String(); +} + +static void FilterStringFilterWithNulls(benchmark::State& state) { + FilterBenchmark(state, true).String(); +} + +static void FilterRecordBatchNoNulls(benchmark::State& state) { + FilterBenchmark(state, false).BenchRecordBatch(); +} + +static void FilterRecordBatchWithNulls(benchmark::State& state) { + FilterBenchmark(state, true).BenchRecordBatch(); +} + +static void TakeInt64RandomIndicesNoNulls(benchmark::State& state) { + TakeBenchmark(state, /*indices_with_nulls=*/false).Int64(); +} + +static void TakeInt64RandomIndicesWithNulls(benchmark::State& state) { + TakeBenchmark(state, /*indices_with_nulls=*/true).Int64(); +} + +static void TakeInt64MonotonicIndices(benchmark::State& state) { + TakeBenchmark(state, /*indices_with_nulls=*/false, /*monotonic=*/true).Int64(); +} + +static void TakeFixedSizeBinaryRandomIndicesNoNulls(benchmark::State& state) { + TakeBenchmark(state, /*indices_with_nulls=*/false).FixedSizeBinary(); +} + +static void TakeFixedSizeBinaryRandomIndicesWithNulls(benchmark::State& state) { + TakeBenchmark(state, /*indices_with_nulls=*/true).FixedSizeBinary(); +} + +static void TakeFixedSizeBinaryMonotonicIndices(benchmark::State& state) { + TakeBenchmark(state, /*indices_with_nulls=*/false, /*monotonic=*/true) + .FixedSizeBinary(); +} + +static void TakeFSLInt64RandomIndicesNoNulls(benchmark::State& state) { + TakeBenchmark(state, /*indices_with_nulls=*/false).FSLInt64(); +} + +static void TakeFSLInt64RandomIndicesWithNulls(benchmark::State& state) { + TakeBenchmark(state, /*indices_with_nulls=*/true).FSLInt64(); +} + +static void TakeFSLInt64MonotonicIndices(benchmark::State& state) { + TakeBenchmark(state, /*indices_with_nulls=*/false, /*monotonic=*/true).FSLInt64(); +} + +static void TakeStringRandomIndicesNoNulls(benchmark::State& state) { + TakeBenchmark(state, /*indices_with_nulls=*/false).String(); +} + +static void TakeStringRandomIndicesWithNulls(benchmark::State& state) { + TakeBenchmark(state, /*indices_with_nulls=*/true).String(); +} + +static void TakeStringMonotonicIndices(benchmark::State& state) { + TakeBenchmark(state, /*indices_with_nulls=*/false, /*monotonic=*/true).FSLInt64(); +} + +static void TakeChunkedChunkedInt64RandomIndicesNoNulls(benchmark::State& state) { + TakeBenchmark(state, /*indices_with_nulls=*/false) + .ChunkedInt64(/*num_chunks=*/100, /*chunk_indices_too=*/true); +} + +static void TakeChunkedChunkedInt64RandomIndicesWithNulls(benchmark::State& state) { + TakeBenchmark(state, /*indices_with_nulls=*/true) + .ChunkedInt64(/*num_chunks=*/100, /*chunk_indices_too=*/true); +} + +static void TakeChunkedChunkedInt64MonotonicIndices(benchmark::State& state) { + TakeBenchmark(state, /*indices_with_nulls=*/false, /*monotonic=*/true) + .ChunkedInt64( + /*num_chunks=*/100, /*chunk_indices_too=*/true); +} + +static void TakeChunkedChunkedFSBRandomIndicesNoNulls(benchmark::State& state) { + TakeBenchmark(state, /*indices_with_nulls=*/false) + .ChunkedFSB(/*num_chunks=*/100, /*chunk_indices_too=*/true); +} + +static void TakeChunkedChunkedFSBRandomIndicesWithNulls(benchmark::State& state) { + TakeBenchmark(state, /*indices_with_nulls=*/true) + .ChunkedFSB(/*num_chunks=*/100, /*chunk_indices_too=*/true); +} + +static void TakeChunkedChunkedFSBMonotonicIndices(benchmark::State& state) { + TakeBenchmark(state, /*indices_with_nulls=*/false, /*monotonic=*/true) + .ChunkedFSB(/*num_chunks=*/100, /*chunk_indices_too=*/true); +} + +static void TakeChunkedChunkedStringRandomIndicesNoNulls(benchmark::State& state) { + TakeBenchmark(state, /*indices_with_nulls=*/false) + .ChunkedString(/*num_chunks=*/100, /*chunk_indices_too=*/true); +} + +static void TakeChunkedChunkedStringRandomIndicesWithNulls(benchmark::State& state) { + TakeBenchmark(state, /*indices_with_nulls=*/true) + .ChunkedString(/*num_chunks=*/100, /*chunk_indices_too=*/true); +} + +static void TakeChunkedChunkedStringMonotonicIndices(benchmark::State& state) { + TakeBenchmark(state, /*indices_with_nulls=*/false, /*monotonic=*/true) + .ChunkedString(/*num_chunks=*/100, /*chunk_indices_too=*/true); +} + +static void TakeChunkedFlatInt64RandomIndicesNoNulls(benchmark::State& state) { + TakeBenchmark(state, /*indices_with_nulls=*/false) + .ChunkedInt64(/*num_chunks=*/100, /*chunk_indices_too=*/false); +} + +static void TakeChunkedFlatInt64RandomIndicesWithNulls(benchmark::State& state) { + TakeBenchmark(state, /*indices_with_nulls=*/true) + .ChunkedInt64(/*num_chunks=*/100, /*chunk_indices_too=*/false); +} + +static void TakeChunkedFlatInt64MonotonicIndices(benchmark::State& state) { + TakeBenchmark(state, /*indices_with_nulls=*/false, /*monotonic=*/true) + .ChunkedInt64( + /*num_chunks=*/100, /*chunk_indices_too=*/false); +} + +void FilterSetArgs(benchmark::internal::Benchmark* bench) { + for (int64_t size : g_data_sizes) { + for (int i = 0; i < static_cast(g_filter_params.size()); ++i) { + bench->Args({static_cast(size), i}); + } + } +} + +void FilterFSBSetArgs(benchmark::internal::Benchmark* bench) { + for (int64_t size : g_data_sizes) { + for (int i = 0; i < static_cast(g_filter_params.size()); ++i) { + // FixedSizeBinary of primitive sizes (powers of two up to 32) + // have a faster path. + for (int32_t byte_width : {8, 9}) { + bench->Args({static_cast(size), i, byte_width}); + } + } + } +} + +BENCHMARK(FilterInt64FilterNoNulls)->Apply(FilterSetArgs); +BENCHMARK(FilterInt64FilterWithNulls)->Apply(FilterSetArgs); +BENCHMARK(FilterFixedSizeBinaryFilterNoNulls)->Apply(FilterFSBSetArgs); +BENCHMARK(FilterFixedSizeBinaryFilterWithNulls)->Apply(FilterFSBSetArgs); +BENCHMARK(FilterFSLInt64FilterNoNulls)->Apply(FilterSetArgs); +BENCHMARK(FilterFSLInt64FilterWithNulls)->Apply(FilterSetArgs); +BENCHMARK(FilterStringFilterNoNulls)->Apply(FilterSetArgs); +BENCHMARK(FilterStringFilterWithNulls)->Apply(FilterSetArgs); + +void FilterRecordBatchSetArgs(benchmark::internal::Benchmark* bench) { + for (auto num_cols : std::vector({10, 50, 100})) { + for (int i = 0; i < static_cast(g_filter_params.size()); ++i) { + bench->Args({num_cols, i}); + } + } +} +BENCHMARK(FilterRecordBatchNoNulls)->Apply(FilterRecordBatchSetArgs); +BENCHMARK(FilterRecordBatchWithNulls)->Apply(FilterRecordBatchSetArgs); + +void TakeSetArgs(benchmark::internal::Benchmark* bench) { + for (int64_t size : g_data_sizes) { + for (auto nulls : std::vector({1000, 10, 2, 1, 0})) { + bench->Args({static_cast(size), nulls}); + } + } +} + +void TakeFSBSetArgs(benchmark::internal::Benchmark* bench) { + for (int64_t size : g_data_sizes) { + for (auto nulls : std::vector({1000, 10, 2, 1, 0})) { + // FixedSizeBinary of primitive sizes (powers of two up to 32) + // have a faster path. + for (int32_t byte_width : {8, 9}) { + bench->Args({static_cast(size), nulls, byte_width}); + } + } + } +} + +// Flat values x Flat indices +BENCHMARK(TakeInt64RandomIndicesNoNulls)->Apply(TakeSetArgs); +BENCHMARK(TakeInt64RandomIndicesWithNulls)->Apply(TakeSetArgs); +BENCHMARK(TakeInt64MonotonicIndices)->Apply(TakeSetArgs); +BENCHMARK(TakeFixedSizeBinaryRandomIndicesNoNulls)->Apply(TakeFSBSetArgs); +BENCHMARK(TakeFixedSizeBinaryRandomIndicesWithNulls)->Apply(TakeFSBSetArgs); +BENCHMARK(TakeFixedSizeBinaryMonotonicIndices)->Apply(TakeFSBSetArgs); +BENCHMARK(TakeFSLInt64RandomIndicesNoNulls)->Apply(TakeSetArgs); +BENCHMARK(TakeFSLInt64RandomIndicesWithNulls)->Apply(TakeSetArgs); +BENCHMARK(TakeFSLInt64MonotonicIndices)->Apply(TakeSetArgs); +BENCHMARK(TakeStringRandomIndicesNoNulls)->Apply(TakeSetArgs); +BENCHMARK(TakeStringRandomIndicesWithNulls)->Apply(TakeSetArgs); +BENCHMARK(TakeStringMonotonicIndices)->Apply(TakeSetArgs); + +// Chunked values x Chunked indices +BENCHMARK(TakeChunkedChunkedInt64RandomIndicesNoNulls)->Apply(TakeSetArgs); +BENCHMARK(TakeChunkedChunkedInt64RandomIndicesWithNulls)->Apply(TakeSetArgs); +BENCHMARK(TakeChunkedChunkedInt64MonotonicIndices)->Apply(TakeSetArgs); +BENCHMARK(TakeChunkedChunkedFSBRandomIndicesNoNulls)->Apply(TakeFSBSetArgs); +BENCHMARK(TakeChunkedChunkedFSBRandomIndicesWithNulls)->Apply(TakeFSBSetArgs); +BENCHMARK(TakeChunkedChunkedFSBMonotonicIndices)->Apply(TakeFSBSetArgs); +BENCHMARK(TakeChunkedChunkedStringRandomIndicesNoNulls)->Apply(TakeSetArgs); +BENCHMARK(TakeChunkedChunkedStringRandomIndicesWithNulls)->Apply(TakeSetArgs); +BENCHMARK(TakeChunkedChunkedStringMonotonicIndices)->Apply(TakeSetArgs); + +// Chunked values x Flat indices +BENCHMARK(TakeChunkedFlatInt64RandomIndicesNoNulls)->Apply(TakeSetArgs); +BENCHMARK(TakeChunkedFlatInt64RandomIndicesWithNulls)->Apply(TakeSetArgs); +BENCHMARK(TakeChunkedFlatInt64MonotonicIndices)->Apply(TakeSetArgs); + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/vector_scatter_by_mask_internal.cc b/cpp/src/arrow/compute/kernels/vector_scatter_by_mask_internal.cc new file mode 100644 index 0000000000000..86da8c1f889e6 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/vector_scatter_by_mask_internal.cc @@ -0,0 +1,1132 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include +#include + +#include "arrow/array/concatenate.h" +#include "arrow/array/data.h" +#include "arrow/buffer_builder.h" +#include "arrow/chunked_array.h" +#include "arrow/compute/api_vector.h" +#include "arrow/compute/exec.h" +#include "arrow/compute/kernel.h" +#include "arrow/compute/kernels/codegen_internal.h" +#include "arrow/compute/kernels/vector_selection_filter_internal.h" +#include "arrow/compute/kernels/vector_selection_internal.h" +#include "arrow/datum.h" +#include "arrow/extension_type.h" +#include "arrow/record_batch.h" +#include "arrow/table.h" +#include "arrow/type.h" +#include "arrow/util/bit_block_counter.h" +#include "arrow/util/bit_run_reader.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/bitmap_ops.h" +#include "arrow/util/fixed_width_internal.h" + +namespace arrow { + +using internal::BinaryBitBlockCounter; +using internal::BitBlockCount; +using internal::BitBlockCounter; +using internal::CopyBitmap; +using internal::CountSetBits; +using internal::OptionalBitBlockCounter; + +namespace compute::internal { + +namespace { + +using FilterState = OptionsWrapper; + +int64_t GetBitmapFilterOutputSize(const ArraySpan& filter, + FilterOptions::NullSelectionBehavior null_selection) { + int64_t output_size = 0; + + if (filter.MayHaveNulls()) { + const uint8_t* filter_is_valid = filter.buffers[0].data; + BinaryBitBlockCounter bit_counter(filter.buffers[1].data, filter.offset, + filter_is_valid, filter.offset, filter.length); + int64_t position = 0; + if (null_selection == FilterOptions::EMIT_NULL) { + while (position < filter.length) { + BitBlockCount block = bit_counter.NextOrNotWord(); + output_size += block.popcount; + position += block.length; + } + } else { + while (position < filter.length) { + BitBlockCount block = bit_counter.NextAndWord(); + output_size += block.popcount; + position += block.length; + } + } + } else { + // The filter has no nulls, so we can use CountSetBits + output_size = CountSetBits(filter.buffers[1].data, filter.offset, filter.length); + } + return output_size; +} + +int64_t GetREEFilterOutputSize(const ArraySpan& filter, + FilterOptions::NullSelectionBehavior null_selection) { + const auto& ree_type = checked_cast(*filter.type); + DCHECK_EQ(ree_type.value_type()->id(), Type::BOOL); + int64_t output_size = 0; + VisitPlainxREEFilterOutputSegments( + filter, /*filter_may_have_nulls=*/true, null_selection, + [&output_size](int64_t, int64_t segment_length, bool) { + output_size += segment_length; + return true; + }); + return output_size; +} + +} // namespace + +int64_t GetFilterOutputSize(const ArraySpan& filter, + FilterOptions::NullSelectionBehavior null_selection) { + if (filter.type->id() == Type::BOOL) { + return GetBitmapFilterOutputSize(filter, null_selection); + } + DCHECK_EQ(filter.type->id(), Type::RUN_END_ENCODED); + return GetREEFilterOutputSize(filter, null_selection); +} + +namespace { + +// ---------------------------------------------------------------------- +// Optimized and streamlined filter for primitive types + +// Use either BitBlockCounter or BinaryBitBlockCounter to quickly scan filter a +// word at a time for the DROP selection type. +class DropNullCounter { + public: + // validity bitmap may be null + DropNullCounter(const uint8_t* validity, const uint8_t* data, int64_t offset, + int64_t length) + : data_counter_(data, offset, length), + data_and_validity_counter_(data, offset, validity, offset, length), + has_validity_(validity != nullptr) {} + + BitBlockCount NextBlock() { + if (has_validity_) { + // filter is true AND not null + return data_and_validity_counter_.NextAndWord(); + } else { + return data_counter_.NextWord(); + } + } + + private: + // For when just data is present, but no validity bitmap + BitBlockCounter data_counter_; + + // For when both validity bitmap and data are present + BinaryBitBlockCounter data_and_validity_counter_; + const bool has_validity_; +}; + +/// \brief The Filter implementation for primitive (fixed-width) types does not +/// use the logical Arrow type but rather the physical C type. This way we only +/// generate one take function for each byte width. +/// +/// We use compile-time specialization for two variations: +/// - operating on boolean data (using kIsBoolean = true) +/// - operating on fixed-width data of arbitrary width (using kByteWidth = -1), +/// with the actual width only known at runtime +template +class PrimitiveFilterImpl { + public: + PrimitiveFilterImpl(const ArraySpan& values, const ArraySpan& filter, + FilterOptions::NullSelectionBehavior null_selection, + ArrayData* out_arr) + : byte_width_(util::FixedWidthInBytes(*values.type)), + values_is_valid_(values.buffers[0].data), + // No offset applied for boolean because it's a bitmap + values_data_(kIsBoolean ? values.buffers[1].data + : util::OffsetPointerOfFixedByteWidthValues(values)), + values_null_count_(values.null_count), + values_offset_(values.offset), + values_length_(values.length), + filter_(filter), + null_selection_(null_selection) { + if constexpr (kByteWidth >= 0 && !kIsBoolean) { + DCHECK_EQ(kByteWidth, byte_width_); + } + + DCHECK_EQ(out_arr->offset, 0); + if (out_arr->buffers[0] != nullptr) { + // May be unallocated if neither filter nor values contain nulls + out_is_valid_ = out_arr->buffers[0]->mutable_data(); + } + out_data_ = util::MutableFixedWidthValuesPointer(out_arr); + out_length_ = out_arr->length; + out_position_ = 0; + } + + void ExecREEFilter() { + if (filter_.child_data[1].null_count == 0 && values_null_count_ == 0) { + DCHECK(!out_is_valid_); + // Fastest: no nulls in either filter or values + return VisitPlainxREEFilterOutputSegments( + filter_, /*filter_may_have_nulls=*/false, null_selection_, + [&](int64_t position, int64_t segment_length, bool filter_valid) { + // Fastest path: all values in range are included and not null + WriteValueSegment(position, segment_length); + DCHECK(filter_valid); + return true; + }); + } + if (values_is_valid_) { + DCHECK(out_is_valid_); + // Slower path: values can be null, so the validity bitmap should be copied + return VisitPlainxREEFilterOutputSegments( + filter_, /*filter_may_have_nulls=*/true, null_selection_, + [&](int64_t position, int64_t segment_length, bool filter_valid) { + if (filter_valid) { + CopyBitmap(values_is_valid_, values_offset_ + position, segment_length, + out_is_valid_, out_position_); + WriteValueSegment(position, segment_length); + } else { + bit_util::SetBitsTo(out_is_valid_, out_position_, segment_length, false); + WriteNullSegment(segment_length); + } + return true; + }); + } + // Faster path: only write to out_is_valid_ if filter contains nulls and + // null_selection is EMIT_NULL + if (out_is_valid_) { + // Set all to valid, so only if nulls are produced by EMIT_NULL, we need + // to set out_is_valid[i] to false. + bit_util::SetBitsTo(out_is_valid_, 0, out_length_, true); + } + return VisitPlainxREEFilterOutputSegments( + filter_, /*filter_may_have_nulls=*/true, null_selection_, + [&](int64_t position, int64_t segment_length, bool filter_valid) { + if (filter_valid) { + WriteValueSegment(position, segment_length); + } else { + bit_util::SetBitsTo(out_is_valid_, out_position_, segment_length, false); + WriteNullSegment(segment_length); + } + return true; + }); + } + + void Exec() { + if (filter_.type->id() == Type::RUN_END_ENCODED) { + return ExecREEFilter(); + } + const auto* filter_is_valid = filter_.buffers[0].data; + const auto* filter_data = filter_.buffers[1].data; + const auto filter_offset = filter_.offset; + if (filter_.null_count == 0 && values_null_count_ == 0) { + // Fast filter when values and filter are not null + ::arrow::internal::VisitSetBitRunsVoid( + filter_data, filter_.offset, values_length_, + [&](int64_t position, int64_t length) { WriteValueSegment(position, length); }); + return; + } + + // Bit counters used for both null_selection behaviors + DropNullCounter drop_null_counter(filter_is_valid, filter_data, filter_offset, + values_length_); + OptionalBitBlockCounter data_counter(values_is_valid_, values_offset_, + values_length_); + OptionalBitBlockCounter filter_valid_counter(filter_is_valid, filter_offset, + values_length_); + + auto WriteNotNull = [&](int64_t index) { + bit_util::SetBit(out_is_valid_, out_position_); + // Increments out_position_ + WriteValue(index); + }; + + auto WriteMaybeNull = [&](int64_t index) { + bit_util::SetBitTo(out_is_valid_, out_position_, + bit_util::GetBit(values_is_valid_, values_offset_ + index)); + // Increments out_position_ + WriteValue(index); + }; + + int64_t in_position = 0; + while (in_position < values_length_) { + BitBlockCount filter_block = drop_null_counter.NextBlock(); + BitBlockCount filter_valid_block = filter_valid_counter.NextWord(); + BitBlockCount data_block = data_counter.NextWord(); + if (filter_block.AllSet() && data_block.AllSet()) { + // Fastest path: all values in block are included and not null + bit_util::SetBitsTo(out_is_valid_, out_position_, filter_block.length, true); + WriteValueSegment(in_position, filter_block.length); + in_position += filter_block.length; + } else if (filter_block.AllSet()) { + // Faster: all values are selected, but some values are null + // Batch copy bits from values validity bitmap to output validity bitmap + CopyBitmap(values_is_valid_, values_offset_ + in_position, filter_block.length, + out_is_valid_, out_position_); + WriteValueSegment(in_position, filter_block.length); + in_position += filter_block.length; + } else if (filter_block.NoneSet() && null_selection_ == FilterOptions::DROP) { + // For this exceedingly common case in low-selectivity filters we can + // skip further analysis of the data and move on to the next block. + in_position += filter_block.length; + } else { + // Some filter values are false or null + if (data_block.AllSet()) { + // No values are null + if (filter_valid_block.AllSet()) { + // Filter is non-null but some values are false + for (int64_t i = 0; i < filter_block.length; ++i) { + if (bit_util::GetBit(filter_data, filter_offset + in_position)) { + WriteNotNull(in_position); + } + ++in_position; + } + } else if (null_selection_ == FilterOptions::DROP) { + // If any values are selected, they ARE NOT null + for (int64_t i = 0; i < filter_block.length; ++i) { + if (bit_util::GetBit(filter_is_valid, filter_offset + in_position) && + bit_util::GetBit(filter_data, filter_offset + in_position)) { + WriteNotNull(in_position); + } + ++in_position; + } + } else { // null_selection == FilterOptions::EMIT_NULL + // Data values in this block are not null + for (int64_t i = 0; i < filter_block.length; ++i) { + const bool is_valid = + bit_util::GetBit(filter_is_valid, filter_offset + in_position); + if (is_valid && + bit_util::GetBit(filter_data, filter_offset + in_position)) { + // Filter slot is non-null and set + WriteNotNull(in_position); + } else if (!is_valid) { + // Filter slot is null, so we have a null in the output + bit_util::ClearBit(out_is_valid_, out_position_); + WriteNull(); + } + ++in_position; + } + } + } else { // !data_block.AllSet() + // Some values are null + if (filter_valid_block.AllSet()) { + // Filter is non-null but some values are false + for (int64_t i = 0; i < filter_block.length; ++i) { + if (bit_util::GetBit(filter_data, filter_offset + in_position)) { + WriteMaybeNull(in_position); + } + ++in_position; + } + } else if (null_selection_ == FilterOptions::DROP) { + // If any values are selected, they ARE NOT null + for (int64_t i = 0; i < filter_block.length; ++i) { + if (bit_util::GetBit(filter_is_valid, filter_offset + in_position) && + bit_util::GetBit(filter_data, filter_offset + in_position)) { + WriteMaybeNull(in_position); + } + ++in_position; + } + } else { // null_selection == FilterOptions::EMIT_NULL + // Data values in this block are not null + for (int64_t i = 0; i < filter_block.length; ++i) { + const bool is_valid = + bit_util::GetBit(filter_is_valid, filter_offset + in_position); + if (is_valid && + bit_util::GetBit(filter_data, filter_offset + in_position)) { + // Filter slot is non-null and set + WriteMaybeNull(in_position); + } else if (!is_valid) { + // Filter slot is null, so we have a null in the output + bit_util::ClearBit(out_is_valid_, out_position_); + WriteNull(); + } + ++in_position; + } + } + } + } // !filter_block.AllSet() + } // while(in_position < values_length_) + } + + // Write the next out_position given the selected in_position for the input + // data and advance out_position + void WriteValue(int64_t in_position) { + if constexpr (kIsBoolean) { + bit_util::SetBitTo(out_data_, out_position_, + bit_util::GetBit(values_data_, values_offset_ + in_position)); + } else { + memcpy(out_data_ + out_position_ * byte_width(), + values_data_ + in_position * byte_width(), byte_width()); + } + ++out_position_; + } + + void WriteValueSegment(int64_t in_start, int64_t length) { + if constexpr (kIsBoolean) { + CopyBitmap(values_data_, values_offset_ + in_start, length, out_data_, + out_position_); + } else { + memcpy(out_data_ + out_position_ * byte_width(), + values_data_ + in_start * byte_width(), length * byte_width()); + } + out_position_ += length; + } + + void WriteNull() { + if constexpr (kIsBoolean) { + // Zero the bit + bit_util::ClearBit(out_data_, out_position_); + } else { + // Zero the memory + memset(out_data_ + out_position_ * byte_width(), 0, byte_width()); + } + ++out_position_; + } + + void WriteNullSegment(int64_t length) { + if constexpr (kIsBoolean) { + // Zero the bits + bit_util::SetBitsTo(out_data_, out_position_, length, false); + } else { + // Zero the memory + memset(out_data_ + out_position_ * byte_width(), 0, length * byte_width()); + } + out_position_ += length; + } + + constexpr int64_t byte_width() const { + if constexpr (kByteWidth >= 0) { + return kByteWidth; + } else { + return byte_width_; + } + } + + private: + int64_t byte_width_; + const uint8_t* values_is_valid_; + const uint8_t* values_data_; + int64_t values_null_count_; + int64_t values_offset_; + int64_t values_length_; + const ArraySpan& filter_; + FilterOptions::NullSelectionBehavior null_selection_; + uint8_t* out_is_valid_ = NULLPTR; + uint8_t* out_data_; + int64_t out_length_; + int64_t out_position_; +}; + +} // namespace + +Status PrimitiveFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + const ArraySpan& values = batch[0].array; + const ArraySpan& filter = batch[1].array; + const bool is_ree_filter = filter.type->id() == Type::RUN_END_ENCODED; + FilterOptions::NullSelectionBehavior null_selection = + FilterState::Get(ctx).null_selection_behavior; + + int64_t output_length = GetFilterOutputSize(filter, null_selection); + + ArrayData* out_arr = out->array_data().get(); + + const bool filter_null_count_is_zero = + is_ree_filter ? filter.child_data[1].null_count == 0 : filter.null_count == 0; + + // The output precomputed null count is unknown except in the narrow + // condition that all the values are non-null and the filter will not cause + // any new nulls to be created. + if (values.null_count == 0 && + (null_selection == FilterOptions::DROP || filter_null_count_is_zero)) { + out_arr->null_count = 0; + } else { + out_arr->null_count = kUnknownNullCount; + } + + // When neither the values nor filter is known to have any nulls, we will + // elect the optimized non-null path where there is no need to populate a + // validity bitmap. + const bool allocate_validity = values.null_count != 0 || !filter_null_count_is_zero; + + DCHECK(util::IsFixedWidthLike(values)); + const int64_t bit_width = util::FixedWidthInBits(*values.type); + RETURN_NOT_OK(util::internal::PreallocateFixedWidthArrayData( + ctx, output_length, /*source=*/values, allocate_validity, out_arr)); + + switch (bit_width) { + case 1: + PrimitiveFilterImpl<1, /*kIsBoolean=*/true>(values, filter, null_selection, out_arr) + .Exec(); + break; + case 8: + PrimitiveFilterImpl<1>(values, filter, null_selection, out_arr).Exec(); + break; + case 16: + PrimitiveFilterImpl<2>(values, filter, null_selection, out_arr).Exec(); + break; + case 32: + PrimitiveFilterImpl<4>(values, filter, null_selection, out_arr).Exec(); + break; + case 64: + PrimitiveFilterImpl<8>(values, filter, null_selection, out_arr).Exec(); + break; + case 128: + // For INTERVAL_MONTH_DAY_NANO, DECIMAL128 + PrimitiveFilterImpl<16>(values, filter, null_selection, out_arr).Exec(); + break; + case 256: + // For DECIMAL256 + PrimitiveFilterImpl<32>(values, filter, null_selection, out_arr).Exec(); + break; + default: + // Non-specializing on byte width + PrimitiveFilterImpl<-1>(values, filter, null_selection, out_arr).Exec(); + break; + } + return Status::OK(); +} + +namespace { + +// ---------------------------------------------------------------------- +// Optimized filter for base binary types (32-bit and 64-bit) + +#define BINARY_FILTER_SETUP_COMMON() \ + const auto raw_offsets = values.GetValues(1); \ + const uint8_t* raw_data = values.buffers[2].data; \ + \ + TypedBufferBuilder offset_builder(ctx->memory_pool()); \ + TypedBufferBuilder data_builder(ctx->memory_pool()); \ + RETURN_NOT_OK(offset_builder.Reserve(output_length + 1)); \ + \ + /* Presize the data builder with a rough estimate */ \ + if (values.length > 0) { \ + const double mean_value_length = (raw_offsets[values.length] - raw_offsets[0]) / \ + static_cast(values.length); \ + RETURN_NOT_OK( \ + data_builder.Reserve(static_cast(mean_value_length * output_length))); \ + } \ + int64_t space_available = data_builder.capacity(); \ + offset_type offset = 0; + +#define APPEND_RAW_DATA(DATA, NBYTES) \ + if (ARROW_PREDICT_FALSE(NBYTES > space_available)) { \ + RETURN_NOT_OK(data_builder.Reserve(NBYTES)); \ + space_available = data_builder.capacity() - data_builder.length(); \ + } \ + data_builder.UnsafeAppend(DATA, NBYTES); \ + space_available -= NBYTES + +#define APPEND_SINGLE_VALUE() \ + do { \ + offset_type val_size = raw_offsets[in_position + 1] - raw_offsets[in_position]; \ + APPEND_RAW_DATA(raw_data + raw_offsets[in_position], val_size); \ + offset += val_size; \ + } while (0) + +// Optimized binary filter for the case where neither values nor filter have +// nulls +template +Status BinaryFilterNonNullImpl(KernelContext* ctx, const ArraySpan& values, + const ArraySpan& filter, int64_t output_length, + FilterOptions::NullSelectionBehavior null_selection, + ArrayData* out) { + using offset_type = typename ArrowType::offset_type; + const bool is_ree_filter = filter.type->id() == Type::RUN_END_ENCODED; + + BINARY_FILTER_SETUP_COMMON(); + + auto emit_segment = [&](int64_t position, int64_t length) { + // Bulk-append raw data + const offset_type run_data_bytes = + (raw_offsets[position + length] - raw_offsets[position]); + APPEND_RAW_DATA(raw_data + raw_offsets[position], run_data_bytes); + // Append offsets + for (int64_t i = 0; i < length; ++i) { + offset_builder.UnsafeAppend(offset); + offset += raw_offsets[i + position + 1] - raw_offsets[i + position]; + } + return Status::OK(); + }; + if (is_ree_filter) { + Status status; + VisitPlainxREEFilterOutputSegments( + filter, /*filter_may_have_nulls=*/false, null_selection, + [&status, emit_segment = std::move(emit_segment)]( + int64_t position, int64_t segment_length, bool filter_valid) { + DCHECK(filter_valid); + status = emit_segment(position, segment_length); + return status.ok(); + }); + RETURN_NOT_OK(std::move(status)); + } else { + const auto filter_data = filter.buffers[1].data; + RETURN_NOT_OK(arrow::internal::VisitSetBitRuns( + filter_data, filter.offset, filter.length, std::move(emit_segment))); + } + + offset_builder.UnsafeAppend(offset); + out->length = output_length; + RETURN_NOT_OK(offset_builder.Finish(&out->buffers[1])); + return data_builder.Finish(&out->buffers[2]); +} + +template +Status BinaryFilterImpl(KernelContext* ctx, const ArraySpan& values, + const ArraySpan& filter, int64_t output_length, + FilterOptions::NullSelectionBehavior null_selection, + ArrayData* out) { + using offset_type = typename ArrowType::offset_type; + + const bool is_ree_filter = filter.type->id() == Type::RUN_END_ENCODED; + + BINARY_FILTER_SETUP_COMMON(); + + const uint8_t* values_is_valid = values.buffers[0].data; + const int64_t values_offset = values.offset; + + const int64_t out_offset = out->offset; + uint8_t* out_is_valid = out->buffers[0]->mutable_data(); + // Zero bits and then only have to set valid values to true + bit_util::SetBitsTo(out_is_valid, out_offset, output_length, false); + + int64_t in_position = 0; + int64_t out_position = 0; + if (is_ree_filter) { + auto emit_segment = [&](int64_t position, int64_t segment_length, bool filter_valid) { + in_position = position; + if (filter_valid) { + // Filter values are all true and not null + // Some of the values in the block may be null + for (int64_t i = 0; i < segment_length; ++i, ++in_position, ++out_position) { + offset_builder.UnsafeAppend(offset); + if (bit_util::GetBit(values_is_valid, values_offset + in_position)) { + bit_util::SetBit(out_is_valid, out_offset + out_position); + APPEND_SINGLE_VALUE(); + } + } + } else { + offset_builder.UnsafeAppend(segment_length, offset); + out_position += segment_length; + } + return Status::OK(); + }; + Status status; + VisitPlainxREEFilterOutputSegments( + filter, /*filter_may_have_nulls=*/true, null_selection, + [&status, emit_segment = std::move(emit_segment)]( + int64_t position, int64_t segment_length, bool filter_valid) { + status = emit_segment(position, segment_length, filter_valid); + return status.ok(); + }); + RETURN_NOT_OK(std::move(status)); + } else { + const auto filter_data = filter.buffers[1].data; + const uint8_t* filter_is_valid = filter.buffers[0].data; + const int64_t filter_offset = filter.offset; + + // We use 3 block counters for fast scanning of the filter + // + // * values_valid_counter: for values null/not-null + // * filter_valid_counter: for filter null/not-null + // * filter_counter: for filter true/false + OptionalBitBlockCounter values_valid_counter(values_is_valid, values_offset, + values.length); + OptionalBitBlockCounter filter_valid_counter(filter_is_valid, filter_offset, + filter.length); + BitBlockCounter filter_counter(filter_data, filter_offset, filter.length); + + while (in_position < filter.length) { + BitBlockCount filter_valid_block = filter_valid_counter.NextWord(); + BitBlockCount values_valid_block = values_valid_counter.NextWord(); + BitBlockCount filter_block = filter_counter.NextWord(); + if (filter_block.NoneSet() && null_selection == FilterOptions::DROP) { + // For this exceedingly common case in low-selectivity filters we can + // skip further analysis of the data and move on to the next block. + in_position += filter_block.length; + } else if (filter_valid_block.AllSet()) { + // Simpler path: no filter values are null + if (filter_block.AllSet()) { + // Fastest path: filter values are all true and not null + if (values_valid_block.AllSet()) { + // The values aren't null either + bit_util::SetBitsTo(out_is_valid, out_offset + out_position, + filter_block.length, true); + + // Bulk-append raw data + offset_type block_data_bytes = + (raw_offsets[in_position + filter_block.length] - + raw_offsets[in_position]); + APPEND_RAW_DATA(raw_data + raw_offsets[in_position], block_data_bytes); + // Append offsets + for (int64_t i = 0; i < filter_block.length; ++i, ++in_position) { + offset_builder.UnsafeAppend(offset); + offset += raw_offsets[in_position + 1] - raw_offsets[in_position]; + } + out_position += filter_block.length; + } else { + // Some of the values in this block are null + for (int64_t i = 0; i < filter_block.length; + ++i, ++in_position, ++out_position) { + offset_builder.UnsafeAppend(offset); + if (bit_util::GetBit(values_is_valid, values_offset + in_position)) { + bit_util::SetBit(out_is_valid, out_offset + out_position); + APPEND_SINGLE_VALUE(); + } + } + } + } else { // !filter_block.AllSet() + // Some of the filter values are false, but all not null + if (values_valid_block.AllSet()) { + // All the values are not-null, so we can skip null checking for + // them + for (int64_t i = 0; i < filter_block.length; ++i, ++in_position) { + if (bit_util::GetBit(filter_data, filter_offset + in_position)) { + offset_builder.UnsafeAppend(offset); + bit_util::SetBit(out_is_valid, out_offset + out_position++); + APPEND_SINGLE_VALUE(); + } + } + } else { + // Some of the values in the block are null, so we have to check + // each one + for (int64_t i = 0; i < filter_block.length; ++i, ++in_position) { + if (bit_util::GetBit(filter_data, filter_offset + in_position)) { + offset_builder.UnsafeAppend(offset); + if (bit_util::GetBit(values_is_valid, values_offset + in_position)) { + bit_util::SetBit(out_is_valid, out_offset + out_position); + APPEND_SINGLE_VALUE(); + } + ++out_position; + } + } + } + } + } else { // !filter_valid_block.AllSet() + // Some of the filter values are null, so we have to handle the DROP + // versus EMIT_NULL null selection behavior. + if (null_selection == FilterOptions::DROP) { + // Filter null values are treated as false. + if (values_valid_block.AllSet()) { + for (int64_t i = 0; i < filter_block.length; ++i, ++in_position) { + if (bit_util::GetBit(filter_is_valid, filter_offset + in_position) && + bit_util::GetBit(filter_data, filter_offset + in_position)) { + offset_builder.UnsafeAppend(offset); + bit_util::SetBit(out_is_valid, out_offset + out_position++); + APPEND_SINGLE_VALUE(); + } + } + } else { + for (int64_t i = 0; i < filter_block.length; ++i, ++in_position) { + if (bit_util::GetBit(filter_is_valid, filter_offset + in_position) && + bit_util::GetBit(filter_data, filter_offset + in_position)) { + offset_builder.UnsafeAppend(offset); + if (bit_util::GetBit(values_is_valid, values_offset + in_position)) { + bit_util::SetBit(out_is_valid, out_offset + out_position); + APPEND_SINGLE_VALUE(); + } + ++out_position; + } + } + } + } else { + // EMIT_NULL + + // Filter null values are appended to output as null whether the + // value in the corresponding slot is valid or not + if (values_valid_block.AllSet()) { + for (int64_t i = 0; i < filter_block.length; ++i, ++in_position) { + const bool filter_not_null = + bit_util::GetBit(filter_is_valid, filter_offset + in_position); + if (filter_not_null && + bit_util::GetBit(filter_data, filter_offset + in_position)) { + offset_builder.UnsafeAppend(offset); + bit_util::SetBit(out_is_valid, out_offset + out_position++); + APPEND_SINGLE_VALUE(); + } else if (!filter_not_null) { + offset_builder.UnsafeAppend(offset); + ++out_position; + } + } + } else { + for (int64_t i = 0; i < filter_block.length; ++i, ++in_position) { + const bool filter_not_null = + bit_util::GetBit(filter_is_valid, filter_offset + in_position); + if (filter_not_null && + bit_util::GetBit(filter_data, filter_offset + in_position)) { + offset_builder.UnsafeAppend(offset); + if (bit_util::GetBit(values_is_valid, values_offset + in_position)) { + bit_util::SetBit(out_is_valid, out_offset + out_position); + APPEND_SINGLE_VALUE(); + } + ++out_position; + } else if (!filter_not_null) { + offset_builder.UnsafeAppend(offset); + ++out_position; + } + } + } + } + } + } + } + offset_builder.UnsafeAppend(offset); + out->length = output_length; + RETURN_NOT_OK(offset_builder.Finish(&out->buffers[1])); + return data_builder.Finish(&out->buffers[2]); +} + +#undef BINARY_FILTER_SETUP_COMMON +#undef APPEND_RAW_DATA +#undef APPEND_SINGLE_VALUE + +Status BinaryFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + FilterOptions::NullSelectionBehavior null_selection = + FilterState::Get(ctx).null_selection_behavior; + + const ArraySpan& values = batch[0].array; + const ArraySpan& filter = batch[1].array; + const bool is_ree_filter = filter.type->id() == Type::RUN_END_ENCODED; + int64_t output_length = GetFilterOutputSize(filter, null_selection); + + ArrayData* out_arr = out->array_data().get(); + + const bool filter_null_count_is_zero = + is_ree_filter ? filter.child_data[1].null_count == 0 : filter.null_count == 0; + + // The output precomputed null count is unknown except in the narrow + // condition that all the values are non-null and the filter will not cause + // any new nulls to be created. + if (values.null_count == 0 && + (null_selection == FilterOptions::DROP || filter_null_count_is_zero)) { + out_arr->null_count = 0; + } else { + out_arr->null_count = kUnknownNullCount; + } + Type::type type_id = values.type->id(); + if (values.null_count == 0 && filter_null_count_is_zero) { + // Faster no-nulls case + if (is_binary_like(type_id)) { + RETURN_NOT_OK(BinaryFilterNonNullImpl( + ctx, values, filter, output_length, null_selection, out_arr)); + } else if (is_large_binary_like(type_id)) { + RETURN_NOT_OK(BinaryFilterNonNullImpl( + ctx, values, filter, output_length, null_selection, out_arr)); + } else { + DCHECK(false); + } + } else { + // Output may have nulls + RETURN_NOT_OK(ctx->AllocateBitmap(output_length).Value(&out_arr->buffers[0])); + if (is_binary_like(type_id)) { + RETURN_NOT_OK(BinaryFilterImpl(ctx, values, filter, output_length, + null_selection, out_arr)); + } else if (is_large_binary_like(type_id)) { + RETURN_NOT_OK(BinaryFilterImpl(ctx, values, filter, output_length, + null_selection, out_arr)); + } else { + DCHECK(false); + } + } + + return Status::OK(); +} + +// ---------------------------------------------------------------------- +// Null filter + +Status NullFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + int64_t output_length = + GetFilterOutputSize(batch[1].array, FilterState::Get(ctx).null_selection_behavior); + out->value = std::make_shared(output_length)->data(); + return Status::OK(); +} + +// ---------------------------------------------------------------------- +// Dictionary filter + +Status DictionaryFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + DictionaryArray dict_values(batch[0].array.ToArrayData()); + Datum result; + RETURN_NOT_OK(Filter(Datum(dict_values.indices()), batch[1].array.ToArrayData(), + FilterState::Get(ctx), ctx->exec_context()) + .Value(&result)); + DictionaryArray filtered_values(dict_values.type(), result.make_array(), + dict_values.dictionary()); + out->value = filtered_values.data(); + return Status::OK(); +} + +// ---------------------------------------------------------------------- +// Extension filter + +Status ExtensionFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + ExtensionArray ext_values(batch[0].array.ToArrayData()); + Datum result; + RETURN_NOT_OK(Filter(Datum(ext_values.storage()), batch[1].array.ToArrayData(), + FilterState::Get(ctx), ctx->exec_context()) + .Value(&result)); + ExtensionArray filtered_values(ext_values.type(), result.make_array()); + out->value = filtered_values.data(); + return Status::OK(); +} + +// Transform filter to selection indices and then use Take. +Status FilterWithTakeExec(const ArrayKernelExec& take_exec, KernelContext* ctx, + const ExecSpan& batch, ExecResult* out) { + std::shared_ptr indices; + RETURN_NOT_OK(GetTakeIndices(batch[1].array, + FilterState::Get(ctx).null_selection_behavior, + ctx->memory_pool()) + .Value(&indices)); + KernelContext take_ctx(*ctx); + TakeState state{TakeOptions::NoBoundsCheck()}; + take_ctx.SetState(&state); + ExecSpan take_batch({batch[0], ArraySpan(*indices)}, batch.length); + return take_exec(&take_ctx, take_batch, out); +} + +// Due to the special treatment with their Take kernels, we filter Struct and SparseUnion +// arrays by transforming filter to selection indices and call Take. +Status StructFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + return FilterWithTakeExec(StructTakeExec, ctx, batch, out); +} + +Status SparseUnionFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + return FilterWithTakeExec(SparseUnionTakeExec, ctx, batch, out); +} + +// ---------------------------------------------------------------------- +// Implement Filter metafunction + +Result> FilterRecordBatch(const RecordBatch& batch, + const Datum& filter, + const FunctionOptions* options, + ExecContext* ctx) { + if (batch.num_rows() != filter.length()) { + return Status::Invalid("Filter inputs must all be the same length"); + } + + // Fetch filter + const auto& filter_opts = *static_cast(options); + ArrayData filter_array; + switch (filter.kind()) { + case Datum::ARRAY: + filter_array = *filter.array(); + break; + case Datum::CHUNKED_ARRAY: { + ARROW_ASSIGN_OR_RAISE(auto combined, Concatenate(filter.chunked_array()->chunks())); + filter_array = *combined->data(); + break; + } + default: + return Status::TypeError("Filter should be array-like"); + } + + // Convert filter to selection vector/indices and use Take + ARROW_ASSIGN_OR_RAISE(std::shared_ptr indices, + GetTakeIndices(filter_array, filter_opts.null_selection_behavior, + ctx->memory_pool())); + std::vector> columns(batch.num_columns()); + for (int i = 0; i < batch.num_columns(); ++i) { + ARROW_ASSIGN_OR_RAISE(Datum out, Take(batch.column(i)->data(), Datum(indices), + TakeOptions::NoBoundsCheck(), ctx)); + columns[i] = out.make_array(); + } + return RecordBatch::Make(batch.schema(), indices->length, std::move(columns)); +} + +Result> FilterTable(const Table& table, const Datum& filter, + const FunctionOptions* options, + ExecContext* ctx) { + if (table.num_rows() != filter.length()) { + return Status::Invalid("Filter inputs must all be the same length"); + } + if (table.num_rows() == 0) { + return Table::Make(table.schema(), table.columns(), 0); + } + + // Last input element will be the filter array + const int num_columns = table.num_columns(); + std::vector inputs(num_columns + 1); + + // Fetch table columns + for (int i = 0; i < num_columns; ++i) { + inputs[i] = table.column(i)->chunks(); + } + // Fetch filter + const auto& filter_opts = *static_cast(options); + switch (filter.kind()) { + case Datum::ARRAY: + inputs.back().push_back(filter.make_array()); + break; + case Datum::CHUNKED_ARRAY: + inputs.back() = filter.chunked_array()->chunks(); + break; + default: + return Status::TypeError("Filter should be array-like"); + } + + // Rechunk inputs to allow consistent iteration over their respective chunks + inputs = arrow::internal::RechunkArraysConsistently(inputs); + + // Instead of filtering each column with the boolean filter + // (which would be slow if the table has a large number of columns: ARROW-10569), + // convert each filter chunk to indices, and take() the column. + const int64_t num_chunks = static_cast(inputs.back().size()); + std::vector out_columns(num_columns); + int64_t out_num_rows = 0; + + for (int64_t i = 0; i < num_chunks; ++i) { + const ArrayData& filter_chunk = *inputs.back()[i]->data(); + ARROW_ASSIGN_OR_RAISE( + const auto indices, + GetTakeIndices(filter_chunk, filter_opts.null_selection_behavior, + ctx->memory_pool())); + + if (indices->length > 0) { + // Take from all input columns + Datum indices_datum{std::move(indices)}; + for (int col = 0; col < num_columns; ++col) { + const auto& column_chunk = inputs[col][i]; + ARROW_ASSIGN_OR_RAISE(Datum out, Take(column_chunk, indices_datum, + TakeOptions::NoBoundsCheck(), ctx)); + out_columns[col].push_back(std::move(out).make_array()); + } + out_num_rows += indices->length; + } + } + + ChunkedArrayVector out_chunks(num_columns); + for (int i = 0; i < num_columns; ++i) { + out_chunks[i] = std::make_shared(std::move(out_columns[i]), + table.column(i)->type()); + } + return Table::Make(table.schema(), std::move(out_chunks), out_num_rows); +} + +const FunctionDoc filter_doc( + "Filter with a boolean selection filter", + ("The output is populated with values from the input at positions\n" + "where the selection filter is non-zero. Nulls in the selection filter\n" + "are handled based on FilterOptions."), + {"input", "selection_filter"}, "FilterOptions"); + +class FilterMetaFunction : public MetaFunction { + public: + FilterMetaFunction() + : MetaFunction("filter", Arity::Binary(), filter_doc, GetDefaultFilterOptions()) {} + + Result ExecuteImpl(const std::vector& args, + const FunctionOptions* options, + ExecContext* ctx) const override { + if (args[1].kind() != Datum::ARRAY && args[1].kind() != Datum::CHUNKED_ARRAY) { + return Status::TypeError("Filter should be array-like"); + } + + const auto& filter_type = *args[1].type(); + const bool filter_is_plain_bool = filter_type.id() == Type::BOOL; + const bool filter_is_ree_bool = + filter_type.id() == Type::RUN_END_ENCODED && + checked_cast(filter_type).value_type()->id() == + Type::BOOL; + if (!filter_is_plain_bool && !filter_is_ree_bool) { + return Status::NotImplemented("Filter argument must be boolean type"); + } + + if (args[0].kind() == Datum::RECORD_BATCH) { + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr out_batch, + FilterRecordBatch(*args[0].record_batch(), args[1], options, ctx)); + return Datum(out_batch); + } else if (args[0].kind() == Datum::TABLE) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_table, + FilterTable(*args[0].table(), args[1], options, ctx)); + return Datum(out_table); + } else { + return CallFunction("array_filter", args, options, ctx); + } + } +}; + +// ---------------------------------------------------------------------- + +} // namespace + +// const FilterOptions* GetDefaultFilterOptions() { +// static const auto kDefaultFilterOptions = FilterOptions::Defaults(); +// return &kDefaultFilterOptions; +// } + +std::unique_ptr MakeScatterByMaskMetaFunction() { + return std::make_unique(); +} + +void PopulateScatterByMaskKernels(std::vector* out) { + auto plain_filter = InputType(Type::BOOL); + auto ree_filter = InputType(match::RunEndEncoded(Type::BOOL)); + + *out = { + // * x Boolean + {InputType(match::Primitive()), plain_filter, PrimitiveFilterExec}, + {InputType(match::BinaryLike()), plain_filter, BinaryFilterExec}, + {InputType(match::LargeBinaryLike()), plain_filter, BinaryFilterExec}, + {InputType(null()), plain_filter, NullFilterExec}, + {InputType(Type::FIXED_SIZE_BINARY), plain_filter, PrimitiveFilterExec}, + {InputType(Type::DECIMAL128), plain_filter, PrimitiveFilterExec}, + {InputType(Type::DECIMAL256), plain_filter, PrimitiveFilterExec}, + {InputType(Type::DICTIONARY), plain_filter, DictionaryFilterExec}, + {InputType(Type::EXTENSION), plain_filter, ExtensionFilterExec}, + {InputType(Type::LIST), plain_filter, ListFilterExec}, + {InputType(Type::LARGE_LIST), plain_filter, LargeListFilterExec}, + {InputType(Type::FIXED_SIZE_LIST), plain_filter, FSLFilterExec}, + {InputType(Type::DENSE_UNION), plain_filter, DenseUnionFilterExec}, + {InputType(Type::SPARSE_UNION), plain_filter, SparseUnionFilterExec}, + {InputType(Type::STRUCT), plain_filter, StructFilterExec}, + {InputType(Type::MAP), plain_filter, MapFilterExec}, + + // * x REE(Boolean) + {InputType(match::Primitive()), ree_filter, PrimitiveFilterExec}, + {InputType(match::BinaryLike()), ree_filter, BinaryFilterExec}, + {InputType(match::LargeBinaryLike()), ree_filter, BinaryFilterExec}, + {InputType(null()), ree_filter, NullFilterExec}, + {InputType(Type::FIXED_SIZE_BINARY), ree_filter, PrimitiveFilterExec}, + {InputType(Type::DECIMAL128), ree_filter, PrimitiveFilterExec}, + {InputType(Type::DECIMAL256), ree_filter, PrimitiveFilterExec}, + {InputType(Type::DICTIONARY), ree_filter, DictionaryFilterExec}, + {InputType(Type::EXTENSION), ree_filter, ExtensionFilterExec}, + {InputType(Type::LIST), ree_filter, ListFilterExec}, + {InputType(Type::LARGE_LIST), ree_filter, LargeListFilterExec}, + {InputType(Type::FIXED_SIZE_LIST), ree_filter, FSLFilterExec}, + {InputType(Type::DENSE_UNION), ree_filter, DenseUnionFilterExec}, + {InputType(Type::SPARSE_UNION), ree_filter, SparseUnionFilterExec}, + {InputType(Type::STRUCT), ree_filter, StructFilterExec}, + {InputType(Type::MAP), ree_filter, MapFilterExec}, + }; +} + +} // namespace compute::internal + +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/vector_scatter_by_mask_internal.h b/cpp/src/arrow/compute/kernels/vector_scatter_by_mask_internal.h new file mode 100644 index 0000000000000..ec1e3f849a7e0 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/vector_scatter_by_mask_internal.h @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include "arrow/array/data.h" +#include "arrow/compute/api_vector.h" +#include "arrow/compute/kernels/vector_scatter_internal.h" + +namespace arrow { +namespace compute { +namespace internal { + +std::unique_ptr MakeScatterByMaskMetaFunction(); + +void PopulateScatterByMaskKernels(std::vector* out); + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/vector_scatter_internal.cc b/cpp/src/arrow/compute/kernels/vector_scatter_internal.cc new file mode 100644 index 0000000000000..2e95132c8735e --- /dev/null +++ b/cpp/src/arrow/compute/kernels/vector_scatter_internal.cc @@ -0,0 +1,888 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include + +#include "arrow/array/array_binary.h" +#include "arrow/array/array_nested.h" +#include "arrow/array/builder_primitive.h" +#include "arrow/buffer_builder.h" +#include "arrow/chunked_array.h" +#include "arrow/compute/api_vector.h" +#include "arrow/compute/function.h" +#include "arrow/compute/kernel.h" +#include "arrow/compute/kernels/codegen_internal.h" +#include "arrow/compute/kernels/vector_scatter_internal.h" +#include "arrow/compute/registry.h" +#include "arrow/type.h" +#include "arrow/type_traits.h" +#include "arrow/util/bit_block_counter.h" +#include "arrow/util/bit_run_reader.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/fixed_width_internal.h" +#include "arrow/util/int_util.h" +#include "arrow/util/logging.h" +#include "arrow/util/ree_util.h" + +namespace arrow { + +using internal::CheckIndexBounds; + +namespace compute::internal { + +void RegisterScatterFunction(const std::string& name, FunctionDoc doc, + VectorKernel base_kernel, + std::vector&& kernels, + const FunctionOptions* default_options, + FunctionRegistry* registry) { + auto func = std::make_shared(name, Arity::Binary(), std::move(doc), + default_options); + for (auto&& kernel_data : kernels) { + base_kernel.signature = KernelSignature::Make( + {std::move(kernel_data.value_type), std::move(kernel_data.selection_type)}, + OutputType(FirstType)); + base_kernel.exec = kernel_data.exec; + DCHECK_OK(func->AddKernel(base_kernel)); + } + kernels.clear(); + DCHECK_OK(registry->AddFunction(std::move(func))); +} + +// namespace { + +// /// \brief Iterate over a REE filter, emitting ranges of a plain values array that +// /// would pass the filter. +// /// +// /// Differently from REExREE, and REExPlain filtering, PlainxREE filtering +// /// does not produce a REE output, but rather a plain output array. As such it's +// /// much simpler. +// /// +// /// \param filter_may_have_nulls Only pass false if you know the filter has no nulls. +// template +// void VisitPlainxREEFilterOutputSegmentsImpl( +// const ArraySpan& filter, bool filter_may_have_nulls, +// FilterOptions::NullSelectionBehavior null_selection, +// const EmitREEFilterSegment& emit_segment) { +// using FilterRunEndCType = typename FilterRunEndType::c_type; +// const ArraySpan& filter_values = arrow::ree_util::ValuesArray(filter); +// const int64_t filter_values_offset = filter_values.offset; +// const uint8_t* filter_is_valid = filter_values.buffers[0].data; +// const uint8_t* filter_selection = filter_values.buffers[1].data; +// filter_may_have_nulls = filter_may_have_nulls && filter_is_valid != nullptr && +// filter_values.null_count != 0; + +// const arrow::ree_util::RunEndEncodedArraySpan filter_span(filter); +// auto it = filter_span.begin(); +// if (filter_may_have_nulls) { +// if (null_selection == FilterOptions::EMIT_NULL) { +// while (!it.is_end(filter_span)) { +// const int64_t i = filter_values_offset + it.index_into_array(); +// const bool valid = bit_util::GetBit(filter_is_valid, i); +// const bool emit = !valid || bit_util::GetBit(filter_selection, i); +// if (ARROW_PREDICT_FALSE( +// emit && !emit_segment(it.logical_position(), it.run_length(), valid))) { +// break; +// } +// ++it; +// } +// } else { // DROP nulls +// while (!it.is_end(filter_span)) { +// const int64_t i = filter_values_offset + it.index_into_array(); +// const bool emit = +// bit_util::GetBit(filter_is_valid, i) && bit_util::GetBit(filter_selection, i); +// if (ARROW_PREDICT_FALSE( +// emit && !emit_segment(it.logical_position(), it.run_length(), true))) { +// break; +// } +// ++it; +// } +// } +// } else { +// while (!it.is_end(filter_span)) { +// const int64_t i = filter_values_offset + it.index_into_array(); +// const bool emit = bit_util::GetBit(filter_selection, i); +// if (ARROW_PREDICT_FALSE( +// emit && !emit_segment(it.logical_position(), it.run_length(), true))) { +// break; +// } +// ++it; +// } +// } +// } + +// } // namespace + +// void VisitPlainxREEFilterOutputSegments( +// const ArraySpan& filter, bool filter_may_have_nulls, +// FilterOptions::NullSelectionBehavior null_selection, +// const EmitREEFilterSegment& emit_segment) { +// if (filter.length == 0) { +// return; +// } +// const auto& ree_type = checked_cast(*filter.type); +// switch (ree_type.run_end_type()->id()) { +// case Type::INT16: +// return VisitPlainxREEFilterOutputSegmentsImpl( +// filter, filter_may_have_nulls, null_selection, emit_segment); +// case Type::INT32: +// return VisitPlainxREEFilterOutputSegmentsImpl( +// filter, filter_may_have_nulls, null_selection, emit_segment); +// default: +// DCHECK(ree_type.run_end_type()->id() == Type::INT64); +// return VisitPlainxREEFilterOutputSegmentsImpl( +// filter, filter_may_have_nulls, null_selection, emit_segment); +// } +// } + +// namespace { + +// // ---------------------------------------------------------------------- +// // Implement take for other data types where there is less performance +// // sensitivity by visiting the selected indices. + +// // Use CRTP to dispatch to type-specific processing of take indices for each +// // unsigned integer type. +// template +// struct Selection { +// using ValuesArrayType = typename TypeTraits::ArrayType; + +// // Forwards the generic value visitors to the VisitFilter template +// struct FilterAdapter { +// static constexpr bool is_take = false; + +// Impl* impl; +// explicit FilterAdapter(Impl* impl) : impl(impl) {} +// template +// Status Generate(ValidVisitor&& visit_valid, NullVisitor&& visit_null) { +// return impl->VisitFilter(std::forward(visit_valid), +// std::forward(visit_null)); +// } +// }; + +// // Forwards the generic value visitors to the take index visitor template +// template +// struct TakeAdapter { +// static constexpr bool is_take = true; + +// Impl* impl; +// explicit TakeAdapter(Impl* impl) : impl(impl) {} +// template +// Status Generate(ValidVisitor&& visit_valid, NullVisitor&& visit_null) { +// return impl->template VisitTake(std::forward(visit_valid), +// std::forward(visit_null)); +// } +// }; + +// KernelContext* ctx; +// const ArraySpan& values; +// const ArraySpan& selection; +// int64_t output_length; +// ArrayData* out; +// TypedBufferBuilder validity_builder; + +// Selection(KernelContext* ctx, const ExecSpan& batch, int64_t output_length, +// ExecResult* out) +// : ctx(ctx), +// values(batch[0].array), +// selection(batch[1].array), +// output_length(output_length), +// out(out->array_data().get()), +// validity_builder(ctx->memory_pool()) {} + +// virtual ~Selection() = default; + +// Status FinishCommon() { +// out->buffers.resize(values.num_buffers()); +// out->length = validity_builder.length(); +// out->null_count = validity_builder.false_count(); +// return validity_builder.Finish(&out->buffers[0]); +// } + +// template +// Status VisitTake(ValidVisitor&& visit_valid, NullVisitor&& visit_null) { +// const auto indices_values = selection.GetValues(1); +// const uint8_t* is_valid = selection.buffers[0].data; +// arrow::internal::OptionalBitIndexer indices_is_valid(is_valid, selection.offset); +// arrow::internal::OptionalBitIndexer values_is_valid(values.buffers[0].data, +// values.offset); + +// const bool values_have_nulls = values.MayHaveNulls(); +// arrow::internal::OptionalBitBlockCounter bit_counter(is_valid, selection.offset, +// selection.length); +// int64_t position = 0; +// while (position < selection.length) { +// BitBlockCount block = bit_counter.NextBlock(); +// const bool indices_have_nulls = block.popcount < block.length; +// if (!indices_have_nulls && !values_have_nulls) { +// // Fastest path, neither indices nor values have nulls +// validity_builder.UnsafeAppend(block.length, true); +// for (int64_t i = 0; i < block.length; ++i) { +// RETURN_NOT_OK(visit_valid(indices_values[position++])); +// } +// } else if (block.popcount > 0) { +// // Since we have to branch on whether the indices are null or not, we +// // combine the "non-null indices block but some values null" and +// // "some-null indices block but values non-null" into a single loop. +// for (int64_t i = 0; i < block.length; ++i) { +// if ((!indices_have_nulls || indices_is_valid[position]) && +// values_is_valid[indices_values[position]]) { +// validity_builder.UnsafeAppend(true); +// RETURN_NOT_OK(visit_valid(indices_values[position])); +// } else { +// validity_builder.UnsafeAppend(false); +// RETURN_NOT_OK(visit_null()); +// } +// ++position; +// } +// } else { +// // The whole block is null +// validity_builder.UnsafeAppend(block.length, false); +// for (int64_t i = 0; i < block.length; ++i) { +// RETURN_NOT_OK(visit_null()); +// } +// position += block.length; +// } +// } +// return Status::OK(); +// } + +// // We use the NullVisitor both for "selected" nulls as well as "emitted" +// // nulls coming from the filter when using FilterOptions::EMIT_NULL +// template +// Status VisitFilter(ValidVisitor&& visit_valid, NullVisitor&& visit_null) { +// const bool is_ree_filter = selection.type->id() == Type::RUN_END_ENCODED; +// const auto null_selection = FilterState::Get(ctx).null_selection_behavior; + +// arrow::internal::OptionalBitIndexer values_is_valid(values.buffers[0].data, +// values.offset); + +// auto AppendNotNull = [&](int64_t index) -> Status { +// validity_builder.UnsafeAppend(true); +// return visit_valid(index); +// }; + +// auto AppendNull = [&]() -> Status { +// validity_builder.UnsafeAppend(false); +// return visit_null(); +// }; + +// auto AppendMaybeNull = [&](int64_t index) -> Status { +// if (values_is_valid[index]) { +// return AppendNotNull(index); +// } else { +// return AppendNull(); +// } +// }; + +// if (is_ree_filter) { +// Status status; +// VisitPlainxREEFilterOutputSegments( +// selection, /*filter_may_have_nulls=*/true, null_selection, +// [&](int64_t position, int64_t segment_length, bool filter_valid) { +// if (filter_valid) { +// for (int64_t i = 0; i < segment_length; ++i) { +// status = AppendMaybeNull(position + i); +// } +// } else { +// for (int64_t i = 0; i < segment_length; ++i) { +// status = AppendNull(); +// } +// } +// return status.ok(); +// }); +// return status; +// } + +// const uint8_t* filter_data = selection.buffers[1].data; +// const uint8_t* filter_is_valid = selection.buffers[0].data; +// const int64_t filter_offset = selection.offset; +// // We use 3 block counters for fast scanning of the filter +// // +// // * values_valid_counter: for values null/not-null +// // * filter_valid_counter: for filter null/not-null +// // * filter_counter: for filter true/false +// arrow::internal::OptionalBitBlockCounter values_valid_counter( +// values.buffers[0].data, values.offset, values.length); +// arrow::internal::OptionalBitBlockCounter filter_valid_counter( +// filter_is_valid, filter_offset, selection.length); +// arrow::internal::BitBlockCounter filter_counter(filter_data, filter_offset, +// selection.length); + +// int64_t in_position = 0; +// while (in_position < selection.length) { +// arrow::internal::BitBlockCount filter_valid_block = filter_valid_counter.NextWord(); +// arrow::internal::BitBlockCount values_valid_block = values_valid_counter.NextWord(); +// arrow::internal::BitBlockCount filter_block = filter_counter.NextWord(); +// if (filter_block.NoneSet() && null_selection == FilterOptions::DROP) { +// // For this exceedingly common case in low-selectivity filters we can +// // skip further analysis of the data and move on to the next block. +// in_position += filter_block.length; +// } else if (filter_valid_block.AllSet()) { +// // Simpler path: no filter values are null +// if (filter_block.AllSet()) { +// // Fastest path: filter values are all true and not null +// if (values_valid_block.AllSet()) { +// // The values aren't null either +// validity_builder.UnsafeAppend(filter_block.length, true); +// for (int64_t i = 0; i < filter_block.length; ++i) { +// RETURN_NOT_OK(visit_valid(in_position++)); +// } +// } else { +// // Some of the values in this block are null +// for (int64_t i = 0; i < filter_block.length; ++i) { +// RETURN_NOT_OK(AppendMaybeNull(in_position++)); +// } +// } +// } else { // !filter_block.AllSet() +// // Some of the filter values are false, but all not null +// if (values_valid_block.AllSet()) { +// // All the values are not-null, so we can skip null checking for +// // them +// for (int64_t i = 0; i < filter_block.length; ++i) { +// if (bit_util::GetBit(filter_data, filter_offset + in_position)) { +// RETURN_NOT_OK(AppendNotNull(in_position)); +// } +// ++in_position; +// } +// } else { +// // Some of the values in the block are null, so we have to check +// // each one +// for (int64_t i = 0; i < filter_block.length; ++i) { +// if (bit_util::GetBit(filter_data, filter_offset + in_position)) { +// RETURN_NOT_OK(AppendMaybeNull(in_position)); +// } +// ++in_position; +// } +// } +// } +// } else { // !filter_valid_block.AllSet() +// // Some of the filter values are null, so we have to handle the DROP +// // versus EMIT_NULL null selection behavior. +// if (null_selection == FilterOptions::DROP) { +// // Filter null values are treated as false. +// for (int64_t i = 0; i < filter_block.length; ++i) { +// if (bit_util::GetBit(filter_is_valid, filter_offset + in_position) && +// bit_util::GetBit(filter_data, filter_offset + in_position)) { +// RETURN_NOT_OK(AppendMaybeNull(in_position)); +// } +// ++in_position; +// } +// } else { +// // Filter null values are appended to output as null whether the +// // value in the corresponding slot is valid or not +// for (int64_t i = 0; i < filter_block.length; ++i) { +// const bool filter_not_null = +// bit_util::GetBit(filter_is_valid, filter_offset + in_position); +// if (filter_not_null && +// bit_util::GetBit(filter_data, filter_offset + in_position)) { +// RETURN_NOT_OK(AppendMaybeNull(in_position)); +// } else if (!filter_not_null) { +// // EMIT_NULL case +// RETURN_NOT_OK(AppendNull()); +// } +// ++in_position; +// } +// } +// } +// } +// return Status::OK(); +// } + +// virtual Status Init() { return Status::OK(); } + +// // Implementation specific finish logic +// virtual Status Finish() = 0; + +// Status ExecTake() { +// RETURN_NOT_OK(this->validity_builder.Reserve(output_length)); +// RETURN_NOT_OK(Init()); +// int index_width = this->selection.type->byte_width(); + +// // CTRP dispatch here +// switch (index_width) { +// case 1: { +// Status s = +// static_cast(this)->template GenerateOutput>(); +// RETURN_NOT_OK(s); +// } break; +// case 2: { +// Status s = +// static_cast(this)->template GenerateOutput>(); +// RETURN_NOT_OK(s); +// } break; +// case 4: { +// Status s = +// static_cast(this)->template GenerateOutput>(); +// RETURN_NOT_OK(s); +// } break; +// case 8: { +// Status s = +// static_cast(this)->template GenerateOutput>(); +// RETURN_NOT_OK(s); +// } break; +// default: +// DCHECK(false) << "Invalid index width"; +// break; +// } +// RETURN_NOT_OK(this->FinishCommon()); +// return Finish(); +// } + +// Status ExecFilter() { +// RETURN_NOT_OK(this->validity_builder.Reserve(output_length)); +// RETURN_NOT_OK(Init()); +// // CRTP dispatch +// Status s = static_cast(this)->template GenerateOutput(); +// RETURN_NOT_OK(s); +// RETURN_NOT_OK(this->FinishCommon()); +// return Finish(); +// } +// }; + +// #define LIFT_BASE_MEMBERS() \ +// using ValuesArrayType = typename Base::ValuesArrayType; \ +// using Base::ctx; \ +// using Base::values; \ +// using Base::selection; \ +// using Base::output_length; \ +// using Base::out; \ +// using Base::validity_builder + +// inline Status VisitNoop() { return Status::OK(); } + +// // A selection implementation for 32-bit and 64-bit variable binary +// // types. Common generated kernels are shared between Binary/String and +// // LargeBinary/LargeString +// template +// struct VarBinarySelectionImpl : public Selection, Type> { +// using offset_type = typename Type::offset_type; + +// using Base = Selection, Type>; +// LIFT_BASE_MEMBERS(); + +// TypedBufferBuilder offset_builder; +// TypedBufferBuilder data_builder; + +// static constexpr int64_t kOffsetLimit = std::numeric_limits::max() - 1; + +// VarBinarySelectionImpl(KernelContext* ctx, const ExecSpan& batch, int64_t output_length, +// ExecResult* out) +// : Base(ctx, batch, output_length, out), +// offset_builder(ctx->memory_pool()), +// data_builder(ctx->memory_pool()) {} + +// template +// Status GenerateOutput() { +// const auto raw_offsets = this->values.template GetValues(1); +// const uint8_t* raw_data = this->values.buffers[2].data; + +// // Presize the data builder with a rough estimate of the required data size +// if (this->values.length > 0) { +// int64_t data_length = raw_offsets[this->values.length] - raw_offsets[0]; +// const double mean_value_length = +// data_length / static_cast(this->values.length); + +// // TODO: See if possible to reduce output_length for take/filter cases +// // where there are nulls in the selection array +// RETURN_NOT_OK( +// data_builder.Reserve(static_cast(mean_value_length * output_length))); +// } +// int64_t space_available = data_builder.capacity(); + +// offset_type offset = 0; +// Adapter adapter(this); +// RETURN_NOT_OK(adapter.Generate( +// [&](int64_t index) { +// offset_builder.UnsafeAppend(offset); +// offset_type val_offset = raw_offsets[index]; +// offset_type val_size = raw_offsets[index + 1] - val_offset; + +// // Use static property to prune this code from the filter path in +// // optimized builds +// if (Adapter::is_take && +// ARROW_PREDICT_FALSE(static_cast(offset) + +// static_cast(val_size)) > kOffsetLimit) { +// return Status::Invalid("Take operation overflowed binary array capacity"); +// } +// offset += val_size; +// if (ARROW_PREDICT_FALSE(val_size > space_available)) { +// RETURN_NOT_OK(data_builder.Reserve(val_size)); +// space_available = data_builder.capacity() - data_builder.length(); +// } +// data_builder.UnsafeAppend(raw_data + val_offset, val_size); +// space_available -= val_size; +// return Status::OK(); +// }, +// [&]() { +// offset_builder.UnsafeAppend(offset); +// return Status::OK(); +// })); +// offset_builder.UnsafeAppend(offset); +// return Status::OK(); +// } + +// Status Init() override { return offset_builder.Reserve(output_length + 1); } + +// Status Finish() override { +// RETURN_NOT_OK(offset_builder.Finish(&out->buffers[1])); +// return data_builder.Finish(&out->buffers[2]); +// } +// }; + +// template +// struct ListSelectionImpl : public Selection, Type> { +// using offset_type = typename Type::offset_type; + +// using Base = Selection, Type>; +// LIFT_BASE_MEMBERS(); + +// TypedBufferBuilder offset_builder; +// typename TypeTraits::OffsetBuilderType child_index_builder; + +// ListSelectionImpl(KernelContext* ctx, const ExecSpan& batch, int64_t output_length, +// ExecResult* out) +// : Base(ctx, batch, output_length, out), +// offset_builder(ctx->memory_pool()), +// child_index_builder(ctx->memory_pool()) {} + +// template +// Status GenerateOutput() { +// ValuesArrayType typed_values(this->values.ToArrayData()); + +// // TODO presize child_index_builder with a similar heuristic as VarBinarySelectionImpl + +// offset_type offset = 0; +// Adapter adapter(this); +// RETURN_NOT_OK(adapter.Generate( +// [&](int64_t index) { +// offset_builder.UnsafeAppend(offset); +// offset_type value_offset = typed_values.value_offset(index); +// offset_type value_length = typed_values.value_length(index); +// offset += value_length; +// RETURN_NOT_OK(child_index_builder.Reserve(value_length)); +// for (offset_type j = value_offset; j < value_offset + value_length; ++j) { +// child_index_builder.UnsafeAppend(j); +// } +// return Status::OK(); +// }, +// [&]() { +// offset_builder.UnsafeAppend(offset); +// return Status::OK(); +// })); +// offset_builder.UnsafeAppend(offset); +// return Status::OK(); +// } + +// Status Init() override { +// RETURN_NOT_OK(offset_builder.Reserve(output_length + 1)); +// return Status::OK(); +// } + +// Status Finish() override { +// std::shared_ptr child_indices; +// RETURN_NOT_OK(child_index_builder.Finish(&child_indices)); + +// ValuesArrayType typed_values(this->values.ToArrayData()); + +// // No need to boundscheck the child values indices +// ARROW_ASSIGN_OR_RAISE(std::shared_ptr taken_child, +// Take(*typed_values.values(), *child_indices, +// TakeOptions::NoBoundsCheck(), ctx->exec_context())); +// RETURN_NOT_OK(offset_builder.Finish(&out->buffers[1])); +// out->child_data = {taken_child->data()}; +// return Status::OK(); +// } +// }; + +// struct DenseUnionSelectionImpl +// : public Selection { +// using Base = Selection; +// LIFT_BASE_MEMBERS(); + +// TypedBufferBuilder value_offset_buffer_builder_; +// TypedBufferBuilder child_id_buffer_builder_; +// std::vector type_codes_; +// std::vector child_indices_builders_; + +// DenseUnionSelectionImpl(KernelContext* ctx, const ExecSpan& batch, +// int64_t output_length, ExecResult* out) +// : Base(ctx, batch, output_length, out), +// value_offset_buffer_builder_(ctx->memory_pool()), +// child_id_buffer_builder_(ctx->memory_pool()), +// type_codes_(checked_cast(*this->values.type).type_codes()), +// child_indices_builders_(type_codes_.size()) { +// for (auto& child_indices_builder : child_indices_builders_) { +// child_indices_builder = Int32Builder(ctx->memory_pool()); +// } +// } + +// template +// Status GenerateOutput() { +// DenseUnionArray typed_values(this->values.ToArrayData()); +// Adapter adapter(this); +// RETURN_NOT_OK(adapter.Generate( +// [&](int64_t index) { +// int8_t child_id = typed_values.child_id(index); +// child_id_buffer_builder_.UnsafeAppend(type_codes_[child_id]); +// int32_t value_offset = typed_values.value_offset(index); +// value_offset_buffer_builder_.UnsafeAppend( +// static_cast(child_indices_builders_[child_id].length())); +// RETURN_NOT_OK(child_indices_builders_[child_id].Reserve(1)); +// child_indices_builders_[child_id].UnsafeAppend(value_offset); +// return Status::OK(); +// }, +// [&]() { +// int8_t child_id = 0; +// child_id_buffer_builder_.UnsafeAppend(type_codes_[child_id]); +// value_offset_buffer_builder_.UnsafeAppend( +// static_cast(child_indices_builders_[child_id].length())); +// RETURN_NOT_OK(child_indices_builders_[child_id].Reserve(1)); +// child_indices_builders_[child_id].UnsafeAppendNull(); +// return Status::OK(); +// })); +// return Status::OK(); +// } + +// Status Init() override { +// RETURN_NOT_OK(child_id_buffer_builder_.Reserve(output_length)); +// RETURN_NOT_OK(value_offset_buffer_builder_.Reserve(output_length)); +// return Status::OK(); +// } + +// Status Finish() override { +// ARROW_ASSIGN_OR_RAISE(auto child_ids_buffer, child_id_buffer_builder_.Finish()); +// ARROW_ASSIGN_OR_RAISE(auto value_offsets_buffer, +// value_offset_buffer_builder_.Finish()); +// DenseUnionArray typed_values(this->values.ToArrayData()); +// auto num_fields = typed_values.num_fields(); +// auto num_rows = child_ids_buffer->size(); +// BufferVector buffers{nullptr, std::move(child_ids_buffer), +// std::move(value_offsets_buffer)}; +// *out = ArrayData(typed_values.type(), num_rows, std::move(buffers), 0); +// for (auto i = 0; i < num_fields; i++) { +// ARROW_ASSIGN_OR_RAISE(auto child_indices_array, +// child_indices_builders_[i].Finish()); +// ARROW_ASSIGN_OR_RAISE(std::shared_ptr child_array, +// Take(*typed_values.field(i), *child_indices_array)); +// out->child_data.push_back(child_array->data()); +// } +// return Status::OK(); +// } +// }; + +// // We need a slightly different approach for SparseUnion. For Take, we can +// // invoke Take on each child's data with boundschecking disabled. For +// // Filter on the other hand, if we naively call Filter on each child, then the +// // filter output length will have to be redundantly computed. Thus, for Filter +// // we instead convert the filter to selection indices and then invoke take. + +// // SparseUnion selection implementation. ONLY used for Take +// struct SparseUnionSelectionImpl +// : public Selection { +// using Base = Selection; +// LIFT_BASE_MEMBERS(); + +// TypedBufferBuilder child_id_buffer_builder_; +// const int8_t type_code_for_null_; + +// SparseUnionSelectionImpl(KernelContext* ctx, const ExecSpan& batch, +// int64_t output_length, ExecResult* out) +// : Base(ctx, batch, output_length, out), +// child_id_buffer_builder_(ctx->memory_pool()), +// type_code_for_null_( +// checked_cast(*this->values.type).type_codes()[0]) {} + +// template +// Status GenerateOutput() { +// SparseUnionArray typed_values(this->values.ToArrayData()); +// Adapter adapter(this); +// RETURN_NOT_OK(adapter.Generate( +// [&](int64_t index) { +// child_id_buffer_builder_.UnsafeAppend(typed_values.type_code(index)); +// return Status::OK(); +// }, +// [&]() { +// child_id_buffer_builder_.UnsafeAppend(type_code_for_null_); +// return Status::OK(); +// })); +// return Status::OK(); +// } + +// Status Init() override { +// RETURN_NOT_OK(child_id_buffer_builder_.Reserve(output_length)); +// return Status::OK(); +// } + +// Status Finish() override { +// ARROW_ASSIGN_OR_RAISE(auto child_ids_buffer, child_id_buffer_builder_.Finish()); +// SparseUnionArray typed_values(this->values.ToArrayData()); +// auto num_fields = typed_values.num_fields(); +// auto num_rows = child_ids_buffer->size(); +// BufferVector buffers{nullptr, std::move(child_ids_buffer)}; +// *out = ArrayData(typed_values.type(), num_rows, std::move(buffers), 0); +// out->child_data.reserve(num_fields); +// for (auto i = 0; i < num_fields; i++) { +// ARROW_ASSIGN_OR_RAISE(auto child_datum, +// Take(*typed_values.field(i), *this->selection.ToArrayData())); +// out->child_data.emplace_back(std::move(child_datum).array()); +// } +// return Status::OK(); +// } +// }; + +// struct FSLSelectionImpl : public Selection { +// Int64Builder child_index_builder; + +// using Base = Selection; +// LIFT_BASE_MEMBERS(); + +// FSLSelectionImpl(KernelContext* ctx, const ExecSpan& batch, int64_t output_length, +// ExecResult* out) +// : Base(ctx, batch, output_length, out), child_index_builder(ctx->memory_pool()) {} + +// template +// Status GenerateOutput() { +// ValuesArrayType typed_values(this->values.ToArrayData()); +// const int32_t list_size = typed_values.list_type()->list_size(); +// const int64_t base_offset = typed_values.offset(); + +// // We must take list_size elements even for null elements of +// // indices. +// RETURN_NOT_OK(child_index_builder.Reserve(output_length * list_size)); + +// Adapter adapter(this); +// return adapter.Generate( +// [&](int64_t index) { +// int64_t offset = (base_offset + index) * list_size; +// for (int64_t j = offset; j < offset + list_size; ++j) { +// child_index_builder.UnsafeAppend(j); +// } +// return Status::OK(); +// }, +// [&]() { return child_index_builder.AppendNulls(list_size); }); +// } + +// Status Finish() override { +// std::shared_ptr child_indices; +// RETURN_NOT_OK(child_index_builder.Finish(&child_indices)); + +// ValuesArrayType typed_values(this->values.ToArrayData()); + +// // No need to boundscheck the child values indices +// ARROW_ASSIGN_OR_RAISE(std::shared_ptr taken_child, +// Take(*typed_values.values(), *child_indices, +// TakeOptions::NoBoundsCheck(), ctx->exec_context())); +// out->child_data = {taken_child->data()}; +// return Status::OK(); +// } +// }; + +// // ---------------------------------------------------------------------- +// // Struct selection implementations + +// // We need a slightly different approach for StructType. For Take, we can +// // invoke Take on each struct field's data with boundschecking disabled. For +// // Filter on the other hand, if we naively call Filter on each field, then the +// // filter output length will have to be redundantly computed. Thus, for Filter +// // we instead convert the filter to selection indices and then invoke take. + +// // Struct selection implementation. ONLY used for Take +// struct StructSelectionImpl : public Selection { +// using Base = Selection; +// LIFT_BASE_MEMBERS(); +// using Base::Base; + +// template +// Status GenerateOutput() { +// StructArray typed_values(this->values.ToArrayData()); +// Adapter adapter(this); +// // There's nothing to do for Struct except to generate the validity bitmap +// return adapter.Generate([&](int64_t index) { return Status::OK(); }, +// /*visit_null=*/VisitNoop); +// } + +// Status Finish() override { +// StructArray typed_values(this->values.ToArrayData()); + +// // Select from children without boundschecking +// out->child_data.resize(this->values.type->num_fields()); +// for (int field_index = 0; field_index < this->values.type->num_fields(); +// ++field_index) { +// ARROW_ASSIGN_OR_RAISE(Datum taken_field, +// Take(Datum(typed_values.field(field_index)), +// Datum(this->selection.ToArrayData()), +// TakeOptions::NoBoundsCheck(), ctx->exec_context())); +// out->child_data[field_index] = taken_field.array(); +// } +// return Status::OK(); +// } +// }; + +// #undef LIFT_BASE_MEMBERS + +// // ---------------------------------------------------------------------- + +template +Status ScatterByMaskExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + int64_t output_length = + GetFilterOutputSize(batch[1].array, FilterState::Get(ctx).null_selection_behavior); + Impl kernel(ctx, batch, output_length, out); + return kernel.ExecFilter(); +} + +// } // namespace + +Status ListFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + return FilterExec>(ctx, batch, out); +} + +Status LargeListFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + return FilterExec>(ctx, batch, out); +} + +Status FSLFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + const ArraySpan& values = batch[0].array; + + // If a FixedSizeList wraps a fixed-width type we can, in some cases, use + // PrimitiveFilterExec for a fixed-size list array. + if (util::IsFixedWidthLike(values, + /*force_null_count=*/true, + /*exclude_bool_and_dictionary=*/true)) { + const auto byte_width = util::FixedWidthInBytes(*values.type); + // 0 is a valid byte width for FixedSizeList, but PrimitiveFilterExec + // might not handle it correctly. + if (byte_width > 0) { + return PrimitiveFilterExec(ctx, batch, out); + } + } + return FilterExec(ctx, batch, out); +} + +Status DenseUnionFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + return FilterExec(ctx, batch, out); +} + +Status MapFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + return FilterExec>(ctx, batch, out); +} + +} // namespace compute::internal +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/vector_scatter_internal.h b/cpp/src/arrow/compute/kernels/vector_scatter_internal.h new file mode 100644 index 0000000000000..d7493997159da --- /dev/null +++ b/cpp/src/arrow/compute/kernels/vector_scatter_internal.h @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include "arrow/array/data.h" +#include "arrow/compute/api_vector.h" +#include "arrow/compute/exec.h" +#include "arrow/compute/function.h" +#include "arrow/compute/kernel.h" +#include "arrow/compute/kernels/codegen_internal.h" + +namespace arrow::compute::internal { + +struct ScatterKernelData { + InputType value_type; + InputType selection_type; + ArrayKernelExec exec; +}; + +void RegisterScatterFunction(const std::string& name, FunctionDoc doc, + VectorKernel base_kernel, + std::vector&& kernels, + const FunctionOptions* default_options, + FunctionRegistry* registry); + +/// \brief Callback type for VisitPlainxREEFilterOutputSegments. +/// +/// position is the logical position in the values array relative to its offset. +/// +/// segment_length is the number of elements that should be emitted. +/// +/// filter_valid is true if the filter run value is non-NULL. This value can +/// only be false if null_selection is NullSelectionBehavior::EMIT_NULL. For +/// NullSelectionBehavior::DROP, NULL values from the filter are simply skipped. +/// +/// Return true if iteration should continue, false if iteration should stop. +// using EmitREEFilterSegment = +// std::function; + +// void VisitPlainxREEFilterOutputSegments( +// const ArraySpan& filter, bool filter_may_have_nulls, +// FilterOptions::NullSelectionBehavior null_selection, +// const EmitREEFilterSegment& emit_segment); + +Status PrimitiveScatterByMaskExec(KernelContext*, const ExecSpan&, ExecResult*); +Status ListScatterByMaskExec(KernelContext*, const ExecSpan&, ExecResult*); +Status LargeListScatterByMaskExec(KernelContext*, const ExecSpan&, ExecResult*); +Status FSLScatterByMaskExec(KernelContext*, const ExecSpan&, ExecResult*); +Status DenseUnionScatterByMaskExec(KernelContext*, const ExecSpan&, ExecResult*); +Status MapScatterByMaskExec(KernelContext*, const ExecSpan&, ExecResult*); + +} // namespace arrow::compute::internal diff --git a/cpp/src/arrow/compute/kernels/vector_scatter_test.cc b/cpp/src/arrow/compute/kernels/vector_scatter_test.cc new file mode 100644 index 0000000000000..cafd88901576c --- /dev/null +++ b/cpp/src/arrow/compute/kernels/vector_scatter_test.cc @@ -0,0 +1,2723 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include +#include +#include + +#include "arrow/array/builder_nested.h" +#include "arrow/array/concatenate.h" +#include "arrow/chunked_array.h" +#include "arrow/compute/api.h" +#include "arrow/compute/kernels/test_util.h" +#include "arrow/table.h" +#include "arrow/testing/builder.h" +#include "arrow/testing/fixed_width_test_util.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" +#include "arrow/testing/util.h" +#include "arrow/util/logging.h" + +namespace arrow { + +using internal::checked_cast; +using internal::checked_pointer_cast; +using std::string_view; + +namespace compute { + +namespace { + +template +Result> REEncode(const T& array) { + ARROW_ASSIGN_OR_RAISE(auto datum, RunEndEncode(array)); + return datum.make_array(); +} + +Result> REEFromJSON(const std::shared_ptr& ree_type, + const std::string& json) { + auto ree_type_ptr = checked_cast(ree_type.get()); + auto array = ArrayFromJSON(ree_type_ptr->value_type(), json); + ARROW_ASSIGN_OR_RAISE( + auto datum, RunEndEncode(array, RunEndEncodeOptions{ree_type_ptr->run_end_type()})); + return datum.make_array(); +} + +Result> FilterFromJSON( + const std::shared_ptr& filter_type, const std::string& json) { + if (filter_type->id() == Type::RUN_END_ENCODED) { + return REEFromJSON(filter_type, json); + } else { + return ArrayFromJSON(filter_type, json); + } +} + +Result> REEncode(const std::shared_ptr& array) { + ARROW_ASSIGN_OR_RAISE(auto datum, RunEndEncode(array)); + return datum.make_array(); +} + +void CheckTakeIndicesCase(const BooleanArray& filter, + const std::shared_ptr& expected_indices, + FilterOptions::NullSelectionBehavior null_selection) { + ASSERT_OK_AND_ASSIGN(auto indices, + internal::GetTakeIndices(*filter.data(), null_selection)); + auto indices_array = MakeArray(indices); + ValidateOutput(indices); + AssertArraysEqual(*expected_indices, *indices_array, /*verbose=*/true); + + ASSERT_OK_AND_ASSIGN(auto ree_filter, REEncode(filter)); + ASSERT_OK_AND_ASSIGN(auto indices_from_ree, + internal::GetTakeIndices(*ree_filter->data(), null_selection)); + auto indices_from_ree_array = MakeArray(indices); + ValidateOutput(indices_from_ree); + AssertArraysEqual(*expected_indices, *indices_from_ree_array, /*verbose=*/true); +} + +void CheckTakeIndicesCase(const std::string& filter_json, const std::string& indices_json, + FilterOptions::NullSelectionBehavior null_selection, + const std::shared_ptr& indices_type = uint16()) { + auto filter = ArrayFromJSON(boolean(), filter_json); + auto expected_indices = ArrayFromJSON(indices_type, indices_json); + const auto& boolean_filter = checked_cast(*filter); + CheckTakeIndicesCase(boolean_filter, expected_indices, null_selection); +} + +} // namespace + +// ---------------------------------------------------------------------- + +TEST(GetTakeIndices, Basics) { + // Drop null cases + CheckTakeIndicesCase("[]", "[]", FilterOptions::DROP); + CheckTakeIndicesCase("[null]", "[]", FilterOptions::DROP); + CheckTakeIndicesCase("[null, false, true, true, false, true]", "[2, 3, 5]", + FilterOptions::DROP); + + // Emit null cases + CheckTakeIndicesCase("[]", "[]", FilterOptions::EMIT_NULL); + CheckTakeIndicesCase("[null]", "[null]", FilterOptions::EMIT_NULL); + CheckTakeIndicesCase("[null, false, true, true]", "[null, 2, 3]", + FilterOptions::EMIT_NULL); +} + +TEST(GetTakeIndices, NullValidityBuffer) { + BooleanArray filter(1, *AllocateEmptyBitmap(1), /*null_bitmap=*/nullptr); + auto expected_indices = ArrayFromJSON(uint16(), "[]"); + + CheckTakeIndicesCase(filter, expected_indices, FilterOptions::DROP); + CheckTakeIndicesCase(filter, expected_indices, FilterOptions::EMIT_NULL); +} + +template +void CheckGetTakeIndicesCase(const Array& untyped_filter) { + const auto& filter = checked_cast(untyped_filter); + ASSERT_OK_AND_ASSIGN(auto ree_filter, REEncode(*filter.data())); + + ASSERT_OK_AND_ASSIGN(std::shared_ptr drop_indices, + internal::GetTakeIndices(*filter.data(), FilterOptions::DROP)); + ASSERT_OK_AND_ASSIGN( + std::shared_ptr drop_indices_from_ree, + internal::GetTakeIndices(*ree_filter->data(), FilterOptions::DROP)); + // Verify DROP indices + { + IndexArrayType indices(drop_indices); + IndexArrayType indices_from_ree(drop_indices); + ValidateOutput(indices); + ValidateOutput(indices_from_ree); + + int64_t out_position = 0; + for (int64_t i = 0; i < filter.length(); ++i) { + if (filter.IsValid(i)) { + if (filter.Value(i)) { + ASSERT_EQ(indices.Value(out_position), i); + ASSERT_EQ(indices_from_ree.Value(out_position), i); + ++out_position; + } + } + } + ASSERT_EQ(out_position, indices.length()); + ASSERT_EQ(out_position, indices_from_ree.length()); + + // Check that the end length agrees with the output of GetFilterOutputSize + ASSERT_EQ(out_position, + internal::GetFilterOutputSize(*filter.data(), FilterOptions::DROP)); + ASSERT_EQ(out_position, + internal::GetFilterOutputSize(*ree_filter->data(), FilterOptions::DROP)); + } + + ASSERT_OK_AND_ASSIGN( + std::shared_ptr emit_indices, + internal::GetTakeIndices(*filter.data(), FilterOptions::EMIT_NULL)); + ASSERT_OK_AND_ASSIGN( + std::shared_ptr emit_indices_from_ree, + internal::GetTakeIndices(*ree_filter->data(), FilterOptions::EMIT_NULL)); + // Verify EMIT_NULL indices + { + IndexArrayType indices(emit_indices); + IndexArrayType indices_from_ree(emit_indices); + ValidateOutput(indices); + ValidateOutput(indices_from_ree); + + int64_t out_position = 0; + for (int64_t i = 0; i < filter.length(); ++i) { + if (filter.IsValid(i)) { + if (filter.Value(i)) { + ASSERT_EQ(indices.Value(out_position), i); + ASSERT_EQ(indices_from_ree.Value(out_position), i); + ++out_position; + } + } else { + ASSERT_TRUE(indices.IsNull(out_position)); + ASSERT_TRUE(indices_from_ree.IsNull(out_position)); + ++out_position; + } + } + + ASSERT_EQ(out_position, indices.length()); + ASSERT_EQ(out_position, indices_from_ree.length()); + + // Check that the end length agrees with the output of GetFilterOutputSize + ASSERT_EQ(out_position, + internal::GetFilterOutputSize(*filter.data(), FilterOptions::EMIT_NULL)); + ASSERT_EQ(out_position, internal::GetFilterOutputSize(*ree_filter->data(), + FilterOptions::EMIT_NULL)); + } +} + +TEST(GetTakeIndices, RandomlyGenerated) { + random::RandomArrayGenerator rng(kRandomSeed); + + // Multiple of word size + 1 + const int64_t length = 6401; + for (auto null_prob : {0.0, 0.01, 0.999, 1.0}) { + for (auto true_prob : {0.0, 0.01, 0.999, 1.0}) { + auto filter = rng.Boolean(length, true_prob, null_prob); + CheckGetTakeIndicesCase(*filter); + CheckGetTakeIndicesCase(*filter->Slice(7)); + } + } + + // Check that the uint32 path is traveled successfully + const int64_t uint16_max = std::numeric_limits::max(); + auto filter = + std::static_pointer_cast(rng.Boolean(uint16_max + 1, 0.99, 0.01)); + CheckGetTakeIndicesCase(*filter->Slice(1)); + CheckGetTakeIndicesCase(*filter); +} + +// ---------------------------------------------------------------------- +// Filter tests + +std::shared_ptr CoalesceNullToFalse(std::shared_ptr filter) { + const bool is_ree = filter->type_id() == Type::RUN_END_ENCODED; + // Work directly on run values array in case of REE + const ArrayData& data = is_ree ? *filter->data()->child_data[1] : *filter->data(); + if (data.GetNullCount() == 0) { + return filter; + } + auto is_true = std::make_shared(data.length, data.buffers[1], nullptr, 0, + data.offset); + auto is_valid = std::make_shared(data.length, data.buffers[0], nullptr, 0, + data.offset); + EXPECT_OK_AND_ASSIGN(Datum out_datum, And(is_true, is_valid)); + if (is_ree) { + const auto& ree_filter = checked_cast(*filter); + EXPECT_OK_AND_ASSIGN( + auto new_ree_filter, + RunEndEncodedArray::Make(ree_filter.length(), ree_filter.run_ends(), + /*values=*/out_datum.make_array(), ree_filter.offset())); + return new_ree_filter; + } + return out_datum.make_array(); +} + +class TestFilterKernel : public ::testing::Test { + protected: + TestFilterKernel() : emit_null_(FilterOptions::EMIT_NULL), drop_(FilterOptions::DROP) {} + + void DoAssertFilter(const std::shared_ptr& values, + const std::shared_ptr& filter, + const std::shared_ptr& expected) { + // test with EMIT_NULL + { + ARROW_SCOPED_TRACE("with EMIT_NULL"); + ASSERT_OK_AND_ASSIGN(Datum out_datum, Filter(values, filter, emit_null_)); + auto actual = out_datum.make_array(); + ValidateOutput(*actual); + AssertArraysEqual(*expected, *actual, /*verbose=*/true); + } + + // test with DROP using EMIT_NULL and a coalesced filter + { + ARROW_SCOPED_TRACE("with DROP"); + auto coalesced_filter = CoalesceNullToFalse(filter); + ASSERT_OK_AND_ASSIGN(Datum out_datum, Filter(values, coalesced_filter, emit_null_)); + auto expected_for_drop = out_datum.make_array(); + ASSERT_OK_AND_ASSIGN(out_datum, Filter(values, filter, drop_)); + auto actual = out_datum.make_array(); + ValidateOutput(*actual); + AssertArraysEqual(*expected_for_drop, *actual, /*verbose=*/true); + } + } + + void AssertFilter(const std::shared_ptr& values, + const std::shared_ptr& filter, + const std::shared_ptr& expected) { + DoAssertFilter(values, filter, expected); + + // Check slicing: add M(=3) dummy values at the start and end of `values`, + // add N(=2) dummy values at the start and end of `filter`. + ARROW_SCOPED_TRACE("for sliced values and filter"); + ASSERT_OK_AND_ASSIGN(auto values_filler, MakeArrayOfNull(values->type(), 3)); + ASSERT_OK_AND_ASSIGN(auto filter_filler, + FilterFromJSON(filter->type(), "[true, false]")); + ASSERT_OK_AND_ASSIGN(auto values_with_filler, + Concatenate({values_filler, values, values_filler})); + ASSERT_OK_AND_ASSIGN(auto filter_with_filler, + Concatenate({filter_filler, filter, filter_filler})); + auto values_sliced = values_with_filler->Slice(3, values->length()); + auto filter_sliced = filter_with_filler->Slice(2, filter->length()); + DoAssertFilter(values_sliced, filter_sliced, expected); + } + + void AssertFilter(const std::shared_ptr& type, const std::string& values, + const std::string& filter, const std::string& expected) { + auto values_array = ArrayFromJSON(type, values); + auto filter_array = ArrayFromJSON(boolean(), filter); + auto expected_array = ArrayFromJSON(type, expected); + AssertFilter(values_array, filter_array, expected_array); + + ASSERT_OK_AND_ASSIGN(auto ree_filter, REEncode(filter_array)); + ARROW_SCOPED_TRACE("for plain values and REE filter"); + AssertFilter(values_array, ree_filter, expected_array); + } + + void TestNumericBasics(const std::shared_ptr& type) { + ARROW_SCOPED_TRACE("type = ", *type); + AssertFilter(type, "[]", "[]", "[]"); + + AssertFilter(type, "[9]", "[0]", "[]"); + AssertFilter(type, "[9]", "[1]", "[9]"); + AssertFilter(type, "[9]", "[null]", "[null]"); + AssertFilter(type, "[null]", "[0]", "[]"); + AssertFilter(type, "[null]", "[1]", "[null]"); + AssertFilter(type, "[null]", "[null]", "[null]"); + + AssertFilter(type, "[7, 8, 9]", "[0, 1, 0]", "[8]"); + AssertFilter(type, "[7, 8, 9]", "[1, 0, 1]", "[7, 9]"); + AssertFilter(type, "[null, 8, 9]", "[0, 1, 0]", "[8]"); + AssertFilter(type, "[7, 8, 9]", "[null, 1, 0]", "[null, 8]"); + AssertFilter(type, "[7, 8, 9]", "[1, null, 1]", "[7, null, 9]"); + + AssertFilter(ArrayFromJSON(type, "[7, 8, 9]"), + ArrayFromJSON(boolean(), "[0, 1, 1, 1, 0, 1]")->Slice(3, 3), + ArrayFromJSON(type, "[7, 9]")); + + ASSERT_RAISES(Invalid, Filter(ArrayFromJSON(type, "[7, 8, 9]"), + ArrayFromJSON(boolean(), "[]"), emit_null_)); + ASSERT_RAISES(Invalid, Filter(ArrayFromJSON(type, "[7, 8, 9]"), + ArrayFromJSON(boolean(), "[]"), drop_)); + } + + const FilterOptions emit_null_, drop_; +}; + +void ValidateFilter(const std::shared_ptr& values, + const std::shared_ptr& filter_boxed) { + FilterOptions emit_null(FilterOptions::EMIT_NULL); + FilterOptions drop(FilterOptions::DROP); + + ASSERT_OK_AND_ASSIGN(Datum out_datum, Filter(values, filter_boxed, emit_null)); + auto filtered_emit_null = out_datum.make_array(); + ValidateOutput(*filtered_emit_null); + + ASSERT_OK_AND_ASSIGN(out_datum, Filter(values, filter_boxed, drop)); + auto filtered_drop = out_datum.make_array(); + ValidateOutput(*filtered_drop); + + // Create the expected arrays using Take + ASSERT_OK_AND_ASSIGN( + std::shared_ptr drop_indices, + internal::GetTakeIndices(*filter_boxed->data(), FilterOptions::DROP)); + ASSERT_OK_AND_ASSIGN(Datum expected_drop, Take(values, Datum(drop_indices))); + + ASSERT_OK_AND_ASSIGN( + std::shared_ptr emit_null_indices, + internal::GetTakeIndices(*filter_boxed->data(), FilterOptions::EMIT_NULL)); + ASSERT_OK_AND_ASSIGN(Datum expected_emit_null, Take(values, Datum(emit_null_indices))); + + AssertArraysEqual(*expected_drop.make_array(), *filtered_drop, + /*verbose=*/true); + AssertArraysEqual(*expected_emit_null.make_array(), *filtered_emit_null, + /*verbose=*/true); +} + +TEST_F(TestFilterKernel, Temporal) { + this->TestNumericBasics(time32(TimeUnit::MILLI)); + this->TestNumericBasics(time64(TimeUnit::MICRO)); + this->TestNumericBasics(timestamp(TimeUnit::NANO, "Europe/Paris")); + this->TestNumericBasics(duration(TimeUnit::SECOND)); + this->TestNumericBasics(date32()); + this->AssertFilter(date64(), "[0, 86400000, null]", "[null, 1, 0]", "[null, 86400000]"); +} + +TEST_F(TestFilterKernel, Duration) { + for (auto type : DurationTypes()) { + this->TestNumericBasics(type); + } +} + +TEST_F(TestFilterKernel, Interval) { + this->TestNumericBasics(month_interval()); + + auto type = day_time_interval(); + this->AssertFilter(type, "[[1, -600], [2, 3000], null]", "[null, 1, 0]", + "[null, [2, 3000]]"); + type = month_day_nano_interval(); + this->AssertFilter(type, + "[[1, -2, 34567890123456789], [2, 3, -34567890123456789], null]", + "[null, 1, 0]", "[null, [2, 3, -34567890123456789]]"); +} + +class TestFilterKernelWithNull : public TestFilterKernel { + protected: + void AssertFilter(const std::string& values, const std::string& filter, + const std::string& expected) { + TestFilterKernel::AssertFilter(ArrayFromJSON(null(), values), + ArrayFromJSON(boolean(), filter), + ArrayFromJSON(null(), expected)); + } +}; + +TEST_F(TestFilterKernelWithNull, FilterNull) { + this->AssertFilter("[]", "[]", "[]"); + + this->AssertFilter("[null, null, null]", "[0, 1, 0]", "[null]"); + this->AssertFilter("[null, null, null]", "[1, 1, 0]", "[null, null]"); +} + +class TestFilterKernelWithBoolean : public TestFilterKernel { + protected: + void AssertFilter(const std::string& values, const std::string& filter, + const std::string& expected) { + TestFilterKernel::AssertFilter(ArrayFromJSON(boolean(), values), + ArrayFromJSON(boolean(), filter), + ArrayFromJSON(boolean(), expected)); + } +}; + +TEST_F(TestFilterKernelWithBoolean, FilterBoolean) { + this->AssertFilter("[]", "[]", "[]"); + + this->AssertFilter("[true, false, true]", "[0, 1, 0]", "[false]"); + this->AssertFilter("[null, false, true]", "[0, 1, 0]", "[false]"); + this->AssertFilter("[true, false, true]", "[null, 1, 0]", "[null, false]"); +} + +TEST_F(TestFilterKernelWithBoolean, DefaultOptions) { + auto values = ArrayFromJSON(int8(), "[7, 8, null, 9]"); + auto filter = ArrayFromJSON(boolean(), "[1, 1, 0, null]"); + + ASSERT_OK_AND_ASSIGN(auto no_options_provided, + CallFunction("filter", {values, filter})); + + auto default_options = FilterOptions::Defaults(); + ASSERT_OK_AND_ASSIGN(auto explicit_defaults, + CallFunction("filter", {values, filter}, &default_options)); + + AssertDatumsEqual(explicit_defaults, no_options_provided); +} + +template +class TestFilterKernelWithNumeric : public TestFilterKernel { + protected: + std::shared_ptr type_singleton() { + return TypeTraits::type_singleton(); + } +}; + +TYPED_TEST_SUITE(TestFilterKernelWithNumeric, NumericArrowTypes); +TYPED_TEST(TestFilterKernelWithNumeric, FilterNumeric) { + this->TestNumericBasics(this->type_singleton()); +} + +template +using Comparator = bool(CType, CType); + +template +Comparator* GetComparator(CompareOperator op) { + static Comparator* cmp[] = { + // EQUAL + [](CType l, CType r) { return l == r; }, + // NOT_EQUAL + [](CType l, CType r) { return l != r; }, + // GREATER + [](CType l, CType r) { return l > r; }, + // GREATER_EQUAL + [](CType l, CType r) { return l >= r; }, + // LESS + [](CType l, CType r) { return l < r; }, + // LESS_EQUAL + [](CType l, CType r) { return l <= r; }, + }; + return cmp[op]; +} + +template ::CType> +std::shared_ptr CompareAndFilter(const CType* data, int64_t length, Fn&& fn) { + std::vector filtered; + filtered.reserve(length); + std::copy_if(data, data + length, std::back_inserter(filtered), std::forward(fn)); + std::shared_ptr filtered_array; + ArrayFromVector(filtered, &filtered_array); + return filtered_array; +} + +template ::CType> +std::shared_ptr CompareAndFilter(const CType* data, int64_t length, CType val, + CompareOperator op) { + auto cmp = GetComparator(op); + return CompareAndFilter(data, length, [&](CType e) { return cmp(e, val); }); +} + +template ::CType> +std::shared_ptr CompareAndFilter(const CType* data, int64_t length, + const CType* other, CompareOperator op) { + auto cmp = GetComparator(op); + return CompareAndFilter(data, length, [&](CType e) { return cmp(e, *other++); }); +} + +TYPED_TEST(TestFilterKernelWithNumeric, CompareScalarAndFilterRandomNumeric) { + using ScalarType = typename TypeTraits::ScalarType; + using ArrayType = typename TypeTraits::ArrayType; + using CType = typename TypeTraits::CType; + + auto rand = random::RandomArrayGenerator(kRandomSeed); + for (size_t i = 3; i < 10; i++) { + const int64_t length = static_cast(1ULL << i); + // TODO(bkietz) rewrite with some nulls + auto array = + checked_pointer_cast(rand.Numeric(length, 0, 100, 0)); + CType c_fifty = 50; + auto fifty = std::make_shared(c_fifty); + for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) { + ASSERT_OK_AND_ASSIGN( + Datum selection, + CallFunction(CompareOperatorToFunctionName(op), {array, Datum(fifty)})); + ASSERT_OK_AND_ASSIGN(Datum filtered, Filter(array, selection)); + auto filtered_array = filtered.make_array(); + ValidateOutput(*filtered_array); + auto expected = + CompareAndFilter(array->raw_values(), array->length(), c_fifty, op); + ASSERT_ARRAYS_EQUAL(*filtered_array, *expected); + } + } +} + +TYPED_TEST(TestFilterKernelWithNumeric, CompareArrayAndFilterRandomNumeric) { + using ArrayType = typename TypeTraits::ArrayType; + + auto rand = random::RandomArrayGenerator(kRandomSeed); + for (size_t i = 3; i < 10; i++) { + const int64_t length = static_cast(1ULL << i); + auto lhs = checked_pointer_cast( + rand.Numeric(length, 0, 100, /*null_probability=*/0.0)); + auto rhs = checked_pointer_cast( + rand.Numeric(length, 0, 100, /*null_probability=*/0.0)); + for (auto op : {EQUAL, NOT_EQUAL, GREATER, LESS_EQUAL}) { + ASSERT_OK_AND_ASSIGN(Datum selection, + CallFunction(CompareOperatorToFunctionName(op), {lhs, rhs})); + ASSERT_OK_AND_ASSIGN(Datum filtered, Filter(lhs, selection)); + auto filtered_array = filtered.make_array(); + ValidateOutput(*filtered_array); + auto expected = CompareAndFilter(lhs->raw_values(), lhs->length(), + rhs->raw_values(), op); + ASSERT_ARRAYS_EQUAL(*filtered_array, *expected); + } + } +} + +TYPED_TEST(TestFilterKernelWithNumeric, ScalarInRangeAndFilterRandomNumeric) { + using ScalarType = typename TypeTraits::ScalarType; + using ArrayType = typename TypeTraits::ArrayType; + using CType = typename TypeTraits::CType; + + auto rand = random::RandomArrayGenerator(kRandomSeed); + for (size_t i = 3; i < 10; i++) { + const int64_t length = static_cast(1ULL << i); + auto array = checked_pointer_cast( + rand.Numeric(length, 0, 100, /*null_probability=*/0.0)); + CType c_fifty = 50, c_hundred = 100; + auto fifty = std::make_shared(c_fifty); + auto hundred = std::make_shared(c_hundred); + ASSERT_OK_AND_ASSIGN(Datum greater_than_fifty, + CallFunction("greater", {array, Datum(fifty)})); + ASSERT_OK_AND_ASSIGN(Datum less_than_hundred, + CallFunction("less", {array, Datum(hundred)})); + ASSERT_OK_AND_ASSIGN(Datum selection, And(greater_than_fifty, less_than_hundred)); + ASSERT_OK_AND_ASSIGN(Datum filtered, Filter(array, selection)); + auto filtered_array = filtered.make_array(); + ValidateOutput(*filtered_array); + auto expected = CompareAndFilter( + array->raw_values(), array->length(), + [&](CType e) { return (e > c_fifty) && (e < c_hundred); }); + ASSERT_ARRAYS_EQUAL(*filtered_array, *expected); + } +} + +template +class TestFilterKernelWithDecimal : public TestFilterKernel { + protected: + std::shared_ptr type_singleton() { return std::make_shared(3, 2); } +}; + +TYPED_TEST_SUITE(TestFilterKernelWithDecimal, DecimalArrowTypes); +TYPED_TEST(TestFilterKernelWithDecimal, FilterNumeric) { + auto type = this->type_singleton(); + this->AssertFilter(type, R"([])", "[]", R"([])"); + + this->AssertFilter(type, R"(["9.00"])", "[0]", R"([])"); + this->AssertFilter(type, R"(["9.00"])", "[1]", R"(["9.00"])"); + this->AssertFilter(type, R"(["9.00"])", "[null]", R"([null])"); + this->AssertFilter(type, R"([null])", "[0]", R"([])"); + this->AssertFilter(type, R"([null])", "[1]", R"([null])"); + this->AssertFilter(type, R"([null])", "[null]", R"([null])"); + + this->AssertFilter(type, R"(["7.12", "8.00", "9.87"])", "[0, 1, 0]", R"(["8.00"])"); + this->AssertFilter(type, R"(["7.12", "8.00", "9.87"])", "[1, 0, 1]", + R"(["7.12", "9.87"])"); + this->AssertFilter(type, R"([null, "8.00", "9.87"])", "[0, 1, 0]", R"(["8.00"])"); + this->AssertFilter(type, R"(["7.12", "8.00", "9.87"])", "[null, 1, 0]", + R"([null, "8.00"])"); + this->AssertFilter(type, R"(["7.12", "8.00", "9.87"])", "[1, null, 1]", + R"(["7.12", null, "9.87"])"); + + this->AssertFilter(ArrayFromJSON(type, R"(["7.12", "8.00", "9.87"])"), + ArrayFromJSON(boolean(), "[0, 1, 1, 1, 0, 1]")->Slice(3, 3), + ArrayFromJSON(type, R"(["7.12", "9.87"])")); + + ASSERT_RAISES(Invalid, Filter(ArrayFromJSON(type, R"(["7.12", "8.00", "9.87"])"), + ArrayFromJSON(boolean(), "[]"), this->emit_null_)); + ASSERT_RAISES(Invalid, Filter(ArrayFromJSON(type, R"(["7.12", "8.00", "9.87"])"), + ArrayFromJSON(boolean(), "[]"), this->drop_)); +} + +TEST_F(TestFilterKernel, NoValidityBitmapButUnknownNullCount) { + auto values = ArrayFromJSON(int32(), "[1, 2, 3, 4]"); + auto filter = ArrayFromJSON(boolean(), "[true, true, false, true]"); + + auto expected = (*Filter(values, filter)).make_array(); + + filter->data()->null_count = kUnknownNullCount; + auto result = (*Filter(values, filter)).make_array(); + + AssertArraysEqual(*expected, *result); +} + +template +class TestFilterKernelWithString : public TestFilterKernel { + protected: + std::shared_ptr value_type() { + return TypeTraits::type_singleton(); + } + + void AssertFilter(const std::string& values, const std::string& filter, + const std::string& expected) { + TestFilterKernel::AssertFilter(ArrayFromJSON(value_type(), values), + ArrayFromJSON(boolean(), filter), + ArrayFromJSON(value_type(), expected)); + } + + void AssertFilterDictionary(const std::string& dictionary_values, + const std::string& dictionary_filter, + const std::string& filter, + const std::string& expected_filter) { + auto dict = ArrayFromJSON(value_type(), dictionary_values); + auto type = dictionary(int8(), value_type()); + ASSERT_OK_AND_ASSIGN(auto values, + DictionaryArray::FromArrays( + type, ArrayFromJSON(int8(), dictionary_filter), dict)); + ASSERT_OK_AND_ASSIGN( + auto expected, + DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), expected_filter), dict)); + auto take_filter = ArrayFromJSON(boolean(), filter); + TestFilterKernel::AssertFilter(values, take_filter, expected); + } +}; + +TYPED_TEST_SUITE(TestFilterKernelWithString, BaseBinaryArrowTypes); + +TYPED_TEST(TestFilterKernelWithString, FilterString) { + this->AssertFilter(R"(["a", "b", "c"])", "[0, 1, 0]", R"(["b"])"); + this->AssertFilter(R"([null, "b", "c"])", "[0, 1, 0]", R"(["b"])"); + this->AssertFilter(R"(["a", "b", "c"])", "[null, 1, 0]", R"([null, "b"])"); +} + +TYPED_TEST(TestFilterKernelWithString, FilterDictionary) { + auto dict = R"(["a", "b", "c", "d", "e"])"; + this->AssertFilterDictionary(dict, "[3, 4, 2]", "[0, 1, 0]", "[4]"); + this->AssertFilterDictionary(dict, "[null, 4, 2]", "[0, 1, 0]", "[4]"); + this->AssertFilterDictionary(dict, "[3, 4, 2]", "[null, 1, 0]", "[null, 4]"); +} + +class TestFilterKernelWithList : public TestFilterKernel { + public: +}; + +TEST_F(TestFilterKernelWithList, FilterListInt32) { + std::string list_json = "[[], [1,2], null, [3]]"; + this->AssertFilter(list(int32()), list_json, "[0, 0, 0, 0]", "[]"); + this->AssertFilter(list(int32()), list_json, "[0, 1, 1, null]", "[[1,2], null, null]"); + this->AssertFilter(list(int32()), list_json, "[0, 0, 1, null]", "[null, null]"); + this->AssertFilter(list(int32()), list_json, "[1, 0, 0, 1]", "[[], [3]]"); + this->AssertFilter(list(int32()), list_json, "[1, 1, 1, 1]", list_json); + this->AssertFilter(list(int32()), list_json, "[0, 1, 0, 1]", "[[1,2], [3]]"); +} + +TEST_F(TestFilterKernelWithList, FilterListListInt32) { + std::string list_json = R"([ + [], + [[1], [2, null, 2], []], + null, + [[3, null], null] + ])"; + auto type = list(list(int32())); + this->AssertFilter(type, list_json, "[0, 0, 0, 0]", "[]"); + this->AssertFilter(type, list_json, "[0, 1, 1, null]", R"([ + [[1], [2, null, 2], []], + null, + null + ])"); + this->AssertFilter(type, list_json, "[0, 0, 1, null]", "[null, null]"); + this->AssertFilter(type, list_json, "[1, 0, 0, 1]", R"([ + [], + [[3, null], null] + ])"); + this->AssertFilter(type, list_json, "[1, 1, 1, 1]", list_json); + this->AssertFilter(type, list_json, "[0, 1, 0, 1]", R"([ + [[1], [2, null, 2], []], + [[3, null], null] + ])"); +} + +class TestFilterKernelWithLargeList : public TestFilterKernel {}; + +TEST_F(TestFilterKernelWithLargeList, FilterListInt32) { + std::string list_json = "[[], [1,2], null, [3]]"; + this->AssertFilter(large_list(int32()), list_json, "[0, 0, 0, 0]", "[]"); + this->AssertFilter(large_list(int32()), list_json, "[0, 1, 1, null]", + "[[1,2], null, null]"); +} + +class TestFilterKernelWithFixedSizeList : public TestFilterKernel { + protected: + std::vector> five_length_filters_ = { + ArrayFromJSON(boolean(), "[false, false, false, false, false]"), + ArrayFromJSON(boolean(), "[true, true, true, true, true]"), + ArrayFromJSON(boolean(), "[false, true, true, false, true]"), + ArrayFromJSON(boolean(), "[null, true, null, false, true]"), + }; + + void AssertFilterOnNestedLists(const std::shared_ptr& inner_type, + const std::vector& list_sizes) { + using NLG = ::arrow::util::internal::NestedListGenerator; + constexpr int64_t kLength = 5; + // Create two equivalent lists: one as a FixedSizeList and another as a List. + ASSERT_OK_AND_ASSIGN(auto fsl_list, + NLG::NestedFSLArray(inner_type, list_sizes, kLength)); + ASSERT_OK_AND_ASSIGN(auto list, + NLG::NestedListArray(inner_type, list_sizes, kLength)); + + ARROW_SCOPED_TRACE("CheckTakeOnNestedLists of type `", *fsl_list->type(), "`"); + + for (auto& filter : five_length_filters_) { + // Use the Filter on ListType as the reference implementation. + ASSERT_OK_AND_ASSIGN(auto expected_list, + Filter(*list, *filter, /*options=*/emit_null_)); + ASSERT_OK_AND_ASSIGN(auto expected_fsl, Cast(expected_list, fsl_list->type())); + auto expected_fsl_array = expected_fsl.make_array(); + this->AssertFilter(fsl_list, filter, expected_fsl_array); + } + } +}; + +TEST_F(TestFilterKernelWithFixedSizeList, FilterFixedSizeListInt32) { + std::string list_json = "[null, [1, null, 3], [4, 5, 6], [7, 8, null]]"; + this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[0, 0, 0, 0]", "[]"); + this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[0, 1, 1, null]", + "[[1, null, 3], [4, 5, 6], null]"); + this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[0, 0, 1, null]", + "[[4, 5, 6], null]"); + this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[1, 1, 1, 1]", list_json); + this->AssertFilter(fixed_size_list(int32(), 3), list_json, "[0, 1, 0, 1]", + "[[1, null, 3], [7, 8, null]]"); +} + +TEST_F(TestFilterKernelWithFixedSizeList, FilterFixedSizeListVarWidth) { + std::string list_json = + R"([["zero", "one", ""], ["two", "", "three"], ["four", "five", "six"], ["seven", "eight", ""]])"; + this->AssertFilter(fixed_size_list(utf8(), 3), list_json, "[0, 0, 0, 0]", "[]"); + this->AssertFilter(fixed_size_list(utf8(), 3), list_json, "[0, 1, 1, null]", + R"([["two", "", "three"], ["four", "five", "six"], null])"); + this->AssertFilter(fixed_size_list(utf8(), 3), list_json, "[0, 0, 1, null]", + R"([["four", "five", "six"], null])"); + this->AssertFilter(fixed_size_list(utf8(), 3), list_json, "[1, 1, 1, 1]", list_json); + this->AssertFilter(fixed_size_list(utf8(), 3), list_json, "[0, 1, 0, 1]", + R"([["two", "", "three"], ["seven", "eight", ""]])"); +} + +TEST_F(TestFilterKernelWithFixedSizeList, FilterFixedSizeListModuloNesting) { + using NLG = ::arrow::util::internal::NestedListGenerator; + const std::vector> value_types = { + int16(), + int32(), + int64(), + }; + NLG::VisitAllNestedListConfigurations( + value_types, [this](const std::shared_ptr& inner_type, + const std::vector& list_sizes) { + this->AssertFilterOnNestedLists(inner_type, list_sizes); + }); +} + +class TestFilterKernelWithMap : public TestFilterKernel {}; + +TEST_F(TestFilterKernelWithMap, FilterMapStringToInt32) { + std::string map_json = R"([ + [["joe", 0], ["mark", null]], + null, + [["cap", 8]], + [] + ])"; + this->AssertFilter(map(utf8(), int32()), map_json, "[0, 0, 0, 0]", "[]"); + this->AssertFilter(map(utf8(), int32()), map_json, "[0, 1, 1, null]", R"([ + null, + [["cap", 8]], + null + ])"); + this->AssertFilter(map(utf8(), int32()), map_json, "[1, 1, 1, 1]", map_json); + this->AssertFilter(map(utf8(), int32()), map_json, "[0, 1, 0, 1]", "[null, []]"); +} + +class TestFilterKernelWithStruct : public TestFilterKernel {}; + +TEST_F(TestFilterKernelWithStruct, FilterStruct) { + auto struct_type = struct_({field("a", int32()), field("b", utf8())}); + auto struct_json = R"([ + null, + {"a": 1, "b": ""}, + {"a": 2, "b": "hello"}, + {"a": 4, "b": "eh"} + ])"; + this->AssertFilter(struct_type, struct_json, "[0, 0, 0, 0]", "[]"); + this->AssertFilter(struct_type, struct_json, "[0, 1, 1, null]", R"([ + {"a": 1, "b": ""}, + {"a": 2, "b": "hello"}, + null + ])"); + this->AssertFilter(struct_type, struct_json, "[1, 1, 1, 1]", struct_json); + this->AssertFilter(struct_type, struct_json, "[1, 0, 1, 0]", R"([ + null, + {"a": 2, "b": "hello"} + ])"); +} + +class TestFilterKernelWithUnion : public TestFilterKernel {}; + +TEST_F(TestFilterKernelWithUnion, FilterUnion) { + for (const auto& union_type : + {dense_union({field("a", int32()), field("b", utf8())}, {2, 5}), + sparse_union({field("a", int32()), field("b", utf8())}, {2, 5})}) { + auto union_json = R"([ + [2, null], + [2, 222], + [5, "hello"], + [5, "eh"], + [2, null], + [2, 111], + [5, null] + ])"; + this->AssertFilter(union_type, union_json, "[0, 0, 0, 0, 0, 0, 0]", "[]"); + this->AssertFilter(union_type, union_json, "[0, 1, 1, null, 0, 1, 1]", R"([ + [2, 222], + [5, "hello"], + [2, null], + [2, 111], + [5, null] + ])"); + this->AssertFilter(union_type, union_json, "[1, 0, 1, 0, 1, 0, 0]", R"([ + [2, null], + [5, "hello"], + [2, null] + ])"); + this->AssertFilter(union_type, union_json, "[1, 1, 1, 1, 1, 1, 1]", union_json); + } +} + +class TestFilterKernelWithRecordBatch : public TestFilterKernel { + public: + void AssertFilter(const std::shared_ptr& schm, const std::string& batch_json, + const std::string& selection, FilterOptions options, + const std::string& expected_batch) { + std::shared_ptr actual; + + ASSERT_OK(this->DoFilter(schm, batch_json, selection, options, &actual)); + ValidateOutput(actual); + ASSERT_BATCHES_EQUAL(*RecordBatchFromJSON(schm, expected_batch), *actual); + } + + Status DoFilter(const std::shared_ptr& schm, const std::string& batch_json, + const std::string& selection, FilterOptions options, + std::shared_ptr* out) { + auto batch = RecordBatchFromJSON(schm, batch_json); + ARROW_ASSIGN_OR_RAISE(Datum out_datum, + Filter(batch, ArrayFromJSON(boolean(), selection), options)); + *out = out_datum.record_batch(); + return Status::OK(); + } +}; + +TEST_F(TestFilterKernelWithRecordBatch, FilterRecordBatch) { + std::vector> fields = {field("a", int32()), field("b", utf8())}; + auto schm = schema(fields); + + auto batch_json = R"([ + {"a": null, "b": "yo"}, + {"a": 1, "b": ""}, + {"a": 2, "b": "hello"}, + {"a": 4, "b": "eh"} + ])"; + for (auto options : {this->emit_null_, this->drop_}) { + this->AssertFilter(schm, batch_json, "[0, 0, 0, 0]", options, "[]"); + this->AssertFilter(schm, batch_json, "[1, 1, 1, 1]", options, batch_json); + this->AssertFilter(schm, batch_json, "[1, 0, 1, 0]", options, R"([ + {"a": null, "b": "yo"}, + {"a": 2, "b": "hello"} + ])"); + } + + this->AssertFilter(schm, batch_json, "[0, 1, 1, null]", this->drop_, R"([ + {"a": 1, "b": ""}, + {"a": 2, "b": "hello"} + ])"); + + this->AssertFilter(schm, batch_json, "[0, 1, 1, null]", this->emit_null_, R"([ + {"a": 1, "b": ""}, + {"a": 2, "b": "hello"}, + {"a": null, "b": null} + ])"); +} + +class TestFilterKernelWithChunkedArray : public TestFilterKernel { + public: + void AssertFilter(const std::shared_ptr& type, + const std::vector& values, const std::string& filter, + const std::vector& expected) { + std::shared_ptr actual; + ASSERT_OK(this->FilterWithArray(type, values, filter, &actual)); + ValidateOutput(actual); + AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual); + } + + void AssertChunkedFilter(const std::shared_ptr& type, + const std::vector& values, + const std::vector& filter, + const std::vector& expected) { + std::shared_ptr actual; + ASSERT_OK(this->FilterWithChunkedArray(type, values, filter, &actual)); + ValidateOutput(actual); + AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual); + } + + Status FilterWithArray(const std::shared_ptr& type, + const std::vector& values, + const std::string& filter, std::shared_ptr* out) { + ARROW_ASSIGN_OR_RAISE(Datum out_datum, Filter(ChunkedArrayFromJSON(type, values), + ArrayFromJSON(boolean(), filter))); + *out = out_datum.chunked_array(); + return Status::OK(); + } + + Status FilterWithChunkedArray(const std::shared_ptr& type, + const std::vector& values, + const std::vector& filter, + std::shared_ptr* out) { + ARROW_ASSIGN_OR_RAISE(Datum out_datum, + Filter(ChunkedArrayFromJSON(type, values), + ChunkedArrayFromJSON(boolean(), filter))); + *out = out_datum.chunked_array(); + return Status::OK(); + } +}; + +TEST_F(TestFilterKernelWithChunkedArray, FilterChunkedArray) { + this->AssertFilter(int8(), {"[]"}, "[]", {}); + this->AssertChunkedFilter(int8(), {"[]"}, {"[]"}, {}); + + this->AssertFilter(int8(), {"[7]", "[8, 9]"}, "[0, 1, 0]", {"[8]"}); + this->AssertChunkedFilter(int8(), {"[7]", "[8, 9]"}, {"[0]", "[1, 0]"}, {"[8]"}); + this->AssertChunkedFilter(int8(), {"[7]", "[8, 9]"}, {"[0, 1]", "[0]"}, {"[8]"}); + + std::shared_ptr arr; + ASSERT_RAISES( + Invalid, this->FilterWithArray(int8(), {"[7]", "[8, 9]"}, "[0, 1, 0, 1, 1]", &arr)); + ASSERT_RAISES(Invalid, this->FilterWithChunkedArray(int8(), {"[7]", "[8, 9]"}, + {"[0, 1, 0]", "[1, 1]"}, &arr)); +} + +class TestFilterKernelWithTable : public TestFilterKernel { + public: + void AssertFilter(const std::shared_ptr& schm, + const std::vector& table_json, const std::string& filter, + FilterOptions options, + const std::vector& expected_table) { + std::shared_ptr
actual; + + ASSERT_OK(this->FilterWithArray(schm, table_json, filter, options, &actual)); + ValidateOutput(actual); + ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected_table), *actual); + } + + void AssertChunkedFilter(const std::shared_ptr& schm, + const std::vector& table_json, + const std::vector& filter, FilterOptions options, + const std::vector& expected_table) { + std::shared_ptr
actual; + + ASSERT_OK(this->FilterWithChunkedArray(schm, table_json, filter, options, &actual)); + ValidateOutput(actual); + AssertTablesEqual(*TableFromJSON(schm, expected_table), *actual, + /*same_chunk_layout=*/false); + } + + Status FilterWithArray(const std::shared_ptr& schm, + const std::vector& values, + const std::string& filter, FilterOptions options, + std::shared_ptr
* out) { + ARROW_ASSIGN_OR_RAISE( + Datum out_datum, + Filter(TableFromJSON(schm, values), ArrayFromJSON(boolean(), filter), options)); + *out = out_datum.table(); + return Status::OK(); + } + + Status FilterWithChunkedArray(const std::shared_ptr& schm, + const std::vector& values, + const std::vector& filter, + FilterOptions options, std::shared_ptr
* out) { + ARROW_ASSIGN_OR_RAISE(Datum out_datum, + Filter(TableFromJSON(schm, values), + ChunkedArrayFromJSON(boolean(), filter), options)); + *out = out_datum.table(); + return Status::OK(); + } +}; + +TEST_F(TestFilterKernelWithTable, FilterTable) { + std::vector> fields = {field("a", int32()), field("b", utf8())}; + auto schm = schema(fields); + + std::vector table_json = {R"([ + {"a": null, "b": "yo"}, + {"a": 1, "b": ""} + ])", + R"([ + {"a": 2, "b": "hello"}, + {"a": 4, "b": "eh"} + ])"}; + for (auto options : {this->emit_null_, this->drop_}) { + this->AssertFilter(schm, table_json, "[0, 0, 0, 0]", options, {}); + this->AssertChunkedFilter(schm, table_json, {"[0]", "[0, 0, 0]"}, options, {}); + this->AssertFilter(schm, table_json, "[1, 1, 1, 1]", options, table_json); + this->AssertChunkedFilter(schm, table_json, {"[1]", "[1, 1, 1]"}, options, + table_json); + } + + std::vector expected_emit_null = {R"([ + {"a": 1, "b": ""} + ])", + R"([ + {"a": 2, "b": "hello"}, + {"a": null, "b": null} + ])"}; + this->AssertFilter(schm, table_json, "[0, 1, 1, null]", this->emit_null_, + expected_emit_null); + this->AssertChunkedFilter(schm, table_json, {"[0, 1, 1]", "[null]"}, this->emit_null_, + expected_emit_null); + + std::vector expected_drop = {R"([{"a": 1, "b": ""}])", + R"([{"a": 2, "b": "hello"}])"}; + this->AssertFilter(schm, table_json, "[0, 1, 1, null]", this->drop_, expected_drop); + this->AssertChunkedFilter(schm, table_json, {"[0, 1, 1]", "[null]"}, this->drop_, + expected_drop); +} + +TEST(TestFilterMetaFunction, ArityChecking) { + ASSERT_RAISES(Invalid, CallFunction("filter", ExecBatch({}, 0))); +} + +// ---------------------------------------------------------------------- +// Take tests + +void AssertTakeArrays(const std::shared_ptr& values, + const std::shared_ptr& indices, + const std::shared_ptr& expected) { + ASSERT_OK_AND_ASSIGN(std::shared_ptr actual, Take(*values, *indices)); + ValidateOutput(actual); + AssertArraysEqual(*expected, *actual, /*verbose=*/true); +} + +Status TakeJSON(const std::shared_ptr& type, const std::string& values, + const std::shared_ptr& index_type, const std::string& indices, + std::shared_ptr* out) { + return Take(*ArrayFromJSON(type, values), *ArrayFromJSON(index_type, indices)) + .Value(out); +} + +void DoCheckTake(const std::shared_ptr& values, + const std::shared_ptr& indices, + const std::shared_ptr& expected) { + AssertTakeArrays(values, indices, expected); + + // Check sliced values + ASSERT_OK_AND_ASSIGN(auto values_filler, MakeArrayOfNull(values->type(), 2)); + ASSERT_OK_AND_ASSIGN(auto values_sliced, + Concatenate({values_filler, values, values_filler})); + values_sliced = values_sliced->Slice(2, values->length()); + AssertTakeArrays(values_sliced, indices, expected); + + // Check sliced indices + ASSERT_OK_AND_ASSIGN(auto zero, MakeScalar(indices->type(), int8_t{0})); + ASSERT_OK_AND_ASSIGN(auto indices_filler, MakeArrayFromScalar(*zero, 3)); + ASSERT_OK_AND_ASSIGN(auto indices_sliced, + Concatenate({indices_filler, indices, indices_filler})); + indices_sliced = indices_sliced->Slice(3, indices->length()); + AssertTakeArrays(values, indices_sliced, expected); +} + +void CheckTake(const std::shared_ptr& type, const std::string& values_json, + const std::string& indices_json, const std::string& expected_json) { + auto values = ArrayFromJSON(type, values_json); + auto expected = ArrayFromJSON(type, expected_json); + for (auto index_type : {int8(), uint32()}) { + auto indices = ArrayFromJSON(index_type, indices_json); + DoCheckTake(values, indices, expected); + } +} + +void AssertTakeNull(const std::string& values, const std::string& indices, + const std::string& expected) { + CheckTake(null(), values, indices, expected); +} + +void AssertTakeBoolean(const std::string& values, const std::string& indices, + const std::string& expected) { + CheckTake(boolean(), values, indices, expected); +} + +template +void ValidateTakeImpl(const std::shared_ptr& values, + const std::shared_ptr& indices, + const std::shared_ptr& result) { + using ValuesArrayType = typename TypeTraits::ArrayType; + using IndexArrayType = typename TypeTraits::ArrayType; + auto typed_values = checked_pointer_cast(values); + auto typed_result = checked_pointer_cast(result); + auto typed_indices = checked_pointer_cast(indices); + for (int64_t i = 0; i < indices->length(); ++i) { + if (typed_indices->IsNull(i) || typed_values->IsNull(typed_indices->Value(i))) { + ASSERT_TRUE(result->IsNull(i)) << i; + // The value of a null element is undefined, but right + // out of the Take kernel it is expected to be 0. + if constexpr (is_primitive(ValuesType::type_id)) { + if constexpr (ValuesType::type_id == Type::BOOL) { + ASSERT_EQ(typed_result->Value(i), false); + } else { + ASSERT_EQ(typed_result->Value(i), 0); + } + } + } else { + ASSERT_FALSE(result->IsNull(i)) << i; + ASSERT_EQ(typed_result->GetView(i), typed_values->GetView(typed_indices->Value(i))) + << i; + } + } +} + +template +void ValidateTake(const std::shared_ptr& values, + const std::shared_ptr& indices) { + ASSERT_OK_AND_ASSIGN(Datum out, Take(values, indices)); + auto taken = out.make_array(); + ValidateOutput(taken); + ASSERT_EQ(indices->length(), taken->length()); + switch (indices->type_id()) { + case Type::INT8: + ValidateTakeImpl(values, indices, taken); + break; + case Type::INT16: + ValidateTakeImpl(values, indices, taken); + break; + case Type::INT32: + ValidateTakeImpl(values, indices, taken); + break; + case Type::INT64: + ValidateTakeImpl(values, indices, taken); + break; + case Type::UINT8: + ValidateTakeImpl(values, indices, taken); + break; + case Type::UINT16: + ValidateTakeImpl(values, indices, taken); + break; + case Type::UINT32: + ValidateTakeImpl(values, indices, taken); + break; + case Type::UINT64: + ValidateTakeImpl(values, indices, taken); + break; + default: + FAIL() << "Invalid index type"; + break; + } +} + +template +T GetMaxIndex(int64_t values_length) { + int64_t max_index = values_length - 1; + if (max_index > static_cast(std::numeric_limits::max())) { + max_index = std::numeric_limits::max(); + } + return static_cast(max_index); +} + +template <> +uint64_t GetMaxIndex(int64_t values_length) { + return static_cast(values_length - 1); +} + +class TestTakeKernel : public ::testing::Test { + public: + void TestNoValidityBitmapButUnknownNullCount(const std::shared_ptr& values, + const std::shared_ptr& indices) { + ASSERT_EQ(values->null_count(), 0); + ASSERT_EQ(indices->null_count(), 0); + auto expected = (*Take(values, indices)).make_array(); + + auto new_values = MakeArray(values->data()->Copy()); + new_values->data()->buffers[0].reset(); + new_values->data()->null_count = kUnknownNullCount; + auto new_indices = MakeArray(indices->data()->Copy()); + new_indices->data()->buffers[0].reset(); + new_indices->data()->null_count = kUnknownNullCount; + auto result = (*Take(new_values, new_indices)).make_array(); + + AssertArraysEqual(*expected, *result); + } + + void TestNoValidityBitmapButUnknownNullCount(const std::shared_ptr& type, + const std::string& values, + const std::string& indices) { + TestNoValidityBitmapButUnknownNullCount(ArrayFromJSON(type, values), + ArrayFromJSON(int16(), indices)); + } + + void TestNumericBasics(const std::shared_ptr& type) { + ARROW_SCOPED_TRACE("type = ", *type); + CheckTake(type, "[7, 8, 9]", "[]", "[]"); + CheckTake(type, "[7, 8, 9]", "[0, 1, 0]", "[7, 8, 7]"); + CheckTake(type, "[null, 8, 9]", "[0, 1, 0]", "[null, 8, null]"); + CheckTake(type, "[7, 8, 9]", "[null, 1, 0]", "[null, 8, 7]"); + CheckTake(type, "[null, 8, 9]", "[]", "[]"); + CheckTake(type, "[7, 8, 9]", "[0, 0, 0, 0, 0, 0, 2]", "[7, 7, 7, 7, 7, 7, 9]"); + + std::shared_ptr arr; + ASSERT_RAISES(IndexError, TakeJSON(type, "[7, 8, 9]", int8(), "[0, 9, 0]", &arr)); + ASSERT_RAISES(IndexError, TakeJSON(type, "[7, 8, 9]", int8(), "[0, -1, 0]", &arr)); + } +}; + +template +class TestTakeKernelTyped : public TestTakeKernel {}; + +TEST_F(TestTakeKernel, TakeNull) { + AssertTakeNull("[null, null, null]", "[0, 1, 0]", "[null, null, null]"); + AssertTakeNull("[null, null, null]", "[0, 2]", "[null, null]"); + + std::shared_ptr arr; + ASSERT_RAISES(IndexError, + TakeJSON(null(), "[null, null, null]", int8(), "[0, 9, 0]", &arr)); + ASSERT_RAISES(IndexError, + TakeJSON(boolean(), "[null, null, null]", int8(), "[0, -1, 0]", &arr)); +} + +TEST_F(TestTakeKernel, InvalidIndexType) { + std::shared_ptr arr; + ASSERT_RAISES(NotImplemented, TakeJSON(null(), "[null, null, null]", float32(), + "[0.0, 1.0, 0.1]", &arr)); +} + +TEST_F(TestTakeKernel, TakeCCEmptyIndices) { + Datum dat = ChunkedArrayFromJSON(int8(), {"[]"}); + Datum idx = ChunkedArrayFromJSON(int32(), {}); + ASSERT_OK_AND_ASSIGN(auto out, Take(dat, idx)); + ValidateOutput(out); + AssertDatumsEqual(ChunkedArrayFromJSON(int8(), {"[]"}), out, true); +} + +TEST_F(TestTakeKernel, TakeACEmptyIndices) { + Datum dat = ArrayFromJSON(int8(), {"[]"}); + Datum idx = ChunkedArrayFromJSON(int32(), {}); + ASSERT_OK_AND_ASSIGN(auto out, Take(dat, idx)); + ValidateOutput(out); + AssertDatumsEqual(ChunkedArrayFromJSON(int8(), {"[]"}), out, true); +} + +TEST_F(TestTakeKernel, DefaultOptions) { + auto indices = ArrayFromJSON(int8(), "[null, 2, 0, 3]"); + auto values = ArrayFromJSON(int8(), "[7, 8, 9, null]"); + ASSERT_OK_AND_ASSIGN(auto no_options_provided, CallFunction("take", {values, indices})); + + auto default_options = TakeOptions::Defaults(); + ASSERT_OK_AND_ASSIGN(auto explicit_defaults, + CallFunction("take", {values, indices}, &default_options)); + + AssertDatumsEqual(explicit_defaults, no_options_provided); +} + +TEST_F(TestTakeKernel, TakeBoolean) { + AssertTakeBoolean("[7, 8, 9]", "[]", "[]"); + AssertTakeBoolean("[true, false, true]", "[0, 1, 0]", "[true, false, true]"); + AssertTakeBoolean("[null, false, true]", "[0, 1, 0]", "[null, false, null]"); + AssertTakeBoolean("[true, false, true]", "[null, 1, 0]", "[null, false, true]"); + + TestNoValidityBitmapButUnknownNullCount(boolean(), "[true, false, true]", "[1, 0, 0]"); + + std::shared_ptr arr; + ASSERT_RAISES(IndexError, + TakeJSON(boolean(), "[true, false, true]", int8(), "[0, 9, 0]", &arr)); + ASSERT_RAISES(IndexError, + TakeJSON(boolean(), "[true, false, true]", int8(), "[0, -1, 0]", &arr)); +} + +TEST_F(TestTakeKernel, Temporal) { + this->TestNumericBasics(time32(TimeUnit::MILLI)); + this->TestNumericBasics(time64(TimeUnit::MICRO)); + this->TestNumericBasics(timestamp(TimeUnit::NANO, "Europe/Paris")); + this->TestNumericBasics(duration(TimeUnit::SECOND)); + this->TestNumericBasics(date32()); + CheckTake(date64(), "[0, 86400000, null]", "[null, 1, 1, 0]", + "[null, 86400000, 86400000, 0]"); +} + +TEST_F(TestTakeKernel, Duration) { + for (auto type : DurationTypes()) { + this->TestNumericBasics(type); + } +} + +TEST_F(TestTakeKernel, Interval) { + this->TestNumericBasics(month_interval()); + + auto type = day_time_interval(); + CheckTake(type, "[[1, -600], [2, 3000], null]", "[0, null, 2, 1]", + "[[1, -600], null, null, [2, 3000]]"); + type = month_day_nano_interval(); + CheckTake(type, "[[1, -2, 34567890123456789], [2, 3, -34567890123456789], null]", + "[0, null, 2, 1]", + "[[1, -2, 34567890123456789], null, null, [2, 3, -34567890123456789]]"); +} + +template +class TestTakeKernelWithNumeric : public TestTakeKernelTyped { + protected: + void AssertTake(const std::string& values, const std::string& indices, + const std::string& expected) { + CheckTake(type_singleton(), values, indices, expected); + } + + std::shared_ptr type_singleton() { + return TypeTraits::type_singleton(); + } +}; + +TYPED_TEST_SUITE(TestTakeKernelWithNumeric, NumericArrowTypes); +TYPED_TEST(TestTakeKernelWithNumeric, TakeNumeric) { + this->TestNumericBasics(this->type_singleton()); +} + +template +class TestTakeKernelWithString : public TestTakeKernelTyped { + public: + std::shared_ptr value_type() { + return TypeTraits::type_singleton(); + } + + void AssertTake(const std::string& values, const std::string& indices, + const std::string& expected) { + CheckTake(value_type(), values, indices, expected); + } + + void AssertTakeDictionary(const std::string& dictionary_values, + const std::string& dictionary_indices, + const std::string& indices, + const std::string& expected_indices) { + auto dict = ArrayFromJSON(value_type(), dictionary_values); + auto type = dictionary(int8(), value_type()); + ASSERT_OK_AND_ASSIGN(auto values, + DictionaryArray::FromArrays( + type, ArrayFromJSON(int8(), dictionary_indices), dict)); + ASSERT_OK_AND_ASSIGN( + auto expected, + DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), expected_indices), dict)); + auto take_indices = ArrayFromJSON(int8(), indices); + AssertTakeArrays(values, take_indices, expected); + } +}; + +TYPED_TEST_SUITE(TestTakeKernelWithString, BaseBinaryArrowTypes); + +TYPED_TEST(TestTakeKernelWithString, TakeString) { + this->AssertTake(R"(["a", "b", "c"])", "[0, 1, 0]", R"(["a", "b", "a"])"); + this->AssertTake(R"([null, "b", "c"])", "[0, 1, 0]", "[null, \"b\", null]"); + this->AssertTake(R"(["a", "b", "c"])", "[null, 1, 0]", R"([null, "b", "a"])"); + + this->TestNoValidityBitmapButUnknownNullCount(this->value_type(), R"(["a", "b", "c"])", + "[0, 1, 0]"); + + std::shared_ptr type = this->value_type(); + std::shared_ptr arr; + ASSERT_RAISES(IndexError, + TakeJSON(type, R"(["a", "b", "c"])", int8(), "[0, 9, 0]", &arr)); + ASSERT_RAISES(IndexError, TakeJSON(type, R"(["a", "b", null, "ddd", "ee"])", int64(), + "[2, 5]", &arr)); +} + +TYPED_TEST(TestTakeKernelWithString, TakeDictionary) { + auto dict = R"(["a", "b", "c", "d", "e"])"; + this->AssertTakeDictionary(dict, "[3, 4, 2]", "[0, 1, 0]", "[3, 4, 3]"); + this->AssertTakeDictionary(dict, "[null, 4, 2]", "[0, 1, 0]", "[null, 4, null]"); + this->AssertTakeDictionary(dict, "[3, 4, 2]", "[null, 1, 0]", "[null, 4, 3]"); +} + +class TestTakeKernelFSB : public TestTakeKernelTyped { + public: + std::shared_ptr value_type() { return fixed_size_binary(3); } + + void AssertTake(const std::string& values, const std::string& indices, + const std::string& expected) { + CheckTake(value_type(), values, indices, expected); + } +}; + +TEST_F(TestTakeKernelFSB, TakeFixedSizeBinary) { + this->AssertTake(R"(["aaa", "bbb", "ccc"])", "[0, 1, 0]", R"(["aaa", "bbb", "aaa"])"); + this->AssertTake(R"([null, "bbb", "ccc"])", "[0, 1, 0]", "[null, \"bbb\", null]"); + this->AssertTake(R"(["aaa", "bbb", "ccc"])", "[null, 1, 0]", R"([null, "bbb", "aaa"])"); + + this->TestNoValidityBitmapButUnknownNullCount(this->value_type(), + R"(["aaa", "bbb", "ccc"])", "[0, 1, 0]"); + + std::shared_ptr type = this->value_type(); + std::shared_ptr arr; + ASSERT_RAISES(IndexError, + TakeJSON(type, R"(["aaa", "bbb", "ccc"])", int8(), "[0, 9, 0]", &arr)); + ASSERT_RAISES(IndexError, TakeJSON(type, R"(["aaa", "bbb", null, "ddd", "eee"])", + int64(), "[2, 5]", &arr)); +} + +class TestTakeKernelWithList : public TestTakeKernelTyped {}; + +TEST_F(TestTakeKernelWithList, TakeListInt32) { + std::string list_json = "[[], [1,2], null, [3]]"; + CheckTake(list(int32()), list_json, "[]", "[]"); + CheckTake(list(int32()), list_json, "[3, 2, 1]", "[[3], null, [1,2]]"); + CheckTake(list(int32()), list_json, "[null, 3, 0]", "[null, [3], []]"); + CheckTake(list(int32()), list_json, "[null, null]", "[null, null]"); + CheckTake(list(int32()), list_json, "[3, 0, 0, 3]", "[[3], [], [], [3]]"); + CheckTake(list(int32()), list_json, "[0, 1, 2, 3]", list_json); + CheckTake(list(int32()), list_json, "[0, 0, 0, 0, 0, 0, 1]", + "[[], [], [], [], [], [], [1, 2]]"); + + this->TestNoValidityBitmapButUnknownNullCount(list(int32()), "[[], [1,2], [3]]", + "[0, 1, 0]"); +} + +TEST_F(TestTakeKernelWithList, TakeListListInt32) { + std::string list_json = R"([ + [], + [[1], [2, null, 2], []], + null, + [[3, null], null] + ])"; + auto type = list(list(int32())); + CheckTake(type, list_json, "[]", "[]"); + CheckTake(type, list_json, "[3, 2, 1]", R"([ + [[3, null], null], + null, + [[1], [2, null, 2], []] + ])"); + CheckTake(type, list_json, "[null, 3, 0]", R"([ + null, + [[3, null], null], + [] + ])"); + CheckTake(type, list_json, "[null, null]", "[null, null]"); + CheckTake(type, list_json, "[3, 0, 0, 3]", + "[[[3, null], null], [], [], [[3, null], null]]"); + CheckTake(type, list_json, "[0, 1, 2, 3]", list_json); + CheckTake(type, list_json, "[0, 0, 0, 0, 0, 0, 1]", + "[[], [], [], [], [], [], [[1], [2, null, 2], []]]"); + + this->TestNoValidityBitmapButUnknownNullCount( + type, "[[[1], [2, null, 2], []], [[3, null]]]", "[0, 1, 0]"); +} + +class TestTakeKernelWithLargeList : public TestTakeKernelTyped {}; + +TEST_F(TestTakeKernelWithLargeList, TakeLargeListInt32) { + std::string list_json = "[[], [1,2], null, [3]]"; + CheckTake(large_list(int32()), list_json, "[]", "[]"); + CheckTake(large_list(int32()), list_json, "[null, 1, 2, 0]", "[null, [1,2], null, []]"); +} + +class TestTakeKernelWithFixedSizeList : public TestTakeKernelTyped { + protected: + void CheckTakeOnNestedLists(const std::shared_ptr& inner_type, + const std::vector& list_sizes, int64_t length) { + using NLG = ::arrow::util::internal::NestedListGenerator; + // Create two equivalent lists: one as a FixedSizeList and another as a List. + ASSERT_OK_AND_ASSIGN(auto fsl_list, + NLG::NestedFSLArray(inner_type, list_sizes, length)); + ASSERT_OK_AND_ASSIGN(auto list, NLG::NestedListArray(inner_type, list_sizes, length)); + + ARROW_SCOPED_TRACE("CheckTakeOnNestedLists of type `", *fsl_list->type(), "`"); + + auto indices = ArrayFromJSON(int64(), "[1, 2, 4]"); + // Use the Take on ListType as the reference implementation. + ASSERT_OK_AND_ASSIGN(auto expected_list, Take(*list, *indices)); + ASSERT_OK_AND_ASSIGN(auto expected_fsl, Cast(*expected_list, fsl_list->type())); + DoCheckTake(fsl_list, indices, expected_fsl); + } +}; + +TEST_F(TestTakeKernelWithFixedSizeList, TakeFixedSizeListInt32) { + std::string list_json = "[null, [1, null, 3], [4, 5, 6], [7, 8, null]]"; + CheckTake(fixed_size_list(int32(), 3), list_json, "[]", "[]"); + CheckTake(fixed_size_list(int32(), 3), list_json, "[3, 2, 1]", + "[[7, 8, null], [4, 5, 6], [1, null, 3]]"); + CheckTake(fixed_size_list(int32(), 3), list_json, "[null, 2, 0]", + "[null, [4, 5, 6], null]"); + CheckTake(fixed_size_list(int32(), 3), list_json, "[null, null]", "[null, null]"); + CheckTake(fixed_size_list(int32(), 3), list_json, "[3, 0, 0, 3]", + "[[7, 8, null], null, null, [7, 8, null]]"); + CheckTake(fixed_size_list(int32(), 3), list_json, "[0, 1, 2, 3]", list_json); + + // No nulls in inner list values trigger the use of FixedWidthTakeExec() in + // FSLTakeExec() + std::string no_nulls_list_json = "[[0, 0, 0], [1, 2, 3], [4, 5, 6], [7, 8, 9]]"; + CheckTake( + fixed_size_list(int32(), 3), no_nulls_list_json, "[2, 2, 2, 2, 2, 2, 1]", + "[[4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [4, 5, 6], [1, 2, 3]]"); + + this->TestNoValidityBitmapButUnknownNullCount(fixed_size_list(int32(), 3), + "[[1, null, 3], [4, 5, 6], [7, 8, null]]", + "[0, 1, 0]"); +} + +TEST_F(TestTakeKernelWithFixedSizeList, TakeFixedSizeListVarWidth) { + std::string list_json = + R"([["zero", "one", ""], ["two", "", "three"], ["four", "five", "six"], ["seven", "eight", ""]])"; + CheckTake(fixed_size_list(utf8(), 3), list_json, "[]", "[]"); + CheckTake(fixed_size_list(utf8(), 3), list_json, "[3, 2, 1]", + R"([["seven", "eight", ""], ["four", "five", "six"], ["two", "", "three"]])"); + CheckTake(fixed_size_list(utf8(), 3), list_json, "[null, 2, 0]", + R"([null, ["four", "five", "six"], ["zero", "one", ""]])"); + CheckTake(fixed_size_list(utf8(), 3), list_json, R"([null, null])", "[null, null]"); + CheckTake( + fixed_size_list(utf8(), 3), list_json, "[3, 0, 0,3]", + R"([["seven", "eight", ""], ["zero", "one", ""], ["zero", "one", ""], ["seven", "eight", ""]])"); + CheckTake(fixed_size_list(utf8(), 3), list_json, "[0, 1, 2, 3]", list_json); + CheckTake(fixed_size_list(utf8(), 3), list_json, "[2, 2, 2, 2, 2, 2, 1]", + R"([ + ["four", "five", "six"], ["four", "five", "six"], + ["four", "five", "six"], ["four", "five", "six"], + ["four", "five", "six"], ["four", "five", "six"], + ["two", "", "three"] + ])"); +} + +TEST_F(TestTakeKernelWithFixedSizeList, TakeFixedSizeListModuloNesting) { + using NLG = ::arrow::util::internal::NestedListGenerator; + const std::vector> value_types = { + int16(), + int32(), + int64(), + }; + NLG::VisitAllNestedListConfigurations( + value_types, [this](const std::shared_ptr& inner_type, + const std::vector& list_sizes) { + this->CheckTakeOnNestedLists(inner_type, list_sizes, /*length=*/5); + }); +} + +class TestTakeKernelWithMap : public TestTakeKernelTyped {}; + +TEST_F(TestTakeKernelWithMap, TakeMapStringToInt32) { + std::string map_json = R"([ + [["joe", 0], ["mark", null]], + null, + [["cap", 8]], + [] + ])"; + CheckTake(map(utf8(), int32()), map_json, "[]", "[]"); + CheckTake(map(utf8(), int32()), map_json, "[3, 1, 3, 1, 3]", + "[[], null, [], null, []]"); + CheckTake(map(utf8(), int32()), map_json, "[2, 1, null]", R"([ + [["cap", 8]], + null, + null + ])"); + CheckTake(map(utf8(), int32()), map_json, "[2, 1, 0]", R"([ + [["cap", 8]], + null, + [["joe", 0], ["mark", null]] + ])"); + CheckTake(map(utf8(), int32()), map_json, "[0, 1, 2, 3]", map_json); + CheckTake(map(utf8(), int32()), map_json, "[0, 0, 0, 0, 0, 0, 3]", R"([ + [["joe", 0], ["mark", null]], + [["joe", 0], ["mark", null]], + [["joe", 0], ["mark", null]], + [["joe", 0], ["mark", null]], + [["joe", 0], ["mark", null]], + [["joe", 0], ["mark", null]], + [] + ])"); +} + +class TestTakeKernelWithStruct : public TestTakeKernelTyped {}; + +TEST_F(TestTakeKernelWithStruct, TakeStruct) { + auto struct_type = struct_({field("a", int32()), field("b", utf8())}); + auto struct_json = R"([ + null, + {"a": 1, "b": ""}, + {"a": 2, "b": "hello"}, + {"a": 4, "b": "eh"} + ])"; + CheckTake(struct_type, struct_json, "[]", "[]"); + CheckTake(struct_type, struct_json, "[3, 1, 3, 1, 3]", R"([ + {"a": 4, "b": "eh"}, + {"a": 1, "b": ""}, + {"a": 4, "b": "eh"}, + {"a": 1, "b": ""}, + {"a": 4, "b": "eh"} + ])"); + CheckTake(struct_type, struct_json, "[3, 1, 0]", R"([ + {"a": 4, "b": "eh"}, + {"a": 1, "b": ""}, + null + ])"); + CheckTake(struct_type, struct_json, "[0, 1, 2, 3]", struct_json); + CheckTake(struct_type, struct_json, "[0, 2, 2, 2, 2, 2, 2]", R"([ + null, + {"a": 2, "b": "hello"}, + {"a": 2, "b": "hello"}, + {"a": 2, "b": "hello"}, + {"a": 2, "b": "hello"}, + {"a": 2, "b": "hello"}, + {"a": 2, "b": "hello"} + ])"); + + this->TestNoValidityBitmapButUnknownNullCount( + struct_type, R"([{"a": 1}, {"a": 2, "b": "hello"}])", "[0, 1, 0]"); +} + +class TestTakeKernelWithUnion : public TestTakeKernelTyped {}; + +TEST_F(TestTakeKernelWithUnion, TakeUnion) { + for (const auto& union_type : + {dense_union({field("a", int32()), field("b", utf8())}, {2, 5}), + sparse_union({field("a", int32()), field("b", utf8())}, {2, 5})}) { + auto union_json = R"([ + [2, 222], + [2, null], + [5, "hello"], + [5, "eh"], + [2, null], + [2, 111], + [5, null] + ])"; + CheckTake(union_type, union_json, "[]", "[]"); + CheckTake(union_type, union_json, "[3, 0, 3, 0, 3]", R"([ + [5, "eh"], + [2, 222], + [5, "eh"], + [2, 222], + [5, "eh"] + ])"); + CheckTake(union_type, union_json, "[4, 2, 0, 6]", R"([ + [2, null], + [5, "hello"], + [2, 222], + [5, null] + ])"); + CheckTake(union_type, union_json, "[0, 1, 2, 3, 4, 5, 6]", union_json); + CheckTake(union_type, union_json, "[1, 2, 2, 2, 2, 2, 2]", R"([ + [2, null], + [5, "hello"], + [5, "hello"], + [5, "hello"], + [5, "hello"], + [5, "hello"], + [5, "hello"] + ])"); + CheckTake(union_type, union_json, "[0, null, 1, null, 2, 2, 2]", R"([ + [2, 222], + [2, null], + [2, null], + [2, null], + [5, "hello"], + [5, "hello"], + [5, "hello"] + ])"); + } +} + +class TestPermutationsWithTake : public ::testing::Test { + protected: + void DoTake(const Int16Array& values, const Int16Array& indices, + std::shared_ptr* out) { + ASSERT_OK_AND_ASSIGN(std::shared_ptr boxed_out, Take(values, indices)); + ValidateOutput(boxed_out); + *out = checked_pointer_cast(std::move(boxed_out)); + } + + std::shared_ptr DoTake(const Int16Array& values, + const Int16Array& indices) { + std::shared_ptr out; + DoTake(values, indices, &out); + return out; + } + + std::shared_ptr DoTakeN(uint64_t n, std::shared_ptr array) { + auto power_of_2 = array; + array = Identity(array->length()); + while (n != 0) { + if (n & 1) { + array = DoTake(*array, *power_of_2); + } + power_of_2 = DoTake(*power_of_2, *power_of_2); + n >>= 1; + } + return array; + } + + template + void Shuffle(const Int16Array& array, Rng& gen, std::shared_ptr* shuffled) { + auto byte_length = array.length() * sizeof(int16_t); + ASSERT_OK_AND_ASSIGN(auto data, array.values()->CopySlice(0, byte_length)); + auto mutable_data = reinterpret_cast(data->mutable_data()); + std::shuffle(mutable_data, mutable_data + array.length(), gen); + shuffled->reset(new Int16Array(array.length(), data)); + } + + template + std::shared_ptr Shuffle(const Int16Array& array, Rng& gen) { + std::shared_ptr out; + Shuffle(array, gen, &out); + return out; + } + + void Identity(int64_t length, std::shared_ptr* identity) { + Int16Builder identity_builder; + ASSERT_OK(identity_builder.Resize(length)); + for (int16_t i = 0; i < length; ++i) { + identity_builder.UnsafeAppend(i); + } + ASSERT_OK(identity_builder.Finish(identity)); + } + + std::shared_ptr Identity(int64_t length) { + std::shared_ptr out; + Identity(length, &out); + return out; + } + + std::shared_ptr Inverse(const std::shared_ptr& permutation) { + auto length = static_cast(permutation->length()); + + std::vector cycle_lengths(length + 1, false); + auto permutation_to_the_i = permutation; + for (int16_t cycle_length = 1; cycle_length <= length; ++cycle_length) { + cycle_lengths[cycle_length] = HasTrivialCycle(*permutation_to_the_i); + permutation_to_the_i = DoTake(*permutation, *permutation_to_the_i); + } + + uint64_t cycle_to_identity_length = 1; + for (int16_t cycle_length = length; cycle_length > 1; --cycle_length) { + if (!cycle_lengths[cycle_length]) { + continue; + } + if (cycle_to_identity_length % cycle_length == 0) { + continue; + } + if (cycle_to_identity_length > + std::numeric_limits::max() / cycle_length) { + // overflow, can't compute Inverse + return nullptr; + } + cycle_to_identity_length *= cycle_length; + } + + return DoTakeN(cycle_to_identity_length - 1, permutation); + } + + bool HasTrivialCycle(const Int16Array& permutation) { + for (int64_t i = 0; i < permutation.length(); ++i) { + if (permutation.Value(i) == static_cast(i)) { + return true; + } + } + return false; + } +}; + +TEST_F(TestPermutationsWithTake, InvertPermutation) { + for (auto seed : std::vector({0, kRandomSeed, kRandomSeed * 2 - 1})) { + std::default_random_engine gen(seed); + for (int16_t length = 0; length < 1 << 10; ++length) { + auto identity = Identity(length); + auto permutation = Shuffle(*identity, gen); + auto inverse = Inverse(permutation); + if (inverse == nullptr) { + break; + } + ASSERT_TRUE(DoTake(*inverse, *permutation)->Equals(identity)); + } + } +} + +class TestTakeKernelWithRecordBatch : public TestTakeKernelTyped { + public: + void AssertTake(const std::shared_ptr& schm, const std::string& batch_json, + const std::string& indices, const std::string& expected_batch) { + std::shared_ptr actual; + + for (auto index_type : {int8(), uint32()}) { + ASSERT_OK(TakeJSON(schm, batch_json, index_type, indices, &actual)); + ValidateOutput(actual); + ASSERT_BATCHES_EQUAL(*RecordBatchFromJSON(schm, expected_batch), *actual); + } + } + + Status TakeJSON(const std::shared_ptr& schm, const std::string& batch_json, + const std::shared_ptr& index_type, const std::string& indices, + std::shared_ptr* out) { + auto batch = RecordBatchFromJSON(schm, batch_json); + ARROW_ASSIGN_OR_RAISE(Datum result, + Take(Datum(batch), Datum(ArrayFromJSON(index_type, indices)))); + *out = result.record_batch(); + return Status::OK(); + } +}; + +TEST_F(TestTakeKernelWithRecordBatch, TakeRecordBatch) { + std::vector> fields = {field("a", int32()), field("b", utf8())}; + auto schm = schema(fields); + + auto struct_json = R"([ + {"a": null, "b": "yo"}, + {"a": 1, "b": ""}, + {"a": 2, "b": "hello"}, + {"a": 4, "b": "eh"} + ])"; + this->AssertTake(schm, struct_json, "[]", "[]"); + this->AssertTake(schm, struct_json, "[3, 1, 3, 1, 3]", R"([ + {"a": 4, "b": "eh"}, + {"a": 1, "b": ""}, + {"a": 4, "b": "eh"}, + {"a": 1, "b": ""}, + {"a": 4, "b": "eh"} + ])"); + this->AssertTake(schm, struct_json, "[3, 1, 0]", R"([ + {"a": 4, "b": "eh"}, + {"a": 1, "b": ""}, + {"a": null, "b": "yo"} + ])"); + this->AssertTake(schm, struct_json, "[0, 1, 2, 3]", struct_json); + this->AssertTake(schm, struct_json, "[0, 2, 2, 2, 2, 2, 2]", R"([ + {"a": null, "b": "yo"}, + {"a": 2, "b": "hello"}, + {"a": 2, "b": "hello"}, + {"a": 2, "b": "hello"}, + {"a": 2, "b": "hello"}, + {"a": 2, "b": "hello"}, + {"a": 2, "b": "hello"} + ])"); +} + +class TestTakeKernelWithChunkedArray : public TestTakeKernelTyped { + public: + void AssertTake(const std::shared_ptr& type, + const std::vector& values, const std::string& indices, + const std::vector& expected) { + std::shared_ptr actual; + ASSERT_OK(this->TakeWithArray(type, values, indices, &actual)); + ValidateOutput(actual); + AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual); + } + + void AssertChunkedTake(const std::shared_ptr& type, + const std::vector& values, + const std::vector& indices, + const std::vector& expected) { + std::shared_ptr actual; + ASSERT_OK(this->TakeWithChunkedArray(type, values, indices, &actual)); + ValidateOutput(actual); + AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual); + } + + Status TakeWithArray(const std::shared_ptr& type, + const std::vector& values, const std::string& indices, + std::shared_ptr* out) { + ARROW_ASSIGN_OR_RAISE(Datum result, Take(ChunkedArrayFromJSON(type, values), + ArrayFromJSON(int8(), indices))); + *out = result.chunked_array(); + return Status::OK(); + } + + Status TakeWithChunkedArray(const std::shared_ptr& type, + const std::vector& values, + const std::vector& indices, + std::shared_ptr* out) { + ARROW_ASSIGN_OR_RAISE(Datum result, Take(ChunkedArrayFromJSON(type, values), + ChunkedArrayFromJSON(int8(), indices))); + *out = result.chunked_array(); + return Status::OK(); + } +}; + +TEST_F(TestTakeKernelWithChunkedArray, TakeChunkedArray) { + this->AssertTake(int8(), {"[]"}, "[]", {"[]"}); + this->AssertChunkedTake(int8(), {}, {}, {}); + this->AssertChunkedTake(int8(), {}, {"[]"}, {"[]"}); + this->AssertChunkedTake(int8(), {}, {"[null]"}, {"[null]"}); + this->AssertChunkedTake(int8(), {"[]"}, {}, {}); + this->AssertChunkedTake(int8(), {"[]"}, {"[]"}, {"[]"}); + this->AssertChunkedTake(int8(), {"[]"}, {"[null]"}, {"[null]"}); + + this->AssertTake(int8(), {"[7]", "[8, 9]"}, "[0, 1, 0, 2]", {"[7, 8, 7, 9]"}); + this->AssertChunkedTake(int8(), {"[7]", "[8, 9]"}, {"[0, 1, 0]", "[]", "[2]"}, + {"[7, 8, 7]", "[]", "[9]"}); + this->AssertTake(int8(), {"[7]", "[8, 9]"}, "[2, 1]", {"[9, 8]"}); + + std::shared_ptr arr; + ASSERT_RAISES(IndexError, + this->TakeWithArray(int8(), {"[7]", "[8, 9]"}, "[0, 5]", &arr)); + ASSERT_RAISES(IndexError, this->TakeWithChunkedArray(int8(), {"[7]", "[8, 9]"}, + {"[0, 1, 0]", "[5, 1]"}, &arr)); + ASSERT_RAISES(IndexError, this->TakeWithChunkedArray(int8(), {}, {"[0]"}, &arr)); + ASSERT_RAISES(IndexError, this->TakeWithChunkedArray(int8(), {"[]"}, {"[0]"}, &arr)); +} + +class TestTakeKernelWithTable : public TestTakeKernelTyped
{ + public: + void AssertTake(const std::shared_ptr& schm, + const std::vector& table_json, const std::string& filter, + const std::vector& expected_table) { + std::shared_ptr
actual; + + ASSERT_OK(this->TakeWithArray(schm, table_json, filter, &actual)); + ValidateOutput(actual); + ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected_table), *actual); + } + + void AssertChunkedTake(const std::shared_ptr& schm, + const std::vector& table_json, + const std::vector& filter, + const std::vector& expected_table) { + std::shared_ptr
actual; + + ASSERT_OK(this->TakeWithChunkedArray(schm, table_json, filter, &actual)); + ValidateOutput(actual); + ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected_table), *actual); + } + + Status TakeWithArray(const std::shared_ptr& schm, + const std::vector& values, const std::string& indices, + std::shared_ptr
* out) { + ARROW_ASSIGN_OR_RAISE(Datum result, Take(Datum(TableFromJSON(schm, values)), + Datum(ArrayFromJSON(int8(), indices)))); + *out = result.table(); + return Status::OK(); + } + + Status TakeWithChunkedArray(const std::shared_ptr& schm, + const std::vector& values, + const std::vector& indices, + std::shared_ptr
* out) { + ARROW_ASSIGN_OR_RAISE(Datum result, + Take(Datum(TableFromJSON(schm, values)), + Datum(ChunkedArrayFromJSON(int8(), indices)))); + *out = result.table(); + return Status::OK(); + } +}; + +TEST_F(TestTakeKernelWithTable, TakeTable) { + std::vector> fields = {field("a", int32()), field("b", utf8())}; + auto schm = schema(fields); + + std::vector table_json = { + "[{\"a\": null, \"b\": \"yo\"},{\"a\": 1, \"b\": \"\"}]", + "[{\"a\": 2, \"b\": \"hello\"},{\"a\": 4, \"b\": \"eh\"}]"}; + + this->AssertTake(schm, table_json, "[]", {"[]"}); + std::vector expected_310 = { + "[{\"a\": 4, \"b\": \"eh\"},{\"a\": 1, \"b\": \"\"},{\"a\": null, \"b\": \"yo\"}]"}; + this->AssertTake(schm, table_json, "[3, 1, 0]", expected_310); + this->AssertChunkedTake(schm, table_json, {"[0, 1]", "[2, 3]"}, table_json); +} + +TEST(TestTakeMetaFunction, ArityChecking) { + ASSERT_RAISES(Invalid, CallFunction("take", ExecBatch({}, 0))); +} + +// ---------------------------------------------------------------------- +// Random data tests + +template +struct FilterRandomTest { + static void Test(const std::shared_ptr& type) { + ARROW_SCOPED_TRACE("type = ", *type); + auto rand = random::RandomArrayGenerator(kRandomSeed); + const int64_t length = static_cast(1ULL << 10); + for (auto null_probability : {0.0, 0.01, 0.1, 0.999, 1.0}) { + for (auto true_probability : {0.0, 0.1, 0.999, 1.0}) { + auto values = rand.ArrayOf(type, length, null_probability); + auto filter = rand.Boolean(length + 1, true_probability, null_probability); + auto filter_no_nulls = rand.Boolean(length + 1, true_probability, 0.0); + ValidateFilter(values, filter->Slice(0, values->length())); + ValidateFilter(values, filter_no_nulls->Slice(0, values->length())); + // Test values and filter have different offsets + ValidateFilter(values->Slice(3), filter->Slice(4)); + ValidateFilter(values->Slice(3), filter_no_nulls->Slice(4)); + } + } + } +}; + +template +void CheckTakeRandom(const std::shared_ptr& values, int64_t indices_length, + double null_probability, random::RandomArrayGenerator* rand) { + using IndexCType = typename IndexType::c_type; + IndexCType max_index = GetMaxIndex(values->length()); + auto indices = rand->Numeric(indices_length, static_cast(0), + max_index, null_probability); + auto indices_no_nulls = rand->Numeric( + indices_length, static_cast(0), max_index, /*null_probability=*/0.0); + ValidateTake(values, indices); + ValidateTake(values, indices_no_nulls); + // Sliced indices array + if (indices_length >= 2) { + indices = indices->Slice(1, indices_length - 2); + indices_no_nulls = indices_no_nulls->Slice(1, indices_length - 2); + ValidateTake(values, indices); + ValidateTake(values, indices_no_nulls); + } +} + +template +struct TakeRandomTest { + static void Test(const std::shared_ptr& type) { + ARROW_SCOPED_TRACE("type = ", *type); + auto rand = random::RandomArrayGenerator(kRandomSeed); + const int64_t values_length = 64 * 16 + 1; + const int64_t indices_length = 64 * 4 + 1; + for (const auto null_probability : {0.0, 0.001, 0.05, 0.25, 0.95, 0.999, 1.0}) { + auto values = rand.ArrayOf(type, values_length, null_probability); + CheckTakeRandom(values, indices_length, null_probability, + &rand); + CheckTakeRandom(values, indices_length, null_probability, + &rand); + CheckTakeRandom(values, indices_length, null_probability, + &rand); + CheckTakeRandom(values, indices_length, null_probability, + &rand); + CheckTakeRandom(values, indices_length, null_probability, + &rand); + CheckTakeRandom(values, indices_length, null_probability, + &rand); + CheckTakeRandom(values, indices_length, null_probability, + &rand); + CheckTakeRandom(values, indices_length, null_probability, + &rand); + // Sliced values array + if (values_length > 2) { + values = values->Slice(1, values_length - 2); + CheckTakeRandom(values, indices_length, null_probability, + &rand); + } + } + } +}; + +TEST(TestFilter, PrimitiveRandom) { TestRandomPrimitiveCTypes(); } + +TEST(TestFilter, RandomBoolean) { FilterRandomTest<>::Test(boolean()); } + +TEST(TestFilter, RandomString) { + FilterRandomTest<>::Test(utf8()); + FilterRandomTest<>::Test(large_utf8()); +} + +TEST(TestFilter, RandomFixedSizeBinary) { + // FixedSizeBinary filter is special-cased for some widths + for (int32_t width : {0, 1, 16, 32, 35}) { + FilterRandomTest<>::Test(fixed_size_binary(width)); + } +} + +TEST(TestTake, PrimitiveRandom) { TestRandomPrimitiveCTypes(); } + +TEST(TestTake, RandomBoolean) { TakeRandomTest::Test(boolean()); } + +TEST(TestTake, RandomString) { + TakeRandomTest::Test(utf8()); + TakeRandomTest::Test(large_utf8()); +} + +TEST(TestTake, RandomFixedSizeBinary) { + // FixedSizeBinary take is special-cased for some widths + for (int32_t width : {0, 1, 16, 32, 35}) { + TakeRandomTest::Test(fixed_size_binary(width)); + } +} + +// ---------------------------------------------------------------------- +// DropNull tests + +void AssertDropNullArrays(const std::shared_ptr& values, + const std::shared_ptr& expected) { + ASSERT_OK_AND_ASSIGN(std::shared_ptr actual, DropNull(*values)); + ValidateOutput(actual); + AssertArraysEqual(*expected, *actual, /*verbose=*/true); +} + +Status DropNullJSON(const std::shared_ptr& type, const std::string& values, + std::shared_ptr* out) { + return DropNull(*ArrayFromJSON(type, values)).Value(out); +} + +void CheckDropNull(const std::shared_ptr& type, const std::string& values, + const std::string& expected) { + std::shared_ptr actual; + + ASSERT_OK(DropNullJSON(type, values, &actual)); + ValidateOutput(actual); + AssertArraysEqual(*ArrayFromJSON(type, expected), *actual, /*verbose=*/true); +} + +struct TestDropNullKernel : public ::testing::Test { + void TestNoValidityBitmapButUnknownNullCount(const std::shared_ptr& values) { + ASSERT_EQ(values->null_count(), 0); + auto expected = (*DropNull(values)).make_array(); + + auto new_values = MakeArray(values->data()->Copy()); + new_values->data()->buffers[0].reset(); + new_values->data()->null_count = kUnknownNullCount; + auto result = (*DropNull(new_values)).make_array(); + AssertArraysEqual(*expected, *result); + } + + void TestNoValidityBitmapButUnknownNullCount(const std::shared_ptr& type, + const std::string& values) { + TestNoValidityBitmapButUnknownNullCount(ArrayFromJSON(type, values)); + } +}; + +TEST_F(TestDropNullKernel, DropNull) { + CheckDropNull(null(), "[null, null, null]", "[]"); + CheckDropNull(null(), "[null]", "[]"); +} + +TEST_F(TestDropNullKernel, DropNullBoolean) { + CheckDropNull(boolean(), "[true, false, true]", "[true, false, true]"); + CheckDropNull(boolean(), "[null, false, true]", "[false, true]"); + CheckDropNull(boolean(), "[]", "[]"); + CheckDropNull(boolean(), "[null, null]", "[]"); + + TestNoValidityBitmapButUnknownNullCount(boolean(), "[true, false, true]"); +} + +template +struct TestDropNullKernelTyped : public TestDropNullKernel { + TestDropNullKernelTyped() : rng_(seed_) {} + + std::shared_ptr Offsets(int32_t length, int32_t slice_count) { + return checked_pointer_cast(rng_.Offsets(slice_count, 0, length)); + } + + // Slice `array` into multiple chunks along `offsets` + ArrayVector Slices(const std::shared_ptr& array, + const std::shared_ptr& offsets) { + ArrayVector slices(offsets->length() - 1); + for (int64_t i = 0; i != static_cast(slices.size()); ++i) { + slices[i] = + array->Slice(offsets->Value(i), offsets->Value(i + 1) - offsets->Value(i)); + } + return slices; + } + + random::SeedType seed_ = 0xdeadbeef; + random::RandomArrayGenerator rng_; +}; + +template +class TestDropNullKernelWithNumeric : public TestDropNullKernelTyped { + protected: + void AssertDropNull(const std::string& values, const std::string& expected) { + CheckDropNull(type_singleton(), values, expected); + } + + std::shared_ptr type_singleton() { + return TypeTraits::type_singleton(); + } +}; + +TYPED_TEST_SUITE(TestDropNullKernelWithNumeric, NumericArrowTypes); +TYPED_TEST(TestDropNullKernelWithNumeric, DropNullNumeric) { + this->AssertDropNull("[7, 8, 9]", "[7, 8, 9]"); + this->AssertDropNull("[null, 8, 9]", "[8, 9]"); + this->AssertDropNull("[null, null, null]", "[]"); +} + +template +class TestDropNullKernelWithString : public TestDropNullKernelTyped { + public: + std::shared_ptr value_type() { + return TypeTraits::type_singleton(); + } + + void AssertDropNull(const std::string& values, const std::string& expected) { + CheckDropNull(value_type(), values, expected); + } + + void AssertDropNullDictionary(const std::string& dictionary_values, + const std::string& dictionary_indices, + const std::string& expected_indices) { + auto dict = ArrayFromJSON(value_type(), dictionary_values); + auto type = dictionary(int8(), value_type()); + ASSERT_OK_AND_ASSIGN(auto values, + DictionaryArray::FromArrays( + type, ArrayFromJSON(int8(), dictionary_indices), dict)); + ASSERT_OK_AND_ASSIGN( + auto expected, + DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), expected_indices), dict)); + AssertDropNullArrays(values, expected); + } +}; + +TYPED_TEST_SUITE(TestDropNullKernelWithString, BaseBinaryArrowTypes); + +TYPED_TEST(TestDropNullKernelWithString, DropNullString) { + this->AssertDropNull(R"(["a", "b", "c"])", R"(["a", "b", "c"])"); + this->AssertDropNull(R"([null, "b", "c"])", "[\"b\", \"c\"]"); + this->AssertDropNull(R"(["a", "b", null])", R"(["a", "b"])"); + + this->TestNoValidityBitmapButUnknownNullCount(this->value_type(), R"(["a", "b", "c"])"); +} + +TYPED_TEST(TestDropNullKernelWithString, DropNullDictionary) { + auto dict = R"(["a", "b", "c", "d", "e"])"; + this->AssertDropNullDictionary(dict, "[3, 4, 2]", "[3, 4, 2]"); + this->AssertDropNullDictionary(dict, "[null, 4, 2]", "[4, 2]"); +} + +class TestDropNullKernelFSB : public TestDropNullKernelTyped { + public: + std::shared_ptr value_type() { return fixed_size_binary(3); } + + void AssertDropNull(const std::string& values, const std::string& expected) { + CheckDropNull(value_type(), values, expected); + } +}; + +TEST_F(TestDropNullKernelFSB, DropNullFixedSizeBinary) { + this->AssertDropNull(R"(["aaa", "bbb", "ccc"])", R"(["aaa", "bbb", "ccc"])"); + this->AssertDropNull(R"([null, "bbb", "ccc"])", "[\"bbb\", \"ccc\"]"); + + this->TestNoValidityBitmapButUnknownNullCount(this->value_type(), + R"(["aaa", "bbb", "ccc"])"); +} + +class TestDropNullKernelWithList : public TestDropNullKernelTyped {}; + +TEST_F(TestDropNullKernelWithList, DropNullListInt32) { + std::string list_json = "[[], [1,2], null, [3]]"; + CheckDropNull(list(int32()), list_json, "[[], [1,2], [3]]"); + this->TestNoValidityBitmapButUnknownNullCount(list(int32()), "[[], [1,2], [3]]"); +} + +TEST_F(TestDropNullKernelWithList, DropNullListListInt32) { + std::string list_json = R"([ + [], + [[1], [2, null, 2], []], + null, + [[3, null], null] + ])"; + auto type = list(list(int32())); + CheckDropNull(type, list_json, R"([ + [], + [[1], [2, null, 2], []], + [[3, null], null] + ])"); + + this->TestNoValidityBitmapButUnknownNullCount(type, + "[[[1], [2, null, 2], []], [[3, null]]]"); +} + +class TestDropNullKernelWithLargeList : public TestDropNullKernelTyped {}; + +TEST_F(TestDropNullKernelWithLargeList, DropNullLargeListInt32) { + std::string list_json = "[[], [1,2], null, [3]]"; + CheckDropNull(large_list(int32()), list_json, "[[], [1,2], [3]]"); + + this->TestNoValidityBitmapButUnknownNullCount( + fixed_size_list(int32(), 3), "[[1, null, 3], [4, 5, 6], [7, 8, null]]"); +} + +class TestDropNullKernelWithFixedSizeList + : public TestDropNullKernelTyped {}; + +TEST_F(TestDropNullKernelWithFixedSizeList, DropNullFixedSizeListInt32) { + std::string list_json = "[null, [1, null, 3], [4, 5, 6], [7, 8, null]]"; + CheckDropNull(fixed_size_list(int32(), 3), list_json, + "[[1, null, 3], [4, 5, 6], [7, 8, null]]"); + + this->TestNoValidityBitmapButUnknownNullCount( + fixed_size_list(int32(), 3), "[[1, null, 3], [4, 5, 6], [7, 8, null]]"); +} + +class TestDropNullKernelWithMap : public TestDropNullKernelTyped {}; + +TEST_F(TestDropNullKernelWithMap, DropNullMapStringToInt32) { + std::string map_json = R"([ + [["joe", 0], ["mark", null]], + null, + [["cap", 8]], + [] + ])"; + std::string expected_json = R"([ + [["joe", 0], ["mark", null]], + [["cap", 8]], + [] + ])"; + CheckDropNull(map(utf8(), int32()), map_json, expected_json); +} + +class TestDropNullKernelWithStruct : public TestDropNullKernelTyped {}; + +TEST_F(TestDropNullKernelWithStruct, DropNullStruct) { + auto struct_type = struct_({field("a", int32()), field("b", utf8())}); + auto struct_json = R"([ + null, + {"a": 1, "b": ""}, + {"a": 2, "b": "hello"}, + {"a": 4, "b": "eh"} + ])"; + auto expected_struct_json = R"([ + {"a": 1, "b": ""}, + {"a": 2, "b": "hello"}, + {"a": 4, "b": "eh"} + ])"; + CheckDropNull(struct_type, struct_json, expected_struct_json); + this->TestNoValidityBitmapButUnknownNullCount(struct_type, expected_struct_json); +} + +class TestDropNullKernelWithUnion : public TestDropNullKernelTyped {}; + +TEST_F(TestDropNullKernelWithUnion, DropNullUnion) { + for (const auto& union_type : + {dense_union({field("a", int32()), field("b", utf8())}, {2, 5}), + sparse_union({field("a", int32()), field("b", utf8())}, {2, 5})}) { + auto union_json = R"([ + [2, null], + [2, 222], + [5, "hello"], + [5, "eh"], + [2, null], + [2, 111], + [5, null] + ])"; + CheckDropNull(union_type, union_json, union_json); + } +} + +class TestDropNullKernelWithRecordBatch : public TestDropNullKernelTyped { + public: + void AssertDropNull(const std::shared_ptr& schm, const std::string& batch_json, + const std::string& expected_batch) { + std::shared_ptr actual; + + ASSERT_OK(this->DoDropNull(schm, batch_json, &actual)); + ValidateOutput(actual); + ASSERT_BATCHES_EQUAL(*RecordBatchFromJSON(schm, expected_batch), *actual); + } + + Status DoDropNull(const std::shared_ptr& schm, const std::string& batch_json, + std::shared_ptr* out) { + auto batch = RecordBatchFromJSON(schm, batch_json); + ARROW_ASSIGN_OR_RAISE(Datum out_datum, DropNull(batch)); + *out = out_datum.record_batch(); + return Status::OK(); + } +}; + +TEST_F(TestDropNullKernelWithRecordBatch, DropNullRecordBatch) { + std::vector> fields = {field("a", int32()), field("b", utf8())}; + auto schm = schema(fields); + + auto batch_json = R"([ + {"a": null, "b": "yo"}, + {"a": 1, "b": ""}, + {"a": 2, "b": "hello"}, + {"a": 4, "b": "eh"} + ])"; + this->AssertDropNull(schm, batch_json, R"([ + {"a": 1, "b": ""}, + {"a": 2, "b": "hello"}, + {"a": 4, "b": "eh"} + ])"); + + batch_json = R"([ + {"a": null, "b": "yo"}, + {"a": 1, "b": null}, + {"a": null, "b": "hello"}, + {"a": 4, "b": null} + ])"; + this->AssertDropNull(schm, batch_json, R"([])"); + this->AssertDropNull(schm, R"([])", R"([])"); +} + +class TestDropNullKernelWithChunkedArray : public TestDropNullKernelTyped { + public: + TestDropNullKernelWithChunkedArray() + : sizes_({0, 1, 2, 4, 16, 31, 1234}), + null_probabilities_({0.0, 0.1, 0.5, 0.9, 1.0}) {} + + void AssertDropNull(const std::shared_ptr& type, + const std::vector& values, + const std::vector& expected) { + std::shared_ptr actual; + ASSERT_OK(this->DoDropNull(type, values, &actual)); + ValidateOutput(actual); + + AssertChunkedEqual(*ChunkedArrayFromJSON(type, expected), *actual); + } + + Status DoDropNull(const std::shared_ptr& type, + const std::vector& values, + std::shared_ptr* out) { + ARROW_ASSIGN_OR_RAISE(Datum out_datum, DropNull(ChunkedArrayFromJSON(type, values))); + *out = out_datum.chunked_array(); + return Status::OK(); + } + + template + void CheckDropNullWithSlices(ArrayFactory&& factory) { + for (auto size : this->sizes_) { + for (auto null_probability : this->null_probabilities_) { + std::shared_ptr concatenated_array; + std::shared_ptr chunked_array; + factory(size, null_probability, &chunked_array, &concatenated_array); + + ASSERT_OK_AND_ASSIGN(auto out_datum, DropNull(chunked_array)); + auto actual_chunked_array = out_datum.chunked_array(); + ASSERT_OK_AND_ASSIGN(auto actual, Concatenate(actual_chunked_array->chunks())); + + ASSERT_OK_AND_ASSIGN(out_datum, DropNull(*concatenated_array)); + auto expected = out_datum.make_array(); + + AssertArraysEqual(*expected, *actual); + } + } + } + + std::vector sizes_; + std::vector null_probabilities_; +}; + +TEST_F(TestDropNullKernelWithChunkedArray, DropNullChunkedArray) { + this->AssertDropNull(int8(), {"[]"}, {"[]"}); + this->AssertDropNull(int8(), {"[null]", "[8, null]"}, {"[8]"}); + + this->AssertDropNull(int8(), {"[null]", "[null, null]"}, {"[]"}); + this->AssertDropNull(int8(), {"[7]", "[8, 9]"}, {"[7]", "[8, 9]"}); + this->AssertDropNull(int8(), {"[]", "[]"}, {"[]", "[]"}); +} + +TEST_F(TestDropNullKernelWithChunkedArray, DropNullChunkedArrayWithSlices) { + // With Null Arrays + this->CheckDropNullWithSlices([this](int32_t size, double null_probability, + std::shared_ptr* out_chunked_array, + std::shared_ptr* out_concatenated_array) { + auto array = std::make_shared(size); + auto offsets = this->Offsets(size, 3); + auto slices = this->Slices(array, offsets); + *out_chunked_array = std::make_shared(std::move(slices)); + + ASSERT_OK_AND_ASSIGN(*out_concatenated_array, + Concatenate((*out_chunked_array)->chunks())); + }); + // Without Null Arrays + this->CheckDropNullWithSlices([this](int32_t size, double null_probability, + std::shared_ptr* out_chunked_array, + std::shared_ptr* out_concatenated_array) { + auto array = this->rng_.ArrayOf(int16(), size, null_probability); + auto offsets = this->Offsets(size, 3); + auto slices = this->Slices(array, offsets); + *out_chunked_array = std::make_shared(std::move(slices)); + + ASSERT_OK_AND_ASSIGN(*out_concatenated_array, + Concatenate((*out_chunked_array)->chunks())); + }); +} + +class TestDropNullKernelWithTable : public TestDropNullKernelTyped
{ + public: + TestDropNullKernelWithTable() + : sizes_({0, 1, 4, 31, 1234}), null_probabilities_({0.0, 0.1, 0.5, 0.9, 1.0}) {} + + void AssertDropNull(const std::shared_ptr& schm, + const std::vector& table_json, + const std::vector& expected_table) { + std::shared_ptr
actual; + ASSERT_OK(this->DoDropNull(schm, table_json, &actual)); + ValidateOutput(actual); + ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected_table), *actual); + } + + Status DoDropNull(const std::shared_ptr& schm, + const std::vector& values, std::shared_ptr
* out) { + ARROW_ASSIGN_OR_RAISE(Datum out_datum, DropNull(TableFromJSON(schm, values))); + *out = out_datum.table(); + return Status::OK(); + } + + template + void CheckDropNullWithSlices(ArrayFactory&& factory) { + for (auto size : this->sizes_) { + for (auto null_probability : this->null_probabilities_) { + std::shared_ptr
table_w_slices; + std::shared_ptr
table_wo_slices; + + factory(size, null_probability, &table_w_slices, &table_wo_slices); + + ASSERT_OK_AND_ASSIGN(auto out_datum, DropNull(table_w_slices)); + ValidateOutput(out_datum); + auto actual = out_datum.table(); + + ASSERT_OK_AND_ASSIGN(out_datum, DropNull(table_wo_slices)); + ValidateOutput(out_datum); + auto expected = out_datum.table(); + if (actual->num_rows() > 0) { + ASSERT_TRUE(actual->num_rows() == expected->num_rows()); + for (int index = 0; index < actual->num_columns(); index++) { + ASSERT_OK_AND_ASSIGN(auto actual_col, + Concatenate(actual->column(index)->chunks())); + ASSERT_OK_AND_ASSIGN(auto expected_col, + Concatenate(expected->column(index)->chunks())); + AssertArraysEqual(*actual_col, *expected_col); + } + } + } + } + } + + std::vector sizes_; + std::vector null_probabilities_; +}; + +TEST_F(TestDropNullKernelWithTable, DropNullTable) { + std::vector> fields = {field("a", int32()), field("b", utf8())}; + auto schm = schema(fields); + + { + std::vector table_json = {R"([ + {"a": null, "b": "yo"}, + {"a": 1, "b": ""} + ])", + R"([ + {"a": 2, "b": "hello"}, + {"a": 4, "b": "eh"} + ])"}; + std::vector expected_table_json = {R"([ + {"a": 1, "b": ""} + ])", + R"([ + {"a": 2, "b": "hello"}, + {"a": 4, "b": "eh"} + ])"}; + this->AssertDropNull(schm, table_json, expected_table_json); + } + { + std::vector table_json = {R"([ + {"a": null, "b": "yo"}, + {"a": 1, "b": null} + ])", + R"([ + {"a": 2, "b": null}, + {"a": null, "b": "eh"} + ])"}; + std::shared_ptr
actual; + ASSERT_OK(this->DoDropNull(schm, table_json, &actual)); + AssertSchemaEqual(schm, actual->schema()); + ASSERT_EQ(actual->num_rows(), 0); + } +} + +TEST_F(TestDropNullKernelWithTable, DropNullTableWithSlices) { + // With Null Arrays + this->CheckDropNullWithSlices([this](int32_t size, double null_probability, + std::shared_ptr
* out_table_w_slices, + std::shared_ptr
* out_table_wo_slices) { + FieldVector fields = {field("a", int32()), field("b", utf8())}; + auto schm = schema(fields); + ASSERT_OK_AND_ASSIGN(auto col_a, MakeArrayOfNull(int32(), size)); + ASSERT_OK_AND_ASSIGN(auto col_b, MakeArrayOfNull(utf8(), size)); + + // Compute random chunkings of columns `a` and `b` + auto slices_a = this->Slices(col_a, this->Offsets(size, 3)); + auto slices_b = this->Slices(col_b, this->Offsets(size, 3)); + + ChunkedArrayVector table_content_w_slices{ + std::make_shared(std::move(slices_a)), + std::make_shared(std::move(slices_b))}; + *out_table_w_slices = Table::Make(schm, std::move(table_content_w_slices), size); + + ChunkedArrayVector table_content_wo_slices{std::make_shared(col_a), + std::make_shared(col_b)}; + *out_table_wo_slices = Table::Make(schm, std::move(table_content_wo_slices), size); + }); + + // Without Null Arrays + this->CheckDropNullWithSlices([this](int32_t size, double null_probability, + std::shared_ptr
* out_table_w_slices, + std::shared_ptr
* out_table_wo_slices) { + FieldVector fields = {field("a", int32()), field("b", utf8())}; + auto schm = schema(fields); + auto col_a = this->rng_.ArrayOf(int32(), size, null_probability); + auto col_b = this->rng_.ArrayOf(utf8(), size, null_probability); + + // Compute random chunkings of columns `a` and `b` + auto slices_a = this->Slices(col_a, this->Offsets(size, 3)); + auto slices_b = this->Slices(col_b, this->Offsets(size, 3)); + + ChunkedArrayVector table_content_w_slices{ + std::make_shared(std::move(slices_a)), + std::make_shared(std::move(slices_b))}; + *out_table_w_slices = Table::Make(schm, std::move(table_content_w_slices), size); + + ChunkedArrayVector table_content_wo_slices{std::make_shared(col_a), + std::make_shared(col_b)}; + *out_table_wo_slices = Table::Make(schm, std::move(table_content_wo_slices), size); + }); +} + +TEST(TestIndicesNonZero, IndicesNonZero) { + Datum actual; + std::shared_ptr result; + + for (const auto& type : NumericTypes()) { + ARROW_SCOPED_TRACE("Input type = ", type->ToString()); + + ASSERT_OK_AND_ASSIGN( + actual, + CallFunction("indices_nonzero", {ArrayFromJSON(type, "[null, 50, 0, 10]")})); + result = actual.make_array(); + AssertArraysEqual(*ArrayFromJSON(uint64(), "[1, 3]"), *result, /*verbose*/ true); + + // empty + ASSERT_OK_AND_ASSIGN(actual, + CallFunction("indices_nonzero", {ArrayFromJSON(type, "[]")})); + result = actual.make_array(); + AssertArraysEqual(*ArrayFromJSON(uint64(), "[]"), *result, /*verbose*/ true); + + // chunked + ChunkedArray chunked_arr( + {ArrayFromJSON(type, "[1, 0, 3]"), ArrayFromJSON(type, "[4, 0, 6]")}); + ASSERT_OK_AND_ASSIGN( + actual, CallFunction("indices_nonzero", {static_cast(chunked_arr)})); + AssertArraysEqual(*ArrayFromJSON(uint64(), "[0, 2, 3, 5]"), *actual.make_array(), + /*verbose*/ true); + + // empty chunked + ChunkedArray chunked_arr_empty({ArrayFromJSON(type, "[1, 0, 3]"), + ArrayFromJSON(type, "[]"), + ArrayFromJSON(type, "[4, 0, 6]")}); + ASSERT_OK_AND_ASSIGN( + actual, CallFunction("indices_nonzero", {static_cast(chunked_arr_empty)})); + AssertArraysEqual(*ArrayFromJSON(uint64(), "[0, 2, 3, 5]"), *actual.make_array(), + /*verbose*/ true); + } +} + +TEST(TestIndicesNonZero, IndicesNonZeroBoolean) { + Datum actual; + std::shared_ptr result; + + // bool + ASSERT_OK_AND_ASSIGN( + actual, CallFunction("indices_nonzero", + {ArrayFromJSON(boolean(), "[null, true, false, true]")})); + result = actual.make_array(); + AssertArraysEqual(*result, *ArrayFromJSON(uint64(), "[1, 3]"), /*verbose*/ true); +} + +TEST(TestIndicesNonZero, IndicesNonZeroDecimal) { + Datum actual; + std::shared_ptr result; + + for (const auto& decimal_factory : {decimal128, decimal256}) { + ASSERT_OK_AND_ASSIGN( + actual, CallFunction("indices_nonzero", + {DecimalArrayFromJSON(decimal_factory(2, -2), + R"(["12E2",null,"0","0"])")})); + result = actual.make_array(); + AssertArraysEqual(*ArrayFromJSON(uint64(), "[0]"), *result, /*verbose*/ true); + + ASSERT_OK_AND_ASSIGN( + actual, + CallFunction( + "indices_nonzero", + {DecimalArrayFromJSON( + decimal_factory(6, 9), + R"(["765483.999999999","0.000000000",null,"-987645.000000001"])")})); + result = actual.make_array(); + AssertArraysEqual(*ArrayFromJSON(uint64(), "[0, 3]"), *result, /*verbose*/ true); + } +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/registry_internal.h b/cpp/src/arrow/compute/registry_internal.h index cdc9f804e72f1..76e5436366784 100644 --- a/cpp/src/arrow/compute/registry_internal.h +++ b/cpp/src/arrow/compute/registry_internal.h @@ -50,6 +50,7 @@ void RegisterVectorHash(FunctionRegistry* registry); void RegisterVectorNested(FunctionRegistry* registry); void RegisterVectorRank(FunctionRegistry* registry); void RegisterVectorReplace(FunctionRegistry* registry); +void RegisterVectorScatter(FunctionRegistry* registry); void RegisterVectorSelectK(FunctionRegistry* registry); void RegisterVectorSelection(FunctionRegistry* registry); void RegisterVectorSort(FunctionRegistry* registry);