From a52686310fd63adc92fa491bb1c041122c0c564e Mon Sep 17 00:00:00 2001 From: harisreedhar Date: Mon, 13 Jan 2025 12:25:02 +0530 Subject: [PATCH] clean generator, discriminator and typing --- .flake8 | 2 +- face_swapper/src/discriminator.py | 6 +- face_swapper/src/generator.py | 207 ++++++++++++++---------------- face_swapper/src/typing.py | 3 +- mypy.ini | 2 +- 5 files changed, 104 insertions(+), 116 deletions(-) diff --git a/.flake8 b/.flake8 index 8103c45..0ed7d05 100644 --- a/.flake8 +++ b/.flake8 @@ -4,4 +4,4 @@ plugins = flake8-import-order application_import_names = arcface_converter import-order-style = pycharm per-file-ignores = preparing.py:E402 -exclude = face_swapper +exclude = face_swapper/LivePortrait diff --git a/face_swapper/src/discriminator.py b/face_swapper/src/discriminator.py index b38ca9e..e73f458 100644 --- a/face_swapper/src/discriminator.py +++ b/face_swapper/src/discriminator.py @@ -1,9 +1,7 @@ -from typing import List - import numpy import torch.nn as nn -from .typing import Tensor, DiscriminatorOutputs +from .typing import DiscriminatorOutputs, List, Tensor class NLayerDiscriminator(nn.Module): @@ -60,11 +58,9 @@ def __init__(self, input_channels : int, num_filters : int, num_layers : int, nu setattr(self, 'discriminator_layer_{}'.format(discriminator_index), single_discriminator.model) self.downsample = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = [ 1, 1 ], count_include_pad = False) # type:ignore[arg-type] - def single_discriminator_forward(self, model_layers : nn.Sequential, input_tensor : Tensor) -> List[Tensor]: return [ model_layers(input_tensor) ] - def forward(self, input_tensor : Tensor) -> DiscriminatorOutputs: discriminator_outputs = [] downsampled_input = input_tensor diff --git a/face_swapper/src/generator.py b/face_swapper/src/generator.py index 3dd18c9..e93a668 100644 --- a/face_swapper/src/generator.py +++ b/face_swapper/src/generator.py @@ -1,132 +1,118 @@ import torch import torch.nn as nn -import torch.nn.functional as F + +from .typing import IDEmbedding, TargetAttributes, Tensor, Tuple class AdaptiveEmbeddingIntegrationNetwork(nn.Module): - def __init__(self, id_channels=512, num_blocks=2): + def __init__(self, id_channels : int, num_blocks : int) -> None: super(AdaptiveEmbeddingIntegrationNetwork, self).__init__() self.encoder = UNet() self.generator = AADGenerator(id_channels, num_blocks) - def forward(self, target, source_embedding): + def forward(self, target : Tensor, source_embedding : IDEmbedding) -> Tuple[Tensor, TargetAttributes]: target_attributes = self.get_attributes(target) swap = self.generator(target_attributes, source_embedding) return swap, target_attributes - def get_attributes(self, target): + def get_attributes(self, target : Tensor) -> TargetAttributes: return self.encoder(target) class AADGenerator(nn.Module): - def __init__(self, id_channels=512, num_blocks=2): + def __init__(self, id_channels : int, num_blocks : int) -> None: super(AADGenerator, self).__init__() - self.upsample = Upsample(id_channels, 1024 * 4) - self.AADBlk1 = AAD_ResBlk(1024, 1024, 1024, id_channels, num_blocks) - self.AADBlk2 = AAD_ResBlk(1024, 1024, 2048, id_channels, num_blocks) - self.AADBlk3 = AAD_ResBlk(1024, 1024, 1024, id_channels, num_blocks) - self.AADBlk4 = AAD_ResBlk(1024, 512, 512, id_channels, num_blocks) - self.AADBlk5 = AAD_ResBlk(512, 256, 256, id_channels, num_blocks) - self.AADBlk6 = AAD_ResBlk(256, 128, 128, id_channels, num_blocks) - self.AADBlk7 = AAD_ResBlk(128, 64, 64, id_channels, num_blocks) - self.AADBlk8 = AAD_ResBlk(64, 3, 64, id_channels, num_blocks) + self.upsample = PixelShuffleUpsample(id_channels, 1024 * 4) + self.res_block_1 = AADResBlock(1024, 1024, 1024, id_channels, num_blocks) + self.res_block_2 = AADResBlock(1024, 1024, 2048, id_channels, num_blocks) + self.res_block_3 = AADResBlock(1024, 1024, 1024, id_channels, num_blocks) + self.res_block_4 = AADResBlock(1024, 512, 512, id_channels, num_blocks) + self.res_block_5 = AADResBlock(512, 256, 256, id_channels, num_blocks) + self.res_block_6 = AADResBlock(256, 128, 128, id_channels, num_blocks) + self.res_block_7 = AADResBlock(128, 64, 64, id_channels, num_blocks) + self.res_block_8 = AADResBlock(64, 3, 64, id_channels, num_blocks) self.apply(initialize_weight) - def forward(self, target_attributes, source_embedding): + def forward(self, target_attributes : TargetAttributes, source_embedding : IDEmbedding) -> Tensor: feature_map = self.upsample(source_embedding) - feature_map_1 = F.interpolate(self.AADBlk1(feature_map, target_attributes[0], source_embedding), scale_factor=2, mode='bilinear', align_corners=False) - feature_map_2 = F.interpolate(self.AADBlk2(feature_map_1, target_attributes[1], source_embedding), scale_factor=2, mode='bilinear', align_corners=False) - feature_map_3 = F.interpolate(self.AADBlk3(feature_map_2, target_attributes[2], source_embedding), scale_factor=2, mode='bilinear', align_corners=False) - feature_map_4 = F.interpolate(self.AADBlk4(feature_map_3, target_attributes[3], source_embedding), scale_factor=2, mode='bilinear', align_corners=False) - feature_map_5 = F.interpolate(self.AADBlk5(feature_map_4, target_attributes[4], source_embedding), scale_factor=2, mode='bilinear', align_corners=False) - feature_map_6 = F.interpolate(self.AADBlk6(feature_map_5, target_attributes[5], source_embedding), scale_factor=2, mode='bilinear', align_corners=False) - feature_map_7 = F.interpolate(self.AADBlk7(feature_map_6, target_attributes[6], source_embedding), scale_factor=2, mode='bilinear', align_corners=False) - output = self.AADBlk8(feature_map_7, target_attributes[7], source_embedding) + feature_map_1 = torch.nn.functional.interpolate(self.res_block_1(feature_map, target_attributes[0], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) + feature_map_2 = torch.nn.functional.interpolate(self.res_block_2(feature_map_1, target_attributes[1], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) + feature_map_3 = torch.nn.functional.interpolate(self.res_block_3(feature_map_2, target_attributes[2], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) + feature_map_4 = torch.nn.functional.interpolate(self.res_block_4(feature_map_3, target_attributes[3], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) + feature_map_5 = torch.nn.functional.interpolate(self.res_block_5(feature_map_4, target_attributes[4], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) + feature_map_6 = torch.nn.functional.interpolate(self.res_block_6(feature_map_5, target_attributes[5], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) + feature_map_7 = torch.nn.functional.interpolate(self.res_block_7(feature_map_6, target_attributes[6], source_embedding), scale_factor = 2, mode = 'bilinear', align_corners = False) + output = self.res_block_8(feature_map_7, target_attributes[7], source_embedding) return torch.tanh(output) class UNet(nn.Module): - def __init__(self): + def __init__(self) -> None: super(UNet, self).__init__() - self.downsampler_1 = Conv4x4(3, 32) - self.downsampler_2 = Conv4x4(32, 64) - self.downsampler_3 = Conv4x4(64, 128) - self.downsampler_4 = Conv4x4(128, 256) - self.downsampler_5 = Conv4x4(256, 512) - self.downsampler_6 = Conv4x4(512, 1024) - - self.bottleneck = Conv4x4(1024, 1024) - - self.upsampler_1 = DeConv4x4(1024, 1024) - self.upsampler_2 = DeConv4x4(2048, 512) - self.upsampler_3 = DeConv4x4(1024, 256) - self.upsampler_4 = DeConv4x4(512, 128) - self.upsampler_5 = DeConv4x4(256, 64) - self.upsampler_6 = DeConv4x4(128, 32) + self.downsampler_1 = DownSample(3, 32) + self.downsampler_2 = DownSample(32, 64) + self.downsampler_3 = DownSample(64, 128) + self.downsampler_4 = DownSample(128, 256) + self.downsampler_5 = DownSample(256, 512) + self.downsampler_6 = DownSample(512, 1024) + self.bottleneck = DownSample(1024, 1024) + self.upsampler_1 = Upsample(1024, 1024) + self.upsampler_2 = Upsample(2048, 512) + self.upsampler_3 = Upsample(1024, 256) + self.upsampler_4 = Upsample(512, 128) + self.upsampler_5 = Upsample(256, 64) + self.upsampler_6 = Upsample(128, 32) self.apply(initialize_weight) - def forward(self, input_tensor): - downsample_feature_1 = self.downsampler_1(input_tensor) + def forward(self, target : Tensor) -> TargetAttributes: + downsample_feature_1 = self.downsampler_1(target) downsample_feature_2 = self.downsampler_2(downsample_feature_1) downsample_feature_3 = self.downsampler_3(downsample_feature_2) downsample_feature_4 = self.downsampler_4(downsample_feature_3) downsample_feature_5 = self.downsampler_5(downsample_feature_4) downsample_feature_6 = self.downsampler_6(downsample_feature_5) - bottleneck_output = self.bottleneck(downsample_feature_6) - upsample_feature_1 = self.upsampler_1(bottleneck_output, downsample_feature_6) upsample_feature_2 = self.upsampler_2(upsample_feature_1, downsample_feature_5) upsample_feature_3 = self.upsampler_3(upsample_feature_2, downsample_feature_4) upsample_feature_4 = self.upsampler_4(upsample_feature_3, downsample_feature_3) upsample_feature_5 = self.upsampler_5(upsample_feature_4, downsample_feature_2) upsample_feature_6 = self.upsampler_6(upsample_feature_5, downsample_feature_1) - - output = F.interpolate(upsample_feature_6, scale_factor=2, mode='bilinear', align_corners=False) - + output = torch.nn.functional.interpolate(upsample_feature_6, scale_factor = 2, mode = 'bilinear', align_corners = False) return bottleneck_output, upsample_feature_1, upsample_feature_2, upsample_feature_3, upsample_feature_4, upsample_feature_5, upsample_feature_6, output class AADLayer(nn.Module): - def __init__(self, input_channels, attr_channels, id_channels): + def __init__(self, input_channels : int, attr_channels : int, id_channels : int) -> None: super(AADLayer, self).__init__() - self.attr_channels = attr_channels - self.id_channels = id_channels self.input_channels = input_channels - - self.conv_gamma = nn.Conv2d(attr_channels, input_channels, kernel_size=1, stride=1, padding=0, bias=True) - self.conv_beta = nn.Conv2d(attr_channels, input_channels, kernel_size=1, stride=1, padding=0, bias=True) - self.fc_gamma = nn.Linear(id_channels, input_channels) + self.conv_beta = nn.Conv2d(attr_channels, input_channels, kernel_size = 1, stride = 1, padding = 0, bias = True) + self.conv_gamma = nn.Conv2d(attr_channels, input_channels, kernel_size = 1, stride = 1, padding = 0, bias = True) self.fc_beta = nn.Linear(id_channels, input_channels) - self.instance_norm = nn.InstanceNorm2d(input_channels, affine=False) - - self.conv_mask = nn.Conv2d(input_channels, 1, kernel_size=1, stride=1, padding=0, bias=True) + self.fc_gamma = nn.Linear(id_channels, input_channels) + self.instance_norm = nn.InstanceNorm2d(input_channels, affine = False) + self.conv_mask = nn.Conv2d(input_channels, 1, kernel_size = 1, stride = 1, padding = 0, bias = True) - def forward(self, feature_map, attr_embedding, id_embedding): + def forward(self, feature_map : Tensor, attr_embedding : Tensor, id_embedding : IDEmbedding) -> Tensor: feature_map = self.instance_norm(feature_map) - attr_gamma = self.conv_gamma(attr_embedding) attr_beta = self.conv_beta(attr_embedding) attr_modulation = attr_gamma * feature_map + attr_beta - - id_gamma = self.fc_gamma(id_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as( - feature_map) - id_beta = self.fc_beta(id_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as( - feature_map) + id_gamma = self.fc_gamma(id_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map) + id_beta = self.fc_beta(id_embedding).reshape(feature_map.shape[0], self.input_channels, 1, 1).expand_as(feature_map) id_modulation = id_gamma * feature_map + id_beta - feature_mask = torch.sigmoid(self.conv_mask(feature_map)) feature_blend = (1 - feature_mask) * attr_modulation + feature_mask * id_modulation return feature_blend class AddBlocksSequential(nn.Sequential): - def forward(self, *inputs): - h, attr_embedding, id_embedding = inputs + def forward(self, *inputs : Tuple[Tensor, Tensor, IDEmbedding]) -> Tuple[Tuple[Tensor, Tensor, IDEmbedding], ...]: + _, attr_embedding, id_embedding = inputs for index, module in enumerate(self._modules.values()): if index % 3 == 0 and index > 0: - inputs = (inputs, attr_embedding, id_embedding) + inputs = (inputs, attr_embedding, id_embedding) # type:ignore[assignment] if type(inputs) == tuple: inputs = module(*inputs) else: @@ -134,9 +120,9 @@ def forward(self, *inputs): return inputs -class AAD_ResBlk(nn.Module): - def __init__(self, in_channels, out_channels, attr_channels, id_channels, num_blocks): - super(AAD_ResBlk, self).__init__() +class AADResBlock(nn.Module): + def __init__(self, in_channels : int, out_channels : int, attr_channels : int, id_channels : int, num_blocks : int) -> None: + super(AADResBlock, self).__init__() self.in_channels = in_channels self.out_channels = out_channels primary_add_blocks = [] @@ -146,8 +132,8 @@ def __init__(self, in_channels, out_channels, attr_channels, id_channels, num_bl primary_add_blocks.extend( [ AADLayer(in_channels, attr_channels, id_channels), - nn.ReLU(inplace=True), - nn.Conv2d(in_channels, intermediate_channels, kernel_size=3, stride=1, padding=1, bias=False) + nn.ReLU(inplace = True), + nn.Conv2d(in_channels, intermediate_channels, kernel_size = 3, stride = 1, padding = 1, bias = False) ]) self.primary_add_blocks = AddBlocksSequential(*primary_add_blocks) @@ -155,12 +141,12 @@ def __init__(self, in_channels, out_channels, attr_channels, id_channels, num_bl auxiliary_add_blocks = \ [ AADLayer(in_channels, attr_channels, id_channels), - nn.ReLU(inplace=True), - nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False) + nn.ReLU(inplace = True), + nn.Conv2d(in_channels, out_channels, kernel_size = 3, stride = 1, padding = 1, bias = False) ] self.auxiliary_add_blocks = AddBlocksSequential(*auxiliary_add_blocks) - def forward(self, feature_map, attr_embedding, id_embedding): + def forward(self, feature_map : Tensor, attr_embedding : Tensor, id_embedding : IDEmbedding) -> Tensor: primary_feature = self.primary_add_blocks(feature_map, attr_embedding, id_embedding) if self.in_channels != self.out_channels: @@ -169,50 +155,47 @@ def forward(self, feature_map, attr_embedding, id_embedding): return output_feature -class Conv4x4(nn.Module): - def __init__(self, in_channels, out_channels): - super(Conv4x4, self).__init__() - self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1, - bias=False) +class DownSample(nn.Module): + def __init__(self, in_channels : int, out_channels : int) -> None: + super(DownSample, self).__init__() + self.conv = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 4, stride = 2, padding = 1, bias = False) self.batch_norm = nn.BatchNorm2d(out_channels) - self.leaky_relu = nn.LeakyReLU(0.1, inplace=True) + self.leaky_relu = nn.LeakyReLU(0.1, inplace = True) - def forward(self, x): - x = self.conv(x) - x = self.batch_norm(x) - x = self.leaky_relu(x) - return x + def forward(self, temp : Tensor) -> Tensor: + temp = self.conv(temp) + temp = self.batch_norm(temp) + temp = self.leaky_relu(temp) + return temp -class DeConv4x4(nn.Module): - def __init__(self, in_channels, out_channels): - super(DeConv4x4, self).__init__() - self.deconv = nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2, - padding=1, bias=False) +class Upsample(nn.Module): + def __init__(self, in_channels : int, out_channels : int) -> None: + super(Upsample, self).__init__() + self.deconv = nn.ConvTranspose2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 4, stride = 2, padding = 1, bias = False) self.batch_norm = nn.BatchNorm2d(out_channels) - self.leaky_relu = nn.LeakyReLU(0.1, inplace=True) + self.leaky_relu = nn.LeakyReLU(0.1, inplace = True) - def forward(self, x, skip): - x = self.deconv(x) - x = self.batch_norm(x) - x = self.leaky_relu(x) - return torch.cat((x, skip), dim=1) + def forward(self, temp : Tensor, skip_tensor : Tensor) -> Tensor: + temp = self.deconv(temp) + temp = self.batch_norm(temp) + temp = self.leaky_relu(temp) + return torch.cat((temp, skip_tensor), dim = 1) -class Upsample(nn.Module): - def __init__(self, in_channels, out_channels): - super(Upsample, self).__init__() - self.initial_conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, - padding=1) - self.pixel_shuffle = nn.PixelShuffle(upscale_factor=2) +class PixelShuffleUpsample(nn.Module): + def __init__(self, in_channels : int, out_channels : int) -> None: + super(PixelShuffleUpsample, self).__init__() + self.conv = nn.Conv2d(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1) + self.pixel_shuffle = nn.PixelShuffle(upscale_factor = 2) - def forward(self, x): - x = self.initial_conv(x.view(x.shape[0], -1, 1, 1)) - x = self.pixel_shuffle(x) - return x + def forward(self, temp : Tensor) -> Tensor: + temp = self.conv(temp.view(temp.shape[0], -1, 1, 1)) + temp = self.pixel_shuffle(temp) + return temp -def initialize_weight(module): +def initialize_weight(module : nn.Module) -> None: if isinstance(module, nn.Linear): module.weight.data.normal_(0, 0.001) module.bias.data.zero_() @@ -222,3 +205,11 @@ def initialize_weight(module): if isinstance(module, nn.ConvTranspose2d): nn.init.xavier_normal_(module.weight.data) + + +if __name__ == '__main__': + model = AdaptiveEmbeddingIntegrationNetwork(512, 2) + src = torch.randn(1, 512) + trg = torch.randn(1, 3, 256, 256) + out = model(trg, src) + print(out[0].shape) diff --git a/face_swapper/src/typing.py b/face_swapper/src/typing.py index 529cb5f..b99e58b 100644 --- a/face_swapper/src/typing.py +++ b/face_swapper/src/typing.py @@ -1,4 +1,4 @@ -from typing import Any, Tuple, List, Dict, Optional +from typing import Any, Dict, List, Tuple from numpy.typing import NDArray from torch import Tensor @@ -9,6 +9,7 @@ TargetAttributes = Tuple[Tensor, ...] DiscriminatorOutputs = List[List[Tensor]] LossDict = Dict[str, Tensor] +IDEmbedding = Tensor Embedding = NDArray[Any] VisionFrame = NDArray[Any] diff --git a/mypy.ini b/mypy.ini index 75ef9f3..081e7bf 100644 --- a/mypy.ini +++ b/mypy.ini @@ -6,4 +6,4 @@ disallow_untyped_defs = True ignore_missing_imports = True strict_optional = False explicit_package_bases = True -exclude = face_swapper +exclude = face_swapper/LivePortrait