Skip to content

Commit

Permalink
ugly training code
Browse files Browse the repository at this point in the history
  • Loading branch information
harisreedhar committed Dec 10, 2024
1 parent 78f5c9b commit ef0abb5
Show file tree
Hide file tree
Showing 10 changed files with 655 additions and 167 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ plugins = flake8-import-order
application_import_names = arcface_converter
import-order-style = pycharm
per-file-ignores = preparing.py:E402
exclude = LivePortrait
exclude = face_swapper
48 changes: 29 additions & 19 deletions face_swapper/config.ini
Original file line number Diff line number Diff line change
@@ -1,55 +1,65 @@
[preparing.dataset]
dataset_path =
dataset_path = /assets/VGGface2_None_norm_512_true_bygfpgan

[preparing.dataloader]
same_person_probability = 0.2

[preparing.augmentation]
expression_augmentation = false
expression = false

[training.loader]
batch_size = 6
batch_size = 4
num_workers = 8

[training.generator]
num_blocks = 2
id_channels = 512
learning_rate = 0.0004

[training.discriminator]
input_channels = 3
num_filters = 64
num_layers = 5
num_discriminators = 3
learning_rate = 0.0004
disable = false

[auxiliary_models.paths]
arcface_path =
landmarker_path =
motion_extractor_path = /home/hari/Documents/Github/Phantom/assets/pretrained_models/liveportrait_motion_extractor.pth
feature_extractor_path =
warping_netwrk_path =
spade_generator_path =
arcface_path = /assets/pretrained_models/arcface_w600k_r50.pt
landmarker_path = /assets/pretrained_models/landmark_203.pt
motion_extractor_path = /assets/pretrained_models/liveportrait_motion_extractor.pth
feature_extractor_path = /assets/pretrained_models/liveportrait_feature_extractor.pth
warping_network_path = /assets/pretrained_models/liveportrait_warping_model.pth
spade_generator_path = /assets/pretrained_models/liveportrait_spade_generator.pth

[training.losses]
weight_adversarial = 1
weight_identity = 20
weight_attribute = 10
weight_reconstruction = 10
weight_tsr = 0
weight_expression = 0
weight_tsr = 100
weight_eye_gaze = 5
weight_eye_open = 5
weight_lip_open = 5

[training.optimizers]
scheduler_step = 5000
scheduler_gamma = 0.2
generator_learning_rate = 0.0004
discriminator_learning_rate = 0.0004
[training.schedulers]
step = 5000
gamma = 0.2

[training.trainer]
epochs = 50
max_epochs = 50
disable_discriminator = false

[training.output]
directory_path =
file_pattern =
checkpoint_path = checkpoints/last.ckpt
directory_path = checkpoints
file_pattern = 'checkpoint-{epoch}-{step}-{l_G:.4f}-{l_D:.4f}'
preview_frequency = 250
validation_frequency = 1000

[training.validation]
sources = assets/test/front/sources
targets = assets/test/front/targets

[exporting]
directory_path =
Expand Down
15 changes: 12 additions & 3 deletions face_swapper/src/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch.utils.data import TensorDataset

from .augmentations import apply_random_motion_blur
from .sub_typing import Batch
from .typing import Batch

CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
Expand All @@ -27,6 +27,7 @@ def __init__(self, dataset_path : str) -> None:
self.image_paths = glob.glob('{}/*/*.*g'.format(dataset_path))
self.folder_paths = glob.glob('{}/*'.format(dataset_path))
self.image_path_dict = {}
self._current_index = 0

for folder_path in tqdm.tqdm(self.folder_paths):
image_paths = glob.glob('{}/*'.format(folder_path))
Expand All @@ -50,12 +51,12 @@ def __init__(self, dataset_path : str) -> None:
[
transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(p = 0.5),
transforms.RandomApply([ apply_random_motion_blur ], p = 0.3),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation = 0.2, hue = 0.1),
transforms.RandomAffine(8, translate = (0.02, 0.02), scale = (0.98, 1.02), shear = (1, 1), fill = 0),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BICUBIC),
])

def __getitem__(self, item : int) -> Batch:
Expand All @@ -80,3 +81,11 @@ def __getitem__(self, item : int) -> Batch:

def __len__(self) -> int:
return self.dataset_total


def state_dict(self):
return {'current_index': self._current_index}


def load_state_dict(self, state_dict):
self._current_index = state_dict['current_index']
15 changes: 4 additions & 11 deletions face_swapper/src/discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy
import torch.nn as nn

from .sub_typing import Tensor
from .typing import Tensor, DiscriminatorOutputs


class NLayerDiscriminator(nn.Module):
Expand Down Expand Up @@ -49,7 +49,6 @@ def forward(self, input_tensor : Tensor) -> Tensor:
return self.model(input_tensor)


# input_channels=3, num_filters=64, num_layers=5, num_discriminators=3
class MultiscaleDiscriminator(nn.Module):
def __init__(self, input_channels : int, num_filters : int, num_layers : int, num_discriminators : int):
super(MultiscaleDiscriminator, self).__init__()
Expand All @@ -61,18 +60,12 @@ def __init__(self, input_channels : int, num_filters : int, num_layers : int, nu
setattr(self, 'discriminator_layer_{}'.format(discriminator_index), single_discriminator.model)
self.downsample = nn.AvgPool2d(kernel_size = 3, stride = 2, padding = [ 1, 1 ], count_include_pad = False) # type:ignore[arg-type]

def single_discriminator_forward(self, model_layers : nn.Sequential, input_tensor : Tensor) -> List[Tensor]:

if self.return_intermediate_features:
feature_maps = [ input_tensor ]
def single_discriminator_forward(self, model_layers : nn.Sequential, input_tensor : Tensor) -> List[Tensor]:
return [ model_layers(input_tensor) ]

for layer in model_layers:
feature_maps.append(layer(feature_maps[-1]))
return feature_maps[1:]
else:
return [ model_layers(input_tensor) ]

def forward(self, input_tensor : Tensor) -> List[Tensor]:
def forward(self, input_tensor : Tensor) -> DiscriminatorOutputs:
discriminator_outputs = []
downsampled_input = input_tensor

Expand Down
Loading

0 comments on commit ef0abb5

Please sign in to comment.