Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][TOSA] Fix shape inference when operand was inferred #66906

Merged
merged 1 commit into from
Sep 22, 2023

Conversation

maxbartel
Copy link
Contributor

057fc8e Introduces a bug in the TosaInferShapesPass when an operand type was already inferred.

$cppClass::Adaptor adaptor(operands, attributes, properties, regions);
interprets the ValueShapeRange as a normal ValueRange and looses the information of the inference.

This PR changes the logic of the shape inference a bit. Instead of saving the type information in a DenseMap and updating the types after the whole analysis for a region, it now updates the types directly in each iteration. That way the operands always have the inferred type.

@maxbartel maxbartel changed the title [TOSA] Fix shape inference when operand was inferred [mlir][TOSA] Fix shape inference when operand was inferred Sep 20, 2023
@maxbartel
Copy link
Contributor Author

@eric-k256 @amanda849 @rsuderman I cannot add reviewers because I don't have write access

@llvmbot
Copy link
Member

llvmbot commented Sep 20, 2023

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Changes

057fc8e Introduces a bug in the TosaInferShapesPass when an operand type was already inferred.

$cppClass::Adaptor adaptor(operands, attributes, properties, regions);
interprets the ValueShapeRange as a normal ValueRange and looses the information of the inference.

This PR changes the logic of the shape inference a bit. Instead of saving the type information in a DenseMap and updating the types after the whole analysis for a region, it now updates the types directly in each iteration. That way the operands always have the inferred type.


Full diff: https://github.com/llvm/llvm-project/pull/66906.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Interfaces/InferTypeOpInterface.td (+1-1)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp (+15-41)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+13)
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
index 54c1c13fd029dbc..c5eeeaf58a7b4f8 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
@@ -223,7 +223,7 @@ def InferTypeOpAdaptorWithIsCompatible : InferTypeOpAdaptorBase<
 >;
 
 // Convenient trait to define a wrapper to inferReturnTypeComponents that passes
-// in the Op Adaptor directly
+// in the Op Adaptor directly. Only uses the current types of the operands.
 class InferShapedTypeOpAdaptorBase<list<string> overridenMethods = []> : TraitList<
   [
     // Op implements infer type op interface.
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index 9c49cd55788571b..3cc16a91edce747 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -41,8 +41,7 @@ namespace {
 
 void propagateShapesInRegion(Region &region);
 
-void propagateShapesToTosaIf(
-    Operation &op, DenseMap<Value, ShapedTypeComponents> &shapesStorage) {
+void propagateShapesToTosaIf(Operation &op) {
   IfOp ifOp = dyn_cast<IfOp>(op);
   if (!ifOp)
     return;
@@ -53,12 +52,12 @@ void propagateShapesToTosaIf(
       return;
 
     for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
-      auto inferredTy = shapesStorage[op.getOperand(i)];
+      auto inferredTy = cast<ShapedType>(op.getOperand(i).getType());
       auto blockArg = frontBlock.getArgument(i - 1);
       auto oldType = cast<ShapedType>(blockArg.getType());
 
       if (inferredTy.hasRank()) {
-        Type newType = oldType.clone(inferredTy.getDims());
+        Type newType = oldType.clone(inferredTy.getShape());
         blockArg.setType(newType);
       }
     }
@@ -79,8 +78,7 @@ void propagateShapesToTosaIf(
   }
 }
 
-void propagateShapesToTosaWhile(
-    Operation &op, DenseMap<Value, ShapedTypeComponents> &shapesStorage) {
+void propagateShapesToTosaWhile(Operation &op) {
   WhileOp whileOp = dyn_cast<WhileOp>(op);
   if (!whileOp)
     return;
@@ -91,9 +89,8 @@ void propagateShapesToTosaWhile(
   llvm::SmallVector<Type> argTypes;
   for (auto operand : op.getOperands()) {
     auto operandTy = cast<ShapedType>(operand.getType());
-    auto shapedTypeComponent = shapesStorage[operand];
-    if (shapedTypeComponent.hasRank()) {
-      auto newTy = operandTy.clone(shapedTypeComponent.getDims());
+    if (operandTy.hasRank()) {
+      auto newTy = operandTy.clone(operandTy.getShape());
       argTypes.push_back(newTy);
     } else {
       argTypes.push_back(operand.getType());
@@ -187,21 +184,6 @@ void propagateShapesToTosaWhile(
 }
 
 void propagateShapesInRegion(Region &region) {
-  DenseMap<Value, ShapedTypeComponents> shapesStorage;
-  auto setShapes = [&](Value val, Type t) {
-    if (auto st = dyn_cast<ShapedType>(t))
-      shapesStorage[val] = st;
-    else
-      shapesStorage[val] = t;
-  };
-  auto operandShape = [&](Value val) -> ShapeAdaptor {
-    // Query the WIP mapping rather than the type if set.
-    auto it = shapesStorage.find(val);
-    if (it == shapesStorage.end())
-      return nullptr;
-    return it->second;
-  };
-
   // Check whether this use case is replaceable. We define an op as
   // being replaceable if it is used by a ReturnOp, a TosaOp, or an op with a
   // type-inference related interface.
@@ -217,8 +199,8 @@ void propagateShapesInRegion(Region &region) {
       if (op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
         continue;
 
-      propagateShapesToTosaIf(op, shapesStorage);
-      propagateShapesToTosaWhile(op, shapesStorage);
+      propagateShapesToTosaIf(op);
+      propagateShapesToTosaWhile(op);
 
       InferShapedTypeOpInterface shapeInterface =
           dyn_cast<InferShapedTypeOpInterface>(op);
@@ -227,12 +209,11 @@ void propagateShapesInRegion(Region &region) {
 
       SmallVector<ShapedTypeComponents> returnedShapes;
 
-      ValueShapeRange range(op.getOperands(), operandShape);
       if (shapeInterface
-              .inferReturnTypeComponents(op.getContext(), op.getLoc(), range,
-                                         op.getDiscardableAttrDictionary(),
-                                         op.getPropertiesStorage(),
-                                         op.getRegions(), returnedShapes)
+              .inferReturnTypeComponents(
+                  op.getContext(), op.getLoc(), op.getOperands(),
+                  op.getDiscardableAttrDictionary(), op.getPropertiesStorage(),
+                  op.getRegions(), returnedShapes)
               .succeeded()) {
         for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
           Value result = std::get<0>(it);
@@ -262,20 +243,13 @@ void propagateShapesInRegion(Region &region) {
               ValueKnowledge::join(currentKnowledge, inferredKnowledge);
           if (!newKnowledge)
             continue;
-          setShapes(result, newKnowledge.getType());
+
+          // Set new type
+          result.setType(newKnowledge.getType());
         }
       }
     }
   }
-
-  // Actually update types with updated shape knowledge.
-  for (auto it : shapesStorage) {
-    auto result = it.second;
-    if (result.hasRank()) {
-      Type t = cast<ShapedType>(it.first.getType()).clone(result.getDims());
-      it.first.setType(t);
-    }
-  }
 }
 
 /// Pass that performs shape propagation across TOSA operations. This includes
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index cb96b2a8a0d193b..d468ba582483cbe 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1259,3 +1259,16 @@ func.func @test_non_tosa_consumer_extract(%arg0: tensor<4x4xf32>, %arg1: index)
   %1 = tensor.extract %0[%arg1, %arg1] : tensor<?x?xf32>
   return %1 : f32
 }
+
+// -----
+
+// CHECK-LABEL: test_tosa_use_def_chain
+func.func @test_tosa_use_def_chain(%arg0: tensor<1x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<?x16x16x16xf32> {
+  // CHECK: [[CONV:%.+]] = tosa.conv2d %arg0, %arg1, %arg2
+  // CHECK: (tensor<1x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+  %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>} : (tensor<1x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<?x32x32x16xf32>
+  // CHECK: tosa.max_pool2d [[CONV]]
+  // CHECK: (tensor<1x32x32x16xf32>) -> tensor<1x16x16x16xf32>
+  %1 = tosa.max_pool2d %0 {kernel = array<i64: 2, 2>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>} : (tensor<?x32x32x16xf32>) -> tensor<?x16x16x16xf32>
+  return %1 : tensor<?x16x16x16xf32>
+}

@eric-k256
Copy link
Contributor

Yes, the inability to assign reviewers is a pain point with the new GitHub PR system. What is a bit unusual is that llvmbot didn't add the comment until I approved running the checks. I've been using the llvmbot notification to try to watch TOSA changes and would have missed this without your comment.

As for the change itself, it looks good to me, I'll give others a chance to review and if no objections approve tomorrow.

@maxbartel
Copy link
Contributor Author

@eric-k256 Thank you for the review! Could you maybe land the changes? I don't have write access to the repository.

@eric-k256 eric-k256 merged commit e9cb582 into llvm:main Sep 22, 2023
@maxbartel
Copy link
Contributor Author

Thank you!

@jpienaar
Copy link
Member

jpienaar commented Sep 23, 2023

interprets the ValueShapeRange as a normal ValueRange and looses the information of the inference

You are correct this is a bit of a footgun (This needs to be improved).

@amrami
Copy link
Contributor

amrami commented Nov 1, 2023

Hi!
I have a test that fails now after rebasing on master.
When I revert this commit it passes.
I minimized the test case to show the failure:

Command:
mlir-opt --tosa-infer-shapes

Input:

func.func @a(%arg0: tensor<1x1x1xf32>) -> (tensor<f32>) {
    %0 = tosa.reshape %arg0 {new_shape = array<i64: 1>} : (tensor<1x1x1xf32>) -> tensor<f32>
    return %0 : tensor<f32>
}

Previous output:

  func.func @a(%arg0: tensor<1x1x1xf32>) -> tensor<f32> {
    %0 = tosa.reshape %arg0 {new_shape = array<i64: 1>} : (tensor<1x1x1xf32>) -> tensor<f32>
    return %0 : tensor<f32>
  }

Current output:

<stdin>:2:10: error: 'tosa.reshape' op result #0 must be ranked tensor of number values, but got 'tensor<*xf32>'
    %0 = tosa.reshape %arg0 {new_shape = array<i64: 1>} : (tensor<1x1x1xf32>) -> tensor<f32>
         ^
<stdin>:2:10: note: see current operation: %0 = "tosa.reshape"(%arg0) <{new_shape = array<i64: 1>}> : (tensor<1x1x1xf32>) -> tensor<*xf32>

Do you think the previous output was wrong? Or maybe something broke with this change?
Thanks in advance,
Maya

@maxbartel
Copy link
Contributor Author

I would expect a tensor<?xf32> as the correct input for the test and tensor<1xf32> as the correct output. What do you think @eric-k256?

@eric-k256
Copy link
Contributor

I think @maxbartel is right based on what we have today. It does point out a problem that there is no way to reshape down to a rank0 tensor with the current new_shape argument. The spec says rank1 to MAX_RANK today, but I think it needs to support rank 0 also (and plan to change the spec to allow it). Maybe a new_shape of array<i64: 0> as a special case for rank 0, but there are other options.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants