Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][SCF] ForOp: Remove getIterArgNumberForOpOperand #66629

Merged

Conversation

matthias-springer
Copy link
Member

This function was inconsistent with the remaining API because it accepted OpOperand & that do not belong to the op. All the other functions assert. This helper function is also not really necessary, as the iter_arg number is identical to the result number.

Depends on #66622. Review only the top commit.

@llvmbot
Copy link
Collaborator

llvmbot commented Sep 18, 2023

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir-bufferization
@llvm/pr-subscribers-mlir-cf
@llvm/pr-subscribers-mlir-scf
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Changes

This function was inconsistent with the remaining API because it accepted OpOperand & that do not belong to the op. All the other functions assert. This helper function is also not really necessary, as the iter_arg number is identical to the result number.

Depends on #66622. Review only the top commit.


Full diff: https://github.com/llvm/llvm-project/pull/66629.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (-17)
  • (modified) mlir/include/mlir/IR/ValueRange.h (+3-7)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp (+5-1)
  • (modified) mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp (+5-5)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp (+9-13)
  • (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+2-2)
  • (modified) mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp (+3-3)
  • (modified) mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp (+9-10)
  • (modified) mlir/lib/IR/OperationSupport.cpp (+8)
  • (modified) mlir/lib/Transforms/Utils/CFGToSCF.cpp (+20-8)
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 6d8aaf64e3263b9..89c1a06412947b2 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -250,17 +250,10 @@ def ForOp : SCF_Op<"for",
         "expected an index less than the number of region iter args");
       return getBody()->getArguments().drop_front(getNumInductionVars())[index];
     }
-    MutableArrayRef<OpOperand> getIterOpOperands() {
-      return
-        getOperation()->getOpOperands().drop_front(getNumControlOperands());
-    }
 
     void setLowerBound(Value bound) { getOperation()->setOperand(0, bound); }
     void setUpperBound(Value bound) { getOperation()->setOperand(1, bound); }
     void setStep(Value step) { getOperation()->setOperand(2, step); }
-    void setIterArg(unsigned iterArgNum, Value iterArgValue) {
-      getOperation()->setOperand(iterArgNum + getNumControlOperands(), iterArgValue);
-    }
 
     /// Number of induction variables, always 1 for scf::ForOp.
     unsigned getNumInductionVars() { return 1; }
@@ -270,16 +263,6 @@ def ForOp : SCF_Op<"for",
     }
     /// Number of operands controlling the loop: lb, ub, step
     unsigned getNumControlOperands() { return 3; }
-    /// Get the iter arg number for an operand. If it isnt an iter arg
-    /// operand return std::nullopt.
-    std::optional<unsigned> getIterArgNumberForOpOperand(OpOperand &opOperand) {
-      if (opOperand.getOwner() != getOperation())
-        return std::nullopt;
-      unsigned operandNumber = opOperand.getOperandNumber();
-      if (operandNumber < getNumControlOperands())
-        return std::nullopt;
-      return operandNumber - getNumControlOperands();
-    }
 
     /// Get the region iter arg that corresponds to an OpOperand.
     /// This helper prevents internal op implementation detail leakage to
diff --git a/mlir/include/mlir/IR/ValueRange.h b/mlir/include/mlir/IR/ValueRange.h
index f1a1f1841f179e7..9c11178f9cd9cae 100644
--- a/mlir/include/mlir/IR/ValueRange.h
+++ b/mlir/include/mlir/IR/ValueRange.h
@@ -165,13 +165,9 @@ class MutableOperandRange {
   /// Returns the OpOperand at the given index.
   OpOperand &operator[](unsigned index) const;
 
-  OperandRange::iterator begin() const {
-    return static_cast<OperandRange>(*this).begin();
-  }
-
-  OperandRange::iterator end() const {
-    return static_cast<OperandRange>(*this).end();
-  }
+  /// Iterators enumerate OpOperands.
+  MutableArrayRef<OpOperand>::iterator begin() const;
+  MutableArrayRef<OpOperand>::iterator end() const;
 
 private:
   /// Update the length of this range to the one provided.
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
index 43ba11cf132cb92..09d30835828084d 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp
@@ -47,6 +47,10 @@ static Value buildBoolValue(OpBuilder &builder, Location loc, bool value) {
 
 static bool isMemref(Value v) { return v.getType().isa<BaseMemRefType>(); }
 
+static bool isMemrefOperand(OpOperand &operand) {
+  return isMemref(operand.get());
+}
+
 //===----------------------------------------------------------------------===//
 // Backedges analysis
 //===----------------------------------------------------------------------===//
@@ -937,7 +941,7 @@ BufferDeallocation::handleInterface(RegionBranchTerminatorOpInterface op) {
 
   // Add an additional operand for every MemRef for the ownership indicator.
   if (!funcWithoutDynamicOwnership) {
-    unsigned numMemRefs = llvm::count_if(operands, isMemref);
+    unsigned numMemRefs = llvm::count_if(operands, isMemrefOperand);
     SmallVector<Value> newOperands{OperandRange(operands)};
     auto ownershipValues =
         deallocOp.getUpdatedConditions().take_front(numMemRefs);
diff --git a/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp
index e847e946eef1b5d..9423af2542690d9 100644
--- a/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/ControlFlow/Transforms/BufferDeallocationOpInterfaceImpl.cpp
@@ -96,12 +96,12 @@ struct CondBranchOpInterface
         mapping[retained] = ownership;
       }
       SmallVector<Value> replacements, ownerships;
-      for (Value operand : destOperands) {
-        replacements.push_back(operand);
-        if (isMemref(operand)) {
-          assert(mapping.contains(operand) &&
+      for (OpOperand &operand : destOperands) {
+        replacements.push_back(operand.get());
+        if (isMemref(operand.get())) {
+          assert(mapping.contains(operand.get()) &&
                  "Should be contained at this point");
-          ownerships.push_back(mapping[operand]);
+          ownerships.push_back(mapping[operand.get()]);
         }
       }
       replacements.append(ownerships);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index 21bc0554e717692..a9debb7bbc489a4 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -810,13 +810,9 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
   OpBuilder::InsertionGuard g(rewriter);
   rewriter.setInsertionPointAfter(hoistedPackedTensor.getDefiningOp());
 
-  std::optional<unsigned> maybeOperandNumber =
-      forOp.getIterArgNumberForOpOperand(*pUse);
-  assert(maybeOperandNumber.has_value() && "expected a proper iter arg number");
-
-  int64_t operandNumber = maybeOperandNumber.value();
+  unsigned iterArgNumber = forOp.getResultForOpOperand(*pUse).getResultNumber();
   auto yieldOp = cast<scf::YieldOp>(forOp.getBody(0)->getTerminator());
-  auto yieldingExtractSliceOp = yieldOp->getOperand(operandNumber)
+  auto yieldingExtractSliceOp = yieldOp->getOperand(iterArgNumber)
                                     .getDefiningOp<tensor::ExtractSliceOp>();
   if (!yieldingExtractSliceOp)
     return tensor::ExtractSliceOp();
@@ -829,9 +825,9 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
     return tensor::ExtractSliceOp();
 
   SmallVector<Value> initArgs = forOp.getInitArgs();
-  initArgs[operandNumber] = hoistedPackedTensor;
+  initArgs[iterArgNumber] = hoistedPackedTensor;
   SmallVector<Value> yieldOperands = yieldOp.getOperands();
-  yieldOperands[operandNumber] = yieldingExtractSliceOp.getSource();
+  yieldOperands[iterArgNumber] = yieldingExtractSliceOp.getSource();
 
   int64_t numOriginalForOpResults = initArgs.size();
   LLVM_DEBUG(DBGS() << "numOriginalForOpResults: " << numOriginalForOpResults
@@ -844,7 +840,7 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
         hoistedPackedTensor.getLoc(), hoistedPackedTensor,
         outerSliceOp.getMixedOffsets(), outerSliceOp.getMixedSizes(),
         outerSliceOp.getMixedStrides());
-    rewriter.replaceAllUsesWith(forOp.getResult(operandNumber), extracted);
+    rewriter.replaceAllUsesWith(forOp.getResult(iterArgNumber), extracted);
   }
   scf::ForOp newForOp =
       replaceLoopWithNewYields(rewriter, forOp, initArgs, yieldOperands);
@@ -853,20 +849,20 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
                     << "\n");
   LLVM_DEBUG(DBGS() << "replace source of: " << extracted << "\n");
   LLVM_DEBUG(DBGS() << "with result #"
-                    << numOriginalForOpResults + operandNumber
+                    << numOriginalForOpResults + iterArgNumber
                     << " of forOp, giving us: " << extracted << "\n");
   rewriter.startRootUpdate(extracted);
   extracted.getSourceMutable().assign(
-      newForOp.getResult(numOriginalForOpResults + operandNumber));
+      newForOp.getResult(numOriginalForOpResults + iterArgNumber));
   rewriter.finalizeRootUpdate(extracted);
 
   LLVM_DEBUG(DBGS() << "replace uses of: " << paddedValueBeforeHoisting
                     << "\n");
   LLVM_DEBUG(DBGS() << "with region iter arg #"
-                    << numOriginalForOpResults + operandNumber << "\n");
+                    << numOriginalForOpResults + iterArgNumber << "\n");
   rewriter.replaceAllUsesWith(
       paddedValueBeforeHoisting,
-      newForOp.getRegionIterArg(numOriginalForOpResults + operandNumber));
+      newForOp.getRegionIterArg(numOriginalForOpResults + iterArgNumber));
 
   return extracted;
 }
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 5565aefbad18db5..2a760c76d2f6867 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -932,7 +932,7 @@ replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand,
   assert(operand.get().getType() != replacement.getType() &&
          "Expected a different type");
   SmallVector<Value> newIterOperands;
-  for (OpOperand &opOperand : forOp.getIterOpOperands()) {
+  for (OpOperand &opOperand : forOp.getInitArgsMutable()) {
     if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
       newIterOperands.push_back(replacement);
       continue;
@@ -1015,7 +1015,7 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
 
   LogicalResult matchAndRewrite(ForOp op,
                                 PatternRewriter &rewriter) const override {
-    for (auto it : llvm::zip(op.getIterOpOperands(), op.getResults())) {
+    for (auto it : llvm::zip(op.getInitArgsMutable(), op.getResults())) {
       OpOperand &iterOpOperand = std::get<0>(it);
       auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
       if (!incomingCast ||
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index 11cfefed890c669..8c04a8887013c8f 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -332,7 +332,7 @@ DenseSet<int64_t> getEquivalentBuffers(Block::BlockArgListType bbArgs,
 /// Helper function for loop bufferization. Return the bufferized values of the
 /// given OpOperands. If an operand is not a tensor, return the original value.
 static FailureOr<SmallVector<Value>>
-getBuffers(RewriterBase &rewriter, MutableArrayRef<OpOperand> operands,
+getBuffers(RewriterBase &rewriter, MutableOperandRange operands,
            const BufferizationOptions &options) {
   SmallVector<Value> result;
   for (OpOperand &opOperand : operands) {
@@ -606,7 +606,7 @@ struct ForOpInterface
 
     // The new memref init_args of the loop.
     FailureOr<SmallVector<Value>> maybeInitArgs =
-        getBuffers(rewriter, forOp.getIterOpOperands(), options);
+        getBuffers(rewriter, forOp.getInitArgsMutable(), options);
     if (failed(maybeInitArgs))
       return failure();
     SmallVector<Value> initArgs = *maybeInitArgs;
@@ -825,7 +825,7 @@ struct WhileOpInterface
 
     // The new memref init_args of the loop.
     FailureOr<SmallVector<Value>> maybeInitArgs =
-        getBuffers(rewriter, whileOp->getOpOperands(), options);
+        getBuffers(rewriter, whileOp.getInitsMutable(), options);
     if (failed(maybeInitArgs))
       return failure();
     SmallVector<Value> initArgs = *maybeInitArgs;
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 1ce25565edcaf61..d7b986e052ae575 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -508,7 +508,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
                                       MutableArrayRef<scf::ForOp> loops) {
   // 1. Get the producer of the source (potentially walking through
   // `iter_args` of nested `scf.for`)
-  auto [fusableProducer, destinationIterArg] =
+  auto [fusableProducer, destinationInitArg] =
       getUntiledProducerFromSliceSource(&candidateSliceOp.getSourceMutable()[0],
                                         loops);
   if (!fusableProducer)
@@ -575,17 +575,16 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
   // TODO: This can be modeled better if the `DestinationStyleOpInterface`.
   // Update to use that when it does become available.
   scf::ForOp outerMostLoop = loops.front();
-  std::optional<unsigned> iterArgNumber;
-  if (destinationIterArg) {
-    iterArgNumber =
-        outerMostLoop.getIterArgNumberForOpOperand(*destinationIterArg.value());
-  }
-  if (iterArgNumber) {
+  if (destinationInitArg &&
+      (*destinationInitArg)->getOwner() == outerMostLoop) {
+    unsigned iterArgNumber =
+        outerMostLoop.getResultForOpOperand(**destinationInitArg)
+            .getResultNumber();
     int64_t resultNumber = fusableProducer.getResultNumber();
     if (auto dstOp =
             dyn_cast<DestinationStyleOpInterface>(fusableProducer.getOwner())) {
-      outerMostLoop.setIterArg(iterArgNumber.value(),
-                               dstOp.getTiedOpOperand(fusableProducer)->get());
+      (*destinationInitArg)
+          ->set(dstOp.getTiedOpOperand(fusableProducer)->get());
     }
     for (auto tileAndFusedOp : tileAndFuseResult->tiledOps) {
       auto dstOp = dyn_cast<DestinationStyleOpInterface>(tileAndFusedOp);
@@ -594,7 +593,7 @@ mlir::scf::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
       scf::ForOp innerMostLoop = loops.back();
       updateDestinationOperandsForTiledOp(
           rewriter, dstOp.getDpsInitOperand(resultNumber)->get(),
-          innerMostLoop.getRegionIterArgs()[iterArgNumber.value()]);
+          innerMostLoop.getRegionIterArgs()[iterArgNumber]);
     }
   }
   return scf::SCFFuseProducerOfSliceResult{fusableProducer,
diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp
index 7b17e231ce1065f..b0c50f3d6e29855 100644
--- a/mlir/lib/IR/OperationSupport.cpp
+++ b/mlir/lib/IR/OperationSupport.cpp
@@ -522,6 +522,14 @@ OpOperand &MutableOperandRange::operator[](unsigned index) const {
   return owner->getOpOperand(start + index);
 }
 
+MutableArrayRef<OpOperand>::iterator MutableOperandRange::begin() const {
+  return owner->getOpOperands().slice(start, length).begin();
+}
+
+MutableArrayRef<OpOperand>::iterator MutableOperandRange::end() const {
+  return owner->getOpOperands().slice(start, length).end();
+}
+
 //===----------------------------------------------------------------------===//
 // MutableOperandRangeRange
 
diff --git a/mlir/lib/Transforms/Utils/CFGToSCF.cpp b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
index 9aab89ed7553600..e7bf6628ccbd7e0 100644
--- a/mlir/lib/Transforms/Utils/CFGToSCF.cpp
+++ b/mlir/lib/Transforms/Utils/CFGToSCF.cpp
@@ -137,6 +137,13 @@ getMutableSuccessorOperands(Block *block, unsigned successorIndex) {
   return succOps.getMutableForwardedOperands();
 }
 
+/// Return the operand range used to transfer operands from `block` to its
+/// successor with the given index.
+static OperandRange getSuccessorOperands(Block *block,
+                                         unsigned successorIndex) {
+  return getMutableSuccessorOperands(block, successorIndex);
+}
+
 /// Appends all the block arguments from `other` to the block arguments of
 /// `block`, copying their types and locations.
 static void addBlockArgumentsFromOther(Block *block, Block *other) {
@@ -175,8 +182,14 @@ class Edge {
 
   /// Returns the arguments of this edge that are passed to the block arguments
   /// of the successor.
-  MutableOperandRange getSuccessorOperands() const {
-    return getMutableSuccessorOperands(fromBlock, successorIndex);
+  MutableOperandRange getMutableSuccessorOperands() const {
+    return ::getMutableSuccessorOperands(fromBlock, successorIndex);
+  }
+
+  /// Returns the arguments of this edge that are passed to the block arguments
+  /// of the successor.
+  OperandRange getSuccessorOperands() const {
+    return ::getSuccessorOperands(fromBlock, successorIndex);
   }
 };
 
@@ -262,7 +275,7 @@ class EdgeMultiplexer {
     assert(result != blockArgMapping.end() &&
            "Edge was not originally passed to `create` method.");
 
-    MutableOperandRange successorOperands = edge.getSuccessorOperands();
+    MutableOperandRange successorOperands = edge.getMutableSuccessorOperands();
 
     // Extra arguments are always appended at the end of the block arguments.
     unsigned extraArgsBeginIndex =
@@ -666,7 +679,7 @@ transformToReduceLoop(Block *loopHeader, Block *exitBlock,
   // invalidated when mutating the operands through a different
   // `MutableOperandRange` of the same operation.
   SmallVector<Value> loopHeaderSuccessorOperands =
-      llvm::to_vector(getMutableSuccessorOperands(latch, loopHeaderIndex));
+      llvm::to_vector(getSuccessorOperands(latch, loopHeaderIndex));
 
   // Add all values used in the next iteration to the exit block. Replace
   // any uses that are outside the loop with the newly created exit block.
@@ -742,7 +755,7 @@ transformToReduceLoop(Block *loopHeader, Block *exitBlock,
 
           loopHeaderSuccessorOperands.push_back(argument);
           for (Edge edge : successorEdges(latch))
-            edge.getSuccessorOperands().append(argument);
+            edge.getMutableSuccessorOperands().append(argument);
         }
 
         use.set(blockArgument);
@@ -939,9 +952,8 @@ static FailureOr<SmallVector<Block *>> transformToStructuredCFBranches(
   if (regionEntry->getNumSuccessors() == 1) {
     // Single successor we can just splice together.
     Block *successor = regionEntry->getSuccessor(0);
-    for (auto &&[oldValue, newValue] :
-         llvm::zip(successor->getArguments(),
-                   getMutableSuccessorOperands(regionEntry, 0)))
+    for (auto &&[oldValue, newValue] : llvm::zip(
+             successor->getArguments(), getSuccessorOperands(regionEntry, 0)))
       oldValue.replaceAllUsesWith(newValue);
     regionEntry->getTerminator()->erase();
 

This function was inconsistent with the remaining API because it accepted `OpOperand &` that do not belong to the op. All the other functions assert. This helper function is also not really necessary, as the iter_arg number is identical to the result number.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:bufferization Bufferization infrastructure mlir:cf mlir:core MLIR Core Infrastructure mlir:linalg mlir:scf mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants