diff --git a/python/tvm/autotvm/task/task.py b/python/tvm/autotvm/task/task.py index 7c587fe39783..a0c992b07347 100644 --- a/python/tvm/autotvm/task/task.py +++ b/python/tvm/autotvm/task/task.py @@ -338,7 +338,7 @@ def _count_flop(exp): expr.Max, expr.Min, expr.EQ, expr.NE, expr.LT, expr.LE, expr.GT, expr.GE, expr.And, expr.Or, expr.Not)): - base = 1 if "float" in exp.a.dtype else 0 + base = 1 if isinstance(exp, expr.Not): # unary return base + _count_flop(exp.a) @@ -348,6 +348,10 @@ def _count_flop(exp): return _count_flop(exp.condition) + max(_count_flop(exp.true_value), _count_flop(exp.false_value)) if isinstance(exp, expr.Call): + if exp.call_type == expr.Call.Halide: + # Ignore flops from indexing expressions. + return 0 + return sum([_count_flop(x) for x in exp.args]) raise FlopCalculationError("Found unsupported operator in the compute expr") diff --git a/tests/python/unittest/test_autotvm_flop_calculator.py b/tests/python/unittest/test_autotvm_flop_calculator.py index 27bd49fe14df..c5c046894f0c 100644 --- a/tests/python/unittest/test_autotvm_flop_calculator.py +++ b/tests/python/unittest/test_autotvm_flop_calculator.py @@ -5,11 +5,17 @@ from tvm.autotvm.task.task import compute_flop +def random_dtypes(): + """Return pair of (input, accumulator) dtypes""" + candidates = [("float32", "float32"), ("float16", "float32"), ("int8", "int32")] + return candidates[np.random.choice(len(candidates))] + def test_conv(): for i in range(5): N, H, W, CO, CI, KH, KW = [np.random.randint(10, 32) for _ in range(7)] - D = tvm.placeholder((N, CI, H, W)) - K = tvm.placeholder((CO, CI, KH, KW)) + (input_dtype, acc_dtype) = random_dtypes() + D = tvm.placeholder((N, CI, H, W), dtype=input_dtype) + K = tvm.placeholder((CO, CI, KH, KW), dtype=input_dtype) KH = min(H, KH) KW = min(W, KW) @@ -22,7 +28,8 @@ def test_conv(): OW = (W - KW) + 1 C = tvm.compute((N, CO, OH, OW), lambda n, co, h, w: - tvm.sum(D[n][ci][h][w] * K[co][ci][h][w], axis=[ci, kh, kw])) + tvm.sum(D[n][ci][h][w].astype(acc_dtype) * K[co][ci][h][w].astype(acc_dtype), + axis=[ci, kh, kw])) s = tvm.create_schedule([C.op]) @@ -31,15 +38,16 @@ def test_conv(): def test_pack_gemm(): for i in range(5): N, L, M = [np.random.randint(10, 128) * 4 for _ in range(3)] - A = tvm.placeholder((N, L)) - B = tvm.placeholder((M, L)) + (input_dtype, acc_dtype) = random_dtypes() + A = tvm.placeholder((N, L), dtype=input_dtype) + B = tvm.placeholder((M, L), dtype=input_dtype) k = tvm.reduce_axis((0, L)) bn = 4 A_pack = tvm.compute((N // bn, L, bn), lambda i, j, k: A[i * bn + k][j]) B_pack = tvm.compute((M // bn, L, bn), lambda i, j, k: B[i * bn + k][j]) C_pack = tvm.compute((N // bn, M // bn, bn, bn), lambda i, j, ii, jj: - tvm.sum(A_pack[i, k, ii] * B_pack[j, k, jj], axis=[k])) + tvm.sum(A_pack[i, k, ii].astype(acc_dtype) * B_pack[j, k, jj].astype(acc_dtype), axis=[k])) C = tvm.compute((N, M), lambda i, j: C_pack[i // bn][j // bn][i % bn][j % bn]) s = tvm.create_schedule([C.op]) @@ -48,14 +56,61 @@ def test_pack_gemm(): def test_outer_dot(): for i in range(5): N, M = [np.random.randint(10, 128) * 4 for _ in range(2)] - A = tvm.placeholder((N,)) - B = tvm.placeholder((M,)) + (input_dtype, acc_dtype) = random_dtypes() + A = tvm.placeholder((N,), dtype=input_dtype) + B = tvm.placeholder((M,), dtype=input_dtype) - C = tvm.compute((N, M), lambda i, j: A[i] * B[j]) + C = tvm.compute((N, M), lambda i, j: A[i].astype(acc_dtype) * B[j].astype(acc_dtype)) s = tvm.create_schedule([C.op]) assert compute_flop(s) == N * M +def test_max_pool(): + for i in range(5): + N, H, W, CO, CI, KH, KW = [np.random.randint(10, 32) for _ in range(7)] + (input_dtype, _) = random_dtypes() + D = tvm.placeholder((N, CI, H, W), dtype=input_dtype) + + KH = min(H, KH) + KW = min(W, KW) + + kh = tvm.reduce_axis((0, KH)) + kw = tvm.reduce_axis((0, KW)) + + OH = (H - KH) + 1 + OW = (W - KW) + 1 + + C = tvm.compute( + (N, CO, OH, OW), + lambda n, co, h, w: tvm.max(D[n][co][h + kh][w + kw], axis=[kh, kw])) + + s = tvm.create_schedule([C.op]) + + assert compute_flop(s) == N * CO * OH * OW * KH * KW + +def test_average_pool(): + for i in range(5): + N, H, W, CO, CI, KH, KW = [np.random.randint(10, 32) for _ in range(7)] + (input_dtype, acc_dtype) = random_dtypes() + D = tvm.placeholder((N, CI, H, W), dtype=input_dtype) + + KH = min(H, KH) + KW = min(W, KW) + + kh = tvm.reduce_axis((0, KH)) + kw = tvm.reduce_axis((0, KW)) + + OH = (H - KH) + 1 + OW = (W - KW) + 1 + + C = tvm.compute( + (N, CO, OH, OW), + lambda n, co, h, w: tvm.sum(D[n][co][h + kh][w + kw].astype(acc_dtype) / (KW * KH), axis=[kh, kw])) + + s = tvm.create_schedule([C.op]) + + assert compute_flop(s) == 2 * N * CO * OH * OW * KH * KW + def test_move(): """No float number operation in simple move. So the estimator should raise an error """ N = 1024