diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp index ea50e1232a4c7..5fde8d71cac3e 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp @@ -227,9 +227,6 @@ mlir::linalg::inferContractionDims(LinalgOp linalgOp) { linalgOp, linalgOp.getDpsInputOperand(1), red); llvm::set_intersect(ra, rb); - if (ac.empty() || bc.empty() || ra.empty()) - return failure(); - // Return each set in sorted order. ContractionDimensions dimensions{ SmallVector(batches.begin(), batches.end()), diff --git a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir index bad6893eaa99e..1da092ab42ad7 100644 --- a/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir +++ b/mlir/test/Dialect/Linalg/match-ops-interpreter.mlir @@ -910,6 +910,19 @@ module attributes { transform.target_tag = "start_here" } { return %result : tensor<10x15xf64> } + func.func @vecmat_simple(%lhs: tensor<20xf32>, %rhs: tensor<20x15xf32>) -> tensor<15xf64> { + %cst = arith.constant 0.0 : f64 + %empty = tensor.empty() : tensor<15xf64> + %fill = linalg.fill ins(%cst : f64) outs(%empty : tensor<15xf64>) -> tensor<15xf64> + // expected-remark @below {{contraction}} + // expected-remark @below {{batch dims}} + // expected-remark @below {{m dims}} + // expected-remark @below {{n dims 0}} + // expected-remark @below {{k dims 1}} + %result = linalg.vecmat ins(%lhs, %rhs: tensor<20xf32>, tensor<20x15xf32>) outs(%fill: tensor<15xf64>) -> tensor<15xf64> + return %result : tensor<15xf64> + } + func.func @double_batch(%lhs: tensor<40x10x50x20xf32>, %rhs: tensor<40x20x50x15xf32>) -> tensor<40x10x50x15xf32> { %cst = arith.constant 0.0 : f32 %empty = tensor.empty() : tensor<40x10x50x15xf32>