diff --git a/.flake8 b/.flake8 index bd09e59..8103c45 100644 --- a/.flake8 +++ b/.flake8 @@ -4,4 +4,4 @@ plugins = flake8-import-order application_import_names = arcface_converter import-order-style = pycharm per-file-ignores = preparing.py:E402 -exclude = LivePortrait +exclude = face_swapper diff --git a/face_swapper/config.ini b/face_swapper/config.ini index 0651a02..b94f925 100644 --- a/face_swapper/config.ini +++ b/face_swapper/config.ini @@ -1,55 +1,65 @@ [preparing.dataset] -dataset_path = +dataset_path = /assets/VGGface2_None_norm_512_true_bygfpgan [preparing.dataloader] same_person_probability = 0.2 [preparing.augmentation] -expression_augmentation = false +expression = false [training.loader] -batch_size = 6 +batch_size = 4 num_workers = 8 [training.generator] num_blocks = 2 id_channels = 512 +learning_rate = 0.0004 [training.discriminator] input_channels = 3 num_filters = 64 num_layers = 5 num_discriminators = 3 +learning_rate = 0.0004 +disable = false [auxiliary_models.paths] -arcface_path = -landmarker_path = -motion_extractor_path = /home/hari/Documents/Github/Phantom/assets/pretrained_models/liveportrait_motion_extractor.pth -feature_extractor_path = -warping_netwrk_path = -spade_generator_path = +arcface_path = /assets/pretrained_models/arcface_w600k_r50.pt +landmarker_path = /assets/pretrained_models/landmark_203.pt +motion_extractor_path = /assets/pretrained_models/liveportrait_motion_extractor.pth +feature_extractor_path = /assets/pretrained_models/liveportrait_feature_extractor.pth +warping_network_path = /assets/pretrained_models/liveportrait_warping_model.pth +spade_generator_path = /assets/pretrained_models/liveportrait_spade_generator.pth [training.losses] weight_adversarial = 1 weight_identity = 20 weight_attribute = 10 weight_reconstruction = 10 -weight_tsr = 0 -weight_expression = 0 +weight_tsr = 100 +weight_eye_gaze = 5 +weight_eye_open = 5 +weight_lip_open = 5 -[training.optimizers] -scheduler_step = 5000 -scheduler_gamma = 0.2 -generator_learning_rate = 0.0004 -discriminator_learning_rate = 0.0004 +[training.schedulers] +step = 5000 +gamma = 0.2 [training.trainer] -epochs = 50 +max_epochs = 50 disable_discriminator = false [training.output] -directory_path = -file_pattern = +checkpoint_path = checkpoints/last.ckpt +directory_path = checkpoints +file_pattern = 'checkpoint-{epoch}-{step}-{l_G:.4f}-{l_D:.4f}' +preview_frequency = 250 +validation_frequency = 1000 + +[training.validation] +sources = assets/test/front/sources +targets = assets/test/front/targets [exporting] directory_path = diff --git a/face_swapper/src/data_loader.py b/face_swapper/src/data_loader.py index 6019747..412a970 100644 --- a/face_swapper/src/data_loader.py +++ b/face_swapper/src/data_loader.py @@ -9,7 +9,7 @@ from torch.utils.data import TensorDataset from .augmentations import apply_random_motion_blur -from .sub_typing import Batch +from .typing import Batch CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') @@ -27,6 +27,7 @@ def __init__(self, dataset_path : str) -> None: self.image_paths = glob.glob('{}/*/*.*g'.format(dataset_path)) self.folder_paths = glob.glob('{}/*'.format(dataset_path)) self.image_path_dict = {} + self._current_index = 0 for folder_path in tqdm.tqdm(self.folder_paths): image_paths = glob.glob('{}/*'.format(folder_path)) @@ -50,12 +51,12 @@ def __init__(self, dataset_path : str) -> None: [ transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC), transforms.ToTensor(), - transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC), transforms.RandomHorizontalFlip(p = 0.5), transforms.RandomApply([ apply_random_motion_blur ], p = 0.3), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation = 0.2, hue = 0.1), transforms.RandomAffine(8, translate = (0.02, 0.02), scale = (0.98, 1.02), shear = (1, 1), fill = 0), - transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC), ]) def __getitem__(self, item : int) -> Batch: @@ -80,3 +81,11 @@ def __getitem__(self, item : int) -> Batch: def __len__(self) -> int: return self.dataset_total + + + def state_dict(self): + return {'current_index': self._current_index} + + + def load_state_dict(self, state_dict): + self._current_index = state_dict['current_index'] diff --git a/face_swapper/src/discriminator.py b/face_swapper/src/discriminator.py index b1b3938..b38ca9e 100644 --- a/face_swapper/src/discriminator.py +++ b/face_swapper/src/discriminator.py @@ -3,7 +3,7 @@ import numpy import torch.nn as nn -from .sub_typing import Tensor +from .typing import Tensor, DiscriminatorOutputs class NLayerDiscriminator(nn.Module): @@ -49,7 +49,6 @@ def forward(self, input_tensor : Tensor) -> Tensor: return self.model(input_tensor) -# input_channels=3, num_filters=64, num_layers=5, num_discriminators=3 class MultiscaleDiscriminator(nn.Module): def __init__(self, input_channels : int, num_filters : int, num_layers : int, num_discriminators : int): super(MultiscaleDiscriminator, self).__init__() @@ -61,18 +60,12 @@ def __init__(self, input_channels : int, num_filters : int, num_layers : int, nu setattr(self, 'discriminator_layer_{}'.format(discriminator_index), single_discriminator.model) self.downsample = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = [ 1, 1 ], count_include_pad = False) # type:ignore[arg-type] - def single_discriminator_forward(self, model_layers : nn.Sequential, input_tensor : Tensor) -> List[Tensor]: - if self.return_intermediate_features: - feature_maps = [ input_tensor ] + def single_discriminator_forward(self, model_layers : nn.Sequential, input_tensor : Tensor) -> List[Tensor]: + return [ model_layers(input_tensor) ] - for layer in model_layers: - feature_maps.append(layer(feature_maps[-1])) - return feature_maps[1:] - else: - return [ model_layers(input_tensor) ] - def forward(self, input_tensor : Tensor) -> List[Tensor]: + def forward(self, input_tensor : Tensor) -> DiscriminatorOutputs: discriminator_outputs = [] downsampled_input = input_tensor diff --git a/face_swapper/src/generator.py b/face_swapper/src/generator.py index 8b8bff2..3dd18c9 100644 --- a/face_swapper/src/generator.py +++ b/face_swapper/src/generator.py @@ -1,55 +1,52 @@ -from typing import Tuple - import torch import torch.nn as nn - -from .sub_typing import Tensor, UNetAttributes +import torch.nn.functional as F class AdaptiveEmbeddingIntegrationNetwork(nn.Module): - def __init__(self, id_channels : int, num_blocks : int) -> None: + def __init__(self, id_channels=512, num_blocks=2): super(AdaptiveEmbeddingIntegrationNetwork, self).__init__() self.encoder = UNet() - self.generator = AdaptiveAttentionalDenorm_Generator(id_channels, num_blocks) + self.generator = AADGenerator(id_channels, num_blocks) - def forward(self, target : Tensor, source_embedding : Tensor) -> Tuple[Tensor, UNetAttributes]: + def forward(self, target, source_embedding): target_attributes = self.get_attributes(target) swap = self.generator(target_attributes, source_embedding) return swap, target_attributes - def get_attributes(self, target : Tensor) -> UNetAttributes: + def get_attributes(self, target): return self.encoder(target) -class AdaptiveAttentionalDenorm_Generator(nn.Module): - def __init__(self, id_channels : int, num_blocks : int) -> None: - super(AdaptiveAttentionalDenorm_Generator, self).__init__() +class AADGenerator(nn.Module): + def __init__(self, id_channels=512, num_blocks=2): + super(AADGenerator, self).__init__() self.upsample = Upsample(id_channels, 1024 * 4) - self.block_1 = AdaptiveAttentionalDenorm_ResBlock(1024, 1024, 1024, id_channels, num_blocks) - self.block_2 = AdaptiveAttentionalDenorm_ResBlock(1024, 1024, 2048, id_channels, num_blocks) - self.block_3 = AdaptiveAttentionalDenorm_ResBlock(1024, 1024, 1024, id_channels, num_blocks) - self.block_4 = AdaptiveAttentionalDenorm_ResBlock(1024, 512, 512, id_channels, num_blocks) - self.block_5 = AdaptiveAttentionalDenorm_ResBlock(512, 256, 256, id_channels, num_blocks) - self.block_6 = AdaptiveAttentionalDenorm_ResBlock(256, 128, 128, id_channels, num_blocks) - self.block_7 = AdaptiveAttentionalDenorm_ResBlock(128, 64, 64, id_channels, num_blocks) - self.block_7 = AdaptiveAttentionalDenorm_ResBlock(64, 3, 64, id_channels, num_blocks) + self.AADBlk1 = AAD_ResBlk(1024, 1024, 1024, id_channels, num_blocks) + self.AADBlk2 = AAD_ResBlk(1024, 1024, 2048, id_channels, num_blocks) + self.AADBlk3 = AAD_ResBlk(1024, 1024, 1024, id_channels, num_blocks) + self.AADBlk4 = AAD_ResBlk(1024, 512, 512, id_channels, num_blocks) + self.AADBlk5 = AAD_ResBlk(512, 256, 256, id_channels, num_blocks) + self.AADBlk6 = AAD_ResBlk(256, 128, 128, id_channels, num_blocks) + self.AADBlk7 = AAD_ResBlk(128, 64, 64, id_channels, num_blocks) + self.AADBlk8 = AAD_ResBlk(64, 3, 64, id_channels, num_blocks) self.apply(initialize_weight) - def forward(self, target_attributes : UNetAttributes, source_embedding : Tensor) -> Tensor: + def forward(self, target_attributes, source_embedding): feature_map = self.upsample(source_embedding) - feature_map_1 = nn.functional.interpolate(self.block_1(feature_map, target_attributes[0], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False) - feature_map_2 = nn.functional.interpolate(self.block_2(feature_map_1, target_attributes[1], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False) - feature_map_3 = nn.functional.interpolate(self.block_3(feature_map_2, target_attributes[2], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False) - feature_map_4 = nn.functional.interpolate(self.block_4(feature_map_3, target_attributes[3], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False) - feature_map_5 = nn.functional.interpolate(self.block_5(feature_map_4, target_attributes[4], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False) - feature_map_6 = nn.functional.interpolate(self.block_6(feature_map_5, target_attributes[5], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False) - feature_map_7 = nn.functional.interpolate(self.block_7(feature_map_6, target_attributes[6], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False) - output = self.block_7(feature_map_7, target_attributes[7], source_embedding) + feature_map_1 = F.interpolate(self.AADBlk1(feature_map, target_attributes[0], source_embedding), scale_factor=2, mode='bilinear', align_corners=False) + feature_map_2 = F.interpolate(self.AADBlk2(feature_map_1, target_attributes[1], source_embedding), scale_factor=2, mode='bilinear', align_corners=False) + feature_map_3 = F.interpolate(self.AADBlk3(feature_map_2, target_attributes[2], source_embedding), scale_factor=2, mode='bilinear', align_corners=False) + feature_map_4 = F.interpolate(self.AADBlk4(feature_map_3, target_attributes[3], source_embedding), scale_factor=2, mode='bilinear', align_corners=False) + feature_map_5 = F.interpolate(self.AADBlk5(feature_map_4, target_attributes[4], source_embedding), scale_factor=2, mode='bilinear', align_corners=False) + feature_map_6 = F.interpolate(self.AADBlk6(feature_map_5, target_attributes[5], source_embedding), scale_factor=2, mode='bilinear', align_corners=False) + feature_map_7 = F.interpolate(self.AADBlk7(feature_map_6, target_attributes[6], source_embedding), scale_factor=2, mode='bilinear', align_corners=False) + output = self.AADBlk8(feature_map_7, target_attributes[7], source_embedding) return torch.tanh(output) class UNet(nn.Module): - def __init__(self) -> None: + def __init__(self): super(UNet, self).__init__() self.downsampler_1 = Conv4x4(3, 32) self.downsampler_2 = Conv4x4(32, 64) @@ -57,7 +54,9 @@ def __init__(self) -> None: self.downsampler_4 = Conv4x4(128, 256) self.downsampler_5 = Conv4x4(256, 512) self.downsampler_6 = Conv4x4(512, 1024) + self.bottleneck = Conv4x4(1024, 1024) + self.upsampler_1 = DeConv4x4(1024, 1024) self.upsampler_2 = DeConv4x4(2048, 512) self.upsampler_3 = DeConv4x4(1024, 256) @@ -66,53 +65,64 @@ def __init__(self) -> None: self.upsampler_6 = DeConv4x4(128, 32) self.apply(initialize_weight) - def forward(self, input_tensor : Tensor) -> UNetAttributes: + def forward(self, input_tensor): downsample_feature_1 = self.downsampler_1(input_tensor) downsample_feature_2 = self.downsampler_2(downsample_feature_1) downsample_feature_3 = self.downsampler_3(downsample_feature_2) downsample_feature_4 = self.downsampler_4(downsample_feature_3) downsample_feature_5 = self.downsampler_5(downsample_feature_4) downsample_feature_6 = self.downsampler_6(downsample_feature_5) + bottleneck_output = self.bottleneck(downsample_feature_6) + upsample_feature_1 = self.upsampler_1(bottleneck_output, downsample_feature_6) upsample_feature_2 = self.upsampler_2(upsample_feature_1, downsample_feature_5) upsample_feature_3 = self.upsampler_3(upsample_feature_2, downsample_feature_4) upsample_feature_4 = self.upsampler_4(upsample_feature_3, downsample_feature_3) upsample_feature_5 = self.upsampler_5(upsample_feature_4, downsample_feature_2) upsample_feature_6 = self.upsampler_6(upsample_feature_5, downsample_feature_1) - output = nn.functional.interpolate(upsample_feature_6, scale_factor = 2, mode = 'bilinear', align_corners = False) + + output = F.interpolate(upsample_feature_6, scale_factor=2, mode='bilinear', align_corners=False) + return bottleneck_output, upsample_feature_1, upsample_feature_2, upsample_feature_3, upsample_feature_4, upsample_feature_5, upsample_feature_6, output -class AdaptiveAttentionalDenorm_Layer(nn.Module): - def __init__(self, input_channels : int, attr_channels : int, id_channels : int) -> None: - super(AdaptiveAttentionalDenorm_Layer, self).__init__() +class AADLayer(nn.Module): + def __init__(self, input_channels, attr_channels, id_channels): + super(AADLayer, self).__init__() self.attr_channels = attr_channels self.id_channels = id_channels self.input_channels = input_channels - self.conv_gamma = nn.Conv2d(attr_channels, input_channels, kernel_size = 1, stride = 1, padding = 0, bias = True) - self.conv_beta = nn.Conv2d(attr_channels, input_channels, kernel_size = 1, stride = 1, padding = 0, bias = True) + + self.conv_gamma = nn.Conv2d(attr_channels, input_channels, kernel_size=1, stride=1, padding=0, bias=True) + self.conv_beta = nn.Conv2d(attr_channels, input_channels, kernel_size=1, stride=1, padding=0, bias=True) self.fc_gamma = nn.Linear(id_channels, input_channels) self.fc_beta = nn.Linear(id_channels, input_channels) self.instance_norm = nn.InstanceNorm2d(input_channels, affine=False) - self.conv_mask = nn.Conv2d(input_channels, 1, kernel_size = 1, stride = 1, padding = 0, bias = True) - def forward(self, feature_map : Tensor, attr_embedding : Tensor, id_embedding : Tensor) -> Tensor: + self.conv_mask = nn.Conv2d(input_channels, 1, kernel_size=1, stride=1, padding=0, bias=True) + + def forward(self, feature_map, attr_embedding, id_embedding): feature_map = self.instance_norm(feature_map) + attr_gamma = self.conv_gamma(attr_embedding) attr_beta = self.conv_beta(attr_embedding) attr_modulation = attr_gamma * feature_map + attr_beta - id_gamma = self.fc_gamma(id_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map) - id_beta = self.fc_beta(id_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map) + + id_gamma = self.fc_gamma(id_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as( + feature_map) + id_beta = self.fc_beta(id_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as( + feature_map) id_modulation = id_gamma * feature_map + id_beta + feature_mask = torch.sigmoid(self.conv_mask(feature_map)) feature_blend = (1 - feature_mask) * attr_modulation + feature_mask * id_modulation return feature_blend class AddBlocksSequential(nn.Sequential): - def forward(self, *inputs : Tuple[Tensor, ...]) -> Tensor: - feature_map, attr_embedding, id_embedding = inputs + def forward(self, *inputs): + h, attr_embedding, id_embedding = inputs for index, module in enumerate(self._modules.values()): if index % 3 == 0 and index > 0: @@ -124,9 +134,9 @@ def forward(self, *inputs : Tuple[Tensor, ...]) -> Tensor: return inputs -class AdaptiveAttentionalDenorm_ResBlock(nn.Module): - def __init__(self, in_channels : int, out_channels : int, attr_channels : int, id_channels : int, num_blocks : int) -> None: - super(AdaptiveAttentionalDenorm_ResBlock, self).__init__() +class AAD_ResBlk(nn.Module): + def __init__(self, in_channels, out_channels, attr_channels, id_channels, num_blocks): + super(AAD_ResBlk, self).__init__() self.in_channels = in_channels self.out_channels = out_channels primary_add_blocks = [] @@ -135,22 +145,22 @@ def __init__(self, in_channels : int, out_channels : int, attr_channels : int, i intermediate_channels = in_channels if i < (num_blocks - 1) else out_channels primary_add_blocks.extend( [ - AdaptiveAttentionalDenorm_Layer(in_channels, attr_channels, id_channels), - nn.ReLU(inplace = True), - nn.Conv2d(in_channels, intermediate_channels, kernel_size = 3, stride = 1, padding = 1, bias = False) + AADLayer(in_channels, attr_channels, id_channels), + nn.ReLU(inplace=True), + nn.Conv2d(in_channels, intermediate_channels, kernel_size=3, stride=1, padding=1, bias=False) ]) self.primary_add_blocks = AddBlocksSequential(*primary_add_blocks) if in_channels != out_channels: auxiliary_add_blocks = \ [ - AdaptiveAttentionalDenorm_Layer(in_channels, attr_channels, id_channels), - nn.ReLU(inplace = True), - nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, bias = False) + AADLayer(in_channels, attr_channels, id_channels), + nn.ReLU(inplace=True), + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) ] self.auxiliary_add_blocks = AddBlocksSequential(*auxiliary_add_blocks) - def forward(self, feature_map : Tensor, attr_embedding : Tensor, id_embedding : Tensor) -> Tensor: + def forward(self, feature_map, attr_embedding, id_embedding): primary_feature = self.primary_add_blocks(feature_map, attr_embedding, id_embedding) if self.in_channels != self.out_channels: @@ -160,47 +170,49 @@ def forward(self, feature_map : Tensor, attr_embedding : Tensor, id_embedding : class Conv4x4(nn.Module): - def __init__(self, in_channels : int, out_channels : int) -> None: + def __init__(self, in_channels, out_channels): super(Conv4x4, self).__init__() - self.conv = nn.Conv2d(in_channels=in_channels, out_channels = out_channels, kernel_size = 4, stride = 2, padding = 1, bias = False) + self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1, + bias=False) self.batch_norm = nn.BatchNorm2d(out_channels) - self.leaky_relu = nn.LeakyReLU(0.1, inplace = True) + self.leaky_relu = nn.LeakyReLU(0.1, inplace=True) - def forward(self, input : Tensor) -> Tensor: - output = self.conv(input) - output = self.batch_norm(output) - output = self.leaky_relu(output) - return output + def forward(self, x): + x = self.conv(x) + x = self.batch_norm(x) + x = self.leaky_relu(x) + return x class DeConv4x4(nn.Module): - def __init__(self, in_channels : int, out_channels : int) -> None: + def __init__(self, in_channels, out_channels): super(DeConv4x4, self).__init__() - self.deconv = nn.ConvTranspose2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 4, stride = 2, padding = 1, bias = False) + self.deconv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2, + padding=1, bias=False) self.batch_norm = nn.BatchNorm2d(out_channels) - self.leaky_relu = nn.LeakyReLU(0.1, inplace = True) + self.leaky_relu = nn.LeakyReLU(0.1, inplace=True) - def forward(self, input : Tensor, skip_connection : Tensor) -> Tensor: - output = self.deconv(input) - output = self.batch_norm(output) - output = self.leaky_relu(output) - output = torch.cat((output, skip_connection), dim = 1) - return output + def forward(self, x, skip): + x = self.deconv(x) + x = self.batch_norm(x) + x = self.leaky_relu(x) + return torch.cat((x, skip), dim=1) class Upsample(nn.Module): - def __init__(self, in_channels : int, out_channels : int): + def __init__(self, in_channels, out_channels): super(Upsample, self).__init__() - self.initial_conv = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1) - self.pixel_shuffle = nn.PixelShuffle(upscale_factor = 2) + self.initial_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, + padding=1) + self.pixel_shuffle = nn.PixelShuffle(upscale_factor=2) - def forward(self, input : Tensor) -> Tensor: - output = self.initial_conv(input.view(input.shape[0], -1, 1, 1)) - output = self.pixel_shuffle(output) - return output + def forward(self, x): + x = self.initial_conv(x.view(x.shape[0], -1, 1, 1)) + x = self.pixel_shuffle(x) + return x -def initialize_weight(module : nn.Module) -> None: +def initialize_weight(module): if isinstance(module, nn.Linear): module.weight.data.normal_(0, 0.001) module.bias.data.zero_() diff --git a/face_swapper/src/helper.py b/face_swapper/src/helper.py new file mode 100644 index 0000000..4a6b2bd --- /dev/null +++ b/face_swapper/src/helper.py @@ -0,0 +1,128 @@ +import configparser +from typing import Tuple + +import torch +from .typing import Tensor +import numpy +import torch.nn.functional as F + +CONFIG = configparser.ConfigParser() +CONFIG.read('config.ini') + +if CONFIG.getboolean('preparing.augmentation', 'expression'): + from LivePortrait.src.utils.camera import headpose_pred_to_degree, get_rotation_matrix + +L2_loss = torch.nn.MSELoss() +EXPRESSION_MIN = numpy.array( +[ + [ + [-2.88067125e-02, -8.12731311e-02, -1.70541159e-03], + [-4.88598682e-02, -3.32196616e-02, -1.67431499e-04], + [-6.75425082e-02, -4.28681746e-02, -1.98950816e-04], + [-7.23103955e-02, -3.28503326e-02, -7.31324719e-04], + [-3.87073644e-02, -6.01546466e-02, -5.50269964e-04], + [-6.38048723e-02, -2.23840728e-01, -7.13261834e-04], + [-3.02710701e-02, -3.93195450e-02, -8.24086510e-06], + [-2.95799859e-02, -5.39318882e-02, -1.74219604e-04], + [-2.92359516e-02, -1.53050944e-02, -6.30460854e-05], + [-5.56493877e-03, -2.34344602e-02, -1.26858242e-04], + [-4.37593013e-02, -2.77768299e-02, -2.70503685e-02], + [-1.76926646e-02, -1.91676542e-02, -1.15090821e-04], + [-8.34268332e-03, -3.99775570e-03, -3.27481248e-05], + [-3.40162888e-02, -2.81868968e-02, -1.96679524e-04], + [-2.91855410e-02, -3.97511162e-02, -2.81230678e-05], + [-1.50395725e-02, -2.49494594e-02, -9.42573533e-05], + [-1.67938769e-02, -2.00953931e-02, -4.00750607e-04], + [-1.86435618e-02, -2.48535164e-02, -2.74416432e-02], + [-4.61211195e-03, -1.21660791e-02, -2.93173041e-04], + [-4.10017073e-02, -7.43824020e-02, -4.42762971e-02], + [-1.90370996e-02, -3.74363363e-02, -1.34740388e-02] + ] +]).astype(numpy.float32) +EXPRESSION_MAX = numpy.array( +[ + [ + [4.46682945e-02, 7.08772913e-02, 4.08344204e-04], + [2.14308221e-02, 6.15894832e-02, 4.85319615e-05], + [3.02363783e-02, 4.45043296e-02, 1.28298725e-05], + [3.05869691e-02, 3.79812494e-02, 6.57040102e-04], + [4.45670523e-02, 3.97259220e-02, 7.10966764e-04], + [9.43699256e-02, 9.85926315e-02, 2.02551950e-04], + [1.61131397e-02, 2.92906128e-02, 3.44733417e-06], + [5.23825921e-02, 1.07065082e-01, 6.61510974e-04], + [2.85718683e-03, 8.32320191e-03, 2.39314613e-04], + [2.57947259e-02, 1.60935968e-02, 2.41853559e-05], + [4.90833223e-02, 3.43903080e-02, 3.22353356e-02], + [1.44766076e-02, 3.39248963e-02, 1.42291479e-04], + [8.75749043e-04, 6.82212645e-03, 2.76097053e-05], + [1.86958015e-02, 3.84016186e-02, 7.33085908e-05], + [2.01714113e-02, 4.90544215e-02, 2.34028921e-05], + [2.46518422e-02, 3.29151377e-02, 3.48571630e-05], + [2.22457591e-02, 1.21796541e-02, 1.56396593e-04], + [1.72109623e-02, 3.01626958e-02, 1.36556877e-02], + [1.83460284e-02, 1.61141958e-02, 2.87440169e-04], + [3.57594155e-02, 1.80554688e-01, 2.75554154e-02], + [2.17450950e-02, 8.66811201e-02, 3.34241726e-02] + ] +]).astype(numpy.float32) + + +def randomize_expression(face_tensor, feature_extractor, motion_extractor, warping_network, spade_generator): + with torch.no_grad(): + face_tensor_norm = (face_tensor + 1) * 0.5 + input_device = face_tensor.device + feature_volume = feature_extractor(face_tensor_norm) + motion_extractor_dict = motion_extractor(face_tensor_norm) + + translation = motion_extractor_dict.get('t') + expression = motion_extractor_dict.get('exp') + scale = motion_extractor_dict.get('scale') + points = motion_extractor_dict.get('kp') + + pitch = headpose_pred_to_degree(motion_extractor_dict.get('pitch'))[:, None] + yaw = headpose_pred_to_degree(motion_extractor_dict.get('yaw'))[:, None] + roll = headpose_pred_to_degree(motion_extractor_dict.get('roll'))[:, None] + rotation_matrix = get_rotation_matrix(pitch, yaw, roll) + random_expression = get_random_expression_blend(expression) + + points_transformed = transform_points(points, rotation_matrix, expression, scale, translation) + points_driv = transform_points(points, rotation_matrix, random_expression, scale, translation) + + data = warping_network(feature_volume, points_driv, points_transformed).get('out') + output = spade_generator(data) + output = output.to(input_device) + output = F.interpolate(output.clamp(0, 1), [256, 256], mode='bilinear', align_corners=False) + output = (output - 0.5) * 2 + return output + + +def get_random_expression_blend(expression : Tensor) -> Tensor: + blend = 0.35 + expression = expression.view(-1, 21, 3) + min_array = torch.from_numpy(EXPRESSION_MIN).to(expression.device).to(expression.dtype).expand(expression.shape[0], -1, -1) + max_array = torch.from_numpy(EXPRESSION_MAX).to(expression.device).to(expression.dtype).expand(expression.shape[0], -1, -1) + random_batch = torch.rand_like(min_array).to(expression.device) * (max_array - min_array) + min_array + random_batch[:, [0, 1, 8, 6, 9, 4, 5, 10]] = expression[:, [0, 1, 8, 6, 9, 4, 5, 10]] + random_batch[:, [3, 7]] = random_batch[:, [13, 16]] * 0.1 + expression[:, [13, 16]] * 0.9 + random_batch[:, [3, 7]] = random_batch[:, [3, 7]] * 0.5 + expression[:, [3, 7]] * 0.5 + return random_batch * 0.8 * blend + expression * (1 - blend) + + +def transform_points(points : Tensor, rotation_matrix : Tensor, expression : Tensor, scale : Tensor, translation : Tensor): + points_transformed = points.view(-1, 21, 3) @ rotation_matrix + expression.view(-1, 21, 3) + points_transformed *= scale[..., None] + points_transformed[:, :, 0:2] += translation[:, None, 0:2] + return points_transformed + + +def hinge_loss(tensor : Tensor, is_positive : bool) -> Tensor: + if is_positive: + return torch.relu(1 - tensor) + else: + return torch.relu(tensor + 1) + + +def calc_distance_ratio(landmarks : Tensor, indices : Tuple[int, int, int, int]) -> Tensor: + distance_horizontal = torch.norm(landmarks[:, indices[0]] - landmarks[:, indices[1]], p = 2, dim = 1, keepdim = True) + distance_vertical = torch.norm(landmarks[:, indices[2]] - landmarks[:, indices[3]], p=2, dim = 1, keepdim = True) + return distance_horizontal / (distance_vertical + 1e-4) diff --git a/face_swapper/src/model_loader.py b/face_swapper/src/model_loader.py deleted file mode 100644 index db83a1a..0000000 --- a/face_swapper/src/model_loader.py +++ /dev/null @@ -1,50 +0,0 @@ -import configparser - -import torch -import torch.nn as nn - -from .discriminator import MultiscaleDiscriminator -from .generator import AdaptiveEmbeddingIntegrationNetwork - -CONFIG = configparser.ConfigParser() -CONFIG.read('config.ini') - - -def load_generator() -> nn.Module: - id_channels = CONFIG.getint('training.generator', 'id_channels') - num_blocks = CONFIG.getint('training.generator', 'num_blocks') - generator = AdaptiveEmbeddingIntegrationNetwork(id_channels, num_blocks) - return generator - - -def load_discriminator() -> nn.Module: - input_channels = CONFIG.getint('training.discriminator', 'input_channels') - num_filters = CONFIG.getint('training.discriminator', 'num_filters') - num_layers = CONFIG.getint('training.discriminator', 'num_layers') - num_discriminators = CONFIG.getint('training.discriminator', 'num_discriminators') - discriminator = MultiscaleDiscriminator(input_channels, num_filters, num_layers, num_discriminators) - return discriminator - - -def load_arcface() -> nn.Module: - model_path = CONFIG.get('auxiliary_models.paths', 'arcface_path') - arcface = torch.load(model_path, map_location = 'cpu', weights_only = False) - arcface.eval() - return arcface - - -def load_landmarker() -> nn.Module: - model_path = CONFIG.get('auxiliary_models.paths', 'landmarker_path') - landmarker = torch.load(model_path, map_location = 'cpu', weights_only = False) - landmarker.eval() - return landmarker - - -def load_motion_extractor() -> nn.Module: - from LivePortrait.src.modules.motion_extractor import MotionExtractor - - model_path = CONFIG.get('auxiliary_models.paths', 'motion_extractor_path') - motion_extractor = MotionExtractor(num_kp = 21, backbone = 'convnextv2_tiny') - motion_extractor.load_state_dict(torch.load(model_path, map_location = 'cpu', weights_only = True)) - motion_extractor.eval() - return motion_extractor diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 1482236..182fb58 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -1,5 +1,389 @@ -from .model_loader import load_motion_extractor +import configparser +import random + +from sympy.stats.sampling.sample_numpy import numpy + +from typing import Tuple +import os +import cv2 +import torchvision + +import pytorch_lightning +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import ModelCheckpoint +from torch.utils.data import DataLoader +from pytorch_lightning.utilities.types import OptimizerLRScheduler +import torch + +from .discriminator import MultiscaleDiscriminator +from .generator import AdaptiveEmbeddingIntegrationNetwork +from .data_loader import DataLoaderVGG, read_image + +from .typing import Tensor, LossDict, TargetAttributes, DiscriminatorOutputs, Batch +from .helper import hinge_loss, calc_distance_ratio, L2_loss, randomize_expression +from pytorch_msssim import ssim + +CONFIG = configparser.ConfigParser() +CONFIG.read('config.ini') + + +def load_models(): + id_channels = CONFIG.getint('training.generator', 'id_channels') + num_blocks = CONFIG.getint('training.generator', 'num_blocks') + generator = AdaptiveEmbeddingIntegrationNetwork(id_channels, num_blocks) + + input_channels = CONFIG.getint('training.discriminator', 'input_channels') + num_filters = CONFIG.getint('training.discriminator', 'num_filters') + num_layers = CONFIG.getint('training.discriminator', 'num_layers') + num_discriminators = CONFIG.getint('training.discriminator', 'num_discriminators') + discriminator = MultiscaleDiscriminator(input_channels, num_filters, num_layers, num_discriminators) + + model_path = CONFIG.get('auxiliary_models.paths', 'arcface_path') + arcface = torch.load(model_path, map_location = 'cpu', weights_only = False) + arcface.eval() + + if CONFIG.getfloat('training.losses', 'weight_eye_gaze') > 0 or CONFIG.getfloat('training.losses', 'weight_eye_open') > 0 or CONFIG.getfloat('training.losses', 'weight_lip_open') > 0: + model_path = CONFIG.get('auxiliary_models.paths', 'landmarker_path') + landmarker = torch.load(model_path, map_location = 'cpu', weights_only = False) + landmarker.eval() + else: + landmarker = None + + if CONFIG.getfloat('training.losses', 'weight_tsr') > 0 or CONFIG.getboolean('preparing.augmentation', 'expression'): + from LivePortrait.src.modules.motion_extractor import MotionExtractor + + model_path = CONFIG.get('auxiliary_models.paths', 'motion_extractor_path') + motion_extractor = MotionExtractor(num_kp = 21, backbone = 'convnextv2_tiny') + motion_extractor.load_state_dict(torch.load(model_path, map_location = 'cpu', weights_only = True)) + motion_extractor.eval() + else: + motion_extractor = None + + if CONFIG.getboolean('preparing.augmentation', 'expression'): + from LivePortrait.src.modules.appearance_feature_extractor import AppearanceFeatureExtractor + from LivePortrait.src.modules.warping_network import WarpingNetwork + from LivePortrait.src.modules.spade_generator import SPADEDecoder + + feature_extractor_path = CONFIG.get('auxiliary_models.paths', 'feature_extractor_path') + feature_extractor = AppearanceFeatureExtractor(3, 64, 2, 512, 32, 16, 6) + feature_extractor.load_state_dict(torch.load(feature_extractor_path, map_location = 'cpu', weights_only = True)) + feature_extractor.eval() + + warping_network_path = CONFIG.get('auxiliary_models.paths', 'warping_network_path') + dense_motion_params = { + 'block_expansion': 32, + 'max_features': 1024, + 'num_blocks': 5, + 'reshape_depth': 16, + 'compress': 4 + } + warping_network = WarpingNetwork(num_kp = 21, block_expansion = 64, max_features = 512, num_down_blocks = 2, reshape_channel = 32, estimate_occlusion_map = True, dense_motion_params = dense_motion_params) + warping_network.load_state_dict(torch.load(warping_network_path, map_location='cpu', weights_only=True)) + warping_network.eval() + + spade_generator_path = CONFIG.get('auxiliary_models.paths', 'spade_generator_path') + spade_generator = SPADEDecoder(upscale = 2, block_expansion = 64, max_features = 512, num_down_blocks = 2) + spade_generator.load_state_dict(torch.load(spade_generator_path, map_location = 'cpu', weights_only = True)) + spade_generator.eval() + else: + feature_extractor = None + warping_network = None + spade_generator = None + return generator, discriminator, arcface, landmarker, motion_extractor, feature_extractor, warping_network, spade_generator + + +def create_trainer() -> Trainer: + trainer_max_epochs = CONFIG.getint('training.trainer', 'max_epochs') + output_directory_path = CONFIG.get('training.output', 'directory_path') + output_file_pattern = CONFIG.get('training.output', 'file_pattern') + os.makedirs(output_directory_path, exist_ok = True) + + return Trainer( + max_epochs = trainer_max_epochs, + precision = '16-mixed', + callbacks = + [ + ModelCheckpoint( + monitor = 'l_G', + dirpath = output_directory_path, + filename = output_file_pattern, + # every_n_epochs = 1, + every_n_train_steps = 1000, + save_top_k = 5, + mode = 'min', + save_last = True + ) + ], + log_every_n_steps = 10, + accumulate_grad_batches = 1, + ) def train(): - return print(load_motion_extractor()) + batch_size = CONFIG.getint('training.loader', 'batch_size') + num_workers = CONFIG.getint('training.loader', 'num_workers') + checkpoint_path = CONFIG.get('training.output', 'checkpoint_path') + dataset = DataLoaderVGG(CONFIG.get('preparing.dataset', 'dataset_path')) + + if not (checkpoint_path and os.path.exists(checkpoint_path)): + checkpoint_path = None + data_loader = DataLoader(dataset, batch_size = batch_size, shuffle = True, num_workers = num_workers, drop_last = True, pin_memory = True, persistent_workers = True) + face_swap_model = FaceSwapper(*load_models()) + trainer = create_trainer() + trainer.fit(face_swap_model, data_loader, ckpt_path = checkpoint_path) + + +class FaceSwapper(pytorch_lightning.LightningModule): + def __init__(self, generator, discriminator, arcface, landmarker, motion_extractor, feature_extractor, warping_network, spade_generator) -> None: + super().__init__() + self.generator = generator + self.discriminator = discriminator + self.arcface = arcface + self.landmarker = landmarker + self.motion_extractor = motion_extractor + self.feature_extractor = feature_extractor + self.warping_network = warping_network + self.spade_generator = spade_generator + + self.loss_adversarial_accumulated = 20 + self.automatic_optimization = False + self.batch_size = CONFIG.getint('training.loader', 'batch_size') + + + def forward(self, target_tensor : Tensor, source_embedding : Tensor) -> Tensor: + output = self.generator(target_tensor, source_embedding) + return output + + + def state_dict(self, *args, **kwargs): + return { + "generator": self.generator.state_dict(), + "discriminator": self.discriminator.state_dict(), + } + + def load_state_dict(self, state_dict, strict: bool = True): + if "generator" in state_dict: + self.generator.load_state_dict(state_dict["generator"], strict = strict) + if "discriminator" in state_dict: + self.discriminator.load_state_dict(state_dict["discriminator"], strict = strict) + + + def configure_optimizers(self) -> OptimizerLRScheduler: + generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr = CONFIG.getfloat('training.generator', 'learning_rate'), betas = (0.0, 0.999), weight_decay = 1e-4) + discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr = CONFIG.getfloat('training.discriminator', 'learning_rate'), betas = (0.0, 0.999), weight_decay = 1e-4) + generator_scheduler = torch.optim.lr_scheduler.StepLR(generator_optimizer, step_size = CONFIG.getint('training.schedulers', 'step'), gamma = CONFIG.getfloat('training.schedulers', 'gamma')) + discriminator_scheduler = torch.optim.lr_scheduler.StepLR(discriminator_optimizer, step_size = CONFIG.getint('training.schedulers', 'step'), gamma = CONFIG.getfloat('training.schedulers', 'gamma')) + return ( + { + "optimizer": generator_optimizer, + "lr_scheduler": generator_scheduler + }, + { + "optimizer": discriminator_optimizer, + "lr_scheduler": discriminator_scheduler + }) + + + def training_step(self, batch : Batch, batch_index : int) -> Tensor: + source_tensor, target_tensor, is_same_person = batch + generator_optimizer, discriminator_optimizer = self.optimizers() + source_embedding = self.get_arcface_embedding(source_tensor, (0, 0, 0, 0)) + + if random.random() > 0.5 and CONFIG.getboolean('preparing.augmentation', 'expression'): + target_tensor = randomize_expression(target_tensor, self.feature_extractor, self.motion_extractor, self.warping_network, self.spade_generator) + + swap_tensor, target_attributes = self(target_tensor, source_embedding) + discriminator_outputs = self.discriminator(swap_tensor) + + generator_losses = self.calc_generator_loss(swap_tensor, target_attributes, discriminator_outputs, batch) + generator_optimizer.zero_grad() + self.manual_backward(generator_losses.get('loss_generator')) + generator_optimizer.step() + + discriminator_losses = self.calc_discriminator_loss(swap_tensor, source_tensor) + discriminator_optimizer.zero_grad() + self.manual_backward(discriminator_losses.get('loss_discriminator')) + + if not CONFIG.getboolean('training.discriminator', 'disable') or self.loss_adversarial_accumulated < 0.4: + discriminator_optimizer.step() + + if self.global_step % CONFIG.getint('training.output', 'preview_frequency') == 0: + self.log_generator_preview(source_tensor, target_tensor, swap_tensor) + + if self.global_step % CONFIG.getint('training.output', 'validation_frequency') == 0: + self.log_validation_preview() + self.log('l_G', generator_losses.get('loss_generator'), prog_bar = True) + self.log('l_D', discriminator_losses.get('loss_discriminator'), prog_bar = True) + self.log('l_ADV_A', self.loss_adversarial_accumulated, prog_bar = True) + self.log('l_ADV', generator_losses.get('loss_adversarial'), prog_bar = False) + self.log('l_id', generator_losses.get('loss_identity'), prog_bar = True) + self.log('l_attr', generator_losses.get('loss_attribute'), prog_bar = True) + self.log('l_rec', generator_losses.get('loss_reconstruction'), prog_bar = True) + return generator_losses.get('loss_generator') + + + def calc_generator_loss(self, swap_tensor : Tensor, target_attributes : TargetAttributes, discriminator_outputs : DiscriminatorOutputs, batch : Batch) -> LossDict: + source_tensor, target_tensor, is_same_person = batch + generator_losses = {} + # adversarial loss + loss_adversarial = 0 + + for discriminator_output in discriminator_outputs: + loss_adversarial += hinge_loss(discriminator_output[0], True).mean(dim = [ 1, 2, 3 ]) + loss_adversarial = torch.mean(loss_adversarial) + generator_losses['loss_adversarial'] = loss_adversarial + generator_losses['loss_generator'] = loss_adversarial * CONFIG.getfloat('training.losses', 'weight_adversarial') + self.loss_adversarial_accumulated = self.loss_adversarial_accumulated * 0.98 + loss_adversarial.item() * 0.02 + + # identity loss + swap_embedding = self.get_arcface_embedding(swap_tensor, (30, 0, 10, 10)) + source_embedding = self.get_arcface_embedding(source_tensor, (30, 0, 10, 10)) + loss_identity = (1 - torch.cosine_similarity(source_embedding, swap_embedding, dim = 1)).mean() + generator_losses['loss_identity'] = loss_identity + generator_losses['loss_generator'] += loss_identity * CONFIG.getfloat('training.losses', 'weight_identity') + + # attribute loss + loss_attribute = 0 + swap_attributes = self.generator.get_attributes(swap_tensor) + + for swap_attribute, target_attribute in zip(swap_attributes, target_attributes): + loss_attribute += torch.mean(torch.pow(swap_attribute - target_attribute, 2).reshape(self.batch_size, -1), dim = 1).mean() + loss_attribute *= 0.5 + generator_losses['loss_attribute'] = loss_attribute + generator_losses['loss_generator'] += loss_attribute * CONFIG.getfloat('training.losses', 'weight_attribute') + + # reconstruction loss + loss_reconstruction = torch.sum(0.5 * torch.mean(torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1), dim = 1) * is_same_person) / (is_same_person.sum() + 1e-4) + loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean() + loss_reconstruction = loss_reconstruction * 0.3 + loss_ssim * 0.7 + generator_losses['loss_reconstruction'] = loss_reconstruction + generator_losses['loss_generator'] += CONFIG.getfloat('training.losses', 'weight_reconstruction') + + if CONFIG.getfloat('training.losses', 'weight_tsr') > 0: + # tsr loss + swap_motion_features = self.get_motion_features(swap_tensor) + target_motion_features = self.get_motion_features(target_tensor) + loss_tsr = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype) + + for swap_motion_feature, target_motion_feature in zip(swap_motion_features, target_motion_features): + loss_tsr += L2_loss(swap_motion_feature, target_motion_feature) + generator_losses['loss_tsr'] = loss_tsr + generator_losses['loss_generator'] += loss_tsr * CONFIG.getfloat('training.losses', 'weight_tsr') + + + if CONFIG.getfloat('training.losses', 'weight_eye_gaze') > 0 or CONFIG.getfloat('training.losses', 'weight_eye_open') > 0 or CONFIG.getfloat('training.losses', 'weight_lip_open') > 0: + swap_landmark_features = self.get_landmark_features(swap_tensor) + target_landmark_features = self.get_landmark_features(target_tensor) + + # eye gaze loss + loss_left_eye_gaze = L2_loss(swap_landmark_features[3], target_landmark_features[3]) + loss_right_eye_gaze = L2_loss(swap_landmark_features[4], target_landmark_features[4]) + loss_eye_gaze = loss_left_eye_gaze + loss_right_eye_gaze + generator_losses['loss_eye_gaze'] = loss_eye_gaze + generator_losses['loss_generator'] += loss_eye_gaze * CONFIG.getfloat('training.losses', 'weight_eye_gaze') + + # eye open loss + loss_left_eye_open = L2_loss(swap_landmark_features[0], target_landmark_features[0]) + loss_right_eye_open = L2_loss(swap_landmark_features[1], target_landmark_features[1]) + loss_eye_open = loss_left_eye_open + loss_right_eye_open + generator_losses['loss_eye_open'] = loss_eye_open * CONFIG.getfloat('training.losses', 'weight_eye_open') + generator_losses['loss_generator'] += loss_eye_open + + # lip open loss + loss_lip_open = L2_loss(swap_landmark_features[2], target_landmark_features[2]) + generator_losses['loss_lip_open'] = loss_lip_open * CONFIG.getfloat('training.losses', 'weight_lip_open') + generator_losses['loss_generator'] += loss_lip_open + return generator_losses + + + def calc_discriminator_loss(self, swap_tensor : Tensor, source_tensor : Tensor) -> LossDict: + discriminator_losses = {} + fake_discriminator_outputs = self.discriminator(swap_tensor.detach()) + loss_fake = 0 + + for fake_discriminator_output in fake_discriminator_outputs: + loss_fake += torch.mean(hinge_loss(fake_discriminator_output[0], False).mean(dim=[1, 2, 3])) + true_discriminator_outputs = self.discriminator(source_tensor) + loss_true = 0 + + for true_discriminator_output in true_discriminator_outputs: + loss_true += torch.mean(hinge_loss(true_discriminator_output[0], True).mean(dim=[1, 2, 3])) + discriminator_losses['loss_discriminator'] = (loss_true.mean() + loss_fake.mean()) * 0.5 + return discriminator_losses + + + def get_arcface_embedding(self, vision_tensor : Tensor, padding : Tuple[int, int, int, int]) -> Tensor: + _, _, height, width = vision_tensor.shape + crop_height = int(height * 0.0586) + crop_width = int(width * 0.0586) + crop_vision_tensor = vision_tensor[:, :, crop_height : height - crop_height, crop_width : width - crop_width] + crop_vision_tensor = torch.nn.functional.interpolate(crop_vision_tensor, size = (112, 112), mode = 'bilinear') + crop_vision_tensor[:, :, :padding[0], :] = 0 + crop_vision_tensor[:, :, -padding[1]:, :] = 0 + crop_vision_tensor[:, :, :, :padding[2]] = 0 + crop_vision_tensor[:, :, :, -padding[3]:] = 0 + embedding = self.arcface(crop_vision_tensor) + embedding = torch.nn.functional.normalize(embedding, p = 2, dim = 1) + return embedding + + + def get_landmark_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]: + vision_tensor_norm = (vision_tensor + 1) * 0.5 + vision_tensor_norm = torch.nn.functional.interpolate(vision_tensor_norm, size = (224, 224), mode = 'bilinear') + landmarks = self.landmarker(vision_tensor_norm)[2] + landmarks = landmarks.view(-1, 203, 2) * 256 + left_eye_open_ratio = calc_distance_ratio(landmarks, (6, 18, 0, 12)) + right_eye_open_ratio = calc_distance_ratio(landmarks, (30, 42, 24, 36)) + lip_open_ratio = calc_distance_ratio(landmarks, (90, 102, 48, 66)) + left_eye_gaze = landmarks[:, 198] + right_eye_gaze = landmarks[:, 197] + return left_eye_open_ratio, right_eye_open_ratio, lip_open_ratio, left_eye_gaze, right_eye_gaze + + + def get_motion_features(self, vision_tensor : Tensor) -> Tuple[Tensor, Tensor, Tensor]: + vision_tensor_norm = (vision_tensor + 1) * 0.5 + motion_dict = self.motion_extractor(vision_tensor_norm) + translation = motion_dict.get('t') + scale = motion_dict.get('scale') + rotation = torch.cat([ motion_dict.get('pitch'), motion_dict.get('yaw'), motion_dict.get('roll') ], dim = 1) + return translation, scale, rotation + + + def log_generator_preview(self, source_tensor, target_tensor, swap_tensor): + max_preview = 8 + source_tensor = source_tensor[:max_preview] + target_tensor = target_tensor[:max_preview] + swap_tensor = swap_tensor[:max_preview] + rows = [torch.cat([src, tgt, swp], dim=2) for src, tgt, swp in zip(source_tensor, target_tensor, swap_tensor)] + grid = torchvision.utils.make_grid(torch.cat(rows, dim=1).unsqueeze(0), nrow=1, normalize=True, scale_each=True) + os.makedirs("previews", exist_ok=True) + torchvision.utils.save_image(grid, f"previews/step_{self.global_step}.jpg") + self.logger.experiment.add_image("Generator Preview", grid, self.global_step) + + def log_validation_preview(self): + validation_source_path = CONFIG.get('training.validation', 'sources') + validation_target_path = CONFIG.get('training.validation', 'targets') + sources = [read_image(os.path.join(validation_source_path, f)) for f in os.listdir(validation_source_path) if f.lower().endswith('.jpg') or f.lower().endswith('.png')] + targets = [read_image(os.path.join(validation_target_path, f)) for f in os.listdir(validation_target_path) if f.lower().endswith('.jpg') or f.lower().endswith('.png')] + transforms = torchvision.transforms.Compose( + [ + torchvision.transforms.Resize((256, 256), interpolation = torchvision.transforms.InterpolationMode.BICUBIC), + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ]) + to_numpy = lambda x: (x.cpu().detach().numpy()[0].transpose(1, 2, 0).clip(-1, 1)[:,:,::-1] + 1) * 127.5 + self.generator.eval() + results = [] + + for source, target in zip(sources, targets): + source_tensor = transforms(source).unsqueeze(0).to(self.device).half() + target_tensor = transforms(target).unsqueeze(0).to(self.device).half() + source_embedding = self.get_arcface_embedding(source_tensor, (0, 0, 0, 0)) + + with torch.no_grad(): + output, _ = self.generator(target_tensor, source_embedding) + results.append(numpy.hstack([to_numpy(source_tensor), to_numpy(target_tensor), to_numpy(output)])) + preview = numpy.vstack(results) + os.makedirs("validation_previews", exist_ok=True) + cv2.imwrite(f"validation_previews/step_{self.global_step}.jpg", preview) + self.generator.train() diff --git a/face_swapper/src/sub_typing.py b/face_swapper/src/typing.py similarity index 50% rename from face_swapper/src/sub_typing.py rename to face_swapper/src/typing.py index 4297b40..529cb5f 100644 --- a/face_swapper/src/sub_typing.py +++ b/face_swapper/src/typing.py @@ -1,12 +1,14 @@ -from typing import Any, Tuple +from typing import Any, Tuple, List, Dict, Optional from numpy.typing import NDArray from torch import Tensor from torch.utils.data import DataLoader -Batch = Tuple[Tensor, Tensor, int] +Batch = Tuple[Any, Any, Any] Loader = DataLoader[Tuple[Tensor, ...]] -UNetAttributes = Tuple[Tensor, ...] +TargetAttributes = Tuple[Tensor, ...] +DiscriminatorOutputs = List[List[Tensor]] +LossDict = Dict[str, Tensor] Embedding = NDArray[Any] VisionFrame = NDArray[Any] diff --git a/mypy.ini b/mypy.ini index 081e7bf..75ef9f3 100644 --- a/mypy.ini +++ b/mypy.ini @@ -6,4 +6,4 @@ disallow_untyped_defs = True ignore_missing_imports = True strict_optional = False explicit_package_bases = True -exclude = face_swapper/LivePortrait +exclude = face_swapper