Skip to content

Commit

Permalink
add IP-Adapter-FaceID-Portrait
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaohu2015 committed Jan 19, 2024
1 parent d536fa0 commit 2397f7c
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 5 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ we present IP-Adapter, an effective and lightweight adapter to achieve image pro
![arch](assets/figs/fig1.png)

## Release
- [2024/01/19] 🔥 Add IP-Adapter-FaceID-Portrait, more information can be found [here](https://huggingface.co/h94/IP-Adapter-FaceID).
- [2024/01/17] 🔥 Add an experimental version of IP-Adapter-FaceID-PlusV2 for SDXL, more information can be found [here](https://huggingface.co/h94/IP-Adapter-FaceID).
- [2024/01/04] 🔥 Add an experimental version of IP-Adapter-FaceID for SDXL, more information can be found [here](https://huggingface.co/h94/IP-Adapter-FaceID).
- [2023/12/29] 🔥 Add an experimental version of IP-Adapter-FaceID-PlusV2, more information can be found [here](https://huggingface.co/h94/IP-Adapter-FaceID).
Expand Down
21 changes: 16 additions & 5 deletions ip_adapter/ip_adapter_faceid_separate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from safetensors import safe_open
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection

from .attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor
from .utils import is_torch2_available

USE_DAFAULT_ATTN = False # should be True for visualization_attnmap
Expand Down Expand Up @@ -118,10 +117,11 @@ def forward(self, id_embeds, clip_embeds, shortcut=False, scale=1.0):


class IPAdapterFaceID:
def __init__(self, sd_pipe, ip_ckpt, device, num_tokens=4, torch_dtype=torch.float16):
def __init__(self, sd_pipe, ip_ckpt, device, num_tokens=4, n_cond=1, torch_dtype=torch.float16):
self.device = device
self.ip_ckpt = ip_ckpt
self.num_tokens = num_tokens
self.n_cond = n_cond
self.torch_dtype = torch_dtype

self.pipe = sd_pipe.to(self.device)
Expand Down Expand Up @@ -157,7 +157,7 @@ def set_ip_adapter(self):
attn_procs[name] = AttnProcessor()
else:
attn_procs[name] = IPAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, num_tokens=self.num_tokens,
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, num_tokens=self.num_tokens*self.n_cond,
).to(self.device, dtype=self.torch_dtype)
unet.set_attn_processor(attn_procs)

Expand All @@ -178,15 +178,26 @@ def load_ip_adapter(self):

@torch.inference_mode()
def get_image_embeds(self, faceid_embeds):


multi_face = False
if faceid_embeds.dim() == 3:
multi_face = True
b, n, c = faceid_embeds.shape
faceid_embeds = faceid_embeds.reshape(b*n, c)

faceid_embeds = faceid_embeds.to(self.device, dtype=self.torch_dtype)
image_prompt_embeds = self.image_proj_model(faceid_embeds)
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(faceid_embeds))
if multi_face:
c = image_prompt_embeds.size(-1)
image_prompt_embeds = image_prompt_embeds.reshape(b, -1, c)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.reshape(b, -1, c)

return image_prompt_embeds, uncond_image_prompt_embeds

def set_scale(self, scale):
for attn_processor in self.pipe.unet.attn_processors.values():
if isinstance(attn_processor, LoRAIPAttnProcessor):
if isinstance(attn_processor, IPAttnProcessor):
attn_processor.scale = scale

def generate(
Expand Down

0 comments on commit 2397f7c

Please sign in to comment.