Skip to content

Commit

Permalink
Fix typing-related syntax errors in Python < 3.10 introduced in recen…
Browse files Browse the repository at this point in the history
…t refactor (#79)
  • Loading branch information
hafriedlander authored Dec 25, 2022
1 parent d590799 commit 97b8897
Showing 1 changed file with 26 additions and 22 deletions.
48 changes: 26 additions & 22 deletions lora_diffusion/lora.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import math
from itertools import groupby
from types import UnionType
from typing import Callable, Dict, List, Optional, Tuple
from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union

import numpy as np
import PIL
Expand Down Expand Up @@ -39,7 +38,7 @@ def forward(self, input):

def _find_children(
model,
search_class: type[nn.Module] | UnionType = nn.Linear,
search_class: List[Type[nn.Module]] = [nn.Linear],
):
"""
Find all modules of a certain class (or union of classes).
Expand All @@ -50,15 +49,15 @@ def _find_children(
# For each target find every linear_class module that isn't a child of a LoraInjectedLinear
for parent in model.modules():
for name, module in parent.named_children():
if isinstance(module, search_class):
if any([isinstance(module, _class) for _class in search_class]):
yield parent, name, module


def _find_modules(
model,
ancestor_class: set[str] = DEFAULT_TARGET_REPLACE,
search_class: type[nn.Module] | UnionType = nn.Linear,
exclude_children_of: type[nn.Module] | UnionType = LoraInjectedLinear,
ancestor_class: Set[str] = DEFAULT_TARGET_REPLACE,
search_class: List[Type[nn.Module]] = [nn.Linear],
exclude_children_of: Optional[List[Type[nn.Module]]] = [LoraInjectedLinear],
):
"""
Find all modules of a certain class (or union of classes) that are direct or
Expand All @@ -78,22 +77,24 @@ def _find_modules(
# For each target find every linear_class module that isn't a child of a LoraInjectedLinear
for ancestor in ancestors:
for fullname, module in ancestor.named_modules():
if isinstance(module, search_class):
if any([isinstance(module, _class) for _class in search_class]):
# Find the direct parent if this is a descendant, not a child, of target
*path, name = fullname.split(".")
parent = ancestor
while path:
parent = parent.get_submodule(path.pop(0))
# Skip this linear if it's a child of a LoraInjectedLinear
if exclude_children_of and isinstance(parent, exclude_children_of):
if exclude_children_of and any(
[isinstance(parent, _class) for _class in exclude_children_of]
):
continue
# Otherwise, yield it
yield parent, name, module


def inject_trainable_lora(
model: nn.Module,
target_replace_module: set[str] = DEFAULT_TARGET_REPLACE,
target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE,
r: int = 4,
loras=None, # path to lora .pt
):
Expand All @@ -108,7 +109,7 @@ def inject_trainable_lora(
loras = torch.load(loras)

for _module, name, _child_module in _find_modules(
model, target_replace_module, search_class=nn.Linear
model, target_replace_module, search_class=[nn.Linear]
):
weight = _child_module.weight
bias = _child_module.bias
Expand Down Expand Up @@ -144,7 +145,7 @@ def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE):
loras = []

for _m, _n, _child_module in _find_modules(
model, target_replace_module, search_class=LoraInjectedLinear
model, target_replace_module, search_class=[LoraInjectedLinear]
):
loras.append((_child_module.lora_up, _child_module.lora_down))

Expand Down Expand Up @@ -182,7 +183,7 @@ def save_lora_as_json(model, path="./lora.json"):


def save_safeloras(
modelmap: dict[str, tuple[nn.Module, set[str]]] = {},
modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {},
outpath="./lora.safetensors",
):
"""
Expand Down Expand Up @@ -214,7 +215,7 @@ def save_safeloras(


def convert_loras_to_safeloras(
modelmap: dict[str, tuple[str, set[str], int]] = {},
modelmap: Dict[str, Tuple[str, Set[str], int]] = {},
outpath="./lora.safetensors",
):
"""
Expand Down Expand Up @@ -252,7 +253,7 @@ def convert_loras_to_safeloras(

def parse_safeloras(
safeloras,
) -> dict[str, tuple[list[nn.parameter.Parameter], list[int], list[str]]]:
) -> Dict[str, Tuple[List[nn.parameter.Parameter], List[int], List[str]]]:
"""
Converts a loaded safetensor file that contains a set of module Loras
into Parameters and other information
Expand Down Expand Up @@ -315,7 +316,7 @@ def weight_apply_lora(
):

for _m, _n, _child_module in _find_modules(
model, target_replace_module, search_class=nn.Linear
model, target_replace_module, search_class=[nn.Linear]
):
weight = _child_module.weight

Expand All @@ -331,7 +332,7 @@ def monkeypatch_lora(
model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, r: int = 4
):
for _module, name, _child_module in _find_modules(
model, target_replace_module, search_class=nn.Linear
model, target_replace_module, search_class=[nn.Linear]
):
weight = _child_module.weight
bias = _child_module.bias
Expand Down Expand Up @@ -366,7 +367,7 @@ def monkeypatch_replace_lora(
model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, r: int = 4
):
for _module, name, _child_module in _find_modules(
model, target_replace_module, search_class=LoraInjectedLinear
model, target_replace_module, search_class=[LoraInjectedLinear]
):
weight = _child_module.linear.weight
bias = _child_module.linear.bias
Expand Down Expand Up @@ -398,10 +399,13 @@ def monkeypatch_replace_lora(


def monkeypatch_or_replace_lora(
model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, r: int | list[int] = 4
model,
loras,
target_replace_module=DEFAULT_TARGET_REPLACE,
r: Union[int, List[int]] = 4,
):
for _module, name, _child_module in _find_modules(
model, target_replace_module, search_class=nn.Linear | LoraInjectedLinear
model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear]
):
_source = (
_child_module.linear
Expand Down Expand Up @@ -453,7 +457,7 @@ def monkeypatch_or_replace_safeloras(models, safeloras):

def monkeypatch_remove_lora(model):
for _module, name, _child_module in _find_children(
model, search_class=LoraInjectedLinear
model, search_class=[LoraInjectedLinear]
):
_source = _child_module.linear
weight, bias = _source.weight, _source.bias
Expand All @@ -475,7 +479,7 @@ def monkeypatch_add_lora(
beta: float = 1.0,
):
for _module, name, _child_module in _find_modules(
model, target_replace_module, search_class=LoraInjectedLinear
model, target_replace_module, search_class=[LoraInjectedLinear]
):
weight = _child_module.linear.weight

Expand Down

0 comments on commit 97b8897

Please sign in to comment.