Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADT] Make concat able to handle ranges with iterators that return by value (such as zip) #112783

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions llvm/include/llvm/ADT/STLExtras.h
Original file line number Diff line number Diff line change
Expand Up @@ -1023,6 +1023,16 @@ class concat_iterator
std::forward_iterator_tag, ValueT> {
using BaseT = typename concat_iterator::iterator_facade_base;

static constexpr bool ReturnsByValue =
!(std::is_reference_v<decltype(*std::declval<IterTs>())> && ...);

using reference_type =
typename std::conditional_t<ReturnsByValue, ValueT, ValueT &>;

using handle_type =
typename std::conditional_t<ReturnsByValue, std::optional<ValueT>,
ValueT *>;

/// We store both the current and end iterators for each concatenated
/// sequence in a tuple of pairs.
///
Expand Down Expand Up @@ -1065,27 +1075,30 @@ class concat_iterator
/// Returns null if the specified iterator is at the end. Otherwise,
/// dereferences the iterator and returns the address of the resulting
/// reference.
template <size_t Index> ValueT *getHelper() const {
template <size_t Index> handle_type getHelper() const {
auto &Begin = std::get<Index>(Begins);
auto &End = std::get<Index>(Ends);
if (Begin == End)
return nullptr;
return {};
kuhar marked this conversation as resolved.
Show resolved Hide resolved

return &*Begin;
if constexpr (ReturnsByValue)
return *Begin;
else
return &*Begin;
}

/// Finds the first non-end iterator, dereferences, and returns the resulting
/// reference.
///
/// It is an error to call this with all iterators at the end.
template <size_t... Ns> ValueT &get(std::index_sequence<Ns...>) const {
template <size_t... Ns> reference_type get(std::index_sequence<Ns...>) const {
// Build a sequence of functions to get from iterator if possible.
ValueT *(concat_iterator::*GetHelperFns[])() const = {
&concat_iterator::getHelper<Ns>...};
handle_type (concat_iterator::*GetHelperFns[])()
const = {&concat_iterator::getHelper<Ns>...};

// Loop over them, and return the first result we find.
for (auto &GetHelperFn : GetHelperFns)
if (ValueT *P = (this->*GetHelperFn)())
if (auto P = (this->*GetHelperFn)())
return *P;

llvm_unreachable("Attempted to get a pointer from an end concat iterator!");
Expand All @@ -1107,7 +1120,7 @@ class concat_iterator
return *this;
}

ValueT &operator*() const {
reference_type operator*() const {
return get(std::index_sequence_for<IterTs...>());
}

Expand Down
37 changes: 37 additions & 0 deletions llvm/unittests/ADT/STLExtrasTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,43 @@ TEST(STLExtrasTest, ConcatRange) {
EXPECT_EQ(Expected, Test);
}

template <typename T> struct Iterator {
int i = 0;
T operator*() const { return i; }
Iterator &operator++() {
++i;
return *this;
}
bool operator==(Iterator RHS) const { return i == RHS.i; }
};

template <typename T> struct RangeWithValueType {
int i;
RangeWithValueType(int i) : i(i) {}
Iterator<T> begin() { return Iterator<T>{0}; }
Iterator<T> end() { return Iterator<T>{i}; }
};

TEST(STLExtrasTest, ValueReturn) {
RangeWithValueType<int> R(1);
auto C = concat<int>(R, R);
auto I = C.begin();
ASSERT_NE(I, C.end());
static_assert(std::is_same_v<decltype((*I)), int>);
auto V = *I;
ASSERT_EQ(V, 0);
}

TEST(STLExtrasTest, ReferenceReturn) {
RangeWithValueType<const int&> R(1);
auto C = concat<const int>(R, R);
auto I = C.begin();
ASSERT_NE(I, C.end());
static_assert(std::is_same_v<decltype((*I)), const int &>);
auto V = *I;
ASSERT_EQ(V, 0);
}

TEST(STLExtrasTest, PartitionAdaptor) {
std::vector<int> V = {1, 2, 3, 4, 5, 6, 7, 8};

Expand Down
Loading