diff --git a/tuning/test_candidate_gen.py b/tuning/test_candidate_gen.py index 5adca21..ad9b97e 100644 --- a/tuning/test_candidate_gen.py +++ b/tuning/test_candidate_gen.py @@ -13,26 +13,50 @@ def test_get_shaped_type_element_bitwidth(): - assert candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8).bitwidth == 8 - assert candidate_gen.ShapedType([2048], candidate_gen.ElementType.i32).bitwidth == 32 - assert candidate_gen.ShapedType([2048, 512, 384], candidate_gen.ElementType.f8).bitwidth == 8 - assert candidate_gen.ShapedType([1, 1], candidate_gen.ElementType.f16).bitwidth == 16 + assert ( + candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8).bitwidth + == 8 + ) + assert ( + candidate_gen.ShapedType([2048], candidate_gen.ElementType.i32).bitwidth == 32 + ) + assert ( + candidate_gen.ShapedType( + [2048, 512, 384], candidate_gen.ElementType.f8 + ).bitwidth + == 8 + ) + assert ( + candidate_gen.ShapedType([1, 1], candidate_gen.ElementType.f16).bitwidth == 16 + ) def test_get_shaped_type_to_str(): - assert str(candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8)) == "1024x2048xi8" - assert str(candidate_gen.ShapedType([1024], candidate_gen.ElementType.f32)) == "1024xf32" - assert str(candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f16)) == "1x2x3xf16" - assert str(candidate_gen.ShapedType([-1, 2, 3], candidate_gen.ElementType.f16)) == "?x2x3xf16" + assert ( + str(candidate_gen.ShapedType([1024, 2048], candidate_gen.ElementType.i8)) + == "1024x2048xi8" + ) + assert ( + str(candidate_gen.ShapedType([1024], candidate_gen.ElementType.f32)) + == "1024xf32" + ) + assert ( + str(candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f16)) + == "1x2x3xf16" + ) + assert ( + str(candidate_gen.ShapedType([-1, 2, 3], candidate_gen.ElementType.f16)) + == "?x2x3xf16" + ) def test_parse_tensor_type(): - assert candidate_gen.parse_tensor_type("tensor<1x2x3xf32>") == candidate_gen.ShapedType( - [1, 2, 3], candidate_gen.ElementType.f32 - ) - assert candidate_gen.parse_tensor_type("tensor<123xi8>") == candidate_gen.ShapedType( - [123], candidate_gen.ElementType.i8 - ) + assert candidate_gen.parse_tensor_type( + "tensor<1x2x3xf32>" + ) == candidate_gen.ShapedType([1, 2, 3], candidate_gen.ElementType.f32) + assert candidate_gen.parse_tensor_type( + "tensor<123xi8>" + ) == candidate_gen.ShapedType([123], candidate_gen.ElementType.i8) def test_get_mmt_tile_sizes(): @@ -74,7 +98,11 @@ def test_get_contract_tile_sizes(): assert candidate_gen.get_contract_tile_sizes(config, ["m", "n", "k"]) == [4, 8, 16] assert candidate_gen.get_contract_tile_sizes(config, ["n", "m", "k"]) == [8, 4, 16] assert candidate_gen.get_contract_tile_sizes(config, ["k", "n", "m"]) == [16, 8, 4] - assert candidate_gen.get_contract_tile_sizes(config, ["k", "k", "k"]) == [16, 16, 16] + assert candidate_gen.get_contract_tile_sizes(config, ["k", "k", "k"]) == [ + 16, + 16, + 16, + ] def test_get_pipeline_config(): @@ -141,7 +169,9 @@ def test_get_shapes_contract(): r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%13, %14 : tensor<2048x1280xf16>, tensor<1280x1280xf16>) outs(%19 : tensor<2048x1280xf32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', r"^bb0(%in: f16, %in_0: f16, %out: f32):", ] - assert candidate_gen.get_shapes_contract(template, "mk", "nk") == candidate_gen.ProblemSize( + assert candidate_gen.get_shapes_contract( + template, "mk", "nk" + ) == candidate_gen.ProblemSize( candidate_gen.MatmulSize(2048, 1280, 1280), candidate_gen.ShapedType([2048, 1280], candidate_gen.ElementType.f16), candidate_gen.ShapedType([1280, 1280], candidate_gen.ElementType.f16), @@ -156,7 +186,9 @@ def test_get_shapes_batch_matmul(): "%11 = linalg.batch_matmul ins(%8, %9 : tensor<1x32x1024xf32>, tensor<1x1024x32xf32>) outs(%10 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", "flow.dispatch.tensor.store %11, %2, offsets = [%arg0, %arg1, %arg2], sizes = [1, 32, 32], strides = [1, 1, 1] : tensor<1x32x32xf32> -> !flow.dispatch.tensor>", ] - assert candidate_gen.get_shapes_batch_matmul(template, "bmk", "bkn") == candidate_gen.ProblemSize( + assert candidate_gen.get_shapes_batch_matmul( + template, "bmk", "bkn" + ) == candidate_gen.ProblemSize( candidate_gen.MatmulSize(32, 32, 1024, 1), candidate_gen.ShapedType([1, 32, 1024], candidate_gen.ElementType.f32), candidate_gen.ShapedType([1, 1024, 32], candidate_gen.ElementType.f32), @@ -181,8 +213,14 @@ def test_get_shapes_batch_mmt(): def test_mfma_intrinsic_to_str(): - assert str(candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32()) == "MFMA_F16_16x16x16_F32" - assert str(candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32()) == "MFMA_I8_32x32x16_I32" + assert ( + str(candidate_gen.MfmaIntrinsic.mfma_f16_16x16x16_f32()) + == "MFMA_F16_16x16x16_F32" + ) + assert ( + str(candidate_gen.MfmaIntrinsic.mfma_i8_32x32x16_i32()) + == "MFMA_I8_32x32x16_I32" + ) def test_get_compatible_mfma_intrinsics(): @@ -256,7 +294,8 @@ def test_calculate_shared_memory_usage_in_bytes(): matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt ) assert ( - candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) == 81920 + candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 512, 64, 128) + == 81920 ) rhs_type = candidate_gen.ShapedType([1024, 1024], candidate_gen.ElementType.i32) @@ -264,7 +303,8 @@ def test_calculate_shared_memory_usage_in_bytes(): matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt ) assert ( - candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 128, 64, 32) == 12288 + candidate_gen.calculate_shared_memory_usage_in_bytes(problem_size, 128, 64, 32) + == 12288 ) @@ -277,11 +317,19 @@ def test_generate_constraints_valid_input(): matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt ) # Define input parameters as z3 Ints - m, n, k = candidate_gen.z3.Int("m"), candidate_gen.z3.Int("n"), candidate_gen.z3.Int("k") + m, n, k = ( + candidate_gen.z3.Int("m"), + candidate_gen.z3.Int("n"), + candidate_gen.z3.Int("k"), + ) subgroup_size = candidate_gen.z3.Int("subgroup_size") intrinsic_mn = candidate_gen.z3.Int("intrinsic_mn") intrinsic_k = candidate_gen.z3.Int("intrinsic_k") - wg_x, wg_y, wg_z = candidate_gen.z3.Int("wg_x"), candidate_gen.z3.Int("wg_y"), candidate_gen.z3.Int("wg_z") + wg_x, wg_y, wg_z = ( + candidate_gen.z3.Int("wg_x"), + candidate_gen.z3.Int("wg_y"), + candidate_gen.z3.Int("wg_z"), + ) sg_m_cnt = candidate_gen.z3.Int("sg_m_cnt") sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt") waves_per_eu = candidate_gen.z3.Int("waves_per_eu") @@ -314,11 +362,19 @@ def test_generate_constraints_invalid_input(): problem_size = candidate_gen.ProblemSize( matmul_size, lhs_type, rhs_type, res_type, candidate_gen.DispatchKind.mmt ) - m, n, k = candidate_gen.z3.Int("m"), candidate_gen.z3.Int("n"), candidate_gen.z3.Int("k") + m, n, k = ( + candidate_gen.z3.Int("m"), + candidate_gen.z3.Int("n"), + candidate_gen.z3.Int("k"), + ) subgroup_size = candidate_gen.z3.Int("subgroup_size") intrinsic_mn = candidate_gen.z3.Int("intrinsic_mn") intrinsic_k = candidate_gen.z3.Int("intrinsic_k") - wg_x, wg_y, wg_z = candidate_gen.z3.Int("wg_x"), candidate_gen.z3.Int("wg_y"), candidate_gen.z3.Int("wg_z") + wg_x, wg_y, wg_z = ( + candidate_gen.z3.Int("wg_x"), + candidate_gen.z3.Int("wg_y"), + candidate_gen.z3.Int("wg_z"), + ) sg_m_cnt = candidate_gen.z3.Int("sg_m_cnt") sg_n_cnt = candidate_gen.z3.Int("sg_n_cnt") waves_per_eu = candidate_gen.z3.Int("waves_per_eu") @@ -370,7 +426,9 @@ def test_apply_params_mmt(): candidate_gen.ShapedType([M, N], candidate_gen.ElementType.f32), candidate_gen.DispatchKind.mmt, ) - modified, embeddable = candidate_gen.apply_params_mmt(problem_size, mlir_template, config) + modified, embeddable = candidate_gen.apply_params_mmt( + problem_size, mlir_template, config + ) assert modified assert embeddable @@ -408,12 +466,16 @@ def test_apply_params_conv(): problem_size = candidate_gen.ProblemSize( candidate_gen.MatmulSize(oh * ow, oc, fh * fw * ic), - candidate_gen.ShapedType([n, oh + 2, ow + 2, oc], candidate_gen.ElementType.f16), + candidate_gen.ShapedType( + [n, oh + 2, ow + 2, oc], candidate_gen.ElementType.f16 + ), candidate_gen.ShapedType([fh, fw, ic, oc], candidate_gen.ElementType.f16), candidate_gen.ShapedType([n, oh, ow, oc], candidate_gen.ElementType.f32), candidate_gen.DispatchKind.conv, ) - modified, embeddable = candidate_gen.apply_params_conv(problem_size, mlir_template, config) + modified, embeddable = candidate_gen.apply_params_conv( + problem_size, mlir_template, config + ) assert modified assert embeddable @@ -717,4 +779,6 @@ def test_parse_mlir(): mlir_module = candidate_gen.parse_mlir(mlir_str) assert mlir_module != None assert isinstance(mlir_module, candidate_gen.ireec._mlir_libs._mlir.ir.Module) - assert isinstance(mlir_module.body.operations[0], candidate_gen.ireec.dialects.func.FuncOp) + assert isinstance( + mlir_module.body.operations[0], candidate_gen.ireec.dialects.func.FuncOp + )