Skip to content

Commit

Permalink
Port back unload_model_clones
Browse files Browse the repository at this point in the history
  • Loading branch information
huchenlei committed Jan 2, 2025
1 parent 66eeb4e commit aaee30d
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions lib_layerdiffusion/attention_sharing.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,9 +324,42 @@ def __init__(self, layer_list):
self.layers = torch.nn.ModuleList(layer_list)


def unload_model_clones(model, unload_weights_only=True, force_unload=True):
current_loaded_models = model_management.current_loaded_models

to_unload = []
for i, m in enumerate(current_loaded_models):
if model.is_clone(m.model):
to_unload = [i] + to_unload

if len(to_unload) == 0:
return True

same_weights = 0
for i in to_unload:
if model.clone_has_same_weights(current_loaded_models[i].model):
same_weights += 1

if same_weights == len(to_unload):
unload_weight = False
else:
unload_weight = True

if not force_unload:
if unload_weights_only and unload_weight == False:
return None
else:
unload_weight = True

for i in to_unload:
current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight)

return unload_weight

class AttentionSharingPatcher(torch.nn.Module):
def __init__(self, unet, frames=2, use_control=True, rank=256):
super().__init__()
unload_model_clones(unet)

units = []
for i in range(32):
Expand Down

0 comments on commit aaee30d

Please sign in to comment.