Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pdll #976

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft

Pdll #976

wants to merge 6 commits into from

Conversation

KavithaTipturMadhu
Copy link
Contributor

No description provided.


int main(int argc, char **argv) {
mlir::registerAllPasses();
mlir::tpp::registerTppCompilerPasses();
mlir::tpp::registerConvertVectorToXsmmPass();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd expect the pass registration to be already included in registerTppCompilerPasses.
Is a separate one really needed?

// to the callee to specify the expected rank in the VNNI layout as the rank
// depends on the operations we are dealing with.
bool isInVnniLayout(VnniOperandRank expectedRank, VectorType vector) {
return isInVnniLayout((int64_t)expectedRank, vector);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: use static_cast

@@ -122,17 +213,17 @@ struct IntelAMXTileConfigInsertionPass
: public impl::IntelAMXTileConfigInsertionPassBase<
IntelAMXTileConfigInsertionPass> {
void populateCombinePatterns(RewritePatternSet &patterns) {
patterns.add<IntelAMXTileConfig<xsmm::BrgemmOp, xsmm::BrgemmDispatchOp>>(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we ready to already retire the old lowering?

return xsmm::DataTypeAttr::get(rewriter.getContext(), xsmm::DataType::BF16);
return xsmm::DataTypeAttr::get(rewriter.getContext(), xsmm::DataType::F32);
// Callable object to verify if `operand` has static shape.
struct HasStaticShape {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just include StructuredOpMatcher.h for these?

}

static std::pair<Operation *, Operation *>
buildOpImpl(PatternRewriter &rewriter, Operation *contractOp, Operation *input0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you have access to rewriter, can you erase other ops here as well?
I wonder if we could have a single buildOp that on its own tries to fuse consumers to avoid combinatorial explosion of patterns.

rewrite root with{
let replacement = BuildOp(root, input0, input1, input2);
replace root with (replacement.dispatch, replacement.invoke);
let user = GetUser(replacement.dispatch);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a pdll expert but AFAIK there's no validation or matching on the user here.
I don't think we can just randomly erase it when for example:

%0 = vector.contract
%1 = arith.subf %0, ...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you already match for a transfer_write as a consumer in some other patterns so, it's probably just missing here.

FailureOr<vector::ContractionOp>
makeMinorDimensionsInnerMost(RewriterBase &rewriter,
vector::ContractionOp contractOp, unsigned m,
unsigned n, unsigned k, xsmm::DataTypeAttr type);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xsmm::DataTypeAttr needs to be removed as it still couples to Xsmm dialect

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I see it still relies on XsmmEnum in general.
I guess for now we can keep that as it's easy to generate them and we can refactor later.

@adam-smnk
Copy link
Collaborator

I'm getting pretty different results between linalg vs vector to xsmm:

Linalg test case:

tpp-opt ../test.mlir -convert-linalg-to-xsmm -convert-xsmm-to-func | tpp-run -e entry --entry-point-result=void -seed 123 -print

func.func @entry(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>, %arg2: memref<32x32xf32>) 
  -> memref<32x32xf32> {
  linalg.matmul ins(%arg0, %arg1: memref<32x32xf32>, memref<32x32xf32>)
    outs(%arg2: memref<32x32xf32>)
  return %arg2 : memref<32x32xf32>
}

Vector test case:

tpp-opt ../test.mlir -convert-vector-to-xsmm-pass | tpp-run -e entry --entry-point-result=void -seed 123 -print

#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d2, d1)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
module {
  func.func @entry(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>, %arg2: memref<32x32xf32>)
  -> memref<32x32xf32> {
    %c0 = arith.constant 0 : index
    %cst = arith.constant 0.000000e+00 : f32
    %0 = vector.transfer_read %arg0[%c0, %c0],
      %cst {in_bounds = [true, true]} : memref<32x32xf32>, vector<32x32xf32>
    %1 = vector.transfer_read %arg1[%c0, %c0],
      %cst {in_bounds = [true, true]} : memref<32x32xf32>, vector<32x32xf32>
    %2 = vector.transfer_read %arg2[%c0, %c0],
      %cst {in_bounds = [true, true]} : memref<32x32xf32>, vector<32x32xf32>
    %3 = vector.contract {indexing_maps = [#map, #map1, #map2],
      iterator_types = ["parallel", "parallel", "reduction"],
      kind = #vector.kind<add>} %0, %1, %2
      : vector<32x32xf32>, vector<32x32xf32> into vector<32x32xf32>
    vector.transfer_write %3, %arg2[%c0, %c0]
      {in_bounds = [true, true]} : vector<32x32xf32>, memref<32x32xf32>
    return %arg2 : memref<32x32xf32>
  }
}

@adam-smnk
Copy link
Collaborator

After a bit of IR diffing, the current difference comes from dispatch call:

  • created from linalg:
    %0 = call @xsmm_gemm_dispatch(%c1_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c0_i64) : (i64, i64, i64, i64, i64, i64, i64, i64) -> i64
  • created from vector:
    %0 = call @xsmm_gemm_dispatch(%c1_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1_i64, %c1_i64, %c0_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, i64, i64) -> i64

The new call builder add two extra arguments 8th and 9th %c1_i64, %c1_i64 which correspond to unit strides.
Now I'm not really sure how this "just works"^tm as we don't have a wrapper for 10 args gemm_dispatch. But it somehow runs and just produces invalid results.

@KavithaTipturMadhu
Copy link
Contributor Author

After a bit of IR diffing, the current difference comes from dispatch call:

  • created from linalg:
    %0 = call @xsmm_gemm_dispatch(%c1_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c0_i64) : (i64, i64, i64, i64, i64, i64, i64, i64) -> i64
  • created from vector:
    %0 = call @xsmm_gemm_dispatch(%c1_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c32_i64, %c1_i64, %c1_i64, %c0_i64) : (i64, i64, i64, i64, i64, i64, i64, i64, i64, i64) -> i64

The new call builder add two extra arguments 8th and 9th %c1_i64, %c1_i64 which correspond to unit strides. Now I'm not really sure how this "just works"^tm as we don't have a wrapper for 10 args gemm_dispatch. But it somehow runs and just produces invalid results.

Fixed the issue in the last commit @adam-smnk

namespace xegpu {
class XeGPUDialect;
} // namespace xegpu

} // namespace mlir

#include "TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems out of place. Why is this needed and the others are not?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants