Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi IP-Adapter for Flux pipelines #10867

Merged
merged 14 commits into from
Feb 25, 2025
Merged
49 changes: 29 additions & 20 deletions src/diffusers/loaders/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
from ..utils import (
USE_PEFT_BACKEND,
_get_detailed_type,
_get_model_file,
_is_valid_type,
is_accelerate_available,
is_torch_version,
is_transformers_available,
Expand Down Expand Up @@ -577,29 +579,36 @@ def LinearStrengthModel(start, finish, size):
pipeline.set_ip_adapter_scale(ip_strengths)
```
"""
transformer = self.transformer
if not isinstance(scale, list):
scale = [[scale] * transformer.config.num_layers]
elif isinstance(scale, list) and isinstance(scale[0], int) or isinstance(scale[0], float):
if len(scale) != transformer.config.num_layers:
raise ValueError(f"Expected list of {transformer.config.num_layers} scales, got {len(scale)}.")

scale_type = Union[int, float]
num_ip_adapters = self.transformer.encoder_hid_proj.num_ip_adapters
num_layers = self.transformer.config.num_layers

# Single value for all layers of all IP-Adapters
if isinstance(scale, scale_type):
scale = [scale for _ in range(num_ip_adapters)]
# List of per-layer scales for a single IP-Adapter
elif _is_valid_type(scale, List[scale_type]) and num_ip_adapters == 1:
scale = [scale]
# Invalid scale type
elif not _is_valid_type(scale, List[Union[scale_type, List[scale_type]]]):
raise TypeError(f"Unexpected type {_get_detailed_type(scale)} for scale.")

scale_configs = scale
if len(scale) != num_ip_adapters:
raise ValueError(f"Cannot assign {len(scale)} scales to {num_ip_adapters} IP-Adapters.")

key_id = 0
for attn_name, attn_processor in transformer.attn_processors.items():
if isinstance(attn_processor, (FluxIPAdapterJointAttnProcessor2_0)):
if len(scale_configs) != len(attn_processor.scale):
raise ValueError(
f"Cannot assign {len(scale_configs)} scale_configs to "
f"{len(attn_processor.scale)} IP-Adapter."
)
elif len(scale_configs) == 1:
scale_configs = scale_configs * len(attn_processor.scale)
for i, scale_config in enumerate(scale_configs):
attn_processor.scale[i] = scale_config[key_id]
key_id += 1
if any(len(s) != num_layers for s in scale if isinstance(s, list)):
invalid_scale_sizes = {len(s) for s in scale if isinstance(s, list)} - {num_layers}
raise ValueError(
f"Expected list of {num_layers} scales, got {', '.join(str(x) for x in invalid_scale_sizes)}."
)

# Scalars are transformed to lists with length num_layers
scale_configs = [[s] * num_layers if isinstance(s, scale_type) else s for s in scale]

# Set scales. zip over scale_configs prevents going into single transformer layers
for attn_processor, *scale in zip(self.transformer.attn_processors.values(), *scale_configs):
attn_processor.scale = scale

def unload_ip_adapter(self):
"""
Expand Down
15 changes: 8 additions & 7 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2780,9 +2780,8 @@ def __call__(

# IP-adapter
ip_query = hidden_states_query_proj
ip_attn_output = None
# for ip-adapter
# TODO: support for multiple adapters
ip_attn_output = torch.zeros_like(hidden_states)

for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
):
Expand All @@ -2793,12 +2792,14 @@ def __call__(
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_attn_output = F.scaled_dot_product_attention(
current_ip_hidden_states = F.scaled_dot_product_attention(
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_attn_output = ip_attn_output.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_attn_output = scale * ip_attn_output
ip_attn_output = ip_attn_output.to(ip_query.dtype)
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
batch_size, -1, attn.heads * head_dim
)
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
ip_attn_output += scale * current_ip_hidden_states

return hidden_states, encoder_hidden_states, ip_attn_output
else:
Expand Down
5 changes: 5 additions & 0 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2583,6 +2583,11 @@ def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[
super().__init__()
self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)

@property
def num_ip_adapters(self) -> int:
"""Number of IP-Adapters loaded."""
return len(self.image_projection_layers)

def forward(self, image_embeds: List[torch.Tensor]):
projected_image_embeds = []

Expand Down
22 changes: 15 additions & 7 deletions src/diffusers/pipelines/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,23 +405,28 @@ def prepare_ip_adapter_image_embeds(
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]

if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers):
if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters."
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
)

for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers
):
for single_ip_adapter_image in ip_adapter_image:
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)

image_embeds.append(single_image_embeds[None, :])
else:
if not isinstance(ip_adapter_image_embeds, list):
ip_adapter_image_embeds = [ip_adapter_image_embeds]

if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
raise ValueError(
f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
)

for single_image_embeds in ip_adapter_image_embeds:
image_embeds.append(single_image_embeds)

ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
for single_image_embeds in image_embeds:
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
Expand Down Expand Up @@ -872,10 +877,13 @@ def __call__(
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
):
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
negative_ip_adapter_image = [negative_ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters

elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
):
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
ip_adapter_image = [ip_adapter_image] * self.transformer.encoder_hid_proj.num_ip_adapters

if self.joint_attention_kwargs is None:
self._joint_attention_kwargs = {}
Expand Down
75 changes: 1 addition & 74 deletions src/diffusers/pipelines/pipeline_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import re
import warnings
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union, get_args, get_origin
from typing import Any, Callable, Dict, List, Optional, Union

import requests
import torch
Expand Down Expand Up @@ -1059,76 +1059,3 @@ def _maybe_raise_error_for_incorrect_transformers(config_dict):
break
if has_transformers_component and not is_transformers_version(">", "4.47.1"):
raise ValueError("Please upgrade your `transformers` installation to the latest version to use DDUF.")


def _is_valid_type(obj: Any, class_or_tuple: Union[Type, Tuple[Type, ...]]) -> bool:
"""
Checks if an object is an instance of any of the provided types. For collections, it checks if every element is of
the correct type as well.
"""
if not isinstance(class_or_tuple, tuple):
class_or_tuple = (class_or_tuple,)

# Unpack unions
unpacked_class_or_tuple = []
for t in class_or_tuple:
if get_origin(t) is Union:
unpacked_class_or_tuple.extend(get_args(t))
else:
unpacked_class_or_tuple.append(t)
class_or_tuple = tuple(unpacked_class_or_tuple)

if Any in class_or_tuple:
return True

obj_type = type(obj)
# Classes with obj's type
class_or_tuple = {t for t in class_or_tuple if isinstance(obj, get_origin(t) or t)}

# Singular types (e.g. int, ControlNet, ...)
# Untyped collections (e.g. List, but not List[int])
elem_class_or_tuple = {get_args(t) for t in class_or_tuple}
if () in elem_class_or_tuple:
return True
# Typed lists or sets
elif obj_type in (list, set):
return any(all(_is_valid_type(x, t) for x in obj) for t in elem_class_or_tuple)
# Typed tuples
elif obj_type is tuple:
return any(
# Tuples with any length and single type (e.g. Tuple[int, ...])
(len(t) == 2 and t[-1] is Ellipsis and all(_is_valid_type(x, t[0]) for x in obj))
or
# Tuples with fixed length and any types (e.g. Tuple[int, str])
(len(obj) == len(t) and all(_is_valid_type(x, tt) for x, tt in zip(obj, t)))
for t in elem_class_or_tuple
)
# Typed dicts
elif obj_type is dict:
return any(
all(_is_valid_type(k, kt) and _is_valid_type(v, vt) for k, v in obj.items())
for kt, vt in elem_class_or_tuple
)

else:
return False


def _get_detailed_type(obj: Any) -> Type:
"""
Gets a detailed type for an object, including nested types for collections.
"""
obj_type = type(obj)

if obj_type in (list, set):
obj_origin_type = List if obj_type is list else Set
elems_type = Union[tuple({_get_detailed_type(x) for x in obj})]
return obj_origin_type[elems_type]
elif obj_type is tuple:
return Tuple[tuple(_get_detailed_type(x) for x in obj)]
elif obj_type is dict:
keys_type = Union[tuple({_get_detailed_type(k) for k in obj.keys()})]
values_type = Union[tuple({_get_detailed_type(k) for k in obj.values()})]
return Dict[keys_type, values_type]
else:
return obj_type
4 changes: 2 additions & 2 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
DEPRECATED_REVISION_ARGS,
BaseOutput,
PushToHubMixin,
_get_detailed_type,
_is_valid_type,
is_accelerate_available,
is_accelerate_version,
is_torch_npu_available,
Expand All @@ -78,12 +80,10 @@
_fetch_class_library_tuple,
_get_custom_components_and_folders,
_get_custom_pipeline_class,
_get_detailed_type,
_get_final_device_map,
_get_ignore_patterns,
_get_pipeline_class,
_identify_model_variants,
_is_valid_type,
_maybe_raise_error_for_incorrect_transformers,
_maybe_raise_warning_for_inpainting,
_resolve_custom_pipeline_and_cls,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@
convert_state_dict_to_peft,
convert_unet_state_dict_to_peft,
)
from .typing_utils import _get_detailed_type, _is_valid_type


logger = get_logger(__name__)
Expand Down
91 changes: 91 additions & 0 deletions src/diffusers/utils/typing_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Typing utilities: Utilities related to type checking and validation
"""

from typing import Any, Dict, List, Set, Tuple, Type, Union, get_args, get_origin


def _is_valid_type(obj: Any, class_or_tuple: Union[Type, Tuple[Type, ...]]) -> bool:
"""
Checks if an object is an instance of any of the provided types. For collections, it checks if every element is of
the correct type as well.
"""
if not isinstance(class_or_tuple, tuple):
class_or_tuple = (class_or_tuple,)

# Unpack unions
unpacked_class_or_tuple = []
for t in class_or_tuple:
if get_origin(t) is Union:
unpacked_class_or_tuple.extend(get_args(t))
else:
unpacked_class_or_tuple.append(t)
class_or_tuple = tuple(unpacked_class_or_tuple)

if Any in class_or_tuple:
return True

obj_type = type(obj)
# Classes with obj's type
class_or_tuple = {t for t in class_or_tuple if isinstance(obj, get_origin(t) or t)}

# Singular types (e.g. int, ControlNet, ...)
# Untyped collections (e.g. List, but not List[int])
elem_class_or_tuple = {get_args(t) for t in class_or_tuple}
if () in elem_class_or_tuple:
return True
# Typed lists or sets
elif obj_type in (list, set):
return any(all(_is_valid_type(x, t) for x in obj) for t in elem_class_or_tuple)
# Typed tuples
elif obj_type is tuple:
return any(
# Tuples with any length and single type (e.g. Tuple[int, ...])
(len(t) == 2 and t[-1] is Ellipsis and all(_is_valid_type(x, t[0]) for x in obj))
or
# Tuples with fixed length and any types (e.g. Tuple[int, str])
(len(obj) == len(t) and all(_is_valid_type(x, tt) for x, tt in zip(obj, t)))
for t in elem_class_or_tuple
)
# Typed dicts
elif obj_type is dict:
return any(
all(_is_valid_type(k, kt) and _is_valid_type(v, vt) for k, v in obj.items())
for kt, vt in elem_class_or_tuple
)

else:
return False


def _get_detailed_type(obj: Any) -> Type:
"""
Gets a detailed type for an object, including nested types for collections.
"""
obj_type = type(obj)

if obj_type in (list, set):
obj_origin_type = List if obj_type is list else Set
elems_type = Union[tuple({_get_detailed_type(x) for x in obj})]
return obj_origin_type[elems_type]
elif obj_type is tuple:
return Tuple[tuple(_get_detailed_type(x) for x in obj)]
elif obj_type is dict:
keys_type = Union[tuple({_get_detailed_type(k) for k in obj.keys()})]
values_type = Union[tuple({_get_detailed_type(k) for k in obj.values()})]
return Dict[keys_type, values_type]
else:
return obj_type
Loading
Loading