Skip to content

Commit

Permalink
Add permute function options
Browse files Browse the repository at this point in the history
  • Loading branch information
zanmato1984 committed Oct 2, 2024
1 parent 216e217 commit f3c73ea
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 9 deletions.
16 changes: 16 additions & 0 deletions cpp/src/arrow/compute/api_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ static auto kPairwiseOptionsType = GetFunctionOptionsType<PairwiseOptions>(
DataMember("periods", &PairwiseOptions::periods));
static auto kListFlattenOptionsType = GetFunctionOptionsType<ListFlattenOptions>(
DataMember("recursive", &ListFlattenOptions::recursive));
static auto kPermuteOptionsType =
GetFunctionOptionsType<PermuteOptions>(DataMember("bound", &PermuteOptions::bound));
} // namespace
} // namespace internal

Expand Down Expand Up @@ -230,6 +232,10 @@ ListFlattenOptions::ListFlattenOptions(bool recursive)
: FunctionOptions(internal::kListFlattenOptionsType), recursive(recursive) {}
constexpr char ListFlattenOptions::kTypeName[];

PermuteOptions::PermuteOptions(int64_t bound)
: FunctionOptions(internal::kPermuteOptionsType), bound(bound) {}
constexpr char PermuteOptions::kTypeName[];

namespace internal {
void RegisterVectorOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kFilterOptionsType));
Expand All @@ -244,6 +250,7 @@ void RegisterVectorOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kRankOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kPairwiseOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kListFlattenOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kPermuteOptionsType));
}
} // namespace internal

Expand Down Expand Up @@ -429,5 +436,14 @@ Result<Datum> CumulativeMean(const Datum& values, const CumulativeOptions& optio
return CallFunction("cumulative_mean", {Datum(values)}, &options, ctx);
}

// ----------------------------------------------------------------------
// Permute functions

Result<std::shared_ptr<Array>> Permute(const Datum& values, const Datum& indices,
const PermuteOptions& options, ExecContext* ctx) {
ARROW_ASSIGN_OR_RAISE(Datum result, CallFunction("permute", {values, indices}, ctx));
return result.make_array();
}

} // namespace compute
} // namespace arrow
18 changes: 18 additions & 0 deletions cpp/src/arrow/compute/api_vector.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,19 @@ class ARROW_EXPORT ListFlattenOptions : public FunctionOptions {
bool recursive = false;
};

/// \brief Options for permute function
class ARROW_EXPORT PermuteOptions : public FunctionOptions {
public:
explicit PermuteOptions(int64_t bound = -1);
static constexpr char const kTypeName[] = "PermuteOptions";
static PermuteOptions Defaults() { return PermuteOptions(); }

/// \brief The upper bound of the permutation. If -1, the output will be sized as the
/// maximum value in the indices array + 1. Otherwise, the output will be of size bound,
/// and any indices that are greater of equal to bound will be ignored.
int64_t bound = -1;
};

/// @}

/// \brief Filter with a boolean selection filter
Expand Down Expand Up @@ -705,5 +718,10 @@ Result<std::shared_ptr<Array>> PairwiseDiff(const Array& array,
bool check_overflow = false,
ExecContext* ctx = NULLPTR);

Result<std::shared_ptr<Array>> Permute(
const Datum& values, const Datum& indices,
const PermuteOptions& options = PermuteOptions::Defaults(),
ExecContext* ctx = NULLPTR);

} // namespace compute
} // namespace arrow
26 changes: 17 additions & 9 deletions cpp/src/arrow/compute/kernels/vector_permute.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// #include "arrow/compute/kernels/vector_gather_internal.h"
#include "arrow/compute/api_vector.h"
#include "arrow/compute/function.h"
#include "arrow/compute/kernels/codegen_internal.h"
#include "arrow/compute/registry.h"
Expand Down Expand Up @@ -125,10 +126,11 @@ Status FixedWidthPermuteExec(KernelContext* ctx, const ExecSpan& batch, ExecResu
// 0-length fixed-size binary or lists were handled above on `case 0`
DCHECK_GT(byte_width, 0);
return PermuteIndexDispatch<FixedWidthTakeImpl,
/*ValueBitWidth=*/std::integral_constant<int, 8>,
/*OutputIsZeroInitialized=*/std::false_type,
/*WithFactor=*/std::true_type>(ctx, values, indices, out_arr,
/*factor=*/byte_width);
/*ValueBitWidth=*/std::integral_constant<int, 8>,
/*OutputIsZeroInitialized=*/std::false_type,
/*WithFactor=*/std::true_type>(ctx, values, indices,
out_arr,
/*factor=*/byte_width);
}
return Status::NotImplemented("Unsupported primitive type for permute: ", *values.type);
return Status::OK();
Expand All @@ -141,9 +143,10 @@ struct PermuteKernelSignature {
};

std::unique_ptr<Function> MakePermuteFunction(
std::string name, std::vector<PermuteKernelSignature>&& signatures, FunctionDoc doc) {
auto func =
std::make_unique<VectorFunction>(std::move(name), Arity::Binary(), std::move(doc));
std::string name, std::vector<PermuteKernelSignature>&& signatures, FunctionDoc doc,
const FunctionOptions* default_options) {
auto func = std::make_unique<VectorFunction>(std::move(name), Arity::Binary(),
std::move(doc), default_options);
for (auto& signature : signatures) {
auto kernel = VectorKernel{};
kernel.signature = KernelSignature::Make(
Expand All @@ -161,15 +164,20 @@ const FunctionDoc permute_doc(
"Place each input value to the output array at position specified by `indices`",
{"input", "indices"});

const PermuteOptions* GetDefaultPermuteOptions() {
static const auto kDefaultPermuteOptions = PermuteOptions::Defaults();
return &kDefaultPermuteOptions;
}

} // namespace

void RegisterVectorPermute(FunctionRegistry* registry) {
auto permute_indices = match::Integer();
std::vector<PermuteKernelSignature> signatures = {
{InputType(match::Primitive()), permute_indices, FixedWidthPermuteExec},
};
DCHECK_OK(registry->AddFunction(
MakePermuteFunction("permute", std::move(signatures), permute_doc)));
DCHECK_OK(registry->AddFunction(MakePermuteFunction(
"permute", std::move(signatures), permute_doc, GetDefaultPermuteOptions())));
}

} // namespace arrow::compute::internal

0 comments on commit f3c73ea

Please sign in to comment.