Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADT] Make concat able to handle ranges with iterators that return by value (such as zip) #112783

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

dwblaikie
Copy link
Collaborator

If any iterator in the concatenation returns by value, the result must return by value otherwise it'll produce dangling references.

(some context that may or may not be relevant to this part of the code may be in 981ce8f )

An alternative to #112441

… value (such as zip)

If any iterator in the concatenation returns by value, the result must
return by value otherwise it'll produce dangling references.

An alternative to llvm#112441
@llvmbot
Copy link
Collaborator

llvmbot commented Oct 17, 2024

@llvm/pr-subscribers-llvm-adt

Author: David Blaikie (dwblaikie)

Changes

If any iterator in the concatenation returns by value, the result must return by value otherwise it'll produce dangling references.

(some context that may or may not be relevant to this part of the code may be in 981ce8f )

An alternative to #112441


Full diff: https://github.com/llvm/llvm-project/pull/112783.diff

2 Files Affected:

  • (modified) llvm/include/llvm/ADT/STLExtras.h (+21-8)
  • (modified) llvm/unittests/ADT/STLExtrasTest.cpp (+37)
diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h
index eb441bb31c9bc8..43c9b80edff78e 100644
--- a/llvm/include/llvm/ADT/STLExtras.h
+++ b/llvm/include/llvm/ADT/STLExtras.h
@@ -1023,6 +1023,16 @@ class concat_iterator
                                   std::forward_iterator_tag, ValueT> {
   using BaseT = typename concat_iterator::iterator_facade_base;
 
+  static constexpr bool ReturnsByValue =
+      !(std::is_reference_v<decltype(*std::declval<IterTs>())> && ...);
+
+  using reference_type =
+      typename std::conditional_t<ReturnsByValue, ValueT, ValueT &>;
+
+  using handle_type =
+      typename std::conditional_t<ReturnsByValue, std::optional<ValueT>,
+                                  ValueT *>;
+
   /// We store both the current and end iterators for each concatenated
   /// sequence in a tuple of pairs.
   ///
@@ -1065,27 +1075,30 @@ class concat_iterator
   /// Returns null if the specified iterator is at the end. Otherwise,
   /// dereferences the iterator and returns the address of the resulting
   /// reference.
-  template <size_t Index> ValueT *getHelper() const {
+  template <size_t Index> handle_type getHelper() const {
     auto &Begin = std::get<Index>(Begins);
     auto &End = std::get<Index>(Ends);
     if (Begin == End)
-      return nullptr;
+      return {};
 
-    return &*Begin;
+    if constexpr (ReturnsByValue)
+      return *Begin;
+    else
+      return &*Begin;
   }
 
   /// Finds the first non-end iterator, dereferences, and returns the resulting
   /// reference.
   ///
   /// It is an error to call this with all iterators at the end.
-  template <size_t... Ns> ValueT &get(std::index_sequence<Ns...>) const {
+  template <size_t... Ns> reference_type get(std::index_sequence<Ns...>) const {
     // Build a sequence of functions to get from iterator if possible.
-    ValueT *(concat_iterator::*GetHelperFns[])() const = {
-        &concat_iterator::getHelper<Ns>...};
+    handle_type (concat_iterator::*GetHelperFns[])()
+        const = {&concat_iterator::getHelper<Ns>...};
 
     // Loop over them, and return the first result we find.
     for (auto &GetHelperFn : GetHelperFns)
-      if (ValueT *P = (this->*GetHelperFn)())
+      if (auto P = (this->*GetHelperFn)())
         return *P;
 
     llvm_unreachable("Attempted to get a pointer from an end concat iterator!");
@@ -1107,7 +1120,7 @@ class concat_iterator
     return *this;
   }
 
-  ValueT &operator*() const {
+  reference_type operator*() const {
     return get(std::index_sequence_for<IterTs...>());
   }
 
diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp
index ee8299c9b48612..406ff2bc16073b 100644
--- a/llvm/unittests/ADT/STLExtrasTest.cpp
+++ b/llvm/unittests/ADT/STLExtrasTest.cpp
@@ -504,6 +504,43 @@ TEST(STLExtrasTest, ConcatRange) {
   EXPECT_EQ(Expected, Test);
 }
 
+template <typename T> struct Iterator {
+  int i = 0;
+  T operator*() const { return i; }
+  Iterator &operator++() {
+    ++i;
+    return *this;
+  }
+  bool operator==(Iterator RHS) const { return i == RHS.i; }
+};
+
+template <typename T> struct RangeWithValueType {
+  int i;
+  RangeWithValueType(int i) : i(i) {}
+  Iterator<T> begin() { return Iterator<T>{0}; }
+  Iterator<T> end() { return Iterator<T>{i}; }
+};
+
+TEST(STLExtrasTest, ValueReturn) {
+  RangeWithValueType<int> R(1);
+  auto C = concat<int>(R, R);
+  auto I = C.begin();
+  ASSERT_NE(I, C.end());
+  static_assert(std::is_same_v<decltype((*I)), int>);
+  auto V = *I;
+  ASSERT_EQ(V, 0);
+}
+
+TEST(STLExtrasTest, ReferenceReturn) {
+  RangeWithValueType<const int&> R(1);
+  auto C = concat<const int>(R, R);
+  auto I = C.begin();
+  ASSERT_NE(I, C.end());
+  static_assert(std::is_same_v<decltype((*I)), const int &>);
+  auto V = *I;
+  ASSERT_EQ(V, 0);
+}
+
 TEST(STLExtrasTest, PartitionAdaptor) {
   std::vector<int> V = {1, 2, 3, 4, 5, 6, 7, 8};
 

Copy link

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff 23da16933b8ad48a967905369f576e5ec45b985f dfb8a466d49f79db9b9b775a927cc4c4b1c28da3 --extensions h,cpp -- llvm/include/llvm/ADT/STLExtras.h llvm/unittests/ADT/STLExtrasTest.cpp
View the diff from clang-format here.
diff --git a/llvm/unittests/ADT/STLExtrasTest.cpp b/llvm/unittests/ADT/STLExtrasTest.cpp
index 406ff2bc16..4f37087663 100644
--- a/llvm/unittests/ADT/STLExtrasTest.cpp
+++ b/llvm/unittests/ADT/STLExtrasTest.cpp
@@ -532,7 +532,7 @@ TEST(STLExtrasTest, ValueReturn) {
 }
 
 TEST(STLExtrasTest, ReferenceReturn) {
-  RangeWithValueType<const int&> R(1);
+  RangeWithValueType<const int &> R(1);
   auto C = concat<const int>(R, R);
   auto I = C.begin();
   ASSERT_NE(I, C.end());

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants