From e02faecb87c7b235ba39d0d1dbc742fb9fd513bd Mon Sep 17 00:00:00 2001 From: Jack Zhang Date: Fri, 11 Oct 2024 12:05:13 -0700 Subject: [PATCH] Add possibility to collect all TOSA tests to a specified path (#5028) (#6174) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Done in order to collect test vectors for backend compilers. Signed-off-by: Per Åstrand Change-Id: I0fc6e4d6bfcccd6aae18847a9a33f76d3d19fe5f Pull Request resolved: https://github.com/pytorch/executorch/pull/5028 Reviewed By: cccclai Differential Revision: D62242846 Pulled By: digantdesai fbshipit-source-id: 9ecfb7be3c5ed432a2cc36c2ea1eac7157ef6673 Co-authored-by: Per Åstrand --- backends/arm/test/common.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/backends/arm/test/common.py b/backends/arm/test/common.py index 614960b71b..57281ea8f8 100644 --- a/backends/arm/test/common.py +++ b/backends/arm/test/common.py @@ -89,6 +89,29 @@ def is_option_enabled(option: str, fail_if_not_enabled: bool = False) -> bool: return False +def maybe_get_tosa_collate_path() -> str | None: + """ + Checks the environment variable TOSA_TESTCASES_BASE_PATH and returns the + path to the where to store the current tests if it is set. + """ + tosa_test_base = os.environ.get("TOSA_TESTCASES_BASE_PATH") + if tosa_test_base: + current_test = os.environ.get("PYTEST_CURRENT_TEST") + #'backends/arm/test/ops/test_mean_dim.py::TestMeanDim::test_meandim_tosa_BI_0_zeros (call)' + test_class = current_test.split("::")[1] + test_name = current_test.split("::")[-1].split(" ")[0] + if "BI" in test_name: + tosa_test_base = os.path.join(tosa_test_base, "tosa-bi") + elif "MI" in test_name: + tosa_test_base = os.path.join(tosa_test_base, "tosa-mi") + else: + tosa_test_base = os.path.join(tosa_test_base, "other") + + return os.path.join(tosa_test_base, test_class, test_name) + + return None + + def get_tosa_compile_spec( permute_memory_to_nhwc=True, custom_path=None ) -> list[CompileSpec]: @@ -104,7 +127,13 @@ def get_tosa_compile_spec_unbuilt( """Get the ArmCompileSpecBuilder for the default TOSA tests, to modify the compile spec before calling .build() to finalize it. """ - intermediate_path = custom_path or tempfile.mkdtemp(prefix="arm_tosa_") + if not custom_path: + intermediate_path = maybe_get_tosa_collate_path() or tempfile.mkdtemp( + prefix="arm_tosa_" + ) + else: + intermediate_path = custom_path + if not os.path.exists(intermediate_path): os.makedirs(intermediate_path, exist_ok=True) compile_spec_builder = (