Skip to content

Allow empty dimension arrays in linalg::inferContractionDims #69496

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 19, 2023

Conversation

bjacob
Copy link
Contributor

@bjacob bjacob commented Oct 18, 2023

This function was returning failure when any of the intersection sets was empty, but this is actually legitimate in "matrix times vector" cases, where some of the operands have lower dimensionality, implying unit-dimension semantics for the "missing" dimensions.

Example:

func.func @transpose_extend_batch_matmul(
    %vec: tensor<32x128xi16>,
    %mat: tensor<11008x32x128xi4>) -> tensor<11008x32xi32> {
  %c0_i32 = arith.constant 0 : i32
  %cst_0 = arith.constant 0.000000e+00 : f32
  %0 = tensor.empty() : tensor<11008x32xi32>
  %1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor<11008x32xi32>) -> tensor<11008x32xi32>
  %2 = tensor.empty() : tensor<11008xf32>
  %3 = linalg.fill ins(%cst_0 : f32) outs(%2 : tensor<11008xf32>) -> tensor<11008xf32>
  %batch_matmul_result = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d2)>, 
                                                          affine_map<(d0, d1, d2) -> (d0, d1, d2)>, 
                                                          affine_map<(d0, d1, d2) -> (d0, d1)>], 
                                         iterator_types = ["parallel", "parallel", "reduction"]} 
                                         ins(%vec, %mat : tensor<32x128xi16>, tensor<11008x32x128xi4>) 
                                         outs(%1 : tensor<11008x32xi32>) {
  ^bb0(%in: i16, %in_3: i4, %out: i32):
      %19 = arith.extsi %in : i16 to i32
      %20 = arith.extui %in_3 : i4 to i32
      %21 = arith.muli %19, %20 : i32
      %22 = arith.addi %21, %out : i32
      linalg.yield %22 : i32
  } -> tensor<11008x32xi32>
  return %batch_matmul_result : tensor<11008x32xi32>
}

Here, we were returning failure because ac is empty. With this PR, we return this useful information:

batch: [ 1 ]
m: [ ]
n: [ 0 ]
k: [ 2 ]

@bjacob bjacob marked this pull request as ready for review October 18, 2023 18:43
@llvmbot
Copy link
Member

llvmbot commented Oct 18, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-linalg

Author: None (bjacob)

Changes

This function was returning failure when any of the intersection sets was empty, but this is actually legitimate in "matrix times vector" cases, where some of the operands have lower dimensionality, implying unit-dimension semantics for the "missing" dimensions.

Example:

func.func @<!-- -->transpose_extend_batch_matmul(
    %vec: tensor&lt;32x128xi16&gt;,
    %mat: tensor&lt;11008x32x128xi4&gt;) -&gt; tensor&lt;11008x32xi32&gt; {
  %c0_i32 = arith.constant 0 : i32
  %cst_0 = arith.constant 0.000000e+00 : f32
  %0 = tensor.empty() : tensor&lt;11008x32xi32&gt;
  %1 = linalg.fill ins(%c0_i32 : i32) outs(%0 : tensor&lt;11008x32xi32&gt;) -&gt; tensor&lt;11008x32xi32&gt;
  %2 = tensor.empty() : tensor&lt;11008xf32&gt;
  %3 = linalg.fill ins(%cst_0 : f32) outs(%2 : tensor&lt;11008xf32&gt;) -&gt; tensor&lt;11008xf32&gt;
  %batch_matmul_result = linalg.generic {indexing_maps = [affine_map&lt;(d0, d1, d2) -&gt; (d1, d2)&gt;, 
                                                          affine_map&lt;(d0, d1, d2) -&gt; (d0, d1, d2)&gt;, 
                                                          affine_map&lt;(d0, d1, d2) -&gt; (d0, d1)&gt;], 
                                         iterator_types = ["parallel", "parallel", "reduction"]} 
                                         ins(%vec, %mat : tensor&lt;32x128xi16&gt;, tensor&lt;11008x32x128xi4&gt;) 
                                         outs(%1 : tensor&lt;11008x32xi32&gt;) {
  ^bb0(%in: i16, %in_3: i4, %out: i32):
      %19 = arith.extsi %in : i16 to i32
      %20 = arith.extui %in_3 : i4 to i32
      %21 = arith.muli %19, %20 : i32
      %22 = arith.addi %21, %out : i32
      linalg.yield %22 : i32
  } -&gt; tensor&lt;11008x32xi32&gt;
  return %batch_matmul_result : tensor&lt;11008x32xi32&gt;
}

Here, we were returning failure because ac is empty. With this PR, we return this useful information:

batch: [ 1 ]
m: [ ]
n: [ 0 ]
k: [ 2 ]

Full diff: https://github.com/llvm/llvm-project/pull/69496.diff

1 Files Affected:

  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp (-3)
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
index ea50e1232a4c74a..5fde8d71cac3e75 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
@@ -227,9 +227,6 @@ mlir::linalg::inferContractionDims(LinalgOp linalgOp) {
       linalgOp, linalgOp.getDpsInputOperand(1), red);
   llvm::set_intersect(ra, rb);
 
-  if (ac.empty() || bc.empty() || ra.empty())
-    return failure();
-
   // Return each set in sorted order.
   ContractionDimensions dimensions{
       SmallVector<unsigned, 2>(batches.begin(), batches.end()),

Copy link
Contributor

@nicolasvasilache nicolasvasilache left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we please add a test ?

Approving conditioned on a test.

Thank you @bjacob !

@bjacob
Copy link
Contributor Author

bjacob commented Oct 18, 2023

Thanks @nicolasvasilache for the review. How should I test this? This C++ helper function seems to be used in
bool mlir::linalg::isaContractionOpInterface(LinalgOp linalgOp) and in linalg::packMatmulGreedily. It's not readily clearly to me how to test that.

@nicolasvasilache
Copy link
Contributor

It should be easy to add a simple test transform op to the test transform dialect that just prints something.

@bjacob
Copy link
Contributor Author

bjacob commented Oct 19, 2023

Thanks! Actually I figured that the existing match ops were enough to test this. This works -- this new test passes with this change and fails without it. WDYT?

@bjacob
Copy link
Contributor Author

bjacob commented Oct 19, 2023

ok, going to merge this.

@bjacob bjacob merged commit 2ae37be into llvm:main Oct 19, 2023
@nicolasvasilache
Copy link
Contributor

Yeah, any proper test is good, great thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants