diff --git a/build_tools/ci/cpu_comparison/matmul_template/matmul_trunci_MxK_KxN.mlir b/build_tools/ci/cpu_comparison/matmul_template/matmul_trunci_MxK_KxN.mlir new file mode 100644 index 000000000..b6fe8e361 --- /dev/null +++ b/build_tools/ci/cpu_comparison/matmul_template/matmul_trunci_MxK_KxN.mlir @@ -0,0 +1,13 @@ +// input ${M}x${K}x${TYPE1} +// input ${K}x${N}x${TYPE1} + +func.func @matmul_trunci(%arg0: tensor<${M}x${K}x${TYPE1}>, %arg1: tensor<${K}x${N}x${TYPE1}>) -> tensor<${M}x${N}x${TYPE1}> +{ + %cst = arith.constant ${ZERO} : ${TYPE2} + %0 = tensor.empty() : tensor<${M}x${N}x${TYPE2}> + %1 = linalg.fill ins(%cst : ${TYPE2}) outs(%0 : tensor<${M}x${N}x${TYPE2}>) -> tensor<${M}x${N}x${TYPE2}> + %2 = linalg.matmul ins(%arg0, %arg1 : tensor<${M}x${K}x${TYPE1}>, tensor<${K}x${N}x${TYPE1}>) + outs(%1: tensor<${M}x${N}x${TYPE2}>) -> tensor<${M}x${N}x${TYPE2}> + %3 = arith.trunci %2 : tensor<${M}x${N}x${TYPE2}> to tensor<${M}x${N}x${TYPE1}> + return %3: tensor<${M}x${N}x${TYPE1}> +} diff --git a/build_tools/ci/cpu_comparison/run.py b/build_tools/ci/cpu_comparison/run.py index ff0d64bf6..48b1de540 100755 --- a/build_tools/ci/cpu_comparison/run.py +++ b/build_tools/ci/cpu_comparison/run.py @@ -171,6 +171,8 @@ def run(self, config): # does not). if self.use_chess and not config.vitis_dir: return False + if self.use_ukernel and not config.vitis_dir: + return False # If use_chess=0, and config has not provided a valid # path to peano, then bail: a path to peano must be provided. @@ -655,11 +657,70 @@ def _execute(self, config): input_args = generate_inputs( filename, self.get_dir(config), 1, {1: self.lhs, 2: self.rhs} ) - """ - Currently without function outlining, we run out of program memory. - """ - self.add_aie_compilation_flags( - ["--iree-amdaie-enable-function-outlining=balanced"] + aie_vs_baseline( + config=config, + aie_compilation_flags=self.aie_compilation_flags, + test_file=self.get_filename(config), + input_args=input_args, + baseline_value=self.expected_out, + use_ukernel=self.use_ukernel, + tile_pipeline=self.tile_pipeline, + function_name=None, + seed=1, + rtol=0, + atol=0, + lower_to_aie_pipeline=self.lower_to_aie_pipeline, + n_repeats=self.n_repeats, + output_type=get_output_type(self.get_filename(config)), + ) + + return True + + +class MatmulTrunci(BaseMatmul): + """ + A test of the form matmul(A,B) + trunci(C) where A:MxK, B:KxN and C:MxN + """ + + def __init__( + self, + M, + N, + K, + input_type, + acc_type, + lhs, + rhs, + expected_out, + test_params=None, + ): + super().__init__( + name=f"matmul_trunci_{M}_{N}_{K}_{input_type}_{acc_type}", + test_params=test_params, + M=M, + N=N, + K=K, + input_type=input_type, + acc_type=acc_type, + ) + self.labels.append("MatmulTrunci") + + # Assertions on shapes: Check that lhs is MxK, rhs is KxN, and expected_out is MxN + assert lhs.shape == (M, K) + assert rhs.shape == (K, N) + assert expected_out.shape == (M, N) + + self.lhs = lhs + self.rhs = rhs + self.expected_out = expected_out + + def _execute(self, config): + matmul_template_dir = config.file_dir / "matmul_template" + template_name = matmul_template_dir / "matmul_trunci_MxK_KxN.mlir" + self.generate(config, template_name) + filename = self.get_filename(config) + input_args = generate_inputs( + filename, self.get_dir(config), 1, {1: self.lhs, 2: self.rhs} ) aie_vs_baseline( config=config, @@ -1462,6 +1523,73 @@ def __init__(self): self.existing_names = [] self.tests = [] + # Tests Matmul + Trunci. + # Phoenix : Ukernel + Peano. + self.register( + MatmulTrunci( + 256, + 128, + 32, + "i8", + "i32", + 1 * np.ones([256, 32], dtype=np.int8), + 1 * np.ones([32, 128], dtype=np.int8), + 32 * np.ones([256, 128], dtype=np.int8), + test_params=TestParams( + tile_pipeline="pack-peel-4-level-tiling", + run_on_target=["npu1_4col"], + aie_compilation_flags=[ + "--iree-amdaie-num-rows=4", + "--iree-amdaie-num-cols=4", + ], + use_ukernel=True, + ), + ) + ) + # Phoenix : Vectorization + Peano. + self.register( + MatmulTrunci( + 256, + 128, + 32, + "i8", + "i32", + 1 * np.ones([256, 32], dtype=np.int8), + 1 * np.ones([32, 128], dtype=np.int8), + 32 * np.ones([256, 128], dtype=np.int8), + test_params=TestParams( + tile_pipeline="pack-peel-4-level-tiling", + run_on_target=["npu1_4col"], + aie_compilation_flags=[ + "--iree-amdaie-num-rows=4", + "--iree-amdaie-num-cols=4", + ], + ), + ) + ) + # Strix : Ukernel + Chess. + self.register( + MatmulTrunci( + 256, + 128, + 32, + "i8", + "i32", + 1 * np.ones([256, 32], dtype=np.int8), + 1 * np.ones([32, 128], dtype=np.int8), + 32 * np.ones([256, 128], dtype=np.int8), + test_params=TestParams( + tile_pipeline="pack-peel-4-level-tiling", + run_on_target=["npu4"], + aie_compilation_flags=[ + "--iree-amdaie-num-rows=4", + "--iree-amdaie-num-cols=8", + ], + use_chess=True, + use_ukernel=True, + ), + ) + ) # Matmul with truncf test(s): for tile_pipeline in ["pack-peel", "pack-peel-4-level-tiling"]: self.register( diff --git a/compiler/plugins/target/AMD-AIE/aievec/VectorToVectorConversions.cpp b/compiler/plugins/target/AMD-AIE/aievec/VectorToVectorConversions.cpp index b7e8cf744..2cde4b91e 100644 --- a/compiler/plugins/target/AMD-AIE/aievec/VectorToVectorConversions.cpp +++ b/compiler/plugins/target/AMD-AIE/aievec/VectorToVectorConversions.cpp @@ -802,10 +802,11 @@ struct ToMinorIdentityTransferReadPattern /// %1 = arith.truncf %0 : vector<6xf32> to vector<6xbf16> /// %2 = vector.shape_cast %1 : vector<6xbf16> to vector<2x3xbf16> // clang-format on -struct FlattenArithTruncFOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +template +struct FlattenArithTruncOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(arith::TruncFOp op, + LogicalResult matchAndRewrite(TruncOpTy op, PatternRewriter &rewriter) const override { // Get old shape type. auto oldShapedType = dyn_cast(op.getType()); @@ -826,7 +827,7 @@ struct FlattenArithTruncFOpPattern : public OpRewritePattern { Value newInputVector = rewriter.create( op.getLoc(), newVectorTypeForInput, origInputOfTruncFOp); // Create new base operation with the linearized input/output. - Value newTruncFOp = rewriter.create( + Value newTruncFOp = rewriter.create( op.getLoc(), newVectorTypeForOutput, newInputVector); // Delinearize the output back to the original type. rewriter.replaceOpWithNewOp(op, op.getType(), @@ -1054,11 +1055,12 @@ struct CanonicalizeVectorForAIEVecPass { RewritePatternSet patterns(context); - patterns - .add(context); + patterns.add, + FlattenArithTruncOpPattern, + ToMinorIdentityTransferReadPattern, + ToMinorIdentityTransferWritePattern, + ConvertLeadingUnitDimInsertToReshapePattern>(context); patterns.add(context); patterns .add) -> vector<2x3xbf16> { // ----- +// CHECK-LABEL: @arith_trunci( +// CHECK-SAME: %[[INP:.*]]: vector<2x3xi32>) +func.func @arith_trunci(%inp: vector<2x3xi32>) -> vector<2x3xi8> { + // CHECK: %[[LINEARIZE:.*]] = vector.shape_cast %[[INP]] : vector<2x3xi32> to vector<6xi32> + // CHECK: %[[TRUNCI:.*]] = arith.trunci %[[LINEARIZE]] : vector<6xi32> to vector<6xi8> + // CHECK: %[[DELINEARIZE:.*]] = vector.shape_cast %[[TRUNCI]] : vector<6xi8> to vector<2x3xi8> + // CHECK: return %[[DELINEARIZE]] + %0 = arith.trunci %inp : vector<2x3xi32> to vector<2x3xi8> + return %0 : vector<2x3xi8> +} + +// ----- + // CHECK: #map = affine_map<()[s0] -> (s0 * 256 + 96)> // CHECK-LABEL: @trivial_read_access // CHECK-SAME: (%[[ARG0:.*]]: memref<4x8x4x8xbf16, strided<[256, 32, 8, 1]>>, diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/mm_npu1.cc b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/mm_npu1.cc index c50012951..c4c69e744 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/mm_npu1.cc +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/mm_npu1.cc @@ -262,6 +262,137 @@ void matmul_vectorized(const T_in *__restrict pA, unsigned offsetA, event1(); } +template +static inline void matmul_vectorized_4x2(const T_in *__restrict pA, + unsigned offsetA, + const T_in *__restrict pB, + unsigned offsetB, + T_out *__restrict pC, + unsigned offsetC) { + + using MMUL = aie::mmul; + + event0(); + + for (unsigned z = 0; z < rowA; z += 4) + chess_prepare_for_pipelining chess_loop_range(4, ) { + T_out *__restrict pC1 = pC + offsetC + (z * colB + 0) * MMUL::size_C; + T_out *__restrict pC2 = pC + offsetC + ((z + 1) * colB + 0) * MMUL::size_C; + T_out *__restrict pC3 = pC + offsetC + ((z + 2) * colB + 0) * MMUL::size_C; + T_out *__restrict pC4 = pC + offsetC + ((z + 3) * colB + 0) * MMUL::size_C; + + for (unsigned j = 0; j < colB; j += 2) +#ifdef OPT_PERF_ENABLED + chess_flatten_loop +#endif + { + const T_in *__restrict pA1 = pA + offsetA + (z * colA + 0) * MMUL::size_A; + const T_in *__restrict pA2 = pA + offsetA + ((z + 1) * colA + 0) * MMUL::size_A; + const T_in *__restrict pA3 = pA + offsetA + ((z + 2) * colA + 0) * MMUL::size_A; + const T_in *__restrict pA4 = pA + offsetA + ((z + 3) * colA + 0) * MMUL::size_A; + + const T_in *__restrict pB1 = pB + offsetB + (0 * colB + j) * MMUL::size_B; + const T_in *__restrict pB2 = pB + offsetB + (0 * colB + (j + 1)) * MMUL::size_B; + + aie::vector A01 = aie::load_v(pA1); + pA1 += MMUL::size_A; + aie::vector A11 = aie::load_v(pA2); + pA2 += MMUL::size_A; + aie::vector A21 = aie::load_v(pA3); + pA3 += MMUL::size_A; + aie::vector A31 = aie::load_v(pA4); + pA4 += MMUL::size_A; + aie::vector B01 = aie::load_v(pB1); + pB1 += (MMUL::size_B * colB); + aie::vector B11 = aie::load_v(pB2); + pB2 += (MMUL::size_B * colB); + + aie::vector acc_C00 = + aie::load_v(pC1); + aie::vector acc_C01 = + aie::load_v(pC1 + MMUL::size_C); + aie::vector acc_C10 = + aie::load_v(pC2); + aie::vector acc_C11 = + aie::load_v(pC2 + MMUL::size_C); + aie::vector acc_C20 = + aie::load_v(pC3); + aie::vector acc_C21 = + aie::load_v(pC3 + MMUL::size_C); + aie::vector acc_C30 = + aie::load_v(pC4); + aie::vector acc_C31 = + aie::load_v(pC4 + MMUL::size_C); + + MMUL C00(acc_C00); + MMUL C01(acc_C01); + MMUL C10(acc_C10); + MMUL C11(acc_C11); + MMUL C20(acc_C20); + MMUL C21(acc_C21); + MMUL C30(acc_C30); + MMUL C31(acc_C31); + + C00.mac(A01, B01); + C01.mac(A01, B11); + C10.mac(A11, B01); + C11.mac(A11, B11); + C20.mac(A21, B01); + C21.mac(A21, B11); + C30.mac(A31, B01); + C31.mac(A31, B11); + + for (unsigned i = 1; i < colA; i += 1) +#ifdef OPT_PERF_ENABLED + chess_flatten_loop +#endif + { + A01 = aie::load_v(pA1); + pA1 += MMUL::size_A; + A11 = aie::load_v(pA2); + pA2 += MMUL::size_A; + A21 = aie::load_v(pA3); + pA3 += MMUL::size_A; + A31 = aie::load_v(pA4); + pA4 += MMUL::size_A; + B01 = aie::load_v(pB1); + pB1 += (MMUL::size_B * colB); + B11 = aie::load_v(pB2); + pB2 += (MMUL::size_B * colB); + + C00.mac(A01, B01); + C01.mac(A01, B11); + C10.mac(A11, B01); + C11.mac(A11, B11); + C20.mac(A21, B01); + C21.mac(A21, B11); + C30.mac(A31, B01); + C31.mac(A31, B11); + } + + aie::store_v(pC1, C00.template to_vector()); + pC1 += MMUL::size_C; + aie::store_v(pC1, C01.template to_vector()); + pC1 += MMUL::size_C; + aie::store_v(pC2, C10.template to_vector()); + pC2 += MMUL::size_C; + aie::store_v(pC2, C11.template to_vector()); + pC2 += MMUL::size_C; + aie::store_v(pC3, C20.template to_vector()); + pC3 += MMUL::size_C; + aie::store_v(pC3, C21.template to_vector()); + pC3 += MMUL::size_C; + aie::store_v(pC4, C30.template to_vector()); + pC4 += MMUL::size_C; + aie::store_v(pC4, C31.template to_vector()); + pC4 += MMUL::size_C; + } + } + + event1(); +} + template void matmul_vectorized_4x8x4_bf16_bf16_bf16(const bfloat16 *__restrict pA, unsigned offsetA, @@ -295,15 +426,35 @@ void matmul_vectorized_4x8x4_bf16_bf16_f32(const bfloat16 *__restrict pA, pA, offsetA, pB, offsetB, pC, offsetC); } +template +void matmul_vectorized_4x8x8_i8_i8_i32(const int8 *__restrict pA, + unsigned offsetA, + const int8 *__restrict pB, + unsigned offsetB, int32 *__restrict pC, + unsigned offsetC) { + constexpr int r = 4; + constexpr int s = 8; + constexpr int t = 8; + static_assert(m % (4 * r) == 0); // 'm' dimension + static_assert(k % s == 0); // 'k' dimension + static_assert(n % (2 * t) == 0); // 'n' dimension + return matmul_vectorized_4x2( + pA, offsetA, pB, offsetB, pC, offsetC); +} + extern "C" { #define matmul_combos(X, M, N, K) \ X(bfloat16, bf16, bfloat16, bf16, bfloat16, bf16, M, N, K, 4, 8, 4) \ X(bfloat16, bf16, bfloat16, bf16, float, f32, M, N, K, 4, 8, 4) +#define matmul_combos_i8(X, M, N, K) \ + X(int8, i8, int8, i8, int32, i32, M, N, K, 4, 8, 8) + #define zero_fill_combos(X, M, N) \ X(bfloat16, bf16, M, N, N/2) \ - X(float, f32, M, N, N/2) + X(float, f32, M, N, N/2) \ + X(int32, i32, M, N, N/2) #define matmul_vectorized_c_func(lhs_ctype_in, lhs_mlir_type_in, \ rhs_ctype_in, rhs_mlir_type_in, \ @@ -324,6 +475,12 @@ matmul_combos(matmul_vectorized_c_func, 16, 16, 32) matmul_combos(matmul_vectorized_c_func, 32, 32, 32) matmul_combos(matmul_vectorized_c_func, 64, 64, 64) matmul_combos(matmul_vectorized_c_func, 32, 32, 64) +matmul_combos_i8(matmul_vectorized_c_func, 16, 16, 32) +matmul_combos_i8(matmul_vectorized_c_func, 32, 32, 8) +matmul_combos_i8(matmul_vectorized_c_func, 32, 32, 16) +matmul_combos_i8(matmul_vectorized_c_func, 32, 32, 32) +matmul_combos_i8(matmul_vectorized_c_func, 32, 32, 64) +matmul_combos_i8(matmul_vectorized_c_func, 64, 64, 64) zero_fill_combos(zero_vectorized_c_func, 16, 16) zero_fill_combos(zero_vectorized_c_func, 32, 32) diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/mm_npu4.cc b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/mm_npu4.cc index cfd7c55cc..400f7e86c 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/mm_npu4.cc +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/mm_npu4.cc @@ -300,12 +300,16 @@ matmul_combos(matmul_vectorized_c_func, 32, 32, 64) matmul_combos(matmul_vectorized_c_func, 64, 64, 64) matmul_combos_i8(matmul_vectorized_c_func, 16, 16, 32) +matmul_combos_i8(matmul_vectorized_c_func, 32, 16, 32) +matmul_combos_i8(matmul_vectorized_c_func, 32, 16, 64) +matmul_combos_i8(matmul_vectorized_c_func, 32, 32, 8) matmul_combos_i8(matmul_vectorized_c_func, 32, 32, 32) matmul_combos_i8(matmul_vectorized_c_func, 32, 32, 64) matmul_combos_i8(matmul_vectorized_c_func, 64, 64, 64) zero_fill_combos(zero_vectorized_c_func, 16, 8) zero_fill_combos(zero_vectorized_c_func, 16, 16) +zero_fill_combos(zero_vectorized_c_func, 32, 16) zero_fill_combos(zero_vectorized_c_func, 32, 32) zero_fill_combos(zero_vectorized_c_func, 64, 64) diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEInsertCores.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEInsertCores.cpp index 46ce86187..39c7ef03a 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEInsertCores.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEInsertCores.cpp @@ -39,8 +39,8 @@ namespace { static bool isCoreComputeOp(Operation *op) { return isa( - op); + arith::TruncFOp, arith::TruncIOp, vector::TransferReadOp, + vector::TransferWriteOp>(op); } /// Utility to map the parallel mapping attributes to the corresponding diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEVectorization.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEVectorization.cpp index be3c47e42..dafea8ee1 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEVectorization.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEVectorization.cpp @@ -89,7 +89,7 @@ void AMDAIEVectorizationPass::runOnOperation() { // gap between this pass and vector-to-aievec. for (Operation &innerOps : cast(op).getBody()->getOperations()) { - if (!isa(innerOps)) { + if (!isa(innerOps)) { op->emitRemark() << "not vectorizing linalg elementwise op"; return; } diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp index 7c04e74f5..64a09c9bd 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp @@ -369,6 +369,7 @@ void addPackPeel4LevelTilingBasedPassPipeline( AMDAIETileAndFuseOptions tileFuseOptions; tileFuseOptions.tilingLevel = 1; tileFuseOptions.useSCFFor = false; + tileFuseOptions.tileElementwise = false; funcPassManager.addPass(createAMDAIETileAndFusePass(tileFuseOptions)); } funcPassManager.addPass(createAMDAIECleanupPass()); diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/insert_cores.mlir b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/insert_cores.mlir index 808f6d188..09352b518 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/insert_cores.mlir +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/insert_cores.mlir @@ -325,3 +325,28 @@ module { return } } + +// ----- + +// CHECK-LABEL: @insert_trunci_within_core +// CHECK: scf.forall +// CHECK: amdaie.tile +// CHECK: amdaie.core +// CHECK: vector.transfer_read +// CHECK: arith.trunci +// CHECK: vector.transfer_write +// CHECK: amdaie.end +module { + func.func @insert_trunci_within_core(%arg0: memref<10x10xi32, 2 : i32>, %arg1: memref<10x10xi8, 2 : i32>) { + %cst = arith.constant 0 : i32 + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %c0 = arith.constant 0 : index + scf.forall (%arg3, %arg4) in (2, 2) { + %read = vector.transfer_read %arg0[%c0, %c1], %cst {in_bounds = [true, true]} : memref<10x10xi32, 2 : i32>, vector<1x1xi32> + %trunci = arith.trunci %read : vector<1x1xi32> to vector<1x1xi8> + vector.transfer_write %trunci, %arg1[%c0, %c1] {in_bounds = [true, true]} : vector<1x1xi8>, memref<10x10xi8, 2 : i32> + } {mapping = [#gpu.thread, #gpu.thread]} + return + } +} diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/vectorization.mlir b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/vectorization.mlir index 4f6c95dfe..72d0be752 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/vectorization.mlir +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/vectorization.mlir @@ -64,9 +64,9 @@ func.func @fillAndCopy() -> tensor<8xbf16> { } -// CHECK-LABEL: @matmul_elementwise +// CHECK-LABEL: @matmul_elementwise_truncf // CHECK-SAME: (%[[ARG0:.*]]: tensor<4240x160xf32>, %[[ARG1:.*]]: tensor<4240x160xbf16>) -func.func @matmul_elementwise(%arg0: tensor<4240x160xf32>, %arg1: tensor<4240x160xbf16>) -> tensor<4240x160xbf16> { +func.func @matmul_elementwise_truncf(%arg0: tensor<4240x160xf32>, %arg1: tensor<4240x160xbf16>) -> tensor<4240x160xbf16> { %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%arg0: tensor<4240x160xf32>) outs(%arg1 : tensor<4240x160xbf16>) { ^bb0(%in: f32, %out: bf16): %1 = arith.truncf %in : f32 to bf16 @@ -77,3 +77,17 @@ func.func @matmul_elementwise(%arg0: tensor<4240x160xf32>, %arg1: tensor<4240x16 // CHECK: %[[VEC_OPERAND_0:.*]] = vector.transfer_read %[[ARG0]]{{.*}} vector<4240x160xf32> // CHECK: %[[TRUNCF:.*]] = arith.truncf %[[VEC_OPERAND_0]] // CHECK: vector.transfer_write %[[TRUNCF]], %[[ARG1]] + +// CHECK-LABEL: @matmul_elementwise_trunci +// CHECK-SAME: (%[[ARG0:.*]]: tensor<4240x160xi32>, %[[ARG1:.*]]: tensor<4240x160xi8>) +func.func @matmul_elementwise_trunci(%arg0: tensor<4240x160xi32>, %arg1: tensor<4240x160xi8>) -> tensor<4240x160xi8> { + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%arg0: tensor<4240x160xi32>) outs(%arg1 : tensor<4240x160xi8>) { + ^bb0(%in: i32, %out: i8): + %1 = arith.trunci %in : i32 to i8 + linalg.yield %1 : i8 + } -> tensor<4240x160xi8> + return %0 : tensor<4240x160xi8> +} +// CHECK: %[[VEC_OPERAND_0:.*]] = vector.transfer_read %[[ARG0]]{{.*}} vector<4240x160xi32> +// CHECK: %[[TRUNCI:.*]] = arith.trunci %[[VEC_OPERAND_0]] +// CHECK: vector.transfer_write %[[TRUNCI]], %[[ARG1]]