Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi IP-Adapter for Flux pipelines #10867

Merged
merged 14 commits into from
Feb 25, 2025
4 changes: 2 additions & 2 deletions examples/community/pipeline_flux_semantic_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,9 +537,9 @@ def prepare_ip_adapter_image_embeds(
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]

if len(ip_adapter_image) != self.transformer.encoder_hid_proj.num_ip_adapters:
if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters."
)

for single_ip_adapter_image, image_proj_layer in zip(
Expand Down
29 changes: 17 additions & 12 deletions src/diffusers/pipelines/flux/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,18 +410,23 @@ def prepare_ip_adapter_image_embeds(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
)

for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers
):
for single_ip_adapter_image in ip_adapter_image:
single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)

image_embeds.append(single_image_embeds[None, :])
else:
if not isinstance(ip_adapter_image_embeds, list):
ip_adapter_image_embeds = [ip_adapter_image_embeds]

if len(ip_adapter_image_embeds) != self.transformer.encoder_hid_proj.num_ip_adapters:
raise ValueError(
f"`ip_adapter_image_embeds` must have same length as the number of IP Adapters. Got {len(ip_adapter_image_embeds)} image embeds and {self.transformer.encoder_hid_proj.num_ip_adapters} IP Adapters."
)

for single_image_embeds in ip_adapter_image_embeds:
image_embeds.append(single_image_embeds)

ip_adapter_image_embeds = []
for i, single_image_embeds in enumerate(image_embeds):
for single_image_embeds in image_embeds:
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
single_image_embeds = single_image_embeds.to(device=device)
ip_adapter_image_embeds.append(single_image_embeds)
Expand Down Expand Up @@ -868,19 +873,19 @@ def __call__(
else:
guidance = None

# TODO: Clarify this section
if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
):
negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
if isinstance(ip_adapter_image, list):
negative_ip_adapter_image = [negative_ip_adapter_image] * len(ip_adapter_image)
zeros_image = np.zeros((width, height, 3), dtype=np.uint8)
negative_ip_adapter_image_embeds = self.encode_image(zeros_image, device, 1)[None, :]
negative_ip_adapter_image_embeds = [negative_ip_adapter_image_embeds] * self.transformer.encoder_hid_proj.num_ip_adapters

elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
):
ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
if isinstance(negative_ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image] * len(negative_ip_adapter_image)
zeros_image = np.zeros((width, height, 3), dtype=np.uint8)
ip_adapter_image_embeds = self.encode_image(zeros_image, device, 1)[None, :]
ip_adapter_image_embeds = [ip_adapter_image_embeds] * self.transformer.encoder_hid_proj.num_ip_adapters

if self.joint_attention_kwargs is None:
self._joint_attention_kwargs = {}
Expand Down