Skip to content

Commit

Permalink
Fix lint err
Browse files Browse the repository at this point in the history
  • Loading branch information
RattataKing committed Aug 19, 2024
1 parent 41f8b7d commit 65c3731
Showing 1 changed file with 93 additions and 29 deletions.
122 changes: 93 additions & 29 deletions tuning/test_candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,50 @@


def test_get_shaped_type_element_bitwidth():
assert candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8).bitwidth == 8
assert candidate_gen.ShapedType([2048], candidate_gen.ElementType.i32).bitwidth == 32
assert candidate_gen.ShapedType([2048, 512, 384], candidate_gen.ElementType.f8).bitwidth == 8
assert candidate_gen.ShapedType([1, 1], candidate_gen.ElementType.f16).bitwidth == 16
assert (
candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8).bitwidth
== 8
)
assert (
candidate_gen.ShapedType([2048], candidate_gen.ElementType.i32).bitwidth == 32
)
assert (
candidate_gen.ShapedType(
[2048, 512, 384], candidate_gen.ElementType.f8
).bitwidth
== 8
)
assert (
candidate_gen.ShapedType([1, 1], candidate_gen.ElementType.f16).bitwidth == 16
)


def test_get_shaped_type_to_str():
assert str(candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8)) == "1024x2048xi8"
assert str(candidate_gen.ShapedType([1024], candidate_gen.ElementType.f32)) == "1024xf32"
assert str(candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f16)) == "1x2x3xf16"
assert str(candidate_gen.ShapedType([-1, 2, 3], candidate_gen.ElementType.f16)) == "?x2x3xf16"
assert (
str(candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8))
== "1024x2048xi8"
)
assert (
str(candidate_gen.ShapedType([1024], candidate_gen.ElementType.f32))
== "1024xf32"
)
assert (
str(candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f16))
== "1x2x3xf16"
)
assert (
str(candidate_gen.ShapedType([-1, 2, 3], candidate_gen.ElementType.f16))
== "?x2x3xf16"
)


def test_parse_tensor_type():
assert candidate_gen.parse_tensor_type("tensor<1x2x3xf32>") == candidate_gen.ShapedType(
[1, 2, 3], candidate_gen.ElementType.f32
)
assert candidate_gen.parse_tensor_type("tensor<123xi8>") == candidate_gen.ShapedType(
[123], candidate_gen.ElementType.i8
)
assert candidate_gen.parse_tensor_type(
"tensor<1x2x3xf32>"
) == candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f32)
assert candidate_gen.parse_tensor_type(
"tensor<123xi8>"
) == candidate_gen.ShapedType([123], candidate_gen.ElementType.i8)


def test_get_mmt_tile_sizes():
Expand Down Expand Up @@ -74,7 +98,11 @@ def test_get_contract_tile_sizes():
assert candidate_gen.get_contract_tile_sizes(config, ["m", "n", "k"]) == [4, 8, 16]
assert candidate_gen.get_contract_tile_sizes(config, ["n", "m", "k"]) == [8, 4, 16]
assert candidate_gen.get_contract_tile_sizes(config, ["k", "n", "m"]) == [16, 8, 4]
assert candidate_gen.get_contract_tile_sizes(config, ["k", "k", "k"]) == [16, 16, 16]
assert candidate_gen.get_contract_tile_sizes(config, ["k", "k", "k"]) == [
16,
16,
16,
]


def test_get_pipeline_config():
Expand Down Expand Up @@ -141,7 +169,9 @@ def test_get_shapes_contract():
r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[64, 128, 64]]>} {',
r"^bb0(%in: f16, %in_0: f16, %out: f32):",
]
assert candidate_gen.get_shapes_contract(template, "mk", "nk") == candidate_gen.ProblemSize(
assert candidate_gen.get_shapes_contract(
template, "mk", "nk"
) == candidate_gen.ProblemSize(
candidate_gen.MatmulSize(2048, 1280, 1280),
candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16),
candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16),
Expand All @@ -156,7 +186,9 @@ def test_get_shapes_batch_matmul():
"%11 = linalg.batch_matmul ins(%8, %9 : tensor<1x32x1024xf32>, tensor<1x1024x32xf32>) outs(%10 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>",
"flow.dispatch.tensor.store %11, %2, offsets = [%arg0, %arg1, %arg2], sizes = [1, 32, 32], strides = [1, 1, 1] : tensor<1x32x32xf32> -> !flow.dispatch.tensor<writeonly:tensor<4x32x64xf32>>",
]
assert candidate_gen.get_shapes_batch_matmul(template, "bmk", "bkn") == candidate_gen.ProblemSize(
assert candidate_gen.get_shapes_batch_matmul(
template, "bmk", "bkn"
) == candidate_gen.ProblemSize(
candidate_gen.MatmulSize(32, 32, 1024, 1),
candidate_gen.ShapedType([1, 32, 1024], candidate_gen.ElementType.f32),
candidate_gen.ShapedType([1, 1024, 32], candidate_gen.ElementType.f32),
Expand All @@ -181,8 +213,14 @@ def test_get_shapes_batch_mmt():


def test_mfma_intrinsic_to_str():
assert str(candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32()) == "MFMA_F16_16x16x16_F32"
assert str(candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32()) == "MFMA_I8_32x32x16_I32"
assert (
str(candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32())
== "MFMA_F16_16x16x16_F32"
)
assert (
str(candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32())
== "MFMA_I8_32x32x16_I32"
)


def test_get_compatible_mfma_intrinsics():
Expand Down Expand Up @@ -256,15 +294,17 @@ def test_calculate_shared_memory_usage_in_bytes():
matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt
)
assert (
candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) == 81920
candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128)
== 81920
)

rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.i32)
problem_size = candidate_gen.ProblemSize(
matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt
)
assert (
candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 128, 64, 32) == 12288
candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 128, 64, 32)
== 12288
)


Expand All @@ -277,11 +317,19 @@ def test_generate_constraints_valid_input():
matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt
)
# Define input parameters as z3 Ints
m, n, k = candidate_gen.z3.Int("m"), candidate_gen.z3.Int("n"), candidate_gen.z3.Int("k")
m, n, k = (
candidate_gen.z3.Int("m"),
candidate_gen.z3.Int("n"),
candidate_gen.z3.Int("k"),
)
subgroup_size = candidate_gen.z3.Int("subgroup_size")
intrinsic_mn = candidate_gen.z3.Int("intrinsic_mn")
intrinsic_k = candidate_gen.z3.Int("intrinsic_k")
wg_x, wg_y, wg_z = candidate_gen.z3.Int("wg_x"), candidate_gen.z3.Int("wg_y"), candidate_gen.z3.Int("wg_z")
wg_x, wg_y, wg_z = (
candidate_gen.z3.Int("wg_x"),
candidate_gen.z3.Int("wg_y"),
candidate_gen.z3.Int("wg_z"),
)
sg_m_cnt = candidate_gen.z3.Int("sg_m_cnt")
sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt")
waves_per_eu = candidate_gen.z3.Int("waves_per_eu")
Expand Down Expand Up @@ -314,11 +362,19 @@ def test_generate_constraints_invalid_input():
problem_size = candidate_gen.ProblemSize(
matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt
)
m, n, k = candidate_gen.z3.Int("m"), candidate_gen.z3.Int("n"), candidate_gen.z3.Int("k")
m, n, k = (
candidate_gen.z3.Int("m"),
candidate_gen.z3.Int("n"),
candidate_gen.z3.Int("k"),
)
subgroup_size = candidate_gen.z3.Int("subgroup_size")
intrinsic_mn = candidate_gen.z3.Int("intrinsic_mn")
intrinsic_k = candidate_gen.z3.Int("intrinsic_k")
wg_x, wg_y, wg_z = candidate_gen.z3.Int("wg_x"), candidate_gen.z3.Int("wg_y"), candidate_gen.z3.Int("wg_z")
wg_x, wg_y, wg_z = (
candidate_gen.z3.Int("wg_x"),
candidate_gen.z3.Int("wg_y"),
candidate_gen.z3.Int("wg_z"),
)
sg_m_cnt = candidate_gen.z3.Int("sg_m_cnt")
sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt")
waves_per_eu = candidate_gen.z3.Int("waves_per_eu")
Expand Down Expand Up @@ -370,7 +426,9 @@ def test_apply_params_mmt():
candidate_gen.ShapedType([M, N], candidate_gen.ElementType.f32),
candidate_gen.DispatchKind.mmt,
)
modified, embeddable = candidate_gen.apply_params_mmt(problem_size, mlir_template, config)
modified, embeddable = candidate_gen.apply_params_mmt(
problem_size, mlir_template, config
)

assert modified
assert embeddable
Expand Down Expand Up @@ -408,12 +466,16 @@ def test_apply_params_conv():

problem_size = candidate_gen.ProblemSize(
candidate_gen.MatmulSize(oh * ow, oc, fh * fw * ic),
candidate_gen.ShapedType([n, oh + 2, ow + 2, oc], candidate_gen.ElementType.f16),
candidate_gen.ShapedType(
[n, oh + 2, ow + 2, oc], candidate_gen.ElementType.f16
),
candidate_gen.ShapedType([fh, fw, ic, oc], candidate_gen.ElementType.f16),
candidate_gen.ShapedType([n, oh, ow, oc], candidate_gen.ElementType.f32),
candidate_gen.DispatchKind.conv,
)
modified, embeddable = candidate_gen.apply_params_conv(problem_size, mlir_template, config)
modified, embeddable = candidate_gen.apply_params_conv(
problem_size, mlir_template, config
)

assert modified
assert embeddable
Expand Down Expand Up @@ -717,4 +779,6 @@ def test_parse_mlir():
mlir_module = candidate_gen.parse_mlir(mlir_str)
assert mlir_module != None
assert isinstance(mlir_module, candidate_gen.ireec._mlir_libs._mlir.ir.Module)
assert isinstance(mlir_module.body.operations[0], candidate_gen.ireec.dialects.func.FuncOp)
assert isinstance(
mlir_module.body.operations[0], candidate_gen.ireec.dialects.func.FuncOp
)

0 comments on commit 65c3731

Please sign in to comment.