From fb5ad0f7706da449f8b4fdc1db8314f44255e745 Mon Sep 17 00:00:00 2001 From: Lei Yang Date: Fri, 1 Apr 2022 20:08:09 +0800 Subject: [PATCH] [Fix] Fix the logic of skipping loss in smplify (#146) * Revise default weight to None & add _skip_loss * Convert None to none in mse_loss * Modify default weight to None * Add more details in comments * Apply abs in computing relative change & add docs * Revise comments --- mmhuman3d/models/losses/mse_loss.py | 2 + mmhuman3d/models/registrants/smplify.py | 125 +++++++++++------- mmhuman3d/models/registrants/smplifyx.py | 20 +-- .../test_models/test_losses/test_mse_loss.py | 12 ++ 4 files changed, 104 insertions(+), 55 deletions(-) diff --git a/mmhuman3d/models/losses/mse_loss.py b/mmhuman3d/models/losses/mse_loss.py index 4e78438f..319cc94e 100644 --- a/mmhuman3d/models/losses/mse_loss.py +++ b/mmhuman3d/models/losses/mse_loss.py @@ -39,6 +39,7 @@ class MSELoss(nn.Module): def __init__(self, reduction='mean', loss_weight=1.0): super().__init__() assert reduction in (None, 'none', 'mean', 'sum') + reduction = 'none' if reduction is None else reduction self.reduction = reduction self.loss_weight = loss_weight @@ -86,6 +87,7 @@ class KeypointMSELoss(nn.Module): def __init__(self, reduction='mean', loss_weight=1.0, sigma=1.0): super().__init__() assert reduction in (None, 'none', 'mean', 'sum') + reduction = 'none' if reduction is None else reduction self.reduction = reduction self.loss_weight = loss_weight self.sigma = sigma diff --git a/mmhuman3d/models/registrants/smplify.py b/mmhuman3d/models/registrants/smplify.py index b570cfca..53c3d908 100644 --- a/mmhuman3d/models/registrants/smplify.py +++ b/mmhuman3d/models/registrants/smplify.py @@ -285,16 +285,16 @@ def _optimize_stage(self, fit_betas: bool = True, keypoints2d: torch.Tensor = None, keypoints2d_conf: torch.Tensor = None, - keypoints2d_weight: float = 0., + keypoints2d_weight: float = None, keypoints3d: torch.Tensor = None, keypoints3d_conf: torch.Tensor = None, - keypoints3d_weight: float = 0., - shape_prior_weight: float = 0., - joint_prior_weight: float = 0., - smooth_loss_weight: float = 0., - pose_prior_weight: float = 0., - pose_reg_weight: float = 0., - limb_length_weight: float = 0., + keypoints3d_weight: float = None, + shape_prior_weight: float = None, + joint_prior_weight: float = None, + smooth_loss_weight: float = None, + pose_prior_weight: float = None, + pose_reg_weight: float = None, + limb_length_weight: float = None, joint_weights: dict = {}, num_iter: int = 1, ftol: float = 1e-4, @@ -373,7 +373,7 @@ def closure(): loss = optimizer.step(closure) if iter_idx > 0 and pre_loss is not None and ftol > 0: - loss_rel_change = self._compute_relative_loss_change( + loss_rel_change = self._compute_relative_change( pre_loss, loss.item()) if loss_rel_change < ftol: print(f'[ftol={ftol}] Early stop at {iter_idx} iter!') @@ -388,16 +388,16 @@ def evaluate( transl: torch.Tensor = None, keypoints2d: torch.Tensor = None, keypoints2d_conf: torch.Tensor = None, - keypoints2d_weight: float = 0., + keypoints2d_weight: float = None, keypoints3d: torch.Tensor = None, keypoints3d_conf: torch.Tensor = None, - keypoints3d_weight: float = 0., - shape_prior_weight: float = 0., - joint_prior_weight: float = 0., - smooth_loss_weight: float = 0., - pose_prior_weight: float = 0., - pose_reg_weight: float = 0., - limb_length_weight: float = 0., + keypoints3d_weight: float = None, + shape_prior_weight: float = None, + joint_prior_weight: float = None, + smooth_loss_weight: float = None, + pose_prior_weight: float = None, + pose_reg_weight: float = None, + limb_length_weight: float = None, joint_weights: dict = {}, return_verts: bool = False, return_full_pose: bool = False, @@ -491,13 +491,13 @@ def _compute_loss(self, keypoints2d_weight: float = None, keypoints3d: torch.Tensor = None, keypoints3d_conf: torch.Tensor = None, - keypoints3d_weight: float = 0., - shape_prior_weight: float = 0., - joint_prior_weight: float = 0., - smooth_loss_weight: float = 0., - pose_prior_weight: float = 0., - pose_reg_weight: float = 0., - limb_length_weight: float = 0., + keypoints3d_weight: float = None, + shape_prior_weight: float = None, + joint_prior_weight: float = None, + smooth_loss_weight: float = None, + pose_prior_weight: float = None, + pose_reg_weight: float = None, + limb_length_weight: float = None, joint_weights: dict = {}, reduction_override: str = None, global_orient: torch.Tensor = None, @@ -538,8 +538,8 @@ def _compute_loss(self, weight = self._get_weight(**joint_weights) # 2D keypoint loss - if keypoints2d is not None and (self.keypoints2d_mse_loss.loss_weight > - 0 or keypoints2d_weight > 0): + if keypoints2d is not None and not self._skip_loss( + self.keypoints2d_mse_loss, keypoints2d_weight): # bs = model_joints.shape[0] # projected_joints = perspective_projection( # model_joints, @@ -566,8 +566,8 @@ def _compute_loss(self, losses['keypoint2d_loss'] = keypoint2d_loss # 3D keypoint loss - if keypoints3d is not None and (self.keypoints3d_mse_loss.loss_weight > - 0 or keypoints3d_weight > 0): + if keypoints3d is not None and not self._skip_loss( + self.keypoints3d_mse_loss, keypoints3d_weight): keypoints3d_loss = self.keypoints3d_mse_loss( pred=model_joints, pred_conf=model_joint_conf, @@ -579,9 +579,7 @@ def _compute_loss(self, losses['keypoints3d_loss'] = keypoints3d_loss # regularizer to prevent betas from taking large values - if self.shape_prior_loss is not None and ( - self.shape_prior_loss.loss_weight > 0 - or shape_prior_weight > 0): + if not self._skip_loss(self.shape_prior_loss, shape_prior_weight): shape_prior_loss = self.shape_prior_loss( betas=betas, loss_weight_override=shape_prior_weight, @@ -589,9 +587,7 @@ def _compute_loss(self, losses['shape_prior_loss'] = shape_prior_loss # joint prior loss - if self.joint_prior_loss is not None and ( - self.joint_prior_loss.loss_weight > 0 - or joint_prior_weight > 0): + if not self._skip_loss(self.joint_prior_loss, joint_prior_weight): joint_prior_loss = self.joint_prior_loss( body_pose=body_pose, loss_weight_override=joint_prior_weight, @@ -599,8 +595,7 @@ def _compute_loss(self, losses['joint_prior_loss'] = joint_prior_loss # smooth body loss - if self.smooth_loss is not None and (self.smooth_loss.loss_weight > 0 - or smooth_loss_weight > 0): + if not self._skip_loss(self.smooth_loss, smooth_loss_weight): smooth_loss = self.smooth_loss( body_pose=body_pose, loss_weight_override=smooth_loss_weight, @@ -608,8 +603,7 @@ def _compute_loss(self, losses['smooth_loss'] = smooth_loss # pose prior loss - if self.pose_prior_loss is not None and ( - self.pose_prior_loss.loss_weight > 0 or pose_prior_weight > 0): + if not self._skip_loss(self.pose_prior_loss, pose_prior_weight): pose_prior_loss = self.pose_prior_loss( body_pose=body_pose, loss_weight_override=pose_prior_weight, @@ -617,8 +611,7 @@ def _compute_loss(self, losses['pose_prior_loss'] = pose_prior_loss # pose reg loss - if self.pose_reg_loss is not None and ( - self.pose_reg_loss.loss_weight > 0 or pose_reg_weight > 0): + if not self._skip_loss(self.pose_reg_loss, pose_reg_weight): pose_reg_loss = self.pose_reg_loss( body_pose=body_pose, loss_weight_override=pose_reg_weight, @@ -626,9 +619,7 @@ def _compute_loss(self, losses['pose_reg_loss'] = pose_reg_loss # limb length loss - if self.limb_length_loss is not None and ( - self.limb_length_loss.loss_weight > 0 - or limb_length_weight > 0): + if not self._skip_loss(self.limb_length_loss, limb_length_weight): limb_length_loss = self.limb_length_loss( pred=model_joints, pred_conf=model_joint_conf, @@ -783,5 +774,49 @@ def _expand_betas(self, batch_size, betas): return betas_video @staticmethod - def _compute_relative_loss_change(pre_v, cur_v): - return (pre_v - cur_v) / max([np.abs(pre_v), np.abs(cur_v), 1]) + def _compute_relative_change(pre_v, cur_v): + """Compute relative loss change. If relative change is small enough, we + can apply early stop to accelerate the optimization. (1) When one of + the value is larger than 1, we calculate the relative change by diving + their max value. (2) When both values are smaller than 1, it degrades + to absolute change. Intuitively, if two values are small and close, + dividing the difference by the max value may yield a large value. + + Args: + pre_v: previous value + cur_v: current value + + Returns: + float: relative change + """ + return np.abs(pre_v - cur_v) / max([np.abs(pre_v), np.abs(cur_v), 1]) + + @staticmethod + def _skip_loss(loss, loss_weight_override): + """Whether to skip loss computation. If loss is None, it will directly + skip the loss to avoid RuntimeError. If loss is not None, the table + below shows the return value. If the return value is True, it means the + computation of loss can be skipped. As the result is 0 even if it is + calculated, we can skip it to save computational cost. + + | loss.loss_weight | loss_weight_override | returns | + | ---------------- | -------------------- | ------- | + | == 0 | None | True | + | != 0 | None | False | + | == 0 | == 0 | True | + | != 0 | == 0 | True | + | == 0 | != 0 | False | + | != 0 | != 0 | False | + + Args: + loss: loss is an object that has attribute loss_weight. + loss.loss_weight is assigned when loss is initialized. + loss_weight_override: loss_weight used to override loss.loss_weight + + Returns: + bool: True means skipping loss computation, and vice versa + """ + if (loss is None) or (loss.loss_weight == 0 and loss_weight_override is + None) or (loss_weight_override == 0): + return True + return False diff --git a/mmhuman3d/models/registrants/smplifyx.py b/mmhuman3d/models/registrants/smplifyx.py index 457c6732..5e811680 100644 --- a/mmhuman3d/models/registrants/smplifyx.py +++ b/mmhuman3d/models/registrants/smplifyx.py @@ -170,11 +170,11 @@ def _optimize_stage(self, keypoints2d_weight: float = None, keypoints3d: torch.Tensor = None, keypoints3d_conf: torch.Tensor = None, - keypoints3d_weight: float = 0., - shape_prior_weight: float = 0., - joint_prior_weight: float = 0., - smooth_loss_weight: float = 0., - pose_prior_weight: float = 0., + keypoints3d_weight: float = None, + shape_prior_weight: float = None, + joint_prior_weight: float = None, + smooth_loss_weight: float = None, + pose_prior_weight: float = None, joint_weights: dict = {}, num_iter: int = 1) -> None: """Optimize a stage of body model parameters according to @@ -286,11 +286,11 @@ def evaluate( keypoints2d_weight=None, keypoints3d=None, keypoints3d_conf=None, - keypoints3d_weight=0., - shape_prior_weight=0., - joint_prior_weight=0., - smooth_loss_weight=0., - pose_prior_weight=0., + keypoints3d_weight=None, + shape_prior_weight=None, + joint_prior_weight=None, + smooth_loss_weight=None, + pose_prior_weight=None, joint_weights={}, return_verts=False, return_full_pose=False, diff --git a/tests/test_models/test_losses/test_mse_loss.py b/tests/test_models/test_losses/test_mse_loss.py index 336a60d9..77ef90b2 100644 --- a/tests/test_models/test_losses/test_mse_loss.py +++ b/tests/test_models/test_losses/test_mse_loss.py @@ -25,6 +25,18 @@ def test_keypoint_mse_loss(): target = torch.zeros(1, 3, 2) assert torch.allclose(loss(pred, target), torch.tensor(3.)) + # test None reduction + loss_cfg = dict(type='KeypointMSELoss', reduction=None) + loss = build_loss(loss_cfg) + pred = torch.zeros(1, 3, 2) + target = torch.zeros(1, 3, 2) + assert torch.allclose(loss(pred, target), pred) + + pred = torch.ones(1, 3, 2) + target = torch.zeros(1, 3, 2) + result = torch.ones(1, 3, 2) * 0.5 + assert torch.allclose(loss(pred, target), result) + # test None reduction loss_cfg = dict(type='KeypointMSELoss', reduction='none') loss = build_loss(loss_cfg)