diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index 87e151d..9416543 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -17,7 +17,7 @@ from .discriminator import MultiscaleDiscriminator from .generator import AdaptiveEmbeddingIntegrationNetwork from .helper import hinge_fake_loss, hinge_real_loss -from .typing import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, IdEmbedding, Loss, Padding, SourceEmbedding, TargetAttributes, VisionTensor +from .typing import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, IdEmbedding, LossTensor, Padding, SourceEmbedding, TargetAttributes, VisionTensor CONFIG = configparser.ConfigParser() CONFIG.read('config.ini') @@ -88,7 +88,7 @@ def training_step(self, batch : Batch, batch_index : int) -> Tensor: self.log('l_REC', generator_losses.get('loss_reconstruction'), prog_bar = True) return generator_losses.get('loss_generator') - def calc_adversarial_loss(self, discriminator_outputs : DiscriminatorOutputs) -> Loss: + def calc_adversarial_loss(self, discriminator_outputs : DiscriminatorOutputs) -> LossTensor: loss_adversarial = torch.Tensor(0) for discriminator_output in discriminator_outputs: @@ -96,7 +96,7 @@ def calc_adversarial_loss(self, discriminator_outputs : DiscriminatorOutputs) -> loss_adversarial = torch.mean(loss_adversarial) return loss_adversarial - def calc_attribute_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes) -> Loss: + def calc_attribute_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes) -> LossTensor: loss_attribute = torch.Tensor(0) swap_attributes = self.generator.get_attributes(swap_tensor) @@ -105,19 +105,19 @@ def calc_attribute_loss(self, swap_tensor : VisionTensor, target_attributes : Ta loss_attribute *= 0.5 return loss_attribute - def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> Loss: + def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> LossTensor: loss_reconstruction = torch.sum(0.5 * torch.mean(torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1), dim = 1) * is_same_person) / (is_same_person.sum() + 1e-4) loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean() loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5 return loss_reconstruction - def calc_id_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> Loss: + def calc_id_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor: swap_embedding = self.get_id_embedding(swap_tensor, (30, 0, 10, 10)) source_embedding = self.get_id_embedding(source_tensor, (30, 0, 10, 10)) loss_id = (1 - torch.cosine_similarity(source_embedding, swap_embedding, dim = 1)).mean() return loss_id - def calc_tsr_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> Loss: + def calc_tsr_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor: swap_motion_features = self.get_pose_features(swap_tensor) target_motion_features = self.get_pose_features(target_tensor) loss_tsr = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype) @@ -126,7 +126,7 @@ def calc_tsr_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor loss_tsr += self.mse_loss(swap_motion_feature, target_motion_feature) return loss_tsr - def calc_gaze_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> Loss: + def calc_gaze_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor: swap_landmark = self.get_face_landmarks(swap_tensor) target_landmark = self.get_face_landmarks(target_tensor) left_gaze_loss = self.mse_loss(swap_landmark[:, 198], target_landmark[:, 198]) diff --git a/face_swapper/src/typing.py b/face_swapper/src/typing.py index f1366b5..da89355 100644 --- a/face_swapper/src/typing.py +++ b/face_swapper/src/typing.py @@ -14,6 +14,6 @@ Padding = Tuple[int, int, int, int] FaceLandmark203 = Tensor VisionTensor = Tensor -Loss = Tensor -GeneratorLossSet = Dict[str, Loss] -DiscriminatorLossSet = Dict[str, Loss] +LossTensor = Tensor +GeneratorLossSet = Dict[str, LossTensor] +DiscriminatorLossSet = Dict[str, LossTensor]