diff --git a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td index 54c1c13fd029db..c5eeeaf58a7b4f 100644 --- a/mlir/include/mlir/Interfaces/InferTypeOpInterface.td +++ b/mlir/include/mlir/Interfaces/InferTypeOpInterface.td @@ -223,7 +223,7 @@ def InferTypeOpAdaptorWithIsCompatible : InferTypeOpAdaptorBase< >; // Convenient trait to define a wrapper to inferReturnTypeComponents that passes -// in the Op Adaptor directly +// in the Op Adaptor directly. Only uses the current types of the operands. class InferShapedTypeOpAdaptorBase overridenMethods = []> : TraitList< [ // Op implements infer type op interface. diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp index 9c49cd55788571..3cc16a91edce74 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaInferShapes.cpp @@ -41,8 +41,7 @@ namespace { void propagateShapesInRegion(Region ®ion); -void propagateShapesToTosaIf( - Operation &op, DenseMap &shapesStorage) { +void propagateShapesToTosaIf(Operation &op) { IfOp ifOp = dyn_cast(op); if (!ifOp) return; @@ -53,12 +52,12 @@ void propagateShapesToTosaIf( return; for (unsigned int i = 1, s = op.getNumOperands(); i < s; i++) { - auto inferredTy = shapesStorage[op.getOperand(i)]; + auto inferredTy = cast(op.getOperand(i).getType()); auto blockArg = frontBlock.getArgument(i - 1); auto oldType = cast(blockArg.getType()); if (inferredTy.hasRank()) { - Type newType = oldType.clone(inferredTy.getDims()); + Type newType = oldType.clone(inferredTy.getShape()); blockArg.setType(newType); } } @@ -79,8 +78,7 @@ void propagateShapesToTosaIf( } } -void propagateShapesToTosaWhile( - Operation &op, DenseMap &shapesStorage) { +void propagateShapesToTosaWhile(Operation &op) { WhileOp whileOp = dyn_cast(op); if (!whileOp) return; @@ -91,9 +89,8 @@ void propagateShapesToTosaWhile( llvm::SmallVector argTypes; for (auto operand : op.getOperands()) { auto operandTy = cast(operand.getType()); - auto shapedTypeComponent = shapesStorage[operand]; - if (shapedTypeComponent.hasRank()) { - auto newTy = operandTy.clone(shapedTypeComponent.getDims()); + if (operandTy.hasRank()) { + auto newTy = operandTy.clone(operandTy.getShape()); argTypes.push_back(newTy); } else { argTypes.push_back(operand.getType()); @@ -187,21 +184,6 @@ void propagateShapesToTosaWhile( } void propagateShapesInRegion(Region ®ion) { - DenseMap shapesStorage; - auto setShapes = [&](Value val, Type t) { - if (auto st = dyn_cast(t)) - shapesStorage[val] = st; - else - shapesStorage[val] = t; - }; - auto operandShape = [&](Value val) -> ShapeAdaptor { - // Query the WIP mapping rather than the type if set. - auto it = shapesStorage.find(val); - if (it == shapesStorage.end()) - return nullptr; - return it->second; - }; - // Check whether this use case is replaceable. We define an op as // being replaceable if it is used by a ReturnOp, a TosaOp, or an op with a // type-inference related interface. @@ -217,8 +199,8 @@ void propagateShapesInRegion(Region ®ion) { if (op.getDialect()->getNamespace() != TosaDialect::getDialectNamespace()) continue; - propagateShapesToTosaIf(op, shapesStorage); - propagateShapesToTosaWhile(op, shapesStorage); + propagateShapesToTosaIf(op); + propagateShapesToTosaWhile(op); InferShapedTypeOpInterface shapeInterface = dyn_cast(op); @@ -227,12 +209,11 @@ void propagateShapesInRegion(Region ®ion) { SmallVector returnedShapes; - ValueShapeRange range(op.getOperands(), operandShape); if (shapeInterface - .inferReturnTypeComponents(op.getContext(), op.getLoc(), range, - op.getDiscardableAttrDictionary(), - op.getPropertiesStorage(), - op.getRegions(), returnedShapes) + .inferReturnTypeComponents( + op.getContext(), op.getLoc(), op.getOperands(), + op.getDiscardableAttrDictionary(), op.getPropertiesStorage(), + op.getRegions(), returnedShapes) .succeeded()) { for (auto it : llvm::zip(op.getResults(), returnedShapes)) { Value result = std::get<0>(it); @@ -262,20 +243,13 @@ void propagateShapesInRegion(Region ®ion) { ValueKnowledge::join(currentKnowledge, inferredKnowledge); if (!newKnowledge) continue; - setShapes(result, newKnowledge.getType()); + + // Set new type + result.setType(newKnowledge.getType()); } } } } - - // Actually update types with updated shape knowledge. - for (auto it : shapesStorage) { - auto result = it.second; - if (result.hasRank()) { - Type t = cast(it.first.getType()).clone(result.getDims()); - it.first.setType(t); - } - } } /// Pass that performs shape propagation across TOSA operations. This includes diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index cb96b2a8a0d193..d468ba582483cb 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -1259,3 +1259,16 @@ func.func @test_non_tosa_consumer_extract(%arg0: tensor<4x4xf32>, %arg1: index) %1 = tensor.extract %0[%arg1, %arg1] : tensor return %1 : f32 } + +// ----- + +// CHECK-LABEL: test_tosa_use_def_chain +func.func @test_tosa_use_def_chain(%arg0: tensor<1x32x32x3xf32>, %arg1: tensor<16x3x3x3xf32>, %arg2: tensor<16xf32>) -> tensor { + // CHECK: [[CONV:%.+]] = tosa.conv2d %arg0, %arg1, %arg2 + // CHECK: (tensor<1x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor<1x32x32x16xf32> + %0 = tosa.conv2d %arg0, %arg1, %arg2 {dilation = array, pad = array, stride = array} : (tensor<1x32x32x3xf32>, tensor<16x3x3x3xf32>, tensor<16xf32>) -> tensor + // CHECK: tosa.max_pool2d [[CONV]] + // CHECK: (tensor<1x32x32x16xf32>) -> tensor<1x16x16x16xf32> + %1 = tosa.max_pool2d %0 {kernel = array, pad = array, stride = array} : (tensor) -> tensor + return %1 : tensor +}