Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex-PLACET committed Nov 19, 2024
1 parent 9100ca0 commit c6174cf
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 35 deletions.
24 changes: 18 additions & 6 deletions include/sparrow/array_api.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,14 +168,26 @@ namespace sparrow

/**
* Slices the array to keep only the elements between the given \p start and \p end.
* The \ref array is modified in place. The data is not modified, only the ArrowArray.offset and
* ArrowArray.length are updated. If \p end is greater than the size of the buffers, the following elements will be invalid.
* A copy of the \ref array is modified. The data is not modified, only the ArrowArray.offset and
* ArrowArray.length are updated. If \p end is greater than the size of the buffers, the following
* elements will be invalid.
*
* @param start The index of the first element to keep.
* @param end The index of the first element to discard.
* @param start The index of the first element to keep. Must be less than \p end.
* @param end The index of the first element to discard. Must be less than the size of the buffers.
*/
SPARROW_API void slice(size_type start, size_type end);

SPARROW_API array slice(size_type start, size_type end) const;

/**
* Slices the array to keep only the elements between the given \p start and \p end.
* A view of the \ref array is returned. The data is not modified, only the ArrowArray.offset and
* ArrowArray.length are updated. If \p end is greater than the size of the buffers, the following
* elements will be invalid.
*
* @param start The index of the first element to keep. Must be less than \p end.
* @param end The index of the first element to discard. Must be less than the size of the buffers.
*/
SPARROW_API array slice_view(size_type start, size_type end) const ;

private:

SPARROW_API arrow_proxy& get_arrow_proxy();
Expand Down
20 changes: 17 additions & 3 deletions src/array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,25 @@ namespace sparrow
return (*this)[size() - 1];
}

void array::slice(size_type start, size_type end)
array array::slice(size_type start, size_type end) const
{
SPARROW_ASSERT_TRUE(start <= end);
get_arrow_proxy().set_offset(start);
get_arrow_proxy().set_length(end - start);
array copy = *this;
arrow_proxy& arrow_proxy_copy = copy.get_arrow_proxy();
arrow_proxy_copy.set_offset(start);
arrow_proxy_copy.set_length(end - start);
return copy;
}

array array::slice_view(size_type start, size_type end) const
{
SPARROW_ASSERT_TRUE(start <= end);
const arrow_proxy& arrow_proxy_copy = get_arrow_proxy();
ArrowSchema as = arrow_proxy_copy.schema();
ArrowArray ar = arrow_proxy_copy.array();
ar.offset = static_cast<int64_t>(start);
ar.length = static_cast<int64_t>(end - start);
return {std::move(ar), std::move(as)};
}

arrow_proxy& array::get_arrow_proxy()
Expand Down
79 changes: 53 additions & 26 deletions test/test_array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,6 @@ namespace sparrow
}
TEST_CASE_TEMPLATE_APPLY(visit_id, testing_types);


TEST_CASE_TEMPLATE_DEFINE("slice", AR, slice_id)
{
using const_reference = typename AR::const_reference;
Expand All @@ -317,34 +316,62 @@ namespace sparrow
array ar = test::make_array<scalar_value_type>(size);

REQUIRE_EQ(ar.size(), size);
CHECK_EQ(std::get<const_reference>(ar[0]), make_nullable<scalar_value_type>(0));
CHECK_EQ(std::get<const_reference>(ar[1]), make_nullable<scalar_value_type>(1));
CHECK_EQ(std::get<const_reference>(ar[2]), make_nullable<scalar_value_type>(2));
CHECK_EQ(std::get<const_reference>(ar[3]), make_nullable<scalar_value_type>(3));
CHECK_EQ(std::get<const_reference>(ar[4]), make_nullable<scalar_value_type>(4));
CHECK_EQ(std::get<const_reference>(ar[5]), make_nullable<scalar_value_type>(5));
CHECK_EQ(std::get<const_reference>(ar[6]), make_nullable<scalar_value_type>(6));
CHECK_EQ(std::get<const_reference>(ar[7]), make_nullable<scalar_value_type>(7));
CHECK_EQ(std::get<const_reference>(ar[8]), make_nullable<scalar_value_type>(8));
CHECK_EQ(std::get<const_reference>(ar[9]), make_nullable<scalar_value_type>(9));

ar.slice(1, 5);
REQUIRE_EQ(ar.size(), 4);
CHECK_EQ(std::get<const_reference>(ar[0]), make_nullable<scalar_value_type>(1));
CHECK_EQ(std::get<const_reference>(ar[1]), make_nullable<scalar_value_type>(2));
CHECK_EQ(std::get<const_reference>(ar[2]), make_nullable<scalar_value_type>(3));
CHECK_EQ(std::get<const_reference>(ar[3]), make_nullable<scalar_value_type>(4));
scalar_value_type scalar_value = 0;
for (size_t i = 0; i < size; ++i, ++scalar_value)
{
CHECK_EQ(std::get<const_reference>(ar[i]), make_nullable(scalar_value));
}

ar.slice(2, 8);
const auto slice_1_5 = ar.slice(1, 5);
REQUIRE_EQ(slice_1_5.size(), 4);
scalar_value = static_cast<scalar_value_type>(1);
for (size_t i = 0; i < slice_1_5.size(); ++i, ++scalar_value)
{
CHECK_EQ(std::get<const_reference>(slice_1_5[i]).get(), scalar_value);
}

REQUIRE_EQ(ar.size(), 6);
CHECK_EQ(std::get<const_reference>(ar[0]), make_nullable<scalar_value_type>(2));
CHECK_EQ(std::get<const_reference>(ar[1]), make_nullable<scalar_value_type>(3));
CHECK_EQ(std::get<const_reference>(ar[2]), make_nullable<scalar_value_type>(4));
CHECK_EQ(std::get<const_reference>(ar[3]), make_nullable<scalar_value_type>(5));
CHECK_EQ(std::get<const_reference>(ar[4]), make_nullable<scalar_value_type>(6));
CHECK_EQ(std::get<const_reference>(ar[5]), make_nullable<scalar_value_type>(7));
ar.slice(2, 8);
const auto slice_2_8 = ar.slice(2, 8);
REQUIRE_EQ(slice_2_8.size(), 6);
scalar_value = static_cast<scalar_value_type>(2);
for (size_t i = 0; i < slice_2_8.size(); ++i, ++scalar_value)
{
CHECK_EQ(std::get<const_reference>(slice_2_8[i]).get(), scalar_value);
}
}
TEST_CASE_TEMPLATE_APPLY(slice_id, testing_types);

TEST_CASE_TEMPLATE_DEFINE("slice_view", AR, slice_view_id)
{
using const_reference = typename AR::const_reference;
using scalar_value_type = typename AR::inner_value_type;

constexpr size_t size = 10;
array ar = test::make_array<scalar_value_type>(size);

REQUIRE_EQ(ar.size(), size);
scalar_value_type scalar_value = 0;
for (size_t i = 0; i < size; ++i, ++scalar_value)
{
CHECK_EQ(std::get<const_reference>(ar[i]).get(), scalar_value);
}

const auto slice_1_5 = ar.slice_view(1, 5);
REQUIRE_EQ(slice_1_5.size(), 4);
scalar_value = static_cast<scalar_value_type>(1);
for (size_t i = 0; i < slice_1_5.size(); ++i, ++scalar_value)
{
CHECK_EQ(std::get<const_reference>(slice_1_5[i]).get(), scalar_value);
}

ar.slice_view(2, 8);
const auto slice_2_8 = ar.slice_view(2, 8);
REQUIRE_EQ(slice_2_8.size(), 6);
scalar_value = static_cast<scalar_value_type>(2);
for (size_t i = 0; i < slice_2_8.size(); ++i, ++scalar_value)
{
CHECK_EQ(std::get<const_reference>(slice_2_8[i]).get(), scalar_value);
}
}
}
}

0 comments on commit c6174cf

Please sign in to comment.