Skip to content

Commit

Permalink
Skip the verified config (#1775)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <yi4.liu@intel.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
yiliu30 and pre-commit-ci[bot] authored May 8, 2024
1 parent 1a45090 commit 7b8aec0
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 10 deletions.
7 changes: 7 additions & 0 deletions neural_compressor/common/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,13 @@ def _is_op_type(name: str) -> bool:
def get_config_set_for_tuning(cls):
raise NotImplementedError

def __eq__(self, other: BaseConfig) -> bool:
if not isinstance(other, type(self)):
return False
return self.params_list == other.params_list and all(
getattr(self, str(attr)) == getattr(other, str(attr)) for attr in self.params_list
)


class ComposableConfig(BaseConfig):
name = COMPOSABLE_CONFIG
Expand Down
20 changes: 18 additions & 2 deletions neural_compressor/common/base_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,13 +231,29 @@ def __len__(self) -> int:


class ConfigLoader:
def __init__(self, config_set: ConfigSet, sampler: Sampler = default_sampler) -> None:
def __init__(
self, config_set: ConfigSet, sampler: Sampler = default_sampler, skip_verified_config: bool = True
) -> None:
self.config_set = ConfigSet.from_fwk_configs(config_set)
self._sampler = sampler(self.config_set)
self.skip_verified_config = skip_verified_config
self.verify_config_list = list()

def is_verified_config(self, config):
for verified_config in self.verify_config_list:
if config == verified_config:
return True
return False

def __iter__(self) -> Generator[BaseConfig, Any, None]:
for index in self._sampler:
yield self.config_set[index]
new_config = self.config_set[index]
if self.skip_verified_config and self.is_verified_config(new_config):
logger.warning("Skip the verified config:")
logger.warning(new_config.to_dict())
continue
self.verify_config_list.append(new_config)
yield new_config


class TuningConfig:
Expand Down
3 changes: 3 additions & 0 deletions neural_compressor/common/tuning_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,6 @@ def is_tunable(self, value: Any) -> bool:
except Exception as e:
logger.debug(f"Failed to validate the input_args: {e}")
return False

def __str__(self) -> str:
return self.name
6 changes: 3 additions & 3 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def rtn_entry(
configs_mapping: Dict[Tuple[str, callable], RTNConfig],
mode: Mode = Mode.QUANTIZE,
*args,
**kwargs
**kwargs,
) -> torch.nn.Module:
"""The main entry to apply rtn quantization."""
from neural_compressor.torch.algorithms.weight_only.rtn import RTNQuantizer
Expand Down Expand Up @@ -258,7 +258,7 @@ def awq_quantize_entry(
configs_mapping: Dict[Tuple[str, callable], AWQConfig],
mode: Mode = Mode.QUANTIZE,
*args,
**kwargs
**kwargs,
) -> torch.nn.Module:
logger.info("Quantize model with the AWQ algorithm.")
from neural_compressor.torch.algorithms.weight_only.awq import AWQQuantizer
Expand Down Expand Up @@ -455,7 +455,7 @@ def hqq_entry(
configs_mapping: Dict[Tuple[str, Callable], HQQConfig],
mode: Mode = Mode.QUANTIZE,
*args,
**kwargs
**kwargs,
) -> torch.nn.Module:
from neural_compressor.torch.algorithms.weight_only.hqq import HQQuantizer

Expand Down
8 changes: 8 additions & 0 deletions test/3x/common/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,14 @@ def test_config_loader(self) -> None:
for i, config in enumerate(self.loader):
self.assertEqual(config, self.config_set[i])

def test_config_loader_skip_verified_config(self) -> None:
config_set = [FakeAlgoConfig(weight_bits=[4, 8]), FakeAlgoConfig(weight_bits=8)]
config_loader = ConfigLoader(config_set)
config_count = 0
for i, config in enumerate(config_loader):
config_count += 1
self.assertEqual(config_count, 2)


if __name__ == "__main__":
unittest.main()
7 changes: 2 additions & 5 deletions test/3x/torch/quantization/weight_only/test_mixed_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,8 @@


def run_fn(model):
# GPTQ uses ValueError to reduce computation when collecting input data of the first block
# It's special for UTs, no need to add this wrapper in examples.
with pytest.raises(ValueError):
model(torch.tensor([[10, 20, 30]], dtype=torch.long))
model(torch.tensor([[40, 50, 60]], dtype=torch.long))
model(torch.tensor([[10, 20, 30]], dtype=torch.long))
model(torch.tensor([[40, 50, 60]], dtype=torch.long))


class TestMixedTwoAlgo:
Expand Down

0 comments on commit 7b8aec0

Please sign in to comment.