From 5f8cefebd900bbbd96961162ed9b80056e2ab95f Mon Sep 17 00:00:00 2001 From: Thomas Raoux Date: Tue, 12 Jul 2022 22:44:39 +0000 Subject: [PATCH] [mlir][vector] Fix crash in vector.reduction canonicalization since vector.reduce support accumulator in all the cases remove the assert assuming old definition. Differential Revision: https://reviews.llvm.org/D129602 --- .../mlir/Dialect/Vector/IR/VectorOps.h | 5 ++ .../mlir/Dialect/Vector/Utils/VectorUtils.h | 5 -- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 66 +++++++++++++++---- mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp | 50 -------------- mlir/test/Dialect/Vector/canonicalize.mlir | 12 ++++ 5 files changed, 70 insertions(+), 68 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h index 24c2ff5f636d..d51c5592ee3b 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -182,6 +182,11 @@ bool isDisjointTransferIndices(VectorTransferOpInterface transferA, /// memory. bool isDisjointTransferSet(VectorTransferOpInterface transferA, VectorTransferOpInterface transferB); + +/// Return the result value of reducing two scalar/vector values with the +/// corresponding arith operation. +Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, + Value v1, Value v2); } // namespace vector } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h index f6b84f1e28cd..b5e6bc1ae574 100644 --- a/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h +++ b/mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h @@ -34,11 +34,6 @@ namespace vector { /// Helper function that creates a memref::DimOp or tensor::DimOp depending on /// the type of `source`. Value createOrFoldDimOp(OpBuilder &b, Location loc, Value source, int64_t dim); - -/// Return the result value of reducing two scalar/vector values with the -/// corresponding arith operation. -Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, - Value v1, Value v2); } // namespace vector /// Return the number of elements of basis, `0` if empty. diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index f803868c2150..c50359af87b0 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -501,19 +501,9 @@ struct ElideSingleElementReduction : public OpRewritePattern { reductionOp.getVector(), rewriter.getI64ArrayAttr(0)); - if (Value acc = reductionOp.getAcc()) { - assert(reductionOp.getType().isa()); - switch (reductionOp.getKind()) { - case CombiningKind::ADD: - result = rewriter.create(loc, result, acc); - break; - case CombiningKind::MUL: - result = rewriter.create(loc, result, acc); - break; - default: - assert(false && "invalid op!"); - } - } + if (Value acc = reductionOp.getAcc()) + result = vector::makeArithReduction(rewriter, loc, reductionOp.getKind(), + result, acc); rewriter.replaceOp(reductionOp, result); return success(); @@ -5007,6 +4997,56 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) { verifyDistributedType(lhs, rhs, getWarpSize(), getOperation())); } +Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc, + CombiningKind kind, Value v1, Value v2) { + Type t1 = getElementTypeOrSelf(v1.getType()); + Type t2 = getElementTypeOrSelf(v2.getType()); + switch (kind) { + case CombiningKind::ADD: + if (t1.isIntOrIndex() && t2.isIntOrIndex()) + return b.createOrFold(loc, v1, v2); + else if (t1.isa() && t2.isa()) + return b.createOrFold(loc, v1, v2); + llvm_unreachable("invalid value types for ADD reduction"); + case CombiningKind::AND: + assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); + return b.createOrFold(loc, v1, v2); + case CombiningKind::MAXF: + assert(t1.isa() && t2.isa() && + "expected float values"); + return b.createOrFold(loc, v1, v2); + case CombiningKind::MINF: + assert(t1.isa() && t2.isa() && + "expected float values"); + return b.createOrFold(loc, v1, v2); + case CombiningKind::MAXSI: + assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); + return b.createOrFold(loc, v1, v2); + case CombiningKind::MINSI: + assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); + return b.createOrFold(loc, v1, v2); + case CombiningKind::MAXUI: + assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); + return b.createOrFold(loc, v1, v2); + case CombiningKind::MINUI: + assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); + return b.createOrFold(loc, v1, v2); + case CombiningKind::MUL: + if (t1.isIntOrIndex() && t2.isIntOrIndex()) + return b.createOrFold(loc, v1, v2); + else if (t1.isa() && t2.isa()) + return b.createOrFold(loc, v1, v2); + llvm_unreachable("invalid value types for MUL reduction"); + case CombiningKind::OR: + assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); + return b.createOrFold(loc, v1, v2); + case CombiningKind::XOR: + assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); + return b.createOrFold(loc, v1, v2); + }; + llvm_unreachable("unknown CombiningKind"); +} + //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp index 7e6d56aa622e..b979033ab471 100644 --- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp +++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp @@ -43,56 +43,6 @@ Value mlir::vector::createOrFoldDimOp(OpBuilder &b, Location loc, Value source, llvm_unreachable("Expected MemRefType or TensorType"); } -Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc, - CombiningKind kind, Value v1, Value v2) { - Type t1 = getElementTypeOrSelf(v1.getType()); - Type t2 = getElementTypeOrSelf(v2.getType()); - switch (kind) { - case CombiningKind::ADD: - if (t1.isIntOrIndex() && t2.isIntOrIndex()) - return b.createOrFold(loc, v1, v2); - else if (t1.isa() && t2.isa()) - return b.createOrFold(loc, v1, v2); - llvm_unreachable("invalid value types for ADD reduction"); - case CombiningKind::AND: - assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); - return b.createOrFold(loc, v1, v2); - case CombiningKind::MAXF: - assert(t1.isa() && t2.isa() && - "expected float values"); - return b.createOrFold(loc, v1, v2); - case CombiningKind::MINF: - assert(t1.isa() && t2.isa() && - "expected float values"); - return b.createOrFold(loc, v1, v2); - case CombiningKind::MAXSI: - assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); - return b.createOrFold(loc, v1, v2); - case CombiningKind::MINSI: - assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); - return b.createOrFold(loc, v1, v2); - case CombiningKind::MAXUI: - assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); - return b.createOrFold(loc, v1, v2); - case CombiningKind::MINUI: - assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); - return b.createOrFold(loc, v1, v2); - case CombiningKind::MUL: - if (t1.isIntOrIndex() && t2.isIntOrIndex()) - return b.createOrFold(loc, v1, v2); - else if (t1.isa() && t2.isa()) - return b.createOrFold(loc, v1, v2); - llvm_unreachable("invalid value types for MUL reduction"); - case CombiningKind::OR: - assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); - return b.createOrFold(loc, v1, v2); - case CombiningKind::XOR: - assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); - return b.createOrFold(loc, v1, v2); - }; - llvm_unreachable("unknown CombiningKind"); -} - /// Return the number of elements of basis, `0` if empty. int64_t mlir::computeMaxLinearIndex(ArrayRef basis) { if (basis.empty()) diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir index 702670095c8d..54025a626f00 100644 --- a/mlir/test/Dialect/Vector/canonicalize.mlir +++ b/mlir/test/Dialect/Vector/canonicalize.mlir @@ -1619,6 +1619,18 @@ func.func @dont_reduce_one_element_vector(%a : vector<4xf32>) -> f32 { // ----- +// CHECK-LABEL: func @reduce_one_element_vector_maxf +// CHECK-SAME: (%[[V:.+]]: vector<1xf32>, %[[B:.+]]: f32) +// CHECK: %[[A:.+]] = vector.extract %[[V]][0] : vector<1xf32> +// CHECK: %[[S:.+]] = arith.maxf %[[A]], %[[B]] : f32 +// CHECK: return %[[S]] +func.func @reduce_one_element_vector_maxf(%a : vector<1xf32>, %b: f32) -> f32 { + %s = vector.reduction , %a, %b : vector<1xf32> into f32 + return %s : f32 +} + +// ----- + // CHECK-LABEL: func @bitcast( // CHECK-SAME: %[[ARG:.*]]: vector<4x8xf32>) -> vector<4x16xi16> { // CHECK: vector.bitcast %[[ARG:.*]] : vector<4x8xf32> to vector<4x16xi16>