From 97b8897f6709e6fac333ad31e1eea19a3cbe4553 Mon Sep 17 00:00:00 2001 From: Hamish Friedlander Date: Sun, 25 Dec 2022 14:35:37 +1300 Subject: [PATCH] Fix typing-related syntax errors in Python < 3.10 introduced in recent refactor (#79) --- lora_diffusion/lora.py | 48 +++++++++++++++++++++++------------------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py index 8355839..7419086 100644 --- a/lora_diffusion/lora.py +++ b/lora_diffusion/lora.py @@ -1,7 +1,6 @@ import math from itertools import groupby -from types import UnionType -from typing import Callable, Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union import numpy as np import PIL @@ -39,7 +38,7 @@ def forward(self, input): def _find_children( model, - search_class: type[nn.Module] | UnionType = nn.Linear, + search_class: List[Type[nn.Module]] = [nn.Linear], ): """ Find all modules of a certain class (or union of classes). @@ -50,15 +49,15 @@ def _find_children( # For each target find every linear_class module that isn't a child of a LoraInjectedLinear for parent in model.modules(): for name, module in parent.named_children(): - if isinstance(module, search_class): + if any([isinstance(module, _class) for _class in search_class]): yield parent, name, module def _find_modules( model, - ancestor_class: set[str] = DEFAULT_TARGET_REPLACE, - search_class: type[nn.Module] | UnionType = nn.Linear, - exclude_children_of: type[nn.Module] | UnionType = LoraInjectedLinear, + ancestor_class: Set[str] = DEFAULT_TARGET_REPLACE, + search_class: List[Type[nn.Module]] = [nn.Linear], + exclude_children_of: Optional[List[Type[nn.Module]]] = [LoraInjectedLinear], ): """ Find all modules of a certain class (or union of classes) that are direct or @@ -78,14 +77,16 @@ def _find_modules( # For each target find every linear_class module that isn't a child of a LoraInjectedLinear for ancestor in ancestors: for fullname, module in ancestor.named_modules(): - if isinstance(module, search_class): + if any([isinstance(module, _class) for _class in search_class]): # Find the direct parent if this is a descendant, not a child, of target *path, name = fullname.split(".") parent = ancestor while path: parent = parent.get_submodule(path.pop(0)) # Skip this linear if it's a child of a LoraInjectedLinear - if exclude_children_of and isinstance(parent, exclude_children_of): + if exclude_children_of and any( + [isinstance(parent, _class) for _class in exclude_children_of] + ): continue # Otherwise, yield it yield parent, name, module @@ -93,7 +94,7 @@ def _find_modules( def inject_trainable_lora( model: nn.Module, - target_replace_module: set[str] = DEFAULT_TARGET_REPLACE, + target_replace_module: Set[str] = DEFAULT_TARGET_REPLACE, r: int = 4, loras=None, # path to lora .pt ): @@ -108,7 +109,7 @@ def inject_trainable_lora( loras = torch.load(loras) for _module, name, _child_module in _find_modules( - model, target_replace_module, search_class=nn.Linear + model, target_replace_module, search_class=[nn.Linear] ): weight = _child_module.weight bias = _child_module.bias @@ -144,7 +145,7 @@ def extract_lora_ups_down(model, target_replace_module=DEFAULT_TARGET_REPLACE): loras = [] for _m, _n, _child_module in _find_modules( - model, target_replace_module, search_class=LoraInjectedLinear + model, target_replace_module, search_class=[LoraInjectedLinear] ): loras.append((_child_module.lora_up, _child_module.lora_down)) @@ -182,7 +183,7 @@ def save_lora_as_json(model, path="./lora.json"): def save_safeloras( - modelmap: dict[str, tuple[nn.Module, set[str]]] = {}, + modelmap: Dict[str, Tuple[nn.Module, Set[str]]] = {}, outpath="./lora.safetensors", ): """ @@ -214,7 +215,7 @@ def save_safeloras( def convert_loras_to_safeloras( - modelmap: dict[str, tuple[str, set[str], int]] = {}, + modelmap: Dict[str, Tuple[str, Set[str], int]] = {}, outpath="./lora.safetensors", ): """ @@ -252,7 +253,7 @@ def convert_loras_to_safeloras( def parse_safeloras( safeloras, -) -> dict[str, tuple[list[nn.parameter.Parameter], list[int], list[str]]]: +) -> Dict[str, Tuple[List[nn.parameter.Parameter], List[int], List[str]]]: """ Converts a loaded safetensor file that contains a set of module Loras into Parameters and other information @@ -315,7 +316,7 @@ def weight_apply_lora( ): for _m, _n, _child_module in _find_modules( - model, target_replace_module, search_class=nn.Linear + model, target_replace_module, search_class=[nn.Linear] ): weight = _child_module.weight @@ -331,7 +332,7 @@ def monkeypatch_lora( model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, r: int = 4 ): for _module, name, _child_module in _find_modules( - model, target_replace_module, search_class=nn.Linear + model, target_replace_module, search_class=[nn.Linear] ): weight = _child_module.weight bias = _child_module.bias @@ -366,7 +367,7 @@ def monkeypatch_replace_lora( model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, r: int = 4 ): for _module, name, _child_module in _find_modules( - model, target_replace_module, search_class=LoraInjectedLinear + model, target_replace_module, search_class=[LoraInjectedLinear] ): weight = _child_module.linear.weight bias = _child_module.linear.bias @@ -398,10 +399,13 @@ def monkeypatch_replace_lora( def monkeypatch_or_replace_lora( - model, loras, target_replace_module=DEFAULT_TARGET_REPLACE, r: int | list[int] = 4 + model, + loras, + target_replace_module=DEFAULT_TARGET_REPLACE, + r: Union[int, List[int]] = 4, ): for _module, name, _child_module in _find_modules( - model, target_replace_module, search_class=nn.Linear | LoraInjectedLinear + model, target_replace_module, search_class=[nn.Linear, LoraInjectedLinear] ): _source = ( _child_module.linear @@ -453,7 +457,7 @@ def monkeypatch_or_replace_safeloras(models, safeloras): def monkeypatch_remove_lora(model): for _module, name, _child_module in _find_children( - model, search_class=LoraInjectedLinear + model, search_class=[LoraInjectedLinear] ): _source = _child_module.linear weight, bias = _source.weight, _source.bias @@ -475,7 +479,7 @@ def monkeypatch_add_lora( beta: float = 1.0, ): for _module, name, _child_module in _find_modules( - model, target_replace_module, search_class=LoraInjectedLinear + model, target_replace_module, search_class=[LoraInjectedLinear] ): weight = _child_module.linear.weight