From 5f361d24e704b3da243403c11cf006c33cdf31ce Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Mon, 9 Dec 2024 21:42:46 +0530 Subject: [PATCH] new swapper --- .flake8 | 2 +- .gitmodules | 3 + face_swapper/LICENSE.md | 3 + face_swapper/LivePortrait | 1 + face_swapper/config.ini | 61 +++++++++ face_swapper/src/augmentations.py | 27 ++++ face_swapper/src/data_loader.py | 82 ++++++++++++ face_swapper/src/discriminator.py | 85 ++++++++++++ face_swapper/src/generator.py | 212 ++++++++++++++++++++++++++++++ face_swapper/src/model_loader.py | 50 +++++++ face_swapper/src/sub_typing.py | 12 ++ face_swapper/src/training.py | 5 + face_swapper/train.py | 7 + mypy.ini | 1 + 14 files changed, 550 insertions(+), 1 deletion(-) create mode 100644 .gitmodules create mode 100644 face_swapper/LICENSE.md create mode 160000 face_swapper/LivePortrait create mode 100644 face_swapper/config.ini create mode 100644 face_swapper/src/augmentations.py create mode 100644 face_swapper/src/data_loader.py create mode 100644 face_swapper/src/discriminator.py create mode 100644 face_swapper/src/generator.py create mode 100644 face_swapper/src/model_loader.py create mode 100644 face_swapper/src/sub_typing.py create mode 100644 face_swapper/src/training.py create mode 100644 face_swapper/train.py diff --git a/.flake8 b/.flake8 index a840286..bd09e59 100644 --- a/.flake8 +++ b/.flake8 @@ -4,4 +4,4 @@ plugins = flake8-import-order application_import_names = arcface_converter import-order-style = pycharm per-file-ignores = preparing.py:E402 - +exclude = LivePortrait diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..79022b3 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "face_swapper/LivePortrait"] + path = face_swapper/LivePortrait + url = https://github.com/KwaiVGI/LivePortrait diff --git a/face_swapper/LICENSE.md b/face_swapper/LICENSE.md new file mode 100644 index 0000000..158be07 --- /dev/null +++ b/face_swapper/LICENSE.md @@ -0,0 +1,3 @@ +Non-Commercial license + +Copyright (c) 2024 Henry Ruhs diff --git a/face_swapper/LivePortrait b/face_swapper/LivePortrait new file mode 160000 index 0000000..632da74 --- /dev/null +++ b/face_swapper/LivePortrait @@ -0,0 +1 @@ +Subproject commit 632da7486d2c3fb86663fc44190a09aca4e1a8de diff --git a/face_swapper/config.ini b/face_swapper/config.ini new file mode 100644 index 0000000..0651a02 --- /dev/null +++ b/face_swapper/config.ini @@ -0,0 +1,61 @@ +[preparing.dataset] +dataset_path = + +[preparing.dataloader] +same_person_probability = 0.2 + +[preparing.augmentation] +expression_augmentation = false + +[training.loader] +batch_size = 6 +num_workers = 8 + +[training.generator] +num_blocks = 2 +id_channels = 512 + +[training.discriminator] +input_channels = 3 +num_filters = 64 +num_layers = 5 +num_discriminators = 3 + +[auxiliary_models.paths] +arcface_path = +landmarker_path = +motion_extractor_path = /home/hari/Documents/Github/Phantom/assets/pretrained_models/liveportrait_motion_extractor.pth +feature_extractor_path = +warping_netwrk_path = +spade_generator_path = + +[training.losses] +weight_adversarial = 1 +weight_identity = 20 +weight_attribute = 10 +weight_reconstruction = 10 +weight_tsr = 0 +weight_expression = 0 + +[training.optimizers] +scheduler_step = 5000 +scheduler_gamma = 0.2 +generator_learning_rate = 0.0004 +discriminator_learning_rate = 0.0004 + +[training.trainer] +epochs = 50 +disable_discriminator = false + +[training.output] +directory_path = +file_pattern = + +[exporting] +directory_path = +source_path = +target_path = +opset_version = + +[execution] +providers = diff --git a/face_swapper/src/augmentations.py b/face_swapper/src/augmentations.py new file mode 100644 index 0000000..5ed35d2 --- /dev/null +++ b/face_swapper/src/augmentations.py @@ -0,0 +1,27 @@ +import torch +from torch import Tensor + + +def apply_random_motion_blur(tensor_image : Tensor) -> Tensor: + kernel_size = 9 + kernel = torch.zeros((kernel_size, kernel_size), dtype=torch.float32) + random_angle = torch.empty(1).uniform_(-2 * torch.pi, 2 * torch.pi) + dx = torch.cos(random_angle) + dy = torch.sin(random_angle) + center = kernel_size // 2 + + for i in range(kernel_size): + x = int(center + (i - center) * dx) + y = int(center + (i - center) * dy) + if 0 <= x < kernel_size and 0 <= y < kernel_size: + kernel[y, x] = 1 + kernel /= kernel.sum() + kernel = kernel.unsqueeze(0).unsqueeze(0) + blurred_channels = [] + + for channel in tensor_image: + channel = channel.unsqueeze(0).unsqueeze(0) + channel = torch.nn.functional.conv2d(channel, kernel, padding=kernel_size // 2) + channel = channel.squeeze(0).squeeze(0) + blurred_channels.append(channel) + return torch.stack(blurred_channels) diff --git a/face_swapper/src/data_loader.py b/face_swapper/src/data_loader.py new file mode 100644 index 0000000..6019747 --- /dev/null +++ b/face_swapper/src/data_loader.py @@ -0,0 +1,82 @@ +import configparser +import glob +import random + +import cv2 +import torchvision.transforms as transforms +import tqdm +from PIL import Image +from torch.utils.data import TensorDataset + +from .augmentations import apply_random_motion_blur +from .sub_typing import Batch + +CONFIG = configparser.ConfigParser() +CONFIG.read('config.ini') + + +def read_image(image_path: str) -> Image.Image: + image = cv2.imread(image_path)[:, :, ::-1] + pil_image = Image.fromarray(image) + return pil_image + + +class DataLoaderVGG(TensorDataset): + def __init__(self, dataset_path : str) -> None: + self.same_person_probability = float(CONFIG.get('preparing.dataloader', 'same_person_probability')) + self.image_paths = glob.glob('{}/*/*.*g'.format(dataset_path)) + self.folder_paths = glob.glob('{}/*'.format(dataset_path)) + self.image_path_dict = {} + + for folder_path in tqdm.tqdm(self.folder_paths): + image_paths = glob.glob('{}/*'.format(folder_path)) + self.image_path_dict[folder_path] = image_paths + self.dataset_total = len(self.image_paths) + self.transforms_basic = transforms.Compose( + [ + transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + self.transforms_moderate = transforms.Compose( + [ + transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1), + transforms.RandomAffine(4, translate = (0.01, 0.01), scale = (0.98, 1.02), shear = (1, 1), fill = 0), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + self.transforms_complex = transforms.Compose( + [ + transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC), + transforms.RandomHorizontalFlip(p = 0.5), + transforms.RandomApply([ apply_random_motion_blur ], p = 0.3), + transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation = 0.2, hue = 0.1), + transforms.RandomAffine(8, translate = (0.02, 0.02), scale = (0.98, 1.02), shear = (1, 1), fill = 0), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + + def __getitem__(self, item : int) -> Batch: + source_image_path = self.image_paths[item] + source = read_image(source_image_path) + + if random.random() > self.same_person_probability: + is_same_person = 0 + target_image_path = random.choice(self.image_paths) + target = read_image(target_image_path) + source_transform = self.transforms_moderate(source) + target_transform = self.transforms_complex(target) + else: + is_same_person = 1 + source_folder_path = '/'.join(source_image_path.split('/')[:-1]) + target_image_path = random.choice(self.image_path_dict[source_folder_path]) + target = read_image(target_image_path) + source_transform = self.transforms_basic(source) + target_transform = self.transforms_basic(target) + + return source_transform, target_transform, is_same_person + + def __len__(self) -> int: + return self.dataset_total diff --git a/face_swapper/src/discriminator.py b/face_swapper/src/discriminator.py new file mode 100644 index 0000000..b1b3938 --- /dev/null +++ b/face_swapper/src/discriminator.py @@ -0,0 +1,85 @@ +from typing import List + +import numpy +import torch.nn as nn + +from .sub_typing import Tensor + + +class NLayerDiscriminator(nn.Module): + def __init__(self, input_channels : int, num_filters : int, num_layers : int) -> None: + super(NLayerDiscriminator, self).__init__() + self.num_layers = num_layers + kernel_size = 4 + padding_size = int(numpy.ceil((kernel_size - 1.0) / 2)) + model_layers = [ + [ + nn.Conv2d(input_channels, num_filters, kernel_size = kernel_size, stride = 2, padding = padding_size), + nn.LeakyReLU(0.2, True) + ]] + current_filters = num_filters + + for layer_index in range(1, num_layers): + previous_filters = current_filters + current_filters = min(current_filters * 2, 512) + model_layers += [ + [ + nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, stride = 2, padding = padding_size), + nn.InstanceNorm2d(current_filters), nn.LeakyReLU(0.2, True) + ]] + previous_filters = current_filters + current_filters = min(current_filters * 2, 512) + model_layers += [ + [ + nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, stride = 1, padding = padding_size), + nn.InstanceNorm2d(current_filters), + nn.LeakyReLU(0.2, True) + ]] + model_layers += [ + [ + nn.Conv2d(current_filters, 1, kernel_size = kernel_size, stride = 1, padding = padding_size) + ]] + combined_layers = [] + + for layer in model_layers: + combined_layers += layer + self.model = nn.Sequential(*combined_layers) + + def forward(self, input_tensor : Tensor) -> Tensor: + return self.model(input_tensor) + + +# input_channels=3, num_filters=64, num_layers=5, num_discriminators=3 +class MultiscaleDiscriminator(nn.Module): + def __init__(self, input_channels : int, num_filters : int, num_layers : int, num_discriminators : int): + super(MultiscaleDiscriminator, self).__init__() + self.num_discriminators = num_discriminators + self.num_layers = num_layers + + for discriminator_index in range(num_discriminators): + single_discriminator = NLayerDiscriminator(input_channels, num_filters, num_layers) + setattr(self, 'discriminator_layer_{}'.format(discriminator_index), single_discriminator.model) + self.downsample = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = [ 1, 1 ], count_include_pad = False) # type:ignore[arg-type] + + def single_discriminator_forward(self, model_layers : nn.Sequential, input_tensor : Tensor) -> List[Tensor]: + + if self.return_intermediate_features: + feature_maps = [ input_tensor ] + + for layer in model_layers: + feature_maps.append(layer(feature_maps[-1])) + return feature_maps[1:] + else: + return [ model_layers(input_tensor) ] + + def forward(self, input_tensor : Tensor) -> List[Tensor]: + discriminator_outputs = [] + downsampled_input = input_tensor + + for discriminator_index in range(self.num_discriminators): + model_layers = getattr(self, 'discriminator_layer_{}'.format(self.num_discriminators - 1 - discriminator_index)) + discriminator_outputs.append(self.single_discriminator_forward(model_layers, downsampled_input)) + + if discriminator_index != (self.num_discriminators - 1): + downsampled_input = self.downsample(downsampled_input) + return discriminator_outputs diff --git a/face_swapper/src/generator.py b/face_swapper/src/generator.py new file mode 100644 index 0000000..8b8bff2 --- /dev/null +++ b/face_swapper/src/generator.py @@ -0,0 +1,212 @@ +from typing import Tuple + +import torch +import torch.nn as nn + +from .sub_typing import Tensor, UNetAttributes + + +class AdaptiveEmbeddingIntegrationNetwork(nn.Module): + def __init__(self, id_channels : int, num_blocks : int) -> None: + super(AdaptiveEmbeddingIntegrationNetwork, self).__init__() + self.encoder = UNet() + self.generator = AdaptiveAttentionalDenorm_Generator(id_channels, num_blocks) + + def forward(self, target : Tensor, source_embedding : Tensor) -> Tuple[Tensor, UNetAttributes]: + target_attributes = self.get_attributes(target) + swap = self.generator(target_attributes, source_embedding) + return swap, target_attributes + + def get_attributes(self, target : Tensor) -> UNetAttributes: + return self.encoder(target) + + +class AdaptiveAttentionalDenorm_Generator(nn.Module): + def __init__(self, id_channels : int, num_blocks : int) -> None: + super(AdaptiveAttentionalDenorm_Generator, self).__init__() + self.upsample = Upsample(id_channels, 1024 * 4) + self.block_1 = AdaptiveAttentionalDenorm_ResBlock(1024, 1024, 1024, id_channels, num_blocks) + self.block_2 = AdaptiveAttentionalDenorm_ResBlock(1024, 1024, 2048, id_channels, num_blocks) + self.block_3 = AdaptiveAttentionalDenorm_ResBlock(1024, 1024, 1024, id_channels, num_blocks) + self.block_4 = AdaptiveAttentionalDenorm_ResBlock(1024, 512, 512, id_channels, num_blocks) + self.block_5 = AdaptiveAttentionalDenorm_ResBlock(512, 256, 256, id_channels, num_blocks) + self.block_6 = AdaptiveAttentionalDenorm_ResBlock(256, 128, 128, id_channels, num_blocks) + self.block_7 = AdaptiveAttentionalDenorm_ResBlock(128, 64, 64, id_channels, num_blocks) + self.block_7 = AdaptiveAttentionalDenorm_ResBlock(64, 3, 64, id_channels, num_blocks) + self.apply(initialize_weight) + + def forward(self, target_attributes : UNetAttributes, source_embedding : Tensor) -> Tensor: + feature_map = self.upsample(source_embedding) + feature_map_1 = nn.functional.interpolate(self.block_1(feature_map, target_attributes[0], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False) + feature_map_2 = nn.functional.interpolate(self.block_2(feature_map_1, target_attributes[1], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False) + feature_map_3 = nn.functional.interpolate(self.block_3(feature_map_2, target_attributes[2], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False) + feature_map_4 = nn.functional.interpolate(self.block_4(feature_map_3, target_attributes[3], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False) + feature_map_5 = nn.functional.interpolate(self.block_5(feature_map_4, target_attributes[4], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False) + feature_map_6 = nn.functional.interpolate(self.block_6(feature_map_5, target_attributes[5], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False) + feature_map_7 = nn.functional.interpolate(self.block_7(feature_map_6, target_attributes[6], source_embedding), scale_factor = 2, mode ='bilinear', align_corners = False) + output = self.block_7(feature_map_7, target_attributes[7], source_embedding) + return torch.tanh(output) + + +class UNet(nn.Module): + def __init__(self) -> None: + super(UNet, self).__init__() + self.downsampler_1 = Conv4x4(3, 32) + self.downsampler_2 = Conv4x4(32, 64) + self.downsampler_3 = Conv4x4(64, 128) + self.downsampler_4 = Conv4x4(128, 256) + self.downsampler_5 = Conv4x4(256, 512) + self.downsampler_6 = Conv4x4(512, 1024) + self.bottleneck = Conv4x4(1024, 1024) + self.upsampler_1 = DeConv4x4(1024, 1024) + self.upsampler_2 = DeConv4x4(2048, 512) + self.upsampler_3 = DeConv4x4(1024, 256) + self.upsampler_4 = DeConv4x4(512, 128) + self.upsampler_5 = DeConv4x4(256, 64) + self.upsampler_6 = DeConv4x4(128, 32) + self.apply(initialize_weight) + + def forward(self, input_tensor : Tensor) -> UNetAttributes: + downsample_feature_1 = self.downsampler_1(input_tensor) + downsample_feature_2 = self.downsampler_2(downsample_feature_1) + downsample_feature_3 = self.downsampler_3(downsample_feature_2) + downsample_feature_4 = self.downsampler_4(downsample_feature_3) + downsample_feature_5 = self.downsampler_5(downsample_feature_4) + downsample_feature_6 = self.downsampler_6(downsample_feature_5) + bottleneck_output = self.bottleneck(downsample_feature_6) + upsample_feature_1 = self.upsampler_1(bottleneck_output, downsample_feature_6) + upsample_feature_2 = self.upsampler_2(upsample_feature_1, downsample_feature_5) + upsample_feature_3 = self.upsampler_3(upsample_feature_2, downsample_feature_4) + upsample_feature_4 = self.upsampler_4(upsample_feature_3, downsample_feature_3) + upsample_feature_5 = self.upsampler_5(upsample_feature_4, downsample_feature_2) + upsample_feature_6 = self.upsampler_6(upsample_feature_5, downsample_feature_1) + output = nn.functional.interpolate(upsample_feature_6, scale_factor = 2, mode = 'bilinear', align_corners = False) + return bottleneck_output, upsample_feature_1, upsample_feature_2, upsample_feature_3, upsample_feature_4, upsample_feature_5, upsample_feature_6, output + + +class AdaptiveAttentionalDenorm_Layer(nn.Module): + def __init__(self, input_channels : int, attr_channels : int, id_channels : int) -> None: + super(AdaptiveAttentionalDenorm_Layer, self).__init__() + self.attr_channels = attr_channels + self.id_channels = id_channels + self.input_channels = input_channels + self.conv_gamma = nn.Conv2d(attr_channels, input_channels, kernel_size = 1, stride = 1, padding = 0, bias = True) + self.conv_beta = nn.Conv2d(attr_channels, input_channels, kernel_size = 1, stride = 1, padding = 0, bias = True) + self.fc_gamma = nn.Linear(id_channels, input_channels) + self.fc_beta = nn.Linear(id_channels, input_channels) + self.instance_norm = nn.InstanceNorm2d(input_channels, affine=False) + self.conv_mask = nn.Conv2d(input_channels, 1, kernel_size = 1, stride = 1, padding = 0, bias = True) + + def forward(self, feature_map : Tensor, attr_embedding : Tensor, id_embedding : Tensor) -> Tensor: + feature_map = self.instance_norm(feature_map) + attr_gamma = self.conv_gamma(attr_embedding) + attr_beta = self.conv_beta(attr_embedding) + attr_modulation = attr_gamma * feature_map + attr_beta + id_gamma = self.fc_gamma(id_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map) + id_beta = self.fc_beta(id_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map) + id_modulation = id_gamma * feature_map + id_beta + feature_mask = torch.sigmoid(self.conv_mask(feature_map)) + feature_blend = (1 - feature_mask) * attr_modulation + feature_mask * id_modulation + return feature_blend + + +class AddBlocksSequential(nn.Sequential): + def forward(self, *inputs : Tuple[Tensor, ...]) -> Tensor: + feature_map, attr_embedding, id_embedding = inputs + + for index, module in enumerate(self._modules.values()): + if index % 3 == 0 and index > 0: + inputs = (inputs, attr_embedding, id_embedding) + if type(inputs) == tuple: + inputs = module(*inputs) + else: + inputs = module(inputs) + return inputs + + +class AdaptiveAttentionalDenorm_ResBlock(nn.Module): + def __init__(self, in_channels : int, out_channels : int, attr_channels : int, id_channels : int, num_blocks : int) -> None: + super(AdaptiveAttentionalDenorm_ResBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + primary_add_blocks = [] + + for i in range(num_blocks): + intermediate_channels = in_channels if i < (num_blocks - 1) else out_channels + primary_add_blocks.extend( + [ + AdaptiveAttentionalDenorm_Layer(in_channels, attr_channels, id_channels), + nn.ReLU(inplace = True), + nn.Conv2d(in_channels, intermediate_channels, kernel_size = 3, stride = 1, padding = 1, bias = False) + ]) + self.primary_add_blocks = AddBlocksSequential(*primary_add_blocks) + + if in_channels != out_channels: + auxiliary_add_blocks = \ + [ + AdaptiveAttentionalDenorm_Layer(in_channels, attr_channels, id_channels), + nn.ReLU(inplace = True), + nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, bias = False) + ] + self.auxiliary_add_blocks = AddBlocksSequential(*auxiliary_add_blocks) + + def forward(self, feature_map : Tensor, attr_embedding : Tensor, id_embedding : Tensor) -> Tensor: + primary_feature = self.primary_add_blocks(feature_map, attr_embedding, id_embedding) + + if self.in_channels != self.out_channels: + feature_map = self.auxiliary_add_blocks(feature_map, attr_embedding, id_embedding) + output_feature = primary_feature + feature_map + return output_feature + + +class Conv4x4(nn.Module): + def __init__(self, in_channels : int, out_channels : int) -> None: + super(Conv4x4, self).__init__() + self.conv = nn.Conv2d(in_channels=in_channels, out_channels = out_channels, kernel_size = 4, stride = 2, padding = 1, bias = False) + self.batch_norm = nn.BatchNorm2d(out_channels) + self.leaky_relu = nn.LeakyReLU(0.1, inplace = True) + + def forward(self, input : Tensor) -> Tensor: + output = self.conv(input) + output = self.batch_norm(output) + output = self.leaky_relu(output) + return output + + +class DeConv4x4(nn.Module): + def __init__(self, in_channels : int, out_channels : int) -> None: + super(DeConv4x4, self).__init__() + self.deconv = nn.ConvTranspose2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 4, stride = 2, padding = 1, bias = False) + self.batch_norm = nn.BatchNorm2d(out_channels) + self.leaky_relu = nn.LeakyReLU(0.1, inplace = True) + + def forward(self, input : Tensor, skip_connection : Tensor) -> Tensor: + output = self.deconv(input) + output = self.batch_norm(output) + output = self.leaky_relu(output) + output = torch.cat((output, skip_connection), dim = 1) + return output + + +class Upsample(nn.Module): + def __init__(self, in_channels : int, out_channels : int): + super(Upsample, self).__init__() + self.initial_conv = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1) + self.pixel_shuffle = nn.PixelShuffle(upscale_factor = 2) + + def forward(self, input : Tensor) -> Tensor: + output = self.initial_conv(input.view(input.shape[0], -1, 1, 1)) + output = self.pixel_shuffle(output) + return output + + +def initialize_weight(module : nn.Module) -> None: + if isinstance(module, nn.Linear): + module.weight.data.normal_(0, 0.001) + module.bias.data.zero_() + + if isinstance(module, nn.Conv2d): + nn.init.xavier_normal_(module.weight.data) + + if isinstance(module, nn.ConvTranspose2d): + nn.init.xavier_normal_(module.weight.data) diff --git a/face_swapper/src/model_loader.py b/face_swapper/src/model_loader.py new file mode 100644 index 0000000..db83a1a --- /dev/null +++ b/face_swapper/src/model_loader.py @@ -0,0 +1,50 @@ +import configparser + +import torch +import torch.nn as nn + +from .discriminator import MultiscaleDiscriminator +from .generator import AdaptiveEmbeddingIntegrationNetwork + +CONFIG = configparser.ConfigParser() +CONFIG.read('config.ini') + + +def load_generator() -> nn.Module: + id_channels = CONFIG.getint('training.generator', 'id_channels') + num_blocks = CONFIG.getint('training.generator', 'num_blocks') + generator = AdaptiveEmbeddingIntegrationNetwork(id_channels, num_blocks) + return generator + + +def load_discriminator() -> nn.Module: + input_channels = CONFIG.getint('training.discriminator', 'input_channels') + num_filters = CONFIG.getint('training.discriminator', 'num_filters') + num_layers = CONFIG.getint('training.discriminator', 'num_layers') + num_discriminators = CONFIG.getint('training.discriminator', 'num_discriminators') + discriminator = MultiscaleDiscriminator(input_channels, num_filters, num_layers, num_discriminators) + return discriminator + + +def load_arcface() -> nn.Module: + model_path = CONFIG.get('auxiliary_models.paths', 'arcface_path') + arcface = torch.load(model_path, map_location = 'cpu', weights_only = False) + arcface.eval() + return arcface + + +def load_landmarker() -> nn.Module: + model_path = CONFIG.get('auxiliary_models.paths', 'landmarker_path') + landmarker = torch.load(model_path, map_location = 'cpu', weights_only = False) + landmarker.eval() + return landmarker + + +def load_motion_extractor() -> nn.Module: + from LivePortrait.src.modules.motion_extractor import MotionExtractor + + model_path = CONFIG.get('auxiliary_models.paths', 'motion_extractor_path') + motion_extractor = MotionExtractor(num_kp = 21, backbone = 'convnextv2_tiny') + motion_extractor.load_state_dict(torch.load(model_path, map_location = 'cpu', weights_only = True)) + motion_extractor.eval() + return motion_extractor diff --git a/face_swapper/src/sub_typing.py b/face_swapper/src/sub_typing.py new file mode 100644 index 0000000..4297b40 --- /dev/null +++ b/face_swapper/src/sub_typing.py @@ -0,0 +1,12 @@ +from typing import Any, Tuple + +from numpy.typing import NDArray +from torch import Tensor +from torch.utils.data import DataLoader + +Batch = Tuple[Tensor, Tensor, int] +Loader = DataLoader[Tuple[Tensor, ...]] +UNetAttributes = Tuple[Tensor, ...] + +Embedding = NDArray[Any] +VisionFrame = NDArray[Any] diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py new file mode 100644 index 0000000..1482236 --- /dev/null +++ b/face_swapper/src/training.py @@ -0,0 +1,5 @@ +from .model_loader import load_motion_extractor + + +def train(): + return print(load_motion_extractor()) diff --git a/face_swapper/train.py b/face_swapper/train.py new file mode 100644 index 0000000..bbff847 --- /dev/null +++ b/face_swapper/train.py @@ -0,0 +1,7 @@ +#!/usr/bin/env python3 + +from src.training import train + + +if __name__ == '__main__': + train() diff --git a/mypy.ini b/mypy.ini index 64218bc..d144335 100644 --- a/mypy.ini +++ b/mypy.ini @@ -5,3 +5,4 @@ disallow_untyped_calls = True disallow_untyped_defs = True ignore_missing_imports = True strict_optional = False +exclude = ^LivePortrait