diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp index 8fffabf11f3fdd..2e6079e1402e1d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConstantFold.cpp @@ -23,21 +23,21 @@ using namespace mlir; using namespace mlir::linalg; namespace { -/// Base class for constant folding linalg.generic ops with N inputs, 1 output, -/// and permutation indexing maps. +/// Base class for constant folding linalg structured ops with N inputs, 1 +/// output, and permutation indexing maps. /// /// `ConcreteType` should provide methods with signatures /// /// ```c++ -/// bool matchIndexingMaps(GenericOp genericOp) const; -/// RegionComputationFn getRegionComputeFn(GenericOp) const; +/// bool matchIndexingMaps(LinalgOp linalgOp) const; +/// RegionComputationFn getRegionComputeFn(LinalgOp) const; /// ``` /// /// The latter inspects the region and returns the computation inside as a /// functor. The functor will be invoked with constant elements for all inputs /// and should return the corresponding computed constant element for output. template -class FoldConstantBase : public OpRewritePattern { +class FoldConstantBase : public OpInterfaceRewritePattern { public: struct APIntOrFloat { std::optional apInt; @@ -52,25 +52,26 @@ class FoldConstantBase : public OpRewritePattern { FoldConstantBase(MLIRContext *context, const ControlFusionFn &controlFn, PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), controlFn(controlFn) {} + : OpInterfaceRewritePattern(context, benefit), + controlFn(controlFn) {} - LogicalResult matchAndRewrite(GenericOp genericOp, + LogicalResult matchAndRewrite(LinalgOp linalgOp, PatternRewriter &rewriter) const override { // Mixed and buffer sematics aren't supported. - if (!genericOp.hasPureTensorSemantics()) + if (!linalgOp.hasPureTensorSemantics()) return failure(); // Only support ops generating one output for now. - if (genericOp.getNumDpsInits() != 1) + if (linalgOp.getNumDpsInits() != 1) return failure(); - auto outputType = dyn_cast(genericOp.getResultTypes().front()); + auto outputType = dyn_cast(linalgOp->getResultTypes().front()); // Require the output types to be static given that we are generating // constants. if (!outputType || !outputType.hasStaticShape()) return failure(); - if (!llvm::all_of(genericOp.getInputs(), [](Value input) { + if (!llvm::all_of(linalgOp.getDpsInputs(), [](Value input) { return isa(input.getType()); })) return failure(); @@ -80,7 +81,7 @@ class FoldConstantBase : public OpRewritePattern { return cast(value.getType()).getElementType(); }; if (!llvm::all_equal( - llvm::map_range(genericOp->getOperands(), getOperandElementType))) + llvm::map_range(linalgOp->getOperands(), getOperandElementType))) return failure(); // We can only handle the case where we have int/float elements. @@ -93,30 +94,30 @@ class FoldConstantBase : public OpRewritePattern { // entirely in the compiler, without needing to turn all indices into // Values, and then do affine apply on them, and then match back the // constant again. - if (!llvm::all_of(genericOp.getIndexingMapsArray(), + if (!llvm::all_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) { return map.isPermutation(); })) return failure(); - for (OpOperand &operand : genericOp.getDpsInitsMutable()) { - if (genericOp.payloadUsesValueFromOperand(&operand)) + for (OpOperand &operand : linalgOp.getDpsInitsMutable()) { + if (linalgOp.payloadUsesValueFromOperand(&operand)) return failure(); } // Further check the indexing maps are okay for the ConcreteType. - if (!static_cast(this)->matchIndexingMaps(genericOp)) + if (!static_cast(this)->matchIndexingMaps(linalgOp)) return failure(); // Defer to the concrete type to check the region and discover the // computation inside. RegionComputationFn computeFn = - static_cast(this)->getRegionComputeFn(genericOp); + static_cast(this)->getRegionComputeFn(linalgOp); if (!computeFn) return failure(); // All inputs should be constants. - int numInputs = genericOp.getNumDpsInputs(); + int numInputs = linalgOp.getNumDpsInputs(); SmallVector inputValues(numInputs); - for (const auto &en : llvm::enumerate(genericOp.getDpsInputOperands())) { + for (const auto &en : llvm::enumerate(linalgOp.getDpsInputOperands())) { if (!matchPattern(en.value()->get(), m_Constant(&inputValues[en.index()]))) return failure(); @@ -124,12 +125,11 @@ class FoldConstantBase : public OpRewritePattern { // Identified this as a potential candidate for folding. Now check the // policy to see whether we are allowed to proceed. - for (OpOperand *operand : genericOp.getDpsInputOperands()) { + for (OpOperand *operand : linalgOp.getDpsInputOperands()) { if (!controlFn(operand)) return failure(); } - auto linalgOp = cast(genericOp.getOperation()); SmallVector loopBounds = linalgOp.computeStaticLoopSizes(); int64_t numElements = outputType.getNumElements(); @@ -155,8 +155,8 @@ class FoldConstantBase : public OpRewritePattern { SmallVector> inputDims; for (int i = 0; i < numInputs; ++i) - inputDims.push_back(getDimPositions(genericOp.getIndexingMapsArray()[i])); - auto outputDims = getDimPositions(genericOp.getIndexingMapsArray().back()); + inputDims.push_back(getDimPositions(linalgOp.getIndexingMapsArray()[i])); + auto outputDims = getDimPositions(linalgOp.getIndexingMapsArray().back()); auto outputShape = outputType.getShape(); // Allocate small vectors for index delinearization. Initial values do not @@ -173,7 +173,7 @@ class FoldConstantBase : public OpRewritePattern { APIntOrFloatArray computeFnInputs; auto inputShapes = llvm::to_vector<4>( - llvm::map_range(genericOp.getInputs(), [](Value value) { + llvm::map_range(linalgOp.getDpsInputs(), [](Value value) { return cast(value.getType()).getShape(); })); @@ -254,7 +254,7 @@ class FoldConstantBase : public OpRewritePattern { isFloat ? DenseElementsAttr::get(outputType, fpOutputValues) : DenseElementsAttr::get(outputType, intOutputValues); - rewriter.replaceOpWithNewOp(genericOp, outputAttr); + rewriter.replaceOpWithNewOp(linalgOp, outputAttr); return success(); } @@ -262,18 +262,20 @@ class FoldConstantBase : public OpRewritePattern { ControlFusionFn controlFn; }; -// Folds linalg.generic ops that are actually transposes on constant values. +// Folds linalg.transpose (and linalg.generic ops that are actually transposes) +// on constant values. struct FoldConstantTranspose : public FoldConstantBase { + using FoldConstantBase::FoldConstantBase; - bool matchIndexingMaps(GenericOp genericOp) const { + bool matchIndexingMaps(LinalgOp linalgOp) const { // We should have one input and one output. - return genericOp.getIndexingMapsArray().size() == 2; + return linalgOp.getIndexingMapsArray().size() == 2; } - RegionComputationFn getRegionComputeFn(GenericOp genericOp) const { + RegionComputationFn getRegionComputeFn(LinalgOp linalgOp) const { // Make sure the region only contains a yield op. - Block &body = genericOp.getRegion().front(); + Block &body = linalgOp->getRegion(0).front(); if (!llvm::hasSingleElement(body)) return nullptr; auto yieldOp = dyn_cast(body.getTerminator()); diff --git a/mlir/test/Dialect/Linalg/constant-fold.mlir b/mlir/test/Dialect/Linalg/constant-fold.mlir new file mode 100644 index 00000000000000..3929c26a3382f4 --- /dev/null +++ b/mlir/test/Dialect/Linalg/constant-fold.mlir @@ -0,0 +1,148 @@ +// RUN: mlir-opt %s -linalg-fuse-elementwise-ops -split-input-file | FileCheck %s + +// CHECK-LABEL: @transpose_fold_2d_fp32 +func.func @transpose_fold_2d_fp32(%init: tensor<3x2xf32>) -> tensor<3x2xf32> { + %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32> + // CHECK: %[[CST:.+]] = arith.constant + // CHECK-SAME{LITERAL}: dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32> + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + linalg.yield %arg1 : f32 + } -> tensor<3x2xf32> + // CHECK: return %[[CST]] + return %1 : tensor<3x2xf32> +} + +// ----- + +// CHECK-LABEL: @transpose_fold_2d_fp64 +func.func @transpose_fold_2d_fp64(%init: tensor<3x2xf64>) -> tensor<3x2xf64> { + %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf64> + // CHECK: %[[CST:.+]] = arith.constant + // CHECK-SAME{LITERAL}: dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf64> + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%input : tensor<2x3xf64>) outs(%init : tensor<3x2xf64>) { + ^bb0(%arg1: f64, %arg2: f64): + linalg.yield %arg1 : f64 + } -> tensor<3x2xf64> + // CHECK: return %[[CST]] + return %1 : tensor<3x2xf64> +} + +// ----- + +// CHECK-LABEL: @transpose_fold_4d_i32 +func.func @transpose_fold_4d_i32(%init: tensor<3x1x4x2xi32>) -> tensor<3x1x4x2xi32> { + %input = arith.constant dense<[[ + [[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], + [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]] + ]]> : tensor<1x2x3x4xi32> + // CHECK: %[[CST:.+]] = arith.constant dense<[ + // CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]], + // CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]], + // CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]] + // CHECK-SAME{LITERAL}: ]> + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"] + } ins(%input : tensor<1x2x3x4xi32>) outs(%init : tensor<3x1x4x2xi32>) { + ^bb0(%arg1: i32, %arg2: i32): + linalg.yield %arg1 : i32 + } -> tensor<3x1x4x2xi32> + // CHECK: return %[[CST]] + return %1 : tensor<3x1x4x2xi32> +} + +// ----- + +// CHECK-LABEL: @transpose_fold_4d_i16 +func.func @transpose_fold_4d_i16(%init: tensor<3x1x4x2xi16>) -> tensor<3x1x4x2xi16> { + %input = arith.constant dense<[[ + [[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], + [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]] + ]]> : tensor<1x2x3x4xi16> + // CHECK: %[[CST:.+]] = arith.constant dense<[ + // CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]], + // CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]], + // CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]] + // CHECK-SAME{LITERAL}: ]> + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>], + iterator_types = ["parallel", "parallel", "parallel", "parallel"] + } ins(%input : tensor<1x2x3x4xi16>) outs(%init : tensor<3x1x4x2xi16>) { + ^bb0(%arg1: i16, %arg2: i16): + linalg.yield %arg1 : i16 + } -> tensor<3x1x4x2xi16> + // CHECK: return %[[CST]] + return %1 : tensor<3x1x4x2xi16> +} + +// ----- + +// CHECK-LABEL: @transpose_nofold_non_cst_input +func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>, %init: tensor<3x2xf32>) -> tensor<3x2xf32> { + // CHECK: linalg.generic + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + linalg.yield %arg1 : f32 + } -> tensor<3x2xf32> + return %1 : tensor<3x2xf32> +} + +// ----- + +// CHECK-LABEL: @transpose_nofold_yield_const +func.func @transpose_nofold_yield_const(%init: tensor<3x2xf32>) -> tensor<3x2xf32> { + %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32> + %cst = arith.constant 8.0 : f32 + // CHECK: linalg.generic + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + linalg.yield %cst : f32 + } -> tensor<3x2xf32> + return %1 : tensor<3x2xf32> +} + +// ----- + +// CHECK-LABEL: @transpose_nofold_multi_ops_in_region +func.func @transpose_nofold_multi_ops_in_region(%init: tensor<3x2xf32>) -> tensor<3x2xf32> { + %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32> + // CHECK: linalg.generic + %1 = linalg.generic { + indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], + iterator_types = ["parallel", "parallel"] + } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) { + ^bb0(%arg1: f32, %arg2: f32): + %add = arith.addf %arg1, %arg1 : f32 + linalg.yield %add : f32 + } -> tensor<3x2xf32> + return %1 : tensor<3x2xf32> +} + +// ----- + +// CHECK-LABEL: @named_transpose_fold_2d_fp32 +func.func @named_transpose_fold_2d_fp32(%init: tensor<3x2xf32>) -> tensor<3x2xf32> { + %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32> + // CHECK: %[[CST:.+]] = arith.constant + // CHECK-SAME{LITERAL}: dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32> + %1 = linalg.transpose ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) permutation = [1, 0] + // CHECK: return %[[CST]] + return %1 : tensor<3x2xf32> +} + +// ----- + + diff --git a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir index 15a4f6cdd3bbe4..e45a9fbb1052c1 100644 --- a/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir +++ b/mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir @@ -777,139 +777,6 @@ func.func @fuse_scalar_constant(%arg0 : tensor) -> (tensor, te // ----- -// CHECK-LABEL: @transpose_fold_2d_fp32 -func.func @transpose_fold_2d_fp32(%init: tensor<3x2xf32>) -> tensor<3x2xf32> { - %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32> - // CHECK: %[[CST:.+]] = arith.constant - // CHECK-SAME{LITERAL}: dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32> - %1 = linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"] - } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) { - ^bb0(%arg1: f32, %arg2: f32): - linalg.yield %arg1 : f32 - } -> tensor<3x2xf32> - // CHECK: return %[[CST]] - return %1 : tensor<3x2xf32> -} - -// ----- - -// CHECK-LABEL: @transpose_fold_2d_fp64 -func.func @transpose_fold_2d_fp64(%init: tensor<3x2xf64>) -> tensor<3x2xf64> { - %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf64> - // CHECK: %[[CST:.+]] = arith.constant - // CHECK-SAME{LITERAL}: dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf64> - %1 = linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"] - } ins(%input : tensor<2x3xf64>) outs(%init : tensor<3x2xf64>) { - ^bb0(%arg1: f64, %arg2: f64): - linalg.yield %arg1 : f64 - } -> tensor<3x2xf64> - // CHECK: return %[[CST]] - return %1 : tensor<3x2xf64> -} - -// ----- - -// CHECK-LABEL: @transpose_fold_4d_i32 -func.func @transpose_fold_4d_i32(%init: tensor<3x1x4x2xi32>) -> tensor<3x1x4x2xi32> { - %input = arith.constant dense<[[ - [[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], - [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]] - ]]> : tensor<1x2x3x4xi32> - // CHECK: %[[CST:.+]] = arith.constant dense<[ - // CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]], - // CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]], - // CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]] - // CHECK-SAME{LITERAL}: ]> - %1 = linalg.generic { - indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>], - iterator_types = ["parallel", "parallel", "parallel", "parallel"] - } ins(%input : tensor<1x2x3x4xi32>) outs(%init : tensor<3x1x4x2xi32>) { - ^bb0(%arg1: i32, %arg2: i32): - linalg.yield %arg1 : i32 - } -> tensor<3x1x4x2xi32> - // CHECK: return %[[CST]] - return %1 : tensor<3x1x4x2xi32> -} - -// ----- - -// CHECK-LABEL: @transpose_fold_4d_i16 -func.func @transpose_fold_4d_i16(%init: tensor<3x1x4x2xi16>) -> tensor<3x1x4x2xi16> { - %input = arith.constant dense<[[ - [[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], - [[12, 13, 14, 15], [16, 17, 18, 19], [20, 21, 22, 23]] - ]]> : tensor<1x2x3x4xi16> - // CHECK: %[[CST:.+]] = arith.constant dense<[ - // CHECK-SAME{LITERAL}: [[[0, 12], [1, 13], [2, 14], [3, 15]]], - // CHECK-SAME{LITERAL}: [[[4, 16], [5, 17], [6, 18], [7, 19]]], - // CHECK-SAME{LITERAL}: [[[8, 20], [9, 21], [10, 22], [11, 23]]] - // CHECK-SAME{LITERAL}: ]> - %1 = linalg.generic { - indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d2, d0, d3, d1)>], - iterator_types = ["parallel", "parallel", "parallel", "parallel"] - } ins(%input : tensor<1x2x3x4xi16>) outs(%init : tensor<3x1x4x2xi16>) { - ^bb0(%arg1: i16, %arg2: i16): - linalg.yield %arg1 : i16 - } -> tensor<3x1x4x2xi16> - // CHECK: return %[[CST]] - return %1 : tensor<3x1x4x2xi16> -} - -// ----- - -// CHECK-LABEL: @transpose_nofold_non_cst_input -func.func @transpose_nofold_non_cst_input(%input: tensor<2x3xf32>, %init: tensor<3x2xf32>) -> tensor<3x2xf32> { - // CHECK: linalg.generic - %1 = linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"] - } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) { - ^bb0(%arg1: f32, %arg2: f32): - linalg.yield %arg1 : f32 - } -> tensor<3x2xf32> - return %1 : tensor<3x2xf32> -} - -// ----- - -// CHECK-LABEL: @transpose_nofold_yield_const -func.func @transpose_nofold_yield_const(%init: tensor<3x2xf32>) -> tensor<3x2xf32> { - %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32> - %cst = arith.constant 8.0 : f32 - // CHECK: linalg.generic - %1 = linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"] - } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) { - ^bb0(%arg1: f32, %arg2: f32): - linalg.yield %cst : f32 - } -> tensor<3x2xf32> - return %1 : tensor<3x2xf32> -} - -// ----- - -// CHECK-LABEL: @transpose_nofold_multi_ops_in_region -func.func @transpose_nofold_multi_ops_in_region(%init: tensor<3x2xf32>) -> tensor<3x2xf32> { - %input = arith.constant dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32> - // CHECK: linalg.generic - %1 = linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"] - } ins(%input : tensor<2x3xf32>) outs(%init : tensor<3x2xf32>) { - ^bb0(%arg1: f32, %arg2: f32): - %add = arith.addf %arg1, %arg1 : f32 - linalg.yield %add : f32 - } -> tensor<3x2xf32> - return %1 : tensor<3x2xf32> -} - -// ----- - // Fusing the broadcast into a reduction would require to insert extra knowledge // about the size of the reduction dimension. As long, as this is not // implemented, we check that two linalg operations remain.