Skip to content

Commit

Permalink
new swapper
Browse files Browse the repository at this point in the history
  • Loading branch information
harisreedhar committed Dec 9, 2024
1 parent 5fa53da commit 5f361d2
Show file tree
Hide file tree
Showing 14 changed files with 550 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ plugins = flake8-import-order
application_import_names = arcface_converter
import-order-style = pycharm
per-file-ignores = preparing.py:E402

exclude = LivePortrait
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "face_swapper/LivePortrait"]
path = face_swapper/LivePortrait
url = https://github.com/KwaiVGI/LivePortrait
3 changes: 3 additions & 0 deletions face_swapper/LICENSE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Non-Commercial license

Copyright (c) 2024 Henry Ruhs
1 change: 1 addition & 0 deletions face_swapper/LivePortrait
Submodule LivePortrait added at 632da7
61 changes: 61 additions & 0 deletions face_swapper/config.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
[preparing.dataset]
dataset_path =

[preparing.dataloader]
same_person_probability = 0.2

[preparing.augmentation]
expression_augmentation = false

[training.loader]
batch_size = 6
num_workers = 8

[training.generator]
num_blocks = 2
id_channels = 512

[training.discriminator]
input_channels = 3
num_filters = 64
num_layers = 5
num_discriminators = 3

[auxiliary_models.paths]
arcface_path =
landmarker_path =
motion_extractor_path = /home/hari/Documents/Github/Phantom/assets/pretrained_models/liveportrait_motion_extractor.pth
feature_extractor_path =
warping_netwrk_path =
spade_generator_path =

[training.losses]
weight_adversarial = 1
weight_identity = 20
weight_attribute = 10
weight_reconstruction = 10
weight_tsr = 0
weight_expression = 0

[training.optimizers]
scheduler_step = 5000
scheduler_gamma = 0.2
generator_learning_rate = 0.0004
discriminator_learning_rate = 0.0004

[training.trainer]
epochs = 50
disable_discriminator = false

[training.output]
directory_path =
file_pattern =

[exporting]
directory_path =
source_path =
target_path =
opset_version =

[execution]
providers =
27 changes: 27 additions & 0 deletions face_swapper/src/augmentations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch
from torch import Tensor


def apply_random_motion_blur(tensor_image : Tensor) -> Tensor:
kernel_size = 9
kernel = torch.zeros((kernel_size, kernel_size), dtype=torch.float32)
random_angle = torch.empty(1).uniform_(-2 * torch.pi, 2 * torch.pi)
dx = torch.cos(random_angle)
dy = torch.sin(random_angle)
center = kernel_size // 2

for i in range(kernel_size):
x = int(center + (i - center) * dx)
y = int(center + (i - center) * dy)
if 0 <= x < kernel_size and 0 <= y < kernel_size:
kernel[y, x] = 1
kernel /= kernel.sum()
kernel = kernel.unsqueeze(0).unsqueeze(0)
blurred_channels = []

for channel in tensor_image:
channel = channel.unsqueeze(0).unsqueeze(0)
channel = torch.nn.functional.conv2d(channel, kernel, padding=kernel_size // 2)
channel = channel.squeeze(0).squeeze(0)
blurred_channels.append(channel)
return torch.stack(blurred_channels)
82 changes: 82 additions & 0 deletions face_swapper/src/data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import configparser
import glob
import random

import cv2
import torchvision.transforms as transforms
import tqdm
from PIL import Image
from torch.utils.data import TensorDataset

from .augmentations import apply_random_motion_blur
from .sub_typing import Batch

CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')


def read_image(image_path: str) -> Image.Image:
image = cv2.imread(image_path)[:, :, ::-1]
pil_image = Image.fromarray(image)
return pil_image


class DataLoaderVGG(TensorDataset):
def __init__(self, dataset_path : str) -> None:
self.same_person_probability = float(CONFIG.get('preparing.dataloader', 'same_person_probability'))
self.image_paths = glob.glob('{}/*/*.*g'.format(dataset_path))
self.folder_paths = glob.glob('{}/*'.format(dataset_path))
self.image_path_dict = {}

for folder_path in tqdm.tqdm(self.folder_paths):
image_paths = glob.glob('{}/*'.format(folder_path))
self.image_path_dict[folder_path] = image_paths
self.dataset_total = len(self.image_paths)
self.transforms_basic = transforms.Compose(
[
transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
self.transforms_moderate = transforms.Compose(
[
transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.ColorJitter(brightness = 0.2, contrast = 0.2, saturation = 0.2, hue = 0.1),
transforms.RandomAffine(4, translate = (0.01, 0.01), scale = (0.98, 1.02), shear = (1, 1), fill = 0),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
self.transforms_complex = transforms.Compose(
[
transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(p = 0.5),
transforms.RandomApply([ apply_random_motion_blur ], p = 0.3),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation = 0.2, hue = 0.1),
transforms.RandomAffine(8, translate = (0.02, 0.02), scale = (0.98, 1.02), shear = (1, 1), fill = 0),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

def __getitem__(self, item : int) -> Batch:
source_image_path = self.image_paths[item]
source = read_image(source_image_path)

if random.random() > self.same_person_probability:
is_same_person = 0
target_image_path = random.choice(self.image_paths)
target = read_image(target_image_path)
source_transform = self.transforms_moderate(source)
target_transform = self.transforms_complex(target)
else:
is_same_person = 1
source_folder_path = '/'.join(source_image_path.split('/')[:-1])
target_image_path = random.choice(self.image_path_dict[source_folder_path])
target = read_image(target_image_path)
source_transform = self.transforms_basic(source)
target_transform = self.transforms_basic(target)

return source_transform, target_transform, is_same_person

def __len__(self) -> int:
return self.dataset_total
85 changes: 85 additions & 0 deletions face_swapper/src/discriminator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import List

import numpy
import torch.nn as nn

from .sub_typing import Tensor


class NLayerDiscriminator(nn.Module):
def __init__(self, input_channels : int, num_filters : int, num_layers : int) -> None:
super(NLayerDiscriminator, self).__init__()
self.num_layers = num_layers
kernel_size = 4
padding_size = int(numpy.ceil((kernel_size - 1.0) / 2))
model_layers = [
[
nn.Conv2d(input_channels, num_filters, kernel_size = kernel_size, stride = 2, padding = padding_size),
nn.LeakyReLU(0.2, True)
]]
current_filters = num_filters

for layer_index in range(1, num_layers):
previous_filters = current_filters
current_filters = min(current_filters * 2, 512)
model_layers += [
[
nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, stride = 2, padding = padding_size),
nn.InstanceNorm2d(current_filters), nn.LeakyReLU(0.2, True)
]]
previous_filters = current_filters
current_filters = min(current_filters * 2, 512)
model_layers += [
[
nn.Conv2d(previous_filters, current_filters, kernel_size = kernel_size, stride = 1, padding = padding_size),
nn.InstanceNorm2d(current_filters),
nn.LeakyReLU(0.2, True)
]]
model_layers += [
[
nn.Conv2d(current_filters, 1, kernel_size = kernel_size, stride = 1, padding = padding_size)
]]
combined_layers = []

for layer in model_layers:
combined_layers += layer
self.model = nn.Sequential(*combined_layers)

def forward(self, input_tensor : Tensor) -> Tensor:
return self.model(input_tensor)


# input_channels=3, num_filters=64, num_layers=5, num_discriminators=3
class MultiscaleDiscriminator(nn.Module):
def __init__(self, input_channels : int, num_filters : int, num_layers : int, num_discriminators : int):
super(MultiscaleDiscriminator, self).__init__()
self.num_discriminators = num_discriminators
self.num_layers = num_layers

for discriminator_index in range(num_discriminators):
single_discriminator = NLayerDiscriminator(input_channels, num_filters, num_layers)
setattr(self, 'discriminator_layer_{}'.format(discriminator_index), single_discriminator.model)
self.downsample = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = [ 1, 1 ], count_include_pad = False) # type:ignore[arg-type]

def single_discriminator_forward(self, model_layers : nn.Sequential, input_tensor : Tensor) -> List[Tensor]:

if self.return_intermediate_features:
feature_maps = [ input_tensor ]

for layer in model_layers:
feature_maps.append(layer(feature_maps[-1]))
return feature_maps[1:]
else:
return [ model_layers(input_tensor) ]

def forward(self, input_tensor : Tensor) -> List[Tensor]:
discriminator_outputs = []
downsampled_input = input_tensor

for discriminator_index in range(self.num_discriminators):
model_layers = getattr(self, 'discriminator_layer_{}'.format(self.num_discriminators - 1 - discriminator_index))
discriminator_outputs.append(self.single_discriminator_forward(model_layers, downsampled_input))

if discriminator_index != (self.num_discriminators - 1):
downsampled_input = self.downsample(downsampled_input)
return discriminator_outputs
Loading

0 comments on commit 5f361d2

Please sign in to comment.