Skip to content

Commit

Permalink
set[type] also doesnt work on older Pythons, replace with the more co…
Browse files Browse the repository at this point in the history
…mpatible Set[type]
  • Loading branch information
hafriedlander committed Dec 24, 2022
1 parent 580941a commit ea0238d
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions lora_diffusion/lora.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import math
from itertools import groupby
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Set, Tuple, Union

import numpy as np
import PIL
Expand Down Expand Up @@ -55,7 +55,7 @@ def _find_children(

def _find_modules(
model,
ancestor_class: set[str] = DEFAULT_TARGET_REPLACE,
ancestor_class: Set[str] = DEFAULT_TARGET_REPLACE,
search_class: list[type[nn.Module]] = [nn.Linear],
exclude_children_of: Optional[list[type[nn.Module]]] = [LoraInjectedLinear],
):
Expand Down Expand Up @@ -94,7 +94,7 @@ def _find_modules(

def inject_trainable_lora(
model: nn.Module,
target_replace_module: set[str] = DEFAULT_TARGET_REPLACE,
target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE,
r: int = 4,
loras=None, # path to lora .pt
):
Expand Down Expand Up @@ -183,7 +183,7 @@ def save_lora_as_json(model, path="./lora.json"):


def save_safeloras(
modelmap: dict[str, tuple[nn.Module, set[str]]] = {},
modelmap: dict[str, tuple[nn.Module, Set[str]]] = {},
outpath="./lora.safetensors",
):
"""
Expand Down Expand Up @@ -215,7 +215,7 @@ def save_safeloras(


def convert_loras_to_safeloras(
modelmap: dict[str, tuple[str, set[str], int]] = {},
modelmap: dict[str, tuple[str, Set[str], int]] = {},
outpath="./lora.safetensors",
):
"""
Expand Down

0 comments on commit ea0238d

Please sign in to comment.