Skip to content

Commit

Permalink
Revert "Other Acceleration tricks (#93)"
Browse files Browse the repository at this point in the history
This reverts commit eacf501.
  • Loading branch information
cloneofsimo authored Dec 29, 2022
1 parent eacf501 commit 61778c7
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 128 deletions.
1 change: 0 additions & 1 deletion lora_diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from .lora import *
from .dataset import *
from .utils import *
52 changes: 7 additions & 45 deletions lora_diffusion/cli_lora_pti.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,26 +142,21 @@ def collate_fn(examples):
"input_ids": input_ids,
"pixel_values": pixel_values,
}

if examples[0].get("mask", None) is not None:
batch["mask"] = torch.stack([example["mask"] for example in examples])

return batch

train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=train_batch_size,
shuffle=True,
collate_fn=collate_fn,
num_workers=2,
)

return train_dataloader


@torch.autocast("cuda")
def loss_step(
batch, unet, vae, text_encoder, scheduler, weight_dtype, t_mutliplier=1.0
):
def loss_step(batch, unet, vae, text_encoder, scheduler, weight_dtype):
latents = vae.encode(
batch["pixel_values"].to(dtype=weight_dtype).to(unet.device)
).latent_dist.sample()
Expand All @@ -172,7 +167,7 @@ def loss_step(

timesteps = torch.randint(
0,
int(scheduler.config.num_train_timesteps * t_mutliplier),
scheduler.config.num_train_timesteps,
(bsz,),
device=latents.device,
)
Expand All @@ -191,31 +186,6 @@ def loss_step(
else:
raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}")

if batch.get("mask", None) is not None:

mask = (
batch["mask"]
.to(model_pred.device)
.reshape(
model_pred.shape[0], 1, model_pred.shape[2] * 8, model_pred.shape[3] * 8
)
)
# resize to match model_pred
mask = (
F.interpolate(
mask.float(),
size=model_pred.shape[-2:],
mode="nearest",
)
+ 0.1
)

mask = mask / mask.mean()

model_pred = model_pred * mask

target = target * mask

loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
return loss

Expand Down Expand Up @@ -303,15 +273,7 @@ def perform_tuning(
for batch in dataloader:
optimizer.zero_grad()

loss = loss_step(
batch,
unet,
vae,
text_encoder,
scheduler,
weight_dtype,
t_mutliplier=0.8,
)
loss = loss_step(batch, unet, vae, text_encoder, scheduler, weight_dtype)
loss.backward()
torch.nn.utils.clip_grad_norm_(
itertools.chain(unet.parameters(), text_encoder.parameters()), 1.0
Expand Down Expand Up @@ -360,7 +322,7 @@ def train(
class_data_dir: Optional[str] = None,
stochastic_attribute: Optional[str] = None,
perform_inversion: bool = True,
use_template: Literal[None, "object", "style"] = None,
use_template: Optional[str] = Literal[None, "object", "style"],
placeholder_tokens: str = "<s>",
placeholder_token_at_data: Optional[str] = None,
initializer_tokens: str = "dog",
Expand All @@ -370,6 +332,7 @@ def train(
num_class_images: int = 100,
seed: int = 42,
resolution: int = 512,
center_crop: bool = False,
color_jitter: bool = True,
train_batch_size: int = 1,
sample_batch_size: int = 1,
Expand All @@ -387,7 +350,6 @@ def train(
learning_rate_ti: float = 5e-4,
continue_inversion: bool = True,
continue_inversion_lr: Optional[float] = None,
use_face_segmentation_condition: bool = False,
scale_lr: bool = False,
lr_scheduler: str = "constant",
lr_warmup_steps: int = 100,
Expand Down Expand Up @@ -451,8 +413,8 @@ def train(
class_prompt=class_prompt,
tokenizer=tokenizer,
size=resolution,
center_crop=center_crop,
color_jitter=color_jitter,
use_face_segmentation_condition=use_face_segmentation_condition,
)

train_dataloader = text2img_dataloader(
Expand Down
79 changes: 9 additions & 70 deletions lora_diffusion/dataset.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from torch.utils.data import Dataset

from typing import List, Tuple, Dict, Union, Optional
from PIL import Image, ImageFilter
from PIL import Image
from torchvision import transforms
from pathlib import Path
import cv2

import random
import numpy as np

OBJECT_TEMPLATE = [
"a photo of a {}",
Expand Down Expand Up @@ -91,12 +90,12 @@ def __init__(
class_prompt=None,
size=512,
h_flip=True,
center_crop=False,
color_jitter=False,
resize=True,
use_face_segmentation_condition=False,
blur_amount: int = 70,
):
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer
self.resize = resize

Expand All @@ -122,32 +121,25 @@ def __init__(
self.class_prompt = class_prompt
else:
self.class_data_root = None
self.h_flip = h_flip

self.image_transforms = transforms.Compose(
[
transforms.Resize(
size, interpolation=transforms.InterpolationMode.BILINEAR
)
if resize
else transforms.Lambda(lambda x: x),
transforms.ColorJitter(0.1, 0.1)
transforms.ColorJitter(0.2, 0.1)
if color_jitter
else transforms.Lambda(lambda x: x),
transforms.RandomHorizontalFlip()
if h_flip
else transforms.Lambda(lambda x: x),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)

self.use_face_segmentation_condition = use_face_segmentation_condition
if self.use_face_segmentation_condition:
import mediapipe as mp

mp_face_detection = mp.solutions.face_detection
self.face_detection = mp_face_detection.FaceDetection(
model_selection=1, min_detection_confidence=0.5
)
self.blur_amount = blur_amount

def __len__(self):
return self._length

Expand All @@ -171,59 +163,6 @@ def __getitem__(self, index):
for token, value in self.token_map.items():
text = text.replace(token, value)

if self.use_face_segmentation_condition:
image = cv2.imread(
str(self.instance_images_path[index % self.num_instance_images])
)
results = self.face_detection.process(
cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
)
black_image = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)

if results.detections:

for detection in results.detections:

x_min = int(
detection.location_data.relative_bounding_box.xmin
* image.shape[1]
)
y_min = int(
detection.location_data.relative_bounding_box.ymin
* image.shape[0]
)
width = int(
detection.location_data.relative_bounding_box.width
* image.shape[1]
)
height = int(
detection.location_data.relative_bounding_box.height
* image.shape[0]
)

# draw the colored rectangle
black_image[y_min : y_min + height, x_min : x_min + width] = 255

# blur the image
black_image = Image.fromarray(black_image, mode="L").filter(
ImageFilter.GaussianBlur(radius=self.blur_amount)
)
# to tensor
black_image = transforms.ToTensor()(black_image)
# resize as the instance image
black_image = transforms.Resize(
self.size, interpolation=transforms.InterpolationMode.BILINEAR
)(black_image)

example["mask"] = black_image

if self.h_flip and random.random() > 0.5:
hflip = transforms.RandomHorizontalFlip(p=1)

example["instance_images"] = hflip(example["instance_images"])
if self.use_face_segmentation_condition:
example["mask"] = hflip(example["mask"])

example["instance_prompt_ids"] = self.tokenizer(
text,
padding="do_not_pad",
Expand Down
12 changes: 0 additions & 12 deletions lora_diffusion/utils.py

This file was deleted.

0 comments on commit 61778c7

Please sign in to comment.