diff --git a/lib_layerdiffusion/attention_sharing.py b/lib_layerdiffusion/attention_sharing.py index ae424fc..d363e42 100644 --- a/lib_layerdiffusion/attention_sharing.py +++ b/lib_layerdiffusion/attention_sharing.py @@ -324,10 +324,43 @@ def __init__(self, layer_list): self.layers = torch.nn.ModuleList(layer_list) +def unload_model_clones(model, unload_weights_only=True, force_unload=True): + current_loaded_models = model_management.current_loaded_models + + to_unload = [] + for i, m in enumerate(current_loaded_models): + if model.is_clone(m.model): + to_unload = [i] + to_unload + + if len(to_unload) == 0: + return True + + same_weights = 0 + for i in to_unload: + if model.clone_has_same_weights(current_loaded_models[i].model): + same_weights += 1 + + if same_weights == len(to_unload): + unload_weight = False + else: + unload_weight = True + + if not force_unload: + if unload_weights_only and unload_weight is False: + return None + else: + unload_weight = True + + for i in to_unload: + current_loaded_models.pop(i).model_unload(unpatch_weights=unload_weight) + + return unload_weight + + class AttentionSharingPatcher(torch.nn.Module): def __init__(self, unet, frames=2, use_control=True, rank=256): super().__init__() - model_management.unload_model_clones(unet) + unload_model_clones(unet) units = [] for i in range(32):