diff --git a/backends/arm/test/common.py b/backends/arm/test/common.py index 6a8f4b526d..2a6e10e161 100644 --- a/backends/arm/test/common.py +++ b/backends/arm/test/common.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import os import shutil import tempfile @@ -23,6 +24,8 @@ def get_tosa_compile_spec(permute_memory_to_nhwc=False, custom_path=None): Default compile spec for TOSA tests. """ intermediate_path = custom_path or tempfile.mkdtemp(prefix="arm_tosa_") + if not os.path.exists(intermediate_path): + os.makedirs(intermediate_path, exist_ok=True) compile_spec = ( ArmCompileSpecBuilder() .tosa_compile_spec() diff --git a/backends/arm/test/misc/test_debug_feats.py b/backends/arm/test/misc/test_debug_feats.py index df55e5253e..cc3a556363 100644 --- a/backends/arm/test/misc/test_debug_feats.py +++ b/backends/arm/test/misc/test_debug_feats.py @@ -26,7 +26,7 @@ def __init__( bias: bool = True, ): super().__init__() - self.inputs = (torch.ones(5, 10, 25, in_features),) + self.inputs = (torch.randn(5, 10, 25, in_features),) self.fc = torch.nn.Linear( in_features=in_features, out_features=out_features, @@ -93,3 +93,35 @@ def test_BI_artifact(self): if self._is_tosa_marker_in_file(tmp_file): return # Implicit pass test self.fail("File does not contain TOSA dump!") + + +class TestNumericalDiffPrints(unittest.TestCase): + def test_numerical_diff_prints(self): + model = Linear(20, 30) + tester = ( + ArmTester( + model, + inputs=model.get_inputs(), + compile_spec=common.get_tosa_compile_spec(), + ) + .quantize() + .export() + .to_edge() + .partition() + .to_executorch() + ) + if common.TOSA_REF_MODEL_INSTALLED: + # We expect an assertion error here. Any other issues will cause the + # test to fail. Likewise the test will fail if the assertion error is + # not present. + try: + # Tolerate 0 difference => we want to trigger a numerical diff + tester.run_method_and_compare_outputs(atol=0, rtol=0, qtol=0) + except AssertionError: + pass # Implicit pass test + else: + self.fail() + else: + logger.warning( + "TOSA ref model tool not installed, skip numerical correctness tests" + ) diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index 62cf2482b0..841978b539 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -3,10 +3,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import logging from typing import List, Optional, Tuple, Union import numpy as np +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + import torch from executorch.backends.arm.arm_backend import ( @@ -237,6 +242,9 @@ def run_method_and_compare_outputs( export_stage.artifact, is_quantized ) + self.qp_input = qp_input + self.qp_output = qp_output + # Calculate the reference output using the original module or the quant # module. quantization_scale = None @@ -302,3 +310,39 @@ def _calculate_reference_output( """ return module.forward(*inputs) + + def _compare_outputs( + self, + reference_output, + stage_output, + quantization_scale=None, + atol=1e-03, + rtol=1e-03, + qtol=0, + ): + try: + super()._compare_outputs( + reference_output, stage_output, quantization_scale, atol, rtol, qtol + ) + except AssertionError as e: + # Capture assertion error and print more info + banner = "=" * 40 + "TOSA debug info" + "=" * 40 + logger.error(banner) + path_to_tosa_files = self.tosa_test_util.get_tosa_artifact_path() + logger.error(f"{self.qp_input=}") + logger.error(f"{self.qp_output=}") + logger.error(f"{path_to_tosa_files=}") + import os + + torch.save( + stage_output, + os.path.join(path_to_tosa_files, "torch_tosa_output.pt"), + ) + + torch.save( + reference_output, + os.path.join(path_to_tosa_files, "torch_ref_output.pt"), + ) + logger.error(f"{atol=}, {rtol=}, {qtol=}") + + raise e diff --git a/backends/arm/test/tosautil/tosa_test_utils.py b/backends/arm/test/tosautil/tosa_test_utils.py index a79d73abf9..0ab693e3db 100644 --- a/backends/arm/test/tosautil/tosa_test_utils.py +++ b/backends/arm/test/tosautil/tosa_test_utils.py @@ -28,6 +28,9 @@ def __init__(self, node_name: str, zp: int, scale: float): self.zp = zp self.scale = scale + def __repr__(self): + return f"QuantizationParams(node_name={self.node_name}, zp={self.zp}, scale={self.scale})" + """ This class is used to work with TOSA artifacts. @@ -213,6 +216,9 @@ def run_tosa_ref_model( return tosa_ref_output + def get_tosa_artifact_path(self): + return self.intermediate_path + @staticmethod def _run_cmd(cmd: List[str]) -> None: """