Skip to content

Commit

Permalink
[Fix] Fix the logic of skipping loss in smplify (#146)
Browse files Browse the repository at this point in the history
* Revise default weight to None & add _skip_loss

* Convert None to none in mse_loss

* Modify default weight to None

* Add more details in comments

* Apply abs in computing relative change & add docs

* Revise comments
  • Loading branch information
yl-1993 authored Apr 1, 2022
1 parent 10233de commit fb5ad0f
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 55 deletions.
2 changes: 2 additions & 0 deletions mmhuman3d/models/losses/mse_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class MSELoss(nn.Module):
def __init__(self, reduction='mean', loss_weight=1.0):
super().__init__()
assert reduction in (None, 'none', 'mean', 'sum')
reduction = 'none' if reduction is None else reduction
self.reduction = reduction
self.loss_weight = loss_weight

Expand Down Expand Up @@ -86,6 +87,7 @@ class KeypointMSELoss(nn.Module):
def __init__(self, reduction='mean', loss_weight=1.0, sigma=1.0):
super().__init__()
assert reduction in (None, 'none', 'mean', 'sum')
reduction = 'none' if reduction is None else reduction
self.reduction = reduction
self.loss_weight = loss_weight
self.sigma = sigma
Expand Down
125 changes: 80 additions & 45 deletions mmhuman3d/models/registrants/smplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,16 +285,16 @@ def _optimize_stage(self,
fit_betas: bool = True,
keypoints2d: torch.Tensor = None,
keypoints2d_conf: torch.Tensor = None,
keypoints2d_weight: float = 0.,
keypoints2d_weight: float = None,
keypoints3d: torch.Tensor = None,
keypoints3d_conf: torch.Tensor = None,
keypoints3d_weight: float = 0.,
shape_prior_weight: float = 0.,
joint_prior_weight: float = 0.,
smooth_loss_weight: float = 0.,
pose_prior_weight: float = 0.,
pose_reg_weight: float = 0.,
limb_length_weight: float = 0.,
keypoints3d_weight: float = None,
shape_prior_weight: float = None,
joint_prior_weight: float = None,
smooth_loss_weight: float = None,
pose_prior_weight: float = None,
pose_reg_weight: float = None,
limb_length_weight: float = None,
joint_weights: dict = {},
num_iter: int = 1,
ftol: float = 1e-4,
Expand Down Expand Up @@ -373,7 +373,7 @@ def closure():

loss = optimizer.step(closure)
if iter_idx > 0 and pre_loss is not None and ftol > 0:
loss_rel_change = self._compute_relative_loss_change(
loss_rel_change = self._compute_relative_change(
pre_loss, loss.item())
if loss_rel_change < ftol:
print(f'[ftol={ftol}] Early stop at {iter_idx} iter!')
Expand All @@ -388,16 +388,16 @@ def evaluate(
transl: torch.Tensor = None,
keypoints2d: torch.Tensor = None,
keypoints2d_conf: torch.Tensor = None,
keypoints2d_weight: float = 0.,
keypoints2d_weight: float = None,
keypoints3d: torch.Tensor = None,
keypoints3d_conf: torch.Tensor = None,
keypoints3d_weight: float = 0.,
shape_prior_weight: float = 0.,
joint_prior_weight: float = 0.,
smooth_loss_weight: float = 0.,
pose_prior_weight: float = 0.,
pose_reg_weight: float = 0.,
limb_length_weight: float = 0.,
keypoints3d_weight: float = None,
shape_prior_weight: float = None,
joint_prior_weight: float = None,
smooth_loss_weight: float = None,
pose_prior_weight: float = None,
pose_reg_weight: float = None,
limb_length_weight: float = None,
joint_weights: dict = {},
return_verts: bool = False,
return_full_pose: bool = False,
Expand Down Expand Up @@ -491,13 +491,13 @@ def _compute_loss(self,
keypoints2d_weight: float = None,
keypoints3d: torch.Tensor = None,
keypoints3d_conf: torch.Tensor = None,
keypoints3d_weight: float = 0.,
shape_prior_weight: float = 0.,
joint_prior_weight: float = 0.,
smooth_loss_weight: float = 0.,
pose_prior_weight: float = 0.,
pose_reg_weight: float = 0.,
limb_length_weight: float = 0.,
keypoints3d_weight: float = None,
shape_prior_weight: float = None,
joint_prior_weight: float = None,
smooth_loss_weight: float = None,
pose_prior_weight: float = None,
pose_reg_weight: float = None,
limb_length_weight: float = None,
joint_weights: dict = {},
reduction_override: str = None,
global_orient: torch.Tensor = None,
Expand Down Expand Up @@ -538,8 +538,8 @@ def _compute_loss(self,
weight = self._get_weight(**joint_weights)

# 2D keypoint loss
if keypoints2d is not None and (self.keypoints2d_mse_loss.loss_weight >
0 or keypoints2d_weight > 0):
if keypoints2d is not None and not self._skip_loss(
self.keypoints2d_mse_loss, keypoints2d_weight):
# bs = model_joints.shape[0]
# projected_joints = perspective_projection(
# model_joints,
Expand All @@ -566,8 +566,8 @@ def _compute_loss(self,
losses['keypoint2d_loss'] = keypoint2d_loss

# 3D keypoint loss
if keypoints3d is not None and (self.keypoints3d_mse_loss.loss_weight >
0 or keypoints3d_weight > 0):
if keypoints3d is not None and not self._skip_loss(
self.keypoints3d_mse_loss, keypoints3d_weight):
keypoints3d_loss = self.keypoints3d_mse_loss(
pred=model_joints,
pred_conf=model_joint_conf,
Expand All @@ -579,56 +579,47 @@ def _compute_loss(self,
losses['keypoints3d_loss'] = keypoints3d_loss

# regularizer to prevent betas from taking large values
if self.shape_prior_loss is not None and (
self.shape_prior_loss.loss_weight > 0
or shape_prior_weight > 0):
if not self._skip_loss(self.shape_prior_loss, shape_prior_weight):
shape_prior_loss = self.shape_prior_loss(
betas=betas,
loss_weight_override=shape_prior_weight,
reduction_override=reduction_override)
losses['shape_prior_loss'] = shape_prior_loss

# joint prior loss
if self.joint_prior_loss is not None and (
self.joint_prior_loss.loss_weight > 0
or joint_prior_weight > 0):
if not self._skip_loss(self.joint_prior_loss, joint_prior_weight):
joint_prior_loss = self.joint_prior_loss(
body_pose=body_pose,
loss_weight_override=joint_prior_weight,
reduction_override=reduction_override)
losses['joint_prior_loss'] = joint_prior_loss

# smooth body loss
if self.smooth_loss is not None and (self.smooth_loss.loss_weight > 0
or smooth_loss_weight > 0):
if not self._skip_loss(self.smooth_loss, smooth_loss_weight):
smooth_loss = self.smooth_loss(
body_pose=body_pose,
loss_weight_override=smooth_loss_weight,
reduction_override=reduction_override)
losses['smooth_loss'] = smooth_loss

# pose prior loss
if self.pose_prior_loss is not None and (
self.pose_prior_loss.loss_weight > 0 or pose_prior_weight > 0):
if not self._skip_loss(self.pose_prior_loss, pose_prior_weight):
pose_prior_loss = self.pose_prior_loss(
body_pose=body_pose,
loss_weight_override=pose_prior_weight,
reduction_override=reduction_override)
losses['pose_prior_loss'] = pose_prior_loss

# pose reg loss
if self.pose_reg_loss is not None and (
self.pose_reg_loss.loss_weight > 0 or pose_reg_weight > 0):
if not self._skip_loss(self.pose_reg_loss, pose_reg_weight):
pose_reg_loss = self.pose_reg_loss(
body_pose=body_pose,
loss_weight_override=pose_reg_weight,
reduction_override=reduction_override)
losses['pose_reg_loss'] = pose_reg_loss

# limb length loss
if self.limb_length_loss is not None and (
self.limb_length_loss.loss_weight > 0
or limb_length_weight > 0):
if not self._skip_loss(self.limb_length_loss, limb_length_weight):
limb_length_loss = self.limb_length_loss(
pred=model_joints,
pred_conf=model_joint_conf,
Expand Down Expand Up @@ -783,5 +774,49 @@ def _expand_betas(self, batch_size, betas):
return betas_video

@staticmethod
def _compute_relative_loss_change(pre_v, cur_v):
return (pre_v - cur_v) / max([np.abs(pre_v), np.abs(cur_v), 1])
def _compute_relative_change(pre_v, cur_v):
"""Compute relative loss change. If relative change is small enough, we
can apply early stop to accelerate the optimization. (1) When one of
the value is larger than 1, we calculate the relative change by diving
their max value. (2) When both values are smaller than 1, it degrades
to absolute change. Intuitively, if two values are small and close,
dividing the difference by the max value may yield a large value.
Args:
pre_v: previous value
cur_v: current value
Returns:
float: relative change
"""
return np.abs(pre_v - cur_v) / max([np.abs(pre_v), np.abs(cur_v), 1])

@staticmethod
def _skip_loss(loss, loss_weight_override):
"""Whether to skip loss computation. If loss is None, it will directly
skip the loss to avoid RuntimeError. If loss is not None, the table
below shows the return value. If the return value is True, it means the
computation of loss can be skipped. As the result is 0 even if it is
calculated, we can skip it to save computational cost.
| loss.loss_weight | loss_weight_override | returns |
| ---------------- | -------------------- | ------- |
| == 0 | None | True |
| != 0 | None | False |
| == 0 | == 0 | True |
| != 0 | == 0 | True |
| == 0 | != 0 | False |
| != 0 | != 0 | False |
Args:
loss: loss is an object that has attribute loss_weight.
loss.loss_weight is assigned when loss is initialized.
loss_weight_override: loss_weight used to override loss.loss_weight
Returns:
bool: True means skipping loss computation, and vice versa
"""
if (loss is None) or (loss.loss_weight == 0 and loss_weight_override is
None) or (loss_weight_override == 0):
return True
return False
20 changes: 10 additions & 10 deletions mmhuman3d/models/registrants/smplifyx.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,11 @@ def _optimize_stage(self,
keypoints2d_weight: float = None,
keypoints3d: torch.Tensor = None,
keypoints3d_conf: torch.Tensor = None,
keypoints3d_weight: float = 0.,
shape_prior_weight: float = 0.,
joint_prior_weight: float = 0.,
smooth_loss_weight: float = 0.,
pose_prior_weight: float = 0.,
keypoints3d_weight: float = None,
shape_prior_weight: float = None,
joint_prior_weight: float = None,
smooth_loss_weight: float = None,
pose_prior_weight: float = None,
joint_weights: dict = {},
num_iter: int = 1) -> None:
"""Optimize a stage of body model parameters according to
Expand Down Expand Up @@ -286,11 +286,11 @@ def evaluate(
keypoints2d_weight=None,
keypoints3d=None,
keypoints3d_conf=None,
keypoints3d_weight=0.,
shape_prior_weight=0.,
joint_prior_weight=0.,
smooth_loss_weight=0.,
pose_prior_weight=0.,
keypoints3d_weight=None,
shape_prior_weight=None,
joint_prior_weight=None,
smooth_loss_weight=None,
pose_prior_weight=None,
joint_weights={},
return_verts=False,
return_full_pose=False,
Expand Down
12 changes: 12 additions & 0 deletions tests/test_models/test_losses/test_mse_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,18 @@ def test_keypoint_mse_loss():
target = torch.zeros(1, 3, 2)
assert torch.allclose(loss(pred, target), torch.tensor(3.))

# test None reduction
loss_cfg = dict(type='KeypointMSELoss', reduction=None)
loss = build_loss(loss_cfg)
pred = torch.zeros(1, 3, 2)
target = torch.zeros(1, 3, 2)
assert torch.allclose(loss(pred, target), pred)

pred = torch.ones(1, 3, 2)
target = torch.zeros(1, 3, 2)
result = torch.ones(1, 3, 2) * 0.5
assert torch.allclose(loss(pred, target), result)

# test None reduction
loss_cfg = dict(type='KeypointMSELoss', reduction='none')
loss = build_loss(loss_cfg)
Expand Down

0 comments on commit fb5ad0f

Please sign in to comment.