diff --git a/lora_diffusion/__init__.py b/lora_diffusion/__init__.py index 6434b23..0efbb94 100644 --- a/lora_diffusion/__init__.py +++ b/lora_diffusion/__init__.py @@ -1,3 +1,2 @@ from .lora import * from .dataset import * -from .utils import * diff --git a/lora_diffusion/cli_lora_pti.py b/lora_diffusion/cli_lora_pti.py index 8e9aa98..8bb91dc 100644 --- a/lora_diffusion/cli_lora_pti.py +++ b/lora_diffusion/cli_lora_pti.py @@ -142,10 +142,6 @@ def collate_fn(examples): "input_ids": input_ids, "pixel_values": pixel_values, } - - if examples[0].get("mask", None) is not None: - batch["mask"] = torch.stack([example["mask"] for example in examples]) - return batch train_dataloader = torch.utils.data.DataLoader( @@ -153,15 +149,14 @@ def collate_fn(examples): batch_size=train_batch_size, shuffle=True, collate_fn=collate_fn, + num_workers=2, ) return train_dataloader @torch.autocast("cuda") -def loss_step( - batch, unet, vae, text_encoder, scheduler, weight_dtype, t_mutliplier=1.0 -): +def loss_step(batch, unet, vae, text_encoder, scheduler, weight_dtype): latents = vae.encode( batch["pixel_values"].to(dtype=weight_dtype).to(unet.device) ).latent_dist.sample() @@ -172,7 +167,7 @@ def loss_step( timesteps = torch.randint( 0, - int(scheduler.config.num_train_timesteps * t_mutliplier), + scheduler.config.num_train_timesteps, (bsz,), device=latents.device, ) @@ -191,31 +186,6 @@ def loss_step( else: raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}") - if batch.get("mask", None) is not None: - - mask = ( - batch["mask"] - .to(model_pred.device) - .reshape( - model_pred.shape[0], 1, model_pred.shape[2] * 8, model_pred.shape[3] * 8 - ) - ) - # resize to match model_pred - mask = ( - F.interpolate( - mask.float(), - size=model_pred.shape[-2:], - mode="nearest", - ) - + 0.1 - ) - - mask = mask / mask.mean() - - model_pred = model_pred * mask - - target = target * mask - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") return loss @@ -303,15 +273,7 @@ def perform_tuning( for batch in dataloader: optimizer.zero_grad() - loss = loss_step( - batch, - unet, - vae, - text_encoder, - scheduler, - weight_dtype, - t_mutliplier=0.8, - ) + loss = loss_step(batch, unet, vae, text_encoder, scheduler, weight_dtype) loss.backward() torch.nn.utils.clip_grad_norm_( itertools.chain(unet.parameters(), text_encoder.parameters()), 1.0 @@ -360,7 +322,7 @@ def train( class_data_dir: Optional[str] = None, stochastic_attribute: Optional[str] = None, perform_inversion: bool = True, - use_template: Literal[None, "object", "style"] = None, + use_template: Optional[str] = Literal[None, "object", "style"], placeholder_tokens: str = "", placeholder_token_at_data: Optional[str] = None, initializer_tokens: str = "dog", @@ -370,6 +332,7 @@ def train( num_class_images: int = 100, seed: int = 42, resolution: int = 512, + center_crop: bool = False, color_jitter: bool = True, train_batch_size: int = 1, sample_batch_size: int = 1, @@ -387,7 +350,6 @@ def train( learning_rate_ti: float = 5e-4, continue_inversion: bool = True, continue_inversion_lr: Optional[float] = None, - use_face_segmentation_condition: bool = False, scale_lr: bool = False, lr_scheduler: str = "constant", lr_warmup_steps: int = 100, @@ -451,8 +413,8 @@ def train( class_prompt=class_prompt, tokenizer=tokenizer, size=resolution, + center_crop=center_crop, color_jitter=color_jitter, - use_face_segmentation_condition=use_face_segmentation_condition, ) train_dataloader = text2img_dataloader( diff --git a/lora_diffusion/dataset.py b/lora_diffusion/dataset.py index 37de046..6d2d712 100644 --- a/lora_diffusion/dataset.py +++ b/lora_diffusion/dataset.py @@ -1,12 +1,11 @@ from torch.utils.data import Dataset from typing import List, Tuple, Dict, Union, Optional -from PIL import Image, ImageFilter +from PIL import Image from torchvision import transforms from pathlib import Path -import cv2 + import random -import numpy as np OBJECT_TEMPLATE = [ "a photo of a {}", @@ -91,12 +90,12 @@ def __init__( class_prompt=None, size=512, h_flip=True, + center_crop=False, color_jitter=False, resize=True, - use_face_segmentation_condition=False, - blur_amount: int = 70, ): self.size = size + self.center_crop = center_crop self.tokenizer = tokenizer self.resize = resize @@ -122,7 +121,7 @@ def __init__( self.class_prompt = class_prompt else: self.class_data_root = None - self.h_flip = h_flip + self.image_transforms = transforms.Compose( [ transforms.Resize( @@ -130,24 +129,17 @@ def __init__( ) if resize else transforms.Lambda(lambda x: x), - transforms.ColorJitter(0.1, 0.1) + transforms.ColorJitter(0.2, 0.1) if color_jitter else transforms.Lambda(lambda x: x), + transforms.RandomHorizontalFlip() + if h_flip + else transforms.Lambda(lambda x: x), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) - self.use_face_segmentation_condition = use_face_segmentation_condition - if self.use_face_segmentation_condition: - import mediapipe as mp - - mp_face_detection = mp.solutions.face_detection - self.face_detection = mp_face_detection.FaceDetection( - model_selection=1, min_detection_confidence=0.5 - ) - self.blur_amount = blur_amount - def __len__(self): return self._length @@ -171,59 +163,6 @@ def __getitem__(self, index): for token, value in self.token_map.items(): text = text.replace(token, value) - if self.use_face_segmentation_condition: - image = cv2.imread( - str(self.instance_images_path[index % self.num_instance_images]) - ) - results = self.face_detection.process( - cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - ) - black_image = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8) - - if results.detections: - - for detection in results.detections: - - x_min = int( - detection.location_data.relative_bounding_box.xmin - * image.shape[1] - ) - y_min = int( - detection.location_data.relative_bounding_box.ymin - * image.shape[0] - ) - width = int( - detection.location_data.relative_bounding_box.width - * image.shape[1] - ) - height = int( - detection.location_data.relative_bounding_box.height - * image.shape[0] - ) - - # draw the colored rectangle - black_image[y_min : y_min + height, x_min : x_min + width] = 255 - - # blur the image - black_image = Image.fromarray(black_image, mode="L").filter( - ImageFilter.GaussianBlur(radius=self.blur_amount) - ) - # to tensor - black_image = transforms.ToTensor()(black_image) - # resize as the instance image - black_image = transforms.Resize( - self.size, interpolation=transforms.InterpolationMode.BILINEAR - )(black_image) - - example["mask"] = black_image - - if self.h_flip and random.random() > 0.5: - hflip = transforms.RandomHorizontalFlip(p=1) - - example["instance_images"] = hflip(example["instance_images"]) - if self.use_face_segmentation_condition: - example["mask"] = hflip(example["mask"]) - example["instance_prompt_ids"] = self.tokenizer( text, padding="do_not_pad", diff --git a/lora_diffusion/utils.py b/lora_diffusion/utils.py deleted file mode 100644 index cde4ba4..0000000 --- a/lora_diffusion/utils.py +++ /dev/null @@ -1,12 +0,0 @@ -from PIL import Image - - -def image_grid(_imgs, rows, cols): - - w, h = _imgs[0].size - grid = Image.new("RGB", size=(cols * w, rows * h)) - grid_w, grid_h = grid.size - - for i, img in enumerate(_imgs): - grid.paste(img, box=(i % cols * w, i // cols * h)) - return grid