diff --git a/README.md b/README.md index 1615155..99647ad 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ An implementation of [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) (Karras et al., 2022) for PyTorch. -**This repo is a work in progress** (models may break on later versions, script options may change). Also Config F is not currently implemented, this repo currently implements Config E. +**This repo is a work in progress** (models may break on later versions, script options may change). Multi-GPU and multi-node training is supported with [Hugging Face Accelerate](https://huggingface.co/docs/accelerate/index). You can configure Accelerate by running: diff --git a/k_diffusion/__init__.py b/k_diffusion/__init__.py index ecb976c..af29e52 100644 --- a/k_diffusion/__init__.py +++ b/k_diffusion/__init__.py @@ -1,2 +1,2 @@ -from . import evaluation, gns, layers, models, sampling, utils +from . import augmentation, evaluation, gns, layers, models, sampling, utils from .layers import Denoiser diff --git a/k_diffusion/augmentation.py b/k_diffusion/augmentation.py new file mode 100644 index 0000000..a6b7d29 --- /dev/null +++ b/k_diffusion/augmentation.py @@ -0,0 +1,95 @@ +from functools import reduce +import math +import operator + +import numpy as np +from skimage import transform +import torch +from torch import nn + + +def translate2d(tx, ty): + mat = [[1, 0, tx], + [0, 1, ty], + [0, 0, 1]] + return torch.tensor(mat, dtype=torch.float32) + + +def scale2d(sx, sy): + mat = [[sx, 0, 0], + [ 0, sy, 0], + [ 0, 0, 1]] + return torch.tensor(mat, dtype=torch.float32) + + +def rotate2d(theta): + mat = [[torch.cos(theta), torch.sin(-theta), 0], + [torch.sin(theta), torch.cos(theta), 0], + [ 0, 0, 1]] + return torch.tensor(mat, dtype=torch.float32) + + +class KarrasAugmentationPipeline: + def __init__(self, a_prob=0.12, a_scale=2**0.2, a_aniso=2**0.2, a_trans=1/8): + self.a_prob = a_prob + self.a_scale = a_scale + self.a_aniso = a_aniso + self.a_trans = a_trans + + def __call__(self, image): + h, w = image.size + mats = [] + + # x-flip + a0 = torch.randint(2, []).float() + mats.append(scale2d(1 - 2 * a0, 1)) + # y-flip + do = (torch.rand([]) < self.a_prob).float() + a1 = torch.randint(2, []).float() * do + mats.append(scale2d(1, 1 - 2 * a1)) + # scaling + do = (torch.rand([]) < self.a_prob).float() + a2 = torch.randn([]) * do + mats.append(scale2d(self.a_scale ** a2, self.a_scale ** a2)) + # rotation + do = (torch.rand([]) < self.a_prob).float() + a3 = torch.rand([]) * (math.pi * 2 - math.pi) * do + mats.append(rotate2d(-a3)) + # anisotropy + do = (torch.rand([]) < self.a_prob).float() + a4 = torch.rand([]) * (math.pi * 2 - math.pi) * do + a5 = torch.randn([]) * do + mats.append(rotate2d(a4)) + mats.append(scale2d(self.a_aniso ** a5, self.a_aniso ** -a5)) + mats.append(rotate2d(-a4)) + # translation + do = (torch.rand([]) < self.a_prob).float() + a6 = torch.randn([]) * do + a7 = torch.randn([]) * do + mats.append(translate2d(self.a_trans * w * a6, self.a_trans * h * a7)) + + # form the transformation matrix and conditioning vector + mat = reduce(operator.matmul, mats) + cond = torch.stack([a0, a1, a2, a3.cos() - 1, a3.sin(), a5 * a4.cos(), a5 * a4.sin(), a6, a7]) + + # apply the transformation + image = np.array(image, dtype=np.float32) / 255 + tf = transform.AffineTransform(mat.numpy()) + image = transform.warp(image, tf, order=3, mode='reflect', cval=0.5, clip=False, preserve_range=True) + image = torch.as_tensor(image).movedim(2, 0) * 2 - 1 + return image, cond + + +class KarrasAugmentWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.inner_model = model + + def forward(self, input, sigma, aug_cond=None, mapping_cond=None, **kwargs): + if aug_cond is None: + aug_cond = input.new_zeros([input.shape[0], 9]) + if mapping_cond is None: + mapping_cond = aug_cond + else: + mapping_cond = torch.cat([aug_cond, mapping_cond], dim=1) + return self.inner_model(input, sigma, mapping_cond=mapping_cond, **kwargs) diff --git a/k_diffusion/models.py b/k_diffusion/models.py index dcad14d..9b6bf7c 100644 --- a/k_diffusion/models.py +++ b/k_diffusion/models.py @@ -68,11 +68,13 @@ def __init__(self, feats_in, feats_out, n_layers=2): class ImageDenoiserModel(nn.Module): - def __init__(self, c_in, feats_in, depths, channels, self_attn_depths, dropout_rate=0.): + def __init__(self, c_in, feats_in, depths, channels, self_attn_depths, mapping_cond_dim=0, unet_cond_dim=0, dropout_rate=0.): super().__init__() - self.timestep_embed = layers.FourierFeatures(1, 256) - self.mapping = MappingNet(256, feats_in) - self.proj_in = nn.Conv2d(c_in, channels[0], 1) + self.timestep_embed = layers.FourierFeatures(1, feats_in) + if mapping_cond_dim > 0: + self.mapping_cond = nn.Linear(mapping_cond_dim, feats_in, bias=False) + self.mapping = MappingNet(feats_in, feats_in) + self.proj_in = nn.Conv2d(c_in + unet_cond_dim, channels[0], 1) self.proj_out = nn.Conv2d(channels[0], c_in, 1) nn.init.zeros_(self.proj_out.weight) nn.init.zeros_(self.proj_out.bias) @@ -86,11 +88,14 @@ def __init__(self, c_in, feats_in, depths, channels, self_attn_depths, dropout_r u_blocks.append(UBlock(depths[i], feats_in, my_c_in, channels[i], my_c_out, upsample=i > 0, self_attn=self_attn_depths[i], dropout_rate=dropout_rate)) self.u_net = layers.UNet(d_blocks, reversed(u_blocks)) - def forward(self, input, sigma): + def forward(self, input, sigma, mapping_cond=None, unet_cond=None): c_noise = sigma.log() / 4 timestep_embed = self.timestep_embed(utils.append_dims(c_noise, 2)) - mapping_out = self.mapping(timestep_embed) + mapping_cond_embed = torch.zeros_like(timestep_embed) if mapping_cond is None else self.mapping_cond(mapping_cond) + mapping_out = self.mapping(timestep_embed + mapping_cond_embed) cond = {'cond': mapping_out} + if unet_cond is not None: + input = torch.cat([input, unet_cond], dim=1) input = self.proj_in(input) input = self.u_net(input, cond) input = self.proj_out(input) diff --git a/model_configs/model_config_32x32_small.json b/model_configs/model_config_32x32_small.json index 34f2cc8..82bdb8f 100644 --- a/model_configs/model_config_32x32_small.json +++ b/model_configs/model_config_32x32_small.json @@ -6,6 +6,8 @@ "depths": [2, 4, 4], "channels": [128, 256, 512], "self_attn_depths": [false, true, true], + "dropout_rate": 0.0, + "augment_prob": 0.12, "sigma_data": 0.5, "sigma_min": 1e-2, "sigma_max": 80, diff --git a/requirements.txt b/requirements.txt index a2c5d7d..7203400 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,6 +2,7 @@ accelerate einops Pillow resize-right +scikit-image scipy torch torchvision diff --git a/train.py b/train.py index 3aac04d..964f4c8 100755 --- a/train.py +++ b/train.py @@ -43,6 +43,9 @@ def main(): help='the checkpoint to resume from') p.add_argument('--save-every', type=int, default=10000, help='save every this many steps') + p.add_argument('--start-method', type=str, default='spawn', + choices=['fork', 'forkserver', 'spawn'], + help='the multiprocessing start method') p.add_argument('--train-set', type=str, required=True, help='the training set location') p.add_argument('--wandb-entity', type=str, @@ -53,9 +56,10 @@ def main(): help='the wandb project name (specify this to enable wandb)') p.add_argument('--wandb-save-model', action='store_true', help='save model to wandb') - args = p.parse_args() + mp.set_start_method(args.start_method) + model_config = json.load(open(args.model_config)) # TODO: allow non-square input sizes assert len(model_config['input_size']) == 2 and model_config['input_size'][0] == model_config['input_size'][1] @@ -71,8 +75,11 @@ def main(): model_config['mapping_out'], model_config['depths'], model_config['channels'], - model_config['self_attn_depths'] + model_config['self_attn_depths'], + dropout_rate=model_config['dropout_rate'], + mapping_cond_dim=9, ) + inner_model = K.augmentation.KarrasAugmentWrapper(inner_model) accelerator.print('Parameters:', K.utils.n_params(inner_model)) # If logging to wandb, initialize the run @@ -88,17 +95,26 @@ def main(): sched = K.utils.InverseLR(opt, inv_gamma=50000, power=1/2, warmup=0.99) ema_sched = K.utils.EMAWarmup(power=2/3, max_value=0.9999) - tf = transforms.Compose([ + tf_no_aug = transforms.Compose([ transforms.Resize(size[0], interpolation=transforms.InterpolationMode.LANCZOS), transforms.CenterCrop(size[0]), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]) ]) + train_set_no_aug = datasets.ImageFolder(args.train_set, transform=tf_no_aug) + train_dl_no_aug = data.DataLoader(train_set_no_aug, args.batch_size, shuffle=True, + num_workers=args.num_workers) + + tf = transforms.Compose([ + transforms.Resize(size[0], interpolation=transforms.InterpolationMode.LANCZOS), + transforms.CenterCrop(size[0]), + K.augmentation.KarrasAugmentationPipeline(model_config['augment_prob']), + ]) train_set = datasets.ImageFolder(args.train_set, transform=tf) train_dl = data.DataLoader(train_set, args.batch_size, shuffle=True, drop_last=True, num_workers=args.num_workers, persistent_workers=True) - inner_model, opt, train_dl = accelerator.prepare(inner_model, opt, train_dl) + inner_model, opt, train_dl, train_dl_no_aug = accelerator.prepare(inner_model, opt, train_dl, train_dl_no_aug) if use_wandb: wandb.watch(inner_model) if args.gns: @@ -126,9 +142,9 @@ def main(): step = 0 extractor = K.evaluation.InceptionV3FeatureExtractor(device=device) - train_iter = iter(train_dl) + train_iter_no_aug = iter(train_dl_no_aug) accelerator.print('Computing features for reals...') - reals_features = K.evaluation.compute_features(accelerator, lambda x: next(train_iter)[0], extractor, args.evaluate_n, args.batch_size) + reals_features = K.evaluation.compute_features(accelerator, lambda x: next(train_iter_no_aug)[0], extractor, args.evaluate_n, args.batch_size) if accelerator.is_main_process: metrics_log_filepath = Path(f'{args.name}_metrics.csv') if metrics_log_filepath.exists(): @@ -202,10 +218,10 @@ def save(): while True: for batch in tqdm(train_dl, disable=not accelerator.is_local_main_process): opt.zero_grad() - reals = batch[0].to(device) + reals, aug_cond = batch[0] noise = torch.randn_like(reals) sigma = torch.distributions.LogNormal(sigma_mean, sigma_std).sample([reals.shape[0]]).to(device) - loss = model.loss(reals, noise, sigma).mean() + loss = model.loss(reals, noise, sigma, aug_cond=aug_cond).mean() accelerator.backward(loss) if args.gns: sq_norm_small_batch, sq_norm_large_batch = accelerator.reduce(gns_stats_hook.get_stats(), 'mean').tolist() @@ -250,5 +266,4 @@ def save(): if __name__ == '__main__': - mp.set_start_method('spawn') main()