diff --git a/backends/arm/test/common.py b/backends/arm/test/common.py index 614960b71b..57281ea8f8 100644 --- a/backends/arm/test/common.py +++ b/backends/arm/test/common.py @@ -89,6 +89,29 @@ def is_option_enabled(option: str, fail_if_not_enabled: bool = False) -> bool: return False +def maybe_get_tosa_collate_path() -> str | None: + """ + Checks the environment variable TOSA_TESTCASES_BASE_PATH and returns the + path to the where to store the current tests if it is set. + """ + tosa_test_base = os.environ.get("TOSA_TESTCASES_BASE_PATH") + if tosa_test_base: + current_test = os.environ.get("PYTEST_CURRENT_TEST") + #'backends/arm/test/ops/test_mean_dim.py::TestMeanDim::test_meandim_tosa_BI_0_zeros (call)' + test_class = current_test.split("::")[1] + test_name = current_test.split("::")[-1].split(" ")[0] + if "BI" in test_name: + tosa_test_base = os.path.join(tosa_test_base, "tosa-bi") + elif "MI" in test_name: + tosa_test_base = os.path.join(tosa_test_base, "tosa-mi") + else: + tosa_test_base = os.path.join(tosa_test_base, "other") + + return os.path.join(tosa_test_base, test_class, test_name) + + return None + + def get_tosa_compile_spec( permute_memory_to_nhwc=True, custom_path=None ) -> list[CompileSpec]: @@ -104,7 +127,13 @@ def get_tosa_compile_spec_unbuilt( """Get the ArmCompileSpecBuilder for the default TOSA tests, to modify the compile spec before calling .build() to finalize it. """ - intermediate_path = custom_path or tempfile.mkdtemp(prefix="arm_tosa_") + if not custom_path: + intermediate_path = maybe_get_tosa_collate_path() or tempfile.mkdtemp( + prefix="arm_tosa_" + ) + else: + intermediate_path = custom_path + if not os.path.exists(intermediate_path): os.makedirs(intermediate_path, exist_ok=True) compile_spec_builder = (