Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][DO NOT REVIEW] i8->i8 Phoenix|Strix + Ukernel|Vectorization #1084

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// input ${M}x${K}x${TYPE1}
// input ${K}x${N}x${TYPE1}

func.func @matmul_trunci(%arg0: tensor<${M}x${K}x${TYPE1}>, %arg1: tensor<${K}x${N}x${TYPE1}>) -> tensor<${M}x${N}x${TYPE1}>
{
%cst = arith.constant ${ZERO} : ${TYPE2}
%0 = tensor.empty() : tensor<${M}x${N}x${TYPE2}>
%1 = linalg.fill ins(%cst : ${TYPE2}) outs(%0 : tensor<${M}x${N}x${TYPE2}>) -> tensor<${M}x${N}x${TYPE2}>
%2 = linalg.matmul ins(%arg0, %arg1 : tensor<${M}x${K}x${TYPE1}>, tensor<${K}x${N}x${TYPE1}>)
outs(%1: tensor<${M}x${N}x${TYPE2}>) -> tensor<${M}x${N}x${TYPE2}>
%3 = arith.trunci %2 : tensor<${M}x${N}x${TYPE2}> to tensor<${M}x${N}x${TYPE1}>
return %3: tensor<${M}x${N}x${TYPE1}>
}
143 changes: 142 additions & 1 deletion build_tools/ci/cpu_comparison/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,10 +634,11 @@ def __init__(
expected_out,
run_on_target=["npu1_4col"],
tile_pipeline="pack-peel",
aie_compilation_flags=None,
):
super().__init__(
run_on_target=run_on_target,
aie_compilation_flags=None,
aie_compilation_flags=aie_compilation_flags,
M=M,
N=M,
K=K,
Expand Down Expand Up @@ -694,6 +695,88 @@ def _execute(self, config):
return True


class MatmulTrunci(BaseMatmul):
"""
A test of the form matmul(A,B) + trunci(C) where A:MxK, B:KxM and C:MxM
"""

def __init__(
self,
M,
K,
input_type,
acc_type,
lhs,
rhs,
expected_out,
run_on_target=["npu1_4col"],
tile_pipeline="pack-peel",
aie_compilation_flags=None,
use_ukernel=False,
n_repeats=1,
use_chess=False,
):
super().__init__(
run_on_target=run_on_target,
aie_compilation_flags=aie_compilation_flags,
M=M,
N=M,
K=K,
input_type=input_type,
acc_type=acc_type,
tile_pipeline=tile_pipeline,
n_repeats=n_repeats,
use_ukernel=use_ukernel,
use_chess=use_chess,
)
self.labels.append("MatmulTrunci")

# Assertions on shapes: Check that lhs is MxK, rhs is KxM, and expected_out is MxM
assert lhs.shape == (M, K)
assert rhs.shape == (K, M)
assert expected_out.shape == (M, M)

self.name = f"matmul_trunci_{M}_{K}_{input_type}_{acc_type}"
if tile_pipeline == "pack-peel-4-level-tiling":
self.name += "_4_level_tiling"
self.lhs = lhs
self.rhs = rhs
self.expected_out = expected_out

def _execute(self, config):
matmul_template_dir = config.file_dir / "matmul_template"
template_name = matmul_template_dir / "matmul_trunci_MxK_KxN.mlir"
self.generate(config, template_name)
filename = self.get_filename(config)
input_args = generate_inputs(
filename, self.get_dir(config), 1, {1: self.lhs, 2: self.rhs}
)
"""
Currently without function outlining, we run out of program memory.
"""
self.add_aie_compilation_flags(
["--iree-amdaie-enable-function-outlining=balanced"]
)
aie_vs_baseline(
config=config,
aie_compilation_flags=self.aie_compilation_flags,
test_file=self.get_filename(config),
input_args=input_args,
baseline_value=self.expected_out,
use_ukernel=self.use_ukernel,
tile_pipeline=self.tile_pipeline,
function_name=None,
seed=1,
rtol=0,
atol=0,
lower_to_aie_pipeline=self.lower_to_aie_pipeline,
n_repeats=self.n_repeats,
output_type=get_output_type(self.get_filename(config)),
)

return True


def find_executable(install_dir: Path, executable_name):
"""
Search for an executable in the given directory and its subdirectories
Expand Down Expand Up @@ -1475,6 +1558,64 @@ def __init__(self):
self.existing_names = []
self.tests = []

# Tests Matmul + Trunci.
# Phoenix : Ukernel + Peano.
self.register(
MatmulTrunci(
256,
32,
"i8",
"i32",
1 * np.ones([256, 32], dtype=np.int8),
1 * np.ones([32, 256], dtype=np.int8),
32 * np.ones([256, 256], dtype=np.int8),
tile_pipeline="pack-peel-4-level-tiling",
run_on_target=["npu1_4col"],
aie_compilation_flags=[
"--iree-amdaie-num-rows=4",
"--iree-amdaie-num-cols=4",
],
use_ukernel=True,
)
)
# Phoenix : Vectorization + Peano.
self.register(
MatmulTrunci(
256,
32,
"i8",
"i32",
1 * np.ones([256, 32], dtype=np.int8),
1 * np.ones([32, 256], dtype=np.int8),
32 * np.ones([256, 256], dtype=np.int8),
tile_pipeline="pack-peel-4-level-tiling",
run_on_target=["npu1_4col"],
aie_compilation_flags=[
"--iree-amdaie-num-rows=4",
"--iree-amdaie-num-cols=4",
],
)
)
# Strix : Ukernel + Chess.
self.register(
MatmulTrunci(
256,
32,
"i8",
"i32",
1 * np.ones([256, 32], dtype=np.int8),
1 * np.ones([32, 256], dtype=np.int8),
32 * np.ones([256, 256], dtype=np.int8),
tile_pipeline="pack-peel-4-level-tiling",
run_on_target=["npu4"],
aie_compilation_flags=[
"--iree-amdaie-num-rows=4",
"--iree-amdaie-num-cols=8",
],
use_chess=True,
use_ukernel=True,
)
)
# Matmul with truncf test(s):
for tile_pipeline in ["pack-peel", "pack-peel-4-level-tiling"]:
self.register(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -802,10 +802,11 @@ struct ToMinorIdentityTransferReadPattern
/// %1 = arith.truncf %0 : vector<6xf32> to vector<6xbf16>
/// %2 = vector.shape_cast %1 : vector<6xbf16> to vector<2x3xbf16>
// clang-format on
struct FlattenArithTruncFOpPattern : public OpRewritePattern<arith::TruncFOp> {
using OpRewritePattern<arith::TruncFOp>::OpRewritePattern;
template <typename TruncOpTy>
struct FlattenArithTruncOpPattern : public OpRewritePattern<TruncOpTy> {
using OpRewritePattern<TruncOpTy>::OpRewritePattern;

LogicalResult matchAndRewrite(arith::TruncFOp op,
LogicalResult matchAndRewrite(TruncOpTy op,
PatternRewriter &rewriter) const override {
// Get old shape type.
auto oldShapedType = dyn_cast<VectorType>(op.getType());
Expand All @@ -826,7 +827,7 @@ struct FlattenArithTruncFOpPattern : public OpRewritePattern<arith::TruncFOp> {
Value newInputVector = rewriter.create<vector::ShapeCastOp>(
op.getLoc(), newVectorTypeForInput, origInputOfTruncFOp);
// Create new base operation with the linearized input/output.
Value newTruncFOp = rewriter.create<arith::TruncFOp>(
Value newTruncFOp = rewriter.create<TruncOpTy>(
op.getLoc(), newVectorTypeForOutput, newInputVector);
// Delinearize the output back to the original type.
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(op, op.getType(),
Expand Down Expand Up @@ -1054,11 +1055,12 @@ struct CanonicalizeVectorForAIEVecPass

{
RewritePatternSet patterns(context);
patterns
.add<ExtractTransposeFromContractionOp, FlattenArithTruncFOpPattern,
ToMinorIdentityTransferReadPattern,
ToMinorIdentityTransferWritePattern,
ConvertLeadingUnitDimInsertToReshapePattern>(context);
patterns.add<ExtractTransposeFromContractionOp,
FlattenArithTruncOpPattern<arith::TruncFOp>,
FlattenArithTruncOpPattern<arith::TruncIOp>,
ToMinorIdentityTransferReadPattern,
ToMinorIdentityTransferWritePattern,
ConvertLeadingUnitDimInsertToReshapePattern>(context);
patterns.add<ConvertSplatTransferReadToBroadcastPattern>(context);
patterns
.add<copied_from_mlir::FlattenContiguousRowMajorTransferReadPattern,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,19 @@ func.func @arith_truncf(%inp: vector<2x3xf32>) -> vector<2x3xbf16> {

// -----

// CHECK-LABEL: @arith_trunci(
// CHECK-SAME: %[[INP:.*]]: vector<2x3xi32>)
func.func @arith_trunci(%inp: vector<2x3xi32>) -> vector<2x3xi8> {
// CHECK: %[[LINEARIZE:.*]] = vector.shape_cast %[[INP]] : vector<2x3xi32> to vector<6xi32>
// CHECK: %[[TRUNCI:.*]] = arith.trunci %[[LINEARIZE]] : vector<6xi32> to vector<6xi8>
// CHECK: %[[DELINEARIZE:.*]] = vector.shape_cast %[[TRUNCI]] : vector<6xi8> to vector<2x3xi8>
// CHECK: return %[[DELINEARIZE]]
%0 = arith.trunci %inp : vector<2x3xi32> to vector<2x3xi8>
return %0 : vector<2x3xi8>
}

// -----

// CHECK: #map = affine_map<()[s0] -> (s0 * 256 + 96)>
// CHECK-LABEL: @trivial_read_access
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x8x4x8xbf16, strided<[256, 32, 8, 1]>>,
Expand Down
Loading
Loading