Skip to content

Commit

Permalink
Fix the acceptance test (#3660)
Browse files Browse the repository at this point in the history
Signed-off-by: Hitarth Mehta <quic_hitameht@quicinc.com>
  • Loading branch information
quic-hitameht authored Dec 18, 2024
1 parent e213b1f commit d8c67af
Showing 1 changed file with 2 additions and 22 deletions.
24 changes: 2 additions & 22 deletions NightlyTests/torch/test_bias_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,13 @@
import numpy as np
import torch
import torch.nn as nn
from contextlib import contextmanager

from aimet_common.defs import QuantScheme
import aimet_torch.bias_correction
import aimet_torch.layer_selector
from aimet_torch import bias_correction
from aimet_torch.v1.quantsim import QuantParams
from aimet_torch import batch_norm_fold
from aimet_torch import bias_correction as bc
from models.mobilenet import MobileNetV2
from models.imagenet_dataloader import ImageNetDataLoader

Expand All @@ -64,28 +62,10 @@ def evaluate(model, early_stopping_iterations, use_cuda):
return model(random_input)


@contextmanager
def _use_python_impl(flag: bool):
orig_flag = bc.USE_PYTHON_IMPL
try:
bc.USE_PYTHON_IMPL = flag
yield
finally:
bc.USE_PYTHON_IMPL = orig_flag


@pytest.fixture(params=[True, False])
def use_python_impl(request):
param: bool = request.param

with _use_python_impl(param):
yield


class TestBiasCorrection:

@pytest.mark.cuda
def test_bias_correction_empirical(self, use_python_impl):
def test_bias_correction_empirical(self):

torch.manual_seed(10)
model = MobileNetV2().to(torch.device('cpu'))
Expand Down Expand Up @@ -115,7 +95,7 @@ def test_bias_correction_empirical(self, use_python_impl):
assert isinstance(model.features[11].conv[0], nn.Conv2d)

@pytest.mark.cuda
def test_bias_correction_hybrid(self, use_python_impl):
def test_bias_correction_hybrid(self):
torch.manual_seed(10)

model = MobileNetV2().to(torch.device('cpu'))
Expand Down

0 comments on commit d8c67af

Please sign in to comment.