From 79f486209e54ac23194c0941a9550fd426c9b564 Mon Sep 17 00:00:00 2001 From: Sarunya Pumma Date: Thu, 1 Feb 2024 13:12:20 -0800 Subject: [PATCH] Refactor generate_cache_tbes in unit test (#2304) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2304 Removed unused args from `generate_cache_tbes` Reviewed By: q10 Differential Revision: D53305015 fbshipit-source-id: 12572d794e912f272e6c16dba6646d6e420ad6ec --- fbgemm_gpu/test/tbe/cache_common.py | 2 -- fbgemm_gpu/test/tbe/cache_overflow_test.py | 2 -- fbgemm_gpu/test/tbe/cache_test.py | 10 ++++++++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/fbgemm_gpu/test/tbe/cache_common.py b/fbgemm_gpu/test/tbe/cache_common.py index b282a9e6be..2fdcef1fa8 100644 --- a/fbgemm_gpu/test/tbe/cache_common.py +++ b/fbgemm_gpu/test/tbe/cache_common.py @@ -40,9 +40,7 @@ def generate_cache_tbes( T: int, D: int, - B: int, log_E: int, - L: int, mixed: bool, cache_algorithm: CacheAlgorithm = CacheAlgorithm.LRU, prefetch_pipeline: bool = False, diff --git a/fbgemm_gpu/test/tbe/cache_overflow_test.py b/fbgemm_gpu/test/tbe/cache_overflow_test.py index 413918a7af..f1032b2001 100644 --- a/fbgemm_gpu/test/tbe/cache_overflow_test.py +++ b/fbgemm_gpu/test/tbe/cache_overflow_test.py @@ -45,9 +45,7 @@ def test_cache_int32_overflow(self) -> None: cc, cc_ref, _, _ = generate_cache_tbes( T=1, D=D_fac, - B=128, log_E=1, - L=1, mixed=False, prefetch_pipeline=True, cache_sets=cache_sets, diff --git a/fbgemm_gpu/test/tbe/cache_test.py b/fbgemm_gpu/test/tbe/cache_test.py index bd00d4222e..94cc8f2cb4 100644 --- a/fbgemm_gpu/test/tbe/cache_test.py +++ b/fbgemm_gpu/test/tbe/cache_test.py @@ -60,7 +60,7 @@ def test_cache_pipeline( cache_algorithm: CacheAlgorithm, ) -> None: cc, cc_ref, min_Es, sum_Ds = generate_cache_tbes( - T, D, B, log_E, L, mixed, cache_algorithm + T, D, log_E, mixed, cache_algorithm ) iters = 3 requests = generate_requests(iters, B, T, L, min_Es, reuse=0.1) @@ -101,7 +101,13 @@ def _test_cache_prefetch_pipeline( # noqa C901 assert prefetch_location in ["before_fwd", "between_fwd_bwd"] cc, cc_ref, min_Es, sum_Ds = generate_cache_tbes( - T, D, B, log_E, L, mixed, CacheAlgorithm.LRU, True, True + T, + D, + log_E, + mixed, + CacheAlgorithm.LRU, + prefetch_pipeline=True, + use_int_weight=True, ) iters = 5 requests = generate_requests(iters, B, T, L, min_Es, reuse=0.1)