Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[from_pretrained] Simpler code for peft #25726

Merged
merged 9 commits into from
Aug 24, 2023
38 changes: 14 additions & 24 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2368,28 +2368,18 @@ def from_pretrained(
" ignored."
)

if is_peft_available() and _adapter_model_path is None:
maybe_adapter_model_path = find_adapter_config_file(
pretrained_model_name_or_path,
revision=revision,
subfolder=subfolder,
token=token,
commit_hash=commit_hash,
)
elif is_peft_available() and _adapter_model_path is not None:
maybe_adapter_model_path = _adapter_model_path
else:
maybe_adapter_model_path = None

has_adapter_config = maybe_adapter_model_path is not None

if has_adapter_config:
if _adapter_model_path is not None:
adapter_model_id = _adapter_model_path
else:
with open(maybe_adapter_model_path, "r", encoding="utf-8") as f:
adapter_model_id = pretrained_model_name_or_path
pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"]
if is_peft_available():
if _adapter_model_path is None:
_adapter_model_path = find_adapter_config_file(
pretrained_model_name_or_path,
revision=revision,
subfolder=subfolder,
token=token,
commit_hash=commit_hash,
)
if os.path.isfile(_adapter_model_path):
with open(_adapter_model_path, "r", encoding="utf-8") as f:
_adapter_model_path = json.load(f)["base_model_name_or_path"]

# change device_map into a map if we passed an int, a str or a torch.device
if isinstance(device_map, torch.device):
Expand Down Expand Up @@ -3221,9 +3211,9 @@ def from_pretrained(
if quantization_method_from_config == QuantizationMethod.GPTQ:
model = quantizer.post_init_model(model)

if has_adapter_config:
if _adapter_model_path is not None:
model.load_adapter(
adapter_model_id,
_adapter_model_path,
adapter_name=adapter_name,
revision=revision,
token=token,
Expand Down
Loading