diff --git a/analog/lora/lora.py b/analog/lora/lora.py index 77732513..6fa18aab 100644 --- a/analog/lora/lora.py +++ b/analog/lora/lora.py @@ -18,6 +18,17 @@ def find_parameter_sharing_group( return found_groups[0] +def _get_submodules(model, key): + """ + Helper function to replace a module with transformers model + https://github.com/huggingface/peft/blob/c0dd27bc974e4a62c6072142146887b75bb2de6c/src/peft/utils/other.py#L251 + """ + parent = model.get_submodule(".".join(key.split(".")[:-1])) + target_name = key.split(".")[-1] + target = model.get_submodule(key) + return parent, target, target_name + + class LoRAHandler: """ Transforms a model into a Lora model. @@ -94,4 +105,5 @@ def add_lora( lora_module.init_weight(self.init_strategy, hessian_state[name]) lora_module.to(device) - setattr(model, name, lora_module) + parent, target, target_name = _get_submodules(model, name) + setattr(parent, target_name, lora_module)