diff --git a/test/Conversion/LinalgToXsmm/linalg-to-binary.mlir b/test/Conversion/LinalgToXsmm/linalg-to-binary.mlir deleted file mode 100644 index ece08e063..000000000 --- a/test/Conversion/LinalgToXsmm/linalg-to-binary.mlir +++ /dev/null @@ -1,721 +0,0 @@ -// RUN: tpp-opt %s -convert-linalg-to-xsmm -split-input-file | FileCheck %s - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (0, d1)> - -func.func @add_bcast_col_operand_1(%arg0: memref<256x1024xf32>, %arg1: memref<1x1024xf32>) { - linalg.generic { - indexing_maps = [#map, #map1, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %arg1 : memref<256x1024xf32>, memref<1x1024xf32>) - outs(%arg0 : memref<256x1024xf32>) { - ^bb0(%in: f32, %in_6: f32, %out: f32): - %6 = arith.addf %in, %in_6 : f32 - linalg.yield %6 : f32 - } - return -} - -// CHECK-LABEL: add_bcast_col_operand_1 -// CHECK-SAME: %[[ARG0:.+]]: memref<256x1024xf32>, %[[ARG1:.+]]: memref<1x1024xf32> -// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch add [256, 1024, 1024, 1024, 1024] -// CHECK-SAME: flags = (bcast_col_in1) data_type = f32 -// CHECK: xsmm.binary add(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG1]], %[[ARG0]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (0, 0)> - -func.func @add_1(%arg0: memref<256x1024xf32>, %arg1: memref<1x1xf32>) { - linalg.generic { - indexing_maps = [#map, #map1, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %arg1 : memref<256x1024xf32>, memref<1x1xf32>) - outs(%arg0 : memref<256x1024xf32>) { - ^bb0(%in: f32, %in_6: f32, %out: f32): - %6 = arith.addf %in, %in_6 : f32 - linalg.yield %6 : f32 - } - return -} - -// CHECK-LABEL: add_1 -// CHECK-SAME: %[[ARG0:.+]]: memref<256x1024xf32>, %[[ARG1:.+]]: memref<1x1xf32> -// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch add [256, 1024, 1024, 1, 1024] -// CHECK-SAME: flags = (bcast_scalar_in1) data_type = f32 -// CHECK: xsmm.binary add(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG1]], %[[ARG0]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (d0, 0)> - -func.func @add_bcast_row_operand_1(%arg0: memref<256x1024xf32>, %arg1: memref<256x1xf32>) { - linalg.generic { - indexing_maps = [#map, #map1, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %arg1 : memref<256x1024xf32>, memref<256x1xf32>) - outs(%arg0 : memref<256x1024xf32>) { - ^bb0(%in: f32, %in_6: f32, %out: f32): - %6 = arith.addf %in, %in_6 : f32 - linalg.yield %6 : f32 - } - return -} - -// CHECK-LABEL: add_bcast_row_operand_1 -// CHECK-SAME: %[[ARG0:.+]]: memref<256x1024xf32>, %[[ARG1:.+]]: memref<256x1xf32> -// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch add [256, 1024, 1024, 1, 1024] -// CHECK-SAME: flags = (bcast_row_in1) data_type = f32 -// CHECK: xsmm.binary add(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG1]], %[[ARG0]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (d0, 0)> - -func.func @add_bcast_row_operand_0(%arg0: memref<256x1024xf32>, %arg1: memref<256x1xf32>) { - linalg.generic { - indexing_maps = [#map1, #map, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg1, %arg0 : memref<256x1xf32>, memref<256x1024xf32>) - outs(%arg0 : memref<256x1024xf32>) { - ^bb0(%in: f32, %in_6: f32, %out: f32): - %6 = arith.addf %in, %in_6 : f32 - linalg.yield %6 : f32 - } - return -} - -// CHECK-LABEL: add_bcast_row_operand_0 -// CHECK-SAME: %[[ARG0:.+]]: memref<256x1024xf32>, %[[ARG1:.+]]: memref<256x1xf32> -// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch add [256, 1024, 1, 1024, 1024] -// CHECK-SAME: flags = (bcast_row_in0) data_type = f32 -// CHECK: xsmm.binary add(data_type = f32, %[[DIS]], %[[ARG1]], %[[ARG0]], %[[ARG0]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (d1)> - -func.func @add_3(%arg0: memref<256x1024xf32>, %arg1: memref<1024xf32>) { - linalg.generic { - indexing_maps = [#map, #map1, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %arg1 : memref<256x1024xf32>, memref<1024xf32>) - outs(%arg0 : memref<256x1024xf32>) { - ^bb0(%in: f32, %in_6: f32, %out: f32): - %6 = arith.addf %in, %in_6 : f32 - linalg.yield %6 : f32 - } - return -} - -// CHECK-LABEL: add_3 -// CHECK-SAME: %[[ARG0:.+]]: memref<256x1024xf32>, %[[ARG1:.+]]: memref<1024xf32> -// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch add [256, 1024, 1024, 1024, 1024] -// CHECK-SAME: flags = (bcast_col_in1) data_type = f32 -// CHECK: xsmm.binary add(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG1]], %[[ARG0]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> ()> - -func.func @add_bcast_scalar_operand_1(%arg0: memref<256x1024xf32>, %arg1: f32) { - linalg.generic { - indexing_maps = [#map, #map1, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %arg1 : memref<256x1024xf32>, f32) - outs(%arg0 : memref<256x1024xf32>) { - ^bb0(%in: f32, %in_6: f32, %out: f32): - %6 = arith.addf %in, %in_6 : f32 - linalg.yield %6 : f32 - } - return -} - -// CHECK-LABEL: add_bcast_scalar_operand_1 -// CHECK-SAME: %[[ARG0:.+]]: memref<256x1024xf32>, %[[ARG1:.+]]: f32 -// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch add [256, 1024, 1024, 1, 1024] flags = (bcast_scalar_in1) data_type = f32 -// CHECK: xsmm.binary add(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG1]], %[[ARG0]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> ()> - -func.func @add_bcast_scalar_operand_0(%arg0: memref<256x1024xf32>, %arg1: f32) { - linalg.generic { - indexing_maps = [#map1, #map, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg1, %arg0 : f32, memref<256x1024xf32>) - outs(%arg0 : memref<256x1024xf32>) { - ^bb0(%in: f32, %in_6: f32, %out: f32): - %6 = arith.addf %in, %in_6 : f32 - linalg.yield %6 : f32 - } - return -} - -// CHECK-LABEL: add_bcast_scalar_operand_0 -// CHECK-SAME: %[[ARG0:.+]]: memref<256x1024xf32>, %[[ARG1:.+]]: f32 -// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch add [256, 1024, 1, 1024, 1024] flags = (bcast_scalar_in0) data_type = f32 -// CHECK: xsmm.binary add(data_type = f32, %[[DIS]], %[[ARG1]], %[[ARG0]], %[[ARG0]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (0, d1)> - -func.func @add_bcast_col_operand_0(%arg0: memref<1x1024xf32>, %arg1: memref<256x1024xf32>) { - linalg.generic { - indexing_maps = [#map1, #map, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %arg1 : memref<1x1024xf32>, memref<256x1024xf32>) - outs(%arg1 : memref<256x1024xf32>) { - ^bb0(%in: f32, %in_6: f32, %out: f32): - %6 = arith.addf %in, %in_6 : f32 - linalg.yield %6 : f32 - } - return -} - -// CHECK-LABEL: add_bcast_col_operand_0 -// CHECK-SAME: %[[ARG0:.+]]: memref<1x1024xf32>, %[[ARG1:.+]]: memref<256x1024xf32> -// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch add [256, 1024, 1024, 1024, 1024] -// CHECK-SAME: flags = (bcast_col_in0) data_type = f32 -// CHECK: xsmm.binary add(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG1]], %[[ARG1]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (0, d1)> - -func.func @add_6(%arg0: memref<1x1024xf32>, %arg1: memref<256x1024xf32>) { - linalg.generic { - indexing_maps = [#map1, #map1, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %arg0 : memref<1x1024xf32>, memref<1x1024xf32>) - outs(%arg1 : memref<256x1024xf32>) { - ^bb0(%in: f32, %in_6: f32, %out: f32): - %6 = arith.addf %in, %in_6 : f32 - linalg.yield %6 : f32 - } - return -} - -// CHECK-LABEL: add_6 -// CHECK-SAME: %[[ARG0:.+]]: memref<1x1024xf32>, %[[ARG1:.+]]: memref<256x1024xf32> -// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch add [256, 1024, 1024, 1024, 1024] -// CHECK-SAME: flags = (bcast_col_in0, bcast_col_in1) data_type = f32 -// CHECK: xsmm.binary add(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG0]], %[[ARG1]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (d1, d0)> - -func.func @add_7(%arg0: memref<256x1024xf32>, %arg1: memref<1024x256xf32>) { - linalg.generic { - indexing_maps = [#map, #map1, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %arg1 : memref<256x1024xf32>, memref<1024x256xf32>) - outs(%arg0 : memref<256x1024xf32>) { - ^bb0(%in: f32, %in_6: f32, %out: f32): - %6 = arith.addf %in, %in_6 : f32 - linalg.yield %6 : f32 - } - return -} - -// CHECK-LABEL: add_7 -// CHECK-NOT: xsmm.binary.dispatch add -// CHECK: linalg.generic - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> - -func.func @add_8(%arg0: memref<256x1024xf32>, %arg1: memref<256x1024xf32>) { - linalg.generic { - indexing_maps = [#map, #map, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %arg1 : memref<256x1024xf32>, memref<256x1024xf32>) - outs(%arg0 : memref<256x1024xf32>) { - ^bb0(%in: f32, %in_6: f32, %out: f32): - %6 = arith.addf %in, %in_6 : f32 - linalg.yield %6 : f32 - } - return -} - -// CHECK-LABEL: add_8 -// CHECK-SAME: %[[ARG0:.+]]: memref<256x1024xf32>, %[[ARG1:.+]]: memref<256x1024xf32> -// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch add [256, 1024, 1024, 1024, 1024] -// CHECK-SAME: flags = (none) data_type = f32 -// CHECK: xsmm.binary add(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG1]], %[[ARG0]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> - -func.func @add_9(%arg0: memref<256x1024xf32>, %arg1: memref<256x1024xf32>) { - linalg.generic { - indexing_maps = [#map, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg0 : memref<256x1024xf32>) - outs(%arg1 : memref<256x1024xf32>) { - ^bb0(%in: f32, %out: f32): - %6 = arith.addf %in, %out : f32 - linalg.yield %6 : f32 - } - return -} - -// CHECK-LABEL: add_9 -// CHECK-SAME: %[[ARG0:.+]]: memref<256x1024xf32>, %[[ARG1:.+]]: memref<256x1024xf32> -// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch add [256, 1024, 1024, 1024, 1024] -// CHECK-SAME: flags = (none) data_type = f32 -// CHECK: xsmm.binary add(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG1]], %[[ARG1]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (d0)> - -func.func @add_10(%arg0: memref<256x1024xf32>, %arg1: memref<256xf32>) { - linalg.generic { - indexing_maps = [#map, #map1, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %arg1 : memref<256x1024xf32>, memref<256xf32>) - outs(%arg0 : memref<256x1024xf32>) { - ^bb0(%in: f32, %in_6: f32, %out: f32): - %6 = arith.addf %in, %in_6 : f32 - linalg.yield %6 : f32 - } - return -} - -// CHECK-LABEL: add_10 -// CHECK-SAME: %[[ARG0:.+]]: memref<256x1024xf32>, %[[ARG1:.+]]: memref<256xf32> -// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch add [256, 1024, 1024, 1, 1024] -// CHECK-SAME: flags = (bcast_row_in1) data_type = f32 -// CHECK: xsmm.binary add(data_type = f32, %[[DIS]], %[[ARG0]], %{{.+}}, %[[ARG0]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> - -func.func @trivial_sub(%arg0: memref<256x1024xf32>, %arg1: memref<256x1024xf32>) { - linalg.generic { - indexing_maps = [#map, #map, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %arg0 : memref<256x1024xf32>, memref<256x1024xf32>) - outs(%arg1: memref<256x1024xf32>) { - ^bb0(%in: f32, %in_4: f32, %out: f32): - %19 = arith.subf %in, %in_4 : f32 - linalg.yield %19 : f32 - } - return -} - -// CHECK-LABEL: trivial_sub -// CHECK-SAME: %[[ARG0:.+]]: memref<256x1024xf32>, %[[ARG1:.+]]: memref<256x1024xf32> -// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch sub [256, 1024, 1024, 1024, 1024] flags = (none) data_type = f32 -// CHECK: xsmm.binary sub(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG0]], %[[ARG1]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> ()> - -func.func @sub_bcast_scalar_operand_1(%arg0: memref<256x1024xf32>, %arg1: f32, %arg2: memref<256x1024xf32>) { - linalg.generic { - indexing_maps = [#map, #map1, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %arg1 : memref<256x1024xf32>, f32) - outs(%arg2: memref<256x1024xf32>) { - ^bb0(%in: f32, %in_4: f32, %out: f32): - %19 = arith.subf %in, %in_4 : f32 - linalg.yield %19 : f32 - } - return -} - -// CHECK-LABEL: sub_bcast_scalar_operand_1 -// CHECK-SAME: %[[ARG0:.+]]: memref<256x1024xf32>, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: memref<256x1024xf32> -// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch sub [256, 1024, 1024, 1, 1024] flags = (bcast_scalar_in1) data_type = f32 -// CHECK: xsmm.binary sub(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG1]], %[[ARG2]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> ()> - -func.func @sub_bcast_scalar_operand_0(%arg0: memref<256x1024xf32>, %arg1: f32, %arg2: memref<256x1024xf32>) { - linalg.generic { - indexing_maps = [#map1, #map, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg1, %arg0 : f32, memref<256x1024xf32>) - outs(%arg2: memref<256x1024xf32>) { - ^bb0(%in: f32, %in_4: f32, %out: f32): - %19 = arith.subf %in, %in_4 : f32 - linalg.yield %19 : f32 - } - return -} - -// CHECK-LABEL: sub_bcast_scalar_operand_0 -// CHECK-SAME: %[[ARG0:.+]]: memref<256x1024xf32>, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: memref<256x1024xf32> -// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch sub [256, 1024, 1, 1024, 1024] flags = (bcast_scalar_in0) data_type = f32 -// CHECK: xsmm.binary sub(data_type = f32, %[[DIS]], %[[ARG1]], %[[ARG0]], %[[ARG2]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (d1)> - -func.func @sub_bcast_col_operand_1(%arg0: memref<256x1024xf32>, %arg1: memref<1024xf32>, %arg2: memref<256x1024xf32>) { - linalg.generic { - indexing_maps = [#map, #map1, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %arg1 : memref<256x1024xf32>, memref<1024xf32>) - outs(%arg2: memref<256x1024xf32>) { - ^bb0(%in: f32, %in_4: f32, %out: f32): - %19 = arith.subf %in, %in_4 : f32 - linalg.yield %19 : f32 - } - return -} - -// CHECK-LABEL: sub_bcast_col_operand_1 -// CHECK-SAME: %[[ARG0:.+]]: memref<256x1024xf32>, %[[ARG1:.+]]: memref<1024xf32>, %[[ARG2:.+]]: memref<256x1024xf32> -// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch sub [256, 1024, 1024, 1024, 1024] flags = (bcast_col_in1) data_type = f32 -// CHECK: xsmm.binary sub(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG1]], %[[ARG2]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (d1)> - -func.func @sub_bcast_col_operand_0(%arg0: memref<256x1024xf32>, %arg1: memref<1024xf32>, %arg2: memref<256x1024xf32>) { - linalg.generic { - indexing_maps = [#map1, #map, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg1, %arg0 : memref<1024xf32>, memref<256x1024xf32>) - outs(%arg2: memref<256x1024xf32>) { - ^bb0(%in: f32, %in_4: f32, %out: f32): - %19 = arith.subf %in, %in_4 : f32 - linalg.yield %19 : f32 - } - return -} - -// CHECK-LABEL: sub_bcast_col_operand_0 -// CHECK-SAME: %[[ARG0:.+]]: memref<256x1024xf32>, %[[ARG1:.+]]: memref<1024xf32>, %[[ARG2:.+]]: memref<256x1024xf32> -// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch sub [256, 1024, 1024, 1024, 1024] flags = (bcast_col_in0) data_type = f32 -// CHECK: xsmm.binary sub(data_type = f32, %[[DIS]], %[[ARG1]], %[[ARG0]], %[[ARG2]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (d0, 0)> - -func.func @sub_bcast_row_operand_1(%arg0: memref<256x1024xf32>, %arg1: memref<256x1xf32>, %arg2: memref<256x1024xf32>) { - linalg.generic { - indexing_maps = [#map, #map1, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %arg1 : memref<256x1024xf32>, memref<256x1xf32>) - outs(%arg2: memref<256x1024xf32>) { - ^bb0(%in: f32, %in_4: f32, %out: f32): - %19 = arith.subf %in, %in_4 : f32 - linalg.yield %19 : f32 - } - return -} - -// CHECK-LABEL: sub_bcast_row_operand_1 -// CHECK-SAME: %[[ARG0:.+]]: memref<256x1024xf32>, %[[ARG1:.+]]: memref<256x1xf32>, %[[ARG2:.+]]: memref<256x1024xf32> -// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch sub [256, 1024, 1024, 1, 1024] flags = (bcast_row_in1) data_type = f32 -// CHECK: xsmm.binary sub(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG1]], %[[ARG2]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (d0, 0)> - -func.func @sub_bcast_row_operand_0(%arg0: memref<256x1024xf32>, %arg1: memref<256x1xf32>, %arg2: memref<256x1024xf32>) { - linalg.generic { - indexing_maps = [#map1, #map, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg1, %arg0 : memref<256x1xf32>, memref<256x1024xf32>) - outs(%arg2: memref<256x1024xf32>) { - ^bb0(%in: f32, %in_4: f32, %out: f32): - %19 = arith.subf %in, %in_4 : f32 - linalg.yield %19 : f32 - } - return -} - -// CHECK-LABEL: sub_bcast_row_operand_0 -// CHECK-SAME: %[[ARG0:.+]]: memref<256x1024xf32>, %[[ARG1:.+]]: memref<256x1xf32>, %[[ARG2:.+]]: memref<256x1024xf32> -// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch sub [256, 1024, 1, 1024, 1024] flags = (bcast_row_in0) data_type = f32 -// CHECK: xsmm.binary sub(data_type = f32, %[[DIS]], %[[ARG1]], %[[ARG0]], %[[ARG2]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (d0)> - -func.func @sub_bcast_row_1(%arg0: memref<256x1024xf32>, %arg1: memref<256xf32>) { - linalg.generic { - indexing_maps = [#map, #map1, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %arg1 : memref<256x1024xf32>, memref<256xf32>) - outs(%arg0 : memref<256x1024xf32>) { - ^bb0(%in: f32, %in_6: f32, %out: f32): - %6 = arith.subf %in, %in_6 : f32 - linalg.yield %6 : f32 - } - return -} - -// CHECK-LABEL: sub_bcast_row_1 -// CHECK-SAME: %[[ARG0:.+]]: memref<256x1024xf32>, %[[ARG1:.+]]: memref<256xf32> -// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch sub [256, 1024, 1024, 1, 1024] -// CHECK-SAME: flags = (bcast_row_in1) data_type = f32 -// CHECK: xsmm.binary sub(data_type = f32, %[[DIS]], %[[ARG0]], %{{.+}}, %[[ARG0]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> - -func.func @trivial_mul(%arg0: memref<256x1024xf32>, %arg1: memref<256x1024xf32>) { - linalg.generic { - indexing_maps = [#map, #map, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %arg0 : memref<256x1024xf32>, memref<256x1024xf32>) - outs(%arg1: memref<256x1024xf32>) { - ^bb0(%in: f32, %in_4: f32, %out: f32): - %19 = arith.mulf %in, %in_4 : f32 - linalg.yield %19 : f32 - } - return -} - -// CHECK-LABEL: trivial_mul -// CHECK-SAME: %[[ARG0:.+]]: memref<256x1024xf32>, %[[ARG1:.+]]: memref<256x1024xf32> -// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch mul [256, 1024, 1024, 1024, 1024] flags = (none) data_type = f32 -// CHECK: xsmm.binary mul(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG0]], %[[ARG1]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> ()> - -func.func @mul_bcast_scalar_operand_1(%arg0: memref<256x1024xf32>, %arg1: f32, %arg2: memref<256x1024xf32>) { - linalg.generic { - indexing_maps = [#map, #map1, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %arg1 : memref<256x1024xf32>, f32) - outs(%arg2: memref<256x1024xf32>) { - ^bb0(%in: f32, %in_4: f32, %out: f32): - %19 = arith.mulf %in, %in_4 : f32 - linalg.yield %19 : f32 - } - return -} - -// CHECK-LABEL: mul_bcast_scalar_operand_1 -// CHECK-SAME: %[[ARG0:.+]]: memref<256x1024xf32>, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: memref<256x1024xf32> -// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch mul [256, 1024, 1024, 1, 1024] flags = (bcast_scalar_in1) data_type = f32 -// CHECK: xsmm.binary mul(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG1]], %[[ARG2]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> ()> - -func.func @mul_bcast_scalar_operand_0(%arg0: memref<256x1024xf32>, %arg1: f32, %arg2: memref<256x1024xf32>) { - linalg.generic { - indexing_maps = [#map1, #map, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg1, %arg0 : f32, memref<256x1024xf32>) - outs(%arg2: memref<256x1024xf32>) { - ^bb0(%in: f32, %in_4: f32, %out: f32): - %19 = arith.mulf %in, %in_4 : f32 - linalg.yield %19 : f32 - } - return -} - -// CHECK-LABEL: mul_bcast_scalar_operand_0 -// CHECK-SAME: %[[ARG0:.+]]: memref<256x1024xf32>, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: memref<256x1024xf32> -// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch mul [256, 1024, 1, 1024, 1024] flags = (bcast_scalar_in0) data_type = f32 -// CHECK: xsmm.binary mul(data_type = f32, %[[DIS]], %[[ARG1]], %[[ARG0]], %[[ARG2]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (d1)> - -func.func @mul_bcast_col_operand_1(%arg0: memref<256x1024xf32>, %arg1: memref<1024xf32>, %arg2: memref<256x1024xf32>) { - linalg.generic { - indexing_maps = [#map, #map1, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %arg1 : memref<256x1024xf32>, memref<1024xf32>) - outs(%arg2: memref<256x1024xf32>) { - ^bb0(%in: f32, %in_4: f32, %out: f32): - %19 = arith.mulf %in, %in_4 : f32 - linalg.yield %19 : f32 - } - return -} - -// CHECK-LABEL: mul_bcast_col_operand_1 -// CHECK-SAME: %[[ARG0:.+]]: memref<256x1024xf32>, %[[ARG1:.+]]: memref<1024xf32>, %[[ARG2:.+]]: memref<256x1024xf32> -// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch mul [256, 1024, 1024, 1024, 1024] flags = (bcast_col_in1) data_type = f32 -// CHECK: xsmm.binary mul(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG1]], %[[ARG2]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (d1)> - -func.func @mul_bcast_col_operand_0(%arg0: memref<256x1024xf32>, %arg1: memref<1024xf32>, %arg2: memref<256x1024xf32>) { - linalg.generic { - indexing_maps = [#map1, #map, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg1, %arg0 : memref<1024xf32>, memref<256x1024xf32>) - outs(%arg2: memref<256x1024xf32>) { - ^bb0(%in: f32, %in_4: f32, %out: f32): - %19 = arith.mulf %in, %in_4 : f32 - linalg.yield %19 : f32 - } - return -} - -// CHECK-LABEL: mul_bcast_col_operand_0 -// CHECK-SAME: %[[ARG0:.+]]: memref<256x1024xf32>, %[[ARG1:.+]]: memref<1024xf32>, %[[ARG2:.+]]: memref<256x1024xf32> -// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch mul [256, 1024, 1024, 1024, 1024] flags = (bcast_col_in0) data_type = f32 -// CHECK: xsmm.binary mul(data_type = f32, %[[DIS]], %[[ARG1]], %[[ARG0]], %[[ARG2]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (d0, 0)> - -func.func @mul_bcast_row_operand_1(%arg0: memref<256x1024xf32>, %arg1: memref<256x1xf32>, %arg2: memref<256x1024xf32>) { - linalg.generic { - indexing_maps = [#map, #map1, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %arg1 : memref<256x1024xf32>, memref<256x1xf32>) - outs(%arg2: memref<256x1024xf32>) { - ^bb0(%in: f32, %in_4: f32, %out: f32): - %19 = arith.mulf %in, %in_4 : f32 - linalg.yield %19 : f32 - } - return -} - -// CHECK-LABEL: mul_bcast_row_operand_1 -// CHECK-SAME: %[[ARG0:.+]]: memref<256x1024xf32>, %[[ARG1:.+]]: memref<256x1xf32>, %[[ARG2:.+]]: memref<256x1024xf32> -// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch mul [256, 1024, 1024, 1, 1024] flags = (bcast_row_in1) data_type = f32 -// CHECK: xsmm.binary mul(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG1]], %[[ARG2]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (d0, 0)> - -func.func @mul_bcast_row_operand_0(%arg0: memref<256x1024xf32>, %arg1: memref<256x1xf32>, %arg2: memref<256x1024xf32>) { - linalg.generic { - indexing_maps = [#map1, #map, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg1, %arg0 : memref<256x1xf32>, memref<256x1024xf32>) - outs(%arg2: memref<256x1024xf32>) { - ^bb0(%in: f32, %in_4: f32, %out: f32): - %19 = arith.mulf %in, %in_4 : f32 - linalg.yield %19 : f32 - } - return -} - -// CHECK-LABEL: mul_bcast_row_operand_0 -// CHECK-SAME: %[[ARG0:.+]]: memref<256x1024xf32>, %[[ARG1:.+]]: memref<256x1xf32>, %[[ARG2:.+]]: memref<256x1024xf32> -// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch mul [256, 1024, 1, 1024, 1024] flags = (bcast_row_in0) data_type = f32 -// CHECK: xsmm.binary mul(data_type = f32, %[[DIS]], %[[ARG1]], %[[ARG0]], %[[ARG2]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (d0)> - -func.func @mul_bcast_row_1(%arg0: memref<256x1024xf32>, %arg1: memref<256xf32>) { - linalg.generic { - indexing_maps = [#map, #map1, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %arg1 : memref<256x1024xf32>, memref<256xf32>) - outs(%arg0 : memref<256x1024xf32>) { - ^bb0(%in: f32, %in_6: f32, %out: f32): - %6 = arith.mulf %in, %in_6 : f32 - linalg.yield %6 : f32 - } - return -} - -// CHECK-LABEL: mul_bcast_row_1 -// CHECK-SAME: %[[ARG0:.+]]: memref<256x1024xf32>, %[[ARG1:.+]]: memref<256xf32> -// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch mul [256, 1024, 1024, 1, 1024] -// CHECK-SAME: flags = (bcast_row_in1) data_type = f32 -// CHECK: xsmm.binary mul(data_type = f32, %[[DIS]], %[[ARG0]], %{{.+}}, %[[ARG0]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (d0)> - -func.func @mul_bcast_row_in0(%arg0: memref<10xf32>, %arg1: memref<10x10xf32>) { - linalg.generic { - indexing_maps = [#map1, #map, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg0, %arg1 : memref<10xf32>, memref<10x10xf32>) - outs(%arg1 : memref<10x10xf32>) { - ^bb0(%in: f32, %in_6: f32, %out: f32): - %6 = arith.mulf %in, %in_6 : f32 - linalg.yield %6 : f32 - } - return -} - -// CHECK-LABEL: mul_bcast_row_in0 -// CHECK-SAME: %[[ARG0:.+]]: memref<10xf32>, %[[ARG1:.+]]: memref<10x10xf32> -// CHECK: %[[EXP:.+]] = memref.expand_shape %[[ARG0]] {{\[}}[0, 1]] output_shape [10, 1] : memref<10xf32> into memref<10x1xf32> -// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch mul [10, 10, 1, 10, 10] flags = (bcast_row_in0) data_type = f32 -// CHECK: xsmm.binary mul(data_type = f32, %[[DIS]], %[[EXP]], %[[ARG1]], %[[ARG1]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (d0)> - -func.func @mul_bcast_row_in1(%arg0: memref<10xf32>, %arg1: memref<10x10xf32>) { - linalg.generic { - indexing_maps = [#map, #map1, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg1, %arg0 : memref<10x10xf32>, memref<10xf32>) - outs(%arg1 : memref<10x10xf32>) { - ^bb0(%in: f32, %in_6: f32, %out: f32): - %6 = arith.mulf %in, %in_6 : f32 - linalg.yield %6 : f32 - } - return -} - -// CHECK-LABEL: mul_bcast_row_in1 -// CHECK-SAME: %[[ARG0:.+]]: memref<10xf32>, %[[ARG1:.+]]: memref<10x10xf32> -// CHECK: %[[EXP:.+]] = memref.expand_shape %[[ARG0]] {{\[}}[0, 1]] output_shape [10, 1] : memref<10xf32> into memref<10x1xf32> -// CHECK: %[[DIS:.+]] = xsmm.binary.dispatch mul [10, 10, 10, 1, 10] flags = (bcast_row_in1) data_type = f32 -// CHECK: xsmm.binary mul(data_type = f32, %[[DIS]], %[[ARG1]], %[[EXP]], %[[ARG1]]) diff --git a/test/Conversion/LinalgToXsmm/linalg-to-brgemm.mlir b/test/Conversion/LinalgToXsmm/linalg-to-brgemm.mlir deleted file mode 100644 index a37ab961c..000000000 --- a/test/Conversion/LinalgToXsmm/linalg-to-brgemm.mlir +++ /dev/null @@ -1,416 +0,0 @@ -// RUN: tpp-opt %s -convert-linalg-to-xsmm -split-input-file | FileCheck %s - -#map = affine_map<(i, k, kk, j) -> (i, k, kk)> -#map1 = affine_map<(i, k, kk, j) -> (k, kk, j)> -#map2 = affine_map<(i, k, kk, j) -> (i, j)> - -func.func @brgemm(%arg0: memref<2x2x2x4xf32>, %arg1: memref<2x4x8x2xf32>, - %arg2: memref<2x2x8x2xf32>) { - scf.forall (%arg3, %arg4) in (2, 8) { - %subview = memref.subview %arg0[%arg3, 0, 0, 0] [1, 2, 2, 4] [1, 1, 1, 1] - : memref<2x2x2x4xf32> to memref<2x2x4xf32, strided<[8, 4, 1], offset: ?>> - %subview_2 = memref.subview %arg1[0, 0, %arg4, 0] [2, 4, 1, 2] [1, 1, 1, 1] - : memref<2x4x8x2xf32> to memref<2x4x2xf32, strided<[64, 16, 1], offset: ?>> - %subview_3 = memref.subview %arg2[%arg3, 0, %arg4, 0] [1, 2, 1, 2] [1, 1, 1, 1] - : memref<2x2x8x2xf32> to memref<2x2xf32, strided<[16, 1], offset: ?>> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["parallel", "reduction", "reduction", "parallel"]} - ins(%subview, %subview_2 : memref<2x2x4xf32, strided<[8, 4, 1], offset: ?>>, memref<2x4x2xf32, strided<[64, 16, 1], offset: ?>>) - outs(%subview_3 : memref<2x2xf32, strided<[16, 1], offset: ?>>) { - ^bb0(%in: f32, %in_4: f32, %out: f32): - %1 = arith.mulf %in, %in_4 : f32 - %2 = arith.addf %out, %1 : f32 - linalg.yield %2 : f32 - } - } - return -} - -// CHECK-LABEL: brgemm -// CHECK-SAME: %[[ARG0:.+]]: memref<2x2x2x4xf32>, %[[ARG1:.+]]: memref<2x4x8x2xf32>, %[[ARG2:.+]]: memref<2x2x8x2xf32> -// CHECK: %[[C2:.+]] = arith.constant 2 : i64 -// CHECK: scf.forall (%[[ARG3:.+]], %[[ARG4:.+]]) in (2, 8) { -// CHECK: %[[SUB:.+]] = memref.subview %[[ARG0]][%[[ARG3]], 0, 0, 0] [1, 2, 2, 4] [1, 1, 1, 1] -// CHECK-SAME: : memref<2x2x2x4xf32> to memref<2x2x4xf32, strided<[8, 4, 1], offset: ?>> -// CHECK: %[[SUB_0:.+]] = memref.subview %[[ARG1]][0, 0, %[[ARG4]], 0] [2, 4, 1, 2] [1, 1, 1, 1] -// CHECK-SAME: : memref<2x4x8x2xf32> to memref<2x4x2xf32, strided<[64, 16, 1], offset: ?>> -// CHECK: %[[SUB_1:.+]] = memref.subview %[[ARG2]][%[[ARG3]], 0, %[[ARG4]], 0] [1, 2, 1, 2] [1, 1, 1, 1] -// CHECK-SAME: : memref<2x2x8x2xf32> to memref<2x2xf32, strided<[16, 1], offset: ?>> -// CHECK: %[[DIS:.+]] = xsmm.brgemm.dispatch [2, 2, 4, 8, 16, 16, 4, 64] flags = (none) data_type = f32 -// CHECK: xsmm.brgemm(data_type = f32, %[[DIS]], %[[SUB]], %[[SUB_0]], %[[SUB_1]], %[[C2]]) - -// m = 2 -// n = 2 -// k = 4 -// lda = 8 -// ldb = 16 -// ldc = 16 -// stride_a = 4 -// stride_b = 64 - -// ----- - -#map = affine_map<(i, j, kk, k) -> (kk, i, k)> -#map1 = affine_map<(i, j, kk, k) -> (kk, k, j)> -#map2 = affine_map<(i, j, kk, k) -> (i, j)> - -func.func @brgemm_1(%arg0: memref<9x4x5xf32>, %arg1: memref<9x5x8xf32>, %arg2: memref<4x8xf32>) { - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["parallel", "parallel", "reduction", "reduction"]} - ins(%arg0, %arg1 : memref<9x4x5xf32>, memref<9x5x8xf32>) - outs(%arg2: memref<4x8xf32>) { - ^bb0(%in: f32, %in_8: f32, %out: f32): - %5 = arith.mulf %in, %in_8 : f32 - %6 = arith.addf %out, %5 : f32 - linalg.yield %6 : f32 - } - return -} - -// CHECK-LABEL: brgemm_1 -// CHECK-SAME: %[[ARG0:.+]]: memref<9x4x5xf32>, %[[ARG1:.+]]: memref<9x5x8xf32>, %[[ARG2:.+]]: memref<4x8xf32> -// CHECK: %[[C9:.+]] = arith.constant 9 : i64 -// CHECK: %[[DIS:.+]] = xsmm.brgemm.dispatch [4, 8, 5, 5, 8, 8, 20, 40] flags = (none) data_type = f32 -// CHECK: xsmm.brgemm(data_type = f32, %0, %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[C9]]) - -// ----- - -#map = affine_map<(kk, k, i, j) -> (kk, i, k)> -#map1 = affine_map<(kk, k, i, j) -> (kk, k, j)> -#map2 = affine_map<(kk, k, i, j) -> (i, j)> - -func.func @brgemm_2(%arg0: memref<9x4x5xf32>, %arg1: memref<9x5x8xf32>, %arg2: memref<4x8xf32>) { - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "reduction", "parallel", "parallel"]} - ins(%arg0, %arg1 : memref<9x4x5xf32>, memref<9x5x8xf32>) - outs(%arg2: memref<4x8xf32>) { - ^bb0(%in: f32, %in_8: f32, %out: f32): - %5 = arith.mulf %in, %in_8 : f32 - %6 = arith.addf %out, %5 : f32 - linalg.yield %6 : f32 - } - return -} - -// CHECK-LABEL: brgemm_2 -// CHECK-SAME: %[[ARG0:.+]]: memref<9x4x5xf32>, %[[ARG1:.+]]: memref<9x5x8xf32>, %[[ARG2:.+]]: memref<4x8xf32> -// CHECK: %[[C9:.+]] = arith.constant 9 : i64 -// CHECK: %[[DIS:.+]] = xsmm.brgemm.dispatch [4, 8, 5, 5, 8, 8, 20, 40] flags = (none) data_type = f32 -// CHECK: xsmm.brgemm(data_type = f32, %0, %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[C9]]) - -// ----- - -#map = affine_map<(kk, k, i, j) -> (kk, i, k)> -#map1 = affine_map<(kk, k, i, j) -> (kk, k, j)> -#map2 = affine_map<(kk, k, i, j) -> (i, j)> - -// non unit stride. -func.func @brgemm_3(%arg0: memref<9x4x5xf32>, %arg1: memref<9x5x8xf32, strided<[40, 8, 2], offset: ?>>, %arg2: memref<4x8xf32>) { - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "reduction", "parallel", "parallel"]} - ins(%arg0, %arg1 : memref<9x4x5xf32>, memref<9x5x8xf32, strided<[40, 8, 2], offset: ?>>) - outs(%arg2: memref<4x8xf32>) { - ^bb0(%in: f32, %in_8: f32, %out: f32): - %5 = arith.mulf %in, %in_8 : f32 - %6 = arith.addf %out, %5 : f32 - linalg.yield %6 : f32 - } - return -} - -// CHECK-LABEL: brgemm_3 -// CHECK-NOT: xsmm.brgemm -// CHECK: linalg.generic - -// ----- - -#map = affine_map<(i, j, kk, k) -> (kk, i, k)> -#map1 = affine_map<(i, j, kk, k) -> (kk, j, k)> -#map2 = affine_map<(i, j, kk, k) -> (i, j)> - -func.func @brgemm_5(%arg0: memref<9x4x5xf32>, %arg1: memref<9x8x5xf32>, %arg2: memref<4x8xf32>) { - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["parallel", "parallel", "reduction", "reduction"]} - ins(%arg0, %arg1 : memref<9x4x5xf32>, memref<9x8x5xf32>) - outs(%arg2: memref<4x8xf32>) { - ^bb0(%in: f32, %in_8: f32, %out: f32): - %5 = arith.mulf %in, %in_8 : f32 - %6 = arith.addf %out, %5 : f32 - linalg.yield %6 : f32 - } - return -} - -// CHECK-LABEL: brgemm_5 -// CHECK-SAME: %[[ARG0:.+]]: memref<9x4x5xf32>, %[[ARG1:.+]]: memref<9x8x5xf32>, %[[ARG2:.+]]: memref<4x8xf32> -// CHECK: %[[C9:.+]] = arith.constant 9 : i64 -// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<9x5x8xf32> -// CHECK: linalg.transpose ins(%[[ARG1]] : memref<9x8x5xf32>) -// CHECK-SAME: outs(%[[ALLOC]] : memref<9x5x8xf32>) permutation = [0, 2, 1] -// CHECK: %[[DIS:.+]] = xsmm.brgemm.dispatch [4, 8, 5, 5, 8, 8, 20, 40] flags = (none) data_type = f32 -// CHECK: xsmm.brgemm(data_type = f32, %[[DIS]], %[[ARG0]], %[[ALLOC]], %[[ARG2]], %[[C9]]) - - -// ----- - -#map = affine_map<(i, j, k) -> (i, k)> -#map1 = affine_map<(i, j, k) -> (k, j)> -#map2 = affine_map<(i, j, k) -> (i, j)> - -func.func @gemm_1(%arg0: memref<64x32xf32>, %arg1: memref<32x64xf32>, %arg2: memref<64x64xf32>) { - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["parallel", "parallel", "reduction"]} - ins(%arg0, %arg1: memref<64x32xf32>, memref<32x64xf32>) - outs(%arg2: memref<64x64xf32>) { - ^bb0(%in: f32, %in_4: f32, %out: f32): - %1 = arith.mulf %in, %in_4 : f32 - %2 = arith.addf %out, %1 : f32 - linalg.yield %2 : f32 - } - return -} - -// CHECK-LABEL: gemm_1 -// CHECK-SAME: %[[ARG0:.+]]: memref<64x32xf32>, %[[ARG1:.+]]: memref<32x64xf32>, %[[ARG2:.+]]: memref<64x64xf32> -// CHECK: %[[DIS:.+]] = xsmm.gemm.dispatch [64, 64, 32, 32, 64, 64] flags = (none) data_type = f32 -// CHECK: xsmm.gemm(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG1]], %[[ARG2]]) - -// ----- - -#map = affine_map<(k, i, j) -> (i, k)> -#map1 = affine_map<(k, i, j) -> (k, j)> -#map2 = affine_map<(k, i, j) -> (i, j)> - -// permutation on outerloop is not relevant. -func.func @gemm_2(%arg0: memref<64x32xf32>, %arg1: memref<32x64xf32>, %arg2: memref<64x64xf32>) { - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel"]} - ins(%arg0, %arg1: memref<64x32xf32>, memref<32x64xf32>) - outs(%arg2: memref<64x64xf32>) { - ^bb0(%in: f32, %in_4: f32, %out: f32): - %1 = arith.mulf %in, %in_4 : f32 - %2 = arith.addf %out, %1 : f32 - linalg.yield %2 : f32 - } - return -} - -// CHECK-LABEL: gemm_2 -// CHECK-SAME: %[[ARG0:.+]]: memref<64x32xf32>, %[[ARG1:.+]]: memref<32x64xf32>, %[[ARG2:.+]]: memref<64x64xf32> -// CHECK: %[[DIS:.+]] = xsmm.gemm.dispatch [64, 64, 32, 32, 64, 64] flags = (none) data_type = f32 -// CHECK: xsmm.gemm(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG1]], %[[ARG2]]) - -// ----- - -#map = affine_map<(i, j, k) -> (i, k)> -#map1 = affine_map<(i, j, k) -> (k, j)> -#map2 = affine_map<(i, j, k) -> (j, i)> - -func.func @gemm_3(%arg0: memref<64x32xf32>, %arg1: memref<32x64xf32>, %arg2: memref<64x64xf32>) { - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["parallel", "parallel", "reduction"]} - ins(%arg0, %arg1: memref<64x32xf32>, memref<32x64xf32>) - outs(%arg2: memref<64x64xf32>) { - ^bb0(%in: f32, %in_4: f32, %out: f32): - %1 = arith.mulf %in, %in_4 : f32 - %2 = arith.addf %out, %1 : f32 - linalg.yield %2 : f32 - } - return -} - -// CHECK-LABEL: gemm_3 -// CHECK-SAME: %[[ARG0:.+]]: memref<64x32xf32>, %[[ARG1:.+]]: memref<32x64xf32>, %[[ARG2:.+]]: memref<64x64xf32> -// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x64xf32> -// CHECK: %[[DIS_TRANS:.+]] = xsmm.unary.dispatch transpose [64, 32, 32, 64] flags = (none) data_type = f32 -// CHECK: xsmm.unary transpose(data_type = f32, %[[DIS_TRANS]], %[[ARG0]], %[[ALLOC]]) -// CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<64x32xf32> -// CHECK: %[[DIS_TRANS_1:.+]] = xsmm.unary.dispatch transpose [32, 64, 64, 32] flags = (none) data_type = f32 -// CHECK: xsmm.unary transpose(data_type = f32, %[[DIS_TRANS_1]], %[[ARG1]], %[[ALLOC_0]]) -// CHECK: %[[DIS:.+]] = xsmm.gemm.dispatch [64, 64, 32, 32, 64, 64] flags = (none) data_type = f32 -// CHECK: xsmm.gemm(data_type = f32, %[[DIS]], %[[ALLOC_0]], %[[ALLOC]], %[[ARG2]]) - -// ----- - -#map = affine_map<(i, j, k) -> (i, k)> -#map1 = affine_map<(i, j, k) -> (j, k)> -#map2 = affine_map<(i, j, k) -> (j, i)> - -func.func @gemm_4(%arg0: memref<64x32xf32>, %arg1: memref<64x32xf32>, %arg2: memref<64x64xf32>) { - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["parallel", "parallel", "reduction"]} - ins(%arg0, %arg1: memref<64x32xf32>, memref<64x32xf32>) - outs(%arg2: memref<64x64xf32>) { - ^bb0(%in: f32, %in_4: f32, %out: f32): - %1 = arith.mulf %in, %in_4 : f32 - %2 = arith.addf %out, %1 : f32 - linalg.yield %2 : f32 - } - return -} - -// CHECK-LABEL: gemm_4 -// CHECK-SAME: %[[ARG0:.+]]: memref<64x32xf32>, %[[ARG1:.+]]: memref<64x32xf32>, %[[ARG2:.+]]: memref<64x64xf32> -// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<32x64xf32> -// CHECK: %[[DIS_TRAN:.+]] = xsmm.unary.dispatch transpose [64, 32, 32, 64] flags = (none) data_type = f32 -// CHECK: xsmm.unary transpose(data_type = f32, %[[DIS_TRAN]], %[[ARG0]], %[[ALLOC]]) -// CHECK: %[[DIS:.+]] = xsmm.gemm.dispatch [64, 64, 32, 32, 64, 64] flags = (none) data_type = f32 -// CHECK: xsmm.gemm(data_type = f32, %[[DIS]], %[[ARG1]], %[[ALLOC]], %[[ARG2]]) - -// ----- - -func.func @simple_brgemm(%arg0: memref<2x32x32xf32>, %arg1: memref<2x32x32xf32>, %arg2: memref<32x32xf32>) { - linalg.batch_reduce_matmul ins(%arg0, %arg1 : memref<2x32x32xf32>, memref<2x32x32xf32>) - outs(%arg2: memref<32x32xf32>) - return -} - -// CHECK-LABEL: simple_brgemm -// CHECK-SAME: %[[ARG0:.+]]: memref<2x32x32xf32>, %[[ARG1:.+]]: memref<2x32x32xf32>, %[[ARG2:.+]]: memref<32x32xf32> -// CHECK: %[[C2:.+]] = arith.constant 2 : i64 -// CHECK: %[[DIS:.+]] = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (none) data_type = f32 -// CHECK: xsmm.brgemm(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[C2]]) - -// ----- - -#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)> -#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4 floordiv 2, d3, d1)> -#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)> - -func.func @vnni_brgemm_interchanged(%arg0: memref<16x32x32xbf16>, %arg1: memref<16x16x32x2xbf16>, %arg2: memref<32x32xbf16>) { - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} - ins(%arg0, %arg1 : memref<16x32x32xbf16>, memref<16x16x32x2xbf16>) - outs(%arg2 : memref<32x32xbf16>) { - ^bb0(%in: bf16, %in_5: bf16, %out: bf16): - %5 = arith.mulf %in, %in_5 : bf16 - %6 = arith.addf %out, %5 : bf16 - linalg.yield %6 : bf16 - } - return -} - -// CHECK-LABEL: vnni_brgemm_interchanged -// CHECK-SAME: %[[ARG0:.+]]: memref<16x32x32xbf16>, %[[ARG1:.+]]: memref<16x16x32x2xbf16>, -// CHECK-SAME: %[[ARG2:.+]]: memref<32x32xbf16> -// CHECK: %[[C16:.+]] = arith.constant 16 : i64 -// CHECK: %[[DIS:.+]] = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] -// CHECK-SAME: flags = (vnni_b) data_type = bf16 -// CHECK: xsmm.brgemm(data_type = bf16, %[[DIS]], %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[C16]]) - -// ----- - -#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)> -#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3 floordiv 2, d2, d4)> -#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)> - -func.func @vnni_brgemm(%arg0: memref<16x32x32xbf16>, %arg1: memref<16x16x32x2xbf16>, %arg2: memref<32x32xbf16>) { - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} - ins(%arg0, %arg1 : memref<16x32x32xbf16>, memref<16x16x32x2xbf16>) - outs(%arg2 : memref<32x32xbf16>) { - ^bb0(%in: bf16, %in_5: bf16, %out: bf16): - %5 = arith.mulf %in, %in_5 : bf16 - %6 = arith.addf %out, %5 : bf16 - linalg.yield %6 : bf16 - } - return -} - -// CHECK-LABEL: vnni_brgemm -// CHECK-SAME: %[[ARG0:.+]]: memref<16x32x32xbf16>, %[[ARG1:.+]]: memref<16x16x32x2xbf16>, -// CHECK-SAME: %[[ARG2:.+]]: memref<32x32xbf16> -// CHECK: %[[C16:.+]] = arith.constant 16 : i64 -// CHECK: %[[DIS:.+]] = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] -// CHECK-SAME: flags = (vnni_b) data_type = bf16 -// CHECK: xsmm.brgemm(data_type = bf16, %[[DIS]], %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[C16]]) - -// ----- - -#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)> -#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3 floordiv 2, d2, d4)> -#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d2)> - -func.func @vnni_brgemm_strided(%arg0: memref<8x8x8xbf16, strided<[64, 8, 1], offset: ?>>, - %arg1: memref<8x4x8x2xbf16, strided<[64, 16, 2, 1], offset: ?>>, - %arg2: memref<8x8xbf16>) { - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction", "reduction"]} - ins(%arg0, %arg1 : memref<8x8x8xbf16, strided<[64, 8, 1], offset: ?>>, memref<8x4x8x2xbf16, strided<[64, 16, 2, 1], offset: ?>>) - outs(%arg2 : memref<8x8xbf16>) { - ^bb0(%in: bf16, %in_9: bf16, %out: bf16): - %11 = arith.mulf %in, %in_9 : bf16 - %12 = arith.addf %out, %11 : bf16 - linalg.yield %12 : bf16 - } - return -} - -// CHECK-LABEL: vnni_brgemm_strided -// CHECK-SAME: %[[ARG0:.+]]: memref<8x8x8xbf16, strided<[64, 8, 1], offset: ?>>, -// CHECK-SAME: %[[ARG1:.+]]: memref<8x4x8x2xbf16, strided<[64, 16, 2, 1], offset: ?>>, -// CHECK-SAME: %[[ARG2:.+]]: memref<8x8xbf16> -// CHECK: %[[C8:.+]] = arith.constant 8 : i64 -// CHECK: %[[DIS:.+]] = xsmm.brgemm.dispatch [8, 8, 8, 8, 8, 8, 64, 64] -// CHECK-SAME: flags = (vnni_b) data_type = bf16 -// CHECK: xsmm.brgemm(data_type = bf16, %[[DIS]], %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[C8]]) - -// ----- - -#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)> -#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4 floordiv 2, d3, d1)> -#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d2)> - -func.func @vnni_brgemm_require_transpose_on_C(%arg0: memref<16x32x32xbf16>, %arg1: memref<16x16x32x2xbf16>, %arg2: memref<32x32xbf16>) { - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} - ins(%arg0, %arg1 : memref<16x32x32xbf16>, memref<16x16x32x2xbf16>) - outs(%arg2 : memref<32x32xbf16>) { - ^bb0(%in: bf16, %in_5: bf16, %out: bf16): - %5 = arith.mulf %in, %in_5 : bf16 - %6 = arith.addf %out, %5 : bf16 - linalg.yield %6 : bf16 - } - return -} - -// CHECK-LABEL: vnni_brgemm_require_transpose_on_C -// CHECK-NOT: xsmm.brgemm -// CHECK: linalg.generic - -// ----- - -#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4)> -#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4 floordiv 5, d3, d1)> -#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d2)> - -func.func @brgemm_not_vnni(%arg0: memref<16x32x32xbf16>, %arg1: memref<16x16x32x2xbf16>, %arg2: memref<32x32xbf16>) { - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"]} - ins(%arg0, %arg1 : memref<16x32x32xbf16>, memref<16x16x32x2xbf16>) - outs(%arg2 : memref<32x32xbf16>) { - ^bb0(%in: bf16, %in_5: bf16, %out: bf16): - %5 = arith.mulf %in, %in_5 : bf16 - %6 = arith.addf %out, %5 : bf16 - linalg.yield %6 : bf16 - } - return -} - -// CHECK-LABEL: brgemm_not_vnni -// CHECK-NOT: xsmm.brgemm -// CHECK: linalg.generic diff --git a/test/Conversion/LinalgToXsmm/linalg-to-gemm.mlir b/test/Conversion/LinalgToXsmm/linalg-to-gemm.mlir deleted file mode 100644 index 89740180a..000000000 --- a/test/Conversion/LinalgToXsmm/linalg-to-gemm.mlir +++ /dev/null @@ -1,364 +0,0 @@ -// RUN: tpp-opt %s -convert-linalg-to-xsmm -split-input-file | FileCheck %s - -func.func @simple_gemm(%arg0: memref<32x64xf32, strided<[64, 1], offset: ?>>, - %arg1: memref<64x32xf32, strided<[32, 1], offset: ?>>, - %arg2: memref<32x32xf32, strided<[32, 1], offset: ?>>) { - linalg.matmul ins(%arg0, %arg1 : memref<32x64xf32, strided<[64, 1], offset: ?>>, - memref<64x32xf32, strided<[32, 1], offset: ?>>) - outs(%arg2 : memref<32x32xf32, strided<[32, 1], offset: ?>>) - return -} - -// CHECK-LABEL: simple_gemm -// CHECK-SAME: %[[ARG0:.+]]: memref<32x64xf32, strided<[64, 1], offset: ?>>, -// CHECK-SAME: %[[ARG1:.+]]: memref<64x32xf32, strided<[32, 1], offset: ?>>, -// CHECK-SAME: %[[ARG2:.+]]: memref<32x32xf32, strided<[32, 1], offset: ?>> -// CHECK: %[[DIS:.+]] = xsmm.gemm.dispatch [32, 32, 64, 64, 32, 32] flags = (none) data_type = f32 -// CHECK: xsmm.gemm(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG1]], %[[ARG2]]) - -// ----- - -#map = affine_map<(d0, d1, d2) -> (d0, d1)> -#map1 = affine_map<(d0, d1, d2) -> (d2, d1)> -#map2 = affine_map<(d0, d1, d2) -> (d2, d0)> - -func.func @mha_query_times_key(%arg0: memref<64x32x8x64xf32>, %arg1: memref<64x32x8x64xf32>, - %arg2: memref<64x8x32x32xf32>) { - %cst = arith.constant 0.000000e+00 : f32 - scf.forall (%arg3, %arg4) in (64, 8) { - %subview = memref.subview %arg2[%arg3, %arg4, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<64x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - linalg.fill ins(%cst : f32) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>) - %subview_0 = memref.subview %arg0[%arg3, 0, %arg4, 0] [1, 32, 1, 64] [1, 1, 1, 1] : memref<64x32x8x64xf32> to memref<32x64xf32, strided<[512, 1], offset: ?>> - %subview_1 = memref.subview %arg1[%arg3, 0, %arg4, 0] [1, 32, 1, 64] [1, 1, 1, 1] : memref<64x32x8x64xf32> to memref<32x64xf32, strided<[512, 1], offset: ?>> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["parallel", "reduction", "parallel"]} - ins(%subview_0, %subview_1 : memref<32x64xf32, strided<[512, 1], offset: ?>>, memref<32x64xf32, strided<[512, 1], offset: ?>>) outs(%subview : memref<32x32xf32, strided<[32, 1], offset: ?>>) { - ^bb0(%in: f32, %in_2: f32, %out: f32): - %0 = arith.mulf %in, %in_2 : f32 - %1 = arith.addf %out, %0 : f32 - linalg.yield %1 : f32 - } - } - return -} - -// CHECK-LABEL: mha_query_times_key -// CHECK-SAME: %[[ARG0:.+]]: memref<64x32x8x64xf32>, %[[ARG1:.+]]: memref<64x32x8x64xf32>, %[[ARG2:.+]]: memref<64x8x32x32xf32> -// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: scf.forall (%[[ARG3:.+]], %[[ARG4:.+]]) in (64, 8) -// CHECK: %[[SUB:.+]] = memref.subview %[[ARG2]][%[[ARG3]], %[[ARG4]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] -// CHECK-SAME: : memref<64x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> -// CHECK: %[[FILL:.+]] = xsmm.unary.dispatch zero [32, 32, 1, 32] flags = (bcast_scalar) data_type = f32 -// CHECK: xsmm.unary zero(data_type = f32, %[[FILL]], %[[CST]], %[[SUB]]) -// CHECK: %[[SUB_0:.+]] = memref.subview %[[ARG0]][%[[ARG3]], 0, %[[ARG4]], 0] [1, 32, 1, 64] [1, 1, 1, 1] -// CHECK-SAME: : memref<64x32x8x64xf32> to memref<32x64xf32, strided<[512, 1], offset: ?>> -// CHECK: %[[SUB_1:.+]] = memref.subview %[[ARG1]][%[[ARG3]], 0, %[[ARG4]], 0] [1, 32, 1, 64] [1, 1, 1, 1] -// CHECK-SAME: : memref<64x32x8x64xf32> to memref<32x64xf32, strided<[512, 1], offset: ?>> -// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<64x32xf32> -// CHECK: %[[TRAN:.+]] = xsmm.unary.dispatch transpose [32, 64, 512, 32] flags = (none) data_type = f32 -// CHECK: xsmm.unary transpose(data_type = f32, %[[TRAN]], %[[SUB_0]], %[[ALLOC]]) -// CHECK: %[[GEMM:.+]] = xsmm.gemm.dispatch [32, 32, 64, 512, 32, 32] flags = (none) data_type = f32 -// CHECK: xsmm.gemm(data_type = f32, %[[GEMM]], %[[SUB_1]], %[[ALLOC]], %[[SUB]]) - -// ----- - -#map = affine_map<(d0, d1, d2) -> (d0, d1)> -#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> -#map2 = affine_map<(d0, d1, d2) -> (d0, d2)> - -func.func @mha_out_softmax_times_value(%arg0: memref<64x8x32x32xf32>, %arg1: memref<64x32x8x64xf32>, - %arg2: memref<64x32x8x64xf32>) { - %cst = arith.constant 0.000000e+00 : f32 - scf.forall (%arg3, %arg4) in (64, 8) { - %subview = memref.subview %arg2[%arg3, 0, %arg4, 0] [1, 32, 1, 64] [1, 1, 1, 1] : memref<64x32x8x64xf32> to memref<32x64xf32, strided<[512, 1], offset: ?>> - linalg.fill ins(%cst : f32) outs(%subview : memref<32x64xf32, strided<[512, 1], offset: ?>>) - %subview_0 = memref.subview %arg0[%arg3, %arg4, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<64x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - %subview_1 = memref.subview %arg1[%arg3, 0, %arg4, 0] [1, 32, 1, 64] [1, 1, 1, 1] : memref<64x32x8x64xf32> to memref<32x64xf32, strided<[512, 1], offset: ?>> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["parallel", "reduction", "parallel"]} - ins(%subview_0, %subview_1 : memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32x64xf32, strided<[512, 1], offset: ?>>) outs(%subview : memref<32x64xf32, strided<[512, 1], offset: ?>>) { - ^bb0(%in: f32, %in_2: f32, %out: f32): - %0 = arith.mulf %in, %in_2 : f32 - %1 = arith.addf %out, %0 : f32 - linalg.yield %1 : f32 - } - } - return -} - -// CHECK-LABEL: mha_out_softmax_times_value -// CHECK-SAME: %[[ARG0:.+]]: memref<64x8x32x32xf32>, %[[ARG1:.+]]: memref<64x32x8x64xf32>, %[[ARG2:.+]]: memref<64x32x8x64xf32> -// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: scf.forall (%[[ARG3:.+]], %[[ARG4:.+]]) in (64, 8) -// CHECK: %[[SUB:.+]] = memref.subview %[[ARG2]][%[[ARG3]], 0, %[[ARG4]], 0] [1, 32, 1, 64] [1, 1, 1, 1] -// CHECK-SAME: : memref<64x32x8x64xf32> to memref<32x64xf32, strided<[512, 1], offset: ?>> -// CHECK: %[[FILL:.+]] = xsmm.unary.dispatch zero [32, 64, 1, 512] flags = (bcast_scalar) data_type = f32 -// CHECK: xsmm.unary zero(data_type = f32, %[[FILL]], %[[CST]], %[[SUB]]) -// CHECK: %[[SUB_0:.+]] = memref.subview %[[ARG0]][%[[ARG3]], %[[ARG4]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] -// CHECK-SAME: : memref<64x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> -// CHECK: %[[SUB_1:.+]] = memref.subview %[[ARG1]][%[[ARG3]], 0, %[[ARG4]], 0] [1, 32, 1, 64] [1, 1, 1, 1] -// CHECK-SAME: : memref<64x32x8x64xf32> to memref<32x64xf32, strided<[512, 1], offset: ?>> -// CHECK: %[[GEMM:.+]] = xsmm.gemm.dispatch [32, 64, 32, 32, 512, 512] flags = (none) data_type = f32 -// CHECK: xsmm.gemm(data_type = f32, %[[GEMM]], %[[SUB_0]], %[[SUB_1]], %[[SUB]]) - -// ----- - -#map = affine_map<(d0, d1, d2) -> (d0, d1)> -#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> -#map2 = affine_map<(d0, d1, d2) -> (d0, d2)> - -func.func @mha_projection(%arg0: memref<512x8x64xf32>, %arg1: memref<64x32x512xf32>, %arg2: memref<64x32x8x64xf32>) { - %cst = arith.constant 0.000000e+00 : f32 - scf.forall (%arg3, %arg4) in (64, 8) { - %subview = memref.subview %arg2[%arg3, 0, %arg4, 0] [1, 32, 1, 64] [1, 1, 1, 1] : memref<64x32x8x64xf32> to memref<32x64xf32, strided<[512, 1], offset: ?>> - linalg.fill ins(%cst : f32) outs(%subview : memref<32x64xf32, strided<[512, 1], offset: ?>>) - %subview_0 = memref.subview %arg1[%arg3, 0, 0] [1, 32, 512] [1, 1, 1] : memref<64x32x512xf32> to memref<32x512xf32, strided<[512, 1], offset: ?>> - %subview_1 = memref.subview %arg0[0, %arg4, 0] [512, 1, 64] [1, 1, 1] : memref<512x8x64xf32> to memref<512x64xf32, strided<[512, 1], offset: ?>> - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["parallel", "reduction", "parallel"]} - ins(%subview_0, %subview_1 : memref<32x512xf32, strided<[512, 1], offset: ?>>, memref<512x64xf32, strided<[512, 1], offset: ?>>) outs(%subview : memref<32x64xf32, strided<[512, 1], offset: ?>>) { - ^bb0(%in: f32, %in_2: f32, %out: f32): - %0 = arith.mulf %in, %in_2 : f32 - %1 = arith.addf %out, %0 : f32 - linalg.yield %1 : f32 - } - } - return -} - -// CHECK-LABEL: mha_projection -// CHECK-SAME: %[[ARG0:.+]]: memref<512x8x64xf32>, %[[ARG1:.+]]: memref<64x32x512xf32>, %[[ARG2:.+]]: memref<64x32x8x64xf32> -// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: scf.forall (%[[ARG3:.+]], %[[ARG4:.+]]) in (64, 8) -// CHECK: %[[SUB:.+]] = memref.subview %[[ARG2]][%[[ARG3]], 0, %[[ARG4]], 0] [1, 32, 1, 64] [1, 1, 1, 1] -// CHECK-SAME: : memref<64x32x8x64xf32> to memref<32x64xf32, strided<[512, 1], offset: ?>> -// CHECK: %[[FILL:.+]] = xsmm.unary.dispatch zero [32, 64, 1, 512] flags = (bcast_scalar) data_type = f32 -// CHECK: xsmm.unary zero(data_type = f32, %[[FILL]], %[[CST]], %[[SUB]]) -// CHECK: %[[SUB_0:.+]] = memref.subview %[[ARG1]][%[[ARG3]], 0, 0] [1, 32, 512] [1, 1, 1] -// CHECK-SAME: : memref<64x32x512xf32> to memref<32x512xf32, strided<[512, 1], offset: ?>> -// CHECK: %[[SUB_1:.+]] = memref.subview %[[ARG0]][0, %[[ARG4]], 0] [512, 1, 64] [1, 1, 1] -// CHECK-SAME: : memref<512x8x64xf32> to memref<512x64xf32, strided<[512, 1], offset: ?>> -// CHECK: %[[GEMM:.+]] = xsmm.gemm.dispatch [32, 64, 512, 512, 512, 512] flags = (none) data_type = f32 -// CHECK: xsmm.gemm(data_type = f32, %[[GEMM]], %[[SUB_0]], %[[SUB_1]], %[[SUB]]) - -// ----- - -#map = affine_map<(d0, d1, d2, d3) -> (d1, d3)> -#map1 = affine_map<(d0, d1, d2, d3) -> (d3 floordiv 2, d2, d0)> -#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> - -func.func @square_vnni_gemm(%arg0: memref<64x64xbf16, strided<[64, 1], offset: ?>>, - %arg1: memref<32x64x2xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%arg0, %arg1 : memref<64x64xbf16, strided<[64, 1], offset: ?>>, memref<32x64x2xbf16>) - outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } - return -} - -// CHECK-LABEL: square_vnni_gemm -// CHECK-SAME: %[[ARG0:.+]]: memref<64x64xbf16, strided<[64, 1], offset: ?>>, -// CHECK-SAME: %[[ARG1:.+]]: memref<32x64x2xbf16>, -// CHECK-SAME: %[[ARG2:.+]]: memref<64x64xbf16, strided<[64, 1], offset: ?>> -// CHECK: %[[DIS:.+]] = xsmm.gemm.dispatch [64, 64, 64, 64, 64, 64] flags = (vnni_b) data_type = bf16 -// CHECK: xsmm.gemm(data_type = bf16, %[[DIS]], %[[ARG0]], %[[ARG1]], %[[ARG2]]) - -// ----- - -#map = affine_map<(d0, d1, d2, d3) -> (d1, d3)> -#map1 = affine_map<(d0, d1, d2, d3) -> (d3 floordiv 2, d2, d0)> -#map2 = affine_map<(d0, d1, d2, d3) -> (d2, d1)> - -// Require a transpose on C, before mapping to vnni Gemm. -func.func @expect_not_to_match_vnni_gemm(%arg0: memref<64x64xbf16, strided<[64, 1], offset: ?>>, - %arg1: memref<32x64x2xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%arg0, %arg1 : memref<64x64xbf16, strided<[64, 1], offset: ?>>, memref<32x64x2xbf16>) - outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } - return -} - -// CHECK-LABEL: expect_not_to_match_vnni_gemm -// CHECK-NOT: xsmm.gemm -// CHECK: linalg.generic - -// ----- - -#map = affine_map<(d0, d1, d2, d3) -> (d1, d3)> -#map1 = affine_map<(d0, d1, d2, d3) -> (d3 floordiv 5, d2, d0)> -#map2 = affine_map<(d0, d1, d2, d3) -> (d2, d1)> - -// Not VNNI layout. A factor of 5 is not VNNI. -func.func @expect_not_to_match_vnni_gemm(%arg0: memref<64x64xbf16, strided<[64, 1], offset: ?>>, - %arg1: memref<32x64x2xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%arg0, %arg1 : memref<64x64xbf16, strided<[64, 1], offset: ?>>, memref<32x64x2xbf16>) - outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } - return -} - -// CHECK-LABEL: expect_not_to_match_vnni_gemm -// CHECK-NOT: xsmm.gemm -// CHECK: linalg.generic - -// ----- - -#map = affine_map<(d0, d1, d2, d3) -> (d3, d1)> -#map1 = affine_map<(d0, d1, d2, d3) -> (d3 floordiv 2, d2, d0)> -#map2 = affine_map<(d0, d1, d2, d3) -> (d2, d1)> - -// Require a transpose on A, before mapping to vnni Gemm. -func.func @expect_not_to_match_vnni_gemm(%arg0: memref<64x64xbf16, strided<[64, 1], offset: ?>>, - %arg1: memref<32x64x2xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%arg0, %arg1 : memref<64x64xbf16, strided<[64, 1], offset: ?>>, memref<32x64x2xbf16>) - outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } - return -} - -// CHECK-LABEL: expect_not_to_match_vnni_gemm -// CHECK-NOT: xsmm.gemm -// CHECK: linalg.generic - -// ----- - -#map = affine_map<(d0, d1, d2, d3) -> (d0, d2)> -#map1 = affine_map<(d0, d1, d2, d3) -> (d2 floordiv 2, d1, d3)> -#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)> - -// Make sure we can handle interchange on the iterators, but with the right -// access patterns. -func.func @vnni_gemm_interchanged(%arg0: memref<64x64xbf16, strided<[64, 1], offset: ?>>, - %arg1: memref<32x64x2xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["parallel", "parallel", "reduction", "reduction"]} - ins(%arg0, %arg1 : memref<64x64xbf16, strided<[64, 1], offset: ?>>, memref<32x64x2xbf16>) - outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } - return -} - -// CHECK-LABEL: vnni_gemm_interchanged -// CHECK-SAME: %[[ARG0:.+]]: memref<64x64xbf16, strided<[64, 1], offset: ?>>, -// CHECK-SAME: %[[ARG1:.+]]: memref<32x64x2xbf16>, -// CHECK-SAME: %[[ARG2:.+]]: memref<64x64xbf16, strided<[64, 1], offset: ?>> -// CHECK: %[[DIS:.+]] = xsmm.gemm.dispatch [64, 64, 64, 64, 64, 64] flags = (vnni_b) data_type = bf16 -// CHECK: xsmm.gemm(data_type = bf16, %[[DIS]], %[[ARG0]], %[[ARG1]], %[[ARG2]]) - -// ----- - -#map = affine_map<(d0, d1, d2, d3) -> (d1, d3)> -#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d3 floordiv 2)> -#map2 = affine_map<(d0, d1, d2, d3) -> (d2, d1)> - -// Not VNNI layout. The VNNI is not innermost in the access pattern for B. -func.func @expect_not_to_match_vnni_gemm(%arg0: memref<64x64xbf16, strided<[64, 1], offset: ?>>, - %arg1: memref<2x64x32xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%arg0, %arg1 : memref<64x64xbf16, strided<[64, 1], offset: ?>>, memref<2x64x32xbf16>) - outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } - return -} - -// CHECK-LABEL: expect_not_to_match_vnni_gemm -// CHECK-NOT: xsmm.gemm -// CHECK: linalg.generic - - -// ----- - -#map = affine_map<(d0, d1, d2, d3) -> (d1, d3)> -#map1 = affine_map<(d0, d1, d2, d3) -> (d3 floordiv 2, d2, d0)> -#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> - -func.func @non_square_vnni_gemm(%arg0: memref<64x16xbf16, strided<[64, 1], offset: ?>>, - %arg1: memref<8x64x2xbf16>, %arg2: memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%arg0, %arg1 : memref<64x16xbf16, strided<[64, 1], offset: ?>>, memref<8x64x2xbf16>) - outs(%arg2 : memref<64x64xbf16, strided<[64, 1], offset: ?>>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } - return -} - -// CHECK-LABEL: non_square_vnni_gemm -// CHECK-SAME: %[[ARG0:.+]]: memref<64x16xbf16, strided<[64, 1], offset: ?>>, -// CHECK-SAME: %[[ARG1:.+]]: memref<8x64x2xbf16>, -// CHECK-SAME: %[[ARG2:.+]]: memref<64x64xbf16, strided<[64, 1], offset: ?>> -// CHECK: %[[DIS:.+]] = xsmm.gemm.dispatch [64, 64, 16, 64, 64, 64] flags = (vnni_b) data_type = bf16 -// CHECK: xsmm.gemm(data_type = bf16, %[[DIS]], %[[ARG0]], %[[ARG1]], %[[ARG2]]) - -// ----- - -#map = affine_map<(d0, d1, d2, d3) -> (d1, d3)> -#map1 = affine_map<(d0, d1, d2, d3) -> (d3 floordiv 2, d2, d0)> -#map2 = affine_map<(d0, d1, d2, d3) -> (d1, d2)> - -func.func @non_square_vnni_gemm_1(%arg0: memref<4x16xbf16, strided<[64, 1], offset: ?>>, - %arg1: memref<8x64x2xbf16>, %arg2: memref<4x64xbf16, strided<[64, 1], offset: ?>>) { - linalg.generic { - indexing_maps = [#map, #map1, #map2], - iterator_types = ["reduction", "parallel", "parallel", "reduction"]} - ins(%arg0, %arg1 : memref<4x16xbf16, strided<[64, 1], offset: ?>>, memref<8x64x2xbf16>) - outs(%arg2 : memref<4x64xbf16, strided<[64, 1], offset: ?>>) { - ^bb0(%in: bf16, %in_2: bf16, %out: bf16): - %1 = arith.mulf %in, %in_2 : bf16 - %2 = arith.addf %out, %1 : bf16 - linalg.yield %2 : bf16 - } - return -} - -// CHECK-LABEL: non_square_vnni_gemm_1 -// CHECK-SAME: %[[ARG0:.+]]: memref<4x16xbf16, strided<[64, 1], offset: ?>>, -// CHECK-SAME: %[[ARG1:.+]]: memref<8x64x2xbf16>, -// CHECK-SAME: %[[ARG2:.+]]: memref<4x64xbf16, strided<[64, 1], offset: ?>> -// CHECK: %[[DIS:.+]] = xsmm.gemm.dispatch [4, 64, 16, 64, 64, 64] flags = (vnni_b) data_type = bf16 -// CHECK: xsmm.gemm(data_type = bf16, %[[DIS]], %[[ARG0]], %[[ARG1]], %[[ARG2]]) diff --git a/test/Conversion/LinalgToXsmm/linalg-to-unary.mlir b/test/Conversion/LinalgToXsmm/linalg-to-unary.mlir deleted file mode 100644 index 217491ebe..000000000 --- a/test/Conversion/LinalgToXsmm/linalg-to-unary.mlir +++ /dev/null @@ -1,520 +0,0 @@ -// RUN: tpp-opt %s -convert-linalg-to-xsmm -split-input-file | FileCheck %s - -func.func @fill_op(%arg0: memref<32x32xf32>) { - %cst = arith.constant 0.0 : f32 - linalg.fill ins(%cst : f32) outs(%arg0 : memref<32x32xf32>) - return -} - -// CHECK-LABEL: fill_op -// CHECK-SAME: %[[ARG0:.+]]: memref<32x32xf32> -// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[DIS:.+]] = xsmm.unary.dispatch zero [32, 32, 1, 32] flags = (bcast_scalar) data_type = f32 -// CHECK: xsmm.unary zero(data_type = f32, %[[DIS]], %[[CST]], %[[ARG0]]) - -// ----- - -func.func @fill_op(%arg0: tensor<32x32xf32>) -> tensor<32x32xf32> { - %cst = arith.constant 0.0 : f32 - %0 = linalg.fill ins(%cst : f32) outs(%arg0 : tensor<32x32xf32>) -> tensor<32x32xf32> - return %0 : tensor<32x32xf32> -} - -// CHECK-LABEL: fill_op -// CHECK: linalg.fill -// CHECK-NOT: xsmm.unary - -// ----- - -func.func @fill_op(%arg0: memref<32x32xf32, strided<[32, 2], offset: ?>>) { - %cst = arith.constant 0.0 : f32 - linalg.fill ins(%cst : f32) outs(%arg0 : memref<32x32xf32, strided<[32, 2], offset: ?>>) - return -} - -// CHECK-LABEL: fill_op -// CHECK: linalg.fill -// CHECK-NOT: xsmm.unary - -// ----- - -func.func @fill_op(%arg0: memref<32x32xf32>, %cst: f32) { - linalg.fill ins(%cst : f32) outs(%arg0 : memref<32x32xf32>) - return -} - -// CHECK-LABEL: fill_op -// CHECK: linalg.fill -// CHECK-NOT: xsmm.unary - -// ----- - -func.func @fill_op(%arg0: memref<32x32xbf16>) { - %cst = arith.constant 0.0 : bf16 - linalg.fill ins(%cst : bf16) outs(%arg0 : memref<32x32xbf16>) - return -} - -// CHECK-LABEL: fill_op -// CHECK-SAME: %[[ARG0:.+]]: memref<32x32xbf16> -// CHECK: %[[CST:.+]] = arith.constant 0.000000e+00 : bf16 -// CHECK: %[[DIS:.+]] = xsmm.unary.dispatch zero [32, 32, 1, 32] flags = (bcast_scalar) data_type = bf16 -// CHECK: xsmm.unary zero(data_type = bf16, %[[DIS]], %[[CST]], %[[ARG0]]) - -// ----- - -func.func @fill_op(%arg0: memref<32xf32>) { - %cst = arith.constant 0.0 : f32 - linalg.fill ins(%cst : f32) outs(%arg0 : memref<32xf32>) - return -} - -// CHECK-LABEL: fill_op -// CHECK: linalg.fill -// CHECK-NOT: xsmm.unary - -// ----- - -func.func @transpose_op(%arg0: memref<3x5xf32>, %arg1: memref<5x3xf32>) { - linalg.transpose ins(%arg0: memref<3x5xf32>) outs(%arg1: memref<5x3xf32>) permutation = [1, 0] - return -} - -// CHECK-LABEL: transpose_op -// CHECK-SAME: %[[ARG0:.+]]: memref<3x5xf32>, %[[ARG1:.+]]: memref<5x3xf32> -// CHECK: %[[DIS:.+]] = xsmm.unary.dispatch transpose [3, 5, 5, 3] flags = (none) data_type = f32 -// CHECK: xsmm.unary transpose(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG1]]) - -// ----- - -func.func @transpose_op(%arg0: memref<5x3x5xf32>, %arg1: memref<5x5x3xf32>) { - linalg.transpose ins(%arg0: memref<5x3x5xf32>) outs(%arg1: memref<5x5x3xf32>) permutation = [0, 2, 1] - return -} - -// CHECK-LABEL: transpose_op -// CHECK-NOT: xsmm.unary transpose -// CHECK: linalg.transpose - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> - -func.func @relu(%arg0: memref<4x3xf32>) { - %cst = arith.constant 0.000000e+00 : f32 - linalg.generic { - indexing_maps = [#map, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg0 : memref<4x3xf32>) outs(%arg0 : memref<4x3xf32>) { - ^bb0(%in: f32, %out: f32): - %0 = arith.maximumf %in, %cst : f32 - linalg.yield %0 : f32 - } - return -} - -// CHECK-LABEL: relu -// CHECK-SAME: %[[ARG0:.+]]: memref<4x3xf32> -// CHECK: %[[DIS:.+]] = xsmm.unary.dispatch relu [4, 3, 3, 3] flags = (none) data_type = f32 -// CHECK: xsmm.unary relu(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG0]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (0, d1)> - -func.func @relu_1(%arg0: memref<1x3xf32>, %arg1: memref<4x3xf32>) { - %cst = arith.constant 0.000000e+00 : f32 - linalg.generic { - indexing_maps = [#map1, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg0 : memref<1x3xf32>) outs(%arg1 : memref<4x3xf32>) { - ^bb0(%in: f32, %out: f32): - %0 = arith.maximumf %in, %cst : f32 - linalg.yield %0 : f32 - } - return -} - -// CHECK-LABEL: relu_1 -// CHECK-SAME: %[[ARG0:.+]]: memref<1x3xf32>, %[[ARG1:.+]]: memref<4x3xf32> -// CHECK: %[[DIS:.+]] = xsmm.unary.dispatch relu [4, 3, 3, 3] flags = (bcast_col) data_type = f32 -// CHECK: xsmm.unary relu(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG1]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> (d0, 0)> - -func.func @relu_2(%arg0: memref<4x1xf32>, %arg1: memref<4x3xf32>) { - %cst = arith.constant 0.000000e+00 : f32 - linalg.generic { - indexing_maps = [#map1, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg0 : memref<4x1xf32>) outs(%arg1 : memref<4x3xf32>) { - ^bb0(%in: f32, %out: f32): - %0 = arith.maximumf %in, %cst : f32 - linalg.yield %0 : f32 - } - return -} - -// CHECK-LABEL: relu_2 -// CHECK-SAME: %[[ARG0:.+]]: memref<4x1xf32>, %[[ARG1:.+]]: memref<4x3xf32> -// CHECK: %[[DIS:.+]] = xsmm.unary.dispatch relu [4, 3, 1, 3] flags = (bcast_row) data_type = f32 -// CHECK: xsmm.unary relu(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG1]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> ()> - -func.func @relu_3(%arg0: f32, %arg1: memref<4x3xf32>) { - %cst = arith.constant 0.000000e+00 : f32 - linalg.generic { - indexing_maps = [#map1, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg0 : f32) outs(%arg1 : memref<4x3xf32>) { - ^bb0(%in: f32, %out: f32): - %0 = arith.maximumf %in, %cst : f32 - linalg.yield %0 : f32 - } - return -} - -// CHECK-LABEL: relu_3 -// CHECK-SAME: %[[ARG0:.+]]: f32, %[[ARG1:.+]]: memref<4x3xf32> -// CHECK: %[[DIS:.+]] = xsmm.unary.dispatch relu [4, 3, 1, 3] flags = (bcast_scalar) data_type = f32 -// CHECK: xsmm.unary relu(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG1]]) - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> -#map1 = affine_map<(d0, d1) -> ()> - -func.func @relu_4(%arg0: f32, %arg1: memref<4x3xf32, strided<[?, ?], offset: 0>>) { - %cst = arith.constant 0.000000e+00 : f32 - linalg.generic { - indexing_maps = [#map1, #map], - iterator_types = ["parallel", "parallel"]} - ins(%arg0 : f32) outs(%arg1 : memref<4x3xf32, strided<[?, ?], offset: 0>>) { - ^bb0(%in: f32, %out: f32): - %0 = arith.maximumf %in, %cst : f32 - linalg.yield %0 : f32 - } - return -} - -// CHECK-LABEL: relu_4 -// CHECK-NOT: xsmm.unary relu -// CHECK: linalg.generic - -// ----- - -#map = affine_map<(d0, d1) -> (d0, d1)> - -func.func @relu_5(%arg1: memref<4x3xf32>) { - %cst = arith.constant 0.000000e+00 : f32 - linalg.generic { - indexing_maps = [#map], - iterator_types = ["parallel", "parallel"]} - outs(%arg1 : memref<4x3xf32>) { - ^bb0(%in: f32): - %0 = arith.maximumf %in, %cst : f32 - linalg.yield %0 : f32 - } - return -} - -// CHECK-LABEL: relu_5 -// CHECK-SAME: %[[ARG1:.+]]: memref<4x3xf32> -// CHECK: %[[DIS:.+]] = xsmm.unary.dispatch relu [4, 3, 3, 3] flags = (none) data_type = f32 -// CHECK: xsmm.unary relu(data_type = f32, %[[DIS]], %[[ARG1]], %[[ARG1]]) - -// ----- - -#map0 = affine_map<(d0, d1) -> (d1)> -#map1 = affine_map<(d0, d1) -> (d0, d1)> - -func.func @identity_1(%arg0: memref<512xf32>, %arg1: memref<128x512xf32>) { - linalg.generic { - indexing_maps = [#map0, #map1], - iterator_types = ["parallel", "parallel"]} - ins(%arg0 : memref<512xf32>) outs(%arg1 : memref<128x512xf32>) { - ^bb0(%arg9: f32, %arg10: f32): - linalg.yield %arg9 : f32 - } - return -} - -// CHECK-LABEL: identity_1 -// CHECK-SAME: %[[ARG0:.+]]: memref<512xf32>, %[[ARG1:.+]]: memref<128x512xf32> -// CHECK: %[[DIS:.+]] = xsmm.unary.dispatch identity [128, 512, 512, 512] flags = (bcast_col) data_type = f32 -// CHECK: xsmm.unary identity(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG1]]) - -// ----- - -#map0 = affine_map<(d0, d1) -> (d0, d1)> - -func.func @identity_2(%arg0: memref<128x512xf32>, %arg1: memref<128x512xf32>) { - linalg.generic { - indexing_maps = [#map0, #map0], - iterator_types = ["parallel", "parallel"]} - ins(%arg0 : memref<128x512xf32>) outs(%arg1 : memref<128x512xf32>) { - ^bb0(%arg9: f32, %arg10: f32): - linalg.yield %arg9 : f32 - } - return -} - -// CHECK-LABEL: identity_2 -// CHECK-SAME: %[[ARG0:.+]]: memref<128x512xf32>, %[[ARG1:.+]]: memref<128x512xf32> -// CHECK: %[[DIS:.+]] = xsmm.unary.dispatch identity [128, 512, 512, 512] flags = (none) data_type = f32 -// CHECK: xsmm.unary identity(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG1]]) - -// ----- - -#map0 = affine_map<(d0, d1) -> (d0, 0)> -#map1 = affine_map<(d0, d1) -> (d0, d1)> - -func.func @identity_3(%arg0: memref<128x1xf32>, %arg1: memref<128x512xf32>) { - linalg.generic { - indexing_maps = [#map0, #map1], - iterator_types = ["parallel", "parallel"]} - ins(%arg0 : memref<128x1xf32>) outs(%arg1 : memref<128x512xf32>) { - ^bb0(%arg9: f32, %arg10: f32): - linalg.yield %arg9 : f32 - } - return -} - -// CHECK-LABEL: identity_3 -// CHECK-SAME: %[[ARG0:.+]]: memref<128x1xf32>, %[[ARG1:.+]]: memref<128x512xf32> -// CHECK: %[[DIS:.+]] = xsmm.unary.dispatch identity [128, 512, 1, 512] flags = (bcast_row) data_type = f32 -// CHECK: xsmm.unary identity(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG1]]) - -// ----- - -func.func @vnni_packing(%arg0 : memref<32x32xbf16, strided<[512, 1], offset: ?>>, - %arg1: memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>>) { - %expand_shape = memref.expand_shape %arg0 [[0, 1], [2]] output_shape[16, 2, 32] - : memref<32x32xbf16, strided<[512, 1], offset: ?>> - into memref<16x2x32xbf16, strided<[1024, 512, 1], offset: ?>> - linalg.transpose ins(%expand_shape : memref<16x2x32xbf16, strided<[1024, 512, 1], offset: ?>>) - outs(%arg1 : memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>>) permutation = [0, 2, 1] - return -} - -// CHECK-LABEL: vnni_packing -// CHECK-SAME: %[[ARG0:.+]]: memref<32x32xbf16, strided<[512, 1], offset: ?>>, -// CHECK-SAME: %[[ARG1:.+]]: memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>> -// CHECK: %[[DIS:.+]] = xsmm.unary.dispatch vnni_2 [32, 32, 512, 32] flags = (none) data_type = bf16 -// CHECK: xsmm.unary vnni_2(data_type = bf16, %[[DIS]], %[[ARG0]], %[[ARG1]]) - -// ----- - -func.func @not_vnni_packing(%arg0 : memref<32x32xf32, strided<[512, 1], offset: ?>>, - %arg1: memref<16x32x2xf32, strided<[64, 2, 1], offset: ?>>) { - %expand_shape = memref.expand_shape %arg0 [[0, 1], [2]] output_shape[16, 2, 32] - : memref<32x32xf32, strided<[512, 1], offset: ?>> - into memref<16x2x32xf32, strided<[1024, 512, 1], offset: ?>> - linalg.transpose ins(%expand_shape : memref<16x2x32xf32, strided<[1024, 512, 1], offset: ?>>) - outs(%arg1 : memref<16x32x2xf32, strided<[64, 2, 1], offset: ?>>) permutation = [0, 2, 1] - return -} - -// CHECK-LABEL: not_vnni_packing -// CHECK-NOT: xsmm.unary vnni_2 - -// ----- - -#map = affine_map<(d0, d1) -> (d1)> -#map1 = affine_map<(d0, d1) -> (d0, d1)> - -func.func @identity_4(%arg0: memref<1024xbf16>, %arg1: memref<128x1024xbf16>) { - linalg.generic { - indexing_maps = [#map, #map1], - iterator_types = ["parallel", "parallel"]} - ins(%arg0 : memref<1024xbf16>) outs(%arg1 : memref<128x1024xbf16>) { - ^bb0(%arg2: bf16, %arg3: bf16): - linalg.yield %arg2 : bf16 - } - return -} - -// CHECK-LABEL: identity_4 -// CHECK-SAME: %[[ARG0:.+]]: memref<1024xbf16>, %[[ARG1:.+]]: memref<128x1024xbf16> -// CHECK: %[[DIS:.+]] = xsmm.unary.dispatch identity [128, 1024, 1024, 1024] flags = (bcast_col) data_type = bf16 -// CHECK: xsmm.unary identity(data_type = bf16, %[[DIS]], %[[ARG0]], %[[ARG1]]) - -// ----- - -#map = affine_map<(d0) -> (d0 * 32)> - -func.func @vnni_packing_1(%arg1: memref<128x128xbf16>, %arg2: memref<4x4x16x32x2xbf16>) { - scf.forall (%arg3, %arg4) in (4, 4) { - %0 = affine.apply #map(%arg4) - %1 = affine.apply #map(%arg3) - %subview = memref.subview %arg1[%0, %1] [32, 32] [1, 1] - : memref<128x128xbf16> to memref<32x32xbf16, strided<[128, 1], offset: ?>> - %subview_1 = memref.subview %arg2[%arg3, %arg4, 0, 0, 0] [1, 1, 16, 32, 2] [1, 1, 1, 1, 1] - : memref<4x4x16x32x2xbf16> to memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>> - %expand_shape = memref.expand_shape %subview [[0, 1], [2]] output_shape[16, 2, 32] - : memref<32x32xbf16, strided<[128, 1], offset: ?>> into memref<16x2x32xbf16, strided<[256, 128, 1], offset: ?>> - linalg.transpose ins(%expand_shape : memref<16x2x32xbf16, strided<[256, 128, 1], offset: ?>>) - outs(%subview_1 : memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>>) - permutation = [0, 2, 1] - } - return -} - -// CHECK: #[[MAP:.+]] = affine_map<(d0) -> (d0 * 32)> - -// CHECK-LABEL: vnni_packing_1 -// CHECK-SAME: %[[ARG0:.+]]: memref<128x128xbf16>, %[[ARG1:.+]]: memref<4x4x16x32x2xbf16> -// CHECK: scf.forall (%[[ARG2:.+]], %[[ARG3:.+]]) in (4, 4) -// CHECK: %[[OFF:.+]] = affine.apply #[[MAP]](%[[ARG3]]) -// CHECK: %[[OFF_1:.+]] = affine.apply #[[MAP]](%[[ARG2]]) -// CHECK: %[[SUB:.+]] = memref.subview %[[ARG0]][%[[OFF]], %[[OFF_1]]] [32, 32] [1, 1] -// CHECK-SAME: : memref<128x128xbf16> to memref<32x32xbf16, strided<[128, 1], offset: ?>> -// CHECK: %[[SUB_0:.+]] = memref.subview %[[ARG1]][%[[ARG2]], %[[ARG3]], 0, 0, 0] [1, 1, 16, 32, 2] [1, 1, 1, 1, 1] -// CHECK-SAME: : memref<4x4x16x32x2xbf16> to memref<16x32x2xbf16, strided<[64, 2, 1], offset: ?>> -// CHECK: %[[DIS:.+]] = xsmm.unary.dispatch vnni_2 [32, 32, 128, 32] flags = (none) data_type = bf16 -// CHECK: xsmm.unary vnni_2(data_type = bf16, %[[DIS]], %[[SUB]], %[[SUB_0]]) - -// ----- - -#map3 = affine_map<(d0, d1) -> (d0, d1)> - -func.func @relu_no_input(%arg0: memref<10x10xf32>) { - %cst_1 = arith.constant 0.000000e+00 : f32 - linalg.generic { - indexing_maps = [#map3], iterator_types = ["parallel", "parallel"]} outs(%arg0: memref<10x10xf32>) { - ^bb0(%out: f32): - %13 = arith.maximumf %out, %cst_1 : f32 - linalg.yield %13 : f32 - } - return -} - -// CHECK-LABEL: relu_no_input -// CHECK-SAME: %[[ARG0:.+]]: memref<10x10xf32> -// CHECK: %[[DIS:.+]] = xsmm.unary.dispatch relu [10, 10, 10, 10] flags = (none) data_type = f32 -// CHECK: xsmm.unary relu(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG0]]) - -// ----- - -func.func @identity_5(%arg0 : memref<10xf32>, %arg1 : memref<10x10xf32>) { - linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%arg0 : memref<10xf32>) outs(%arg1 : memref<10x10xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } - return -} - -// CHECK-LABEL: identity_5 -// CHECK-SAME: %[[ARG0:.+]]: memref<10xf32>, %[[ARG1:.+]]: memref<10x10xf32> -// CHECK: %[[EXP:.+]] = memref.expand_shape %[[ARG0]] {{\[}}[0, 1]] output_shape [10, 1] : memref<10xf32> into memref<10x1xf32> -// CHECK: %[[DIS:.+]] = xsmm.unary.dispatch identity [10, 10, 1, 10] flags = (bcast_row) data_type = f32 -// CHECK: xsmm.unary identity(data_type = f32, %[[DIS]], %[[EXP]], %[[ARG1]]) - -// ----- - -func.func @identity_bcast_row(%arg0 : memref<10x1xf32>, %arg1 : memref<10x10xf32>) { - linalg.generic { - indexing_maps = [affine_map<(d0, d1) -> (d0, 0)>, affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - ins(%arg0 : memref<10x1xf32>) outs(%arg1 : memref<10x10xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } - return -} - -// CHECK-LABEL: identity_bcast_row -// CHECK-SAME: %[[ARG0:.+]]: memref<10x1xf32>, %[[ARG1:.+]]: memref<10x10xf32> -// CHECK: %[[DIS:.+]] = xsmm.unary.dispatch identity [10, 10, 1, 10] flags = (bcast_row) data_type = f32 -// CHECK: xsmm.unary identity(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG1]]) - -// ----- - -func.func @identity_6(%arg0 : memref<10xf32>, %arg1 : memref<10xf32>) { - linalg.generic { - indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], - iterator_types = ["parallel"]} - ins(%arg0 : memref<10xf32>) outs(%arg1 : memref<10xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } - return -} - -// CHECK-LABEL: identity_6 -// CHECK-NOT: xsmm.unary identity - -// ----- - -#map1 = affine_map<(d0, d1) -> (d0, d1)> -#map2 = affine_map<(d0, d1) -> ()> - -// Rank zero is not matched. -func.func @identity7(%arg0: memref>, %arg1: memref<6x9xf32>) { - linalg.generic { - indexing_maps = [#map2, #map1], - iterator_types = ["parallel", "parallel"]} - ins(%arg0 : memref>) outs(%arg1 : memref<6x9xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } - return -} - -// CHECK-LABEL: identity7 -// CHECK: linalg.generic - -// ----- - -func.func @identity_8(%arg0: f32, %arg1: memref<6x9xf32>) { - linalg.fill ins(%arg0 : f32) outs(%arg1 : memref<6x9xf32>) - return -} - -// TODO: We would like to convert this fill ops too. -// CHECK-LABEL: identity_8 -// CHECK: linalg.fill - -// ----- - -#map = affine_map<(d0, d1) -> (d0)> -#map1 = affine_map<(d0, d1) -> (d0, d1)> - -func.func @identity_9(%arg0: memref<6xf32, strided<[1]>>, %arg1: memref<6x9xf32>) { - linalg.generic { - indexing_maps = [#map, #map1], - iterator_types = ["parallel", "parallel"]} - ins(%arg0 : memref<6xf32, strided<[1]>>) outs(%arg1 : memref<6x9xf32>) { - ^bb0(%in: f32, %out: f32): - linalg.yield %in : f32 - } - return -} - -// CHECK-LABEL: identity_9 -// CHECK-SAME: %[[ARG0:.+]]: memref<6xf32, strided<[1]>>, %[[ARG1:.+]]: memref<6x9xf32> -// CHECK: %[[DIS:.+]] = xsmm.unary.dispatch identity [6, 9, 1, 9] flags = (bcast_row) data_type = f32 -// CHECK: xsmm.unary identity(data_type = f32, %[[DIS]], %{{.+}}, %[[ARG1]]) - -// ----- - -func.func @linalg_copy(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) { - linalg.copy ins(%arg0 : memref<2x2xf32>) outs(%arg1 : memref<2x2xf32>) - return -} - -// CHECK-LABEL: linalg_copy -// CHECK-SAME: %[[ARG0:.+]]: memref<2x2xf32>, %[[ARG1:.+]]: memref<2x2xf32> -// CHECK: %[[DIS:.+]] = xsmm.unary.dispatch identity [2, 2, 2, 2] flags = (none) data_type = f32 -// CHECK: xsmm.unary identity(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG1]]) diff --git a/test/Dialect/Xsmm/xsmm-dispatch-invoke.mlir b/test/Dialect/Xsmm/xsmm-dispatch-invoke.mlir deleted file mode 100644 index 353408670..000000000 --- a/test/Dialect/Xsmm/xsmm-dispatch-invoke.mlir +++ /dev/null @@ -1,38 +0,0 @@ -// RUN: tpp-opt %s -verify-xsmm-calls -verify-diagnostics -split-input-file -// Make sure we do not emit any error here. - -func.func @identity(%arg0: memref<1x1xf32, strided<[8, 1], offset: ?>>, - %arg1: memref<1x1xf32, strided<[8, 1], offset: ?>>) { - %0 = xsmm.unary.dispatch identity [1, 1, 8, 8] flags = (none) data_type = f32 - xsmm.unary identity(data_type = f32, %0, %arg0, %arg1) - : (i64, memref<1x1xf32, strided<[8, 1], offset: ?>>, memref<1x1xf32, strided<[8, 1], offset: ?>>) -> () - return -} - -// ----- - -func.func @identity(%arg0: f32, - %arg1: memref<1x1xf32, strided<[8, 1], offset: ?>>) { - %0 = xsmm.unary.dispatch identity [1, 1, 8, 8] flags = (bcast_scalar) data_type = f32 - xsmm.unary identity(data_type = f32, %0, %arg0, %arg1) - : (i64, f32, memref<1x1xf32, strided<[8, 1], offset: ?>>) -> () - return -} - -// ----- - -func.func @identity(%arg0: f32, %arg1: memref<1x1xf32>) { - %0 = xsmm.unary.dispatch identity [1, 1, 1, 1] flags = (bcast_scalar) data_type = f32 - xsmm.unary identity(data_type = f32, %0, %arg0, %arg1) - : (i64, f32, memref<1x1xf32>) -> () - return -} - -// ----- - -func.func @gemm(%arg0: memref<3x6x2xbf16>, %arg1: memref<6x6xbf16>) { - %0 = xsmm.gemm.dispatch [6, 6, 6, 6, 6, 6] flags = (vnni_a) data_type = bf16 - xsmm.gemm(data_type = bf16, %0, %arg0, %arg0, %arg1) : - (i64, memref<3x6x2xbf16>, memref<3x6x2xbf16>, memref<6x6xbf16>) -> () - return -} diff --git a/test/Dialect/Xsmm/xsmm-invalid-dispatch-and-invoke.mlir b/test/Dialect/Xsmm/xsmm-invalid-dispatch-and-invoke.mlir deleted file mode 100644 index 3e5397529..000000000 --- a/test/Dialect/Xsmm/xsmm-invalid-dispatch-and-invoke.mlir +++ /dev/null @@ -1,244 +0,0 @@ -// RUN: tpp-opt %s -verify-xsmm-calls -verify-diagnostics -split-input-file - -func.func @gemm(%arg0: memref<3x3xf32>, %arg1: memref<3x3xf32>) { - %0 = xsmm.brgemm.dispatch [3, 3, 3, 3, 3, 3, 1, 1] flags = (none) data_type = f32 - // expected-error@+1 {{invalid dispatch operation}} - xsmm.gemm(data_type = f32, %0, %arg0, %arg0, %arg1) : - (i64, memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>) -> () - return -} - -// ----- - -func.func @gemm(%arg0: memref<3x3xf32>, %arg1: memref<3x3xf32>) { - %0 = xsmm.gemm.dispatch [3, 3, 3, 3, 3, 3] flags = (vnni_a) data_type = bf16 - // expected-error@+1 {{inconsistent data types}} - xsmm.gemm(data_type = f32, %0, %arg0, %arg0, %arg1) : - (i64, memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>) -> () - return -} - -// ----- - -func.func @gemm(%arg0: memref<3x3xbf16>, %arg1: memref<3x3xbf16>) { - %0 = xsmm.gemm.dispatch [3, 3, 3, 3, 3, 3] flags = (vnni_a) data_type = bf16 - // expected-error@+1 {{expect VNNI layout for operand A or invalid VNNI_A flags}} - xsmm.gemm(data_type = bf16, %0, %arg0, %arg0, %arg1) : - (i64, memref<3x3xbf16>, memref<3x3xbf16>, memref<3x3xbf16>) -> () - return -} - -// ----- - -func.func @gemm(%arg0: memref<3x3xbf16>, %arg1: memref<3x3xbf16>) { - %0 = xsmm.gemm.dispatch [3, 3, 3, 3, 3, 3] flags = (vnni_b) data_type = bf16 - // expected-error@+1 {{expect VNNI layout for operand B or invalid VNNI_B flags}} - xsmm.gemm(data_type = bf16, %0, %arg0, %arg0, %arg1) : - (i64, memref<3x3xbf16>, memref<3x3xbf16>, memref<3x3xbf16>) -> () - return -} - -// ----- - -func.func @gemm(%arg0: memref<3x3xbf16>, %arg1: memref<3x3xbf16>) { - %0 = xsmm.gemm.dispatch [3, 3, 3, 3, 3, 3] flags = (vnni_c) data_type = bf16 - // expected-error@+1 {{expect VNNI layout for operand C or invalid VNNI_C flags}} - xsmm.gemm(data_type = bf16, %0, %arg0, %arg0, %arg1) : - (i64, memref<3x3xbf16>, memref<3x3xbf16>, memref<3x3xbf16>) -> () - return -} - -// ----- - -func.func @brgemm(%arg0: memref<1x3x3xf32>, %arg1: memref<3x3xf32>) { - %0 = xsmm.gemm.dispatch [3, 3, 3, 3, 3, 3] flags = (none) data_type = f32 - %1 = arith.constant 1 : i64 - // expected-error@+1 {{invalid dispatch operation}} - xsmm.brgemm(data_type = f32, %0, %arg0, %arg0, %arg1, %1) : - (i64, memref<1x3x3xf32>, memref<1x3x3xf32>, memref<3x3xf32>, i64) -> () - return -} - -// ----- - -func.func @brgemm(%arg0: memref<1x3x3xf32>, %arg1: memref<3x3xf32>) { - %0 = xsmm.brgemm.dispatch [3, 3, 3, 3, 3, 3, 1, 1] flags = (none) data_type = bf16 - // expected-error@+1 {{inconsistent data types}} - xsmm.brgemm(data_type = f32, %0, %arg0, %arg0, %arg1, %0) : - (i64, memref<1x3x3xf32>, memref<1x3x3xf32>, memref<3x3xf32>, i64) -> () - return -} - -// ----- - -func.func @brgemm(%arg0: memref<1x3x3xbf16>, %arg1: memref<3x3xbf16>, %batch : i64) { - %0 = xsmm.brgemm.dispatch [3, 3, 3, 3, 3, 3, 1, 1] flags = (vnni_a) data_type = bf16 - // expected-error@+1 {{expect VNNI layout for operand A or invalid VNNI_A flags}} - xsmm.brgemm(data_type = bf16, %0, %arg0, %arg0, %arg1, %batch) : - (i64, memref<1x3x3xbf16>, memref<1x3x3xbf16>, memref<3x3xbf16>, i64) -> () - return -} - -// ----- - -func.func @brgemm(%arg0: memref<1x3x3xbf16>, %arg1: memref<3x3xbf16>, %batch : i64) { - %0 = xsmm.brgemm.dispatch [3, 3, 3, 3, 3, 3, 1, 1] flags = (vnni_b) data_type = bf16 - // expected-error@+1 {{expect VNNI layout for operand B or invalid VNNI_B flags}} - xsmm.brgemm(data_type = bf16, %0, %arg0, %arg0, %arg1, %batch) : - (i64, memref<1x3x3xbf16>, memref<1x3x3xbf16>, memref<3x3xbf16>, i64) -> () - return -} - -// ----- - -func.func @brgemm(%arg0: memref<1x3x3xbf16>, %arg1: memref<3x3xbf16>, %batch : i64) { - %0 = xsmm.brgemm.dispatch [3, 3, 3, 3, 3, 3, 1, 1] flags = (vnni_c) data_type = bf16 - // expected-error@+1 {{expect VNNI layout for operand C or invalid VNNI_C flags}} - xsmm.brgemm(data_type = bf16, %0, %arg0, %arg0, %arg1, %batch) : - (i64, memref<1x3x3xbf16>, memref<1x3x3xbf16>, memref<3x3xbf16>, i64) -> () - return -} - -// ----- - -func.func @fused_brgemm(%arg0: memref<1x3x3xf32>, %arg1: memref<3x3xf32>, %arg2: memref<3xf32>) { - %0 = xsmm.brgemm.dispatch [3, 3, 3, 3, 3, 3, 1, 1] flags = (none) data_type = f32 - // expected-error@+1 {{invalid dispatch operation}} - xsmm.fused_brgemm(data_type = f32, %0, %arg0, %arg0, %arg1, %arg2, %0) : - (i64, memref<1x3x3xf32>, memref<1x3x3xf32>, memref<3x3xf32>, memref<3xf32>, i64) -> () - return -} - -// ----- - -func.func @brgemm(%arg0: memref<1x3x3xf32>, %arg1: memref<3x3xf32>) { - %0 = xsmm.fused_brgemm.dispatch [3, 3, 3, 3, 3, 3, 1, 1] [add, relu] - flags = (none) binary_flags = (none) unary_flags = (none) data_type = bf16 - // expected-error@+1 {{inconsistent data types}} - xsmm.fused_brgemm(data_type = f32, %0, %arg0, %arg0, %arg1, %arg1, %0) : - (i64, memref<1x3x3xf32>, memref<1x3x3xf32>, memref<3x3xf32>, memref<3x3xf32>, i64) -> () - return -} - -// ----- - -func.func @brgemm(%arg0: memref<1x3x3xbf16>, %arg1: memref<3x3xbf16>) { - %0 = xsmm.fused_brgemm.dispatch [3, 3, 3, 3, 3, 3, 1, 1] [add, relu] - flags = (vnni_a) binary_flags = (none) unary_flags = (none) data_type = bf16 - // expected-error@+1 {{expect VNNI layout for operand A or invalid VNNI_A flags}} - xsmm.fused_brgemm(data_type = bf16, %0, %arg0, %arg0, %arg1, %arg1, %0) : - (i64, memref<1x3x3xbf16>, memref<1x3x3xbf16>, memref<3x3xbf16>, memref<3x3xbf16>, i64) -> () - return -} - -// ----- - -func.func @brgemm(%arg0: memref<1x3x3xbf16>, %arg1: memref<3x3xbf16>) { - %0 = xsmm.fused_brgemm.dispatch [3, 3, 3, 3, 3, 3, 1, 1] [add, relu] - flags = (vnni_b) binary_flags = (none) unary_flags = (none) data_type = bf16 - // expected-error@+1 {{expect VNNI layout for operand B or invalid VNNI_B flags}} - xsmm.fused_brgemm(data_type = bf16, %0, %arg0, %arg0, %arg1, %arg1, %0) : - (i64, memref<1x3x3xbf16>, memref<1x3x3xbf16>, memref<3x3xbf16>, memref<3x3xbf16>, i64) -> () - return -} - -// ----- - -func.func @brgemm(%arg0: memref<1x3x3xbf16>, %arg1: memref<3x3xbf16>) { - %0 = xsmm.fused_brgemm.dispatch [3, 3, 3, 3, 3, 3, 1, 1] [add, relu] - flags = (vnni_c) binary_flags = (none) unary_flags = (none) data_type = bf16 - // expected-error@+1 {{expect VNNI layout for operand C or invalid VNNI_C flags}} - xsmm.fused_brgemm(data_type = bf16, %0, %arg0, %arg0, %arg1, %arg1, %0) : - (i64, memref<1x3x3xbf16>, memref<1x3x3xbf16>, memref<3x3xbf16>, memref<3x3xbf16>, i64) -> () - return -} - -// ----- - -func.func @unary(%arg0: memref<3x3xf32>, %arg1: memref<3x3xf32>) { - %0 = xsmm.brgemm.dispatch [3, 3, 3, 3, 3, 3, 1, 1] flags = (none) data_type = f32 - // expected-error@+1 {{invalid dispatch operation}} - xsmm.unary relu(data_type = f32, %0, %arg0, %arg1) : - (i64, memref<3x3xf32>, memref<3x3xf32>) -> () - return -} - -// ----- - -func.func @unary(%arg0: memref<3x3xf32>, %arg1: memref<3x3xf32>) { - %0 = xsmm.unary.dispatch relu [3, 3, 3, 3] flags = (none) data_type = bf16 - // expected-error@+1 {{inconsistent data types}} - xsmm.unary relu(data_type = f32, %0, %arg0, %arg1) : - (i64, memref<3x3xf32>, memref<3x3xf32>) -> () - return -} - -// ----- - -func.func @unary(%arg0: memref<3x3xf32>, %arg1: memref<3x3xf32>) { - %0 = xsmm.unary.dispatch identity [3, 3, 3, 3] flags = (none) data_type = f32 - // expected-error@+1 {{inconsistent callee kind}} - xsmm.unary relu(data_type = f32, %0, %arg0, %arg1) : - (i64, memref<3x3xf32>, memref<3x3xf32>) -> () - return -} - -// ----- - -func.func @unary(%arg0: memref<3x3xf32>, %arg1: memref<3x3xf32>) { - %0 = xsmm.unary.dispatch relu [3, 3, 3, 3] flags = (bcast_scalar) data_type = f32 - // expected-error@+1 {{invalid 'bcast_scalar' flag for input}} - xsmm.unary relu(data_type = f32, %0, %arg0, %arg1) : - (i64, memref<3x3xf32>, memref<3x3xf32>) -> () - return -} - -// ----- - -func.func @binary(%arg0: memref<3x3xf32>, %arg1: memref<3x3xf32>) { - %0 = xsmm.brgemm.dispatch [3, 3, 3, 3, 3, 3, 1, 1] flags = (none) data_type = f32 - // expected-error@+1 {{invalid dispatch operation}} - xsmm.binary add(data_type = f32, %0, %arg0, %arg0, %arg1) : - (i64, memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>) -> () - return -} - -// ----- - -func.func @binary(%arg0: memref<3x3xf32>, %arg1: memref<3x3xf32>) { - %0 = xsmm.binary.dispatch add [3, 3, 3, 3, 3] flags = (none) data_type = bf16 - // expected-error@+1 {{inconsistent data types}} - xsmm.binary add(data_type = f32, %0, %arg0, %arg0, %arg1) : - (i64, memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>) -> () - return -} - -// ----- - -func.func @binary(%arg0: memref<3x3xf32>, %arg1: memref<3x3xf32>) { - %0 = xsmm.binary.dispatch sub [3, 3, 3, 3, 3] flags = (none) data_type = f32 - // expected-error@+1 {{inconsistent callee kind}} - xsmm.binary add(data_type = f32, %0, %arg0, %arg0, %arg1) : - (i64, memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>) -> () - return -} - -// ----- - -func.func @binary(%arg0: memref<3x3xf32>, %arg1: memref<3x3xf32>) { - %0 = xsmm.binary.dispatch add [3, 3, 3, 3, 3] flags = (bcast_scalar_in0) data_type = f32 - // expected-error@+1 {{invalid 'bcast_scalar_in0' flag for lhs input}} - xsmm.binary add(data_type = f32, %0, %arg0, %arg0, %arg1) : - (i64, memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>) -> () - return -} - -// ----- - -func.func @binary(%arg0: memref<3x3xf32>, %arg1: memref<3x3xf32>) { - %0 = xsmm.binary.dispatch add [3, 3, 3, 3, 3] flags = (bcast_scalar_in1) data_type = f32 - // expected-error@+1 {{invalid 'bcast_scalar_in1' flag for rhs input}} - xsmm.binary add(data_type = f32, %0, %arg0, %arg0, %arg1) : - (i64, memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>) -> () - return -} diff --git a/test/Dialect/Xsmm/xsmm-invalid.mlir b/test/Dialect/Xsmm/xsmm-invalid.mlir deleted file mode 100644 index 58b7858ad..000000000 --- a/test/Dialect/Xsmm/xsmm-invalid.mlir +++ /dev/null @@ -1,426 +0,0 @@ -// RUN: tpp-opt -split-input-file -verify-diagnostics %s - -func.func @gemm_dispatch() -> i64 { - // m, n, k, lda, ldb, ldc - // expected-error@+1 {{expect lda to be >= of dimension k}} - %0 = xsmm.gemm.dispatch [1, 2, 3, 1, 5, 6] flags = (none) data_type = f32 - return %0 : i64 -} - -// ----- - -func.func @gemm_dispatch() -> i64 { - // m, n, k, lda, ldb, ldc - // expected-error@+1 {{expect ldb to be >= of dimension n}} - %0 = xsmm.gemm.dispatch [1, 2, 3, 4, 1, 6] flags = (none) data_type = f32 - return %0 : i64 -} - -// ----- - -func.func @gemm_dispatch() -> i64 { - // m, n, k, lda, ldb, ldc - // expected-error@+1 {{expect ldc to be >= of dimension n}} - %0 = xsmm.gemm.dispatch [1, 2, 3, 4, 5, 1] flags = (none) data_type = f32 - return %0 : i64 -} - -// ----- - -func.func @gemm_dispatch() -> i64 { - // expected-error@+1 {{VNNI flags but type is not bf16}} - %0 = xsmm.gemm.dispatch [1, 2, 3, 4, 5, 6] flags = (vnni_a) data_type = f32 - return %0 : i64 -} - -// ----- - -func.func @gemm_dispatch() -> i64 { - // expected-error@+1 {{VNNI flags but type is not bf16}} - %0 = xsmm.gemm.dispatch [1, 2, 3, 4, 5, 6] flags = (vnni_b) data_type = f32 - return %0 : i64 -} - -// ----- - -func.func @gemm_dispatch() -> i64 { - // expected-error@+1 {{VNNI flags but type is not bf16}} - %0 = xsmm.gemm.dispatch [1, 2, 3, 4, 5, 6] flags = (vnni_c) data_type = f32 - return %0 : i64 -} - -// ----- - -func.func @gemm_dispatch() -> i64 { - // expected-error@+1 {{VNNI flags but type is not bf16}} - %0 = xsmm.gemm.dispatch [1, 2, 3, 4, 5, 6] flags = (vnni_a, vnni_c) data_type = f32 - return %0 : i64 -} - -// ----- - -func.func @gemm_dispatch() -> i64 { - // expected-error@+1 {{expected flags to be unique}} - %0 = xsmm.gemm.dispatch [1, 2, 3, 4, 5, 6] flags = (vnni_a, vnni_a) data_type = f32 - return %0 : i64 -} - -// ----- - -func.func @gemm_dispatch() -> i64 { - // expected-error@+1 {{'none' flags conflicts with others}} - %0 = xsmm.gemm.dispatch [1, 2, 3, 4, 5, 6] flags = (none, vnni_a) data_type = f32 - return %0 : i64 -} - -// ----- - -func.func @gemm_dispatch() -> i64 { - // expected-error@+1 {{expect 6 args but got: 5}} - %0 = xsmm.gemm.dispatch [1, 2, 3, 4, 5] flags = (none) data_type = f32 - return %0 : i64 -} - -// ----- - -func.func @gemm_dispatch() -> i64 { - // expected-error@+1 {{failed to satisfy constraint: i64 dense array attribute whose value is non-negative}} - %0 = xsmm.gemm.dispatch [-3, 2, 1, 3, 2] flags = (none) data_type = f32 - return %0 : i64 -} - -// ----- - -func.func @brgemm_dispatch() -> i64 { - // expected-error@+1 {{VNNI flags but type is not bf16}} - %0 = xsmm.brgemm.dispatch [1, 2, 3, 4, 5, 6, 1, 1] flags = (vnni_a) data_type = f32 - return %0 : i64 -} - -// ----- - -func.func @brgemm_dispatch() -> i64 { - // expected-error@+1 {{VNNI flags but type is not bf16}} - %0 = xsmm.brgemm.dispatch [1, 2, 3, 4, 5, 6, 1, 1] flags = (vnni_b) data_type = f32 - return %0 : i64 -} - -// ----- - -func.func @brgemm_dispatch() -> i64 { - // expected-error@+1 {{VNNI flags but type is not bf16}} - %0 = xsmm.brgemm.dispatch [1, 2, 3, 4, 5, 6, 1, 1] flags = (vnni_c) data_type = f32 - return %0 : i64 -} - -// ----- - -func.func @brgemm_dispatch() -> i64 { - // expected-error@+1 {{VNNI flags but type is not bf16}} - %0 = xsmm.brgemm.dispatch [1, 2, 3, 4, 5, 6, 1, 1] flags = (vnni_a, vnni_c) data_type = f32 - return %0 : i64 -} - -// ----- - -func.func @brgemm_dispatch() -> i64 { - // expected-error@+1 {{expected flags to be unique}} - %0 = xsmm.brgemm.dispatch [1, 2, 3, 4, 5, 6, 1, 1] flags = (vnni_a, vnni_a) data_type = f32 - return %0 : i64 -} - -// ----- - -func.func @brgemm_dispatch() -> i64 { - // expected-error@+1 {{'none' flags conflicts with others}} - %0 = xsmm.brgemm.dispatch [1, 2, 3, 4, 5, 6, 1, 1] flags = (none, vnni_a) data_type = f32 - return %0 : i64 -} - -// ----- - -func.func @brgemm_dispatch() -> i64 { - // expected-error@+1 {{expect 8 args but got: 5}} - %0 = xsmm.brgemm.dispatch [1, 2, 3, 4, 5] flags = (none) data_type = f32 - return %0 : i64 -} - -// ----- - -func.func @brgemm_dispatch() -> i64 { - // expected-error@+1 {{failed to satisfy constraint: i64 dense array attribute whose value is non-negative}} - %0 = xsmm.brgemm.dispatch [3, 2, -1, 3, 2] flags = (none) data_type = f32 - return %0 : i64 -} - -// ----- - -func.func @unary_dispatch() -> i64 { - // expected-error@+1 {{failed to satisfy constraint: i64 dense array attribute whose value is non-negative}} - %0 = xsmm.unary.dispatch relu [3, 2, 1, -3] flags = (none) data_type = f32 - return %0 : i64 -} - -// ----- - -func.func @unary_dispatch() -> i64 { - // expected-error@+1 {{op expect 4 args but got: 3}} - %0 = xsmm.unary.dispatch relu [3, 2, 1] flags = (none) data_type = f32 - return %0 : i64 -} - -// ----- - -func.func @binary_dispatch() -> i64 { - // expected-error@+1 {{failed to satisfy constraint: i64 dense array attribute whose value is non-negative}} - %0 = xsmm.binary.dispatch add [3, 2, 1, 3, -2] flags = (none) data_type = f32 - return %0 : i64 -} - -// ----- - -func.func @binary_dispatch() -> i64 { - // expected-error@+1 {{op expect 5 args but got: 3}} - %0 = xsmm.binary.dispatch add [3, 2, 1] flags = (none) data_type = f32 - return %0 : i64 -} - -// ----- - -func.func @fused_dispatch() -> i64 { - // expected-error@+1 {{op expect 8 args but got: 3}} - %0 = xsmm.fused_brgemm.dispatch [3, 2, 1] [add, relu] - flags = (none) binary_flags = (none) unary_flags = (none) data_type = f32 - return %0 : i64 -} - -// ----- - -func.func @fused_dispatch() -> i64 { - // expected-error@+1 {{op expected flags to be unique}} - %0 = xsmm.fused_brgemm.dispatch [1, 2, 3, 4, 5, 6, 1, 1] [add, relu] - flags = (vnni_a, vnni_a) binary_flags = (none) unary_flags = (none) data_type = bf16 - return %0 : i64 -} - -// ----- - -func.func @fused_dispatch() -> i64 { - // expected-error@+1 {{op expected binary_flags to be unique}} - %0 = xsmm.fused_brgemm.dispatch [1, 2, 3, 4, 5, 6, 1, 1] [add, relu] - flags = (vnni_a) binary_flags = (none, none) unary_flags = (none) data_type = bf16 - return %0 : i64 -} - -// ----- - -func.func @fused_dispatch() -> i64 { - // expected-error@+1 {{op expected unary_flags to be unique}} - %0 = xsmm.fused_brgemm.dispatch [1, 2, 3, 4, 5, 6, 1, 1] [add, relu] - flags = (vnni_a) binary_flags = (none) unary_flags = (none, none) data_type = bf16 - return %0 : i64 -} - -// ----- - -func.func @fused_brgemm_none_kind_with_flags() -> i64 { - // expected-error@+1 {{invalid binary flags for kind none}} - %0 = xsmm.fused_brgemm.dispatch [1, 2, 3, 4, 5, 6, 1, 1] [none, relu] - flags = (vnni_a) binary_flags = (bcast_col_in0) unary_flags = (none) data_type = bf16 - return %0 : i64 -} - -// ----- - -func.func @fused_brgemm_none_kind_with_flags() -> i64 { - // expected-error@+1 {{invalid unary flags for kind none}} - %0 = xsmm.fused_brgemm.dispatch [1, 2, 3, 4, 5, 6, 1, 1] [none, none] - flags = (vnni_a) binary_flags = (none) unary_flags = (bcast_scalar) data_type = bf16 - return %0 : i64 -} - -// ----- - -func.func @gemm_invoke(%arg0: i64, %arg1: memref<3x3xf32>, %arg2: memref<3x3xf32>, - %arg3: memref<3x3xf32>) { - // expected-error@+1 {{expect bf16 but got: 'f32' for operand at index: 1}} - xsmm.gemm(data_type = bf16, %arg0, %arg1, %arg2, %arg3) - : (i64, memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>) -> () - return -} - -// ----- - -func.func @gemm_invoke(%arg0: i64, %arg1: memref<3x3xbf16>, %arg2: memref<3x3xbf16>, - %arg3: memref<3x3xbf16>) { - // expected-error@+1 {{expect f32 but got: 'bf16' for operand at index: 1}} - xsmm.gemm(data_type = f32, %arg0, %arg1, %arg2, %arg3) - : (i64, memref<3x3xbf16>, memref<3x3xbf16>, memref<3x3xbf16>) -> () - return -} - -// ----- - -func.func @gemm_invoke(%arg0: memref<3x3xf32>, %arg1: memref<3x3xf32>, %arg2: memref<3x3xf32>, - %arg3: memref<3x3xf32>) { - // expected-error@+1 {{expect an i64 but got 'memref<3x3xf32>' for operand 0 (dispatch)}} - xsmm.gemm(data_type = f32, %arg0, %arg1, %arg2, %arg3) - : (memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>) -> () - return -} - -// ----- - -func.func @gemm_invoke(%arg0: f32, %arg1: memref<3x3xf32>, %arg2: memref<3x3xf32>, - %arg3: memref<3x3xf32>) { - // expected-error@+1 {{op operand #0 must be variadic of 2D/3D static memref of 32-bit float or bfloat16 type values or 64-bit signless integer, but got 'f32'}} - xsmm.gemm(data_type = f32, %arg0, %arg1, %arg2, %arg3) - : (f32, memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>) -> () - return -} - -// ----- - -func.func @gemm_invoke(%arg0: i64, %arg1: memref<1x1x3x3xf32>, %arg2: memref<3x3xf32>, - %arg3: memref<3x3xf32>) { - // expected-error@+1 {{op operand #1 must be variadic of 2D/3D static memref of 32-bit float or bfloat16 type values or 64-bit signless integer, but got 'memref<1x1x3x3xf32>'}} - xsmm.gemm(data_type = f32, %arg0, %arg1, %arg2, %arg3) - : (i64, memref<1x1x3x3xf32>, memref<3x3xf32>, memref<3x3xf32>) -> () - return -} - -// ----- - -func.func @gemm_invoke(%arg0: i64, %arg1: memref<1x3x3xf32>, %arg2: memref<3x3xf32>, - %arg3: memref<3x3xf32>) { - // expected-error@+1 {{expect VNNI layout for operand: 1}} - xsmm.gemm(data_type = f32, %arg0, %arg1, %arg2, %arg3) - : (i64, memref<1x3x3xf32>, memref<3x3xf32>, memref<3x3xf32>) -> () - return -} - -// ----- - -func.func @gemm_invoke(%arg0: i64, %arg1: memref<3x3xf32>, %arg2: memref<3x3xf32>) { - // expected-error@+1 {{expect 4 inputs but got 3}} - xsmm.gemm(data_type = f32, %arg0, %arg1, %arg2) - : (i64, memref<3x3xf32>, memref<3x3xf32>) -> () - return -} - -// ----- - -func.func @brgemm_invoke(%arg0: i64, %arg1: memref<3x3xf32>, %arg2: memref<3x3xf32>) { - // expected-error@+1 {{expect 5 inputs but got 3}} - xsmm.brgemm(data_type = f32, %arg0, %arg1, %arg2) - : (i64, memref<3x3xf32>, memref<3x3xf32>) -> () - return -} - -// ----- - -func.func @brgemm_invoke(%arg0: i64, %arg1: memref<3x3xf32>, %arg2: memref<2x3x3xf32>, - %arg3: memref<3x3xf32>, %arg4: memref<2xf32>) { - // expected-error@+1 {{operand #4 must be variadic of 2D/3D/4D static memref of 32-bit float or bfloat16 type values or 64-bit signless integer, but got 'memref<2xf32>'}} - xsmm.brgemm(data_type = f32, %arg0, %arg1, %arg2, %arg3, %arg4) - : (i64, memref<3x3xf32>, memref<2x3x3xf32>, memref<3x3xf32>, memref<2xf32>) -> () - return -} - -// ----- - -func.func @brgemm_invoke(%arg0: i64, %arg1: memref<2x3x3xf32>, %arg2: memref<2x3x3xf32>, %arg3: memref<2x3x3xf32>) { - // expected-error@+1 {{expect a 2d or 3d VNNI layout for operand: 3}} - xsmm.brgemm(data_type = f32, %arg0, %arg1, %arg2, %arg3, %arg0) - : (i64, memref<2x3x3xf32>, memref<2x3x3xf32>, memref<2x3x3xf32>, i64) -> () - return -} - -// ----- - -func.func @gemm_invoke(%arg0: i64, %arg1: memref, %arg2: memref) { - // expected-error@+1 {{operand #1 must be variadic of 2D/3D static memref of 32-bit float or bfloat16 type values or 64-bit signless integer, but got 'memref'}} - xsmm.gemm(data_type = f32, %arg0, %arg1, %arg2, %arg2) - : (i64, memref, memref, memref) -> () - return -} - -// ----- - -func.func @gemm_invoke(%arg0: i64, %arg1: memref, %arg2: memref, %arg3: memref) { - // expected-error@+1 {{operand #1 must be variadic of 2D/3D/4D static memref of 32-bit float or bfloat16 type values or 64-bit signless integer, but got 'memref'}} - xsmm.brgemm(data_type = f32, %arg0, %arg1, %arg2, %arg3, %arg0) - : (i64, memref, memref, memref, i64) -> () - return -} - -// ----- - -func.func @unary_invoke(%arg0: memref, %arg1: memref<3x3xf32>, %disp: i64) { - // expected-error@+1 {{operand #1 must be variadic of 1D/2D/3D/4D static memref of 32-bit float or bfloat16 type values or 32-bit float or bfloat16 type or 64-bit signless integer, but got 'memref'}} - xsmm.unary relu(data_type = f32, %disp, %arg0, %arg1) : (i64, memref, memref<3x3xf32>) -> () - return -} - -// ----- - -func.func @unary_invoke(%arg0: memref, %arg1: memref<3x3xf32>, %disp: i64) { - // expected-error@+1 {{operand #1 must be variadic of 1D/2D/3D/4D static memref of 32-bit float or bfloat16 type values or 32-bit float or bfloat16 type or 64-bit signless integer, but got 'memref'}} - xsmm.binary add(data_type = f32, %disp, %arg0, %arg0, %arg1) - : (i64, memref, memref, memref<3x3xf32>) -> () - return -} - -// ----- - -func.func @brgemm_invoke(%arg0: i64, %arg1: memref<2x3x3xf32>, %arg2: memref<2x3x3xf32>, %arg3: memref<3x3xf32>) { - // expected-error@+1 {{expect an i64 but got 'memref<3x3xf32>' for last operand (batch)}} - xsmm.brgemm(data_type = f32, %arg0, %arg1, %arg2, %arg3, %arg3) - : (i64, memref<2x3x3xf32>, memref<2x3x3xf32>, memref<3x3xf32>, memref<3x3xf32>) -> () - return -} - -// ----- - -func.func @brgemm_invoke(%arg0: i64, %arg1: memref<2x3x3xf32>, %arg2: memref<2x3x3xf32>, %arg3: memref<3x3xf32>) { - // expected-error@+1 {{expect an i64 but got 'memref<3x3xf32>' for last operand (batch)}} - xsmm.fused_brgemm(data_type = f32, %arg0, %arg1, %arg2, %arg3, %arg3, %arg3) - : (i64, memref<2x3x3xf32>, memref<2x3x3xf32>, memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>) -> () - return -} - -// ----- - -func.func @unary_invoke(%arg0: i64, %arg1: memref<3x3xf32>) { - // expected-error@+1 {{expect 3 inputs but got 5}} - xsmm.unary relu(data_type = f32, %arg0, %arg1, %arg1, %arg1, %arg1) - : (i64, memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>) -> () - return -} - -// ----- - -func.func @unary_invoke(%arg1: memref<3x3xf32>) { - // expected-error@+1 {{expect an i64 but got 'memref<3x3xf32>' for operand 0 (dispatch)}} - xsmm.unary relu(data_type = f32, %arg1, %arg1, %arg1) - : (memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>) -> () - return -} - -// ----- - -func.func @binary_invoke(%arg1: memref<3x3xf32>) { - // expected-error@+1 {{expect an i64 but got 'memref<3x3xf32>' for operand 0 (dispatch)}} - xsmm.binary add(data_type = f32, %arg1, %arg1, %arg1, %arg1) - : (memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>) -> () - return -} - -// ----- - -func.func @binary_invoke(%arg0: i64, %arg1: memref<3x3xf32>) { - // expected-error@+1 {{operands present, but expected 5}} - xsmm.binary add(data_type = f32, %arg0, %arg1, %arg1, %arg1, %arg1, %arg1) - : (i64, memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>) -> () - return -} diff --git a/test/Dialect/Xsmm/xsmm-ops.mlir b/test/Dialect/Xsmm/xsmm-ops.mlir deleted file mode 100644 index 3689ffe1c..000000000 --- a/test/Dialect/Xsmm/xsmm-ops.mlir +++ /dev/null @@ -1,76 +0,0 @@ -// RUN: tpp-opt %s | tpp-opt | FileCheck %s - -// CHECK-LABEL: @xsmm_dialect -func.func @xsmm_dialect(%arg0: memref<2x2xf32>, - %arg1: memref<2x2xf32>, %arg2: memref<2x2xf32>, %arg3: memref<3x2x2xf32>) { - - %d = arith.constant 0 : i64 - // CHECK: xsmm.binary add - xsmm.binary add(data_type = f32, %d, %arg0, %arg1, %arg1) - : (i64, memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () - - // CHECK: xsmm.binary sub - xsmm.binary sub(data_type = f32, %d, %arg0, %arg1, %arg1) - : (i64, memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () - - // CHECK: xsmm.binary div - xsmm.binary div(data_type = f32, %d, %arg0, %arg1, %arg1) - : (i64, memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () - - // CHECK: xsmm.unary relu - xsmm.unary relu(data_type = f32, %d, %arg0, %arg0) - : (i64, memref<2x2xf32>, memref<2x2xf32>) -> () - - // CHECK: xsmm.binary.dispatch add - %0 = xsmm.binary.dispatch add [3, 2, 1, 3, 2] flags = (none) data_type = f32 - - // CHECK: xsmm.unary.dispatch identity - %1 = xsmm.unary.dispatch identity [3, 2, 1, 3] flags = (bcast_row) data_type = f32 - - // CHECK: xsmm.gemm - xsmm.gemm (data_type = f32, %d, %arg0, %arg1, %arg2) - : (i64, memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () - - %b = arith.constant 1 : i64 - // CHECK: xsmm.fused_brgemm - xsmm.fused_brgemm (data_type = f32, %d, %arg3, %arg3, %arg2, %arg2, %b) - : (i64, memref<3x2x2xf32>, memref<3x2x2xf32>, memref<2x2xf32>, memref<2x2xf32>, i64) -> () - - // CHECK: xsmm.gemm.dispatch - %2 = xsmm.gemm.dispatch [1, 2, 3, 4, 5, 5] flags = (none) data_type = f32 - // CHECK-NEXT: xsmm.gemm.dispatch - %3 = xsmm.gemm.dispatch [1, 2, 3, 4, 5, 6] flags = (beta_0) data_type = f32 - // CHECK-NEXT: xsmm.gemm.dispatch - %4 = xsmm.gemm.dispatch [1, 2, 3, 4, 5, 6] flags = (beta_0) data_type = bf16 - // CHECK-NEXT: xsmm.gemm.dispatch - %5 = xsmm.gemm.dispatch [1, 2, 3, 4, 5, 6] flags = (vnni_a, vnni_b) data_type = bf16 - // CHECK-NEXT: xsmm.brgemm.dispatch - %6 = xsmm.brgemm.dispatch [1, 2, 3, 4, 5, 6, 1, 1] flags = (vnni_a, vnni_b) data_type = bf16 - // CHECK-NEXT: xsmm.brgemm.dispatch - %7 = xsmm.brgemm.dispatch [1, 2, 3, 4, 5, 6, 1, 1] flags = (beta_0) data_type = bf16 - // CHECK-NEXT: xsmm.brgemm.dispatch - %8 = xsmm.brgemm.dispatch [1, 2, 3, 4, 5, 6, 1, 1] flags = (beta_0) data_type = f32 - // CHECK-NEXT: xsmm.brgemm.dispatch - %9 = xsmm.brgemm.dispatch [1, 2, 3, 4, 5, 6, 1, 1] flags = (none) data_type = f32 - // CHECK: xsmm.gemm.dispatch {{.*}} {myAttr = "myattr"} - %10 = xsmm.gemm.dispatch [1, 2, 3, 4, 5, 6] flags = (none) data_type = f32 {myAttr = "myattr"} - - // CHECK: xsmm.unary.dispatch zero - %11 = xsmm.unary.dispatch zero [2, 2, 2, 2] flags = (none) data_type = f32 - - // CHECK: xsmm.fused_brgemm.dispatch - %12 = xsmm.fused_brgemm.dispatch [1, 2, 3, 4, 5, 6, 1, 1] [add, relu] - flags = (beta_0) binary_flags = (none) unary_flags = (none) data_type = f32 - - // CHECK: xsmm.unary zero - xsmm.unary zero(data_type = f32, %11, %arg0, %arg0) - : (i64, memref<2x2xf32>, memref<2x2xf32>) -> () - - // CHECK: xsmm.binary.dispatch sub - %13 = xsmm.binary.dispatch sub [3, 2, 1, 3, 2] flags = (none) data_type = f32 - - // CHECK: xsmm.binary.dispatch div - %14 = xsmm.binary.dispatch div [3, 2, 1, 3, 2] flags = (none) data_type = f32 - - return -} diff --git a/test/Passes/DefaultPipeline/linalg-to-xsmm.mlir b/test/Passes/DefaultPipeline/linalg-to-xsmm.mlir deleted file mode 100644 index 2b76d1ebf..000000000 --- a/test/Passes/DefaultPipeline/linalg-to-xsmm.mlir +++ /dev/null @@ -1,63 +0,0 @@ -// RUN: tpp-opt %s -default-tpp-passes -split-input-file | FileCheck %s - -func.func @fill_op(%arg0: memref<3x3xf32>) { - %cst = arith.constant 0.0 : f32 - linalg.fill ins(%cst : f32) outs(%arg0 : memref<3x3xf32>) - return -} - -// CHECK-LABEL: fill_op -// CHECK-SAME: %[[ARG0:.+]]: memref<3x3xf32> -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : i64 -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i64 -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : i64 -// CHECK-DAG: %[[C8:.+]] = arith.constant 8 : i64 -// CHECK-DAG: %[[CST:.+]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[DIS:.+]] = call @xsmm_unary_dispatch(%[[C2]], %[[C1]], %[[C3]], %[[C3]], %[[C1]], %[[C3]], %[[C8]]) -// CHECK: %[[PTR:.+]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<3x3xf32> -> index -// CHECK: %[[PTR_TO_INT:.+]] = arith.index_cast %[[PTR]] : index to i64 -// CHECK: %[[LLVM_PTR:.+]] = llvm.inttoptr %[[PTR_TO_INT]] : i64 to !llvm.ptr -// CHECK: call @xsmm_unary_scalar_invoke(%[[C1]], %[[DIS]], %[[CST]], %[[LLVM_PTR]], %[[C0]]) - -// ----- - -func.func @fill_op_i32(%arg0: memref<3x3xi32>) { - %cst = arith.constant 0 : i32 - linalg.fill ins(%cst : i32) outs(%arg0 : memref<3x3xi32>) - return -} - -// CHECK-LABEL: fill_op_i32 -// CHECK-NOT: xsmm -// CHECK: linalg.fill - -// ----- - -func.func @gemm_with_zero(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x3xf32> { - %cst = arith.constant 0.0 : f32 - %0 = tensor.empty() : tensor<3x3xf32> - %fill = linalg.fill ins(%cst : f32) outs(%0 : tensor<3x3xf32>) -> tensor<3x3xf32> - %mul = linalg.matmul ins(%arg0, %arg1 : tensor<3x3xf32>, tensor<3x3xf32>) - outs(%fill: tensor<3x3xf32>) -> tensor<3x3xf32> - return %mul : tensor<3x3xf32> -} - -// CHECK-LABEL: gemm_with_zero -// CHECK-SAME: %[[ARG0:.+]]: memref<3x3xf32>, %[[ARG1:.+]]: memref<3x3xf32> -// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : i64 -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : i64 -// CHECK-DAG: %[[C4:.+]] = arith.constant 4 : i64 -// CHECK-NOT: xsmm_unary_dispatch -// CHECK: %[[DIS:.+]] = call @xsmm_gemm_dispatch(%[[C1]], %[[C3]], %[[C3]], %[[C3]], %[[C3]], %[[C3]], %[[C3]], %[[C4]]) -// CHECK: %[[INT_PTR_ARG0:.+]] = memref.extract_aligned_pointer_as_index -// CHECK: %[[CAST_ARG0:.+]] = arith.index_cast %[[INT_PTR_ARG0]] : index to i64 -// CHECK: %[[LLVM_PTR_ARG0:.+]] = llvm.inttoptr %[[CAST_ARG0]] : i64 to !llvm.ptr -// CHECK: %[[INT_PTR_ARG1:.+]] = memref.extract_aligned_pointer_as_index -// CHECK: %[[CAST_ARG1:.+]] = arith.index_cast %[[INT_PTR_ARG1]] : index to i64 -// CHECK: %[[LLVM_PTR_ARG1:.+]] = llvm.inttoptr %[[CAST_ARG1]] : i64 to !llvm.ptr -// CHECK: %[[INT_PTR_ALLOC:.+]] = memref.extract_aligned_pointer_as_index -// CHECK: %[[CAST_ALLOC:.+]] = arith.index_cast %[[INT_PTR_ALLOC]] : index to i64 -// CHECK: %[[LLVM_PTR_ALLOC:.+]] = llvm.inttoptr %[[CAST_ALLOC]] : i64 to !llvm.ptr -// CHECK: call @xsmm_gemm_invoke(%[[C1]], %[[DIS]], %[[LLVM_PTR_ARG0]], %[[C0]], %[[LLVM_PTR_ARG1]], %[[C0]], %[[LLVM_PTR_ALLOC]], %[[C0]]) diff --git a/test/Passes/DefaultPipeline/local-dialects-lowering.mlir b/test/Passes/DefaultPipeline/local-dialects-lowering.mlir deleted file mode 100644 index ee699964a..000000000 --- a/test/Passes/DefaultPipeline/local-dialects-lowering.mlir +++ /dev/null @@ -1,61 +0,0 @@ -// RUN: tpp-opt %s -bufferize -lower-local-dialects -convert-xsmm-to-func -split-input-file | FileCheck %s - -func.func @check_dialect() { - %b = arith.constant dense<[ - [ 1.1, 2.1, 3.1, 4.1 ], - [ 1.2, 2.2, 3.2, 4.2 ], - [ 1.3, 2.3, 3.3, 4.3 ], - [ 1.4, 2.4, 3.4, 4.4 ] - ]> : tensor<4x4xf32> - %c = arith.constant dense<[ - [ 1.1, 2.1, 3.1, 4.1 ], - [ 1.2, 2.2, 3.2, 4.2 ], - [ 1.3, 2.3, 3.3, 4.3 ], - [ 1.4, 2.4, 3.4, 4.35 ] - ]> : tensor<4x4xf32> - - %threshold = arith.constant 0.1: f32 - check.expect_almost_eq(%b, %c, %threshold):tensor<4x4xf32>, tensor<4x4xf32>, f32 - return -} - -// CHECK-LABEL: func.func @check_dialect( -// CHECK-NOT: check.expect_almost_eq -// CHECK: scf.for - -// ----- - -func.func @perf_dialect(%A: tensor<4x8xf32>, - %B: tensor<8x4xf32>, %C: tensor<4x4xf32>, %n: i64) -> (f64, i64) { - %output = arith.constant 0 : i64 - - %stats, %res = perf.bench (%n : i64) iter_args(%arg0 = %output) -> (f64, i64) { - %sum = arith.addi %n, %n : i64 - perf.yield %sum : i64 - } - - return %stats, %res : f64, i64 -} - -// CHECK-LABEL: func.func @perf_dialect( -// CHECK-NOT: perf.bench -// CHECK: {{[\w]*\.?}}call @perf_start_timer -// CHECK: scf.for -// CHECK: {{[\w]*\.?}}call @perf_stop_timer - -// ----- - -func.func @xsmm_dialect(%arg0: memref<32x256xf32>, %arg1: memref<1x8x32x32xf32>) -> i64 { - %0 = xsmm.unary.dispatch identity [5, 6, 5, 6] flags = (bcast_row) data_type = f32 - %1 = xsmm.gemm.dispatch [3, 3, 3, 3, 3, 3] flags = (none) data_type = f32 - %2 = arith.addi %0, %1 : i64 - return %2: i64 -} - -// CHECK-DAG: func.func private -// CHECK-DAG: func.func private -// CHECK-LABEL: func.func @xsmm_dialect( -// CHECK-NOT: xsmm.unary.dispatch -// CHECK-NOT: xsmm.ternary.dispatch -// CHECK-DAG: {{[\w]*\.?}}call -// CHECK-DAG: {{[\w]*\.?}}call diff --git a/test/Passes/DefaultPipeline/xsmm.mlir b/test/Passes/DefaultPipeline/xsmm.mlir deleted file mode 100644 index bee500e22..000000000 --- a/test/Passes/DefaultPipeline/xsmm.mlir +++ /dev/null @@ -1,429 +0,0 @@ -// RUN: tpp-opt %s -default-tpp-passes -split-input-file | FileCheck %s - -// CHECK: func.func @add( -// CHECK-SAME: %[[ARG0:.+]]: memref<3x3xf32>, -// CHECK-SAME: %[[ARG1:.+]]: memref<3x3xf32> -func.func @add(%arg0: memref<3x3xf32>, %arg1: memref<3x3xf32>) { - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: call @xsmm_binary_dispatch - - // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] - // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr - - // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] - // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr - - // CHECK: call @xsmm_binary_invoke({{.*}}%[[llvm_ptr0]], %[[C0]], %[[llvm_ptr0]], %[[C0]], %[[llvm_ptr1]], %[[C0]] - %0 = xsmm.binary.dispatch add [3, 3, 3, 3, 3] flags = (none) data_type = f32 - xsmm.binary add(data_type = f32, %0, %arg0, %arg0, %arg1) - : (i64, memref<3x3xf32>, memref<3x3xf32>, memref<3x3xf32>) -> () - - return -} - -// ----- - -#map = affine_map<(d0, d1)[s0] -> (d0 * 10 + d1 + s0)> - -// CHECK: func.func @add_mapping( -func.func @add_mapping(%arg0: memref<1x10x10xf32>, %arg1: memref<1x10x10xf32>) { - // CHECK: %[[of:.*]] = arith.constant 0 : index - // CHECK: memref.subview - // CHECK-NOT: scf.parallel - // CHECK: call @xsmm_binary_dispatch - // CHECK: %[[ptr0:.*]] = llvm.inttoptr %{{.+}} : i64 to !llvm.ptr - // CHECK: %[[ptr1:.*]] = llvm.inttoptr %{{.+}} : i64 to !llvm.ptr - // CHECK: call @xsmm_binary_invoke({{.*}}%[[ptr0]], %[[of]], %[[ptr1]], %[[of]] - - %subview = memref.subview %arg0[0, 0, 0] [1, 10, 10] [1, 1, 1] : memref<1x10x10xf32> to memref<10x10xf32> - %subview_0 = memref.subview %arg1[0, 0, 0] [1, 10, 10] [1, 1, 1] : memref<1x10x10xf32> to memref<10x10xf32> - %0 = xsmm.binary.dispatch add [10, 10, 10, 10, 10] flags = (none) data_type = f32 - xsmm.binary add(data_type = f32, %0, %subview, %subview, %subview_0) - : (i64, memref<10x10xf32>, memref<10x10xf32>, memref<10x10xf32>) -> () - - return -} - -// ----- - -#map = affine_map<(d0, d1)[s0] -> (d0 * 10 + d1 + s0)> - -// CHECK-LABEL: @add_mapping_parallel -func.func @add_mapping_parallel(%arg0: memref<10x10x10xf32>, %arg1: memref<10x10x10xf32>) { - // CHECK: call @xsmm_binary_dispatch - // CHECK: scf.parallel - // CHECK: %[[ptr0:.*]] = llvm.inttoptr %{{.+}} : i64 to !llvm.ptr - // CHECK: %[[ptr1:.*]] = llvm.inttoptr %{{.+}} : i64 to !llvm.ptr - // CHECK: call @xsmm_binary_invoke({{.*}}%[[ptr0]], {{.*}}, %[[ptr1]], {{.+}} - %c0 = arith.constant 0 : index - %c10 = arith.constant 10 : index - %c1 = arith.constant 1 : index - scf.parallel (%arg2) = (%c0) to (%c10) step (%c1) { - %subview = memref.subview %arg0[%arg2, 0, 0] [1, 10, 10] [1, 1, 1] - : memref<10x10x10xf32> to memref<10x10xf32, #map> - %subview_0 = memref.subview %arg1[%arg2, 0, 0] [1, 10, 10] [1, 1, 1] - : memref<10x10x10xf32> to memref<10x10xf32, #map> - %0 = xsmm.binary.dispatch add [10, 10, 10, 10, 10] flags = (none) data_type = f32 - xsmm.binary add(data_type = f32, %0, %subview, %subview, %subview_0) - : (i64, memref<10x10xf32, #map>, memref<10x10xf32, #map>, memref<10x10xf32, #map>) -> () - scf.reduce - } - return -} - -// ----- - -// CHECK: func.func @identity( -// CHECK-SAME: %[[ARG0:.+]]: memref<3x3xf32>, -// CHECK-SAME: %[[ARG1:.+]]: memref<1x1xf32> -func.func @identity(%arg0: memref<3x3xf32>, %arg1: memref<1x1xf32>) { - // CHECK: %[[c0:.*]] = arith.constant 0 : index - // CHECK: call @xsmm_unary_dispatch - %0 = xsmm.unary.dispatch identity [3, 3, 1, 3] flags = (bcast_scalar) data_type = f32 - - // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] - // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr - - // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] - // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr - - // CHECK: call @xsmm_unary_invoke({{.*}}%[[llvm_ptr0]], %{{.+}}, %[[llvm_ptr1]], %{{.+}} - xsmm.unary identity(data_type = f32, %0, %arg1, %arg0) : (i64, memref<1x1xf32>, memref<3x3xf32>) -> () - - return -} - -// ----- - -#map = affine_map<(d0, d1)[s0] -> (d0 * 64 + d1 + s0)> - -// CHECK-LABEL: @identity_mapping -func.func @identity_mapping(%arg0: memref<64xf32>) -> memref<12x56x56x64xf32> { - // CHECK: %[[C0:.+]] = arith.constant 0 : index - // CHECK: call @xsmm_unary_dispatch - // CHECK: scf.parallel - // CHECK: %[[ptr0:.*]] = llvm.inttoptr %{{.+}} : i64 to !llvm.ptr - // CHECK: %[[ptr1:.*]] = llvm.inttoptr %{{.+}} : i64 to !llvm.ptr - // CHECK: call @xsmm_unary_invoke({{.*}}%[[ptr0]], %{{.+}}, %[[ptr1]], %{{.+}} - %c0 = arith.constant 0 : index - %c12 = arith.constant 12 : index - %c1 = arith.constant 1 : index - %c56 = arith.constant 56 : index - %alloc = memref.alloc() {alignment = 128 : i64} : memref<12x56x56x64xf32> - scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c12, %c56) step (%c1, %c1) { - %subview = memref.subview %alloc[%arg1, %arg2, 0, 0] [1, 1, 56, 64] [1, 1, 1, 1] : memref<12x56x56x64xf32> to memref<56x64xf32, #map> - %0 = xsmm.unary.dispatch identity [56, 64, 64, 64] flags = (bcast_col) data_type = f32 - xsmm.unary identity(data_type = f32, %0, %arg0, %subview) : (i64, memref<64xf32>, memref<56x64xf32, #map>) -> () - scf.reduce - } - - return %alloc : memref<12x56x56x64xf32> -} - -// ----- - -// CHECK: func.func @zero( -// CHECK-SAME: %[[ARG0:.+]]: memref<3x3xf32> -func.func @zero(%arg0: memref<3x3xf32>) { - // CHECK: %[[C0:.+]] = arith.constant 0 : index - // CHECK: call @xsmm_unary_dispatch - // CHECK: %[[ptr0:.+]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<3x3xf32> -> index - // CHECK: %[[ptr_cast0:.+]] = arith.index_cast %[[ptr0]] : index to i64 - // CHECK: %[[llvm_ptr0:.+]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr - // CHECK: call @xsmm_unary_invoke({{.*}}%[[llvm_ptr0]], %[[C0]], %[[llvm_ptr0]], %[[C0]] - %0 = xsmm.unary.dispatch zero [3, 3, 3, 3] flags = (none) data_type = f32 - xsmm.unary zero(data_type = f32, %0, %arg0, %arg0) : (i64, memref<3x3xf32>, memref<3x3xf32>) -> () - - return -} - -// ----- - -// CHECK: func.func @relu( -// CHECK-SAME: %[[ARG0:.+]]: memref<3x3xf32> -func.func @relu(%arg0: memref<3x3xf32>) { - // CHECK: %[[C0:.+]] = arith.constant 0 : index - // CHECK: call @xsmm_unary_dispatch - // CHECK: %[[ptr0:.+]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<3x3xf32> -> index - // CHECK: %[[ptr_cast0:.+]] = arith.index_cast %[[ptr0]] : index to i64 - // CHECK: %[[llvm_ptr0:.+]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr - // CHECK: call @xsmm_unary_invoke({{.*}}%[[llvm_ptr0]], %[[C0]], %[[llvm_ptr0]], %[[C0]] - %0 = xsmm.unary.dispatch relu [3, 3, 3, 3] flags = (none) data_type = f32 - xsmm.unary relu(data_type = f32, %0, %arg0, %arg0) : (i64, memref<3x3xf32>, memref<3x3xf32>) -> () - - return -} - -// ----- - -#map = affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)> - -// CHECK-LABEL: @relu_3d( -// CHECK-SAME: %[[arg:.*]]: memref<64x32x32xf32>) { -func.func @relu_3d(%arg0: memref<64x32x32xf32>) -> memref<64x32x32xf32> { - // CHECK: %[[C0:.+]] = arith.constant 0 : index - // CHECK: call @xsmm_unary_dispatch - // CHECK: scf.parallel - // CHECK: %[[ptr0:.*]] = llvm.inttoptr %{{.+}} : i64 to !llvm.ptr - // CHECK: call @xsmm_unary_invoke({{.*}}%[[ptr0]], %{{.+}}, %[[ptr0]], %{{.+}} - %c0 = arith.constant 0 : index - %c64 = arith.constant 64 : index - %c1 = arith.constant 1 : index - scf.parallel (%arg1) = (%c0) to (%c64) step (%c1) { - %subview = memref.subview %arg0[%arg1, 0, 0] [1, 32, 32] [1, 1, 1] : memref<64x32x32xf32> to memref<32x32xf32, #map> - %0 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = f32 - xsmm.unary relu(data_type = f32, %0, %subview, %subview) : (i64, memref<32x32xf32, #map>, memref<32x32xf32, #map>) -> () - scf.reduce - } - - return %arg0 : memref<64x32x32xf32> -} - -// ----- - -// CHECK: func.func @brgemm( -// CHECK-SAME: %[[ARG0:.+]]: memref<2x3x4xf32>, -// CHECK-SAME: %[[ARG1:.+]]: memref<2x4x3xf32>, -// CHECK-SAME: %[[ARG2:.+]]: memref<3x3xf32> -func.func @brgemm(%arg0: memref<2x3x4xf32>, %arg1: memref<2x4x3xf32>, %arg2: memref<3x3xf32>) { - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: call @xsmm_brgemm_dispatch - - // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<2x3x4xf32> -> index - // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr - - // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<2x4x3xf32> -> index - // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr - - // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] : memref<3x3xf32> -> index - // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr - - // CHECK: call @xsmm_brgemm_invoke({{.*}}%[[llvm_ptr0]], %[[C0]], %[[llvm_ptr1]], %[[C0]], %[[llvm_ptr2]], %[[C0]] - %c2_i64 = arith.constant 2 : i64 - %0 = xsmm.brgemm.dispatch [3, 3, 4, 4, 3, 3, 12, 12] flags = (none) data_type = f32 - xsmm.brgemm(data_type = f32, %0, %arg0, %arg1, %arg2, %c2_i64) - : (i64, memref<2x3x4xf32>, memref<2x4x3xf32>, memref<3x3xf32>, i64) -> () - - return -} - -// ----- - -// CHECK-LABEL: func.func @brgemm_bf16 -// CHECK-SAME: %[[ARG0:.+]]: memref<64x4x4xbf16>, -// CHECK-SAME: %[[ARG1:.+]]: memref<64x2x4x2xbf16>, -// CHECK-SAME: %[[ARG2:.+]]: memref<4x4xbf16> -func.func @brgemm_bf16(%arg0: memref<64x4x4xbf16>, %arg1: memref<64x2x4x2xbf16>, - %arg2: memref<4x4xbf16>) { - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: call @xsmm_brgemm_dispatch - - // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] : memref<64x4x4xbf16> -> index - // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr - - // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] : memref<64x2x4x2xbf16> -> index - // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr - - // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] : memref<4x4xbf16> -> index - // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr - - // CHECK: call @xsmm_brgemm_invoke({{.*}}%[[llvm_ptr0]], %[[C0]], %[[llvm_ptr1]], %[[C0]], %[[llvm_ptr2]], %[[C0]] - %c64_i64 = arith.constant 64 : i64 - %0 = xsmm.brgemm.dispatch [4, 4, 4, 4, 4, 4, 16, 16] flags = (vnni_b) data_type = bf16 - xsmm.brgemm(data_type = bf16, %0, %arg0, %arg1, %arg2, %c64_i64) - : (i64, memref<64x4x4xbf16>, memref<64x2x4x2xbf16>, memref<4x4xbf16>, i64) -> () - - return -} - -// ----- - -// CHECK: func.func @gemm( -// CHECK-SAME: %[[ARG0:.+]]: memref<4x8xf32>, -// CHECK-SAME: %[[ARG1:.+]]: memref<8x4xf32>, -// CHECK-SAME: %[[ARG2:.+]]: memref<4x4xf32>) -func.func @gemm(%A: memref<4x8xf32>, - %B: memref<8x4xf32>, %C: memref<4x4xf32>) { - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: call @xsmm_gemm_dispatch - - // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] - // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr - - // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] - // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr - - // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] - // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr - - // CHECK: call @xsmm_gemm_invoke({{.*}}%[[llvm_ptr0]], %[[C0]], %[[llvm_ptr1]], %[[C0]], %[[llvm_ptr2]], %[[C0]] - %0 = xsmm.gemm.dispatch [4, 4, 8, 8, 4, 4] flags = (none) data_type = f32 - xsmm.gemm(data_type = f32, %0, %A, %B, %C) : (i64, memref<4x8xf32>, memref<8x4xf32>, memref<4x4xf32>) -> () - - return -} - -// ----- - -// CHECK-LABEL: func.func @gemm_bf16 -// CHECK-SAME: %[[ARG0:.+]]: memref<6x10xbf16>, -// CHECK-SAME: %[[ARG1:.+]]: memref<5x6x2xbf16>, -// CHECK-SAME: %[[ARG2:.+]]: memref<6x6xbf16> -func.func @gemm_bf16(%arg0: memref<6x10xbf16>, %arg1: memref<5x6x2xbf16>, - %arg2: memref<6x6xbf16>) { - // CHECK: %[[C0:.*]] = arith.constant 0 : index - // CHECK: call @xsmm_gemm_dispatch - - // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] - // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr - - // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] - // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr - - // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] - // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr - - // CHECK: call @xsmm_gemm_invoke({{.*}}%[[llvm_ptr0]], %[[C0]], %[[llvm_ptr1]], %[[C0]], %[[llvm_ptr2]], %[[C0]] - %0 = xsmm.gemm.dispatch [6, 6, 10, 10, 6, 6] flags = (vnni_b) data_type = bf16 - xsmm.gemm(data_type = bf16, %0, %arg0, %arg1, %arg2) : (i64, memref<6x10xbf16>, memref<5x6x2xbf16>, memref<6x6xbf16>) -> () - - return -} - -// ----- - -// CHECK-LABEL: func.func @blocked_matmul( -// CHECK-SAME: %[[ARG0:.+]]: memref<4x16x32x32xf32>, -// CHECK-SAME: %[[ARG1:.+]]: memref<8x16x32x32xf32>, -// CHECK-SAME: %[[ARG2:.+]]: memref<4x8x32x32xf32>) -func.func @blocked_matmul(%arg0: memref<4x16x32x32xf32>, %arg1: memref<8x16x32x32xf32>, %arg2: memref<4x8x32x32xf32>) { - // CHECK: call @xsmm_brgemm_dispatch - // CHECK: scf.parallel - // CHECK: %[[ptr0:.*]] = llvm.inttoptr %{{.+}} : i64 to !llvm.ptr - // CHECK: %[[ptr1:.*]] = llvm.inttoptr %{{.+}} : i64 to !llvm.ptr - // CHECK: %[[ptr2:.*]] = llvm.inttoptr %{{.+}} : i64 to !llvm.ptr - // CHECK: call @xsmm_brgemm_invoke({{.*}}%[[ptr0]], %{{.+}}, %[[ptr1]], %{{.+}}, %[[ptr2]], %{{.+}} - %c16_i64 = arith.constant 16 : i64 - %c0 = arith.constant 0 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %c8 = arith.constant 8 : index - scf.parallel (%arg3, %arg4) = (%c0, %c0) to (%c4, %c8) step (%c1, %c1) { - %subview = memref.subview %arg0[%arg3, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<4x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> - %subview_0 = memref.subview %arg1[%arg4, 0, 0, 0] [1, 16, 32, 32] [1, 1, 1, 1] : memref<8x16x32x32xf32> to memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>> - %subview_1 = memref.subview %arg2[%arg3, %arg4, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<4x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - %0 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (none) data_type = f32 - xsmm.brgemm(data_type = f32, %0, %subview, %subview_0, %subview_1, %c16_i64) - : (i64, memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, - memref<16x32x32xf32, strided<[1024, 32, 1], offset: ?>>, - memref<32x32xf32, strided<[32, 1], offset: ?>>, i64) -> () - scf.reduce - } - - return -} - -// ----- - -// Conv2D weights -memref.global "private" constant @__constant_2048x512xf32 : memref<2048x512xf32> = dense<0.00332225906> {alignment = 128 : i64} - -// CHECK-LABEL: @conv2d_1x1( -// CHECK-SAME: %[[arg:.*]]: memref<1x7x7x2048xf32>) -> memref<1x7x7x512xf32> { -func.func @conv2d_1x1(%arg0: memref<1x7x7x2048xf32>) -> memref<1x7x7x512xf32> { - %cst = arith.constant 0.000000e+00 : f32 - %c7 = arith.constant 7 : index - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %0 = memref.get_global @__constant_2048x512xf32 : memref<2048x512xf32> - - // 1x1 Conv2D - // CHECK: call @xsmm_gemm_dispatch - // CHECK: %[[ptr0:.*]] = llvm.inttoptr %{{.+}} : i64 to !llvm.ptr - // CHECK: %[[ptr1:.*]] = llvm.inttoptr %{{.+}} : i64 to !llvm.ptr - // CHECK: %[[ptr2:.*]] = llvm.inttoptr %{{.+}} : i64 to !llvm.ptr - // CHECK: call @xsmm_gemm_invoke({{.*}}%[[ptr0]], %{{.+}}, %[[ptr1]], %{{.+}}, %[[ptr2]], %{{.+}} - %alloc = memref.alloc() {alignment = 128 : i64} : memref<1x7x7x512xf32> - linalg.fill ins(%cst : f32) outs(%alloc : memref<1x7x7x512xf32>) - scf.for %arg1 = %c0 to %c7 step %c1 { - %subview = memref.subview %arg0[0, %arg1, 0, 0] [1, 1, 7, 2048] [1, 1, 1, 1] : memref<1x7x7x2048xf32> to memref<7x2048xf32, strided<[2048, 1], offset: ?>> - %subview_0 = memref.subview %alloc[0, %arg1, 0, 0] [1, 1, 7, 512] [1, 1, 1, 1] : memref<1x7x7x512xf32> to memref<7x512xf32, strided<[512, 1], offset: ?>> - %1 = xsmm.gemm.dispatch [7, 512, 2048, 2048, 512, 512] flags = (none) data_type = f32 - xsmm.gemm(data_type = f32, %1, %subview, %0, %subview_0) : (i64, memref<7x2048xf32, strided<[2048, 1], offset: ?>>, memref<2048x512xf32>, memref<7x512xf32, strided<[512, 1], offset: ?>>) -> () - } - - return %alloc : memref<1x7x7x512xf32> -} - -// ----- - -#map = affine_map<(d0, d1) -> (d1)> -#map1 = affine_map<(d0, d1) -> (d0, d1)> -#map2 = affine_map<(d0, d1, d2) -> (d0, d2)> -#map3 = affine_map<(d0, d1, d2) -> (d2, d1)> -#map4 = affine_map<(d0, d1, d2) -> (d0, d1)> - -// CHECK: func.func @mlp( -// CHECK-SAME: %[[ARG0:.+]]: memref<128x256xf32>, -// CHECK-SAME: %[[ARG1:.+]]: memref<256x512xf32>, -// CHECK-SAME: %[[ARG2:.+]]: memref<512xf32>, -// CHECK-SAME: %[[ARG3:.+]]: memref<128x512xf32>) -module @predict_function { - func.func @mlp(%arg0: memref<128x256xf32>, %arg1: memref<256x512xf32>, - %arg2: memref<512xf32>, %arg3: memref<128x512xf32>) { - // CHECK: %[[C0:.*]] = arith.constant 0 : index - - // Identity - // CHECK: call @xsmm_unary_dispatch - - // CHECK: %[[ptr0:.*]] = memref.extract_aligned_pointer_as_index %[[ARG2]] - // CHECK-NEXT: %[[ptr_cast0:.*]] = arith.index_cast %[[ptr0]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr0:.*]] = llvm.inttoptr %[[ptr_cast0]] : i64 to !llvm.ptr - - // CHECK: %[[ptr1:.*]] = memref.extract_aligned_pointer_as_index %[[ARG3]] - // CHECK-NEXT: %[[ptr_cast1:.*]] = arith.index_cast %[[ptr1]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr1:.*]] = llvm.inttoptr %[[ptr_cast1]] : i64 to !llvm.ptr - - // CHECK: call @xsmm_unary_invoke({{.*}}%[[llvm_ptr0]], %[[C0]], %[[llvm_ptr1]], %[[C0]] - %0 = xsmm.unary.dispatch identity [128, 512, 512, 512] flags = (bcast_col) data_type = f32 - xsmm.unary identity(data_type = f32, %0, %arg2, %arg3) : (i64, memref<512xf32>, memref<128x512xf32>) -> () - - // Gemm - // CHECK: call @xsmm_gemm_dispatch - // CHECK: %[[ptr2:.*]] = memref.extract_aligned_pointer_as_index %[[ARG0]] - // CHECK-NEXT: %[[ptr_cast2:.*]] = arith.index_cast %[[ptr2]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr2:.*]] = llvm.inttoptr %[[ptr_cast2]] : i64 to !llvm.ptr - - // CHECK: %[[ptr3:.*]] = memref.extract_aligned_pointer_as_index %[[ARG1]] - // CHECK-NEXT: %[[ptr_cast3:.*]] = arith.index_cast %[[ptr3]] : index to i64 - // CHECK-NEXT: %[[llvm_ptr3:.*]] = llvm.inttoptr %[[ptr_cast3]] : i64 to !llvm.ptr - - // CHECK: call @xsmm_gemm_invoke({{.*}}%[[llvm_ptr2]], %[[C0]], %[[llvm_ptr3]], %[[C0]], %[[llvm_ptr1]], %[[C0]] - %1 = xsmm.gemm.dispatch [128, 512, 256, 256, 512, 512] flags = (none) data_type = f32 - xsmm.gemm(data_type = f32, %1, %arg0, %arg1, %arg3) : (i64, memref<128x256xf32>, memref<256x512xf32>, memref<128x512xf32>) -> () - - // Relu - // CHECK: call @xsmm_unary_dispatch - // CHECK: call @xsmm_unary_invoke({{.*}}%[[llvm_ptr1]], %[[C0]], %[[llvm_ptr1]], %[[C0]] - %2 = xsmm.unary.dispatch relu [128, 512, 512, 512] flags = (none) data_type = f32 - xsmm.unary relu(data_type = f32, %2, %arg3, %arg3) : (i64, memref<128x512xf32>, memref<128x512xf32>) -> () - - return - } -} diff --git a/test/Passes/fold-xsmm-flags.mlir b/test/Passes/fold-xsmm-flags.mlir deleted file mode 100644 index 659384a65..000000000 --- a/test/Passes/fold-xsmm-flags.mlir +++ /dev/null @@ -1,333 +0,0 @@ -// RUN: tpp-opt %s -fold-xsmm-flags -split-input-file | FileCheck %s - -func.func @zero_flag_gemm(%arg0: memref<32x512xf32, strided<[512, 1], offset: ?>>, - %arg1: memref<512x64xf32, strided<[512, 1], offset: ?>>) { - %cst = arith.constant 0.000000e+00 : f32 - %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32> - %0 = xsmm.unary.dispatch zero [32, 64, 1, 64] flags = (bcast_scalar) data_type = f32 - xsmm.unary zero(data_type = f32, %0, %cst, %alloc) : (i64, f32, memref<32x64xf32>) -> () - %1 = xsmm.gemm.dispatch [32, 64, 512, 512, 64, 64] flags = (none) data_type = f32 - xsmm.gemm(data_type = f32, %1, %arg0, %arg1, %alloc) : (i64, memref<32x512xf32, strided<[512, 1], offset: ?>>, memref<512x64xf32, strided<[512, 1], offset: ?>>, memref<32x64xf32>) -> () - return -} - -// CHECK-LABEL: zero_flag_gemm -// CHECK-SAME: %[[ARG0:.+]]: memref<32x512xf32, strided<[512, 1], offset: ?>> -// CHECK-SAME: %[[ARG1:.+]]: memref<512x64xf32, strided<[512, 1], offset: ?>> -// CHECK: %[[ALLOC:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32> -// CHECK: %[[DIS:.+]] = xsmm.gemm.dispatch [32, 64, 512, 512, 64, 64] flags = (beta_0) data_type = f32 -// CHECK: xsmm.gemm(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG1]], %[[ALLOC]]) - -// ----- - -func.func @non_zero_flag(%arg0: memref<32x512xf32, strided<[512, 1], offset: ?>>, - %arg1: memref<512x64xf32, strided<[512, 1], offset: ?>>) { - %cst = arith.constant 0.000000e+00 : f32 - %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32> - %0 = xsmm.unary.dispatch identity [32, 64, 1, 64] flags = (bcast_scalar) data_type = f32 - xsmm.unary identity(data_type = f32, %0, %cst, %alloc) : (i64, f32, memref<32x64xf32>) -> () - %1 = xsmm.gemm.dispatch [32, 64, 512, 512, 64, 64] flags = (none) data_type = f32 - xsmm.gemm(data_type = f32, %1, %arg0, %arg1, %alloc) : (i64, memref<32x512xf32, strided<[512, 1], offset: ?>>, memref<512x64xf32, strided<[512, 1], offset: ?>>, memref<32x64xf32>) -> () - return -} - -// CHECK-LABEL: non_zero_flag -// CHECK-NOT: xsmm.gemm.dispatch [32, 64, 512, 512, 64, 64] flags = (beta_0) data_type = f32 -// CHECK: %{{.+}} = xsmm.gemm.dispatch [32, 64, 512, 512, 64, 64] flags = (none) data_type = f32 - -// ----- - -func.func @zero_flag_bb_arg(%arg0: memref<32x512xf32, strided<[512, 1], offset: ?>>, - %arg1: memref<512x64xf32, strided<[512, 1], offset: ?>>, - %arg2: memref<32x64xf32>) { - %cst = arith.constant 0.000000e+00 : f32 - %0 = xsmm.unary.dispatch zero [32, 64, 1, 64] flags = (bcast_scalar) data_type = f32 - xsmm.unary zero(data_type = f32, %0, %cst, %arg2) : (i64, f32, memref<32x64xf32>) -> () - %1 = xsmm.gemm.dispatch [32, 64, 512, 512, 64, 64] flags = (none) data_type = f32 - xsmm.gemm(data_type = f32, %1, %arg0, %arg1, %arg2) : (i64, memref<32x512xf32, strided<[512, 1], offset: ?>>, memref<512x64xf32, strided<[512, 1], offset: ?>>, memref<32x64xf32>) -> () - return -} - -// CHECK-LABEL: zero_flag_bb_arg -// CHECK: %{{.+}} = xsmm.gemm.dispatch [32, 64, 512, 512, 64, 64] flags = (beta_0) data_type = f32 - -// ----- - -func.func @zero_subview(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>) { - %alloc = memref.alloc() {alignment = 64 : i64} : memref<5x32x64xf32> - %cst = arith.constant 0.000000e+00 : f32 - scf.forall (%iv) in (5) { - %sub = memref.subview %alloc[%iv, 0, 0] [1, 32, 64] [1, 1, 1] : memref<5x32x64xf32> to memref<32x64xf32, strided<[64, 1], offset: ?>> - %0 = xsmm.unary.dispatch zero [32, 64, 1, 64] flags = (bcast_scalar) data_type = f32 - xsmm.unary zero(data_type = f32, %0, %cst, %sub) : (i64, f32, memref<32x64xf32, strided<[64, 1], offset: ?>>) -> () - %1 = xsmm.gemm.dispatch [32, 32, 64, 64, 32, 64] flags = (none) data_type = f32 - xsmm.gemm(data_type = f32, %1, %arg0, %arg1, %sub) : (i64, memref<32x32xf32>, memref<32x32xf32>, memref<32x64xf32, strided<[64, 1], offset: ?>>) -> () - } - return -} - -// CHECK-LABEL: zero_subview -// CHECK: %{{.+}} = xsmm.gemm.dispatch [32, 32, 64, 64, 32, 64] flags = (beta_0) data_type = f32 - -// ----- - -// Copy prevents folding. -func.func @zero_with_copy(%arg0: memref<32x512xf32, strided<[512, 1], offset: ?>>, - %arg1: memref<512x64xf32, strided<[512, 1], offset: ?>>, - %arg2: memref<32x64xf32>) { - %cst = arith.constant 0.000000e+00 : f32 - %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32> - %0 = xsmm.unary.dispatch zero [32, 64, 1, 64] flags = (bcast_scalar) data_type = f32 - xsmm.unary zero(data_type = f32, %0, %cst, %alloc) : (i64, f32, memref<32x64xf32>) -> () - memref.copy %alloc, %arg2 : memref<32x64xf32> to memref<32x64xf32> - %1 = xsmm.gemm.dispatch [32, 64, 512, 512, 512, 64] flags = (none) data_type = f32 - xsmm.gemm(data_type = f32, %1, %arg0, %arg1, %arg2) : (i64, memref<32x512xf32, strided<[512, 1], offset: ?>>, memref<512x64xf32, strided<[512, 1], offset: ?>>, memref<32x64xf32>) -> () - return -} - -// CHECK-LABEL: zero_with_copy -// CHECK: xsmm.unary.dispatch zero -// CHECK: %{{.+}} = xsmm.gemm.dispatch [32, 64, 512, 512, 512, 64] flags = (none) data_type = f32 - -// ----- - -func.func @multiple_users(%arg0: memref<32x512xf32, strided<[512, 1], offset: ?>>, - %arg1: memref<512x64xf32, strided<[512, 1], offset: ?>>) { - %cst = arith.constant 0.000000e+00 : f32 - %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x64xf32> - - %0 = xsmm.unary.dispatch zero [32, 64, 1, 64] flags = (bcast_scalar) data_type = f32 - xsmm.unary zero(data_type = f32, %0, %cst, %alloc) : (i64, f32, memref<32x64xf32>) -> () - %1 = xsmm.gemm.dispatch [32, 64, 512, 512, 512, 64] flags = (none) data_type = f32 - xsmm.gemm(data_type = f32, %1, %arg0, %arg1, %alloc) - : (i64, memref<32x512xf32, strided<[512, 1], offset: ?>>, - memref<512x64xf32, strided<[512, 1], offset: ?>>, memref<32x64xf32>) -> () - - %2 = xsmm.unary.dispatch zero [32, 64, 1, 64] flags = (bcast_scalar) data_type = f32 - xsmm.unary zero(data_type = f32, %0, %cst, %alloc) : (i64, f32, memref<32x64xf32>) -> () - xsmm.gemm(data_type = f32, %1, %arg0, %arg1, %alloc) - : (i64, memref<32x512xf32, strided<[512, 1], offset: ?>>, - memref<512x64xf32, strided<[512, 1], offset: ?>>, memref<32x64xf32>) -> () - return -} - -// CHECK-LABEL: multiple_users -// CHECK: %[[FIRST_GEMM:.+]] = xsmm.gemm.dispatch [32, 64, 512, 512, 512, 64] flags = (beta_0) data_type = f32 -// CHECK: %[[SECOND_GEMM:.+]] = xsmm.gemm.dispatch [32, 64, 512, 512, 512, 64] flags = (beta_0) data_type = f32 -// CHECK: xsmm.gemm(data_type = f32, %[[FIRST_GEMM]], %{{.+}}, %{{.+}}, %{{.+}}) -// CHECK: xsmm.gemm(data_type = f32, %[[SECOND_GEMM]], %{{.+}}, %{{.+}}, %{{.+}}) - - -// ----- - -func.func @multiple_users_1(%arg0: memref<512x512xf32>, - %arg1: memref<512x512xf32>) { - %cst = arith.constant 0.000000e+00 : f32 - %alloc = memref.alloc() {alignment = 64 : i64} : memref<512x512xf32> - - %0 = xsmm.unary.dispatch zero [512, 512, 1, 512] flags = (bcast_scalar) data_type = f32 - xsmm.unary zero(data_type = f32, %0, %cst, %alloc) : (i64, f32, memref<512x512xf32>) -> () - %1 = xsmm.gemm.dispatch [512, 512, 512, 512, 512, 512] flags = (none) data_type = f32 - xsmm.gemm(data_type = f32, %1, %arg0, %arg1, %alloc) - : (i64, memref<512x512xf32>, memref<512x512xf32>, memref<512x512xf32>) -> () - - xsmm.gemm(data_type = f32, %1, %arg0, %alloc, %alloc) - : (i64, memref<512x512xf32>, memref<512x512xf32>, memref<512x512xf32>) -> () - return -} - -// CHECK-LABEL: multiple_users_1 -// CHECK: %[[FIRST_GEMM:.+]] = xsmm.gemm.dispatch [512, 512, 512, 512, 512, 512] flags = (beta_0) data_type = f32 -// CHECK: %[[SECOND_GEMM:.+]] = xsmm.gemm.dispatch [512, 512, 512, 512, 512, 512] flags = (none) data_type = f32 -// CHECK: xsmm.gemm(data_type = f32, %[[FIRST_GEMM]], %{{.+}}, %{{.+}}, %{{.+}}) -// CHECK: xsmm.gemm(data_type = f32, %[[SECOND_GEMM]], %{{.+}}, %{{.+}}, %{{.+}}) - -// ----- - -func.func @zero_flag(%arg0: memref<1x32x32xf32>, %arg1: memref<1x32x32xf32>) { - %c1_i64 = arith.constant 1 : i64 - %cst = arith.constant 0.000000e+00 : f32 - %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - %0 = xsmm.unary.dispatch zero [32, 512, 1, 512] flags = (bcast_scalar) data_type = f32 - xsmm.unary zero(data_type = f32, %0, %cst, %alloc) : (i64, f32, memref<32x32xf32>) -> () - %1 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 32, 1] flags = (none) data_type = f32 - xsmm.brgemm(data_type = f32, %1, %arg0, %arg1, %alloc, %c1_i64) : (i64, memref<1x32x32xf32>, memref<1x32x32xf32>, memref<32x32xf32>, i64) -> () - return -} - -// CHECK-LABEL: zero_flag -// CHECK-SAME: %[[ARG0:.+]]: memref<1x32x32xf32>, %[[ARG1:.+]]: memref<1x32x32xf32> -// CHECK: %[[C1:.+]] = arith.constant 1 : i64 -// CHECK: %[[ALLOC:.+]] = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> -// CHECK: %[[DIS:.+]] = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 32, 1] flags = (beta_0) data_type = f32 -// CHECK: xsmm.brgemm(data_type = f32, %[[DIS]], %[[ARG0]], %[[ARG1]], %[[ALLOC]], %[[C1]]) - -// ----- - -func.func @must_be_user_of_gemm(%arg0: memref<1x32x32xf32>, %arg1: memref<1x32x32xf32>, %arg2: memref<32x32xf32>) { - %c1_i64 = arith.constant 1 : i64 - %cst = arith.constant 0.000000e+00 : f32 - %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - %0 = xsmm.unary.dispatch zero [32, 512, 1, 512] flags = (bcast_scalar) data_type = f32 - xsmm.unary zero(data_type = f32, %0, %cst, %alloc) : (i64, f32, memref<32x32xf32>) -> () - %1 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 32, 1] flags = (none) data_type = f32 - xsmm.brgemm(data_type = f32, %1, %arg0, %arg1, %arg2, %c1_i64) : (i64, memref<1x32x32xf32>, memref<1x32x32xf32>, memref<32x32xf32>, i64) -> () - return -} - -// CHECK-LABEL: must_be_user_of_gemm -// CHECK-NOT: beta_0 -// CHECK: %{{.+}} = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 32, 1] flags = (none) data_type = f32 - -// ----- - -func.func @memory_effect_but_do_not_touch_alloc(%arg0: memref<1x32x32xf32>, %arg1: memref<1x32x32xf32>, %arg2: memref<32x32xf32>) { - %c1_i64 = arith.constant 1 : i64 - %cst = arith.constant 0.000000e+00 : f32 - %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - %0 = xsmm.unary.dispatch zero [32, 512, 1, 512] flags = (bcast_scalar) data_type = f32 - xsmm.unary zero(data_type = f32, %0, %cst, %alloc) : (i64, f32, memref<32x32xf32>) -> () - %1 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 32, 1] flags = (none) data_type = f32 - memref.copy %arg2, %arg2 : memref<32x32xf32> to memref<32x32xf32> - xsmm.brgemm(data_type = f32, %1, %arg0, %arg1, %alloc, %c1_i64) : (i64, memref<1x32x32xf32>, memref<1x32x32xf32>, memref<32x32xf32>, i64) -> () - return -} - -// CHECK-LABEL: memory_effect_but_do_not_touch_alloc -// CHECK: %{{.+}} = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 32, 1] flags = (beta_0) data_type = f32 - -// ----- - -func.func @sub_view_aliasing(%arg0: memref<1x32x32xf32>, %arg1: memref<1x32x32xf32>, %arg2: memref<32x32xf32>) { - %c1_i64 = arith.constant 1 : i64 - %cst = arith.constant 0.000000e+00 : f32 - %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - %0 = xsmm.unary.dispatch zero [32, 512, 1, 512] flags = (bcast_scalar) data_type = f32 - xsmm.unary zero(data_type = f32, %0, %cst, %alloc) : (i64, f32, memref<32x32xf32>) -> () - %sub = memref.subview %alloc[0, 0] [16, 16] [1, 1] : memref<32x32xf32> to memref<16x16xf32, strided<[32, 1]>> - call @test(%sub) : (memref<16x16xf32, strided<[32, 1]>>) -> () - %1 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 32, 1] flags = (none) data_type = f32 - xsmm.brgemm(data_type = f32, %1, %arg0, %arg1, %alloc, %c1_i64) : (i64, memref<1x32x32xf32>, memref<1x32x32xf32>, memref<32x32xf32>, i64) -> () - return -} - -func.func private @test(%arg0 : memref<16x16xf32, strided<[32, 1]>>) - -// CHECK-LABEL: sub_view_aliasing -// CHECK-NOT: beta_0 -// CHECK: %{{.+}} = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 32, 1] flags = (none) data_type = f32 - -// ----- - -func.func @sub_view_aliasing_1(%arg0: memref<1x32x32xf32>, %arg1: memref<1x32x32xf32>, %arg2: memref<32x32xf32>) { - %c1_i64 = arith.constant 1 : i64 - %cst = arith.constant 0.000000e+00 : f32 - %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - %sub = memref.subview %alloc[0, 0] [16, 16] [1, 1] : memref<32x32xf32> to memref<16x16xf32, strided<[32, 1]>> - %0 = xsmm.unary.dispatch zero [32, 512, 1, 512] flags = (bcast_scalar) data_type = f32 - xsmm.unary zero(data_type = f32, %0, %cst, %alloc) : (i64, f32, memref<32x32xf32>) -> () - call @test(%sub) : (memref<16x16xf32, strided<[32, 1]>>) -> () - %1 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 32, 1] flags = (none) data_type = f32 - xsmm.brgemm(data_type = f32, %1, %arg0, %arg1, %alloc, %c1_i64) : (i64, memref<1x32x32xf32>, memref<1x32x32xf32>, memref<32x32xf32>, i64) -> () - return -} - -func.func private @test(%arg0 : memref<16x16xf32, strided<[32, 1]>>) - -// CHECK-LABEL: sub_view_aliasing_1 -// CHECK-NOT: beta_0 -// CHECK: %{{.+}} = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 32, 1] flags = (none) data_type = f32 - -// ----- - -func.func @may_have_mem_effects(%arg0: memref<1x32x32xf32>, %arg1: memref<1x32x32xf32>, %arg2: memref<32x32xf32>) { - %c1_i64 = arith.constant 1 : i64 - %cst = arith.constant 0.000000e+00 : f32 - %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - %0 = xsmm.unary.dispatch zero [32, 512, 1, 512] flags = (bcast_scalar) data_type = f32 - xsmm.unary zero(data_type = f32, %0, %cst, %alloc) : (i64, f32, memref<32x32xf32>) -> () - call @test(%alloc) : (memref<32x32xf32>) -> () - %1 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 32, 1] flags = (none) data_type = f32 - xsmm.brgemm(data_type = f32, %1, %arg0, %arg1, %alloc, %c1_i64) : (i64, memref<1x32x32xf32>, memref<1x32x32xf32>, memref<32x32xf32>, i64) -> () - return -} - -func.func private @test(%arg0 : memref<32x32xf32>) - -// CHECK-LABEL: may_have_mem_effects -// CHECK-NOT: beta_0 -// CHECK: %{{.+}} = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 32, 1] flags = (none) data_type = f32 - -// ----- - -func.func @only_read_effect(%arg0: memref<1x32x32xf32>, %arg1: memref<1x32x32xf32>, %arg2: memref<32x32xf32>) { - %c1_i64 = arith.constant 1 : i64 - %cst = arith.constant 0.000000e+00 : f32 - %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - %0 = xsmm.unary.dispatch zero [32, 512, 1, 512] flags = (bcast_scalar) data_type = f32 - xsmm.unary zero(data_type = f32, %0, %cst, %alloc) : (i64, f32, memref<32x32xf32>) -> () - memref.copy %alloc, %arg2 : memref<32x32xf32> to memref<32x32xf32> - %1 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 32, 1] flags = (none) data_type = f32 - xsmm.brgemm(data_type = f32, %1, %arg0, %arg1, %alloc, %c1_i64) : (i64, memref<1x32x32xf32>, memref<1x32x32xf32>, memref<32x32xf32>, i64) -> () - return -} - -// CHECK-LABEL: only_read_effect -// CHECK: %{{.+}} = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 32, 1] flags = (beta_0) data_type = f32 - -// ----- - -func.func @read_write_effect(%arg0: memref<1x32x32xf32>, %arg1: memref<1x32x32xf32>, %arg2: memref<32x32xf32>) { - %c1_i64 = arith.constant 1 : i64 - %cst = arith.constant 0.000000e+00 : f32 - %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - %0 = xsmm.unary.dispatch zero [32, 512, 1, 512] flags = (bcast_scalar) data_type = f32 - xsmm.unary zero(data_type = f32, %0, %cst, %alloc) : (i64, f32, memref<32x32xf32>) -> () - memref.copy %alloc, %alloc : memref<32x32xf32> to memref<32x32xf32> - %1 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 32, 1] flags = (none) data_type = f32 - xsmm.brgemm(data_type = f32, %1, %arg0, %arg1, %alloc, %c1_i64) : (i64, memref<1x32x32xf32>, memref<1x32x32xf32>, memref<32x32xf32>, i64) -> () - return -} - -// CHECK-LABEL: read_write_effect -// CHECK-NOT: beta_0 -// CHECK: %{{.+}} = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 32, 1] flags = (none) data_type = f32 - -// ----- - -func.func @free_effect(%arg0: memref<1x32x32xf32>, %arg1: memref<1x32x32xf32>, %arg2: memref<32x32xf32>) { - %c1_i64 = arith.constant 1 : i64 - %cst = arith.constant 0.000000e+00 : f32 - %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - %0 = xsmm.unary.dispatch zero [32, 512, 1, 512] flags = (bcast_scalar) data_type = f32 - xsmm.unary zero(data_type = f32, %0, %cst, %alloc) : (i64, f32, memref<32x32xf32>) -> () - memref.dealloc %alloc : memref<32x32xf32> - %1 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 32, 1] flags = (none) data_type = f32 - xsmm.brgemm(data_type = f32, %1, %arg0, %arg1, %arg2, %c1_i64) : (i64, memref<1x32x32xf32>, memref<1x32x32xf32>, memref<32x32xf32>, i64) -> () - return -} - -// CHECK-LABEL: free_effect -// CHECK-NOT: beta_0 -// CHECK: %{{.+}} = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 32, 1] flags = (none) data_type = f32 - -// ----- - -func.func @zero_flag_fused_brgemm(%arg0: memref<1x32x32xf32>, %arg1: memref<1x32x32xf32>, %arg2: memref<32x32xf32>) { - %cst = arith.constant 0.000000e+00 : f32 - %c32_i64 = arith.constant 32 : i64 - %alloc = memref.alloc() {alignment = 64 : i64} : memref<32x32xf32> - %0 = xsmm.unary.dispatch zero [32, 32, 1, 32] flags = (bcast_scalar) data_type = f32 - xsmm.unary zero(data_type = f32, %0, %cst, %alloc) : (i64, f32, memref<32x32xf32>) -> () - - %1 = xsmm.fused_brgemm.dispatch [32, 32, 32, 32, 32, 32, 32, 32] [add, relu] - flags = (none) binary_flags = (none) unary_flags = (none) data_type = f32 - xsmm.fused_brgemm(data_type = f32, %1, %arg0, %arg1, %alloc, %arg2, %c32_i64) : - (i64, memref<1x32x32xf32>, memref<1x32x32xf32>, memref<32x32xf32>, memref<32x32xf32>, i64) -> () - return -} - -// CHECK-LABEL: zero_flag_fused_brgemm -// CHECK: %[[DIS:.+]] = xsmm.fused_brgemm.dispatch [32, 32, 32, 32, 32, 32, 32, 32][add,relu] -// CHECK-SAME: flags = (beta_0) binary_flags = (none) unary_flags = (none) data_type = f32 -// CHECK: xsmm.fused_brgemm(data_type = f32, %[[DIS]], %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}, %{{.+}}) diff --git a/test/Passes/pass-tileconfig-hoisting-pass.mlir b/test/Passes/pass-tileconfig-hoisting-pass.mlir deleted file mode 100644 index 5e4786476..000000000 --- a/test/Passes/pass-tileconfig-hoisting-pass.mlir +++ /dev/null @@ -1,132 +0,0 @@ -// RUN: tpp-opt %s --intel-amx-tile-config-hoisting-pass | FileCheck %s - -module{ - -memref.global "private" constant @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16> = dense<1.000000e+00> {alignment = 64 : i64} - -func.func @entry(%arg0: memref<8x32x32x32xbf16>) -> memref<8x32x32x32xbf16> { - %c2 = arith.constant 2 : index - %c1 = arith.constant 1 : index - %c32 = arith.constant 32 : index - %c8 = arith.constant 8 : index - %c0 = arith.constant 0 : index - %c32_i64 = arith.constant 32 : i64 - %0 = memref.get_global @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16> - %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xbf16> - %1 = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, vnni_b, beta_0) data_type = bf16 - %2 = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 - %3 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 - scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c32) step (%c2, %c8) { - scf.for %arg3 = %c0 to %c2 step %c1 { - %10 = arith.addi %arg3, %arg1 : index - %subview = memref.subview %arg0[%10, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>> - scf.for %arg4 = %c0 to %c8 step %c1 { - %11 = arith.addi %arg4, %arg2 : index - %subview_1 = memref.subview %alloc[%10, %11, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> - %alloca = memref.alloca() : memref<64xi8> - "xsmm.IntelAMXtileConfig"(%1, %alloca) : (i64, memref<64xi8>) -> () - xsmm.brgemm(data_type = bf16, %3, %subview, %0, %subview_1, %c32_i64) : (i64, memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () - "xsmm.IntelAMXtileConfig"(%2, %alloca) : (i64, memref<64xi8>) -> () - } - } - scf.reduce - } - %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xbf16> - %4 = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, vnni_b, beta_0) data_type = bf16 - %5 = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 - %6 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 - scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c32) step (%c2, %c8) { - scf.for %arg3 = %c0 to %c2 step %c1 { - %10 = arith.addi %arg3, %arg1 : index - %subview = memref.subview %alloc[%10, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>> - scf.for %arg4 = %c0 to %c8 step %c1 { - %11 = arith.addi %arg4, %arg2 : index - %subview_1 = memref.subview %alloc_0[%10, %11, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> - %alloca = memref.alloca() : memref<64xi8> - "xsmm.IntelAMXtileConfig"(%4, %alloca) : (i64, memref<64xi8>) -> () - xsmm.brgemm(data_type = bf16, %6, %subview, %0, %subview_1, %c32_i64) : (i64, memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () - "xsmm.IntelAMXtileConfig"(%5, %alloca) : (i64, memref<64xi8>) -> () - } - } - scf.reduce - } - %7 = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, vnni_b, beta_0) data_type = bf16 - %8 = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 - %9 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 - scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c32) step (%c2, %c8) { - scf.for %arg3 = %c0 to %c2 step %c1 { - %10 = arith.addi %arg3, %arg1 : index - %subview = memref.subview %alloc_0[%10, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>> - scf.for %arg4 = %c0 to %c8 step %c1 { - %11 = arith.addi %arg4, %arg2 : index - %subview_1 = memref.subview %alloc[%10, %11, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> - %alloca = memref.alloca() : memref<64xi8> - "xsmm.IntelAMXtileConfig"(%7, %alloca) : (i64, memref<64xi8>) -> () - xsmm.brgemm(data_type = bf16, %9, %subview, %0, %subview_1, %c32_i64) : (i64, memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () - "xsmm.IntelAMXtileConfig"(%8, %alloca) : (i64, memref<64xi8>) -> () - } - } - scf.reduce - } - memref.dealloc %alloc_0 : memref<8x32x32x32xbf16> - return %alloc : memref<8x32x32x32xbf16> -} -} - -// CHECK-LABEL: func.func @entry( -// CHECK: %[[ARG0:.*]]: memref<8x32x32x32xbf16>) -> memref<8x32x32x32xbf16> { -// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[c32:.*]] = arith.constant 32 : index -// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[c32_i64:.*]] = arith.constant 32 : i64 -// CHECK: %[[dispatch1:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, vnni_b, beta_0) data_type = bf16 -// CHECK: %[[dispatch2:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 -// CHECK: %[[brgemmdispatch:.*]] = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 -// CHECK: scf.parallel (%[[ARG1:.*]], %[[ARG2:.*]]) = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c8]]) { -// CHECK: %[[alloca:.*]] = memref.alloca() : memref<64xi8> -// CHECK: "xsmm.IntelAMXtileConfig"(%[[dispatch1]], %[[alloca]]) : (i64, memref<64xi8>) -> () -// CHECK: scf.for %[[ARG3:.*]] = %[[c0]] to %[[c2]] step %[[c1]] { -// CHECK: %[[temp10:.*]] = arith.addi %[[ARG3]], %[[ARG1]] : index -// CHECK: scf.for %[[ARG4:.*]] = %c0 to %c8 step %c1 { -// CHECK: %[[temp11:.*]] = arith.addi %[[ARG4]], %[[ARG2]] : index -// CHECK: xsmm.brgemm(data_type = bf16, %[[brgemmdispatch]], %{{.*}}, %{{.*}}, %{{.*}}, %[[c32_i64]]) -// CHECK: } -// CHECK: } -// CHECK: "xsmm.IntelAMXtileConfig"(%[[dispatch2]], %[[alloca]]) : (i64, memref<64xi8>) -> () -// CHECK: scf.reduce -// CHECK: } -// CHECK: %[[dispatch3:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, vnni_b, beta_0) data_type = bf16 -// CHECK: %[[dispatch4:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 -// CHECK: %[[brgemmdispatch2:.*]] = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 -// CHECK: scf.parallel (%[[ARG1:.*]], %[[ARG2:.*]]) = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c8]]) { -// CHECK: %[[alloca:.*]] = memref.alloca() : memref<64xi8> -// CHECK: "xsmm.IntelAMXtileConfig"(%[[dispatch3]], %[[alloca]]) : (i64, memref<64xi8>) -> () -// CHECK: scf.for %[[ARG3:.*]] = %[[c0]] to %[[c2]] step %[[c1]] { -// CHECK: %[[temp10:.*]] = arith.addi %[[ARG3]], %[[ARG1]] : index -// CHECK: scf.for %[[ARG4:.*]] = %c0 to %c8 step %c1 { -// CHECK: %[[temp11:.*]] = arith.addi %[[ARG4]], %[[ARG2]] : index -// CHECK: xsmm.brgemm(data_type = bf16, %[[brgemmdispatch2]], %{{.*}}, %{{.*}}, %{{.*}}, %[[c32_i64]]) -// CHECK: } -// CHECK: } -// CHECK: "xsmm.IntelAMXtileConfig"(%[[dispatch4]], %[[alloca]]) : (i64, memref<64xi8>) -> () -// CHECK: scf.reduce -// CHECK: } -// CHECK: %[[dispatch5:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, vnni_b, beta_0) data_type = bf16 -// CHECK: %[[dispatch6:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 -// CHECK: %[[brgemmdispatch3:.*]] = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (no_reset_tileconfig, no_setup_tileconfig, vnni_b, beta_0) data_type = bf16 -// CHECK: scf.parallel (%[[ARG1:.*]], %[[ARG2:.*]]) = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c8]]) { -// CHECK: %[[alloca:.*]] = memref.alloca() : memref<64xi8> -// CHECK: "xsmm.IntelAMXtileConfig"(%[[dispatch5]], %[[alloca]]) : (i64, memref<64xi8>) -> () -// CHECK: scf.for %[[ARG3:.*]] = %[[c0]] to %[[c2]] step %[[c1]] { -// CHECK: %[[temp10:.*]] = arith.addi %[[ARG3]], %[[ARG1]] : index -// CHECK: scf.for %[[ARG4:.*]] = %c0 to %c8 step %c1 { -// CHECK: %[[temp11:.*]] = arith.addi %[[ARG4]], %[[ARG2]] : index -// CHECK: xsmm.brgemm(data_type = bf16, %[[brgemmdispatch3]], %{{.*}}, %{{.*}}, %{{.*}}, %[[c32_i64]]) -// CHECK: } -// CHECK: } -// CHECK: "xsmm.IntelAMXtileConfig"(%[[dispatch6]], %[[alloca]]) : (i64, memref<64xi8>) -> () -// CHECK: scf.reduce -// CHECK: } - diff --git a/test/Passes/pass-tileconfig-insertion.mlir b/test/Passes/pass-tileconfig-insertion.mlir deleted file mode 100644 index b95c31bb7..000000000 --- a/test/Passes/pass-tileconfig-insertion.mlir +++ /dev/null @@ -1,113 +0,0 @@ -// RUN: tpp-opt %s --intel-amx-tile-config-insertion-pass | FileCheck %s - -module { - memref.global "private" constant @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16> = dense<1.000000e+00> {alignment = 64 : i64} - func.func @entry(%arg0: memref<8x32x32x32xbf16>) -> memref<8x32x32x32xbf16> { - %c1 = arith.constant 1 : index - %c32 = arith.constant 32 : index - %c8 = arith.constant 8 : index - %c0 = arith.constant 0 : index - %c32_i64 = arith.constant 32 : i64 - %0 = memref.get_global @__constant_32x16x32x2xbf16 : memref<32x16x32x2xbf16> - %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xbf16> - %1 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (vnni_b, beta_0) data_type = bf16 - %c0_0 = arith.constant 0 : index - %c2 = arith.constant 2 : index - %c8_1 = arith.constant 8 : index - %2 = arith.muli %c1, %c2 : index - %3 = arith.muli %c1, %c8_1 : index - scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c32) step (%2, %3) { - scf.for %arg3 = %c0_0 to %2 step %c1 { - scf.for %arg4 = %c0_0 to %3 step %c1 { - %8 = arith.addi %arg3, %arg1 : index - %9 = arith.addi %arg4, %arg2 : index - %subview = memref.subview %alloc[%8, %9, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> - %subview_9 = memref.subview %arg0[%8, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>> - xsmm.brgemm(data_type = bf16, %1, %subview_9, %0, %subview, %c32_i64) : (i64, memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () - } - } - scf.reduce - } - %alloc_2 = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xbf16> - %c0_3 = arith.constant 0 : index - %c2_4 = arith.constant 2 : index - %c8_5 = arith.constant 8 : index - %4 = arith.muli %c1, %c2_4 : index - %5 = arith.muli %c1, %c8_5 : index - scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c32) step (%4, %5) { - scf.for %arg3 = %c0_3 to %4 step %c1 { - scf.for %arg4 = %c0_3 to %5 step %c1 { - %8 = arith.addi %arg3, %arg1 : index - %9 = arith.addi %arg4, %arg2 : index - %subview = memref.subview %alloc_2[%8, %9, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> - %subview_9 = memref.subview %alloc[%8, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>> - xsmm.brgemm(data_type = bf16, %1, %subview_9, %0, %subview, %c32_i64) : (i64, memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () - } - } - scf.reduce - } - %c0_6 = arith.constant 0 : index - %c2_7 = arith.constant 2 : index - %c8_8 = arith.constant 8 : index - %6 = arith.muli %c1, %c2_7 : index - %7 = arith.muli %c1, %c8_8 : index - scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c32) step (%6, %7) { - scf.for %arg3 = %c0_6 to %6 step %c1 { - scf.for %arg4 = %c0_6 to %7 step %c1 { - %8 = arith.addi %arg3, %arg1 : index - %9 = arith.addi %arg4, %arg2 : index - %subview = memref.subview %alloc[%8, %9, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> - %subview_9 = memref.subview %alloc_2[%8, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xbf16> to memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>> - xsmm.brgemm(data_type = bf16, %1, %subview_9, %0, %subview, %c32_i64) : (i64, memref<32x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<32x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () - } - } - scf.reduce - } - memref.dealloc %alloc_2 : memref<8x32x32x32xbf16> - return %alloc : memref<8x32x32x32xbf16> - } -} - -// CHECK:func.func @entry(%[[ARG0:.*]]: memref<8x32x32x32xbf16>) -> memref<8x32x32x32xbf16> { -// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index -// CHECK-DAG: %[[c32:.*]] = arith.constant 32 : index -// CHECK-DAG: %[[c8:.*]] = arith.constant 8 : index -// CHECK-DAG: %[[c0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[c32_i64:.*]] = arith.constant 32 : i64 -// CHECK: scf.parallel (%[[ARG1:.*]], %[[ARG2:.*]]) = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c8]]) { -// CHECK: scf.for %[[ARG3:.*]] = %[[c0]] to %[[c2]] step %[[c1]] { -// CHECK: scf.for %[[ARG4:.*]] = %[[c0]] to %[[c8]] step %[[c1]] { -// CHECK: %[[temp1:.*]] = arith.addi %[[ARG3]], %[[ARG1]] : index -// CHECK: %[[temp2:.*]] = arith.addi %[[ARG4]], %[[ARG2]] : index -// CHECK: %[[temp3:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (vnni_b, beta_0, no_reset_tileconfig) data_type = bf16 -// CHECK: %[[temp4:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (vnni_b, beta_0, no_setup_tileconfig) data_type = bf16 -// CHECK: %[[temp5:.*]] = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (vnni_b, beta_0, no_reset_tileconfig, no_setup_tileconfig) data_type = bf16 -// CHECK: %[[alloca:.*]] = memref.alloca() : memref<64xi8> -// CHECK: "xsmm.IntelAMXtileConfig"(%[[temp3]], %[[alloca]]) : (i64, memref<64xi8>) -> () -// CHECK: xsmm.brgemm(data_type = bf16, %[[temp5]], %{{.*}}, %{{.*}}, %{{.*}}, %[[c32_i64]]) -// CHECK: "xsmm.IntelAMXtileConfig"(%[[temp4]], %[[alloca]]) : (i64, memref<64xi8>) -> () -// CHECK: scf.parallel (%[[ARG1:.*]], %[[ARG2:.*]]) = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c8]]) { -// CHECK: scf.for %[[ARG3:.*]] = %[[c0]] to %[[c2]] step %[[c1]] { -// CHECK: scf.for %[[ARG4:.*]] = %[[c0]] to %[[c8]] step %[[c1]] { -// CHECK: %[[temp1:.*]] = arith.addi %[[ARG3]], %[[ARG1]] : index -// CHECK: %[[temp2:.*]] = arith.addi %[[ARG4]], %[[ARG2]] : index -// CHECK: %[[temp3:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (vnni_b, beta_0, no_reset_tileconfig) data_type = bf16 -// CHECK: %[[temp4:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (vnni_b, beta_0, no_setup_tileconfig) data_type = bf16 -// CHECK: %[[temp5:.*]] = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (vnni_b, beta_0, no_reset_tileconfig, no_setup_tileconfig) data_type = bf16 -// CHECK: %[[alloca:.*]] = memref.alloca() : memref<64xi8> -// CHECK: "xsmm.IntelAMXtileConfig"(%[[temp3]], %[[alloca]]) : (i64, memref<64xi8>) -> () -// CHECK: xsmm.brgemm(data_type = bf16, %[[temp5]], %{{.*}}, %{{.*}}, %{{.*}}, %[[c32_i64]]) -// CHECK: "xsmm.IntelAMXtileConfig"(%[[temp4]], %[[alloca]]) : (i64, memref<64xi8>) -> () -// CHECK: scf.parallel (%[[ARG1:.*]], %[[ARG2:.*]]) = (%[[c0]], %[[c0]]) to (%[[c8]], %[[c32]]) step (%[[c2]], %[[c8]]) { -// CHECK: scf.for %[[ARG3:.*]] = %[[c0]] to %[[c2]] step %[[c1]] { -// CHECK: scf.for %[[ARG4:.*]] = %[[c0]] to %[[c8]] step %[[c1]] { -// CHECK: %[[temp1:.*]] = arith.addi %[[ARG3]], %[[ARG1]] : index -// CHECK: %[[temp2:.*]] = arith.addi %[[ARG4]], %[[ARG2]] : index -// CHECK: %[[temp3:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (vnni_b, beta_0, no_reset_tileconfig) data_type = bf16 -// CHECK: %[[temp4:.*]] = xsmm.IntelAMXtileConfig.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (vnni_b, beta_0, no_setup_tileconfig) data_type = bf16 -// CHECK: %[[temp5:.*]] = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (vnni_b, beta_0, no_reset_tileconfig, no_setup_tileconfig) data_type = bf16 -// CHECK: %[[alloca:.*]] = memref.alloca() : memref<64xi8> -// CHECK: "xsmm.IntelAMXtileConfig"(%[[temp3]], %[[alloca]]) : (i64, memref<64xi8>) -> () -// CHECK: xsmm.brgemm(data_type = bf16, %[[temp5]], %{{.*}}, %{{.*}}, %{{.*}}, %[[c32_i64]]) -// CHECK: "xsmm.IntelAMXtileConfig"(%[[temp4]], %[[alloca]]) : (i64, memref<64xi8>) -> () diff --git a/test/Passes/xsmm-combine.mlir b/test/Passes/xsmm-combine.mlir deleted file mode 100644 index 8edf7a2bc..000000000 --- a/test/Passes/xsmm-combine.mlir +++ /dev/null @@ -1,308 +0,0 @@ -//RUN: tpp-opt -verify-xsmm-calls --combine-xsmm-op-optimization -verify-xsmm-calls %s --split-input-file | FileCheck %s - -memref.global "private" constant @__constant_4x32x32xf32 : memref<4x32x32xf32> = dense<1.000000e+00> {alignment = 128 : i64} -memref.global "private" constant @__constant_8x32x32xf32 : memref<8x32x32xf32> = dense<1.000000e+00> {alignment = 128 : i64} -memref.global "private" constant @__constant_32xf32: memref<32xf32, strided<[32], offset:?>> = dense<1.000000e+00> {alignment = 128 : i64} - -func.func @bcast_col_in0_on_binary_add(%arg0: memref<256x128xf32>) -> memref<256x512xf32> { - %c0 = arith.constant 0 : index - %c8 = arith.constant 8 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %c4_i64 = arith.constant 4 : i64 - %c8_i64 = arith.constant 8 : i64 - %cst = arith.constant 0.000000e+00 : f32 - %0 = memref.get_global @__constant_4x32x32xf32 : memref<4x32x32xf32> - %1 = memref.get_global @__constant_8x32x32xf32 : memref<8x32x32xf32> - %2 = memref.get_global @__constant_32xf32 : memref<32xf32, strided<[32], offset:?>> - %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x4x32x32xf32> - %3 = xsmm.unary.dispatch identity [32, 32, 128, 32] flags = (none) data_type = f32 - %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x8x32x32xf32> - %4 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (beta_0) data_type = f32 - %5 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (bcast_col_in0) data_type = f32 - %6 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = f32 - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<256x512xf32> - scf.parallel (%arg3, %arg2) = (%c0, %c0) to (%c8, %c8) step (%c1, %c1) { - %subview = memref.subview %alloc_0[%arg3, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - %subview_2 = memref.subview %alloc[%arg3, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : memref<8x4x32x32xf32> to memref<4x32x32xf32, strided<[1024, 32, 1], offset: ?>> - xsmm.brgemm(data_type = f32, %4, %subview_2, %0, %subview, %c4_i64) : (i64, memref<4x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<4x32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>, i64) -> () - xsmm.binary add(data_type = f32, %5, %2, %subview, %subview) : (i64, memref<32xf32, strided<[32], offset:?>>, memref<32x32xf32, strided<[32,1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> () - xsmm.unary relu(data_type = f32, %6, %subview, %subview) : (i64, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> () - scf.reduce - } - return %alloc_1 : memref<256x512xf32> - } - -// CHECK-LABEL: func.func @bcast_col_in0_on_binary_add( -// CHECK: %[[ARG0:.*]]: memref<256x128xf32>) -> memref<256x512xf32> { -// CHECK: %[[BIAS:.*]] = memref.get_global @__constant_32xf32 : memref<32xf32, strided<[32], offset: ?>> -// CHECK-NOT: xsmm.brgemm.dispatch -// CHECK-NOT: xsmm.unary.dispatch -// CHECK-NOT: xsmm.binary.dispatch -// CHECK: %[[DISPATCH:.*]] = xsmm.fused_brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024][add,relu] flags = (beta_0) binary_flags = (bcast_col_in0) unary_flags = (none) data_type = f32 -// CHECK-NOT: xsmm.brgemm( -// CHECK-NOT: xsmm.binary add -// CHECK-NOT: xsmm.unary relu -// CHECK: xsmm.fused_brgemm(data_type = f32, %[[DISPATCH]], %{{.*}}, %{{.*}}, %{{.*}}, %[[BIAS]], %{{.*}}) - - -// ----- - -memref.global "private" constant @__constant_4x32x32xf32 : memref<4x32x32xf32> = dense<1.000000e+00> {alignment = 128 : i64} -memref.global "private" constant @__constant_8x32x32xf32 : memref<8x32x32xf32> = dense<1.000000e+00> {alignment = 128 : i64} -memref.global "private" constant @__constant_32xf32: memref<32xf32, strided<[32], offset:?>> = dense<1.000000e+00> {alignment = 128 : i64} - -func.func @bcast_col_in1_on_binary_add(%arg0: memref<256x128xf32>) -> memref<256x512xf32> { - %c0 = arith.constant 0 : index - %c8 = arith.constant 8 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %c4_i64 = arith.constant 4 : i64 - %c8_i64 = arith.constant 8 : i64 - %cst = arith.constant 0.000000e+00 : f32 - %0 = memref.get_global @__constant_4x32x32xf32 : memref<4x32x32xf32> - %1 = memref.get_global @__constant_8x32x32xf32 : memref<8x32x32xf32> - %2 = memref.get_global @__constant_32xf32 : memref<32xf32, strided<[32], offset:?>> - %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x4x32x32xf32> - %3 = xsmm.unary.dispatch identity [32, 32, 128, 32] flags = (none) data_type = f32 - %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x8x32x32xf32> - %4 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (beta_0) data_type = f32 - %5 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (bcast_col_in1) data_type = f32 - %6 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = f32 - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<256x512xf32> - scf.parallel (%arg3, %arg2) = (%c0, %c0) to (%c8, %c8) step (%c1, %c1) { - %subview = memref.subview %alloc_0[%arg3, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - %subview_2 = memref.subview %alloc[%arg3, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : memref<8x4x32x32xf32> to memref<4x32x32xf32, strided<[1024, 32, 1], offset: ?>> - xsmm.brgemm(data_type = f32, %4, %subview_2, %0, %subview, %c4_i64) : (i64, memref<4x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<4x32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>, i64) -> () - xsmm.binary add(data_type = f32, %5, %subview, %2, %subview) : (i64, memref<32x32xf32, strided<[32,1], offset: ?>>, memref<32xf32, strided<[32], offset:?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> () - xsmm.unary relu(data_type = f32, %6, %subview, %subview) : (i64, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> () - scf.reduce - } - return %alloc_1 : memref<256x512xf32> - } - -// CHECK-LABEL: func.func @bcast_col_in1_on_binary_add( -// CHECK: %[[ARG0:.*]]: memref<256x128xf32>) -> memref<256x512xf32> { -// CHECK: %[[BIAS:.*]] = memref.get_global @__constant_32xf32 : memref<32xf32, strided<[32], offset: ?>> -// CHECK-NOT: %[[DISPATCH:.*]] = xsmm.fused_brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024][add,relu] flags = (beta_0) binary_flags = (bcast_col_in1) unary_flags = (none) data_type = f32 -// CHECK-NOT: xsmm.fused_brgemm(data_type = f32, %[[DISPATCH]], %{{.*}}, %{{.*}}, %{{.*}}, %[[BIAS]], %{{.*}}) - -// ----- - -memref.global "private" constant @__constant_4x32x32xf32 : memref<4x32x32xf32> = dense<1.000000e+00> {alignment = 128 : i64} -memref.global "private" constant @__constant_8x32x32xf32 : memref<8x32x32xf32> = dense<1.000000e+00> {alignment = 128 : i64} -memref.global "private" constant @__constant_32x32xf32: memref<32x32xf32> = dense<1.000000e+00> {alignment = 128 : i64} - -func.func @none_on_binary_add(%arg0: memref<256x128xf32>) -> memref<256x512xf32> { - %c0 = arith.constant 0 : index - %c8 = arith.constant 8 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %c4_i64 = arith.constant 4 : i64 - %c8_i64 = arith.constant 8 : i64 - %cst = arith.constant 0.000000e+00 : f32 - %0 = memref.get_global @__constant_4x32x32xf32 : memref<4x32x32xf32> - %1 = memref.get_global @__constant_8x32x32xf32 : memref<8x32x32xf32> - %2 = memref.get_global @__constant_32x32xf32 : memref<32x32xf32> - %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x4x32x32xf32> - %3 = xsmm.unary.dispatch identity [32, 32, 128, 32] flags = (none) data_type = f32 - %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x8x32x32xf32> - %4 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (beta_0) data_type = f32 - %5 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (none) data_type = f32 - %6 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = f32 - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<256x512xf32> - scf.parallel (%arg3, %arg2) = (%c0, %c0) to (%c8, %c8) step (%c1, %c1) { - %subview = memref.subview %alloc_0[%arg3, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x8x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - %subview_2 = memref.subview %alloc[%arg3, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : memref<8x4x32x32xf32> to memref<4x32x32xf32, strided<[1024, 32, 1], offset: ?>> - xsmm.brgemm(data_type = f32, %4, %subview_2, %0, %subview, %c4_i64) : (i64, memref<4x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<4x32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>, i64) -> () - xsmm.binary add(data_type = f32, %5, %subview, %2, %subview) : (i64, memref<32x32xf32, strided<[32,1], offset: ?>>, memref<32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> () - xsmm.unary relu(data_type = f32, %6, %subview, %subview) : (i64, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> () - scf.reduce - } - return %alloc_1 : memref<256x512xf32> - } - -// CHECK-LABEL: func.func @none_on_binary_add( -// CHECK: %[[ARG0:.*]]: memref<256x128xf32>) -> memref<256x512xf32> { -// CHECK: %[[BIAS:.*]] = memref.get_global @__constant_32x32xf32 : memref<32x32xf32> -// CHECK-NOT: %[[DISPATCH:.*]] = xsmm.fused_brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024][add,relu] flags = (beta_0) binary_flags = (none) unary_flags = (none) data_type = f32 -// CHECK-NOT: xsmm.fused_brgemm(data_type = f32, %[[DISPATCH]], %{{.*}}, %{{.*}}, %{{.*}}, %[[BIAS]], %{{.*}}) - -// ----- - -memref.global "private" constant @__constant_4x16x32x2xbf16 : memref<4x16x32x2xbf16> = dense<1.000000e+00> {alignment = 128 : i64} -memref.global "private" constant @__constant_8x16x32x2xbf16 : memref<8x16x32x2xbf16> = dense<1.000000e+00> {alignment = 128 : i64} -memref.global "private" constant @__constant_32xbf16: memref<32xbf16, strided<[32], offset:?>> = dense<1.000000e+00> {alignment = 128 : i64} - -// Bcast_col_in0 flag set on binary add -func.func @bcast_col_in0_on_binary_add_bf16(%arg0: memref<256x128xbf16>) -> memref<256x512xbf16> { - %c0 = arith.constant 0 : index - %c8 = arith.constant 8 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %c4_i64 = arith.constant 4 : i64 - %c8_i64 = arith.constant 8 : i64 - %cst = arith.constant 0.000000e+00 : bf16 - %0 = memref.get_global @__constant_4x16x32x2xbf16 : memref<4x16x32x2xbf16> - %1 = memref.get_global @__constant_8x16x32x2xbf16 : memref<8x16x32x2xbf16> - %2 = memref.get_global @__constant_32xbf16 : memref<32xbf16, strided<[32], offset:?>> - %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x4x32x32xbf16> - %3 = xsmm.unary.dispatch identity [32, 32, 128, 32] flags = (none) data_type = bf16 - %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x8x32x32xbf16> - %4 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (vnni_b, beta_0) data_type = bf16 - %5 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (bcast_col_in0) data_type = bf16 - %6 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = bf16 - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<256x512xbf16> - scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c8) step (%c1, %c1) { - %subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x8x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> - %subview_2 = memref.subview %alloc[%arg1, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : memref<8x4x32x32xbf16> to memref<4x32x32xbf16, strided<[1024, 32, 1], offset: ?>> - xsmm.brgemm(data_type = bf16, %4, %subview_2, %0, %subview, %c4_i64) : (i64, memref<4x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<4x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () - xsmm.binary add(data_type = bf16, %5, %2, %subview, %subview) : (i64, memref<32xbf16, strided<[32], offset:?>>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, memref<32x32xbf16, strided<[32, 1], offset: ?>>) -> () - xsmm.unary relu(data_type = bf16, %6, %subview, %subview) : (i64, memref<32x32xbf16, strided<[32, 1], offset: ?>>, memref<32x32xbf16, strided<[32, 1], offset: ?>>) -> () - scf.reduce - } - return %alloc_1 : memref<256x512xbf16> -} - -// CHECK-LABEL: func.func @bcast_col_in0_on_binary_add_bf16( -// CHECK: %[[ARG0:.*]]: memref<256x128xbf16>) -> memref<256x512xbf16> { -// CHECK: %[[BIAS:.*]] = memref.get_global @__constant_32xbf16 : memref<32xbf16, strided<[32], offset: ?>> -// CHECK: %[[DISPATCH:.*]] = xsmm.fused_brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024][add,relu] flags = (vnni_b, beta_0) binary_flags = (bcast_col_in0) unary_flags = (none) data_type = bf16 -// CHECK: xsmm.fused_brgemm(data_type = bf16, %[[DISPATCH]], %{{.*}}, %{{.*}}, %{{.*}}, %[[BIAS]], %{{.*}}) - -// ----- - -memref.global "private" constant @__constant_4x16x32x2xbf16 : memref<4x16x32x2xbf16> = dense<1.000000e+00> {alignment = 128 : i64} -memref.global "private" constant @__constant_8x16x32x2xbf16 : memref<8x16x32x2xbf16> = dense<1.000000e+00> {alignment = 128 : i64} -memref.global "private" constant @__constant_32xbf16: memref<32xbf16, strided<[32], offset:?>> = dense<1.000000e+00> {alignment = 128 : i64} - -// Bcast_col_in1 flag set on binary add -func.func @bcast_col_in1_on_binary_add_bf16(%arg0: memref<256x128xbf16>) -> memref<256x512xbf16> { - %c0 = arith.constant 0 : index - %c8 = arith.constant 8 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %c4_i64 = arith.constant 4 : i64 - %c8_i64 = arith.constant 8 : i64 - %cst = arith.constant 0.000000e+00 : bf16 - %0 = memref.get_global @__constant_4x16x32x2xbf16 : memref<4x16x32x2xbf16> - %1 = memref.get_global @__constant_8x16x32x2xbf16 : memref<8x16x32x2xbf16> - %2 = memref.get_global @__constant_32xbf16 : memref<32xbf16, strided<[32], offset:?>> - %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x4x32x32xbf16> - %3 = xsmm.unary.dispatch identity [32, 32, 128, 32] flags = (none) data_type = bf16 - %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x8x32x32xbf16> - %4 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (vnni_b, beta_0) data_type = bf16 - %5 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (bcast_col_in1) data_type = bf16 - %6 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = bf16 - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<256x512xbf16> - scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c8) step (%c1, %c1) { - %subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x8x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> - %subview_2 = memref.subview %alloc[%arg1, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : memref<8x4x32x32xbf16> to memref<4x32x32xbf16, strided<[1024, 32, 1], offset: ?>> - xsmm.brgemm(data_type = bf16, %4, %subview_2, %0, %subview, %c4_i64) : (i64, memref<4x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<4x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () - xsmm.binary add(data_type = bf16, %5, %subview, %2, %subview) : (i64 , memref<32x32xbf16, strided<[32, 1], offset: ?>>,memref<32xbf16, strided<[32], offset:?>>, memref<32x32xbf16, strided<[32, 1], offset: ?>>) -> () - xsmm.unary relu(data_type = bf16, %6, %subview, %subview) : (i64, memref<32x32xbf16, strided<[32, 1], offset: ?>>, memref<32x32xbf16, strided<[32, 1], offset: ?>>) -> () - scf.reduce - } - return %alloc_1 : memref<256x512xbf16> -} - -// CHECK-LABEL: func.func @bcast_col_in1_on_binary_add_bf16( -// CHECK: %[[ARG0:.*]]: memref<256x128xbf16>) -> memref<256x512xbf16> { -// CHECK: %[[BIAS:.*]] = memref.get_global @__constant_32xbf16 : memref<32xbf16, strided<[32], offset: ?>> -// CHECK-NOT: %[[DISPATCH:.*]] = xsmm.fused_brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024][add,relu] flags = (vnni_b, beta_0) binary_flags = (bcast_col_in1) unary_flags = (none) data_type = bf16 -// CHECK-NOT: xsmm.fused_brgemm(data_type = bf16, %[[DISPATCH]] , %{{.*}}, %{{.*}}, %{{.*}}, %[[BIAS]], %{{.*}}) - - -// ----- - -memref.global "private" constant @__constant_4x16x32x2xbf16 : memref<4x16x32x2xbf16> = dense<1.000000e+00> {alignment = 128 : i64} -memref.global "private" constant @__constant_8x16x32x2xbf16 : memref<8x16x32x2xbf16> = dense<1.000000e+00> {alignment = 128 : i64} -memref.global "private" constant @__constant_32x32xbf16: memref<32x32xbf16> = dense<1.000000e+00> {alignment = 128 : i64} - -// None flag set on binary add -func.func @none_on_binary_add_bf16(%arg0: memref<256x128xbf16>) -> memref<256x512xbf16> { - %c0 = arith.constant 0 : index - %c8 = arith.constant 8 : index - %c4 = arith.constant 4 : index - %c1 = arith.constant 1 : index - %c16 = arith.constant 16 : index - %c4_i64 = arith.constant 4 : i64 - %c8_i64 = arith.constant 8 : i64 - %cst = arith.constant 0.000000e+00 : bf16 - %0 = memref.get_global @__constant_4x16x32x2xbf16 : memref<4x16x32x2xbf16> - %1 = memref.get_global @__constant_8x16x32x2xbf16 : memref<8x16x32x2xbf16> - %2 = memref.get_global @__constant_32x32xbf16 : memref<32x32xbf16> - %alloc = memref.alloc() {alignment = 64 : i64} : memref<8x4x32x32xbf16> - %3 = xsmm.unary.dispatch identity [32, 32, 128, 32] flags = (none) data_type = bf16 - %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x8x32x32xbf16> - %4 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (vnni_b, beta_0) data_type = bf16 - %5 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (none) data_type = bf16 - %6 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = bf16 - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<256x512xbf16> - scf.parallel (%arg1, %arg2) = (%c0, %c0) to (%c8, %c8) step (%c1, %c1) { - %subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x8x32x32xbf16> to memref<32x32xbf16, strided<[32, 1], offset: ?>> - %subview_2 = memref.subview %alloc[%arg1, 0, 0, 0] [1, 4, 32, 32] [1, 1, 1, 1] : memref<8x4x32x32xbf16> to memref<4x32x32xbf16, strided<[1024, 32, 1], offset: ?>> - xsmm.brgemm(data_type = bf16, %4, %subview_2, %0, %subview, %c4_i64) : (i64, memref<4x32x32xbf16, strided<[1024, 32, 1], offset: ?>>, memref<4x16x32x2xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>, i64) -> () - xsmm.binary add(data_type = bf16, %5, %subview, %2, %subview) : (i64 , memref<32x32xbf16, strided<[32, 1], offset: ?>>,memref<32x32xbf16>, memref<32x32xbf16, strided<[32, 1], offset: ?>>) -> () - xsmm.unary relu(data_type = bf16, %6, %subview, %subview) : (i64, memref<32x32xbf16, strided<[32, 1], offset: ?>>, memref<32x32xbf16, strided<[32, 1], offset: ?>>) -> () - scf.reduce - } - return %alloc_1 : memref<256x512xbf16> -} - -// CHECK-LABEL: func.func @none_on_binary_add_bf16( -// CHECK: %[[ARG0:.*]]: memref<256x128xbf16>) -> memref<256x512xbf16> { -// CHECK: %[[BIAS:.*]] = memref.get_global @__constant_32x32xbf16 : memref<32x32xbf16> -// CHECK-NOT: %[[DISPATCH:.*]] = xsmm.fused_brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024][add,relu] flags = (vnni_b, beta_0) binary_flags = (none) unary_flags = (none) data_type = bf16 -// CHECK-NOT: xsmm.fused_brgemm(data_type = bf16, %[[DISPATCH]] , %{{.*}}, %{{.*}}, %{{.*}}, %[[BIAS]], %{{.*}}) - -// ----- - memref.global "private" constant @__constant_32x32x32xf32_1 : memref<32x32x32xf32> = dense<1.600000e+00> {alignment = 64 : i64} - memref.global "private" constant @__constant_32xf32_1 : memref<32xf32> = dense<1.300000e+00> {alignment = 64 : i64} - memref.global "private" constant @__constant_32x32x32xf32_0 : memref<32x32x32xf32> = dense<1.500000e+00> {alignment = 64 : i64} - memref.global "private" constant @__constant_32xf32_0 : memref<32xf32> = dense<1.200000e+00> {alignment = 64 : i64} - memref.global "private" constant @__constant_32x32x32xf32 : memref<32x32x32xf32> = dense<1.400000e+00> {alignment = 64 : i64} - memref.global "private" constant @__constant_32xf32 : memref<32xf32> = dense<1.100000e+00> {alignment = 64 : i64} - -func.func @forward(%arg0: memref<256x1024xf32>) -> memref<256x1024xf32> { - %c32_i64 = arith.constant 32 : i64 - %cst = arith.constant 0.000000e+00 : f32 - %0 = memref.get_global @__constant_32xf32 : memref<32xf32> - %1 = memref.get_global @__constant_32x32x32xf32 : memref<32x32x32xf32> - %2 = memref.get_global @__constant_32x32x32xf32_0 : memref<32x32x32xf32> - %3 = memref.get_global @__constant_32xf32_0 : memref<32xf32> - %4 = memref.get_global @__constant_32x32x32xf32_1 : memref<32x32x32xf32> - %alloc = memref.alloc() {alignment = 64 : i64} : memref<256x1024xf32> - %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xf32> - %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<8x32x32x32xf32> - scf.forall (%arg1, %arg2) in (8, 32) { - %subview = memref.subview %alloc_0[%arg1, %arg2, 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> - %5 = xsmm.unary.dispatch zero [32, 32, 1, 32] flags = (bcast_scalar) data_type = f32 - xsmm.unary zero(data_type = f32, %5, %cst, %subview) : (i64, f32, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> () - %subview_3 = memref.subview %alloc_1[%arg1, 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xf32> to memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>> - %6 = xsmm.brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024] flags = (none) data_type = f32 - xsmm.brgemm(data_type = f32, %6, %subview_3, %4, %subview, %c32_i64) : (i64, memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>, i64) -> () - %7 = xsmm.binary.dispatch add [32, 32, 32, 32, 32] flags = (bcast_col_in0) data_type = f32 - xsmm.binary add(data_type = f32, %7, %3, %subview, %subview) : (i64, memref<32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> () - %8 = xsmm.unary.dispatch relu [32, 32, 32, 32] flags = (none) data_type = f32 - xsmm.unary relu(data_type = f32, %8, %subview, %subview) : (i64, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32x32xf32, strided<[32, 1], offset: ?>>) -> () - } - return %alloc : memref<256x1024xf32> -} - -// CHECK-LABEL:func.func @forward( -// CHECK: %[[ARG0:.*]]: memref<256x1024xf32>) -> memref<256x1024xf32> { -// CHECK-DAG: %[[c32_i64:.*]] = arith.constant 32 : i64 -// CHECK: scf.forall (%[[arg1:.*]], %[[arg2:.*]]) in (8, 32) { -// CHECK: %[[subview:.*]] = memref.subview %{{.*}}[%[[arg1]], %[[arg2]], 0, 0] [1, 1, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xf32> to memref<32x32xf32, strided<[32, 1], offset: ?>> -// CHECK: %[[subview_2:.*]] = memref.subview %{{.*}}[%[[arg1]], 0, 0, 0] [1, 32, 32, 32] [1, 1, 1, 1] : memref<8x32x32x32xf32> to memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>> -// CHECK: %[[temp2:.*]] = xsmm.fused_brgemm.dispatch [32, 32, 32, 32, 32, 32, 1024, 1024][add,relu] flags = (beta_0) binary_flags = (bcast_col_in0) unary_flags = (none) data_type = f32 -// CHECK: xsmm.fused_brgemm(data_type = f32, %[[temp2]], %[[subview_2]], %{{.*}}, %[[subview]], %{{.*}} %[[c32_i64]]) : (i64, memref<32x32x32xf32, strided<[1024, 32, 1], offset: ?>>, memref<32x32x32xf32>, memref<32x32xf32, strided<[32, 1], offset: ?>>, memref<32xf32>, i64) -> () -// CHECK: } -// CHECK: return %{{.*}} : memref<256x1024xf32> -