Skip to content

Commit

Permalink
cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
harisreedhar committed Jan 17, 2025
1 parent 2fb5ad6 commit c4cea27
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
14 changes: 7 additions & 7 deletions face_swapper/src/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .discriminator import MultiscaleDiscriminator
from .generator import AdaptiveEmbeddingIntegrationNetwork
from .helper import hinge_fake_loss, hinge_real_loss
from .typing import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, IdEmbedding, Loss, Padding, SourceEmbedding, TargetAttributes, VisionTensor
from .typing import Batch, DiscriminatorLossSet, DiscriminatorOutputs, FaceLandmark203, GeneratorLossSet, IdEmbedding, LossTensor, Padding, SourceEmbedding, TargetAttributes, VisionTensor

CONFIG = configparser.ConfigParser()
CONFIG.read('config.ini')
Expand Down Expand Up @@ -88,15 +88,15 @@ def training_step(self, batch : Batch, batch_index : int) -> Tensor:
self.log('l_REC', generator_losses.get('loss_reconstruction'), prog_bar = True)
return generator_losses.get('loss_generator')

def calc_adversarial_loss(self, discriminator_outputs : DiscriminatorOutputs) -> Loss:
def calc_adversarial_loss(self, discriminator_outputs : DiscriminatorOutputs) -> LossTensor:
loss_adversarial = torch.Tensor(0)

for discriminator_output in discriminator_outputs:
loss_adversarial += hinge_real_loss(discriminator_output[0]).mean(dim = [ 1, 2, 3 ])
loss_adversarial = torch.mean(loss_adversarial)
return loss_adversarial

def calc_attribute_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes) -> Loss:
def calc_attribute_loss(self, swap_tensor : VisionTensor, target_attributes : TargetAttributes) -> LossTensor:
loss_attribute = torch.Tensor(0)
swap_attributes = self.generator.get_attributes(swap_tensor)

Expand All @@ -105,19 +105,19 @@ def calc_attribute_loss(self, swap_tensor : VisionTensor, target_attributes : Ta
loss_attribute *= 0.5
return loss_attribute

def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> Loss:
def calc_reconstruction_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor, is_same_person : Tensor) -> LossTensor:
loss_reconstruction = torch.sum(0.5 * torch.mean(torch.pow(swap_tensor - target_tensor, 2).reshape(self.batch_size, -1), dim = 1) * is_same_person) / (is_same_person.sum() + 1e-4)
loss_ssim = 1 - ssim(swap_tensor, target_tensor, data_range = float(torch.max(swap_tensor) - torch.min(swap_tensor))).mean()
loss_reconstruction = (loss_reconstruction + loss_ssim) * 0.5
return loss_reconstruction

def calc_id_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> Loss:
def calc_id_loss(self, source_tensor : VisionTensor, swap_tensor : VisionTensor) -> LossTensor:
swap_embedding = self.get_id_embedding(swap_tensor, (30, 0, 10, 10))
source_embedding = self.get_id_embedding(source_tensor, (30, 0, 10, 10))
loss_id = (1 - torch.cosine_similarity(source_embedding, swap_embedding, dim = 1)).mean()
return loss_id

def calc_tsr_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> Loss:
def calc_tsr_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor:
swap_motion_features = self.get_pose_features(swap_tensor)
target_motion_features = self.get_pose_features(target_tensor)
loss_tsr = torch.tensor(0).to(swap_tensor.device).to(swap_tensor.dtype)
Expand All @@ -126,7 +126,7 @@ def calc_tsr_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor
loss_tsr += self.mse_loss(swap_motion_feature, target_motion_feature)
return loss_tsr

def calc_gaze_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> Loss:
def calc_gaze_loss(self, swap_tensor : VisionTensor, target_tensor : VisionTensor) -> LossTensor:
swap_landmark = self.get_face_landmarks(swap_tensor)
target_landmark = self.get_face_landmarks(target_tensor)
left_gaze_loss = self.mse_loss(swap_landmark[:, 198], target_landmark[:, 198])
Expand Down
6 changes: 3 additions & 3 deletions face_swapper/src/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@
Padding = Tuple[int, int, int, int]
FaceLandmark203 = Tensor
VisionTensor = Tensor
Loss = Tensor
GeneratorLossSet = Dict[str, Loss]
DiscriminatorLossSet = Dict[str, Loss]
LossTensor = Tensor
GeneratorLossSet = Dict[str, LossTensor]
DiscriminatorLossSet = Dict[str, LossTensor]

0 comments on commit c4cea27

Please sign in to comment.