Skip to content

Commit

Permalink
fix 37425fb
Browse files Browse the repository at this point in the history
Things to understand:

- subscripted generic basic types (e.g. `list[int]`) are types.GenericAlias;
- subscripted generic classes are `typing._GenericAlias`;
- neither can be used with `isinstance()`;
- get_origin is the cleanest way to check for this.
  • Loading branch information
catwell committed Feb 6, 2024
1 parent f9305aa commit 98fce82
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 7 deletions.
13 changes: 9 additions & 4 deletions src/refiners/fluxion/adapters/lora.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Generic, TypeVar, cast
from typing import Any, Generic, Iterator, TypeVar, cast

from torch import Tensor, device as Device, dtype as DType
from torch.nn import Parameter as TorchParameter
Expand Down Expand Up @@ -385,20 +385,25 @@ def __init__(self, target: fl.WeightedModule, /, *loras: Lora[Any]) -> None:
with self.setup_adapter(target):
super().__init__(target, *loras)

@property
def lora_layers(self) -> Iterator[Lora[Any]]:
"""The LoRA layers."""
return cast(Iterator[Lora[Any]], self.layers(Lora))

@property
def names(self) -> list[str]:
"""The names of the LoRA layers."""
return [lora.name for lora in self.layers(Lora[Any])]
return [lora.name for lora in self.lora_layers]

@property
def loras(self) -> dict[str, Lora[Any]]:
"""The LoRA layers indexed by name."""
return {lora.name: lora for lora in self.layers(Lora[Any])}
return {lora.name: lora for lora in self.lora_layers}

@property
def scales(self) -> dict[str, float]:
"""The scales of the LoRA layers indexed by names."""
return {lora.name: lora.scale for lora in self.layers(Lora[Any])}
return {lora.name: lora.scale for lora in self.lora_layers}

@scales.setter
def scale(self, values: dict[str, float]) -> None:
Expand Down
6 changes: 5 additions & 1 deletion src/refiners/fluxion/layers/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
import traceback
from collections import defaultdict
from typing import Any, Callable, Iterable, Iterator, Sequence, TypeVar, cast, overload
from typing import Any, Callable, Iterable, Iterator, Sequence, TypeVar, cast, get_origin, overload

import torch
from torch import Tensor, cat, device as Device, dtype as DType
Expand Down Expand Up @@ -349,6 +349,10 @@ def walk(
Yields:
Each module that matches the predicate.
"""

if get_origin(predicate) is not None:
raise ValueError(f"subscripted generics cannot be used as predicates")

if isinstance(predicate, type):
# if the predicate is a Module type
# build a predicate function that matches the type
Expand Down
6 changes: 4 additions & 2 deletions src/refiners/foundationals/latent_diffusion/lora.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any
from typing import Any, Iterator, cast
from warnings import warn

from torch import Tensor
Expand Down Expand Up @@ -193,7 +193,9 @@ def update_scales(self, scales: dict[str, float], /) -> None:
@property
def loras(self) -> list[Lora[Any]]:
"""List of all the LoRA layers managed by the SDLoraManager."""
return list(self.unet.layers(Lora[Any])) + list(self.clip_text_encoder.layers(Lora[Any]))
unet_layers = cast(Iterator[Lora[Any]], self.unet.layers(Lora))
text_encoder_layers = cast(Iterator[Lora[Any]], self.clip_text_encoder.layers(Lora))
return [*unet_layers, *text_encoder_layers]

@property
def names(self) -> list[str]:
Expand Down

0 comments on commit 98fce82

Please sign in to comment.