Skip to content

Commit

Permalink
Implementation done and basic tests
Browse files Browse the repository at this point in the history
  • Loading branch information
zanmato1984 committed Oct 6, 2024
1 parent be88f0c commit b445c36
Show file tree
Hide file tree
Showing 10 changed files with 526 additions and 102 deletions.
2 changes: 1 addition & 1 deletion cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@ if(ARROW_COMPUTE)
compute/kernels/vector_array_sort.cc
compute/kernels/vector_cumulative_ops.cc
compute/kernels/vector_pairwise.cc
compute/kernels/vector_permute.cc
compute/kernels/vector_placement.cc
compute/kernels/vector_nested.cc
compute/kernels/vector_rank.cc
compute/kernels/vector_replace.cc
Expand Down
20 changes: 20 additions & 0 deletions cpp/src/arrow/compute/api_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ static auto kPairwiseOptionsType = GetFunctionOptionsType<PairwiseOptions>(
DataMember("periods", &PairwiseOptions::periods));
static auto kListFlattenOptionsType = GetFunctionOptionsType<ListFlattenOptions>(
DataMember("recursive", &ListFlattenOptions::recursive));
static auto kReverseIndexOptionsType = GetFunctionOptionsType<ReverseIndexOptions>(
DataMember("output_length", &ReverseIndexOptions::output_length),
DataMember("output_type", &ReverseIndexOptions::output_type),
DataMember("output_non_taken", &ReverseIndexOptions::output_non_taken));
static auto kPermuteOptionsType =
GetFunctionOptionsType<PermuteOptions>(DataMember("bound", &PermuteOptions::bound));
} // namespace
Expand Down Expand Up @@ -232,6 +236,15 @@ ListFlattenOptions::ListFlattenOptions(bool recursive)
: FunctionOptions(internal::kListFlattenOptionsType), recursive(recursive) {}
constexpr char ListFlattenOptions::kTypeName[];

ReverseIndexOptions::ReverseIndexOptions(int64_t output_length,
std::shared_ptr<DataType> output_type,
std::shared_ptr<Scalar> output_non_taken)
: FunctionOptions(internal::kPermuteOptionsType),
output_length(output_length),
output_type(std::move(output_type)),
output_non_taken(std::move(output_non_taken)) {}
constexpr char ReverseIndexOptions::kTypeName[];

PermuteOptions::PermuteOptions(int64_t bound)
: FunctionOptions(internal::kPermuteOptionsType), bound(bound) {}
constexpr char PermuteOptions::kTypeName[];
Expand Down Expand Up @@ -439,6 +452,13 @@ Result<Datum> CumulativeMean(const Datum& values, const CumulativeOptions& optio
// ----------------------------------------------------------------------
// Permute functions

Result<std::shared_ptr<Array>> ReverseIndex(const Datum& indices,
const ReverseIndexOptions& options,
ExecContext* ctx) {
ARROW_ASSIGN_OR_RAISE(Datum result, CallFunction("reverse_index", {indices}, ctx));
return result.make_array();
}

Result<std::shared_ptr<Array>> Permute(const Datum& values, const Datum& indices,
const PermuteOptions& options, ExecContext* ctx) {
ARROW_ASSIGN_OR_RAISE(Datum result, CallFunction("permute", {values, indices}, ctx));
Expand Down
19 changes: 19 additions & 0 deletions cpp/src/arrow/compute/api_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,25 @@ class ARROW_EXPORT ListFlattenOptions : public FunctionOptions {
bool recursive = false;
};

/// \brief Options for reverse_index function
class ARROW_EXPORT ReverseIndexOptions : public FunctionOptions {
public:
explicit ReverseIndexOptions(int64_t output_length = 0,
std::shared_ptr<DataType> output_type = int32(),
std::shared_ptr<Scalar> output_non_taken = NULLPTR);
static constexpr char const kTypeName[] = "ReverseIndexOptions";
static ReverseIndexOptions Defaults() { return ReverseIndexOptions(); }

/// \brief The upper bound of the permutation. If -1, the output will be sized as the
/// maximum value in the indices array + 1. Otherwise, the output will be of size bound,
/// and any indices that are greater of equal to bound will be ignored.
int64_t output_length = 0;
/// \brief The type of the output reverse index. If null, the output type will be the
/// smallest possible integer type that can hold the maximum value in the indices array.
std::shared_ptr<DataType> output_type = int32();
std::shared_ptr<Scalar> output_non_taken = NULLPTR;
};

/// \brief Options for permute function
class ARROW_EXPORT PermuteOptions : public FunctionOptions {
public:
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/arrow/compute/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ add_arrow_compute_test(vector_selection_test
EXTRA_LINK_LIBS
arrow_compute_kernels_testing)

add_arrow_compute_test(vector_placement_test
SOURCES
vector_placement_test.cc
EXTRA_LINK_LIBS
arrow_compute_kernels_testing)

add_arrow_benchmark(vector_hash_benchmark PREFIX "arrow-compute")
add_arrow_benchmark(vector_sort_benchmark PREFIX "arrow-compute")
add_arrow_benchmark(vector_partition_benchmark PREFIX "arrow-compute")
Expand Down
5 changes: 3 additions & 2 deletions cpp/src/arrow/compute/kernels/codegen_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -985,8 +985,9 @@ ArrayKernelExec GenerateFloatingPoint(detail::GetTypeId get_id) {
// Generate a kernel given a templated functor for integer types
//
// See "Numeric" above for description of the generator functor
template <template <typename...> class Generator, typename Type0, typename... Args>
ArrayKernelExec GenerateInteger(detail::GetTypeId get_id) {
template <template <typename...> class Generator, typename Type0,
typename KernelType = ArrayKernelExec, typename... Args>
KernelType GenerateInteger(detail::GetTypeId get_id) {
switch (get_id.id) {
case Type::INT8:
return Generator<Type0, Int8Type, Args...>::Exec;
Expand Down
99 changes: 0 additions & 99 deletions cpp/src/arrow/compute/kernels/vector_permute.cc

This file was deleted.

Loading

0 comments on commit b445c36

Please sign in to comment.