Skip to content

Commit

Permalink
Bias backup (AUTOMATIC1111#7)
Browse files Browse the repository at this point in the history
* Prevent uncessary bias backup

* Fix LoRA bias error

---------

Co-authored-by: AUTOMATIC1111 <16777216c@gmail.com>
  • Loading branch information
huchenlei and AUTOMATIC1111 authored May 16, 2024
1 parent 5b49881 commit b66dfb5
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion extensions-builtin/Lora/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,13 +378,18 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
self.network_weights_backup = weights_backup

bias_backup = getattr(self, "network_bias_backup", None)
if bias_backup is None:
if bias_backup is None and wanted_names != ():
if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None:
bias_backup = self.out_proj.bias.to(devices.cpu, copy=True)
elif getattr(self, 'bias', None) is not None:
bias_backup = self.bias.to(devices.cpu, copy=True)
else:
bias_backup = None

# Unlike weight which always has value, some modules don't have bias.
# Only report if bias is not None and current bias are not unchanged.
if bias_backup is not None and current_names != ():
raise RuntimeError("no backup bias found and current bias are not unchanged")
self.network_bias_backup = bias_backup

if current_names != wanted_names:
Expand Down

0 comments on commit b66dfb5

Please sign in to comment.