Skip to content

Commit

Permalink
FIX Prevent CUDA context initialization due to AWQ (#2230)
Browse files Browse the repository at this point in the history
Importing from AWQ triggers CUDA context initialization, which can be
problematic in some circumstances (see #1877). This PR moves the import
so that it's local, preventing this issue.
  • Loading branch information
BenjaminBossan authored Dec 5, 2024
1 parent f86522e commit 15712db
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions src/peft/tuners/lora/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@
from peft.tuners.tuners_utils import BaseTunerLayer


if is_auto_awq_available():
from awq.modules.linear import WQLinear_GEMM


class AwqLoraLinear(torch.nn.Module, LoraLayer):
def __init__(
self,
Expand Down Expand Up @@ -105,18 +101,21 @@ def dispatch_awq(
else:
target_base_layer = target

if is_auto_awq_available() and isinstance(target_base_layer, WQLinear_GEMM):
# Raise the error only at the dispatch level
AUTOAWQ_MINIMUM_VERSION = packaging.version.parse("0.2.0")
version_autoawq = packaging.version.parse(importlib_metadata.version("autoawq"))
if is_auto_awq_available():
from awq.modules.linear import WQLinear_GEMM

if isinstance(target_base_layer, WQLinear_GEMM):
# Raise the error only at the dispatch level
AUTOAWQ_MINIMUM_VERSION = packaging.version.parse("0.2.0")
version_autoawq = packaging.version.parse(importlib_metadata.version("autoawq"))

if AUTOAWQ_MINIMUM_VERSION > version_autoawq:
raise ImportError(
f"Found an incompatible version of auto-awq. Found version {version_autoawq}, "
f"but only versions above {AUTOAWQ_MINIMUM_VERSION} are supported for PEFT."
)
if AUTOAWQ_MINIMUM_VERSION > version_autoawq:
raise ImportError(
f"Found an incompatible version of auto-awq. Found version {version_autoawq}, "
f"but only versions above {AUTOAWQ_MINIMUM_VERSION} are supported for PEFT."
)

new_module = AwqLoraLinear(target, adapter_name, **kwargs)
target.qweight = target_base_layer.qweight
new_module = AwqLoraLinear(target, adapter_name, **kwargs)
target.qweight = target_base_layer.qweight

return new_module

0 comments on commit 15712db

Please sign in to comment.