Skip to content

Commit

Permalink
Other acceleration trick, reopened (#95)
Browse files Browse the repository at this point in the history
* feat : face segmentation mask

* feat :

* feat : conditioning dataset
  • Loading branch information
cloneofsimo authored Dec 29, 2022
1 parent 431df66 commit 5240d32
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 16 deletions.
1 change: 1 addition & 0 deletions lora_diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .lora import *
from .dataset import *
from .utils import *
52 changes: 45 additions & 7 deletions lora_diffusion/cli_lora_pti.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,21 +142,26 @@ def collate_fn(examples):
"input_ids": input_ids,
"pixel_values": pixel_values,
}

if examples[0].get("mask", None) is not None:
batch["mask"] = torch.stack([example["mask"] for example in examples])

return batch

train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=train_batch_size,
shuffle=True,
collate_fn=collate_fn,
num_workers=2,
)

return train_dataloader


@torch.autocast("cuda")
def loss_step(batch, unet, vae, text_encoder, scheduler, weight_dtype):
def loss_step(
batch, unet, vae, text_encoder, scheduler, weight_dtype, t_mutliplier=1.0
):
latents = vae.encode(
batch["pixel_values"].to(dtype=weight_dtype).to(unet.device)
).latent_dist.sample()
Expand All @@ -167,7 +172,7 @@ def loss_step(batch, unet, vae, text_encoder, scheduler, weight_dtype):

timesteps = torch.randint(
0,
scheduler.config.num_train_timesteps,
int(scheduler.config.num_train_timesteps * t_mutliplier),
(bsz,),
device=latents.device,
)
Expand All @@ -186,6 +191,31 @@ def loss_step(batch, unet, vae, text_encoder, scheduler, weight_dtype):
else:
raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}")

if batch.get("mask", None) is not None:

mask = (
batch["mask"]
.to(model_pred.device)
.reshape(
model_pred.shape[0], 1, model_pred.shape[2] * 8, model_pred.shape[3] * 8
)
)
# resize to match model_pred
mask = (
F.interpolate(
mask.float(),
size=model_pred.shape[-2:],
mode="nearest",
)
+ 0.1
)

mask = mask / mask.mean()

model_pred = model_pred * mask

target = target * mask

loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
return loss

Expand Down Expand Up @@ -273,7 +303,15 @@ def perform_tuning(
for batch in dataloader:
optimizer.zero_grad()

loss = loss_step(batch, unet, vae, text_encoder, scheduler, weight_dtype)
loss = loss_step(
batch,
unet,
vae,
text_encoder,
scheduler,
weight_dtype,
t_mutliplier=0.8,
)
loss.backward()
torch.nn.utils.clip_grad_norm_(
itertools.chain(unet.parameters(), text_encoder.parameters()), 1.0
Expand Down Expand Up @@ -322,7 +360,7 @@ def train(
class_data_dir: Optional[str] = None,
stochastic_attribute: Optional[str] = None,
perform_inversion: bool = True,
use_template: Optional[str] = Literal[None, "object", "style"],
use_template: Literal[None, "object", "style"] = None,
placeholder_tokens: str = "<s>",
placeholder_token_at_data: Optional[str] = None,
initializer_tokens: str = "dog",
Expand All @@ -332,7 +370,6 @@ def train(
num_class_images: int = 100,
seed: int = 42,
resolution: int = 512,
center_crop: bool = False,
color_jitter: bool = True,
train_batch_size: int = 1,
sample_batch_size: int = 1,
Expand All @@ -350,6 +387,7 @@ def train(
learning_rate_ti: float = 5e-4,
continue_inversion: bool = True,
continue_inversion_lr: Optional[float] = None,
use_face_segmentation_condition: bool = False,
scale_lr: bool = False,
lr_scheduler: str = "constant",
lr_warmup_steps: int = 100,
Expand Down Expand Up @@ -413,8 +451,8 @@ def train(
class_prompt=class_prompt,
tokenizer=tokenizer,
size=resolution,
center_crop=center_crop,
color_jitter=color_jitter,
use_face_segmentation_condition=use_face_segmentation_condition,
)

train_dataloader = text2img_dataloader(
Expand Down
81 changes: 72 additions & 9 deletions lora_diffusion/dataset.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from torch.utils.data import Dataset

from typing import List, Tuple, Dict, Union, Optional
from PIL import Image
from PIL import Image, ImageFilter
from torchvision import transforms
from pathlib import Path

import cv2
import random
import numpy as np

OBJECT_TEMPLATE = [
"a photo of a {}",
Expand Down Expand Up @@ -90,12 +91,12 @@ def __init__(
class_prompt=None,
size=512,
h_flip=True,
center_crop=False,
color_jitter=False,
resize=True,
use_face_segmentation_condition=False,
blur_amount: int = 70,
):
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer
self.resize = resize

Expand All @@ -121,25 +122,32 @@ def __init__(
self.class_prompt = class_prompt
else:
self.class_data_root = None

self.h_flip = h_flip
self.image_transforms = transforms.Compose(
[
transforms.Resize(
size, interpolation=transforms.InterpolationMode.BILINEAR
)
if resize
else transforms.Lambda(lambda x: x),
transforms.ColorJitter(0.2, 0.1)
transforms.ColorJitter(0.1, 0.1)
if color_jitter
else transforms.Lambda(lambda x: x),
transforms.RandomHorizontalFlip()
if h_flip
else transforms.Lambda(lambda x: x),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)

self.use_face_segmentation_condition = use_face_segmentation_condition
if self.use_face_segmentation_condition:
import mediapipe as mp

mp_face_detection = mp.solutions.face_detection
self.face_detection = mp_face_detection.FaceDetection(
model_selection=1, min_detection_confidence=0.5
)
self.blur_amount = blur_amount

def __len__(self):
return self._length

Expand All @@ -163,6 +171,61 @@ def __getitem__(self, index):
for token, value in self.token_map.items():
text = text.replace(token, value)

print(text)

if self.use_face_segmentation_condition:
image = cv2.imread(
str(self.instance_images_path[index % self.num_instance_images])
)
results = self.face_detection.process(
cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
)
black_image = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)

if results.detections:

for detection in results.detections:

x_min = int(
detection.location_data.relative_bounding_box.xmin
* image.shape[1]
)
y_min = int(
detection.location_data.relative_bounding_box.ymin
* image.shape[0]
)
width = int(
detection.location_data.relative_bounding_box.width
* image.shape[1]
)
height = int(
detection.location_data.relative_bounding_box.height
* image.shape[0]
)

# draw the colored rectangle
black_image[y_min : y_min + height, x_min : x_min + width] = 255

# blur the image
black_image = Image.fromarray(black_image, mode="L").filter(
ImageFilter.GaussianBlur(radius=self.blur_amount)
)
# to tensor
black_image = transforms.ToTensor()(black_image)
# resize as the instance image
black_image = transforms.Resize(
self.size, interpolation=transforms.InterpolationMode.BILINEAR
)(black_image)

example["mask"] = black_image

if self.h_flip and random.random() > 0.5:
hflip = transforms.RandomHorizontalFlip(p=1)

example["instance_images"] = hflip(example["instance_images"])
if self.use_face_segmentation_condition:
example["mask"] = hflip(example["mask"])

example["instance_prompt_ids"] = self.tokenizer(
text,
padding="do_not_pad",
Expand Down
12 changes: 12 additions & 0 deletions lora_diffusion/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from PIL import Image


def image_grid(_imgs, rows, cols):

w, h = _imgs[0].size
grid = Image.new("RGB", size=(cols * w, rows * h))
grid_w, grid_h = grid.size

for i, img in enumerate(_imgs):
grid.paste(img, box=(i % cols * w, i // cols * h))
return grid
31 changes: 31 additions & 0 deletions scripts/use_face_conditioning_example.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export INSTANCE_DIR="./data_example_small"
export OUTPUT_DIR="./exps/output_example_enid_w_mask"

lora_pti \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--train_text_encoder \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=1 \
--learning_rate_unet=3e-4 \
--learning_rate_text=3e-4 \
--learning_rate_ti=1e-3 \
--color_jitter \
--lr_scheduler="constant" \
--lr_warmup_steps=100 \
--placeholder_tokens="<s1>|<s2>|<s3>" \
--placeholder_token_at_data="<s>|<s1><s2><s3>"\
--initializer_tokens="girl|<rand>|<rand>" \
--save_steps=100 \
--max_train_steps_ti=500 \
--max_train_steps_tuning=1000 \
--perform_inversion=True \
--use_template="object"\
--weight_decay_ti=0.1 \
--weight_decay_lora=0.001\
--continue_inversion_lr=1e-4\
--device="cuda:0"\
--use_face_segmentation_condition\

0 comments on commit 5240d32

Please sign in to comment.