diff --git a/face_swapper/src/training.py b/face_swapper/src/training.py index c2d512b..f7d1c37 100644 --- a/face_swapper/src/training.py +++ b/face_swapper/src/training.py @@ -318,9 +318,9 @@ def get_arcface_embedding(self, vision_tensor : Tensor, padding : Tuple[int, int crop_vision_tensor = vision_tensor[:, :, crop_height : height - crop_height, crop_width : width - crop_width] crop_vision_tensor = torch.nn.functional.interpolate(crop_vision_tensor, size = (112, 112), mode = 'bilinear') crop_vision_tensor[:, :, :padding[0], :] = 0 - crop_vision_tensor[:, :, -padding[1]:, :] = 0 + crop_vision_tensor[:, :, 112 - padding[1]:, :] = 0 crop_vision_tensor[:, :, :, :padding[2]] = 0 - crop_vision_tensor[:, :, :, -padding[3]:] = 0 + crop_vision_tensor[:, :, :, 112 - padding[3]:] = 0 embedding = self.arcface(crop_vision_tensor) embedding = torch.nn.functional.normalize(embedding, p = 2, dim = 1) return embedding