Skip to content

[mlir][Vector] add vector.insert canonicalization pattern to convert a chain of insertions to vector.from_elements #142944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from

Conversation

yangtetris
Copy link
Contributor

@yangtetris yangtetris commented Jun 5, 2025

Description

This change introduces a new canonicalization pattern for the MLIR Vector dialect that optimizes chains of insertions. The optimization identifies when a vector is completely initialized through a series of vector.insert operations and replaces the entire chain with a single vector.from_elements operation.

Please be aware that the new pattern doesn't work for poison vectors where only some elements are set, as MLIR doesn't support partial poison vectors for now.

New Pattern: InsertChainFullyInitialized

  • Detects chains of vector.insert operations.
  • Validates that all insertions are at static positions, and all intermediate insertions have only one use.
  • Ensures the entire vector is completely initialized.
  • Replaces the entire chain with a single vector.from_elementts operation.

Refactored Helper Function

  • Extracted calculateInsertPosition from foldDenseElementsAttrDestInsertOp to avoid code duplication.

Example

// Before:
%v1 = vector.insert %c10, %v0[0] : i64 into vector<2xi64>
%v2 = vector.insert %c20, %v1[1] : i64 into vector<2xi64>

// After:
%v2 = vector.from_elements %c10, %c20 : vector<2xi64>

It also works for multidimensional vectors.

// Before:
%v1 = vector.insert %cv0, %v0[0] : vector<3xi64> into vector<2x3xi64>
%v2 = vector.insert %cv1, %v1[1] : vector<3xi64> into vector<2x3xi64>

// After:
%0:3 = vector.to_elements %arg1 : vector<3xi64>
%1:3 = vector.to_elements %arg2 : vector<3xi64>
%v2 = vector.from_elements %0#0, %0#1, %0#2, %1#0, %1#1, %1#2 : vector<2x3xi64>

Copy link

github-actions bot commented Jun 5, 2025

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented Jun 5, 2025

@llvm/pr-subscribers-mlir-spirv
@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Yang Bai (yangtetris)

Changes

Description

This change introduces a new canonicalization pattern for the MLIR Vector dialect that optimizes chains of constant insertions into vectors initialized with ub.poison. The optimization identifies when a vector is completely initialized through a series of vector.insert operations and replaces the entire chain with a single arith.constant operation.

Please be aware that the new pattern doesn't work for poison vectors where only some elements are set, as MLIR doesn't support partial poison vectors for now.

New Pattern: InsertConstantToPoison

  • Detects chains of vector.insert operations that start from an ub.poison operation.
  • Validates that all insertions use constant values at static positions.
  • Ensures the entire vector is completely initialized.
  • Replaces the entire chain with a single arith.constant operation containing a DenseElementsAttr.

Refactored Helper Function

  • Extracted calculateInsertPositionAndExtractValues from foldDenseElementsAttrDestInsertOp to avoid code duplication.

Example

// Before:
%poison = ub.poison : vector&lt;2xi64&gt;
%v1 = vector.insert %c10, %poison[0] : i64 into vector&lt;2xi64&gt;
%v2 = vector.insert %c20, %v1[1] : i64 into vector&lt;2xi64&gt;

// After:
%result = arith.constant dense&lt;[10, 20]&gt; : vector&lt;2xi64&gt;

It also works for multidimensional vectors.

// Before:
%poison = ub.poison : vector&lt;2x3xi64&gt;
%cv0 = arith.constant dense&lt;[1, 2, 3]&gt; : vector&lt;3xi64&gt;
%cv1 = arith.constant dense&lt;[4, 5, 6]&gt; : vector&lt;3xi64&gt;
%v1 = vector.insert %cv0, %poison[0] : vector&lt;3xi64&gt; into vector&lt;2x3xi64&gt;
%v2 = vector.insert %cv1, %v1[1] : vector&lt;3xi64&gt; into vector&lt;2x3xi64&gt;

// After:
%result = arith.constant dense&lt;[[1, 2, 3], [4, 5, 6]]&gt; : vector&lt;2x3xi64&gt;

---
Full diff: https://github.com/llvm/llvm-project/pull/142944.diff


2 Files Affected:

- (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+145-29) 
- (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+32) 


``````````diff
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index fcfb401fd9867..253d148072dc0 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -3149,6 +3149,42 @@ LogicalResult InsertOp::verify() {
   return success();
 }
 
+// Calculate the linearized position for inserting elements and extract values
+// from the source attribute. Returns the starting position in the destination
+// vector where elements should be inserted.
+static int64_t calculateInsertPositionAndExtractValues(
+    VectorType destTy, const ArrayRef<int64_t> &positions, Attribute srcAttr,
+    SmallVector<Attribute> &valueToInsert) {
+  llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
+  copy(positions, completePositions.begin());
+  int64_t insertBeginPosition =
+      linearize(completePositions, computeStrides(destTy.getShape()));
+
+  Type destEltType = destTy.getElementType();
+
+  /// Converts the expected type to an IntegerAttr if there's
+  /// a mismatch.
+  auto convertIntegerAttr = [](Attribute attr, Type expectedType) -> Attribute {
+    if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
+      if (intAttr.getType() != expectedType)
+        return IntegerAttr::get(expectedType, intAttr.getInt());
+    }
+    return attr;
+  };
+
+  // The `convertIntegerAttr` method specifically handles the case
+  // for `llvm.mlir.constant` which can hold an attribute with a
+  // different type than the return type.
+  if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
+    for (auto value : denseSource.getValues<Attribute>())
+      valueToInsert.push_back(convertIntegerAttr(value, destEltType));
+  } else {
+    valueToInsert.push_back(convertIntegerAttr(srcAttr, destEltType));
+  }
+
+  return insertBeginPosition;
+}
+
 namespace {
 
 // If insertOp is only inserting unit dimensions it can be transformed to a
@@ -3191,6 +3227,109 @@ class InsertSplatToSplat final : public OpRewritePattern<InsertOp> {
   }
 };
 
+// Pattern to optimize a chain of constant insertions into a poison vector.
+//
+// This pattern identifies chains of vector.insert operations that:
+// 1. Start from an ub.poison operation.
+// 2. Insert only constant values at static positions.
+// 3. Completely initialize all elements in the resulting vector.
+//
+// When these conditions are met, the entire chain can be replaced with a
+// single arith.constant operation containing a dense elements attribute.
+//
+// Example transformation:
+//   %poison = ub.poison : vector<2xi32>
+//   %0 = vector.insert %c1, %poison[0] : i32 into vector<2xi32>
+//   %1 = vector.insert %c2, %0[1] : i32 into vector<2xi32>
+// ->
+//   %result = arith.constant dense<[1, 2]> : vector<2xi32>
+
+// TODO: Support the case where only some elements of the poison vector are set.
+//       Currently, MLIR doesn't support partial poison vectors.
+
+class InsertConstantToPoison final : public OpRewritePattern<InsertOp> {
+public:
+  using OpRewritePattern::OpRewritePattern;
+  LogicalResult matchAndRewrite(InsertOp op,
+                                PatternRewriter &rewriter) const override {
+
+    VectorType destTy = op.getDestVectorType();
+    if (destTy.isScalable())
+      return failure();
+    // Check if the result is used as the dest operand of another vector.insert
+    // Only care about the last op in a chain of insertions.
+    for (Operation *user : op.getResult().getUsers())
+      if (auto insertOp = dyn_cast<InsertOp>(user))
+        if (insertOp.getDest() == op.getResult())
+          return failure();
+
+    InsertOp firstInsertOp;
+    InsertOp previousInsertOp = op;
+    SmallVector<InsertOp> chainInsertOps;
+    SmallVector<Attribute> srcAttrs;
+    while (previousInsertOp) {
+      // Dynamic position is not supported.
+      if (previousInsertOp.hasDynamicPosition())
+        return failure();
+
+      // The inserted content must be constant.
+      chainInsertOps.push_back(previousInsertOp);
+      srcAttrs.push_back(Attribute());
+      matchPattern(previousInsertOp.getValueToStore(),
+                   m_Constant(&srcAttrs.back()));
+      if (!srcAttrs.back())
+        return failure();
+
+      // An insertion at poison index makes the entire chain poisoned.
+      if (is_contained(previousInsertOp.getStaticPosition(),
+                       InsertOp::kPoisonIndex))
+        return failure();
+
+      firstInsertOp = previousInsertOp;
+      previousInsertOp = previousInsertOp.getDest().getDefiningOp<InsertOp>();
+    }
+
+    if (!firstInsertOp.getDest().getDefiningOp<ub::PoisonOp>())
+      return failure();
+
+    // Need to make sure all elements are initialized.
+    int64_t vectorSize = destTy.getNumElements();
+    int64_t initializedCount = 0;
+    SmallVector<bool> initialized(vectorSize, false);
+    SmallVector<Attribute> initValues(vectorSize);
+
+    for (auto [insertOp, srcAttr] : llvm::zip(chainInsertOps, srcAttrs)) {
+      // Calculate the linearized position for inserting elements, as well as
+      // convert the source attribute to the proper type.
+      SmallVector<Attribute> valueToInsert;
+      int64_t insertBeginPosition = calculateInsertPositionAndExtractValues(
+          destTy, insertOp.getStaticPosition(), srcAttr, valueToInsert);
+      for (auto index :
+           llvm::seq<int64_t>(insertBeginPosition,
+                              insertBeginPosition + valueToInsert.size())) {
+        if (initialized[index])
+          continue;
+
+        initialized[index] = true;
+        ++initializedCount;
+        initValues[index] = valueToInsert[index - insertBeginPosition];
+      }
+      // If all elements in the vector have been initialized, we can stop
+      // processing the remaining insert operations in the chain.
+      if (initializedCount == vectorSize)
+        break;
+    }
+
+    // some positions are not initialized.
+    if (initializedCount != vectorSize)
+      return failure();
+
+    auto newAttr = DenseElementsAttr::get(destTy, initValues);
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, destTy, newAttr);
+    return success();
+  }
+};
+
 } // namespace
 
 static Attribute
@@ -3217,35 +3356,11 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
       !insertOp->hasOneUse())
     return {};
 
-  // Calculate the linearized position of the continuous chunk of elements to
-  // insert.
-  llvm::SmallVector<int64_t> completePositions(destTy.getRank(), 0);
-  copy(insertOp.getStaticPosition(), completePositions.begin());
-  int64_t insertBeginPosition =
-      linearize(completePositions, computeStrides(destTy.getShape()));
-
+  // Calculate the linearized position for inserting elements, as well as
+  // convert the source attribute to the proper type.
   SmallVector<Attribute> insertedValues;
-  Type destEltType = destTy.getElementType();
-
-  /// Converts the expected type to an IntegerAttr if there's
-  /// a mismatch.
-  auto convertIntegerAttr = [](Attribute attr, Type expectedType) -> Attribute {
-    if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
-      if (intAttr.getType() != expectedType)
-        return IntegerAttr::get(expectedType, intAttr.getInt());
-    }
-    return attr;
-  };
-
-  // The `convertIntegerAttr` method specifically handles the case
-  // for `llvm.mlir.constant` which can hold an attribute with a
-  // different type than the return type.
-  if (auto denseSource = llvm::dyn_cast<DenseElementsAttr>(srcAttr)) {
-    for (auto value : denseSource.getValues<Attribute>())
-      insertedValues.push_back(convertIntegerAttr(value, destEltType));
-  } else {
-    insertedValues.push_back(convertIntegerAttr(srcAttr, destEltType));
-  }
+  int64_t insertBeginPosition = calculateInsertPositionAndExtractValues(
+      destTy, insertOp.getStaticPosition(), srcAttr, insertedValues);
 
   auto allValues = llvm::to_vector(denseDst.getValues<Attribute>());
   copy(insertedValues, allValues.begin() + insertBeginPosition);
@@ -3256,7 +3371,8 @@ foldDenseElementsAttrDestInsertOp(InsertOp insertOp, Attribute srcAttr,
 
 void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                            MLIRContext *context) {
-  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat>(context);
+  results.add<InsertToBroadcast, BroadcastFolder, InsertSplatToSplat,
+              InsertConstantToPoison>(context);
 }
 
 OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index a06a9f67d54dc..36f3d7196bb93 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2320,6 +2320,38 @@ func.func @insert_2d_constant() -> (vector<2x3xi32>, vector<2x3xi32>, vector<2x3
 
 // -----
 
+// CHECK-LABEL: func.func @fully_insert_scalar_constant_to_poison_vector
+//       CHECK: %[[VAL0:.+]] = arith.constant dense<[10, 20]> : vector<2xi64>
+//  CHECK-NEXT: return %[[VAL0]]
+func.func @fully_insert_scalar_constant_to_poison_vector() -> vector<2xi64> {
+  %poison = ub.poison : vector<2xi64>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %e0 = arith.constant 10 : i64
+  %e1 = arith.constant 20 : i64
+  %v1 = vector.insert %e0, %poison[%c0] : i64 into vector<2xi64>
+  %v2 = vector.insert %e1, %v1[%c1] : i64 into vector<2xi64>
+  return %v2 : vector<2xi64>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @fully_insert_vector_constant_to_poison_vector
+//       CHECK: %[[VAL0:.+]] = arith.constant dense<{{\[\[1, 2, 3\], \[4, 5, 6\]\]}}> : vector<2x3xi64>
+//  CHECK-NEXT: return %[[VAL0]]
+func.func @fully_insert_vector_constant_to_poison_vector() -> vector<2x3xi64> {
+  %poison = ub.poison : vector<2x3xi64>
+  %cv0 = arith.constant dense<[1, 2, 3]> : vector<3xi64>
+  %cv1 = arith.constant dense<[4, 5, 6]> : vector<3xi64>
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %v1 = vector.insert %cv0, %poison[%c0] : vector<3xi64> into vector<2x3xi64>
+  %v2 = vector.insert %cv1, %v1[%c1] : vector<3xi64> into vector<2x3xi64>
+  return %v2 : vector<2x3xi64>
+}
+
+// -----
+
 // CHECK-LABEL: func.func @insert_2d_splat_constant
 //   CHECK-DAG: %[[ACST:.*]] = arith.constant dense<0> : vector<2x3xi32>
 //   CHECK-DAG: %[[BCST:.*]] = arith.constant dense<{{\[\[99, 0, 0\], \[0, 0, 0\]\]}}> : vector<2x3xi32>

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comments. Otherwise, it LGTM!

return failure();

auto newAttr = DenseElementsAttr::get(destTy, initValues);
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, destTy, newAttr);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use vector.from_elements here and let it canonicalize to arith.constant if all inputs are constant?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea! That would indeed remove the constant-only restriction. Let me try implementing it. Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just learned that currently vector.from_elements does not have a folder to fold it into an arith.constant op. Do you happen to know if there is any ongoing PR implementing this? If not, I would be interested in creating one.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to create one, we should have one like that if it doesnt exist already.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, I think that's is a great idea and something that we would need anyways. Let's implement that first!

Comment on lines 3152 to 3154
// Calculate the linearized position for inserting elements and extract values
// from the source attribute. Returns the starting position in the destination
// vector where elements should be inserted.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calculat the linearized position based on what? I cannot understand what this is trying to say.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the confusion. I'll updated the function’s documentation to clarify this. Since we adopted a new implementation using from_elements, I also refactored this function. It now focuses solely on calculating the linearized position of the continuous chunk of elements to insert within the destination vector.

Comment on lines 3167 to 3173
auto convertIntegerAttr = [](Attribute attr, Type expectedType) -> Attribute {
if (auto intAttr = mlir::dyn_cast<IntegerAttr>(attr)) {
if (intAttr.getType() != expectedType)
return IntegerAttr::get(expectedType, intAttr.getInt());
}
return attr;
};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this safe?

Copy link
Contributor Author

@yangtetris yangtetris Jun 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this code originates from #88314 which addressed issues related to mismatches between index types and i64. I think this background is enough to reassures us about the safety. But, it also reminded me that we might need a similar attribute conversion in the from_elements to constant PR. I'll verify later whether that pattern encounters similar issues.

/// This pattern identifies chains of vector.insert operations that:
/// 1. Start from an ub.poison operation.
/// 2. Insert only constant values at static positions.
/// 3. Completely initialize all elements in the resulting vector.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If all elements are initialized, why does it matter if the operation started from a ub.poison?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the vector didn't start from ub.poison, the existing foldDenseElementsAttrDestInsertOp in InsertOp::fold should already be able to handle the folding. However, it does not guarantee all elements are initialized. That’s why I want to add this new pattern explicitly targets ub.poison.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, this feels like an awkward special case. Why can't the existing folder handle this case instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I chose not to use the folder is that this method requires traversing the entire chain of insert operations, which results in linear time complexity for each traversal. If someone uses createOrFold to create a sequence of insert operations, the access pattern would look like this:

insert op 1 -> ub.poison  
insert op 2 -> insert op 1 -> ub.poison  
insert op 3 -> insert op 2 -> insert op 1 -> ub.poison  
...
insert op n -> insert op n - 1 -> ... -> inset op 1 -> ub.poison

The condition for being fully initialized is only satisfied at the very last line, but the overall complexity becomes O(n²) due to repeated traversals.

However, I also admit that this is an awkward case. I’m completely okay if you think it’s not worth introducing a new canonicalization pattern. Perhaps it would be much easier to address this once partial poisoned vectors can be easily represented.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't matter at all if the start is a ub.poison. If you think that the pattern isn't a canonicalization if the start is something else, then it shouldn't be a canonicalization at all. Personally, this seems like a cleanup and we could live with this in cleanup before lowering to backends.

But the initial start shouldn't matter at all whatever the case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't matter at all if the start is a ub.poison.

After a second thought, I think it indeed makes sense. We don’t need to worry about whether a canonicalization pattern may overlap with a folder in some cases. I'll remove the check for ub.poison.

if (previousInsertOp.hasDynamicPosition())
return failure();

// The inserted content must be constant.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you use from_elements this is not required.

if (!srcAttrs.back())
return failure();

// An insertion at poison index makes the entire chain poisoned.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would happen in this case? Can you add a comment if this will get folded by the folder?

Copy link
Contributor Author

@yangtetris yangtetris Jun 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The insert operation’s folder will fold this case into a ub.poison, effectively truncating the insert chain’s backward traversal. If the vector is not fully initialized by that point, this pattern will definitively fail. I will add a comment to clarify this behavior in the code.

Update: I also moved this check to another place to give this pattern a better chance of succeeding.

Comment on lines 3299 to 3300
// Currently, MLIR doesn't support partial poison vectors, so we can only
// optimize when the entire vector is completely initialized.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly, you can always do a shufflevector instruction to get these things in the right order, but we can ignore this for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that this is about representing a constant vector with poison and non-poison values. AFAIK, we can't represent that in MLIR right now

SmallVector<bool> initialized(vectorSize, false);
SmallVector<Attribute> initValues(vectorSize);

for (auto [insertOp, srcAttr] : llvm::zip(chainInsertOps, srcAttrs)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use zip_equal if they are equal

/// This pattern identifies chains of vector.insert operations that:
/// 1. Start from an ub.poison operation.
/// 2. Insert only constant values at static positions.
/// 3. Completely initialize all elements in the resulting vector.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, this feels like an awkward special case. Why can't the existing folder handle this case instead?

@yangtetris
Copy link
Contributor Author

Almost all of the changes have been completed. However, while updating the test base, I found that VectorFromElementsLowering currently does not support vectors with rank > 1. So, replacing insert chains with from_elements would break some lower-to-llvm tests. Let's pause until that issue is fixed.

@dcaballe
Copy link
Contributor

What is the current state of this? Any blockers?

@yangtetris
Copy link
Contributor Author

What is the current state of this? Any blockers?

It is still blocked due to the missing from_elements to llvm conversion for multi-dim vectors. I think we can

  1. Support lowering multi-dim vectors to llvm first.
  2. Or, restrict this pattern to make it only work for 1-dim vectors.

@dcaballe
Copy link
Contributor

Support lowering multi-dim vectors to llvm first.

It sounds like this shouldn't be too complicated. Is that something you could help with?

@yangtetris
Copy link
Contributor Author

It sounds like this shouldn't be too complicated. Is that something you could help with?

I'm not familiar with that part, but I think I can give it a try.

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates. I am wondering whether this is robust enough to handle overlapping vector.insert operations?

Also, I am not sure whether this would be beneficial? Also, are there any tests for this case?

// Before:
%poison = ub.poison : vector<2x3xi64>
%v1 = vector.insert %cv0, %poison[0] : vector<3xi64> into vector<2x3xi64>
%result = vector.insert %cv1, %v1[1] : vector<3xi64> into vector<2x3xi64>

// After:
%v1 = vector.extract %cv0[0] : i64 from vector<3xi64>
%v2 = vector.extract %cv0[1] : i64 from vector<3xi64>
%v3 = vector.extract %cv0[2] : i64 from vector<3xi64>
%v4 = vector.extract %cv1[0] : i64 from vector<3xi64>
%v5 = vector.extract %cv1[1] : i64 from vector<3xi64>
%v6 = vector.extract %cv1[2] : i64 from vector<3xi64>
%result = vector.from_elements %v1, %v2, %v3, %v4, %v5, %v6 : vector<2x3xi64>

Comment on lines 3291 to 3292
// Check if the result is used as the dest operand of another vector.insert
// Only care about the last op in a chain of insertions.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please document why we need the check as opposed to what is being checked.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense. Document updated.

Comment on lines +3309 to +3310
if (currentOp && !currentOp->hasOneUse())
return failure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this condition guarantee that conditions checked in lines L3293-L3296 are always met?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are two different checks.
The check on lines L3293-L3296 is to ensure that there are no more insert ops after the current op. The purpose of that check is to skip the O(n) time complexity part for intermediate insert ops.
In contrast, this hasOneUse check is to handle the following case:

%v1 = %vector.insert %c1, %v0[0] : i64 into vector<3xi64>
%v2 = %vector.insert %c2, %v1[1] : i64 into vector<3xi64>
%v3_3 = %vector.insert %c3, %v2[2] : i64 into vector<3xi64>
%v3_4 = %vector.insert %c4, %v2[2] : i64 into vector<3xi64>
%v3_5 = %vector.insert %c5, %v2[2] : i64 into vector<3xi64>

The key point is that when %v1 or %v2 has multiple users, we should not introduce new from_elements ops because the insert chain cannot be completely eliminated. This IR is also an example of the 'explosion' you asked about later. We want to avoid generating three new form_elements ops for %v3_3, %v3_4 and %v3_5.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also added a negative test for such case.

Comment on lines 3307 to 3308
// Check that intermediate inserts have only one use to avoid an explosion
// of vectors.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite follow what you mean by "explosion of vectors". With multiple users there simply wouldn't be a "chain" of vector.inserts. Could you provide an example of where such an "explosion" could happen?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refer to the previous question for an example.

Comment on lines 3321 to 3322
// The insert op folder will fold an insert at poison index into a
// ub.poison, which truncates the insert chain's backward traversal.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't the reason to "fail" simply that there's nothing that this transformation can do with an index that's poison? I'm not sure whether other transformations are relevant here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fold function is relevant. However, the reason you gave is also true and easier to understand. Will update the document, thanks.

Comment on lines 3342 to 3343
if (initialized[index])
continue;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there are multiple vector.insert inserting at the same "index", how do you make sure to select the right one (i.e. the one inserting "last")? Are there any tests for this scenario?

Copy link
Contributor Author

@yangtetris yangtetris Aug 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern uses two reverse traversals to ensure the last insert op takes effect.
It follows the chain backwards to collect insert ops into a list, then a llvm::reverse(L3364) is applied to the list. So the later insert ops in the chain will override the earlier ones. The continue statement you quoted only affects the initialization count and won't skip setting values.
It is better to have a test. I just added fully_insert_to_vector_overlap to cover it.

Co-authored-by: Andrzej Warzyński <andrzej.warzynski@gmail.com>
Copy link

github-actions bot commented Aug 4, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@yangtetris yangtetris changed the title [mlir][Vector] add vector.insert canonicalization pattern for vectors created from ub.poison [mlir][Vector] add vector.insert canonicalization pattern to convert a chain of insertions to vector.from_elements Aug 4, 2025
@yangtetris
Copy link
Contributor Author

Thanks for the updates. I am wondering whether this is robust enough to handle overlapping vector.insert operations?

Please refer to another thread on this page.

Also, I am not sure whether this would be beneficial? Also, are there any tests for this case?

Sorry I didn't update the PR description in time. With @dcaballe 's advice, we now use vector.to_elements, which looks much better. The case fully_insert_to_vector_overlap covers this scenario.
@banach-space Please take another look at the new PR description. Thanks!

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates!

Overall looks good and is very clearly crafted! I've left some final suggestions inline.

Comment on lines 3291 to 3297
// This pattern has linear time complexity with respect to the length of the
// insert chain. So we only care about the last insert op which has the
// highest probability of success.
for (Operation *user : op.getResult().getUsers())
if (auto insertOp = dyn_cast<InsertOp>(user))
if (insertOp.getDest() == op.getResult())
return failure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initially, I found this block a bit confusing - below is my current understanding for confirmation.


To identify a valid chain of vector.insert operations, the pattern first needs to locate the trailing vector.insert op. Consider the following two example chains:

  1. insert -> insert -> insert -> insert
  2. insert -> insert -> insert -> insert

Only Option 1 qualifies under the current design. That makes sense to me - it’s a conscious design choice that simplifies the pattern logic. Supporting Option 2 would be possible, but would likely add significant complexity.

This code block ensures that the given op is the trailing vector.insert in a chain that matches this pattern.

If that's accurate, I’d suggest updating the in-code comment to something like:

Ensure this is the trailing vector.insert op in a chain of inserts.

I'd also recommend adding a note about this constraint in the high-level comment for the pattern.


As for this comment ...

// This pattern has linear time complexity with respect to the length of the
// insert chain. So we only care about the last insert op which has the
// highest probability of success.

IMHO, you want to avoid matching overly complex or fragmented insert chains, and focusing on the last op is a clean and efficient approach. That's the design and that's fine. To me, everything else is secondary and can be skipped.

I mostly want to avoid our "future selves" getting into a discussion on "probability" and "cost" 😅

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your understanding. This is exactly what this code is trying to express! The comment you proposed looks much more reasonable, let me update it.

Comment on lines +3269 to +3271
/// 1. Only insert values at static positions.
/// 2. Completely initialize all elements in the resulting vector.
/// 3. All intermediate insert operations have only one use.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] It would be helpful if you made references to these high level design points within the implementation (e.g. "Check Cond 1. (only static indices are used)").

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point!


// -----

// Test the case where multiple ops insert to overlapped indices.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not immediately obvious that ARG0 is completely skipped - could you add a comment?

I would also add more examples of overlapping vector.insert Ops. Perhaps:

insert scalar
insert scalar
insert scalar
insert vector

where "insert vector" overwrites all of "insert scalar".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Let me add a detailed comment and a new test.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants