diff --git a/cpp/src/arrow/array/scatter.cc b/cpp/src/arrow/array/scatter.cc index 7402ac7260128..52caf6f06645f 100644 --- a/cpp/src/arrow/array/scatter.cc +++ b/cpp/src/arrow/array/scatter.cc @@ -42,6 +42,24 @@ Status ScatterBitmap(const uint8_t* in_bitmap, const uint8_t* mask_bitmap, return Status::OK(); } +Result> ScatterBuffer(const Buffer& src, const uint8_t* mask, + int64_t byte_width, int64_t length, + MemoryPool* pool) { + ARROW_ASSIGN_OR_RAISE(auto out, AllocateBuffer(length * byte_width, pool)); + auto* in_data = src.data(); + auto* out_data = out->mutable_data(); + int64_t i_in = 0, i_out = 0; + VisitNullBitmapInline( + mask, /*valid_bits_offset=*/0, length, kUnknownNullCount, + [&] { + std::memcpy(out_data + i_out++ * byte_width, in_data + i_in++ * byte_width, + byte_width); + }, + [&] { ++i_out; }); + + return Status::OK(); +} + struct ScatterImpl { explicit ScatterImpl(const ArrayData& in, const BooleanArray& mask, MemoryPool* pool) : in_(in), mask_(mask), pool_(pool), out_(std::shared_ptr()) { @@ -98,6 +116,13 @@ struct ScatterImpl { return Status::OK(); } + Status Visit(const FixedWidthType& type) { + DCHECK_EQ(type.bit_width() % 8, 0); + return ScatterBuffer(*in_.buffers[1], out_->buffers[0]->data(), type.bit_width() / 8, + out_->length, pool_) + .Value(&out_->buffers[1]); + } + Status Visit(const DataType&) { return Status::NotImplemented("Scatter not implemented for type ", *out_->type); }