Skip to content

Commit

Permalink
Arm backend: Provide more debug info for numerical diff issues (#3596)
Browse files Browse the repository at this point in the history
Summary:
* Save torch reference output
* Print useful debug info
* Print path to debug artifacts

Pull Request resolved: #3596

Reviewed By: manuelcandales

Differential Revision: D57618585

Pulled By: digantdesai

fbshipit-source-id: 80eebc25e02ef9f9aab1a0367cb0eab4a5c73eae
  • Loading branch information
freddan80 authored and facebook-github-bot committed May 21, 2024
1 parent a707550 commit bc5ba99
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 1 deletion.
3 changes: 3 additions & 0 deletions backends/arm/test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
import shutil
import tempfile

Expand All @@ -23,6 +24,8 @@ def get_tosa_compile_spec(permute_memory_to_nhwc=False, custom_path=None):
Default compile spec for TOSA tests.
"""
intermediate_path = custom_path or tempfile.mkdtemp(prefix="arm_tosa_")
if not os.path.exists(intermediate_path):
os.makedirs(intermediate_path, exist_ok=True)
compile_spec = (
ArmCompileSpecBuilder()
.tosa_compile_spec()
Expand Down
34 changes: 33 additions & 1 deletion backends/arm/test/misc/test_debug_feats.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(
bias: bool = True,
):
super().__init__()
self.inputs = (torch.ones(5, 10, 25, in_features),)
self.inputs = (torch.randn(5, 10, 25, in_features),)
self.fc = torch.nn.Linear(
in_features=in_features,
out_features=out_features,
Expand Down Expand Up @@ -93,3 +93,35 @@ def test_BI_artifact(self):
if self._is_tosa_marker_in_file(tmp_file):
return # Implicit pass test
self.fail("File does not contain TOSA dump!")


class TestNumericalDiffPrints(unittest.TestCase):
def test_numerical_diff_prints(self):
model = Linear(20, 30)
tester = (
ArmTester(
model,
inputs=model.get_inputs(),
compile_spec=common.get_tosa_compile_spec(),
)
.quantize()
.export()
.to_edge()
.partition()
.to_executorch()
)
if common.TOSA_REF_MODEL_INSTALLED:
# We expect an assertion error here. Any other issues will cause the
# test to fail. Likewise the test will fail if the assertion error is
# not present.
try:
# Tolerate 0 difference => we want to trigger a numerical diff
tester.run_method_and_compare_outputs(atol=0, rtol=0, qtol=0)
except AssertionError:
pass # Implicit pass test
else:
self.fail()
else:
logger.warning(
"TOSA ref model tool not installed, skip numerical correctness tests"
)
44 changes: 44 additions & 0 deletions backends/arm/test/tester/arm_tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
from typing import List, Optional, Tuple, Union

import numpy as np

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


import torch

from executorch.backends.arm.arm_backend import (
Expand Down Expand Up @@ -237,6 +242,9 @@ def run_method_and_compare_outputs(
export_stage.artifact, is_quantized
)

self.qp_input = qp_input
self.qp_output = qp_output

# Calculate the reference output using the original module or the quant
# module.
quantization_scale = None
Expand Down Expand Up @@ -302,3 +310,39 @@ def _calculate_reference_output(
"""

return module.forward(*inputs)

def _compare_outputs(
self,
reference_output,
stage_output,
quantization_scale=None,
atol=1e-03,
rtol=1e-03,
qtol=0,
):
try:
super()._compare_outputs(
reference_output, stage_output, quantization_scale, atol, rtol, qtol
)
except AssertionError as e:
# Capture assertion error and print more info
banner = "=" * 40 + "TOSA debug info" + "=" * 40
logger.error(banner)
path_to_tosa_files = self.tosa_test_util.get_tosa_artifact_path()
logger.error(f"{self.qp_input=}")
logger.error(f"{self.qp_output=}")
logger.error(f"{path_to_tosa_files=}")
import os

torch.save(
stage_output,
os.path.join(path_to_tosa_files, "torch_tosa_output.pt"),
)

torch.save(
reference_output,
os.path.join(path_to_tosa_files, "torch_ref_output.pt"),
)
logger.error(f"{atol=}, {rtol=}, {qtol=}")

raise e
6 changes: 6 additions & 0 deletions backends/arm/test/tosautil/tosa_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def __init__(self, node_name: str, zp: int, scale: float):
self.zp = zp
self.scale = scale

def __repr__(self):
return f"QuantizationParams(node_name={self.node_name}, zp={self.zp}, scale={self.scale})"


"""
This class is used to work with TOSA artifacts.
Expand Down Expand Up @@ -213,6 +216,9 @@ def run_tosa_ref_model(

return tosa_ref_output

def get_tosa_artifact_path(self):
return self.intermediate_path

@staticmethod
def _run_cmd(cmd: List[str]) -> None:
"""
Expand Down

0 comments on commit bc5ba99

Please sign in to comment.