-
Notifications
You must be signed in to change notification settings - Fork 12.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][TOSA] Fix shape inference when operand was inferred #66906
Conversation
@eric-k256 @amanda849 @rsuderman I cannot add reviewers because I don't have write access |
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir Changes057fc8e Introduces a bug in the
ValueShapeRange as a normal ValueRange and looses the information of the inference.
This PR changes the logic of the shape inference a bit. Instead of saving the type information in a Full diff: https://github.com/llvm/llvm-project/pull/66906.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
index 54c1c13fd029dbc..c5eeeaf58a7b4f8 100644
--- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
+++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
@@ -223,7 +223,7 @@ def InferTypeOpAdaptorWithIsCompatible : InferTypeOpAdaptorBase<
>;
// Convenient trait to define a wrapper to inferReturnTypeComponents that passes
-// in the Op Adaptor directly
+// in the Op Adaptor directly. Only uses the current types of the operands.
class InferShapedTypeOpAdaptorBase<list<string> overridenMethods = []> : TraitList<
[
// Op implements infer type op interface.
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
index 9c49cd55788571b..3cc16a91edce747 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp
@@ -41,8 +41,7 @@ namespace {
void propagateShapesInRegion(Region ®ion);
-void propagateShapesToTosaIf(
- Operation &op, DenseMap<Value, ShapedTypeComponents> &shapesStorage) {
+void propagateShapesToTosaIf(Operation &op) {
IfOp ifOp = dyn_cast<IfOp>(op);
if (!ifOp)
return;
@@ -53,12 +52,12 @@ void propagateShapesToTosaIf(
return;
for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) {
- auto inferredTy = shapesStorage[op.getOperand(i)];
+ auto inferredTy = cast<ShapedType>(op.getOperand(i).getType());
auto blockArg = frontBlock.getArgument(i - 1);
auto oldType = cast<ShapedType>(blockArg.getType());
if (inferredTy.hasRank()) {
- Type newType = oldType.clone(inferredTy.getDims());
+ Type newType = oldType.clone(inferredTy.getShape());
blockArg.setType(newType);
}
}
@@ -79,8 +78,7 @@ void propagateShapesToTosaIf(
}
}
-void propagateShapesToTosaWhile(
- Operation &op, DenseMap<Value, ShapedTypeComponents> &shapesStorage) {
+void propagateShapesToTosaWhile(Operation &op) {
WhileOp whileOp = dyn_cast<WhileOp>(op);
if (!whileOp)
return;
@@ -91,9 +89,8 @@ void propagateShapesToTosaWhile(
llvm::SmallVector<Type> argTypes;
for (auto operand : op.getOperands()) {
auto operandTy = cast<ShapedType>(operand.getType());
- auto shapedTypeComponent = shapesStorage[operand];
- if (shapedTypeComponent.hasRank()) {
- auto newTy = operandTy.clone(shapedTypeComponent.getDims());
+ if (operandTy.hasRank()) {
+ auto newTy = operandTy.clone(operandTy.getShape());
argTypes.push_back(newTy);
} else {
argTypes.push_back(operand.getType());
@@ -187,21 +184,6 @@ void propagateShapesToTosaWhile(
}
void propagateShapesInRegion(Region ®ion) {
- DenseMap<Value, ShapedTypeComponents> shapesStorage;
- auto setShapes = [&](Value val, Type t) {
- if (auto st = dyn_cast<ShapedType>(t))
- shapesStorage[val] = st;
- else
- shapesStorage[val] = t;
- };
- auto operandShape = [&](Value val) -> ShapeAdaptor {
- // Query the WIP mapping rather than the type if set.
- auto it = shapesStorage.find(val);
- if (it == shapesStorage.end())
- return nullptr;
- return it->second;
- };
-
// Check whether this use case is replaceable. We define an op as
// being replaceable if it is used by a ReturnOp, a TosaOp, or an op with a
// type-inference related interface.
@@ -217,8 +199,8 @@ void propagateShapesInRegion(Region ®ion) {
if (op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace())
continue;
- propagateShapesToTosaIf(op, shapesStorage);
- propagateShapesToTosaWhile(op, shapesStorage);
+ propagateShapesToTosaIf(op);
+ propagateShapesToTosaWhile(op);
InferShapedTypeOpInterface shapeInterface =
dyn_cast<InferShapedTypeOpInterface>(op);
@@ -227,12 +209,11 @@ void propagateShapesInRegion(Region ®ion) {
SmallVector<ShapedTypeComponents> returnedShapes;
- ValueShapeRange range(op.getOperands(), operandShape);
if (shapeInterface
- .inferReturnTypeComponents(op.getContext(), op.getLoc(), range,
- op.getDiscardableAttrDictionary(),
- op.getPropertiesStorage(),
- op.getRegions(), returnedShapes)
+ .inferReturnTypeComponents(
+ op.getContext(), op.getLoc(), op.getOperands(),
+ op.getDiscardableAttrDictionary(), op.getPropertiesStorage(),
+ op.getRegions(), returnedShapes)
.succeeded()) {
for (auto it : llvm::zip(op.getResults(), returnedShapes)) {
Value result = std::get<0>(it);
@@ -262,20 +243,13 @@ void propagateShapesInRegion(Region ®ion) {
ValueKnowledge::join(currentKnowledge, inferredKnowledge);
if (!newKnowledge)
continue;
- setShapes(result, newKnowledge.getType());
+
+ // Set new type
+ result.setType(newKnowledge.getType());
}
}
}
}
-
- // Actually update types with updated shape knowledge.
- for (auto it : shapesStorage) {
- auto result = it.second;
- if (result.hasRank()) {
- Type t = cast<ShapedType>(it.first.getType()).clone(result.getDims());
- it.first.setType(t);
- }
- }
}
/// Pass that performs shape propagation across TOSA operations. This includes
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index cb96b2a8a0d193b..d468ba582483cbe 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -1259,3 +1259,16 @@ func.func @test_non_tosa_consumer_extract(%arg0: tensor<4x4xf32>, %arg1: index)
%1 = tensor.extract %0[%arg1, %arg1] : tensor<?x?xf32>
return %1 : f32
}
+
+// -----
+
+// CHECK-LABEL: test_tosa_use_def_chain
+func.func @test_tosa_use_def_chain(%arg0: tensor<1x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor<?x16x16x16xf32> {
+ // CHECK: [[CONV:%.+]] = tosa.conv2d %arg0, %arg1, %arg2
+ // CHECK: (tensor<1x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32>
+ %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array<i64: 1, 1>, pad = array<i64: 1, 1, 1, 1>, stride = array<i64: 1, 1>} : (tensor<1x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<?x32x32x16xf32>
+ // CHECK: tosa.max_pool2d [[CONV]]
+ // CHECK: (tensor<1x32x32x16xf32>) -> tensor<1x16x16x16xf32>
+ %1 = tosa.max_pool2d %0 {kernel = array<i64: 2, 2>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 2>} : (tensor<?x32x32x16xf32>) -> tensor<?x16x16x16xf32>
+ return %1 : tensor<?x16x16x16xf32>
+}
|
Yes, the inability to assign reviewers is a pain point with the new GitHub PR system. What is a bit unusual is that llvmbot didn't add the comment until I approved running the checks. I've been using the llvmbot notification to try to watch TOSA changes and would have missed this without your comment. As for the change itself, it looks good to me, I'll give others a chance to review and if no objections approve tomorrow. |
@eric-k256 Thank you for the review! Could you maybe land the changes? I don't have write access to the repository. |
Thank you! |
You are correct this is a bit of a footgun (This needs to be improved). |
Hi! Command: Input:
Previous output:
Current output:
Do you think the previous output was wrong? Or maybe something broke with this change? |
I would expect a |
I think @maxbartel is right based on what we have today. It does point out a problem that there is no way to reshape down to a rank0 tensor with the current |
057fc8e Introduces a bug in the
TosaInferShapesPass
when an operand type was already inferred.llvm-project/mlir/include/mlir/Interfaces/InferTypeOpInterface.td
Line 248 in f7bfa58
ValueShapeRange
as a normalValueRange
and looses the information of the inference.This PR changes the logic of the shape inference a bit. Instead of saving the type information in a
DenseMap
and updating the types after the whole analysis for a region, it now updates the types directly in each iteration. That way the operands always have the inferred type.