diff --git a/lora_diffusion/__init__.py b/lora_diffusion/__init__.py index 0efbb94..6434b23 100644 --- a/lora_diffusion/__init__.py +++ b/lora_diffusion/__init__.py @@ -1,2 +1,3 @@ from .lora import * from .dataset import * +from .utils import * diff --git a/lora_diffusion/cli_lora_pti.py b/lora_diffusion/cli_lora_pti.py index 8bb91dc..8e9aa98 100644 --- a/lora_diffusion/cli_lora_pti.py +++ b/lora_diffusion/cli_lora_pti.py @@ -142,6 +142,10 @@ def collate_fn(examples): "input_ids": input_ids, "pixel_values": pixel_values, } + + if examples[0].get("mask", None) is not None: + batch["mask"] = torch.stack([example["mask"] for example in examples]) + return batch train_dataloader = torch.utils.data.DataLoader( @@ -149,14 +153,15 @@ def collate_fn(examples): batch_size=train_batch_size, shuffle=True, collate_fn=collate_fn, - num_workers=2, ) return train_dataloader @torch.autocast("cuda") -def loss_step(batch, unet, vae, text_encoder, scheduler, weight_dtype): +def loss_step( + batch, unet, vae, text_encoder, scheduler, weight_dtype, t_mutliplier=1.0 +): latents = vae.encode( batch["pixel_values"].to(dtype=weight_dtype).to(unet.device) ).latent_dist.sample() @@ -167,7 +172,7 @@ def loss_step(batch, unet, vae, text_encoder, scheduler, weight_dtype): timesteps = torch.randint( 0, - scheduler.config.num_train_timesteps, + int(scheduler.config.num_train_timesteps * t_mutliplier), (bsz,), device=latents.device, ) @@ -186,6 +191,31 @@ def loss_step(batch, unet, vae, text_encoder, scheduler, weight_dtype): else: raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}") + if batch.get("mask", None) is not None: + + mask = ( + batch["mask"] + .to(model_pred.device) + .reshape( + model_pred.shape[0], 1, model_pred.shape[2] * 8, model_pred.shape[3] * 8 + ) + ) + # resize to match model_pred + mask = ( + F.interpolate( + mask.float(), + size=model_pred.shape[-2:], + mode="nearest", + ) + + 0.1 + ) + + mask = mask / mask.mean() + + model_pred = model_pred * mask + + target = target * mask + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") return loss @@ -273,7 +303,15 @@ def perform_tuning( for batch in dataloader: optimizer.zero_grad() - loss = loss_step(batch, unet, vae, text_encoder, scheduler, weight_dtype) + loss = loss_step( + batch, + unet, + vae, + text_encoder, + scheduler, + weight_dtype, + t_mutliplier=0.8, + ) loss.backward() torch.nn.utils.clip_grad_norm_( itertools.chain(unet.parameters(), text_encoder.parameters()), 1.0 @@ -322,7 +360,7 @@ def train( class_data_dir: Optional[str] = None, stochastic_attribute: Optional[str] = None, perform_inversion: bool = True, - use_template: Optional[str] = Literal[None, "object", "style"], + use_template: Literal[None, "object", "style"] = None, placeholder_tokens: str = "", placeholder_token_at_data: Optional[str] = None, initializer_tokens: str = "dog", @@ -332,7 +370,6 @@ def train( num_class_images: int = 100, seed: int = 42, resolution: int = 512, - center_crop: bool = False, color_jitter: bool = True, train_batch_size: int = 1, sample_batch_size: int = 1, @@ -350,6 +387,7 @@ def train( learning_rate_ti: float = 5e-4, continue_inversion: bool = True, continue_inversion_lr: Optional[float] = None, + use_face_segmentation_condition: bool = False, scale_lr: bool = False, lr_scheduler: str = "constant", lr_warmup_steps: int = 100, @@ -413,8 +451,8 @@ def train( class_prompt=class_prompt, tokenizer=tokenizer, size=resolution, - center_crop=center_crop, color_jitter=color_jitter, + use_face_segmentation_condition=use_face_segmentation_condition, ) train_dataloader = text2img_dataloader( diff --git a/lora_diffusion/dataset.py b/lora_diffusion/dataset.py index 6d2d712..ab23ddf 100644 --- a/lora_diffusion/dataset.py +++ b/lora_diffusion/dataset.py @@ -1,11 +1,12 @@ from torch.utils.data import Dataset from typing import List, Tuple, Dict, Union, Optional -from PIL import Image +from PIL import Image, ImageFilter from torchvision import transforms from pathlib import Path - +import cv2 import random +import numpy as np OBJECT_TEMPLATE = [ "a photo of a {}", @@ -90,12 +91,12 @@ def __init__( class_prompt=None, size=512, h_flip=True, - center_crop=False, color_jitter=False, resize=True, + use_face_segmentation_condition=False, + blur_amount: int = 70, ): self.size = size - self.center_crop = center_crop self.tokenizer = tokenizer self.resize = resize @@ -121,7 +122,7 @@ def __init__( self.class_prompt = class_prompt else: self.class_data_root = None - + self.h_flip = h_flip self.image_transforms = transforms.Compose( [ transforms.Resize( @@ -129,17 +130,24 @@ def __init__( ) if resize else transforms.Lambda(lambda x: x), - transforms.ColorJitter(0.2, 0.1) + transforms.ColorJitter(0.1, 0.1) if color_jitter else transforms.Lambda(lambda x: x), - transforms.RandomHorizontalFlip() - if h_flip - else transforms.Lambda(lambda x: x), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) + self.use_face_segmentation_condition = use_face_segmentation_condition + if self.use_face_segmentation_condition: + import mediapipe as mp + + mp_face_detection = mp.solutions.face_detection + self.face_detection = mp_face_detection.FaceDetection( + model_selection=1, min_detection_confidence=0.5 + ) + self.blur_amount = blur_amount + def __len__(self): return self._length @@ -163,6 +171,61 @@ def __getitem__(self, index): for token, value in self.token_map.items(): text = text.replace(token, value) + print(text) + + if self.use_face_segmentation_condition: + image = cv2.imread( + str(self.instance_images_path[index % self.num_instance_images]) + ) + results = self.face_detection.process( + cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + ) + black_image = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8) + + if results.detections: + + for detection in results.detections: + + x_min = int( + detection.location_data.relative_bounding_box.xmin + * image.shape[1] + ) + y_min = int( + detection.location_data.relative_bounding_box.ymin + * image.shape[0] + ) + width = int( + detection.location_data.relative_bounding_box.width + * image.shape[1] + ) + height = int( + detection.location_data.relative_bounding_box.height + * image.shape[0] + ) + + # draw the colored rectangle + black_image[y_min : y_min + height, x_min : x_min + width] = 255 + + # blur the image + black_image = Image.fromarray(black_image, mode="L").filter( + ImageFilter.GaussianBlur(radius=self.blur_amount) + ) + # to tensor + black_image = transforms.ToTensor()(black_image) + # resize as the instance image + black_image = transforms.Resize( + self.size, interpolation=transforms.InterpolationMode.BILINEAR + )(black_image) + + example["mask"] = black_image + + if self.h_flip and random.random() > 0.5: + hflip = transforms.RandomHorizontalFlip(p=1) + + example["instance_images"] = hflip(example["instance_images"]) + if self.use_face_segmentation_condition: + example["mask"] = hflip(example["mask"]) + example["instance_prompt_ids"] = self.tokenizer( text, padding="do_not_pad", diff --git a/lora_diffusion/utils.py b/lora_diffusion/utils.py new file mode 100644 index 0000000..cde4ba4 --- /dev/null +++ b/lora_diffusion/utils.py @@ -0,0 +1,12 @@ +from PIL import Image + + +def image_grid(_imgs, rows, cols): + + w, h = _imgs[0].size + grid = Image.new("RGB", size=(cols * w, rows * h)) + grid_w, grid_h = grid.size + + for i, img in enumerate(_imgs): + grid.paste(img, box=(i % cols * w, i // cols * h)) + return grid diff --git a/scripts/use_face_conditioning_example.sh b/scripts/use_face_conditioning_example.sh new file mode 100644 index 0000000..2456b1a --- /dev/null +++ b/scripts/use_face_conditioning_example.sh @@ -0,0 +1,31 @@ +export MODEL_NAME="runwayml/stable-diffusion-v1-5" +export INSTANCE_DIR="./data_example_small" +export OUTPUT_DIR="./exps/output_example_enid_w_mask" + +lora_pti \ + --pretrained_model_name_or_path=$MODEL_NAME \ + --instance_data_dir=$INSTANCE_DIR \ + --output_dir=$OUTPUT_DIR \ + --train_text_encoder \ + --resolution=512 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=1 \ + --learning_rate_unet=3e-4 \ + --learning_rate_text=3e-4 \ + --learning_rate_ti=1e-3 \ + --color_jitter \ + --lr_scheduler="constant" \ + --lr_warmup_steps=100 \ + --placeholder_tokens="||" \ + --placeholder_token_at_data="|"\ + --initializer_tokens="girl||" \ + --save_steps=100 \ + --max_train_steps_ti=500 \ + --max_train_steps_tuning=1000 \ + --perform_inversion=True \ + --use_template="object"\ + --weight_decay_ti=0.1 \ + --weight_decay_lora=0.001\ + --continue_inversion_lr=1e-4\ + --device="cuda:0"\ + --use_face_segmentation_condition\ \ No newline at end of file