diff --git a/src/peft/tuners/lora/awq.py b/src/peft/tuners/lora/awq.py index 7245200b74..86989d9000 100644 --- a/src/peft/tuners/lora/awq.py +++ b/src/peft/tuners/lora/awq.py @@ -22,10 +22,6 @@ from peft.tuners.tuners_utils import BaseTunerLayer -if is_auto_awq_available(): - from awq.modules.linear import WQLinear_GEMM - - class AwqLoraLinear(torch.nn.Module, LoraLayer): def __init__( self, @@ -105,18 +101,21 @@ def dispatch_awq( else: target_base_layer = target - if is_auto_awq_available() and isinstance(target_base_layer, WQLinear_GEMM): - # Raise the error only at the dispatch level - AUTOAWQ_MINIMUM_VERSION = packaging.version.parse("0.2.0") - version_autoawq = packaging.version.parse(importlib_metadata.version("autoawq")) + if is_auto_awq_available(): + from awq.modules.linear import WQLinear_GEMM + + if isinstance(target_base_layer, WQLinear_GEMM): + # Raise the error only at the dispatch level + AUTOAWQ_MINIMUM_VERSION = packaging.version.parse("0.2.0") + version_autoawq = packaging.version.parse(importlib_metadata.version("autoawq")) - if AUTOAWQ_MINIMUM_VERSION > version_autoawq: - raise ImportError( - f"Found an incompatible version of auto-awq. Found version {version_autoawq}, " - f"but only versions above {AUTOAWQ_MINIMUM_VERSION} are supported for PEFT." - ) + if AUTOAWQ_MINIMUM_VERSION > version_autoawq: + raise ImportError( + f"Found an incompatible version of auto-awq. Found version {version_autoawq}, " + f"but only versions above {AUTOAWQ_MINIMUM_VERSION} are supported for PEFT." + ) - new_module = AwqLoraLinear(target, adapter_name, **kwargs) - target.qweight = target_base_layer.qweight + new_module = AwqLoraLinear(target, adapter_name, **kwargs) + target.qweight = target_base_layer.qweight return new_module