diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h index eb441bb31c9bc8..43c9b80edff78e 100644 --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -1023,6 +1023,16 @@ class concat_iterator std::forward_iterator_tag, ValueT> { using BaseT = typename concat_iterator::iterator_facade_base; + static constexpr bool ReturnsByValue = + !(std::is_reference_v())> && ...); + + using reference_type = + typename std::conditional_t; + + using handle_type = + typename std::conditional_t, + ValueT *>; + /// We store both the current and end iterators for each concatenated /// sequence in a tuple of pairs. /// @@ -1065,27 +1075,30 @@ class concat_iterator /// Returns null if the specified iterator is at the end. Otherwise, /// dereferences the iterator and returns the address of the resulting /// reference. - template ValueT *getHelper() const { + template handle_type getHelper() const { auto &Begin = std::get(Begins); auto &End = std::get(Ends); if (Begin == End) - return nullptr; + return {}; - return &*Begin; + if constexpr (ReturnsByValue) + return *Begin; + else + return &*Begin; } /// Finds the first non-end iterator, dereferences, and returns the resulting /// reference. /// /// It is an error to call this with all iterators at the end. - template ValueT &get(std::index_sequence) const { + template reference_type get(std::index_sequence) const { // Build a sequence of functions to get from iterator if possible. - ValueT *(concat_iterator::*GetHelperFns[])() const = { - &concat_iterator::getHelper...}; + handle_type (concat_iterator::*GetHelperFns[])() + const = {&concat_iterator::getHelper...}; // Loop over them, and return the first result we find. for (auto &GetHelperFn : GetHelperFns) - if (ValueT *P = (this->*GetHelperFn)()) + if (auto P = (this->*GetHelperFn)()) return *P; llvm_unreachable("Attempted to get a pointer from an end concat iterator!"); @@ -1107,7 +1120,7 @@ class concat_iterator return *this; } - ValueT &operator*() const { + reference_type operator*() const { return get(std::index_sequence_for()); } diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp index ee8299c9b48612..406ff2bc16073b 100644 --- a/llvm/unittests/ADT/STLExtrasTest.cpp +++ b/llvm/unittests/ADT/STLExtrasTest.cpp @@ -504,6 +504,43 @@ TEST(STLExtrasTest, ConcatRange) { EXPECT_EQ(Expected, Test); } +template struct Iterator { + int i = 0; + T operator*() const { return i; } + Iterator &operator++() { + ++i; + return *this; + } + bool operator==(Iterator RHS) const { return i == RHS.i; } +}; + +template struct RangeWithValueType { + int i; + RangeWithValueType(int i) : i(i) {} + Iterator begin() { return Iterator{0}; } + Iterator end() { return Iterator{i}; } +}; + +TEST(STLExtrasTest, ValueReturn) { + RangeWithValueType R(1); + auto C = concat(R, R); + auto I = C.begin(); + ASSERT_NE(I, C.end()); + static_assert(std::is_same_v); + auto V = *I; + ASSERT_EQ(V, 0); +} + +TEST(STLExtrasTest, ReferenceReturn) { + RangeWithValueType R(1); + auto C = concat(R, R); + auto I = C.begin(); + ASSERT_NE(I, C.end()); + static_assert(std::is_same_v); + auto V = *I; + ASSERT_EQ(V, 0); +} + TEST(STLExtrasTest, PartitionAdaptor) { std::vector V = {1, 2, 3, 4, 5, 6, 7, 8};